Commits

Kirill Simonov  committed 74c12ac

Added unary +/- operators, `round()` function.

Also binary operators no longer need to be decorated with underscores.

  • Participants
  • Parent commits fa949c8

Comments (0)

Files changed (5)

File src/htsql/tr/binder.py

 
     def bind(self, parent):
         name = self.syntax.symbol
-        if self.syntax.left is not None:
+        if self.syntax.left is None:
+            name = name+'_'
+        if self.syntax.right is None:
             name = '_'+name
-        if self.syntax.right is not None:
-            name = name+'_'
         function = self.binder.find_function(name)
         return function.bind_operator(self.syntax, parent)
 

File src/htsql/tr/fn/function.py

 
 class EqualityOperator(ProperFunction):
 
-    adapts(named['_=_'])
+    adapts(named['='])
 
     parameters = [
             Parameter('left'),
 
 class InequalityOperator(ProperFunction):
 
-    adapts(named['_!=_'])
+    adapts(named['!='])
 
     parameters = [
             Parameter('left'),
 
 class TotalEqualityOperator(ProperFunction):
 
-    adapts(named['_==_'])
+    adapts(named['=='])
 
     parameters = [
             Parameter('left'),
 
 class TotalInequalityOperator(ProperFunction):
 
-    adapts(named['_!==_'])
+    adapts(named['!=='])
 
     parameters = [
             Parameter('left'),
 
 class ConjunctionOperator(ProperFunction):
 
-    adapts(named['_&_'])
+    adapts(named['&'])
 
     parameters = [
             Parameter('left'),
 
 class DisjunctionOperator(ProperFunction):
 
-    adapts(named['_|_'])
+    adapts(named['|'])
 
     parameters = [
             Parameter('left'),
 
 class LessThanOperator(ComparisonOperator):
 
-    adapts(named['_<_'])
+    adapts(named['<'])
     direction = '<'
 
 
 class LessThanOrEqualOperator(ComparisonOperator):
 
-    adapts(named['_<=_'])
+    adapts(named['<='])
     direction = '<='
 
 
 class GreaterThanOperator(ComparisonOperator):
 
-    adapts(named['_>_'])
+    adapts(named['>'])
     direction = '>'
 
 
 class GreaterThanOrEqualOperator(ComparisonOperator):
 
-    adapts(named['_>=_'])
+    adapts(named['>='])
     direction = '>='
 
 
                                  direction=self.direction)
 
 
-class AdditionOperator(ProperFunction):
-
-    adapts(named['_+_'])
+class UnaryPlusOperator(ProperFunction):
+
+    adapts(named['+_'])
+
+    parameters = [
+            Parameter('value'),
+    ]
+
+    def correlate(self, value, syntax, parent):
+        Implementation = UnaryPlus.realize(value.domain)
+        plus = Implementation(value, self.binder, syntax, parent)
+        yield plus()
+
+
+class UnaryMinusOperator(ProperFunction):
+
+    adapts(named['-_'])
+
+    parameters = [
+            Parameter('value'),
+    ]
+
+    def correlate(self, value, syntax, parent):
+        Implementation = UnaryMinus.realize(value.domain)
+        minus = Implementation(value, self.binder, syntax, parent)
+        yield minus()
+
+
+class SubtractionOperator(ProperFunction):
+
+    adapts(named['-'])
 
     parameters = [
             Parameter('left'),
     ]
 
     def correlate(self, left, right, syntax, parent):
-        Implementation = Add.realize(left.domain, right.domain)
-        add = Implementation(left, right, self.binder, syntax, parent)
-        yield add()
-
-
-class SubtractionOperator(ProperFunction):
-
-    adapts(named['_-_'])
+        Implementation = Subtract.realize(left.domain, right.domain)
+        subtract = Implementation(left, right, self.binder, syntax, parent)
+        yield subtract()
+
+
+class AdditionOperator(ProperFunction):
+
+    adapts(named['+'])
 
     parameters = [
             Parameter('left'),
     ]
 
     def correlate(self, left, right, syntax, parent):
-        Implementation = Subtract.realize(left.domain, right.domain)
-        subtract = Implementation(left, right, self.binder, syntax, parent)
-        yield subtract()
-
-
-class MultiplicationOperator(ProperFunction):
-
-    adapts(named['_*_'])
+        Implementation = Add.realize(left.domain, right.domain)
+        add = Implementation(left, right, self.binder, syntax, parent)
+        yield add()
+
+
+class SubtractionOperator(ProperFunction):
+
+    adapts(named['-'])
 
     parameters = [
             Parameter('left'),
     ]
 
     def correlate(self, left, right, syntax, parent):
-        Implementation = Multiply.realize(left.domain, right.domain)
-        multiply = Implementation(left, right, self.binder, syntax, parent)
-        yield multiply()
-
-
-class DivisionOperator(ProperFunction):
-
-    adapts(named['_/_'])
+        Implementation = Subtract.realize(left.domain, right.domain)
+        subtract = Implementation(left, right, self.binder, syntax, parent)
+        yield subtract()
+
+
+class MultiplicationOperator(ProperFunction):
+
+    adapts(named['*'])
 
     parameters = [
             Parameter('left'),
     ]
 
     def correlate(self, left, right, syntax, parent):
+        Implementation = Multiply.realize(left.domain, right.domain)
+        multiply = Implementation(left, right, self.binder, syntax, parent)
+        yield multiply()
+
+
+class DivisionOperator(ProperFunction):
+
+    adapts(named['/'])
+
+    parameters = [
+            Parameter('left'),
+            Parameter('right'),
+    ]
+
+    def correlate(self, left, right, syntax, parent):
         Implementation = Divide.realize(left.domain, right.domain)
         divide = Implementation(left, right, self.binder, syntax, parent)
         yield divide()
 
 
+class UnaryPlus(Adapter):
+
+    adapts(Domain)
+
+    def __init__(self, value, binder, syntax, parent):
+        self.value = value
+        self.binder = binder
+        self.syntax = syntax
+        self.parent = parent
+
+    def __call__(self):
+        raise InvalidArgumentError("unexpected argument type",
+                                   self.syntax.mark)
+
+
+class UnaryMinus(Adapter):
+
+    adapts(Domain)
+
+    def __init__(self, value, binder, syntax, parent):
+        self.value = value
+        self.binder = binder
+        self.syntax = syntax
+        self.parent = parent
+
+    def __call__(self):
+        raise InvalidArgumentError("unexpected argument type",
+                                   self.syntax.mark)
+
+
 class Add(Adapter):
 
     adapts(Domain, Domain)
     adapts(UntypedDomain, UntypedDomain)
 
 
+UnaryPlusBinding = GenericBinding.factory(UnaryPlusOperator)
+UnaryPlusExpression = GenericExpression.factory(UnaryPlusOperator)
+UnaryPlusPhrase = GenericPhrase.factory(UnaryPlusOperator)
+
+
+EncodeUnaryPlus = GenericEncode.factory(UnaryPlusOperator,
+        UnaryPlusBinding, UnaryPlusExpression)
+EvaluateUnaryPlus = GenericEvaluate.factory(UnaryPlusOperator,
+        UnaryPlusExpression, UnaryPlusPhrase)
+SerializeUnaryPlus = GenericSerialize.factory(UnaryPlusOperator,
+        UnaryPlusPhrase, "(+ %(value)s)")
+
+
+class UnaryPlusForNumber(UnaryPlus):
+
+    adapts(NumberDomain)
+
+    def __call__(self):
+        return UnaryPlusBinding(self.parent, self.value.domain, self.syntax,
+                                value=self.value)
+
+
+UnaryMinusBinding = GenericBinding.factory(UnaryMinusOperator)
+UnaryMinusExpression = GenericExpression.factory(UnaryMinusOperator)
+UnaryMinusPhrase = GenericPhrase.factory(UnaryMinusOperator)
+
+
+EncodeUnaryMinus = GenericEncode.factory(UnaryMinusOperator,
+        UnaryMinusBinding, UnaryMinusExpression)
+EvaluateUnaryMinus = GenericEvaluate.factory(UnaryMinusOperator,
+        UnaryMinusExpression, UnaryMinusPhrase)
+SerializeUnaryMinus = GenericSerialize.factory(UnaryMinusOperator,
+        UnaryMinusPhrase, "(- %(value)s)")
+
+
+class UnaryMinusForNumber(UnaryMinus):
+
+    adapts(NumberDomain)
+
+    def __call__(self):
+        return UnaryMinusBinding(self.parent, self.value.domain, self.syntax,
+                                value=self.value)
+
+
 AdditionBinding = GenericBinding.factory(AdditionOperator)
 AdditionExpression = GenericExpression.factory(AdditionOperator)
 AdditionPhrase = GenericPhrase.factory(AdditionOperator)
     domain = FloatDomain()
 
 
+class RoundFunction(ProperFunction):
+
+    adapts(named['round'])
+
+    parameters = [
+            Parameter('value'),
+            Parameter('digits', IntegerDomain, is_mandatory=False),
+    ]
+
+    def correlate(self, value, digits, syntax, parent):
+        Implementation = Round.realize(value.domain)
+        round = Implementation(value, digits, self.binder, syntax, parent)
+        yield round()
+
+
+class Round(Adapter):
+
+    adapts(Domain)
+
+    def __init__(self, value, digits, binder, syntax, parent):
+        self.value = value
+        self.digits = digits
+        self.binder = binder
+        self.syntax = syntax
+        self.parent = parent
+
+    def __call__(self):
+        raise InvalidArgumentError("unexpected argument types",
+                                   self.syntax.mark)
+
+
+class RoundDecimal(Round):
+
+    adapts_none()
+
+    def __call__(self):
+        value = self.binder.cast(self.value, DecimalDomain(),
+                                 parent=self.parent)
+        digits = self.digits
+        if digits is None:
+            digits = LiteralBinding(self.parent, 0, IntegerDomain(),
+                                    self.syntax)
+        return RoundBinding(self.parent, DecimalDomain(), self.syntax,
+                            value=value, digits=digits)
+
+
+class RoundDecimalFromInteger(RoundDecimal):
+
+    adapts(IntegerDomain)
+
+
+class RoundDecimalFromDecimal(RoundDecimal):
+
+    adapts(DecimalDomain)
+
+
+class RoundFloat(Round):
+
+    adapts(FloatDomain)
+
+    def __call__(self):
+        if self.digits is not None:
+            raise InvalidArgumentError("unexpected argument", self.digits.mark)
+        return RoundBinding(self.parent, FloatDomain(), self.syntax,
+                            value=self.value, digits=None)
+
+
+RoundBinding = GenericBinding.factory(RoundFunction)
+RoundExpression = GenericExpression.factory(RoundFunction)
+RoundPhrase = GenericPhrase.factory(RoundFunction)
+
+
+EncodeRound = GenericEncode.factory(RoundFunction,
+        RoundBinding, RoundExpression)
+EvaluateRound = GenericEvaluate.factory(RoundFunction,
+        RoundExpression, RoundPhrase)
+
+
+class SerializeRound(Serialize):
+
+    adapts(RoundPhrase, Serializer)
+
+    def serialize(self):
+        value = self.serializer.serialize(self.phrase.value)
+        digits = None
+        if self.phrase.digits is not None:
+            digits = self.serializer.serialize(self.phrase.digits)
+        return self.format.round_fn(value, digits)
+
+
 class IsNullFunction(ProperFunction):
 
     adapts(named['is_null'])
     def coalesce_fn(self, arguments):
         return "COALESCE(%s)" % ", ".join(arguments)
 
+    def round_fn(self, value, digits=None):
+        if digits is None:
+            return "ROUND(%s)" % value
+        else:
+            return "ROUND(%s, %s)" % (value, digits)
+
     def if_fn(self, predicates, values):
         assert len(predicates) >= 1
         assert len(values)-1 <= len(predicates) <= len(values)

File src/htsql/tr/parser.py

         while symbol_tokens:
             symbol_token = symbol_tokens.pop()
             symbol = symbol_token.value
-            mark = Mark.union(symbol_token, test)
+            mark = Mark.union(symbol_token, expression)
             expression = OperatorSyntax(symbol, None, expression, mark)
         return expression
 

File test/input/pgsql.yaml

         # Invalid String->Float cast.
         - uri: /{float(string('X'))}
           expect: 409
+        # Unary plus and minus.
+        - uri: /{+2,+2.0,+2e0,-2,-2.0,-2e0,++1,+-1,-+1,--1}
         # Addition.
         - uri: /{2+2,2+2.0,2+2e0,2.0+2.0,2.0+2e0,2e0+2e0}
         # Subtraction.
         # Multiplication: overflow.
         - uri: /{65536*65536}
           expect: 409
+        # Round for decimal values.
+        - uri: /{round(65.536),round(65.536,0),
+                 round(65.536,1),round(65.536,-1)}
+        # Round with integer values (implicitly cast to decimal).
+        - uri: /{round(65535),round(65536,-3)}
+        # Round for float values.
+        - uri: /{round(35536e-3)}
+        # Invalid Round call with float values and digits indicator.
+        - uri: /{round(35536e-3,1)}
+          expect: 400
 
   # Simple (non-aggregate) filters.
   - title: Simple filters

File test/output/pgsql.yaml

             :
                 /{float(string('X'))}
                 ^^^^^^^^^^^^^^^^^^^^^
+        - uri: /{+2,+2.0,+2e0,-2,-2.0,-2e0,++1,+-1,-+1,--1}
+          status: 200 OK
+          headers:
+          - [Content-Type, text/plain; charset=UTF-8]
+          body: |2
+             | /{+2,+2.0,+2e0,-2,-2.0,-2e0,++1,+-1,-+1,--1}                |
+            -+-------------------------------------------------------------+-
+             | +2 | +2.0 | +2e0 | -2 | -2.0 | -2e0 | ++1 | +-1 | -+1 | --1 |
+            -+----+------+------+----+------+------+-----+-----+-----+-----+-
+             |  2 |  2.0 |  2.0 | -2 | -2.0 | -2.0 |   1 |  -1 |  -1 |   1 |
+                                                                     (1 row)
+
+             ----
+             /{+2,+2.0,+2e0,-2,-2.0,-2e0,++1,+-1,-+1,--1}
+             SELECT (+ 2), (+ 2.0), (+ 2.0::float8), (- 2), (- 2.0), (- 2.0::float8), (+ (+ 1)), (+ (- 1)), (- (+ 1)), (- (- 1))
         - uri: /{2+2,2+2.0,2+2e0,2.0+2.0,2.0+2e0,2e0+2e0}
           status: 200 OK
           headers:
             :
                 /{65536*65536}
                 ^^^^^^^^^^^^^^
+        - uri: /{round(65.536),round(65.536,0), round(65.536,1),round(65.536,-1)}
+          status: 200 OK
+          headers:
+          - [Content-Type, text/plain; charset=UTF-8]
+          body: |2
+             | /{round(65.536),round(65.536,0),round(65.536,1),round(65.536,-1)}    |
+            -+----------------------------------------------------------------------+-
+             | round(65.536) | round(65.536,0) | round(65.536,1) | round(65.536,-1) |
+            -+---------------+-----------------+-----------------+------------------+-
+             |            66 |              66 |            65.5 |               70 |
+                                                                              (1 row)
+
+             ----
+             /{round(65.536),round(65.536,0),round(65.536,1),round(65.536,-1)}
+             SELECT ROUND(65.536, 0), ROUND(65.536, 0), ROUND(65.536, 1), ROUND(65.536, (- 1))
+        - uri: /{round(65535),round(65536,-3)}
+          status: 200 OK
+          headers:
+          - [Content-Type, text/plain; charset=UTF-8]
+          body: |2
+             | /{round(65535),round(65536,-3)} |
+            -+---------------------------------+-
+             | round(65535)  | round(65536,-3) |
+            -+---------------+-----------------+-
+             |         65535 |           66000 |
+                                         (1 row)
+
+             ----
+             /{round(65535),round(65536,-3)}
+             SELECT ROUND(CAST(65535 AS NUMERIC), 0), ROUND(CAST(65536 AS NUMERIC), (- 3))
+        - uri: /{round(35536e-3)}
+          status: 200 OK
+          headers:
+          - [Content-Type, text/plain; charset=UTF-8]
+          body: |2
+             | /{round(35536e-3)} |
+            -+--------------------+-
+             | round(35536e-3)    |
+            -+--------------------+-
+             |               36.0 |
+                            (1 row)
+
+             ----
+             /{round(35536e-3)}
+             SELECT ROUND(35.536::float8)
+        - uri: /{round(35536e-3,1)}
+          status: 400 Bad Request
+          headers:
+          - [Content-Type, text/plain; charset=UTF-8]
+          body: |
+            invalid argument: unexpected argument:
+                /{round(35536e-3,1)}
+                                 ^
   - id: simple-filters
     tests:
     - uri: /school?code='ns'