1. Prometheus Research, LLC
  2. Prometheus
  3. htsql

Commits

Kirill Simonov  committed 65541e3

Refactoring functions.

Refactored binding and encoding of aggregate functions to reduce
the number of signatures, use the standard correlate adapter and
unify encoding of exists()/every() and regular aggregates.

  • Participants
  • Parent commits e9a000a
  • Branches default

Comments (0)

Files changed (5)

File src/htsql/tr/fn/assemble.py

View file
  • Ignore whitespace
 from ...adapter import adapts, adapts_none
 from ..assemble import EvaluateBySignature
 from ..frame import FormulaPhrase
-from .signature import ConcatenateSig, WrapExistsSig, TakeCountSig
+from .signature import ConcatenateSig, ExistsSig, CountSig
 
 
 class EvaluateFunction(EvaluateBySignature):
 
 class EvaluateWrapExists(EvaluateFunction):
 
-    adapts(WrapExistsSig)
+    adapts(ExistsSig)
     is_null_regular = False
     is_nullable = False
 
 
 class EvaluateTakeCount(EvaluateFunction):
 
-    adapts(TakeCountSig)
+    adapts(CountSig)
     is_null_regular = False
     is_nullable = False
 

File src/htsql/tr/fn/bind.py

View file
  • Ignore whitespace
 from ..error import BindError
 from ..coerce import coerce
 from ..lookup import lookup
-from .signature import (Signature, ThisSig, RootSig, DirectSig, FiberSig, AsSig,
-                        SortDirectionSig, LimitSig, SortSig, NullSig, TrueSig,
-                        FalseSig, CastSig, DateSig, IsEqualSig, IsInSig,
-                        IsTotallyEqualSig, AndSig, OrSig, NotSig, CompareSig,
-                        AddSig, ConcatenateSig, DateIncrementSig,
-                        SubtractSig, DateDecrementSig, DateDifferenceSig,
+from .signature import (Signature, UnarySig, BinarySig, ThisSig, RootSig,
+                        DirectSig, FiberSig, AsSig, SortDirectionSig, LimitSig,
+                        SortSig, NullSig, TrueSig, FalseSig, CastSig, DateSig,
+                        IsEqualSig, IsInSig, IsTotallyEqualSig, AndSig, OrSig,
+                        NotSig, CompareSig, AddSig, ConcatenateSig,
+                        DateIncrementSig, SubtractSig, DateDecrementSig,
+                        DateDifferenceSig,
                         MultiplySig, DivideSig, IsNullSig, NullIfSig,
                         IfNullSig, IfSig, SwitchSig, KeepPolaritySig,
                         ReversePolaritySig, RoundSig, RoundToSig, LengthSig,
-                        ContainsSig, ExistsSig, EverySig, BinarySig,
-                        UnarySig, CountSig, MinSig, MaxSig, SumSig, AvgSig)
+                        ContainsSig, ExistsSig, CountSig, MinMaxSig,
+                        SumSig, AvgSig, AggregateSig, QuantifySig)
 import sys
 
 
 class BindFunction(BindByName):
 
     signature = None
+    hint = None
+    help = None
 
     def match(self):
         assert self.signature is not None
     domains = []
     codomain = None
 
+    hint = None
+    help = None
+
     @classmethod
     def dominates(component, other):
+        if component.input_signature is None:
+            return False
+        if other.input_signature is None:
+            return False
         if issubclass(component, other):
             return True
         if issubclass(component.input_signature, other.input_signature):
 
     @classmethod
     def matches(component, dispatch_key):
+        if component.input_signature is None:
+            return False
         key_signature, key_domain_vector = dispatch_key
         if not issubclass(key_signature, component.input_signature):
             return False
     frame.f_locals['input_arity'] = arity
 
 
+def correlates_none():
+    frame = sys._getframe(1)
+    frame.f_locals['input_signature'] = None
+    frame.f_locals['input_domains'] = []
+    frame.f_locals['input_arity'] = 0
+
+
 class BindPolyFunction(BindFunction):
 
     signature = None
 
 class BindExistsBase(BindFunction):
 
-    signature = UnarySig
-    bind_signature = None
+    signature = ExistsSig
+    polarity = None
 
     def correlate(self, op):
         op = CastBinding(op, coerce(BooleanDomain()), op.syntax)
-        return FormulaBinding(self.bind_signature(), op.domain, self.syntax,
-                               base=self.state.base, op=op)
+        return FormulaBinding(QuantifySig(self.polarity), op.domain,
+                              self.syntax, base=self.state.base, op=op)
 
 
 class BindExists(BindExistsBase):
 
     named('exists')
-    bind_signature = ExistsSig
+    polarity = +1
 
 
 class BindEvery(BindExistsBase):
 
     named('every')
-    bind_signature = EverySig
+    polarity = -1
 
 
 class BindCount(BindFunction):
 
     named('count')
-    signature = UnarySig
+    signature = CountSig
 
     def correlate(self, op):
         op = CastBinding(op, coerce(BooleanDomain()), op.syntax)
-        return FormulaBinding(CountSig(), coerce(IntegerDomain()),
-                               self.syntax, base=self.state.base, op=op)
+        op = FormulaBinding(CountSig(), coerce(IntegerDomain()),
+                            self.syntax, op=op)
+        return FormulaBinding(AggregateSig(), op.domain, self.syntax,
+                              base=self.state.base, op=op)
 
 
-class CorrelateAggregate(Adapter):
+class BindPolyAggregate(BindPolyFunction):
 
-    adapts(Domain)
+    signature = UnarySig
+    codomain = UntypedDomain()
+
+    def correlate(self, op):
+        binding = FormulaBinding(self.signature(), self.codomain, self.syntax,
+                                 op=op)
+        correlate = CorrelateAggregate(binding, self.state)
+        return correlate()
+
+
+class CorrelateAggregate(Correlate):
+
+    correlates_none()
     signature = None
-    domain = None
+    domains = []
     codomain = None
 
     def __call__(self):
-        return (self.signature is not None and
-                self.domain is not None and
-                self.codomain is not None)
+        op = super(CorrelateAggregate, self).__call__()
+        return FormulaBinding(AggregateSig(), op.domain, op.syntax,
+                              base=self.state.base, op=op)
 
 
-class BindPolyAggregate(BindFunction):
+class BindMinMaxBase(BindPolyAggregate):
 
-    signature = UnarySig
-    correlation = None
+    signature = MinMaxSig
+    polarity = None
 
     def correlate(self, op):
-        correlate = self.correlation(op.domain)
-        if not correlate():
-            raise BindError("incompatible argument", self.syntax.mark)
-        op = CastBinding(op, coerce(correlate.domain), op.syntax)
-        return FormulaBinding(correlate.signature(),
-                              coerce(correlate.codomain), self.syntax,
-                              base=self.state.base, op=op)
+        binding = FormulaBinding(self.signature(self.polarity), self.codomain,
+                                 self.syntax, op=op)
+        correlate = CorrelateAggregate(binding, self.state)
+        return correlate()
 
 
-class CorrelateMin(CorrelateAggregate):
-
-    signature = MinSig
-
-
-class BindMin(BindPolyAggregate):
+class BindMinMaxBase(BindMinMaxBase):
 
     named('min')
-    correlation = CorrelateMin
+    signature = MinMaxSig
+    polarity = +1
 
 
-class CorrelateIntegerMin(CorrelateMin):
+class BindMinMaxBase(BindMinMaxBase):
 
-    adapts(IntegerDomain)
-    domain = IntegerDomain()
+    named('max')
+    signature = MinMaxSig
+    polarity = -1
+
+
+class CorrelateIntegerMinMax(CorrelateAggregate):
+
+    correlates(MinMaxSig, IntegerDomain)
+    signature = MinMaxSig
+    domains = [IntegerDomain()]
     codomain = IntegerDomain()
 
 
-class CorrelateDecimalMin(CorrelateMin):
+class CorrelateDecimalMinMax(CorrelateAggregate):
 
-    adapts(DecimalDomain)
-    domain = DecimalDomain()
+    correlates(MinMaxSig, DecimalDomain)
+    signature = MinMaxSig
+    domains = [DecimalDomain()]
     codomain = DecimalDomain()
 
 
-class CorrelateFloatMin(CorrelateMin):
+class CorrelateFloatMinMax(CorrelateAggregate):
 
-    adapts(FloatDomain)
-    domain = FloatDomain()
+    correlates(MinMaxSig, FloatDomain)
+    signature = MinMaxSig
+    domains = [FloatDomain()]
     codomain = FloatDomain()
 
 
-class CorrelateStringMin(CorrelateMin):
+class CorrelateStringMinMax(CorrelateAggregate):
 
-    adapts(StringDomain)
-    domain = StringDomain()
+    correlates(MinMaxSig, StringDomain)
+    signature = MinMaxSig
+    domains = [StringDomain()]
     codomain = StringDomain()
 
 
-class CorrelateDateMin(CorrelateMin):
+class CorrelateDateMinMax(CorrelateAggregate):
 
-    adapts(DateDomain)
-    domain = DateDomain()
+    correlates(MinMaxSig, DateDomain)
+    signature = MinMaxSig
+    domains = [DateDomain()]
     codomain = DateDomain()
 
 
-class CorrelateMax(CorrelateAggregate):
-
-    signature = MaxSig
-
-
-class BindMax(BindPolyAggregate):
-
-    named('max')
-    correlation = CorrelateMax
-
-
-class CorrelateIntegerMax(CorrelateMax):
-
-    adapts(IntegerDomain)
-    domain = IntegerDomain()
-    codomain = IntegerDomain()
-
-
-class CorrelateDecimalMax(CorrelateMax):
-
-    adapts(DecimalDomain)
-    domain = DecimalDomain()
-    codomain = DecimalDomain()
-
-
-class CorrelateFloatMax(CorrelateMax):
-
-    adapts(FloatDomain)
-    domain = FloatDomain()
-    codomain = FloatDomain()
-
-
-class CorrelateStringMax(CorrelateMax):
-
-    adapts(StringDomain)
-    domain = StringDomain()
-    codomain = StringDomain()
-
-
-class CorrelateDateMax(CorrelateMax):
-
-    adapts(DateDomain)
-    domain = DateDomain()
-    codomain = DateDomain()
-
-
-class CorrelateSum(CorrelateAggregate):
-
-    signature = SumSig
-
-
 class BindSum(BindPolyAggregate):
 
     named('sum')
-    correlation = CorrelateSum
+    signature = SumSig
 
 
-class CorrelateIntegerSum(CorrelateSum):
+class CorrelateIntegerSum(CorrelateAggregate):
 
-    adapts(IntegerDomain)
-    domain = IntegerDomain()
+    correlates(SumSig, IntegerDomain)
+    signature = SumSig
+    domains = [IntegerDomain()]
     codomain = IntegerDomain()
 
 
-class CorrelateDecimalSum(CorrelateSum):
+class CorrelateDecimalSum(CorrelateAggregate):
 
-    adapts(DecimalDomain)
-    domain = DecimalDomain()
+    correlates(SumSig, DecimalDomain)
+    signature = SumSig
+    domains = [DecimalDomain()]
     codomain = DecimalDomain()
 
 
-class CorrelateFloatSum(CorrelateSum):
+class CorrelateFloatSum(CorrelateAggregate):
 
-    adapts(FloatDomain)
-    domain = FloatDomain()
+    correlates(SumSig, FloatDomain)
+    signature = SumSig
+    domains = [FloatDomain()]
     codomain = FloatDomain()
 
 
-class CorrelateAvg(CorrelateAggregate):
-
-    signature = AvgSig
-
-
 class BindAvg(BindPolyAggregate):
 
     named('avg')
-    correlation = CorrelateAvg
+    signature = AvgSig
 
 
-class CorrelateDecimalAvg(CorrelateAvg):
+class CorrelateDecimalAvg(CorrelateAggregate):
 
-    adapts_many(IntegerDomain,
-                DecimalDomain)
-    domain = DecimalDomain()
+    correlates(AvgSig, IntegerDomain,
+                       DecimalDomain)
+    signature = AvgSig
+    domains = [DecimalDomain()]
     codomain = DecimalDomain()
 
 
-class CorrelateFloatAvg(CorrelateAvg):
+class CorrelateFloatAvg(CorrelateAggregate):
 
-    adapts(FloatDomain)
-    domain = FloatDomain()
+    correlates(AvgSig, FloatDomain)
+    signature = AvgSig
+    domains = [FloatDomain()]
     codomain = FloatDomain()
 
 

File src/htsql/tr/fn/dump.py

View file
  • Ignore whitespace
                         MultiplySig, DivideSig, IfSig, SwitchSig,
                         ReversePolaritySig,
                         RoundSig, RoundToSig, LengthSig,
-                        WrapExistsSig, TakeCountSig, TakeMinSig, TakeMaxSig,
-                        TakeSumSig, TakeAvgSig)
+                        ExistsSig, CountSig, MinMaxSig, SumSig, AvgSig)
 
 
 class DumpFunction(DumpBySignature):
     template = "CHARACTER_LENGTH({op})"
 
 
-class DumpWrapExists(DumpFunction):
+class DumpExists(DumpFunction):
 
-    adapts(WrapExistsSig)
+    adapts(ExistsSig)
     template = "EXISTS{op}"
 
 
-class DumpTakeCount(DumpFunction):
+class DumpCount(DumpFunction):
 
-    adapts(TakeCountSig)
+    adapts(CountSig)
     template = "COUNT({op})"
 
 
-class DumpTakeMin(DumpFunction):
+class DumpMinMax(DumpFunction):
 
-    adapts(TakeMinSig)
-    template = "MIN({op})"
+    adapts(MinMaxSig)
 
+    def __call__(self):
+        if self.signature.polarity > 0:
+            self.state.format("MIN({op})", self.arguments)
+        else:
+            self.state.format("MAX({op})", self.arguments)
 
-class DumpTakeMax(DumpFunction):
 
-    adapts(TakeMaxSig)
-    template = "MAX({op})"
+class DumpSum(DumpFunction):
 
-
-class DumpTakeSum(DumpFunction):
-
-    adapts(TakeSumSig)
+    adapts(SumSig)
     template = "SUM({op})"
 
 
-class DumpTakeAvg(DumpFunction):
+class DumpAvg(DumpFunction):
 
-    adapts(TakeAvgSig)
+    adapts(AvgSig)
     template = "AVG({op})"
 
 

File src/htsql/tr/fn/encode.py

View file
  • Ignore whitespace
 """
 
 
-from ...adapter import Adapter, adapts
-from ...domain import UntypedDomain
-from ..encode import EncodeBySignature
+from ...adapter import Adapter, adapts, adapts_many
+from ...domain import UntypedDomain, BooleanDomain
+from ..encode import EncodeBySignature, EncodingState
 from ..error import EncodeError
 from ..coerce import coerce
 from ..binding import LiteralBinding, CastBinding
 from ..code import (LiteralCode, ScalarUnit, CorrelatedUnit,
                     AggregateUnit, FilteredSpace, FormulaCode)
-from .signature import (NotSig, NullIfSig, IfNullSig, QuantifySig,
-                        WrapExistsSig, AggregateSig, CountSig,TakeCountSig,
-                        MinSig, TakeMinSig, MaxSig, TakeMaxSig,
-                        SumSig, TakeSumSig, AvgSig, TakeAvgSig)
+from .signature import (Signature, NotSig, NullIfSig, IfNullSig, QuantifySig,
+                        ExistsSig, AggregateSig, QuantifySig,
+                        CountSig, SumSig)
 
 
-class EncodeQuantify(EncodeBySignature):
+class EncodeAggregate(EncodeBySignature):
 
-    adapts(QuantifySig)
+    adapts(AggregateSig)
 
-    def __call__(self):
-        op = self.state.encode(self.binding.op)
-        if self.signature.polarity < 0:
-            op = FormulaCode(NotSig(), op.domain, op.binding, op=op)
-        space = self.state.relate(self.binding.base)
+    def aggregate(self, op, space):
         plural_units = [unit for unit in op.units
                              if not space.spans(unit.space)]
         if not plural_units:
         plural_space = plural_spaces[0]
         if not plural_space.spans(space):
             raise EncodeError("invalid plural operand", op.mark)
+        return plural_space
+
+    def __call__(self):
+        op = self.state.encode(self.binding.op)
+        space = self.state.relate(self.binding.base)
+        plural_space = self.aggregate(op, space)
+        aggregate = AggregateUnit(op, plural_space, space, self.binding)
+        wrap = WrapAggregate(aggregate, self.state)
+        wrapper = wrap()
+        wrapper = ScalarUnit(wrapper, space, self.binding)
+        return wrapper
+
+
+class WrapAggregate(Adapter):
+
+    adapts(Signature)
+
+    @classmethod
+    def dispatch(cls, unit, *args, **kwds):
+        assert isinstance(unit, AggregateUnit)
+        if not isinstance(unit.code, FormulaCode):
+            return (Signature,)
+        return (type(unit.code.signature),)
+
+    def __init__(self, unit, state):
+        assert isinstance(unit, AggregateUnit)
+        assert isinstance(state, EncodingState)
+        self.unit = unit
+        self.state = state
+        self.code = unit.code
+
+    def __call__(self):
+        return self.unit
+
+
+class EncodeCount(EncodeBySignature):
+
+    adapts(CountSig)
+
+    def __call__(self):
+        op = self.state.encode(self.binding.op)
+        false_literal = LiteralCode(False, op.domain, op.binding)
+        op = FormulaCode(NullIfSig(), op.domain, op.binding,
+                         lop=op, rop=false_literal)
+        return FormulaCode(CountSig(), self.binding.domain, self.binding,
+                           op=op)
+
+
+class WrapCountSum(WrapAggregate):
+
+    adapts_many(CountSig, SumSig)
+
+    def __call__(self):
+        zero_literal = LiteralBinding('0', UntypedDomain(),
+                                      self.unit.syntax)
+        zero_literal = CastBinding(zero_literal, self.unit.domain,
+                                   self.unit.syntax)
+        zero_literal = self.state.encode(zero_literal)
+        return FormulaCode(IfNullSig(), self.unit.domain, self.unit.binding,
+                           lop=self.unit, rop=zero_literal)
+
+
+class EncodeQuantify(EncodeAggregate):
+
+    adapts(QuantifySig)
+
+    def __call__(self):
+        op = self.state.encode(self.binding.op)
+        space = self.state.relate(self.binding.base)
+        plural_space = self.aggregate(op, space)
+        if self.signature.polarity < 0:
+            op = FormulaCode(NotSig(), op.domain, op.binding, op=op)
         plural_space = FilteredSpace(plural_space, op, self.binding)
-        op = LiteralCode(True, op.domain, self.binding)
-        aggregate = CorrelatedUnit(op, plural_space, space,
+        true_literal = LiteralCode(True, coerce(BooleanDomain()), self.binding)
+        aggregate = CorrelatedUnit(true_literal, plural_space, space,
                                    self.binding)
-        wrapper = FormulaCode(WrapExistsSig(), op.domain, self.binding,
+        wrapper = FormulaCode(ExistsSig(), op.domain, self.binding,
                               op=aggregate)
         if self.signature.polarity < 0:
             wrapper = FormulaCode(NotSig(), wrapper.domain, wrapper.binding,
         return wrapper
 
 
-class EncodeAggregate(EncodeBySignature):
-
-    adapts(AggregateSig)
-
-    def take(self, op):
-        return op
-
-    def wrap(self, op):
-        return op
-
-    def __call__(self):
-        op = self.take(self.state.encode(self.binding.op))
-        space = self.state.relate(self.binding.base)
-        plural_units = [unit for unit in op.units
-                             if not space.spans(unit.space)]
-        if not plural_units:
-            raise EncodeError("a plural operand is required", op.mark)
-        plural_spaces = []
-        for unit in plural_units:
-            if any(plural_space.dominates(unit.space)
-                   for plural_space in plural_spaces):
-                continue
-            plural_spaces = [plural_space
-                             for plural_space in plural_spaces
-                             if not unit.space.dominates(plural_space)]
-            plural_spaces.append(unit.space)
-        if len(plural_spaces) > 1:
-            raise EncodeError("invalid plural operand", op.mark)
-        plural_space = plural_spaces[0]
-        if not plural_space.spans(space):
-            raise EncodeError("invalid plural operand", op.mark)
-        aggregate = AggregateUnit(op, plural_space, space, self.binding)
-        wrapper = self.wrap(aggregate)
-        wrapper = ScalarUnit(wrapper, space, self.binding)
-        return wrapper
-
-
-class EncodeCount(EncodeAggregate):
-
-    adapts(CountSig)
-
-    def take(self, op):
-        false = LiteralCode(False, op.domain, op.binding)
-        op = FormulaCode(NullIfSig(), op.domain, op.binding,
-                         lop=op, rop=false)
-        return FormulaCode(TakeCountSig(), self.binding.domain, self.binding,
-                           op=op)
-
-    def wrap(self, op):
-        zero = LiteralBinding('0', UntypedDomain(), op.syntax)
-        zero = CastBinding(zero, op.domain, op.syntax)
-        zero = self.state.encode(zero)
-        return FormulaCode(IfNullSig(), op.domain, op.binding,
-                           lop=op, rop=zero)
-
-
-class EncodeMin(EncodeAggregate):
-
-    adapts(MinSig)
-
-    def take(self, op):
-        return FormulaCode(TakeMinSig(), self.binding.domain, self.binding,
-                           op=op)
-
-
-class EncodeMax(EncodeAggregate):
-
-    adapts(MaxSig)
-
-    def take(self, op):
-        return FormulaCode(TakeMaxSig(), self.binding.domain, self.binding,
-                           op=op)
-
-
-class EncodeSum(EncodeAggregate):
-
-    adapts(SumSig)
-
-    def take(self, op):
-        return FormulaCode(TakeSumSig(), self.binding.domain, self.binding,
-                           op=op)
-
-    def wrap(self, op):
-        zero = LiteralBinding('0', UntypedDomain(), op.syntax)
-        zero = CastBinding(zero, op.domain, op.syntax)
-        zero = self.state.encode(zero)
-        return FormulaCode(IfNullSig(), op.domain, op.binding,
-                           lop=op, rop=zero)
-
-
-class EncodeAvg(EncodeAggregate):
-
-    adapts(AvgSig)
-
-    def take(self, op):
-        return FormulaCode(TakeAvgSig(), self.binding.domain, self.binding,
-                           op=op)
-
-

File src/htsql/tr/fn/signature.py

View file
  • Ignore whitespace
     pass
 
 
-class QuantifySig(PolarSig):
-
-    slots = [
-            Slot('base'),
-            Slot('op'),
-    ]
-
-
-class ExistsSig(QuantifySig):
-
-    def __init__(self):
-        super(ExistsSig, self).__init__(polarity=+1)
-
-
-class EverySig(QuantifySig):
-
-    def __init__(self):
-        super(EverySig, self).__init__(polarity=-1)
-
-
-class WrapExistsSig(UnarySig):
-    pass
-
-
 class AggregateSig(Signature):
 
     slots = [
     ]
 
 
-class CountSig(AggregateSig):
+class QuantifySig(AggregateSig, PolarSig):
     pass
 
 
-class TakeCountSig(UnarySig):
+class ExistsSig(UnarySig):
     pass
 
 
-class MinSig(AggregateSig):
+class CountSig(UnarySig):
     pass
 
 
-class TakeMinSig(UnarySig):
+class MinMaxSig(UnarySig, PolarSig):
     pass
 
 
-class MaxSig(AggregateSig):
+class SumSig(UnarySig):
     pass
 
 
-class TakeMaxSig(UnarySig):
+class AvgSig(UnarySig):
     pass
 
 
-class SumSig(AggregateSig):
-    pass
-
-
-class TakeSumSig(UnarySig):
-    pass
-
-
-class AvgSig(AggregateSig):
-    pass
-
-
-class TakeAvgSig(UnarySig):
-    pass
-
-