Commits

Anonymous committed 8414e16

Fixed #17258 -- Moved `threading.local` from `DatabaseWrapper` to the `django.db.connections` dictionary. This allows connections to be explicitly shared between multiple threads and is particularly useful for enabling the sharing of in-memory SQLite connections. Many thanks to Anssi Kääriäinen for the excellent suggestions and feedback, and to Alex Gaynor for the reviews. Refs #2879.

Comments (0)

Files changed (6)

django/db/__init__.py

 # we manually create the dictionary from the settings, passing only the
 # settings that the database backends care about. Note that TIME_ZONE is used
 # by the PostgreSQL backends.
-# we load all these up for backwards compatibility, you should use
+# We load all these up for backwards compatibility, you should use
 # connections['default'] instead.
-connection = connections[DEFAULT_DB_ALIAS]
+class DefaultConnectionProxy(object):
+    """
+    Proxy for accessing the default DatabaseWrapper object's attributes. If you
+    need to access the DatabaseWrapper object itself, use
+    connections[DEFAULT_DB_ALIAS] instead.
+    """
+    def __getattr__(self, item):
+        return getattr(connections[DEFAULT_DB_ALIAS], item)
+
+    def __setattr__(self, name, value):
+        return setattr(connections[DEFAULT_DB_ALIAS], name, value)
+
+connection = DefaultConnectionProxy()
 backend = load_backend(connection.settings_dict['ENGINE'])
 
 # Register an event that closes the database connection

django/db/backends/__init__.py

+from django.db.utils import DatabaseError
+
 try:
     import thread
 except ImportError:
     import dummy_thread as thread
-from threading import local
 from contextlib import contextmanager
 
 from django.conf import settings
 from django.utils.timezone import is_aware
 
 
-class BaseDatabaseWrapper(local):
+class BaseDatabaseWrapper(object):
     """
     Represents a database connection.
     """
     ops = None
     vendor = 'unknown'
 
-    def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
+    def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS,
+                 allow_thread_sharing=False):
         # `settings_dict` should be a dictionary containing keys such as
         # NAME, USER, etc. It's called `settings_dict` instead of `settings`
         # to disambiguate it from Django settings modules.
         self.transaction_state = []
         self.savepoint_state = 0
         self._dirty = None
+        self._thread_ident = thread.get_ident()
+        self.allow_thread_sharing = allow_thread_sharing
 
     def __eq__(self, other):
         return self.alias == other.alias
                 "pending COMMIT/ROLLBACK")
         self._dirty = False
 
+    def validate_thread_sharing(self):
+        """
+        Validates that the connection isn't accessed by another thread than the
+        one which originally created it, unless the connection was explicitly
+        authorized to be shared between threads (via the `allow_thread_sharing`
+        property). Raises an exception if the validation fails.
+        """
+        if (not self.allow_thread_sharing
+            and self._thread_ident != thread.get_ident()):
+                raise DatabaseError("DatabaseWrapper objects created in a "
+                    "thread can only be used in that same thread. The object"
+                    "with alias '%s' was created in thread id %s and this is "
+                    "thread id %s."
+                    % (self.alias, self._thread_ident, thread.get_ident()))
+
     def is_dirty(self):
         """
         Returns True if the current transaction requires a commit for changes to
         """
         Commits changes if the system is not in managed transaction mode.
         """
+        self.validate_thread_sharing()
         if not self.is_managed():
             self._commit()
             self.clean_savepoints()
         """
         Rolls back changes if the system is not in managed transaction mode.
         """
+        self.validate_thread_sharing()
         if not self.is_managed():
             self._rollback()
         else:
         """
         Does the commit itself and resets the dirty flag.
         """
+        self.validate_thread_sharing()
         self._commit()
         self.set_clean()
 
         """
         This function does the rollback itself and resets the dirty flag.
         """
+        self.validate_thread_sharing()
         self._rollback()
         self.set_clean()
 
         Rolls back the most recent savepoint (if one exists). Does nothing if
         savepoints are not supported.
         """
+        self.validate_thread_sharing()
         if self.savepoint_state:
             self._savepoint_rollback(sid)
 
         Commits the most recent savepoint (if one exists). Does nothing if
         savepoints are not supported.
         """
+        self.validate_thread_sharing()
         if self.savepoint_state:
             self._savepoint_commit(sid)
 
         pass
 
     def close(self):
+        self.validate_thread_sharing()
         if self.connection is not None:
             self.connection.close()
             self.connection = None
 
     def cursor(self):
+        self.validate_thread_sharing()
         if (self.use_debug_cursor or
             (self.use_debug_cursor is None and settings.DEBUG)):
             cursor = self.make_debug_cursor(self._cursor())

django/db/backends/sqlite3/base.py

 
 import datetime
 import decimal
+import warnings
 import re
 import sys
 
-from django.conf import settings
 from django.db import utils
 from django.db.backends import *
 from django.db.backends.signals import connection_created
                 'detect_types': Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
             }
             kwargs.update(settings_dict['OPTIONS'])
+            # Always allow the underlying SQLite connection to be shareable
+            # between multiple threads. The safe-guarding will be handled at a
+            # higher level by the `BaseDatabaseWrapper.allow_thread_sharing`
+            # property. This is necessary as the shareability is disabled by
+            # default in pysqlite and it cannot be changed once a connection is
+            # opened.
+            if 'check_same_thread' in kwargs and kwargs['check_same_thread']:
+                warnings.warn(
+                    'The `check_same_thread` option was provided and set to '
+                    'True. It will be overriden with False. Use the '
+                    '`DatabaseWrapper.allow_thread_sharing` property instead '
+                    'for controlling thread shareability.',
+                    RuntimeWarning
+                )
+            kwargs.update({'check_same_thread': False})
             self.connection = Database.connect(**kwargs)
             # Register extract, date_trunc, and regexp functions.
             self.connection.create_function("django_extract", 2, _sqlite_extract)

django/db/utils.py

 import os
+from threading import local
 
 from django.conf import settings
 from django.core.exceptions import ImproperlyConfigured
 class ConnectionHandler(object):
     def __init__(self, databases):
         self.databases = databases
-        self._connections = {}
+        self._connections = local()
 
     def ensure_defaults(self, alias):
         """
             conn.setdefault(setting, None)
 
     def __getitem__(self, alias):
-        if alias in self._connections:
-            return self._connections[alias]
+        if hasattr(self._connections, alias):
+            return getattr(self._connections, alias)
 
         self.ensure_defaults(alias)
         db = self.databases[alias]
         backend = load_backend(db['ENGINE'])
         conn = backend.DatabaseWrapper(db, alias)
-        self._connections[alias] = conn
+        setattr(self._connections, alias, conn)
         return conn
 
+    def __setitem__(self, key, value):
+        setattr(self._connections, key, value)
+
     def __iter__(self):
         return iter(self.databases)
 

docs/releases/1.4.txt

 :setting:`USE_TZ` is ``False``, if you attempt to save an aware datetime
 object, Django raises an exception.
 
+Database connection's thread-locality
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+``DatabaseWrapper`` objects (i.e. the connection objects referenced by
+``django.db.connection`` and ``django.db.connections["some_alias"]``) used to
+be thread-local. They are now global objects in order to be potentially shared
+between multiple threads. While the individual connection objects are now
+global, the ``django.db.connections`` dictionary referencing those objects is
+still thread-local. Therefore if you just use the ORM or
+``DatabaseWrapper.cursor()`` then the behavior is still the same as before.
+Note, however, that ``django.db.connection`` does not directly reference the
+default ``DatabaseWrapper`` object any more and is now a proxy to access that
+object's attributes. If you need to access the actual ``DatabaseWrapper``
+object, use ``django.db.connections[DEFAULT_DB_ALIAS]`` instead.
+
+As part of this change, all underlying SQLite connections are now enabled for
+potential thread-sharing (by passing the ``check_same_thread=False`` attribute
+to pysqlite). ``DatabaseWrapper`` however preserves the previous behavior by
+disabling thread-sharing by default, so this does not affect any existing
+code that purely relies on the ORM or on ``DatabaseWrapper.cursor()``.
+
+Finally, while it is now possible to pass connections between threads, Django
+does not make any effort to synchronize access to the underlying backend.
+Concurrency behavior is defined by the underlying backend implementation.
+Check their documentation for details.
+
 `COMMENTS_BANNED_USERS_GROUP` setting
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 

tests/regressiontests/backends/tests.py

 from __future__ import with_statement, absolute_import
 
 import datetime
+import threading
 
 from django.conf import settings
 from django.core.management.color import no_style
         connection_created.connect(receiver)
         connection.close()
         cursor = connection.cursor()
-        self.assertTrue(data["connection"] is connection)
+        self.assertTrue(data["connection"].connection is connection.connection)
 
         connection_created.disconnect(receiver)
         data.clear()
                         connection.check_constraints()
             finally:
                 transaction.rollback()
+
+
+class ThreadTests(TestCase):
+
+    def test_default_connection_thread_local(self):
+        """
+        Ensure that the default connection (i.e. django.db.connection) is
+        different for each thread.
+        Refs #17258.
+        """
+        connections_set = set()
+        connection.cursor()
+        connections_set.add(connection.connection)
+        def runner():
+            from django.db import connection
+            connection.cursor()
+            connections_set.add(connection.connection)
+        for x in xrange(2):
+            t = threading.Thread(target=runner)
+            t.start()
+            t.join()
+        self.assertEquals(len(connections_set), 3)
+        # Finish by closing the connections opened by the other threads (the
+        # connection opened in the main thread will automatically be closed on
+        # teardown).
+        for conn in connections_set:
+            if conn != connection.connection:
+                conn.close()
+
+    def test_connections_thread_local(self):
+        """
+        Ensure that the connections are different for each thread.
+        Refs #17258.
+        """
+        connections_set = set()
+        for conn in connections.all():
+            connections_set.add(conn)
+        def runner():
+            from django.db import connections
+            for conn in connections.all():
+                connections_set.add(conn)
+        for x in xrange(2):
+            t = threading.Thread(target=runner)
+            t.start()
+            t.join()
+        self.assertEquals(len(connections_set), 6)
+        # Finish by closing the connections opened by the other threads (the
+        # connection opened in the main thread will automatically be closed on
+        # teardown).
+        for conn in connections_set:
+            if conn != connection:
+                conn.close()
+
+    def test_pass_connection_between_threads(self):
+        """
+        Ensure that a connection can be passed from one thread to the other.
+        Refs #17258.
+        """
+        models.Person.objects.create(first_name="John", last_name="Doe")
+
+        def do_thread():
+            def runner(main_thread_connection):
+                from django.db import connections
+                connections['default'] = main_thread_connection
+                try:
+                    models.Person.objects.get(first_name="John", last_name="Doe")
+                except DatabaseError, e:
+                    exceptions.append(e)
+            t = threading.Thread(target=runner, args=[connections['default']])
+            t.start()
+            t.join()
+
+        # Without touching allow_thread_sharing, which should be False by default.
+        exceptions = []
+        do_thread()
+        # Forbidden!
+        self.assertTrue(isinstance(exceptions[0], DatabaseError))
+
+        # If explicitly setting allow_thread_sharing to False
+        connections['default'].allow_thread_sharing = False
+        exceptions = []
+        do_thread()
+        # Forbidden!
+        self.assertTrue(isinstance(exceptions[0], DatabaseError))
+
+        # If explicitly setting allow_thread_sharing to True
+        connections['default'].allow_thread_sharing = True
+        exceptions = []
+        do_thread()
+        # All good
+        self.assertEqual(len(exceptions), 0)