Commits

Kirill Simonov  committed cc30973

Permit the second operand of the equality/inequality operator to be a list:
serialized to `IN/NOT IN` clause.

  • Participants
  • Parent commits 07aa4df

Comments (0)

Files changed (3)

File src/htsql/tr/fn/function.py

     named('=')
 
     parameters = [
-            Parameter('left'),
-            Parameter('right'),
+            Parameter('lop'),
+            Parameter('rops', is_list=True),
     ]
 
-    def correlate(self, left, right):
-        domain = coerce(left.domain, right.domain)
+    def correlate(self, lop, rops):
+        domain = coerce(lop.domain, *(rop.domain for rop in rops))
         if domain is None:
             raise InvalidArgumentError("incompatible types",
                                        self.syntax.mark)
-        left = CastBinding(left, domain, left.syntax)
-        right = CastBinding(right, domain, right.syntax)
-        yield EqualityBinding(left, right, self.syntax)
+        lop = CastBinding(lop, domain, lop.syntax)
+        rops = [CastBinding(rop, domain, rop.syntax) for rop in rops]
+        if len(rops) == 1:
+            yield EqualityBinding(lop, rops[0], self.syntax)
+        else:
+            yield AmongBinding(coerce(BooleanDomain()), self.syntax,
+                               lop=lop, rops=rops)
 
 
 class InequalityOperator(ProperFunction):
     named('!=')
 
     parameters = [
-            Parameter('left'),
-            Parameter('right'),
+            Parameter('lop'),
+            Parameter('rops', is_list=True),
     ]
 
-    def correlate(self, left, right):
-        domain = coerce(left.domain, right.domain)
+    def correlate(self, lop, rops):
+        domain = coerce(lop.domain, *(rop.domain for rop in rops))
         if domain is None:
             raise InvalidArgumentError("incompatible types",
                                        self.syntax.mark)
-        left = CastBinding(left, domain, left.syntax)
-        right = CastBinding(right, domain, right.syntax)
-        yield NegationBinding(EqualityBinding(left, right, self.syntax),
-                              self.syntax)
+        lop = CastBinding(lop, domain, lop.syntax)
+        rops = [CastBinding(rop, domain, rop.syntax) for rop in rops]
+        if len(rops) == 1:
+            yield NegationBinding(EqualityBinding(lop, rops[0], self.syntax),
+                                  self.syntax)
+        else:
+            yield NotAmongBinding(coerce(BooleanDomain()), self.syntax,
+                                  lop=lop, rops=rops)
 
 
 class TotalEqualityOperator(ProperFunction):
         " LPAD(CAST(%(day)s AS TEXT), 2, '0') AS DATE)")
 
 
+AmongBinding = GenericBinding.factory(EqualityOperator)
+AmongExpression = GenericExpression.factory(EqualityOperator)
+AmongPhrase = GenericPhrase.factory(EqualityOperator)
+
+
+EncodeAmong = GenericEncode.factory(EqualityOperator,
+        AmongBinding, AmongExpression)
+EvaluateAmong = GenericEvaluate.factory(EqualityOperator,
+        AmongExpression, AmongPhrase)
+ReduceAmong = GenericReduce.factory(EqualityOperator,
+        AmongPhrase)
+
+
+class SerializeAmong(Serialize):
+
+    adapts(AmongPhrase, Serializer)
+
+    def serialize(self):
+        lop = self.serializer.serialize(self.phrase.lop)
+        rops = [self.serializer.serialize(rop) for rop in self.phrase.rops]
+        return self.format.among(lop, rops)
+
+
+NotAmongBinding = GenericBinding.factory(InequalityOperator)
+NotAmongExpression = GenericExpression.factory(InequalityOperator)
+NotAmongPhrase = GenericPhrase.factory(InequalityOperator)
+
+
+EncodeNotAmong = GenericEncode.factory(InequalityOperator,
+        NotAmongBinding, NotAmongExpression)
+EvaluateNotAmong = GenericEvaluate.factory(InequalityOperator,
+        NotAmongExpression, NotAmongPhrase)
+ReduceNotAmong = GenericReduce.factory(InequalityOperator,
+        NotAmongPhrase)
+
+
+class SerializeNotAmong(Serialize):
+
+    adapts(NotAmongPhrase, Serializer)
+
+    def serialize(self):
+        lop = self.serializer.serialize(self.phrase.lop)
+        rops = [self.serializer.serialize(rop) for rop in self.phrase.rops]
+        return self.format.not_among(lop, rops)
+
+
 ComparisonBinding = GenericBinding.factory(ComparisonOperator)
 ComparisonExpression = GenericExpression.factory(ComparisonOperator)
 ComparisonPhrase = GenericPhrase.factory(ComparisonOperator)
     def coalesce_fn(self, arguments):
         return "COALESCE(%s)" % ", ".join(arguments)
 
+    def among(self, lop, rops):
+        return "(%s IN (%s))" % (lop, ", ".join(rops))
+
+    def not_among(self, lop, rops):
+        return "(%s NOT IN (%s))" % (lop, ", ".join(rops))
+
     def round_fn(self, value, digits=None):
         if digits is None:
             return "ROUND(%s)" % value

File test/input/pgsql.yaml

                  1==1,1==0,1==null(),null()==null(),
                  1!==1,1!==0,1!==null(),null()!==null()}
         - uri: /{'X'='X',1=1.0,1=1e0,1.0=1e0,1='1'}
+        - uri: /{0={1,2,3},2={1,2,3},0!={1,2,3},2!={1,2,3}}
+        - uri: /{'X'!={'A','B','C'},1.0={0,1,2}}
         - uri: /{integer('1')=string('1')}
           expect: 400
         # Less Than/Greater Than.
     tests:
     - uri: /school?code='ns'
     - uri: /department?school.code='ns'
+    - uri: /department?school.code={'art','la'}
     - uri: /program?school.code='ns'&code='uchem'
     - uri: /course?credits=5
     # ENUM literal.

File test/output/pgsql.yaml

              ----
              /{'X'='X',1=1.0,1=1e0,1.0=1e0,1='1'}
              SELECT ('X' = 'X'), (CAST(1 AS NUMERIC) = 1.0), (CAST(1 AS FLOAT) = 1.0::float8), (CAST(1.0 AS FLOAT) = 1.0::float8), (1 = 1)
+        - uri: /{0={1,2,3},2={1,2,3},0!={1,2,3},2!={1,2,3}}
+          status: 200 OK
+          headers:
+          - [Content-Type, text/plain; charset=UTF-8]
+          body: |2
+             |                                                 |
+            -+-------------------------------------------------+-
+             | 0={1,2,3} | 2={1,2,3} | 0!={1,2,3} | 2!={1,2,3} |
+            -+-----------+-----------+------------+------------+-
+             | false     | true      | true       | false      |
+                                                         (1 row)
+
+             ----
+             /{0={1,2,3},2={1,2,3},0!={1,2,3},2!={1,2,3}}
+             SELECT (0 IN (1, 2, 3)), (2 IN (1, 2, 3)), (0 NOT IN (1, 2, 3)), (2 NOT IN (1, 2, 3))
+        - uri: /{'X'!={'A','B','C'},1.0={0,1,2}}
+          status: 200 OK
+          headers:
+          - [Content-Type, text/plain; charset=UTF-8]
+          body: |2
+             |                                  |
+            -+----------------------------------+-
+             | 'X'!={'A','B','C'} | 1.0={0,1,2} |
+            -+--------------------+-------------+-
+             | true               | true        |
+                                          (1 row)
+
+             ----
+             /{'X'!={'A','B','C'},1.0={0,1,2}}
+             SELECT ('X' NOT IN ('A', 'B', 'C')), (1.0 IN (CAST(0 AS NUMERIC), CAST(1 AS NUMERIC), CAST(2 AS NUMERIC)))
         - uri: /{integer('1')=string('1')}
           status: 400 Bad Request
           headers:
          ----
          /department?school.code='ns'
          SELECT "department"."code", "department"."name", "department"."school" FROM "ad"."department" AS "department" LEFT OUTER JOIN "ad"."school" AS "school_2" ON (("department"."school" = "school_2"."code")) WHERE ("school_2"."code" = 'ns') ORDER BY 1 ASC
+    - uri: /department?school.code={'art','la'}
+      status: 200 OK
+      headers:
+      - [Content-Type, text/plain; charset=UTF-8]
+      body: |2
+         | (department?school.code={'art','la'}) |
+        -+---------------------------------------+-
+         | code    | name               | school |
+        -+---------+--------------------+--------+-
+         | arthis  | Art History        | art    |
+         | eng     | English            | la     |
+         | hist    | History            | la     |
+         | lang    | Foreign Languages  | la     |
+         | poli    | Political Science  | la     |
+         | psych   | Psychology         | la     |
+         | stdart  | Studio Art         | art    |
+                                          (7 rows)
+
+         ----
+         /department?school.code={'art','la'}
+         SELECT "department"."code", "department"."name", "department"."school" FROM "ad"."department" AS "department" LEFT OUTER JOIN "ad"."school" AS "school_2" ON (("department"."school" = "school_2"."code")) WHERE ("school_2"."code" IN ('art', 'la')) ORDER BY 1 ASC
     - uri: /program?school.code='ns'&code='uchem'
       status: 200 OK
       headers: