Commits

Kirill Simonov committed 8fad76b

Added total equality/inequality operator.

Also added some tests for the regular equality operator.

  • Participants
  • Parent commits b1a611f

Comments (0)

Files changed (9)

File src/htsql/tr/binding.py

         self.right = right
 
 
+class TotalEqualityBinding(Binding):
+
+    def __init__(self, parent, left, right, syntax):
+        assert isinstance(left, Binding)
+        assert isinstance(right, Binding)
+        domain = BooleanDomain()
+        super(TotalEqualityBinding, self).__init__(parent, domain, syntax)
+        self.left = left
+        self.right = right
+
+
+class TotalInequalityBinding(Binding):
+
+    def __init__(self, parent, left, right, syntax):
+        assert isinstance(left, Binding)
+        assert isinstance(right, Binding)
+        domain = BooleanDomain()
+        super(TotalInequalityBinding, self).__init__(parent, domain, syntax)
+        self.left = left
+        self.right = right
+
+
 class ConjunctionBinding(Binding):
 
     def __init__(self, parent, terms, syntax):

File src/htsql/tr/code.py

         return self.left.get_units()+self.right.get_units()
 
 
+class TotalEqualityExpression(Expression):
+
+    def __init__(self, left, right, mark):
+        assert isinstance(left, Expression)
+        assert isinstance(right, Expression)
+        domain = BooleanDomain()
+        super(TotalEqualityExpression, self).__init__(domain, mark,
+                                hash=(self.__class__, left.hash, right.hash))
+        self.left = left
+        self.right = right
+
+    def get_units(self):
+        return self.left.get_units()+self.right.get_units()
+
+
+class TotalInequalityExpression(Expression):
+
+    def __init__(self, left, right, mark):
+        assert isinstance(left, Expression)
+        assert isinstance(right, Expression)
+        domain = BooleanDomain()
+        super(TotalInequalityExpression, self).__init__(domain, mark,
+                                hash=(self.__class__, left.hash, right.hash))
+        self.left = left
+        self.right = right
+
+    def get_units(self):
+        return self.left.get_units()+self.right.get_units()
+
+
 class ConjunctionExpression(Expression):
 
     def __init__(self, terms, mark):

File src/htsql/tr/compiler.py

 from ..util import listof
 from ..adapter import Adapter, adapts, find_adapters
 from .code import (Expression, LiteralExpression, EqualityExpression,
-                   InequalityExpression, ConjunctionExpression,
+                   InequalityExpression, TotalEqualityExpression,
+                   TotalInequalityExpression, ConjunctionExpression,
                    DisjunctionExpression, NegationExpression,
                    CastExpression, TupleExpression, Unit)
 from .sketch import (Sketch, LeafSketch, ScalarSketch, BranchSketch,
                      LeafAppointment, BranchAppointment, FrameAppointment)
 from .frame import (LeafFrame, ScalarFrame, BranchFrame, CorrelatedFrame,
                     SegmentFrame, QueryFrame, Link, Phrase, EqualityPhrase,
-                    InequalityPhrase, ConjunctionPhrase, DisjunctionPhrase,
-                    NegationPhrase, LiteralPhrase, CastPhrase, TuplePhrase,
+                    InequalityPhrase, TotalEqualityPhrase,
+                    TotalInequalityPhrase, ConjunctionPhrase,
+                    DisjunctionPhrase, NegationPhrase, LiteralPhrase,
+                    CastPhrase, TuplePhrase,
                     LeafReferencePhrase, BranchReferencePhrase)
 
 
         return InequalityPhrase(left, right, self.expression.mark)
 
 
+class EvaluateTotalEquality(Evaluate):
+
+    adapts(TotalEqualityExpression, Compiler)
+
+    def evaluate(self, references):
+        left = self.compiler.evaluate(self.expression.left, references)
+        right = self.compiler.evaluate(self.expression.right, references)
+        return TotalEqualityPhrase(left, right, self.expression.mark)
+
+
+class EvaluateTotalInequality(Evaluate):
+
+    adapts(TotalInequalityExpression, Compiler)
+
+    def evaluate(self, references):
+        left = self.compiler.evaluate(self.expression.left, references)
+        right = self.compiler.evaluate(self.expression.right, references)
+        return TotalInequalityPhrase(left, right, self.expression.mark)
+
+
 class EvaluateConjunction(Evaluate):
 
     adapts(ConjunctionExpression, Compiler)

File src/htsql/tr/encoder.py

                       TableBinding, FreeTableBinding, JoinedTableBinding,
                       ColumnBinding, LiteralBinding, SieveBinding,
                       OrderedBinding, EqualityBinding, InequalityBinding,
+                      TotalEqualityBinding, TotalInequalityBinding,
                       ConjunctionBinding, DisjunctionBinding,
                       NegationBinding, CastBinding, TupleBinding)
 from .code import (ScalarSpace, FreeTableSpace, JoinedTableSpace,
                    ScreenSpace, OrderedSpace, LiteralExpression, ColumnUnit,
                    TupleExpression, QueryCode, SegmentCode, ElementExpression,
                    EqualityExpression, InequalityExpression,
+                   TotalEqualityExpression, TotalInequalityExpression,
                    ConjunctionExpression, DisjunctionExpression,
                    NegationExpression, CastExpression)
 from .lookup import Lookup
         return InequalityExpression(left, right, self.binding.mark)
 
 
+class EncodeTotalEquality(Encode):
+
+    adapts(TotalEqualityBinding, Encoder)
+
+    def encode(self):
+        left = self.encoder.encode(self.binding.left)
+        right = self.encoder.encode(self.binding.right)
+        return TotalEqualityExpression(left, right, self.binding.mark)
+
+
+class EncodeTotalInequality(Encode):
+
+    adapts(TotalInequalityBinding, Encoder)
+
+    def encode(self):
+        left = self.encoder.encode(self.binding.left)
+        right = self.encoder.encode(self.binding.right)
+        return TotalInequalityExpression(left, right, self.binding.mark)
+
+
 class EncodeConjunction(Encode):
 
     adapts(ConjunctionBinding, Encoder)

File src/htsql/tr/fn/function.py

                        IntegerDomain, DecimalDomain, FloatDomain, DateDomain)
 from ..binding import (LiteralBinding, OrderedBinding, FunctionBinding,
                        EqualityBinding, InequalityBinding,
+                       TotalEqualityBinding, TotalInequalityBinding,
                        ConjunctionBinding, DisjunctionBinding, NegationBinding)
 from ..encoder import Encoder, Encode
 from ..code import FunctionExpression, AggregateUnit
         domain = self.binder.coerce(left.domain,
                                     right.domain)
         if domain is None:
-            raise InvalidArgumentError("invalid arguments", syntax.mark)
+            raise InvalidArgumentError("incompatible types",
+                                       syntax.mark)
         domain = self.binder.coerce(domain)
         if domain is None:
-            raise InvalidArgumentError("invalid arguments", syntax.mark)
+            raise InvalidArgumentError("incompatible types",
+                                       syntax.mark)
         left = self.binder.cast(left, domain)
         right = self.binder.cast(right, domain)
         yield EqualityBinding(parent, left, right, syntax)
         domain = self.binder.coerce(left.domain,
                                     right.domain)
         if domain is None:
-            raise InvalidArgumentError("invalid arguments", syntax.mark)
+            raise InvalidArgumentError("incompatible types",
+                                       syntax.mark)
         domain = self.binder.coerce(domain)
         if domain is None:
-            raise InvalidArgumentError("invalid arguments", syntax.mark)
+            raise InvalidArgumentError("incompatible types",
+                                       syntax.mark)
         left = self.binder.cast(left, domain)
         right = self.binder.cast(right, domain)
         yield InequalityBinding(parent, left, right, syntax)
 
 
+class TotalEqualityOperator(ProperFunction):
+
+    adapts(named['_==_'])
+
+    parameters = [
+            Parameter('left'),
+            Parameter('right'),
+    ]
+
+    def correlate(self, left, right, syntax, parent):
+        domain = self.binder.coerce(left.domain,
+                                    right.domain)
+        if domain is None:
+            raise InvalidArgumentError("incompatible types",
+                                       syntax.mark)
+        domain = self.binder.coerce(domain)
+        if domain is None:
+            raise InvalidArgumentError("incompatible types",
+                                       syntax.mark)
+        left = self.binder.cast(left, domain)
+        right = self.binder.cast(right, domain)
+        yield TotalEqualityBinding(parent, left, right, syntax)
+
+
+class TotalInequalityOperator(ProperFunction):
+
+    adapts(named['_!==_'])
+
+    parameters = [
+            Parameter('left'),
+            Parameter('right'),
+    ]
+
+    def correlate(self, left, right, syntax, parent):
+        domain = self.binder.coerce(left.domain,
+                                    right.domain)
+        if domain is None:
+            raise InvalidArgumentError("incompatible types",
+                                       syntax.mark)
+        domain = self.binder.coerce(domain)
+        if domain is None:
+            raise InvalidArgumentError("incompatible types",
+                                       syntax.mark)
+        left = self.binder.cast(left, domain)
+        right = self.binder.cast(right, domain)
+        yield TotalInequalityBinding(parent, left, right, syntax)
+
+
 class ConjunctionOperator(ProperFunction):
 
     adapts(named['_&_'])

File src/htsql/tr/frame.py

         self.right = right
 
 
+class TotalEqualityPhrase(Phrase):
+
+    def __init__(self, left, right, mark):
+        assert isinstance(left, Phrase)
+        assert isinstance(right, Phrase)
+        domain = BooleanDomain()
+        super(TotalEqualityPhrase, self).__init__(domain, False, mark)
+        self.left = left
+        self.right = right
+
+
+class TotalInequalityPhrase(Phrase):
+
+    def __init__(self, left, right, mark):
+        assert isinstance(left, Phrase)
+        assert isinstance(right, Phrase)
+        domain = BooleanDomain()
+        super(TotalInequalityPhrase, self).__init__(domain, False, mark)
+        self.left = left
+        self.right = right
+
+
 class ConjunctionPhrase(Phrase):
 
     def __init__(self, terms, mark):

File src/htsql/tr/serializer.py

 
 
 from ..adapter import Adapter, Utility, adapts, find_adapters
+from ..error import InvalidArgumentError
 from ..domain import (Domain, BooleanDomain, NumberDomain, IntegerDomain,
                       DecimalDomain, FloatDomain, StringDomain, DateDomain)
 from .frame import (Clause, Frame, LeafFrame, ScalarFrame,
                     BranchFrame, CorrelatedFrame, SegmentFrame,
                     QueryFrame, Phrase, EqualityPhrase, InequalityPhrase,
+                    TotalEqualityPhrase, TotalInequalityPhrase,
                     ConjunctionPhrase, DisjunctionPhrase, NegationPhrase,
                     CastPhrase, LiteralPhrase, LeafReferencePhrase,
                     BranchReferencePhrase, CorrelatedFramePhrase, TuplePhrase)
             op = "!="
         return self.binary_op(left, op, right)
 
+    def total_equal_op(self, left, right, is_negative=False):
+        op = "IS NOT DISTINCT FROM"
+        if is_negative:
+            op = "IS DISTINCT FROM"
+        return self.binary_op(left, op, right)
+
     def to_boolean(self, value):
         return "(%s IS NOT NULL)" % value
 
     def to_boolean_from_string(self, value):
         return "(NULLIF(%s, '') IS NOT NULL)" % value
 
+    def to_integer(self, value):
+        return "CAST(%s AS INTEGER)" % value
+
+    def to_decimal(self, value):
+        return "CAST(%s AS NUMERIC)" % value
+
+    def to_float(self, value):
+        return "CAST(%s AS FLOAT)" % value
+
     def is_null(self, arg):
         return "(%s IS NULL)" % arg
 
         return self.format.equal_op(left, right, is_negative=True)
 
 
+class SerializeTotalEquality(SerializePhrase):
+
+    adapts(TotalEqualityPhrase, Serializer)
+
+    def serialize(self):
+        left = self.serializer.serialize(self.phrase.left)
+        right = self.serializer.serialize(self.phrase.right)
+        return self.format.total_equal_op(left, right)
+
+
+class SerializeTotalInequality(SerializePhrase):
+
+    adapts(TotalInequalityPhrase, Serializer)
+
+    def serialize(self):
+        left = self.serializer.serialize(self.phrase.left)
+        right = self.serializer.serialize(self.phrase.right)
+        return self.format.total_equal_op(left, right, is_negative=True)
+
+
 class SerializeConjunction(SerializePhrase):
 
     adapts(ConjunctionPhrase, Serializer)
         self.serializer = serializer
         self.format = serializer.format
 
+    def serialize(self, phrase):
+        raise InvalidArgumentError("unable to cast", phrase.mark)
+
 
 class SerializeToBooleanFromString(SerializeTo):
 
         return self.format.to_boolean(value)
 
 
+class SerializeToInteger(SerializeTo):
+
+    adapts(IntegerDomain, Domain, Serializer)
+
+    def serialize(self, phrase):
+        value = self.serializer.serialize(phrase)
+        return self.format.to_integer(value)
+
+
+class SerializeToDecimal(SerializeTo):
+
+    adapts(DecimalDomain, Domain, Serializer)
+
+    def serialize(self, phrase):
+        value = self.serializer.serialize(phrase)
+        return self.format.to_decimal(value)
+
+
+class SerializeToFloat(SerializeTo):
+
+    adapts(FloatDomain, Domain, Serializer)
+
+    def serialize(self, phrase):
+        value = self.serializer.serialize(phrase)
+        return self.format.to_float(value)
+
+
 class SerializeLiteral(SerializePhrase):
 
     adapts(LiteralPhrase, Serializer)

File test/input/pgsql.yaml

                  switch(null(),null(),1,0),
                  switch('Y','X',1,'Y',2,'Z',3),
                  switch('Y','A',1,'B',2,'C',3,0)}
+        # Equality/Inequality
+        - uri: /{1=1,1=0,1=null(),null()=null(),
+                 1!=1,1!=0,1!=null(),null()!=null(),
+                 1==1,1==0,1==null(),null()==null(),
+                 1!==1,1!==0,1!==null(),null()!==null()}
+        - uri: /{'X'='X',1=1.0,1=1e0,1.0=1e0,1='1'}
+        - uri: /{integer('1')=string('1')}
+          expect: 400
 
   # Simple (non-aggregate) filters.
   - title: Simple filters

File test/output/pgsql.yaml

              ----
              /{switch('Y','X',1),switch('Y','Y',1),switch('Y','X',1,0),switch('Y','Y',1,0),switch(null(),null(),1,0),switch('Y','X',1,'Y',2,'Z',3),switch('Y','A',1,'B',2,'C',3,0)}
              SELECT (CASE 'Y' WHEN 'X' THEN 1 END), (CASE 'Y' WHEN 'Y' THEN 1 END), (CASE 'Y' WHEN 'X' THEN 1 ELSE 0 END), (CASE 'Y' WHEN 'Y' THEN 1 ELSE 0 END), (CASE NULL WHEN NULL THEN 1 ELSE 0 END), (CASE 'Y' WHEN 'X' THEN 1 WHEN 'Y' THEN 2 WHEN 'Z' THEN 3 END), (CASE 'Y' WHEN 'A' THEN 1 WHEN 'B' THEN 2 WHEN 'C' THEN 3 ELSE 0 END)
+        - uri: /{1=1,1=0,1=null(),null()=null(), 1!=1,1!=0,1!=null(),null()!=null(),
+            1==1,1==0,1==null(),null()==null(), 1!==1,1!==0,1!==null(),null()!==null()}
+          status: 200 OK
+          headers:
+          - [Content-Type, text/plain; charset=UTF-8]
+          body: |2
+             | /{1=1,1=0,1=null(),null()=null(),1!=1,1!=0,1!=null(),null()!=null(),1==1,1==0,1==null(),null()==null(),1!==1,1!==0,1!==null(),null()!==null()}                                 |
+            -+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-
+             | 1=1  | 1=0   | 1=null() | null()=null() | 1!=1  | 1!=0 | 1!=null() | null()!=null() | 1==1 | 1==0  | 1==null() | null()==null() | 1!==1 | 1!==0 | 1!==null() | null()!==null() |
+            -+------+-------+----------+---------------+-------+------+-----------+----------------+------+-------+-----------+----------------+-------+-------+------------+-----------------+-
+             | true | false |          |               | false | true |           |                | true | false | false     | true           | false | true  | true       | false           |
+                                                                                                                                                                                        (1 row)
+
+             ----
+             /{1=1,1=0,1=null(),null()=null(),1!=1,1!=0,1!=null(),null()!=null(),1==1,1==0,1==null(),null()==null(),1!==1,1!==0,1!==null(),null()!==null()}
+             SELECT (1 = 1), (1 = 0), (1 = NULL), (NULL = NULL), (1 != 1), (1 != 0), (1 != NULL), (NULL != NULL), (1 IS NOT DISTINCT FROM 1), (1 IS NOT DISTINCT FROM 0), (1 IS NOT DISTINCT FROM NULL), (NULL IS NOT DISTINCT FROM NULL), (1 IS DISTINCT FROM 1), (1 IS DISTINCT FROM 0), (1 IS DISTINCT FROM NULL), (NULL IS DISTINCT FROM NULL)
+        - uri: /{'X'='X',1=1.0,1=1e0,1.0=1e0,1='1'}
+          status: 200 OK
+          headers:
+          - [Content-Type, text/plain; charset=UTF-8]
+          body: |2
+             | /{'X'='X',1=1.0,1=1e0,1.0=1e0,1='1'}      |
+            -+-------------------------------------------+-
+             | 'X'='X' | 1=1.0 | 1=1e0 | 1.0=1e0 | 1='1' |
+            -+---------+-------+-------+---------+-------+-
+             | true    | true  | true  | true    | true  |
+                                                   (1 row)
+
+             ----
+             /{'X'='X',1=1.0,1=1e0,1.0=1e0,1='1'}
+             SELECT ('X' = 'X'), (CAST(1 AS NUMERIC) = 1.0), (CAST(1 AS FLOAT) = 1.0::float8), (CAST(1.0 AS FLOAT) = 1.0::float8), (1 = 1)
+        - uri: /{integer('1')=string('1')}
+          status: 400 Bad Request
+          headers:
+          - [Content-Type, text/plain; charset=UTF-8]
+          body: |
+            invalid argument: incompatible types:
+                /{integer('1')=string('1')}
+                  ^^^^^^^^^^^^^^^^^^^^^^^^
   - id: simple-filters
     tests:
     - uri: /school?code='ns'