Commits

Kirill Simonov committed 65541e3

Refactoring functions.

Refactored binding and encoding of aggregate functions to reduce
the number of signatures, use the standard correlate adapter and
unify encoding of exists()/every() and regular aggregates.

Comments (0)

Files changed (5)

src/htsql/tr/fn/assemble.py

 from ...adapter import adapts, adapts_none
 from ..assemble import EvaluateBySignature
 from ..frame import FormulaPhrase
-from .signature import ConcatenateSig, WrapExistsSig, TakeCountSig
+from .signature import ConcatenateSig, ExistsSig, CountSig
 
 
 class EvaluateFunction(EvaluateBySignature):
 
 class EvaluateWrapExists(EvaluateFunction):
 
-    adapts(WrapExistsSig)
+    adapts(ExistsSig)
     is_null_regular = False
     is_nullable = False
 
 
 class EvaluateTakeCount(EvaluateFunction):
 
-    adapts(TakeCountSig)
+    adapts(CountSig)
     is_null_regular = False
     is_nullable = False
 

src/htsql/tr/fn/bind.py

 from ..error import BindError
 from ..coerce import coerce
 from ..lookup import lookup
-from .signature import (Signature, ThisSig, RootSig, DirectSig, FiberSig, AsSig,
-                        SortDirectionSig, LimitSig, SortSig, NullSig, TrueSig,
-                        FalseSig, CastSig, DateSig, IsEqualSig, IsInSig,
-                        IsTotallyEqualSig, AndSig, OrSig, NotSig, CompareSig,
-                        AddSig, ConcatenateSig, DateIncrementSig,
-                        SubtractSig, DateDecrementSig, DateDifferenceSig,
+from .signature import (Signature, UnarySig, BinarySig, ThisSig, RootSig,
+                        DirectSig, FiberSig, AsSig, SortDirectionSig, LimitSig,
+                        SortSig, NullSig, TrueSig, FalseSig, CastSig, DateSig,
+                        IsEqualSig, IsInSig, IsTotallyEqualSig, AndSig, OrSig,
+                        NotSig, CompareSig, AddSig, ConcatenateSig,
+                        DateIncrementSig, SubtractSig, DateDecrementSig,
+                        DateDifferenceSig,
                         MultiplySig, DivideSig, IsNullSig, NullIfSig,
                         IfNullSig, IfSig, SwitchSig, KeepPolaritySig,
                         ReversePolaritySig, RoundSig, RoundToSig, LengthSig,
-                        ContainsSig, ExistsSig, EverySig, BinarySig,
-                        UnarySig, CountSig, MinSig, MaxSig, SumSig, AvgSig)
+                        ContainsSig, ExistsSig, CountSig, MinMaxSig,
+                        SumSig, AvgSig, AggregateSig, QuantifySig)
 import sys
 
 
 class BindFunction(BindByName):
 
     signature = None
+    hint = None
+    help = None
 
     def match(self):
         assert self.signature is not None
     domains = []
     codomain = None
 
+    hint = None
+    help = None
+
     @classmethod
     def dominates(component, other):
+        if component.input_signature is None:
+            return False
+        if other.input_signature is None:
+            return False
         if issubclass(component, other):
             return True
         if issubclass(component.input_signature, other.input_signature):
 
     @classmethod
     def matches(component, dispatch_key):
+        if component.input_signature is None:
+            return False
         key_signature, key_domain_vector = dispatch_key
         if not issubclass(key_signature, component.input_signature):
             return False
     frame.f_locals['input_arity'] = arity
 
 
+def correlates_none():
+    frame = sys._getframe(1)
+    frame.f_locals['input_signature'] = None
+    frame.f_locals['input_domains'] = []
+    frame.f_locals['input_arity'] = 0
+
+
 class BindPolyFunction(BindFunction):
 
     signature = None
 
 class BindExistsBase(BindFunction):
 
-    signature = UnarySig
-    bind_signature = None
+    signature = ExistsSig
+    polarity = None
 
     def correlate(self, op):
         op = CastBinding(op, coerce(BooleanDomain()), op.syntax)
-        return FormulaBinding(self.bind_signature(), op.domain, self.syntax,
-                               base=self.state.base, op=op)
+        return FormulaBinding(QuantifySig(self.polarity), op.domain,
+                              self.syntax, base=self.state.base, op=op)
 
 
 class BindExists(BindExistsBase):
 
     named('exists')
-    bind_signature = ExistsSig
+    polarity = +1
 
 
 class BindEvery(BindExistsBase):
 
     named('every')
-    bind_signature = EverySig
+    polarity = -1
 
 
 class BindCount(BindFunction):
 
     named('count')
-    signature = UnarySig
+    signature = CountSig
 
     def correlate(self, op):
         op = CastBinding(op, coerce(BooleanDomain()), op.syntax)
-        return FormulaBinding(CountSig(), coerce(IntegerDomain()),
-                               self.syntax, base=self.state.base, op=op)
+        op = FormulaBinding(CountSig(), coerce(IntegerDomain()),
+                            self.syntax, op=op)
+        return FormulaBinding(AggregateSig(), op.domain, self.syntax,
+                              base=self.state.base, op=op)
 
 
-class CorrelateAggregate(Adapter):
+class BindPolyAggregate(BindPolyFunction):
 
-    adapts(Domain)
+    signature = UnarySig
+    codomain = UntypedDomain()
+
+    def correlate(self, op):
+        binding = FormulaBinding(self.signature(), self.codomain, self.syntax,
+                                 op=op)
+        correlate = CorrelateAggregate(binding, self.state)
+        return correlate()
+
+
+class CorrelateAggregate(Correlate):
+
+    correlates_none()
     signature = None
-    domain = None
+    domains = []
     codomain = None
 
     def __call__(self):
-        return (self.signature is not None and
-                self.domain is not None and
-                self.codomain is not None)
+        op = super(CorrelateAggregate, self).__call__()
+        return FormulaBinding(AggregateSig(), op.domain, op.syntax,
+                              base=self.state.base, op=op)
 
 
-class BindPolyAggregate(BindFunction):
+class BindMinMaxBase(BindPolyAggregate):
 
-    signature = UnarySig
-    correlation = None
+    signature = MinMaxSig
+    polarity = None
 
     def correlate(self, op):
-        correlate = self.correlation(op.domain)
-        if not correlate():
-            raise BindError("incompatible argument", self.syntax.mark)
-        op = CastBinding(op, coerce(correlate.domain), op.syntax)
-        return FormulaBinding(correlate.signature(),
-                              coerce(correlate.codomain), self.syntax,
-                              base=self.state.base, op=op)
+        binding = FormulaBinding(self.signature(self.polarity), self.codomain,
+                                 self.syntax, op=op)
+        correlate = CorrelateAggregate(binding, self.state)
+        return correlate()
 
 
-class CorrelateMin(CorrelateAggregate):
-
-    signature = MinSig
-
-
-class BindMin(BindPolyAggregate):
+class BindMinMaxBase(BindMinMaxBase):
 
     named('min')
-    correlation = CorrelateMin
+    signature = MinMaxSig
+    polarity = +1
 
 
-class CorrelateIntegerMin(CorrelateMin):
+class BindMinMaxBase(BindMinMaxBase):
 
-    adapts(IntegerDomain)
-    domain = IntegerDomain()
+    named('max')
+    signature = MinMaxSig
+    polarity = -1
+
+
+class CorrelateIntegerMinMax(CorrelateAggregate):
+
+    correlates(MinMaxSig, IntegerDomain)
+    signature = MinMaxSig
+    domains = [IntegerDomain()]
     codomain = IntegerDomain()
 
 
-class CorrelateDecimalMin(CorrelateMin):
+class CorrelateDecimalMinMax(CorrelateAggregate):
 
-    adapts(DecimalDomain)
-    domain = DecimalDomain()
+    correlates(MinMaxSig, DecimalDomain)
+    signature = MinMaxSig
+    domains = [DecimalDomain()]
     codomain = DecimalDomain()
 
 
-class CorrelateFloatMin(CorrelateMin):
+class CorrelateFloatMinMax(CorrelateAggregate):
 
-    adapts(FloatDomain)
-    domain = FloatDomain()
+    correlates(MinMaxSig, FloatDomain)
+    signature = MinMaxSig
+    domains = [FloatDomain()]
     codomain = FloatDomain()
 
 
-class CorrelateStringMin(CorrelateMin):
+class CorrelateStringMinMax(CorrelateAggregate):
 
-    adapts(StringDomain)
-    domain = StringDomain()
+    correlates(MinMaxSig, StringDomain)
+    signature = MinMaxSig
+    domains = [StringDomain()]
     codomain = StringDomain()
 
 
-class CorrelateDateMin(CorrelateMin):
+class CorrelateDateMinMax(CorrelateAggregate):
 
-    adapts(DateDomain)
-    domain = DateDomain()
+    correlates(MinMaxSig, DateDomain)
+    signature = MinMaxSig
+    domains = [DateDomain()]
     codomain = DateDomain()
 
 
-class CorrelateMax(CorrelateAggregate):
-
-    signature = MaxSig
-
-
-class BindMax(BindPolyAggregate):
-
-    named('max')
-    correlation = CorrelateMax
-
-
-class CorrelateIntegerMax(CorrelateMax):
-
-    adapts(IntegerDomain)
-    domain = IntegerDomain()
-    codomain = IntegerDomain()
-
-
-class CorrelateDecimalMax(CorrelateMax):
-
-    adapts(DecimalDomain)
-    domain = DecimalDomain()
-    codomain = DecimalDomain()
-
-
-class CorrelateFloatMax(CorrelateMax):
-
-    adapts(FloatDomain)
-    domain = FloatDomain()
-    codomain = FloatDomain()
-
-
-class CorrelateStringMax(CorrelateMax):
-
-    adapts(StringDomain)
-    domain = StringDomain()
-    codomain = StringDomain()
-
-
-class CorrelateDateMax(CorrelateMax):
-
-    adapts(DateDomain)
-    domain = DateDomain()
-    codomain = DateDomain()
-
-
-class CorrelateSum(CorrelateAggregate):
-
-    signature = SumSig
-
-
 class BindSum(BindPolyAggregate):
 
     named('sum')
-    correlation = CorrelateSum
+    signature = SumSig
 
 
-class CorrelateIntegerSum(CorrelateSum):
+class CorrelateIntegerSum(CorrelateAggregate):
 
-    adapts(IntegerDomain)
-    domain = IntegerDomain()
+    correlates(SumSig, IntegerDomain)
+    signature = SumSig
+    domains = [IntegerDomain()]
     codomain = IntegerDomain()
 
 
-class CorrelateDecimalSum(CorrelateSum):
+class CorrelateDecimalSum(CorrelateAggregate):
 
-    adapts(DecimalDomain)
-    domain = DecimalDomain()
+    correlates(SumSig, DecimalDomain)
+    signature = SumSig
+    domains = [DecimalDomain()]
     codomain = DecimalDomain()
 
 
-class CorrelateFloatSum(CorrelateSum):
+class CorrelateFloatSum(CorrelateAggregate):
 
-    adapts(FloatDomain)
-    domain = FloatDomain()
+    correlates(SumSig, FloatDomain)
+    signature = SumSig
+    domains = [FloatDomain()]
     codomain = FloatDomain()
 
 
-class CorrelateAvg(CorrelateAggregate):
-
-    signature = AvgSig
-
-
 class BindAvg(BindPolyAggregate):
 
     named('avg')
-    correlation = CorrelateAvg
+    signature = AvgSig
 
 
-class CorrelateDecimalAvg(CorrelateAvg):
+class CorrelateDecimalAvg(CorrelateAggregate):
 
-    adapts_many(IntegerDomain,
-                DecimalDomain)
-    domain = DecimalDomain()
+    correlates(AvgSig, IntegerDomain,
+                       DecimalDomain)
+    signature = AvgSig
+    domains = [DecimalDomain()]
     codomain = DecimalDomain()
 
 
-class CorrelateFloatAvg(CorrelateAvg):
+class CorrelateFloatAvg(CorrelateAggregate):
 
-    adapts(FloatDomain)
-    domain = FloatDomain()
+    correlates(AvgSig, FloatDomain)
+    signature = AvgSig
+    domains = [FloatDomain()]
     codomain = FloatDomain()
 
 

src/htsql/tr/fn/dump.py

                         MultiplySig, DivideSig, IfSig, SwitchSig,
                         ReversePolaritySig,
                         RoundSig, RoundToSig, LengthSig,
-                        WrapExistsSig, TakeCountSig, TakeMinSig, TakeMaxSig,
-                        TakeSumSig, TakeAvgSig)
+                        ExistsSig, CountSig, MinMaxSig, SumSig, AvgSig)
 
 
 class DumpFunction(DumpBySignature):
     template = "CHARACTER_LENGTH({op})"
 
 
-class DumpWrapExists(DumpFunction):
+class DumpExists(DumpFunction):
 
-    adapts(WrapExistsSig)
+    adapts(ExistsSig)
     template = "EXISTS{op}"
 
 
-class DumpTakeCount(DumpFunction):
+class DumpCount(DumpFunction):
 
-    adapts(TakeCountSig)
+    adapts(CountSig)
     template = "COUNT({op})"
 
 
-class DumpTakeMin(DumpFunction):
+class DumpMinMax(DumpFunction):
 
-    adapts(TakeMinSig)
-    template = "MIN({op})"
+    adapts(MinMaxSig)
 
+    def __call__(self):
+        if self.signature.polarity > 0:
+            self.state.format("MIN({op})", self.arguments)
+        else:
+            self.state.format("MAX({op})", self.arguments)
 
-class DumpTakeMax(DumpFunction):
 
-    adapts(TakeMaxSig)
-    template = "MAX({op})"
+class DumpSum(DumpFunction):
 
-
-class DumpTakeSum(DumpFunction):
-
-    adapts(TakeSumSig)
+    adapts(SumSig)
     template = "SUM({op})"
 
 
-class DumpTakeAvg(DumpFunction):
+class DumpAvg(DumpFunction):
 
-    adapts(TakeAvgSig)
+    adapts(AvgSig)
     template = "AVG({op})"
 
 

src/htsql/tr/fn/encode.py

 """
 
 
-from ...adapter import Adapter, adapts
-from ...domain import UntypedDomain
-from ..encode import EncodeBySignature
+from ...adapter import Adapter, adapts, adapts_many
+from ...domain import UntypedDomain, BooleanDomain
+from ..encode import EncodeBySignature, EncodingState
 from ..error import EncodeError
 from ..coerce import coerce
 from ..binding import LiteralBinding, CastBinding
 from ..code import (LiteralCode, ScalarUnit, CorrelatedUnit,
                     AggregateUnit, FilteredSpace, FormulaCode)
-from .signature import (NotSig, NullIfSig, IfNullSig, QuantifySig,
-                        WrapExistsSig, AggregateSig, CountSig,TakeCountSig,
-                        MinSig, TakeMinSig, MaxSig, TakeMaxSig,
-                        SumSig, TakeSumSig, AvgSig, TakeAvgSig)
+from .signature import (Signature, NotSig, NullIfSig, IfNullSig, QuantifySig,
+                        ExistsSig, AggregateSig, QuantifySig,
+                        CountSig, SumSig)
 
 
-class EncodeQuantify(EncodeBySignature):
+class EncodeAggregate(EncodeBySignature):
 
-    adapts(QuantifySig)
+    adapts(AggregateSig)
 
-    def __call__(self):
-        op = self.state.encode(self.binding.op)
-        if self.signature.polarity < 0:
-            op = FormulaCode(NotSig(), op.domain, op.binding, op=op)
-        space = self.state.relate(self.binding.base)
+    def aggregate(self, op, space):
         plural_units = [unit for unit in op.units
                              if not space.spans(unit.space)]
         if not plural_units:
         plural_space = plural_spaces[0]
         if not plural_space.spans(space):
             raise EncodeError("invalid plural operand", op.mark)
+        return plural_space
+
+    def __call__(self):
+        op = self.state.encode(self.binding.op)
+        space = self.state.relate(self.binding.base)
+        plural_space = self.aggregate(op, space)
+        aggregate = AggregateUnit(op, plural_space, space, self.binding)
+        wrap = WrapAggregate(aggregate, self.state)
+        wrapper = wrap()
+        wrapper = ScalarUnit(wrapper, space, self.binding)
+        return wrapper
+
+
+class WrapAggregate(Adapter):
+
+    adapts(Signature)
+
+    @classmethod
+    def dispatch(cls, unit, *args, **kwds):
+        assert isinstance(unit, AggregateUnit)
+        if not isinstance(unit.code, FormulaCode):
+            return (Signature,)
+        return (type(unit.code.signature),)
+
+    def __init__(self, unit, state):
+        assert isinstance(unit, AggregateUnit)
+        assert isinstance(state, EncodingState)
+        self.unit = unit
+        self.state = state
+        self.code = unit.code
+
+    def __call__(self):
+        return self.unit
+
+
+class EncodeCount(EncodeBySignature):
+
+    adapts(CountSig)
+
+    def __call__(self):
+        op = self.state.encode(self.binding.op)
+        false_literal = LiteralCode(False, op.domain, op.binding)
+        op = FormulaCode(NullIfSig(), op.domain, op.binding,
+                         lop=op, rop=false_literal)
+        return FormulaCode(CountSig(), self.binding.domain, self.binding,
+                           op=op)
+
+
+class WrapCountSum(WrapAggregate):
+
+    adapts_many(CountSig, SumSig)
+
+    def __call__(self):
+        zero_literal = LiteralBinding('0', UntypedDomain(),
+                                      self.unit.syntax)
+        zero_literal = CastBinding(zero_literal, self.unit.domain,
+                                   self.unit.syntax)
+        zero_literal = self.state.encode(zero_literal)
+        return FormulaCode(IfNullSig(), self.unit.domain, self.unit.binding,
+                           lop=self.unit, rop=zero_literal)
+
+
+class EncodeQuantify(EncodeAggregate):
+
+    adapts(QuantifySig)
+
+    def __call__(self):
+        op = self.state.encode(self.binding.op)
+        space = self.state.relate(self.binding.base)
+        plural_space = self.aggregate(op, space)
+        if self.signature.polarity < 0:
+            op = FormulaCode(NotSig(), op.domain, op.binding, op=op)
         plural_space = FilteredSpace(plural_space, op, self.binding)
-        op = LiteralCode(True, op.domain, self.binding)
-        aggregate = CorrelatedUnit(op, plural_space, space,
+        true_literal = LiteralCode(True, coerce(BooleanDomain()), self.binding)
+        aggregate = CorrelatedUnit(true_literal, plural_space, space,
                                    self.binding)
-        wrapper = FormulaCode(WrapExistsSig(), op.domain, self.binding,
+        wrapper = FormulaCode(ExistsSig(), op.domain, self.binding,
                               op=aggregate)
         if self.signature.polarity < 0:
             wrapper = FormulaCode(NotSig(), wrapper.domain, wrapper.binding,
         return wrapper
 
 
-class EncodeAggregate(EncodeBySignature):
-
-    adapts(AggregateSig)
-
-    def take(self, op):
-        return op
-
-    def wrap(self, op):
-        return op
-
-    def __call__(self):
-        op = self.take(self.state.encode(self.binding.op))
-        space = self.state.relate(self.binding.base)
-        plural_units = [unit for unit in op.units
-                             if not space.spans(unit.space)]
-        if not plural_units:
-            raise EncodeError("a plural operand is required", op.mark)
-        plural_spaces = []
-        for unit in plural_units:
-            if any(plural_space.dominates(unit.space)
-                   for plural_space in plural_spaces):
-                continue
-            plural_spaces = [plural_space
-                             for plural_space in plural_spaces
-                             if not unit.space.dominates(plural_space)]
-            plural_spaces.append(unit.space)
-        if len(plural_spaces) > 1:
-            raise EncodeError("invalid plural operand", op.mark)
-        plural_space = plural_spaces[0]
-        if not plural_space.spans(space):
-            raise EncodeError("invalid plural operand", op.mark)
-        aggregate = AggregateUnit(op, plural_space, space, self.binding)
-        wrapper = self.wrap(aggregate)
-        wrapper = ScalarUnit(wrapper, space, self.binding)
-        return wrapper
-
-
-class EncodeCount(EncodeAggregate):
-
-    adapts(CountSig)
-
-    def take(self, op):
-        false = LiteralCode(False, op.domain, op.binding)
-        op = FormulaCode(NullIfSig(), op.domain, op.binding,
-                         lop=op, rop=false)
-        return FormulaCode(TakeCountSig(), self.binding.domain, self.binding,
-                           op=op)
-
-    def wrap(self, op):
-        zero = LiteralBinding('0', UntypedDomain(), op.syntax)
-        zero = CastBinding(zero, op.domain, op.syntax)
-        zero = self.state.encode(zero)
-        return FormulaCode(IfNullSig(), op.domain, op.binding,
-                           lop=op, rop=zero)
-
-
-class EncodeMin(EncodeAggregate):
-
-    adapts(MinSig)
-
-    def take(self, op):
-        return FormulaCode(TakeMinSig(), self.binding.domain, self.binding,
-                           op=op)
-
-
-class EncodeMax(EncodeAggregate):
-
-    adapts(MaxSig)
-
-    def take(self, op):
-        return FormulaCode(TakeMaxSig(), self.binding.domain, self.binding,
-                           op=op)
-
-
-class EncodeSum(EncodeAggregate):
-
-    adapts(SumSig)
-
-    def take(self, op):
-        return FormulaCode(TakeSumSig(), self.binding.domain, self.binding,
-                           op=op)
-
-    def wrap(self, op):
-        zero = LiteralBinding('0', UntypedDomain(), op.syntax)
-        zero = CastBinding(zero, op.domain, op.syntax)
-        zero = self.state.encode(zero)
-        return FormulaCode(IfNullSig(), op.domain, op.binding,
-                           lop=op, rop=zero)
-
-
-class EncodeAvg(EncodeAggregate):
-
-    adapts(AvgSig)
-
-    def take(self, op):
-        return FormulaCode(TakeAvgSig(), self.binding.domain, self.binding,
-                           op=op)
-
-

src/htsql/tr/fn/signature.py

     pass
 
 
-class QuantifySig(PolarSig):
-
-    slots = [
-            Slot('base'),
-            Slot('op'),
-    ]
-
-
-class ExistsSig(QuantifySig):
-
-    def __init__(self):
-        super(ExistsSig, self).__init__(polarity=+1)
-
-
-class EverySig(QuantifySig):
-
-    def __init__(self):
-        super(EverySig, self).__init__(polarity=-1)
-
-
-class WrapExistsSig(UnarySig):
-    pass
-
-
 class AggregateSig(Signature):
 
     slots = [
     ]
 
 
-class CountSig(AggregateSig):
+class QuantifySig(AggregateSig, PolarSig):
     pass
 
 
-class TakeCountSig(UnarySig):
+class ExistsSig(UnarySig):
     pass
 
 
-class MinSig(AggregateSig):
+class CountSig(UnarySig):
     pass
 
 
-class TakeMinSig(UnarySig):
+class MinMaxSig(UnarySig, PolarSig):
     pass
 
 
-class MaxSig(AggregateSig):
+class SumSig(UnarySig):
     pass
 
 
-class TakeMaxSig(UnarySig):
+class AvgSig(UnarySig):
     pass
 
 
-class SumSig(AggregateSig):
-    pass
-
-
-class TakeSumSig(UnarySig):
-    pass
-
-
-class AvgSig(AggregateSig):
-    pass
-
-
-class TakeAvgSig(UnarySig):
-    pass
-
-