Commits

Remy Blank committed 51b0405

1.0dev: Merged from 0.12-stable.

Comments (0)

Files changed (2)

 from .util import ConnectionWrapper
 
 
-_transaction_local = ThreadLocal(wdb=None, rdb=None)
-
 def with_transaction(env, db=None):
     """Function decorator to emulate a context manager for database
     transactions.
     >>>     return result
 
     """
+    dbm = DatabaseManager(env)
+    _transaction_local = dbm._transaction_local
+    
     def transaction_wrapper(fn):
         ldb = _transaction_local.wdb
         if db is not None:
         elif ldb:
             fn(ldb)
         else:
-            ldb = _transaction_local.wdb = DatabaseManager(env).get_connection()
+            ldb = _transaction_local.wdb = dbm.get_connection()
             try:
                 fn(ldb)
                 ldb.commit()
     db = None
 
     def __init__(self, env):
-        self.env = env
+        self.dbmgr = DatabaseManager(env)
 
     def execute(self, query, params=None):
         """Shortcut for directly executing a query."""
     """
 
     def __enter__(self):
-        db = _transaction_local.wdb # outermost writable db
+        db = self.dbmgr._transaction_local.wdb # outermost writable db
         if not db:
-            db = _transaction_local.rdb # reuse wrapped connection
+            db = self.dbmgr._transaction_local.rdb # reuse wrapped connection
             if db:
                 db = ConnectionWrapper(db.cnx, db.log)
             else:
-                db = DatabaseManager(self.env).get_connection()
-            _transaction_local.wdb = self.db = db
+                db = self.dbmgr.get_connection()
+            self.dbmgr._transaction_local.wdb = self.db = db
         return db
 
     def __exit__(self, et, ev, tb): 
         if self.db: 
-            _transaction_local.wdb = None
+            self.dbmgr._transaction_local.wdb = None
             if et is None: 
                 self.db.commit()
             else: 
                 self.db.rollback()
-            if not _transaction_local.rdb:
+            if not self.dbmgr._transaction_local.rdb:
                 self.db.close()
 
 
     """
 
     def __enter__(self):
-        db = _transaction_local.rdb # outermost readonly db
+        db = self.dbmgr._transaction_local.rdb # outermost readonly db
         if not db:
-            db = _transaction_local.wdb # reuse wrapped connection
+            db = self.dbmgr._transaction_local.wdb # reuse wrapped connection
             if db:
                 db = ConnectionWrapper(db.cnx, db.log, readonly=True)
             else:
-                db = DatabaseManager(self.env).get_connection(readonly=True)
-            _transaction_local.rdb = self.db = db
+                db = self.dbmgr.get_connection(readonly=True)
+            self.dbmgr._transaction_local.rdb = self.db = db
         return db
 
     def __exit__(self, et, ev, tb): 
         if self.db:
-            _transaction_local.rdb = None
-            if not _transaction_local.wdb:
+            self.dbmgr._transaction_local.rdb = None
+            if not self.dbmgr._transaction_local.wdb:
                 self.db.close()
 
 
 
     def __init__(self):
         self._cnx_pool = None
+        self._transaction_local = ThreadLocal(wdb=None, rdb=None)
 
     def init_db(self):
         connector, args = self.get_connector()

trac/db/tests/api.py

 import os
 import unittest
 
-from trac.db.api import DatabaseManager, _parse_db_str, with_transaction, \
-                        get_column_names
+from trac.db.api import DatabaseManager, _parse_db_str, get_column_names, \
+                        with_transaction
 from trac.test import EnvironmentStub, Mock
+from trac.util.concurrency import ThreadLocal
 
 
 class Connection(object):
     pass
 
 
+def make_env(get_cnx):
+    return Mock(components={DatabaseManager:
+             Mock(get_connection=get_cnx,
+                  _transaction_local=ThreadLocal(wdb=None, rdb=None))})
+
+
 class WithTransactionTest(unittest.TestCase):
-
+                      
     def test_successful_transaction(self):
         db = Connection()
-        env = Mock(components={
-                DatabaseManager: Mock(get_connection=lambda: db)})
+        env = make_env(lambda: db)
         @with_transaction(env)
         def do_transaction(db):
             self.assertTrue(not db.committed and not db.rolledback)
         
     def test_failed_transaction(self):
         db = Connection()
-        env = Mock(components={
-                DatabaseManager: Mock(get_connection=lambda: db)})
+        env = make_env(lambda: db)
         try:
             @with_transaction(env)
             def do_transaction(db):
         self.assertTrue(not db.committed and db.rolledback)
         
     def test_implicit_nesting_success(self):
-        env = Mock(components={
-                DatabaseManager: Mock(get_connection=Connection)})
+        env = make_env(Connection)
         dbs = [None, None]
         @with_transaction(env)
         def level0(db):
         self.assertTrue(dbs[0].committed and not dbs[0].rolledback)
 
     def test_implicit_nesting_failure(self):
-        env = Mock(components={
-                DatabaseManager: Mock(get_connection=Connection)})
+        env = make_env(Connection)
         dbs = [None, None]
         try:
             @with_transaction(env)
 
     def test_explicit_success(self):
         db = Connection()
-        env = Mock(components={
-                DatabaseManager: Mock(get_connection=lambda: None)})
-        env = Mock(get_db_cnx=lambda: None)
+        env = make_env(lambda: None)
         @with_transaction(env, db)
         def do_transaction(idb):
             self.assertTrue(idb is db)
 
     def test_explicit_failure(self):
         db = Connection()
-        env = Mock(components={
-                DatabaseManager: Mock(get_connection=lambda: None)})
-        env = Mock(get_db_cnx=lambda: None)
+        env = make_env(lambda: None)
         try:
             @with_transaction(env, db)
             def do_transaction(idb):
 
     def test_implicit_in_explicit_success(self):
         db = Connection()
-        env = Mock(components={
-                DatabaseManager: Mock(get_connection=lambda: db)})
+        env = make_env(lambda: db)
         dbs = [None, None]
         @with_transaction(env, db)
         def level0(db):
 
     def test_implicit_in_explicit_failure(self):
         db = Connection()
-        env = Mock(components={
-                DatabaseManager: Mock(get_connection=lambda: db)})
+        env = make_env(lambda: db)
         dbs = [None, None]
         try:
             @with_transaction(env, db)
 
     def test_explicit_in_implicit_success(self):
         db = Connection()
-        env = Mock(components={
-                DatabaseManager: Mock(get_connection=lambda: db)})
+        env = make_env(lambda: db)
         dbs = [None, None]
         @with_transaction(env)
         def level0(db):
 
     def test_explicit_in_implicit_failure(self):
         db = Connection()
-        env = Mock(components={
-                DatabaseManager: Mock(get_connection=lambda: db)})
+        env = make_env(lambda: db)
         dbs = [None, None]
         try:
             @with_transaction(env)
         self.assertTrue(not dbs[0].committed and dbs[0].rolledback)
 
     def test_invalid_nesting(self):
-        env = Mock(components={
-                DatabaseManager: Mock(get_connection=Connection)})
+        env = make_env(Connection)
         try:
             @with_transaction(env)
             def level0(db):