Commits

Kirill Simonov committed cd2c8e4

Completed port to Oracle.

  • Participants
  • Parent commits c6b17c8

Comments (0)

Files changed (18)

File src/htsql/tr/fn/encode.py

         empty = LiteralCode('', old.domain, self.binding)
         old = FormulaCode(IfNullSig(), old.domain, self.binding,
                           lop=old, rop=empty)
+        new = FormulaCode(IfNullSig(), old.domain, self.binding,
+                          lop=new, rop=empty)
         return FormulaCode(self.signature, self.domain, self.binding,
                            op=op, old=old, new=new)
 

File src/htsql_mssql/tr/dump.py

             self.write("CAST(%s AS BIGINT)" % self.value)
 
 
-class MySQLDumpFloat(DumpFloat):
+class MSSQLDumpFloat(DumpFloat):
 
     def __call__(self):
         assert str(self.value) not in ['inf', '-inf', 'nan']
         self.write(value)
 
 
-class MySQLDumpDecimal(DumpDecimal):
+class MSSQLDumpDecimal(DumpDecimal):
 
     def __call__(self):
         assert self.value.is_finite()

File src/htsql_oracle/connect.py

         if with_autocommit:
             connection.autocommit = True
         connection.outputtypehandler = self.outputtypehandler
+        cursor = connection.cursor()
+        cursor.execute("ALTER SESSION SET NLS_SORT = BINARY_CI");
+        cursor.execute("ALTER SESSION SET NLS_COMP = LINGUISTIC");
         return connection
 
     def normalize_error(self, exception):

File src/htsql_oracle/tr/compile.py

+#
+# Copyright (c) 2006-2011, Prometheus Research, LLC
+# Authors: Clark C. Evans <cce@clarkevans.com>,
+#          Kirill Simonov <xi@resolvent.net>
+#
+
+
+from htsql.domain import BooleanDomain, IntegerDomain
+from htsql.tr.term import PermanentTerm, FilterTerm, OrderTerm
+from htsql.tr.code import LiteralCode, FormulaCode, ScalarUnit
+from htsql.tr.coerce import coerce
+from htsql.tr.signature import CompareSig
+from .signature import RowNumSig
+from htsql.tr.compile import CompileOrdered, ordering, spread
+
+
+class OracleCompileOrdered(CompileOrdered):
+
+    def __call__(self):
+        if self.space.limit is None and self.space.offset is None:
+            return super(OracleCompileOrdered, self).__call__()
+        left_limit = None
+        if self.space.offset is not None:
+            left_limit = self.space.offset+1
+        right_limit = None
+        if self.space.limit is not None:
+            if self.space.offset is not None:
+                right_limit = self.space.limit+self.space.offset+1
+            else:
+                right_limit = self.space.limit+1
+        kid = self.state.compile(self.space.base,
+                                  baseline=self.state.scalar,
+                                  mask=self.state.scalar)
+        order = ordering(self.space)
+        codes = [code for code, direction in order]
+        kid = self.state.inject(kid, [code for code, direction in order])
+        routes = kid.routes.copy()
+        for unit in spread(self.space):
+            routes[unit.clone(space=self.space)] = routes[unit]
+        kid = OrderTerm(self.state.tag(), kid, order, None, None,
+                        self.space, kid.baseline, routes)
+        kid = PermanentTerm(self.state.tag(), kid,
+                            kid.space, kid.baseline, kid.routes.copy())
+        row_num_code = FormulaCode(RowNumSig(), coerce(IntegerDomain()),
+                                   self.space.binding)
+        if right_limit is not None:
+            right_limit_code = LiteralCode(right_limit,
+                                           coerce(IntegerDomain()),
+                                           self.space.binding)
+            right_filter = FormulaCode(CompareSig('<'),
+                                       coerce(BooleanDomain()),
+                                       self.space.binding,
+                                       lop=row_num_code,
+                                       rop=right_limit_code)
+            kid = FilterTerm(self.state.tag(), kid, right_filter,
+                             kid.space, kid.baseline, kid.routes.copy())
+        else:
+            kid = WrapperTerm(self.state.tag(), kid,
+                              kid.space, kid.baseline, kid.routes.copy())
+        routes = kid.routes.copy()
+        if left_limit is not None:
+            row_num_unit = ScalarUnit(row_num_code, self.space.base,
+                                      self.space.binding)
+            routes[row_num_unit] = kid.tag
+        kid = PermanentTerm(self.state.tag(), kid,
+                            kid.space, kid.baseline, routes)
+        if left_limit is not None:
+            left_limit_code = LiteralCode(left_limit,
+                                          coerce(IntegerDomain()),
+                                          self.space.binding)
+            left_filter = FormulaCode(CompareSig('>='),
+                                      coerce(BooleanDomain()),
+                                      self.space.binding,
+                                      lop=row_num_unit,
+                                      rop=left_limit_code)
+            kid = FilterTerm(self.state.tag(), kid, left_filter,
+                             kid.space, kid.baseline, kid.routes.copy())
+        return kid
+
+

File src/htsql_oracle/tr/dump.py

 """
 
 
-from htsql.tr.dump import DumpAnchor, DumpLeadingAnchor
+from htsql.adapter import adapts
+from htsql.domain import BooleanDomain, StringDomain, DateDomain
+from htsql.tr.frame import ScalarFrame
+from htsql.tr.dump import (SerializeSegment, Dump, DumpBranch, DumpAnchor,
+                           DumpLeadingAnchor, DumpFromPredicate,
+                           DumpToPredicate, DumpBoolean, DumpInteger,
+                           DumpFloat, DumpToFloat, DumpToDecimal,
+                           DumpToString,
+                           DumpIsTotallyEqual, DumpBySignature)
+from htsql.tr.fn.dump import (DumpLength, DumpSubstring, DumpDateIncrement,
+                              DumpDateDecrement, DumpDateDifference,
+                              DumpMakeDate)
+from .signature import RowNumSig
+
+
+class OracleSerializeSegment(SerializeSegment):
+
+    max_alias_length = 30
+
+
+class OracleDumpScalar(Dump):
+
+    adapts(ScalarFrame)
+
+    def __call__(self):
+        self.write("DUAL")
+
+
+class OracleDumpBranch(DumpBranch):
+
+    def dump_group(self):
+        if not self.frame.group:
+            return
+        self.newline()
+        self.write("GROUP BY ")
+        for index, phrase in enumerate(self.frame.group):
+            self.format("{kernel}", kernel=phrase)
+            if index < len(self.frame.group)-1:
+                self.write(", ")
+
+    def dump_order(self):
+        if not self.frame.order:
+            return
+        self.newline()
+        self.format("ORDER BY ")
+        for index, (phrase, direction) in enumerate(self.frame.order):
+            if phrase in self.frame.select:
+                position = self.frame.select.index(phrase)+1
+                self.write(str(position))
+            else:
+                self.format("{kernel}", kernel=phrase)
+            self.format(" {direction:switch{ASC|DESC}}", direction=direction)
+            if phrase.is_nullable:
+                self.format(" NULLS {direction:switch{FIRST|LAST}}",
+                            direction=direction)
+            if index < len(self.frame.order)-1:
+                self.write(", ")
+
+    def dump_limit(self):
+        assert self.frame.limit is None
+        assert self.frame.offset is None
 
 
 class OracleDumpLeadingAnchor(DumpLeadingAnchor):
         self.dedent()
 
 
+class OracleDumpFromPredicate(DumpFromPredicate):
+
+    def __call__(self):
+        if self.phrase.is_nullable:
+            self.format("(CASE WHEN {op} THEN 1 WHEN NOT {op} THEN 0 END)",
+                        self.arguments)
+        else:
+            self.format("(CASE WHEN {op} THEN 1 ELSE 0 END)",
+                        self.arguments)
+
+
+class OracleDumpToPredicate(DumpToPredicate):
+
+    def __call__(self):
+        self.format("({op} <> 0)", self.arguments)
+
+
+class OracleDumpBoolean(DumpBoolean):
+
+    def __call__(self):
+        if self.value is True:
+            self.write("1")
+        if self.value is False:
+            self.write("0")
+
+
+class OracleDumpInteger(DumpInteger):
+
+    def __call__(self):
+        self.write(str(self.value))
+
+
+class OracleDumpFloat(DumpFloat):
+
+    def __call__(self):
+        assert str(self.value) not in ['inf', '-inf', 'nan']
+        self.write(repr(self.value)+'D')
+
+
+class OracleDumpToFloat(DumpToFloat):
+
+    def __call__(self):
+        self.format("CAST({base} AS BINARY_DOUBLE)", base=self.base)
+
+
+class OracleDumpToDecimal(DumpToDecimal):
+
+    def __call__(self):
+        self.format("CAST({base} AS NUMBER)", base=self.base)
+
+
+class OracleDumpToString(DumpToString):
+
+    def __call__(self):
+        self.format("TO_CHAR({base})", base=self.base)
+
+
+class OracleDumpBooleanToString(DumpToString):
+
+    adapts(BooleanDomain, StringDomain)
+
+    def __call__(self):
+        if self.base.is_nullable:
+            self.format("(CASE WHEN {base} <> 0 THEN 'true'"
+                        " WHEN NOT {base} = 0 THEN 'false' END)",
+                        base=self.base)
+        else:
+            self.format("(CASE WHEN {base} <> 0 THEN 'true' ELSE 'false' END)",
+                        base=self.base)
+
+
+class OracleDumpDateToString(DumpToString):
+
+    adapts(DateDomain, StringDomain)
+
+    def __call__(self):
+        self.format("TO_CHAR({base}, 'YYYY-MM-DD')", base=self.base)
+
+
+class OracleDumpIsTotallyEqual(DumpIsTotallyEqual):
+
+    def __call__(self):
+        self.format("((CASE WHEN ({lop} = {rop}) OR"
+                    " ({lop} IS NULL AND {rop} IS NULL)"
+                    " THEN 1 ELSE 0 END) {polarity:switch{<>|=}} 0)",
+                    self.arguments, self.signature)
+
+
+class OracleDumpRowNumber(DumpBySignature):
+
+    adapts(RowNumSig)
+
+    def __call__(self):
+        self.write("ROWNUM")
+
+
+class OracleDumpLength(DumpLength):
+
+    template = "LENGTH({op})"
+
+
+class OracleDumpSubstring(DumpSubstring):
+
+    def __call__(self):
+        if self.phrase.length is None:
+            self.format("SUBSTR({op}, {start})", self.phrase)
+        else:
+            self.format("SUBSTR({op}, {start}, {length})", self.phrase)
+
+
+class OracleDumpDateIncrement(DumpDateIncrement):
+
+    template = "({lop} + {rop})"
+
+
+class OracleDumpDateDecrement(DumpDateDecrement):
+
+    template = "({lop} - {rop})"
+
+
+class OracleDumpDateDifference(DumpDateDifference):
+
+    template = "({lop} - {rop})"
+
+
+class OracleDumpMakeDate(DumpMakeDate):
+
+    template = ("(DATE '2001-01-01' + ({year} - 2001) * INTERVAL '1' YEAR"
+                " + ({month} - 1) * INTERVAL '1' MONTH"
+                " + ({day} - 1) * INTERVAL '1' DAY)")
+
+

File src/htsql_oracle/tr/encode.py

+#
+# Copyright (c) 2006-2011, Prometheus Research, LLC
+# Authors: Clark C. Evans <cce@clarkevans.com>,
+#          Kirill Simonov <xi@resolvent.net>
+#
+
+
+from htsql.adapter import adapts
+from htsql.domain import StringDomain
+from htsql.tr.code import LiteralCode, FormulaCode
+from htsql.tr.encode import EncodeLiteral
+from htsql.tr.fn.encode import EncodeFunction
+from htsql.tr.signature import IfNullSig
+from htsql.tr.fn.signature import LengthSig
+
+
+class OracleEncodeLength(EncodeFunction):
+
+    adapts(LengthSig)
+
+    def __call__(self):
+        code = super(OracleEncodeLength, self).__call__()
+        zero = LiteralCode(0, code.domain, code.binding)
+        return FormulaCode(IfNullSig(), code.domain, code.binding,
+                           lop=code, rop=zero)
+
+

File src/htsql_oracle/tr/reduce.py

+#
+# Copyright (c) 2006-2011, Prometheus Research, LLC
+# Authors: Clark C. Evans <cce@clarkevans.com>,
+#          Kirill Simonov <xi@resolvent.net>
+#
+
+
+from htsql.domain import StringDomain
+from htsql.tr.signature import isformula, ToPredicateSig, FromPredicateSig
+from htsql.tr.frame import ScalarFrame, NullPhrase, LeadingAnchor
+from htsql.tr.reduce import (ReduceScalar, ReduceBranch, ReduceLiteral,
+                             ReduceFromPredicate, ReduceToPredicate)
+
+
+class OracleReduceScalar(ReduceScalar):
+
+    def __call__(self):
+        return self.frame
+
+
+class OracleReduceBranch(ReduceBranch):
+
+    def reduce_include(self):
+        include = super(OracleReduceBranch, self).reduce_include()
+        if not include:
+            frame = ScalarFrame(self.frame.term)
+            anchor = LeadingAnchor(frame)
+            include = [anchor]
+        return include
+
+
+class OracleReduceLiteral(ReduceLiteral):
+
+    def __call__(self):
+        if (isinstance(self.phrase.domain, StringDomain) and
+            self.phrase.value == ""):
+            return NullPhrase(self.phrase.domain, self.phrase.expression)
+        return super(OracleReduceLiteral, self).__call__()
+
+
+class OracleReduceFromPredicate(ReduceFromPredicate):
+
+    def __call__(self):
+        op = self.state.reduce(self.phrase.op)
+        if isformula(op, ToPredicateSig):
+            return op.op
+        return self.phrase.clone(is_nullable=op.is_nullable, op=op)
+
+
+class OracleReduceToPredicate(ReduceToPredicate):
+
+    def __call__(self):
+        op = self.state.reduce(self.phrase.op)
+        if isformula(op, FromPredicateSig):
+            return op.op
+        return self.phrase.clone(is_nullable=op.is_nullable, op=op)
+
+

File src/htsql_oracle/tr/rewrite.py

+#
+# Copyright (c) 2006-2011, Prometheus Research, LLC
+# Authors: Clark C. Evans <cce@clarkevans.com>,
+#          Kirill Simonov <xi@resolvent.net>
+#
+
+
+from htsql.adapter import adapts
+from htsql.domain import StringDomain
+from htsql.tr.code import LiteralCode
+from htsql.tr.rewrite import Rewrite
+
+
+class OracleRewriteLiteral(Rewrite):
+
+    adapts(LiteralCode)
+
+    def __call__(self):
+        if isinstance(self.code.domain, StringDomain) and self.code.value == "":
+            return self.code.clone(value=None)
+        return self.code
+
+

File src/htsql_oracle/tr/signature.py

+#
+# Copyright (c) 2006-2011, Prometheus Research, LLC
+# Authors: Clark C. Evans <cce@clarkevans.com>,
+#          Kirill Simonov <xi@resolvent.net>
+#
+
+
+from htsql.tr.signature import NullarySig
+
+
+class RowNumSig(NullarySig):
+    pass
+
+

File test/regress/input/library.yaml

   # Out of range
   - uri: /{18446744073709551616}
     expect: 400
+    ifndef: oracle
+  # Oracle does not have a range limitation for the `INTEGER` data type.
+  - uri: /{18446744073709551616, 340282366920938463463374607431768211456}
+    ifdef: oracle
 
   # Decimal values
   - uri: /{1.0, -2.5, 0.875}
   - uri: /{replace('OMGWTFBBQ','WTF','LOL')}
   - uri: /{replace('OMGWTFBBQ','wtf','LOL'),
            replace('OMGWTFBBQ','WTF','lol')}
-    ifndef: mssql
+    ifndef: [mssql, oracle]
   # `REPLACE` in MSSQL respects the database collation, which is
-  # case-insensitive for the regression database.
+  # case-insensitive for the regression database.  Same with Oracle
+  # when `NLS_SORT = BINARY_CI` and `NLS_COMP = LINGUISTIC`.
   - uri: /{replace('OMGWTFBBQ','wtf','LOL')}
-    ifdef: mssql
+    ifdef: [mssql, oracle]
   - uri: /{replace('OMGWTFBBQ','WTF','lol')}
-    ifdef: mssql
+    ifdef: [mssql, oracle]
   - uri: /{replace('floccinaucinihilipilification','ili','LOL')}
   - uri: /{replace('OMGWTFBBQ','','LOL'),
            replace('OMGWTFBBQ','WTF','')}
   tests:
 
   # Exists, Every, Count
+  # (Oracle cannot handle EXISTS and an aggregate in the same SELECT clause)
   - uri: /course?department='lang'
   - uri: /{exists(course?department='lang'),
            every(course?department='lang'),
            count(course?department='lang')}
+    ifndef: oracle
+  - uri: /{exists(course?department='lang'),
+           every(course?department='lang')}
+    ifdef: oracle
+  - uri: /{count(course?department='lang')}
+    ifdef: oracle
   # Applied to an empty set
   - uri: /course?department='str'
   - uri: /{exists(course?department='str'),
            every(course?department='str'),
            count(course?department='str')}
+    ifndef: oracle
   # Applied to all-TRUE, all-FALSE, mixed sets
   - uri: /course{department,no,credits,credits>3}
                 ?department={'me','mth','phys'}
   - uri: /{exists(course{credits>3}?department='me'),
            every(course{credits>3}?department='me'),
            count(course{credits>3}?department='me')}
+    ifndef: oracle
   - uri: /{exists(course{credits>3}?department='mth'),
            every(course{credits>3}?department='mth'),
            count(course{credits>3}?department='mth')}
+    ifndef: oracle
   - uri: /{exists(course{credits>3}?department='phys'),
            every(course{credits>3}?department='phys'),
            count(course{credits>3}?department='phys')}
+    ifndef: oracle
   # Coercion
   - uri: /department{code,school,boolean(school)}
   - uri: /{exists(department{school}),
            every(department{school}),
            count(department{school})}
+    ifndef: oracle
   # Singular operand
   - uri: /{exists(true())}
     expect: 400

File test/regress/input/oracle.yaml

 
 - title: Run the test collection
   id: test-oracle
-  skip: true
   tests:
   - define: oracle
   - db: *connect

File test/regress/input/translation.yaml

     - uri: /count(school?exists(department))
     - uri: /{exists(school?!exists(department)),
              count(school?!exists(department))}
+      ifndef: oracle # Oracle cannot handle EXISTS and an aggregate in
+                     # the same SELECT clause.
     - uri: /{count(course),min(course.credits),
                            max(course.credits),
                            avg(course.credits)}

File test/regress/output/mssql.yaml

             -+-----------------------------------------------------------------------------------------------------+-
              | replace(null(),'WTF','LOL') | replace('OMGWTFBBQ',null(),'LOL') | replace('OMGWTFBBQ','WTF',null()) |
             -+-----------------------------+-----------------------------------+-----------------------------------+-
-             |                             | OMGWTFBBQ                         |                                   |
+             |                             | OMGWTFBBQ                         | OMGBBQ                            |
                                                                                                              (1 row)
 
              ----
              /{replace(null(),'WTF','LOL'),replace('OMGWTFBBQ',null(),'LOL'),replace('OMGWTFBBQ','WTF',null())}
              SELECT REPLACE(NULL, 'WTF', 'LOL'),
                     REPLACE('OMGWTFBBQ', '', 'LOL'),
-                    REPLACE('OMGWTFBBQ', 'WTF', NULL)
+                    REPLACE('OMGWTFBBQ', 'WTF', '')
       - id: date-functions-and-operators
         tests:
         - uri: /{date(null()), date('2010-04-15')}

File test/regress/output/mysql.yaml

             -+-----------------------------------------------------------------------------------------------------+-
              | replace(null(),'WTF','LOL') | replace('OMGWTFBBQ',null(),'LOL') | replace('OMGWTFBBQ','WTF',null()) |
             -+-----------------------------+-----------------------------------+-----------------------------------+-
-             |                             | OMGWTFBBQ                         |                                   |
+             |                             | OMGWTFBBQ                         | OMGBBQ                            |
                                                                                                              (1 row)
 
              ----
              /{replace(null(),'WTF','LOL'),replace('OMGWTFBBQ',null(),'LOL'),replace('OMGWTFBBQ','WTF',null())}
              SELECT REPLACE(NULL, 'WTF', 'LOL'),
                     REPLACE('OMGWTFBBQ', '', 'LOL'),
-                    REPLACE('OMGWTFBBQ', 'WTF', NULL)
+                    REPLACE('OMGWTFBBQ', 'WTF', '')
       - id: date-functions-and-operators
         tests:
         - uri: /{date(null()), date('2010-04-15')}