Commits

Mike Bayer  committed 77bce4e

repaired oracle savepoint implementation

  • Participants
  • Parent commits ddaea56

Comments (0)

Files changed (3)

File lib/sqlalchemy/ansisql.py

         return text
         
     def visit_savepoint(self, savepoint_stmt):
-        return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
+        return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
 
     def visit_rollback_to_savepoint(self, savepoint_stmt):
-        return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
+        return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
     
     def visit_release_savepoint(self, savepoint_stmt):
-        return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
+        return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
     
     def __str__(self):
         return self.string
     def format_alias(self, alias, name=None):
         return self.__generic_obj_format(alias, name or alias.name)
 
-    def format_savepoint(self, savepoint):
-        return self.__generic_obj_format(savepoint, savepoint)
+    def format_savepoint(self, savepoint, name=None):
+        return self.__generic_obj_format(savepoint, name or savepoint.ident)
 
     def format_constraint(self, constraint):
         return self.__generic_obj_format(constraint, constraint.name)

File lib/sqlalchemy/databases/oracle.py

         else:
             return "rowid"
 
+    def do_release_savepoint(self, connection, name):
+        # Oracle does not support RELEASE SAVEPOINT
+        pass
+
     def create_execution_context(self, *args, **kwargs):
         return OracleExecutionContext(self, *args, **kwargs)
 
     def compiler(self, statement, bindparams, **kwargs):
         return OracleCompiler(self, statement, bindparams, **kwargs)
 
+    def preparer(self):
+        return OracleIdentifierPreparer(self)
+
     def schemagenerator(self, *args, **kwargs):
         return OracleSchemaGenerator(self, *args, **kwargs)
 
     def visit_sequence(self, seq):
         return self.connection.execute("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL").scalar()
 
+class OracleIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+    def format_savepoint(self, savepoint):
+        name = re.sub(r'^_+', '', savepoint.ident)
+        return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name)
+
+    
 dialect = OracleDialect

File test/engine/transaction.py

         )
         connection.close()
     
-    @testing.supported('postgres', 'mysql')
+    @testing.supported('postgres', 'mysql', 'oracle')
     @testing.exclude('mysql', '<', (5, 0, 3))
     def testtwophasetransaction(self):
         connection = testbase.db.connect()
         tlengine = create_engine(testbase.db.url, strategy='threadlocal')
         metadata = MetaData()
         users = Table('query_users', metadata,
-            Column('user_id', INT, primary_key = True),
+            Column('user_id', INT, Sequence('query_users_id_seq', optional=True), primary_key=True),
             Column('user_name', VARCHAR(20)),
             test_needs_acid=True,
         )