Commits

Kirill Simonov committed b1a611f

Added functions `if()` and `switch()`.

Comments (0)

Files changed (3)

src/htsql/tr/fn/function.py

         return self.format.coalesce_fn(arguments)
 
 
+class IfFunction(ProperFunction):
+
+    adapts(named['if'])
+
+    parameters = [
+            Parameter('conditions', is_list=True),
+            Parameter('values', is_list=True),
+    ]
+
+    def bind_arguments(self, arguments, parent, mark):
+        conditions = []
+        values = []
+        for index, argument in enumerate(arguments):
+            argument = self.binder.bind_one(argument, parent)
+            if (index % 2 == 0) and index < len(arguments)-1:
+                conditions.append(argument)
+            else:
+                values.append(argument)
+        arguments = [conditions, values]
+        return self.check_arguments(arguments, mark)
+
+    def correlate(self, conditions, values, syntax, parent):
+        conditions = [self.binder.cast(condition, BooleanDomain())
+                      for condition in conditions]
+        domain = values[0].domain
+        for value in values[1:]:
+            domain = self.binder.coerce(domain, value.domain)
+            if domain is None:
+                raise InvalidArgumentError("unexpected domain", value.mark)
+        domain = self.binder.coerce(domain)
+        if domain is None:
+            raise InvalidArgumentError("unexpected domain", syntax.mark)
+        values = [self.binder.cast(value, domain) for value in values]
+        yield IfBinding(parent, domain, syntax,
+                        conditions=conditions, values=values)
+
+
+IfBinding = GenericBinding.factory(IfFunction)
+IfExpression = GenericExpression.factory(IfFunction)
+IfPhrase = GenericPhrase.factory(IfFunction)
+
+
+EncodeIf = GenericEncode.factory(IfFunction,
+        IfBinding, IfExpression)
+EvaluateIf = GenericEvaluate.factory(IfFunction,
+        IfExpression, IfPhrase,
+        is_null_regular=False)
+
+
+class SerializeIf(Serialize):
+
+    adapts(IfPhrase, Serializer)
+
+    def serialize(self):
+        conditions = [self.serializer.serialize(condition)
+                      for condition in self.phrase.conditions]
+        values = [self.serializer.serialize(value)
+                  for value in self.phrase.values]
+        return self.format.if_fn(conditions, values)
+
+
+class SwitchFunction(ProperFunction):
+
+    adapts(named['switch'])
+
+    parameters = [
+            Parameter('token'),
+            Parameter('items', is_list=True),
+            Parameter('values', is_list=True),
+    ]
+
+    def bind_arguments(self, arguments, parent, mark):
+        if not arguments:
+            return self.check_arguments([], mark)
+        token = self.binder.bind_one(arguments[0], parent)
+        items = []
+        values = []
+        for index, argument in enumerate(arguments[1:]):
+            argument = self.binder.bind_one(argument, parent)
+            if (index % 2 == 0) and index < len(arguments)-2:
+                items.append(argument)
+            else:
+                values.append(argument)
+        arguments = [[token], items, values]
+        return self.check_arguments(arguments, mark)
+
+    def correlate(self, token, items, values, syntax, parent):
+        token_domain = token.domain
+        for item in items:
+            token_domain = self.binder.coerce(token_domain, item.domain)
+            if token_domain is None:
+                raise InvalidArgumentError("unexpected domain", item.mark)
+        token_domain = self.binder.coerce(token_domain)
+        if token_domain is None:
+            raise InvalidArgumentError("unexpected domain", token.mark)
+        token = self.binder.cast(token, token_domain)
+        items = [self.binder.cast(item, token_domain) for item in items]
+        domain = values[0].domain
+        for value in values[1:]:
+            domain = self.binder.coerce(domain, value.domain)
+            if domain is None:
+                raise InvalidArgumentError("unexpected domain", value.mark)
+        domain = self.binder.coerce(domain)
+        if domain is None:
+            raise InvalidArgumentError("unexpected domain", syntax.mark)
+        values = [self.binder.cast(value, domain) for value in values]
+        yield SwitchBinding(parent, domain, syntax,
+                            token=token, items=items, values=values)
+
+
+SwitchBinding = GenericBinding.factory(SwitchFunction)
+SwitchExpression = GenericExpression.factory(SwitchFunction)
+SwitchPhrase = GenericPhrase.factory(SwitchFunction)
+
+
+EncodeSwitch = GenericEncode.factory(SwitchFunction,
+        SwitchBinding, SwitchExpression)
+EvaluateSwitch = GenericEvaluate.factory(SwitchFunction,
+        SwitchExpression, SwitchPhrase,
+        is_null_regular=False)
+
+
+class SerializeSwitch(Serialize):
+
+    adapts(SwitchPhrase, Serializer)
+
+    def serialize(self):
+        token = self.serializer.serialize(self.phrase.token)
+        items = [self.serializer.serialize(item)
+                 for item in self.phrase.items]
+        values = [self.serializer.serialize(value)
+                  for value in self.phrase.values]
+        return self.format.switch_fn(token, items, values)
+
+
 class FormatFunctions(Format):
 
     weights(0)
     def coalesce_fn(self, arguments):
         return "COALESCE(%s)" % ", ".join(arguments)
 
+    def if_fn(self, predicates, values):
+        assert len(predicates) >= 1
+        assert len(values)-1 <= len(predicates) <= len(values)
+        default = None
+        if len(predicates) == len(values)-1:
+            default = values.pop()
+        chunks = []
+        chunks.append('CASE')
+        for predicate, value in zip(predicates, values):
+            chunks.append('WHEN')
+            chunks.append(predicate)
+            chunks.append('THEN')
+            chunks.append(value)
+        if default is not None:
+            chunks.append('ELSE')
+            chunks.append(default)
+        chunks.append('END')
+        return "(%s)" % ' '.join(chunks)
+
+    def switch_fn(self, token, items, values):
+        assert len(items) >= 1
+        assert len(values)-1 <= len(items) <= len(values)
+        default = None
+        if len(items) == len(values)-1:
+            default = values.pop()
+        chunks = []
+        chunks.append('CASE')
+        chunks.append(token)
+        for item, value in zip(items, values):
+            chunks.append('WHEN')
+            chunks.append(item)
+            chunks.append('THEN')
+            chunks.append(value)
+        if default is not None:
+            chunks.append('ELSE')
+            chunks.append(default)
+        chunks.append('END')
+        return "(%s)" % ' '.join(chunks)
+
 
 class CountFunction(ProperFunction):
 

test/input/pgsql.yaml

 
     - title: Scalar functions
       tests:
-      - title: Boolean functions and operators
+      - title: Boolean constants and logical operators
         tests:
         # Boolean constants.
         - uri: /{true(),false()}
         - uri: /{!true(),!false(),!null()}
         # Auto-cast of arguments (true,false,false,true).
         - uri: /{!string(''),!string('X'),!1,!integer(null())}
+
+      - title: Comparison functions and operators
+        tests:
         # Is NULL function (null => true, otherwise => false).
         - uri: /{is_null(null()),is_null(true()),is_null(''),is_null(0)}
         # Null If method (`this` is equal to one of the arguments => null, otherwise => `this`).
                  null().if_null(null(),null(),null()),
                  null().if_null(null(),null(),0),
                  null().if_null(0,1,2,3,null())}
+        # If function (if(`cond1`,`then1`,[`cond2`,`then2`,...],[`else`])).
+        - uri: /{if(true(),1),if(false(),1),if(null(),1),
+                 if(true(),1,0),if(false(),1,0),if(null(),1,0),
+                 if(true(),1,true(),2),if(true(),1,false(),2),
+                 if(false(),1,true(),2),if(false(),1,false(),2),
+                 if(false(),1,false(),2,0)}
+        # Switch function (switch(`token`,`case1`,`then1`,[`case2`,`then2`,...],[`else`])).
+        - uri: /{switch('Y','X',1),switch('Y','Y',1),
+                 switch('Y','X',1,0),switch('Y','Y',1,0),
+                 switch(null(),null(),1,0),
+                 switch('Y','X',1,'Y',2,'Z',3),
+                 switch('Y','A',1,'B',2,'C',3,0)}
 
   # Simple (non-aggregate) filters.
   - title: Simple filters

test/output/pgsql.yaml

                 ^^^^^^
     - id: scalar-functions
       tests:
-      - id: boolean-functions-and-operators
+      - id: boolean-constants-and-logical-operators
         tests:
         - uri: /{true(),false()}
           status: 200 OK
              ----
              /{!string(''),!string('X'),!1,!integer(null())}
              SELECT (NOT (NULLIF('', '') IS NOT NULL)), (NOT (NULLIF('X', '') IS NOT NULL)), (NOT (1 IS NOT NULL)), (NOT (NULL IS NOT NULL))
+      - id: comparison-functions-and-operators
+        tests:
         - uri: /{is_null(null()),is_null(true()),is_null(''),is_null(0)}
           status: 200 OK
           headers:
              ----
              /{'X'.if_null('Y'),null().if_null('X'),null().if_null(null()),null().if_null(null(),null(),null()),null().if_null(null(),null(),0),null().if_null(0,1,2,3,null())}
              SELECT COALESCE('X', 'Y'), COALESCE(NULL, 'X'), COALESCE(NULL, NULL), COALESCE(NULL, NULL, NULL, NULL), COALESCE(NULL, NULL, NULL, 0), COALESCE(NULL, 0, 1, 2, 3, NULL)
+        - uri: /{if(true(),1),if(false(),1),if(null(),1), if(true(),1,0),if(false(),1,0),if(null(),1,0),
+            if(true(),1,true(),2),if(true(),1,false(),2), if(false(),1,true(),2),if(false(),1,false(),2),
+            if(false(),1,false(),2,0)}
+          status: 200 OK
+          headers:
+          - [Content-Type, text/plain; charset=UTF-8]
+          body: |2
+             | /{if(true(),1),if(false(),1),if(null(),1),if(true(),1,0),if(false(),1,0),if(null(),1,0),if(true(),1,true(),2),if(true(),1,false(),2),if(false(),1,true(),2),if(false(),1,false(),2),if(false(),1,false(),2,0)}                  |
+            -+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-
+             | if(true(),1) | if(false(),1) | if(null(),1) | if(true(),1,0) | if(false(),1,0) | if(null(),1,0) | if(true(),1,true(),2) | if(true(),1,false(),2) | if(false(),1,true(),2) | if(false(),1,false(),2) | if(false(),1,false(),2,0) |
+            -+--------------+---------------+--------------+----------------+-----------------+----------------+-----------------------+------------------------+------------------------+-------------------------+---------------------------+-
+             |            1 |               |              |              1 |               0 |              0 |                     1 |                      1 |                      2 |                         |                         0 |
+                                                                                                                                                                                                                                         (1 row)
+
+             ----
+             /{if(true(),1),if(false(),1),if(null(),1),if(true(),1,0),if(false(),1,0),if(null(),1,0),if(true(),1,true(),2),if(true(),1,false(),2),if(false(),1,true(),2),if(false(),1,false(),2),if(false(),1,false(),2,0)}
+             SELECT (CASE WHEN TRUE THEN 1 END), (CASE WHEN FALSE THEN 1 END), (CASE WHEN NULL THEN 1 END), (CASE WHEN TRUE THEN 1 ELSE 0 END), (CASE WHEN FALSE THEN 1 ELSE 0 END), (CASE WHEN NULL THEN 1 ELSE 0 END), (CASE WHEN TRUE THEN 1 WHEN TRUE THEN 2 END), (CASE WHEN TRUE THEN 1 WHEN FALSE THEN 2 END), (CASE WHEN FALSE THEN 1 WHEN TRUE THEN 2 END), (CASE WHEN FALSE THEN 1 WHEN FALSE THEN 2 END), (CASE WHEN FALSE THEN 1 WHEN FALSE THEN 2 ELSE 0 END)
+        - uri: /{switch('Y','X',1),switch('Y','Y',1), switch('Y','X',1,0),switch('Y','Y',1,0),
+            switch(null(),null(),1,0), switch('Y','X',1,'Y',2,'Z',3), switch('Y','A',1,'B',2,'C',3,0)}
+          status: 200 OK
+          headers:
+          - [Content-Type, text/plain; charset=UTF-8]
+          body: |2
+             | /{switch('Y','X',1),switch('Y','Y',1),switch('Y','X',1,0),switch('Y','Y',1,0),switch(null(),null(),1,0),switch('Y','X',1,'Y',2,'Z',3),switch('Y','A',1,'B',2,'C',3,0)}          |
+            -+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-
+             | switch('Y','X',1) | switch('Y','Y',1) | switch('Y','X',1,0) | switch('Y','Y',1,0) | switch(null(),null(),1,0) | switch('Y','X',1,'Y',2,'Z',3) | switch('Y','A',1,'B',2,'C',3,0) |
+            -+-------------------+-------------------+---------------------+---------------------+---------------------------+-------------------------------+---------------------------------+-
+             |                   |                 1 |                   0 |                   1 |                         0 |                             2 |                               0 |
+                                                                                                                                                                                         (1 row)
+
+             ----
+             /{switch('Y','X',1),switch('Y','Y',1),switch('Y','X',1,0),switch('Y','Y',1,0),switch(null(),null(),1,0),switch('Y','X',1,'Y',2,'Z',3),switch('Y','A',1,'B',2,'C',3,0)}
+             SELECT (CASE 'Y' WHEN 'X' THEN 1 END), (CASE 'Y' WHEN 'Y' THEN 1 END), (CASE 'Y' WHEN 'X' THEN 1 ELSE 0 END), (CASE 'Y' WHEN 'Y' THEN 1 ELSE 0 END), (CASE NULL WHEN NULL THEN 1 ELSE 0 END), (CASE 'Y' WHEN 'X' THEN 1 WHEN 'Y' THEN 2 WHEN 'Z' THEN 3 END), (CASE 'Y' WHEN 'A' THEN 1 WHEN 'B' THEN 2 WHEN 'C' THEN 3 ELSE 0 END)
   - id: simple-filters
     tests:
     - uri: /school?code='ns'