Commits

Christian Boos  committed 4c91758

Move `with_transaction` and `get_read_db` from `trac.db.util` to `trac.db.api`.

  • Participants
  • Parent commits 3a9049d
  • Branches trunk

Comments (0)

Files changed (6)

File trac/db/api.py

 from trac.config import BoolOption, IntOption, Option
 from trac.core import *
 from trac.db.pool import ConnectionPool
+from trac.util.concurrency import ThreadLocal
 from trac.util.text import unicode_passwd
 from trac.util.translation import _
 
+_transaction_local = ThreadLocal(db=None)
 
-def get_column_names(cursor):
-    return cursor.description and \
-           [(isinstance(d[0], str) and [unicode(d[0], 'utf-8')] or [d[0]])[0]
-            for d in cursor.description] or []
+def with_transaction(env, db=None):
+    """Function decorator to emulate a context manager for database
+    transactions.
+    
+    >>> def api_method(p1, p2):
+    >>>     result[0] = value1
+    >>>     @with_transaction(env)
+    >>>     def implementation(db):
+    >>>         # implementation
+    >>>         result[0] = value2
+    >>>     return result[0]
+    
+    In this example, the `implementation()` function is called automatically
+    right after its definition, with a database connection as an argument.
+    If the function completes, a COMMIT is issued on the connection. If the
+    function raises an exception, a ROLLBACK is issued and the exception is
+    re-raised. Nested transactions are supported, and a COMMIT will only be
+    issued when the outermost transaction block in a thread exits.
+    
+    This mechanism is intended to replace the current practice of getting a
+    database connection with `env.get_db_cnx()` and issuing an explicit commit
+    or rollback, for mutating database accesses. Its automatic handling of
+    commit, rollback and nesting makes it much more robust.
+    
+    This decorator will be replaced by a context manager once python 2.4
+    support is dropped.
+
+    The optional `db` argument is intended for legacy code and should not
+    be used in new code.
+    """
+    def transaction_wrapper(fn):
+        ldb = _transaction_local.db
+        if db is not None:
+            if ldb is None:
+                _transaction_local.db = db
+                try:
+                    fn(db)
+                finally:
+                    _transaction_local.db = None
+            else:
+                assert ldb is db, "Invalid transaction nesting"
+                fn(db)
+        elif ldb:
+            fn(ldb)
+        else:
+            ldb = _transaction_local.db = env.get_db_cnx()
+            try:
+                fn(ldb)
+                ldb.commit()
+                _transaction_local.db = None
+            except:
+                _transaction_local.db = None
+                ldb.rollback()
+                ldb = None
+                raise
+    return transaction_wrapper
+
+
+def get_read_db(env):
+    """Get a database connection for reading only."""
+    return _transaction_local.db or DatabaseManager(env).get_connection()
 
 
 class IDatabaseConnector(Interface):
     _get_connector = get_connector  # For 0.11 compatibility
 
 
+def get_column_names(cursor):
+    return cursor.description and \
+           [(isinstance(d[0], str) and [unicode(d[0], 'utf-8')] or [d[0]])[0]
+            for d in cursor.description] or []
+
+
 def _parse_db_str(db_str):
     scheme, rest = db_str.split(':', 1)
 

File trac/db/tests/__init__.py

     suite.addTest(api.suite())
     suite.addTest(mysql_test.suite())
     suite.addTest(postgres_test.suite())
-    suite.addTest(util.suite())
+    #suite.addTest(util.suite())
     return suite
 
 if __name__ == '__main__':

File trac/db/tests/api.py

 import os
 import unittest
 
-from trac.db.api import _parse_db_str
-from trac.test import EnvironmentStub
+from trac.db.api import _parse_db_str, with_transaction
+from trac.test import EnvironmentStub, Mock
+
+
+class Connection(object):
+    
+    committed = False
+    rolledback = False
+    
+    def commit(self):
+        self.committed = True
+    
+    def rollback(self):
+        self.rolledback = True
+
+
+class Error(Exception):
+    pass
+
+
+class WithTransactionTest(unittest.TestCase):
+
+    def test_successful_transaction(self):
+        db = Connection()
+        env = Mock(get_db_cnx=lambda: db)
+        @with_transaction(env)
+        def do_transaction(db):
+            self.assertTrue(not db.committed and not db.rolledback)
+        self.assertTrue(db.committed and not db.rolledback)
+        
+    def test_failed_transaction(self):
+        db = Connection()
+        env = Mock(get_db_cnx=lambda: db)
+        try:
+            @with_transaction(env)
+            def do_transaction(db):
+                self.assertTrue(not db.committed and not db.rolledback)
+                raise Error()
+            self.fail()
+        except Error:
+            pass
+        self.assertTrue(not db.committed and db.rolledback)
+        
+    def test_implicit_nesting_success(self):
+        env = Mock(get_db_cnx=lambda: Connection())
+        dbs = [None, None]
+        @with_transaction(env)
+        def level0(db):
+            dbs[0] = db
+            @with_transaction(env)
+            def level1(db):
+                dbs[1] = db
+                self.assertTrue(not db.committed and not db.rolledback)
+            self.assertTrue(not db.committed and not db.rolledback)
+        self.assertTrue(dbs[0] is not None)
+        self.assertTrue(dbs[0] is dbs[1])
+        self.assertTrue(dbs[0].committed and not dbs[0].rolledback)
+
+    def test_implicit_nesting_failure(self):
+        env = Mock(get_db_cnx=lambda: Connection())
+        dbs = [None, None]
+        try:
+            @with_transaction(env)
+            def level0(db):
+                dbs[0] = db
+                try:
+                    @with_transaction(env)
+                    def level1(db):
+                        dbs[1] = db
+                        self.assertTrue(not db.committed and not db.rolledback)
+                        raise Error()
+                    self.fail()
+                except Error:
+                    self.assertTrue(not db.committed and not db.rolledback)
+                    raise
+            self.fail()
+        except Error:
+            pass
+        self.assertTrue(dbs[0] is not None)
+        self.assertTrue(dbs[0] is dbs[1])
+        self.assertTrue(not dbs[0].committed and dbs[0].rolledback)
+
+    def test_explicit_success(self):
+        db = Connection()
+        env = Mock(get_db_cnx=lambda: None)
+        @with_transaction(env, db)
+        def do_transaction(idb):
+            self.assertTrue(idb is db)
+            self.assertTrue(not db.committed and not db.rolledback)
+        self.assertTrue(not db.committed and not db.rolledback)
+
+    def test_explicit_failure(self):
+        db = Connection()
+        env = Mock(get_db_cnx=lambda: None)
+        try:
+            @with_transaction(env, db)
+            def do_transaction(idb):
+                self.assertTrue(idb is db)
+                self.assertTrue(not db.committed and not db.rolledback)
+                raise Error()
+            self.fail()
+        except Error:
+            pass
+        self.assertTrue(not db.committed and not db.rolledback)
+
+    def test_implicit_in_explicit_success(self):
+        db = Connection()
+        env = Mock(get_db_cnx=lambda: Connection())
+        dbs = [None, None]
+        @with_transaction(env, db)
+        def level0(db):
+            dbs[0] = db
+            @with_transaction(env)
+            def level1(db):
+                dbs[1] = db
+                self.assertTrue(not db.committed and not db.rolledback)
+            self.assertTrue(not db.committed and not db.rolledback)
+        self.assertTrue(dbs[0] is not None)
+        self.assertTrue(dbs[0] is dbs[1])
+        self.assertTrue(not dbs[0].committed and not dbs[0].rolledback)
+
+    def test_implicit_in_explicit_failure(self):
+        db = Connection()
+        env = Mock(get_db_cnx=lambda: Connection())
+        dbs = [None, None]
+        try:
+            @with_transaction(env, db)
+            def level0(db):
+                dbs[0] = db
+                @with_transaction(env)
+                def level1(db):
+                    dbs[1] = db
+                    self.assertTrue(not db.committed and not db.rolledback)
+                    raise Error()
+                self.fail()
+            self.fail()
+        except Error:
+            pass
+        self.assertTrue(dbs[0] is not None)
+        self.assertTrue(dbs[0] is dbs[1])
+        self.assertTrue(not dbs[0].committed and not dbs[0].rolledback)
+
+    def test_explicit_in_implicit_success(self):
+        db = Connection()
+        env = Mock(get_db_cnx=lambda: Connection())
+        dbs = [None, None]
+        @with_transaction(env)
+        def level0(db):
+            dbs[0] = db
+            @with_transaction(env, db)
+            def level1(db):
+                dbs[1] = db
+                self.assertTrue(not db.committed and not db.rolledback)
+            self.assertTrue(not db.committed and not db.rolledback)
+        self.assertTrue(dbs[0] is not None)
+        self.assertTrue(dbs[0] is dbs[1])
+        self.assertTrue(dbs[0].committed and not dbs[0].rolledback)
+
+    def test_explicit_in_implicit_failure(self):
+        db = Connection()
+        env = Mock(get_db_cnx=lambda: Connection())
+        dbs = [None, None]
+        try:
+            @with_transaction(env)
+            def level0(db):
+                dbs[0] = db
+                @with_transaction(env, db)
+                def level1(db):
+                    dbs[1] = db
+                    self.assertTrue(not db.committed and not db.rolledback)
+                    raise Error()
+                self.fail()
+            self.fail()
+        except Error:
+            pass
+        self.assertTrue(dbs[0] is not None)
+        self.assertTrue(dbs[0] is dbs[1])
+        self.assertTrue(not dbs[0].committed and dbs[0].rolledback)
+
+    def test_invalid_nesting(self):
+        env = Mock(get_db_cnx=lambda: Connection())
+        try:
+            @with_transaction(env)
+            def level0(db):
+                @with_transaction(env, Connection())
+                def level1(db):
+                    raise Error()
+                raise Error()
+            raise Error()
+        except AssertionError:
+            pass
+
 
 
 class ParseConnectionStringTestCase(unittest.TestCase):
     suite.addTest(unittest.makeSuite(ParseConnectionStringTestCase, 'test'))
     suite.addTest(unittest.makeSuite(StringsTestCase, 'test'))
     suite.addTest(unittest.makeSuite(ConnectionTestCase, 'test'))
+    suite.addTest(unittest.makeSuite(WithTransactionTest, 'test'))
     return suite
 
+
 if __name__ == '__main__':
-    unittest.main()
+    unittest.main(defaultTest='suite')

File trac/db/tests/util.py

 
 import unittest
 
-from trac.db.util import with_transaction
-from trac.test import Mock
-
-
-class Connection(object):
-    
-    committed = False
-    rolledback = False
-    
-    def commit(self):
-        self.committed = True
-    
-    def rollback(self):
-        self.rolledback = True
-
-
-class Error(Exception):
-    pass
-
-
-class WithTransactionTest(unittest.TestCase):
-
-    def test_successful_transaction(self):
-        db = Connection()
-        env = Mock(get_db_cnx=lambda: db)
-        @with_transaction(env)
-        def do_transaction(db):
-            self.assertTrue(not db.committed and not db.rolledback)
-        self.assertTrue(db.committed and not db.rolledback)
-        
-    def test_failed_transaction(self):
-        db = Connection()
-        env = Mock(get_db_cnx=lambda: db)
-        try:
-            @with_transaction(env)
-            def do_transaction(db):
-                self.assertTrue(not db.committed and not db.rolledback)
-                raise Error()
-            self.fail()
-        except Error:
-            pass
-        self.assertTrue(not db.committed and db.rolledback)
-        
-    def test_implicit_nesting_success(self):
-        env = Mock(get_db_cnx=lambda: Connection())
-        dbs = [None, None]
-        @with_transaction(env)
-        def level0(db):
-            dbs[0] = db
-            @with_transaction(env)
-            def level1(db):
-                dbs[1] = db
-                self.assertTrue(not db.committed and not db.rolledback)
-            self.assertTrue(not db.committed and not db.rolledback)
-        self.assertTrue(dbs[0] is not None)
-        self.assertTrue(dbs[0] is dbs[1])
-        self.assertTrue(dbs[0].committed and not dbs[0].rolledback)
-
-    def test_implicit_nesting_failure(self):
-        env = Mock(get_db_cnx=lambda: Connection())
-        dbs = [None, None]
-        try:
-            @with_transaction(env)
-            def level0(db):
-                dbs[0] = db
-                try:
-                    @with_transaction(env)
-                    def level1(db):
-                        dbs[1] = db
-                        self.assertTrue(not db.committed and not db.rolledback)
-                        raise Error()
-                    self.fail()
-                except Error:
-                    self.assertTrue(not db.committed and not db.rolledback)
-                    raise
-            self.fail()
-        except Error:
-            pass
-        self.assertTrue(dbs[0] is not None)
-        self.assertTrue(dbs[0] is dbs[1])
-        self.assertTrue(not dbs[0].committed and dbs[0].rolledback)
-
-    def test_explicit_success(self):
-        db = Connection()
-        env = Mock(get_db_cnx=lambda: None)
-        @with_transaction(env, db)
-        def do_transaction(idb):
-            self.assertTrue(idb is db)
-            self.assertTrue(not db.committed and not db.rolledback)
-        self.assertTrue(not db.committed and not db.rolledback)
-
-    def test_explicit_failure(self):
-        db = Connection()
-        env = Mock(get_db_cnx=lambda: None)
-        try:
-            @with_transaction(env, db)
-            def do_transaction(idb):
-                self.assertTrue(idb is db)
-                self.assertTrue(not db.committed and not db.rolledback)
-                raise Error()
-            self.fail()
-        except Error:
-            pass
-        self.assertTrue(not db.committed and not db.rolledback)
-
-    def test_implicit_in_explicit_success(self):
-        db = Connection()
-        env = Mock(get_db_cnx=lambda: Connection())
-        dbs = [None, None]
-        @with_transaction(env, db)
-        def level0(db):
-            dbs[0] = db
-            @with_transaction(env)
-            def level1(db):
-                dbs[1] = db
-                self.assertTrue(not db.committed and not db.rolledback)
-            self.assertTrue(not db.committed and not db.rolledback)
-        self.assertTrue(dbs[0] is not None)
-        self.assertTrue(dbs[0] is dbs[1])
-        self.assertTrue(not dbs[0].committed and not dbs[0].rolledback)
-
-    def test_implicit_in_explicit_failure(self):
-        db = Connection()
-        env = Mock(get_db_cnx=lambda: Connection())
-        dbs = [None, None]
-        try:
-            @with_transaction(env, db)
-            def level0(db):
-                dbs[0] = db
-                @with_transaction(env)
-                def level1(db):
-                    dbs[1] = db
-                    self.assertTrue(not db.committed and not db.rolledback)
-                    raise Error()
-                self.fail()
-            self.fail()
-        except Error:
-            pass
-        self.assertTrue(dbs[0] is not None)
-        self.assertTrue(dbs[0] is dbs[1])
-        self.assertTrue(not dbs[0].committed and not dbs[0].rolledback)
-
-    def test_explicit_in_implicit_success(self):
-        db = Connection()
-        env = Mock(get_db_cnx=lambda: Connection())
-        dbs = [None, None]
-        @with_transaction(env)
-        def level0(db):
-            dbs[0] = db
-            @with_transaction(env, db)
-            def level1(db):
-                dbs[1] = db
-                self.assertTrue(not db.committed and not db.rolledback)
-            self.assertTrue(not db.committed and not db.rolledback)
-        self.assertTrue(dbs[0] is not None)
-        self.assertTrue(dbs[0] is dbs[1])
-        self.assertTrue(dbs[0].committed and not dbs[0].rolledback)
-
-    def test_explicit_in_implicit_failure(self):
-        db = Connection()
-        env = Mock(get_db_cnx=lambda: Connection())
-        dbs = [None, None]
-        try:
-            @with_transaction(env)
-            def level0(db):
-                dbs[0] = db
-                @with_transaction(env, db)
-                def level1(db):
-                    dbs[1] = db
-                    self.assertTrue(not db.committed and not db.rolledback)
-                    raise Error()
-                self.fail()
-            self.fail()
-        except Error:
-            pass
-        self.assertTrue(dbs[0] is not None)
-        self.assertTrue(dbs[0] is dbs[1])
-        self.assertTrue(not dbs[0].committed and dbs[0].rolledback)
-
-    def test_invalid_nesting(self):
-        env = Mock(get_db_cnx=lambda: Connection())
-        try:
-            @with_transaction(env)
-            def level0(db):
-                @with_transaction(env, Connection())
-                def level1(db):
-                    raise Error()
-                raise Error()
-            raise Error()
-        except AssertionError:
-            pass
-
-
-def suite():
-    suite = unittest.TestSuite()
-    suite.addTest(unittest.makeSuite(WithTransactionTest, 'test'))
-    return suite
-
-
-if __name__ == '__main__':
-    unittest.main(defaultTest='suite')
+# TODO: test sql_escape_percent, IterableCursor, ConnectionWrapper ...

File trac/db/util.py

 #
 # Author: Christopher Lenz <cmlenz@gmx.de>
 
-from trac.util.concurrency import ThreadLocal
-
-
-_transaction_local = ThreadLocal(db=None)
-
-def with_transaction(env, db=None):
-    """Function decorator to emulate a context manager for database
-    transactions.
-    
-    >>> def api_method(p1, p2):
-    >>>     result[0] = value1
-    >>>     @with_transaction(env)
-    >>>     def implementation(db):
-    >>>         # implementation
-    >>>         result[0] = value2
-    >>>     return result[0]
-    
-    In this example, the `implementation()` function is called automatically
-    right after its definition, with a database connection as an argument.
-    If the function completes, a COMMIT is issued on the connection. If the
-    function raises an exception, a ROLLBACK is issued and the exception is
-    re-raised. Nested transactions are supported, and a COMMIT will only be
-    issued when the outermost transaction block in a thread exits.
-    
-    This mechanism is intended to replace the current practice of getting a
-    database connection with `env.get_db_cnx()` and issuing an explicit commit
-    or rollback, for mutating database accesses. Its automatic handling of
-    commit, rollback and nesting makes it much more robust.
-    
-    This decorator will be replaced by a context manager once python 2.4
-    support is dropped.
-
-    The optional `db` argument is intended for legacy code and should not
-    be used in new code.
-    """
-    def transaction_wrapper(fn):
-        ldb = _transaction_local.db
-        if db is not None:
-            if ldb is None:
-                _transaction_local.db = db
-                try:
-                    fn(db)
-                finally:
-                    _transaction_local.db = None
-            else:
-                assert ldb is db, "Invalid transaction nesting"
-                fn(db)
-        elif ldb:
-            fn(ldb)
-        else:
-            ldb = _transaction_local.db = env.get_db_cnx()
-            try:
-                fn(ldb)
-                ldb.commit()
-                _transaction_local.db = None
-            except:
-                _transaction_local.db = None
-                ldb.rollback()
-                ldb = None
-                raise
-    return transaction_wrapper
-
-
-def get_read_db(env):
-    """Get a database connection for reading only."""
-    from trac.db.api import DatabaseManager
-    return _transaction_local.db or DatabaseManager(env).get_connection()
-
 
 def sql_escape_percent(sql):
     import re
 from trac.config import *
 from trac.core import Component, ComponentManager, implements, Interface, \
                       ExtensionPoint, TracError
-from trac.db import DatabaseManager
-from trac.db.util import get_read_db, with_transaction
+from trac.db.api import DatabaseManager, get_read_db, with_transaction
 from trac.util import copytree, create_file, get_pkginfo, makedirs
 from trac.util.compat import any
 from trac.util.concurrency import threading
     def with_transaction(self, db=None):
         """Decorator for transaction functions.
 
-        See `trac.db.util.with_transaction` for detailed documentation."""
+        See `trac.db.api.with_transaction` for detailed documentation."""
         return with_transaction(self, db)
 
     def get_read_db(self):
         """Return a database connection for read purposes.
 
-        See `trac.db.util.get_read_db` for detailed documentation."""
+        See `trac.db.api.get_read_db` for detailed documentation."""
         return get_read_db(self)
 
     def shutdown(self, tid=None):