Commits

Mike Bayer committed 11b1fd1

- Got PG server side cursors back into shape, added fixed
unit tests as part of the default test suite. Added
better uniqueness to the cursor ID [ticket:1001]
- update().values() and insert().values() take keyword
arguments.

Comments (0)

Files changed (5)

     - random() is now a generic sql function and will compile to
       the database's random implementation, if any.
 
+    - update().values() and insert().values() take keyword 
+      arguments.
+
     - Fixed an issue in select() regarding its generation of FROM
       clauses, in rare circumstances two clauses could be produced
       when one was intended to cancel out the other.  Some ORM
        property can also be set on individual declarative
        classes using the "__mapper_cls__" property.
 
+- postgres
+    - Got PG server side cursors back into shape, added fixed
+      unit tests as part of the default test suite.  Added
+      better uniqueness to the cursor ID [ticket:1001]
+      
 - oracle
     - The "owner" keyword on Table is now deprecated, and is
       exactly synonymous with the "schema" keyword.  Tables can

lib/sqlalchemy/databases/postgres.py

         ('host',"Hostname", None),
     ]}
 
+SERVER_SIDE_CURSOR_RE = re.compile(
+    r'\s*SELECT',
+    re.I | re.UNICODE)
+
 SELECT_RE = re.compile(
     r'\s*(?:SELECT|FETCH|(UPDATE|INSERT))',
     re.I | re.UNICODE)
         \sRETURNING\s""", re.I | re.UNICODE | re.VERBOSE)
 
 class PGExecutionContext(default.DefaultExecutionContext):
-
     def returns_rows_text(self, statement):
         m = SELECT_RE.match(statement)
         return m and (not m.group(1) or (RETURNING_RE.search(statement)
             )
 
     def create_cursor(self):
-        # executing a default or Sequence standalone creates an execution context without a statement.
-        # so slightly hacky "if no statement assume we're server side" logic
-        # TODO: dont use regexp if Compiled is used ?
         self.__is_server_side = \
             self.dialect.server_side_cursors and \
-            (self.statement is None or \
-            (SELECT_RE.match(self.statement) and not re.search(r'FOR UPDATE(?: NOWAIT)?\s*$', self.statement, re.I))
-        )
+            ((self.compiled and isinstance(self.compiled.statement, expression.Selectable)) \
+            or \
+            (not self.compiled and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement)))
 
         if self.__is_server_side:
             # use server-side cursors:
             # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
-            ident = "c" + hex(random.randint(0, 65535))[2:]
+            ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:])
             return self._connection.connection.cursor(ident)
         else:
             return self._connection.connection.cursor()
-
+    
     def get_result_proxy(self):
         if self.__is_server_side:
             return base.BufferedRowResultProxy(self)
             self.execute()
 
 class PGDefaultRunner(base.DefaultRunner):
+    def __init__(self, context):
+        base.DefaultRunner.__init__(self, context)
+        # craete cursor which won't conflict with a server-side cursor
+        self.cursor = context._connection.connection.cursor()
+    
     def get_column_default(self, column, isinsert=True):
         if column.primary_key:
             # pre-execute passive defaults on primary keys

lib/sqlalchemy/engine/base.py

     def __init__(self, context):
         self.context = context
         self.dialect = context.dialect
+        self.cursor = context.cursor
 
     def get_column_default(self, column):
         if column.default is not None:
         conn = self.context._connection
         if isinstance(stmt, unicode) and not self.dialect.supports_unicode_statements:
             stmt = stmt.encode(self.dialect.encoding)
-        conn._cursor_execute(self.context.cursor, stmt, params)
-        return self.context.cursor.fetchone()[0]
+        conn._cursor_execute(self.cursor, stmt, params)
+        return self.cursor.fetchone()[0]
 
     def visit_column_onupdate(self, onupdate):
         if isinstance(onupdate.arg, expression.ClauseElement):

lib/sqlalchemy/sql/expression.py

         self._bind = bind
     bind = property(bind, _set_bind)
 
-class Insert(_UpdateBase):
+class _ValuesBase(_UpdateBase):
+    def values(self, *args, **kwargs):
+        """specify the VALUES clause for an INSERT statement, or the SET clause for an UPDATE.
+
+            \**kwargs
+                key=<somevalue> arguments
+                
+            \*args
+                deprecated.  A single dictionary can be sent as the first positional argument.
+        """
+        
+        if args:
+            v = args[0]
+        else:
+            v = {}
+        if len(v) == 0 and len(kwargs) == 0:
+            return self
+        u = self._clone()
+        
+        if u.parameters is None:
+            u.parameters = u._process_colparams(v)
+            u.parameters.update(kwargs)
+        else:
+            u.parameters = self.parameters.copy()
+            u.parameters.update(u._process_colparams(v))
+            u.parameters.update(kwargs)
+        return u
+
+class Insert(_ValuesBase):
     def __init__(self, table, values=None, inline=False, bind=None, prefixes=None, **kwargs):
         self._bind = bind
         self.table = table
     def _copy_internals(self, clone=_clone):
         self.parameters = self.parameters.copy()
 
-    def values(self, v):
-        if len(v) == 0:
-            return self
-        u = self._clone()
-        if u.parameters is None:
-            u.parameters = u._process_colparams(v)
-        else:
-            u.parameters = self.parameters.copy()
-            u.parameters.update(u._process_colparams(v))
-        return u
-
     def prefix_with(self, clause):
         """Add a word or expression between INSERT and INTO. Generative.
 
         gen._prefixes = self._prefixes + [clause]
         return gen
 
-class Update(_UpdateBase):
+class Update(_ValuesBase):
     def __init__(self, table, whereclause, values=None, inline=False, bind=None, **kwargs):
         self._bind = bind
         self.table = table
             s._whereclause = _literal_as_text(whereclause)
         return s
 
-    def values(self, v):
-        if len(v) == 0:
-            return self
-        u = self._clone()
-        if u.parameters is None:
-            u.parameters = u._process_colparams(v)
-        else:
-            u.parameters = self.parameters.copy()
-            u.parameters.update(u._process_colparams(v))
-        return u
 
 class Delete(_UpdateBase):
     def __init__(self, table, whereclause, bind=None):

test/dialect/postgres.py

         result = connection.execute(s).fetchone() 
         self.assertEqual(result[0], datetime.datetime(2007, 12, 25, 0, 0)) 
 
-
+class ServerSideCursorsTest(TestBase, AssertsExecutionResults):
+    __only_on__ = 'postgres'
+    
+    def setUpAll(self):
+        global ss_engine
+        ss_engine = engines.testing_engine(options={'server_side_cursors':True})
+        
+    def tearDownAll(self):
+        ss_engine.dispose()
+    
+    def test_roundtrip(self):
+        test_table = Table('test_table', MetaData(ss_engine),
+            Column('id', Integer, primary_key=True),
+            Column('data', String(50))
+        )
+        test_table.create(checkfirst=True)
+        try:
+            test_table.insert().execute(data='data1')
+            
+            nextid = ss_engine.execute(Sequence('test_table_id_seq'))
+            test_table.insert().execute(id=nextid, data='data2')
+            
+            self.assertEquals(test_table.select().execute().fetchall(), [(1, 'data1'), (2, 'data2')])
+            
+            test_table.update().where(test_table.c.id==2).values(data=test_table.c.data + ' updated').execute()
+            self.assertEquals(test_table.select().execute().fetchall(), [(1, 'data1'), (2, 'data2 updated')])
+            test_table.delete().execute()
+            self.assertEquals(test_table.count().scalar(), 0)
+        finally:
+            test_table.drop(checkfirst=True)
+            
+    
 if __name__ == "__main__":
     testenv.main()