Mike Bayer avatar Mike Bayer committed d4f167b

- very rudimental support for OUT parameters added; use sql.outparam(name, type)
to set up an OUT parameter, just like bindparam(); after execution, values are
avaiable via result.out_parameters dictionary. [ticket:507]
- dialect.get_type_map() apparently never worked, not sure why unit test seemed
to work the first time around.
- OracleText doesn't seem to return cx_oracle.LOB.

Comments (0)

Files changed (9)

     from SelectResults isn't present anymore, need to use join(). 
 - postgres
   - Added PGArray datatype for using postgres array datatypes
+- oracle
+  - very rudimental support for OUT parameters added; use sql.outparam(name, type)
+    to set up an OUT parameter, just like bindparam(); after execution, values are
+    avaiable via result.out_parameters dictionary. [ticket:507]
 
 0.3.11
 - orm

lib/sqlalchemy/databases/oracle.py

     def get_col_spec(self):
         return "CLOB"
 
-    def convert_result_value(self, value, dialect):
-        if value is None:
-            return None
-        else:
-            return super(OracleText, self).convert_result_value(value.read(), dialect)
+   # def convert_result_value(self, value, dialect):
+   #     if value is None:
+   #         return None
+   #     else:
+   #         return super(OracleText, self).convert_result_value(value.read(), dialect)
 
 
 class OracleRaw(sqltypes.Binary):
         super(OracleExecutionContext, self).pre_exec()
         if self.dialect.auto_setinputsizes:
             self.set_input_sizes()
+        if self.compiled_parameters is not None and not isinstance(self.compiled_parameters, list):
+            for key in self.compiled_parameters:
+                (bindparam, name, value) = self.compiled_parameters.get_parameter(key)
+                if bindparam.isoutparam:
+                    dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
+                    if not hasattr(self, 'out_parameters'):
+                        self.out_parameters = {}
+                    self.out_parameters[name] = self.cursor.var(dbtype)
+                    self.parameters[name] = self.out_parameters[name]
 
     def get_result_proxy(self):
+        if hasattr(self, 'out_parameters'):
+            if self.compiled_parameters is not None:
+                 for k in self.out_parameters:
+                     type = self.compiled_parameters.get_type(k)
+                     self.out_parameters[k] = type.dialect_impl(self.dialect).convert_result_value(self.out_parameters[k].getvalue(), self.dialect)
+            else:
+                 for k in self.out_parameters:
+                     self.out_parameters[k] = self.out_parameters[k].getvalue()
+
         if self.cursor.description is not None:
             for column in self.cursor.description:
                 type_code = column[1]

lib/sqlalchemy/engine/base.py

             return self.context.get_rowcount()
     rowcount = property(_get_rowcount)
     lastrowid = property(lambda s:s.cursor.lastrowid)
+    out_parameters = property(lambda s:s.context.out_parameters)
     
     def _init_metadata(self):
         if hasattr(self, '_ResultProxy__props'):

lib/sqlalchemy/engine/default.py

         dialect_module = sys.modules[self.__class__.__module__]
         map = {}
         for obj in dialect_module.__dict__.values():
-            if isinstance(obj, types.TypeEngine):
-                map[obj().get_dbapi_type(self.dialect)] = obj
+            if isinstance(obj, type) and issubclass(obj, types.TypeEngine):
+                obj = obj()
+                map[obj.get_dbapi_type(self.dbapi)] = obj
         self._dbapi_type_map = map
     
     def decode_result_columnname(self, name):

lib/sqlalchemy/sql.py

            'between', 'bindparam', 'case', 'cast', 'column', 'delete',
            'desc', 'distinct', 'except_', 'except_all', 'exists', 'extract', 'func', 'modifier',
            'insert', 'intersect', 'intersect_all', 'join', 'literal',
-           'literal_column', 'not_', 'null', 'or_', 'outerjoin', 'select',
+           'literal_column', 'not_', 'null', 'or_', 'outparam', 'outerjoin', 'select',
            'subquery', 'table', 'text', 'union', 'union_all', 'update',]
 
 BIND_PARAMS = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
     attribute, which returns a dictionary containing the values.
     """
     
-    return _BindParamClause(key, type_=type_, unique=False, isoutparam=True)
+    return _BindParamClause(key, None, type_=type_, unique=False, isoutparam=True)
     
 def text(text, bind=None, *args, **kwargs):
     """Create literal text to be inserted into a query.
         self.__binds = {}
         self.positional = positional or []
 
+    def get_parameter(self, key):
+        return self.__binds[key]
+
     def set_parameter(self, bindparam, value, name):
         self.__binds[name] = [bindparam, name, value]
         
    
     def keys(self):
         return self.__binds.keys()
+
+    def __iter__(self):
+        return iter(self.keys())
  
     def __getitem__(self, key):
         return self.get_processed(key)

test/dialect/alltests.py

     modules_to_test = (
         'dialect.mysql',
         'dialect.postgres',
+        'dialect.oracle',
         )
     alltests = unittest.TestSuite()
     for name in modules_to_test:

test/dialect/oracle.py

+import testbase, testing
+from sqlalchemy import *
+from sqlalchemy.databases import mysql
+from testlib import *
+
+
+class OutParamTest(AssertMixin):
+    @testing.supported('oracle')
+    def setUpAll(self):
+        testbase.db.execute("""
+create or replace procedure foo(x_in IN number, x_out OUT number, y_out OUT number) IS
+  retval number;
+    begin
+    retval := 6;
+    x_out := 10;
+    y_out := x_in * 15;
+    end;
+        """)
+
+    @testing.supported('oracle')
+    def test_out_params(self):
+        result = testbase.db.execute(text("begin foo(:x, :y, :z); end;", bindparams=[bindparam('x', Numeric), outparam('y', Numeric), outparam('z', Numeric)]), x=5)
+        assert result.out_parameters == {'y':10, 'z':75}, result.out_parameters
+        print result.out_parameters
+
+    @testing.supported('oracle')
+    def tearDownAll(self):
+         testbase.db.execute("DROP PROCEDURE foo")
+
+if __name__ == '__main__':
+    testbase.main()

test/sql/query.py

             if result.lastrow_has_defaults():
                 criterion = and_(*[col==id for col, id in zip(table.primary_key, result.last_inserted_ids())])
                 row = table.select(criterion).execute().fetchone()
-                ret.update(row)
+                for c in table.c:
+                    ret[c.key] = row[c]
             return ret
 
         for supported, table, values, assertvalues in [
             (
                 {'unsupported':['sqlite']},
                 Table("t1", metadata, 
-                    Column('id', Integer, primary_key=True),
+                    Column('id', Integer, Sequence('t1_id_seq', optional=True), primary_key=True),
                     Column('foo', String(30), primary_key=True)),
                 {'foo':'hi'},
                 {'id':1, 'foo':'hi'}
             (
                 {'unsupported':['sqlite']},
                 Table("t2", metadata, 
-                    Column('id', Integer, primary_key=True),
+                    Column('id', Integer, Sequence('t2_id_seq', optional=True), primary_key=True),
                     Column('foo', String(30), primary_key=True),
                     Column('bar', String(30), PassiveDefault('hi'))
                 ),
             (
                 {'unsupported':[]},
                 Table("t4", metadata, 
-                    Column('id', Integer, primary_key=True),
+                    Column('id', Integer, Sequence('t4_id_seq', optional=True), primary_key=True),
                     Column('foo', String(30), primary_key=True),
                     Column('bar', String(30), PassiveDefault('hi'))
                 ),

test/sql/testtypes.py

         from decimal import Decimal
         numeric_table.insert().execute(numericcol=3.5, floatcol=5.6, ncasdec=12.4, fcasdec=15.78)
         numeric_table.insert().execute(numericcol=Decimal("3.5"), floatcol=Decimal("5.6"), ncasdec=Decimal("12.4"), fcasdec=Decimal("15.78"))
-        print numeric_table.select().execute().fetchall()
-        assert numeric_table.select().execute().fetchall() == [
+        l = numeric_table.select().execute().fetchall()
+        print l
+        rounded = [
+            (l[0][0], l[0][1], round(l[0][2], 5), l[0][3], l[0][4]),
+            (l[1][0], l[1][1], round(l[1][2], 5), l[1][3], l[1][4]),
+        ]
+        assert rounded == [
             (1, 3.5, 5.6, Decimal("12.4"), Decimal("15.78")),
             (2, 3.5, 5.6, Decimal("12.4"), Decimal("15.78")),
         ]
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.