Mike Bayer avatar Mike Bayer committed 87a4089

Improvements to Connection auto-invalidation
handling. If a non-disconnect error occurs,
but leads to a delayed disconnect error within error
handling (happens with MySQL), the disconnect condition
is detected. The Connection can now also be closed
when in an invalid state, meaning it will raise "closed"
on next usage, and additionally the "close with result"
feature will work even if the autorollback in an error
handling routine fails and regardless of whether the
condition is a disconnect or not.
[ticket:2695]

Comments (0)

Files changed (3)

doc/build/changelog/changelog_08.rst

 
     .. change::
       :tags: bug, sql
+      :tickets: 2695
+
+      Improvements to Connection auto-invalidation
+      handling.  If a non-disconnect error occurs,
+      but leads to a delayed disconnect error within error
+      handling (happens with MySQL), the disconnect condition
+      is detected.  The Connection can now also be closed
+      when in an invalid state, meaning it will raise "closed"
+      on next usage, and additionally the "close with result"
+      feature will work even if the autorollback in an error
+      handling routine fails and regardless of whether the
+      condition is a disconnect or not.
+
+    .. change::
+      :tags: bug, sql
       :tickets: 2702
 
       A major fix to the way in which a select() object produces

lib/sqlalchemy/engine/base.py

         self.__savepoint_seq = 0
         self.__branch = _branch
         self.__invalid = False
+        self.__can_reconnect = True
         if _dispatch:
             self.dispatch = _dispatch
         elif engine._has_events:
     def closed(self):
         """Return True if this connection is closed."""
 
-        return not self.__invalid and '_Connection__connection' \
-                        not in self.__dict__
+        return '_Connection__connection' not in self.__dict__ \
+            and not self.__can_reconnect
 
     @property
     def invalidated(self):
             return self._revalidate_connection()
 
     def _revalidate_connection(self):
-        if self.__invalid:
+        if self.__can_reconnect and self.__invalid:
             if self.__transaction is not None:
                 raise exc.InvalidRequestError(
                                 "Can't reconnect until invalid "
         and will allow no further operations.
 
         """
-
         try:
             conn = self.__connection
         except AttributeError:
-            return
-        if not self.__branch:
-            conn.close()
-        self.__invalid = False
-        del self.__connection
+            pass
+        else:
+            if not self.__branch:
+                conn.close()
+            del self.__connection
+        self.__can_reconnect = False
         self.__transaction = None
 
     def scalar(self, object, *multiparams, **params):
             if isinstance(e, (SystemExit, KeyboardInterrupt)):
                 raise
 
+    _reentrant_error = False
+    _is_disconnect = False
+
     def _handle_dbapi_exception(self,
                                     e,
                                     statement,
                                     parameters,
                                     cursor,
                                     context):
-        if getattr(self, '_reentrant_error', False):
+
+        if not self._is_disconnect:
+            self._is_disconnect = isinstance(e, self.dialect.dbapi.Error) and \
+                not self.closed and \
+                self.dialect.is_disconnect(e, self.__connection, cursor)
+
+        if self._reentrant_error:
             # Py3K
             #raise exc.DBAPIError.instance(statement, parameters, e,
             #                               self.dialect.dbapi.Error) from e
                                                     e)
                 context.handle_dbapi_exception(e)
 
-            is_disconnect = isinstance(e, self.dialect.dbapi.Error) and \
-                self.dialect.is_disconnect(e, self.__connection, cursor)
-
-            if is_disconnect:
-                dbapi_conn_wrapper = self.connection
-                self.invalidate(e)
-                if not hasattr(dbapi_conn_wrapper, '_pool') or \
-                    dbapi_conn_wrapper._pool is self.engine.pool:
-                    self.engine.dispose()
-            else:
+            if not self._is_disconnect:
                 if cursor:
                     self._safe_close_cursor(cursor)
                 self._autorollback()
-                if self.should_close_with_result:
-                    self.close()
 
             if not should_wrap:
                 return
             #                        parameters,
             #                        e,
             #                        self.dialect.dbapi.Error,
-            #                        connection_invalidated=is_disconnect) \
+            #                        connection_invalidated=self._is_disconnect) \
             #                        from e
             # Py2K
             raise exc.DBAPIError.instance(
                                     parameters,
                                     e,
                                     self.dialect.dbapi.Error,
-                                    connection_invalidated=is_disconnect), \
+                                    connection_invalidated=self._is_disconnect), \
                                     None, sys.exc_info()[2]
             # end Py2K
 
         finally:
             del self._reentrant_error
+            if self._is_disconnect:
+                del self._is_disconnect
+                dbapi_conn_wrapper = self.connection
+                self.invalidate(e)
+                if not hasattr(dbapi_conn_wrapper, '_pool') or \
+                        dbapi_conn_wrapper._pool is self.engine.pool:
+                    self.engine.dispose()
+            if self.should_close_with_result:
+                self.close()
 
     # poor man's multimethod/generic function thingy
     executors = {

test/engine/test_reconnect.py

 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.engines import testing_engine
 
-class MockDisconnect(Exception):
+class MockError(Exception):
+    pass
+
+class MockDisconnect(MockError):
     pass
 
 class MockDBAPI(object):
         self.connections = weakref.WeakKeyDictionary()
     def connect(self, *args, **kwargs):
         return MockConnection(self)
-    def shutdown(self):
+    def shutdown(self, explode='execute'):
         for c in self.connections:
-            c.explode[0] = True
-    Error = MockDisconnect
+            c.explode = explode
+    Error = MockError
 
 class MockConnection(object):
     def __init__(self, dbapi):
         dbapi.connections[self] = True
-        self.explode = [False]
+        self.explode = ""
     def rollback(self):
-        pass
+        if self.explode == 'rollback':
+            raise MockDisconnect("Lost the DB connection on rollback")
+        if self.explode == 'rollback_no_disconnect':
+            raise MockError(
+                "something broke on rollback but we didn't lose the connection")
+        else:
+            return
     def commit(self):
         pass
     def cursor(self):
         self.explode = parent.explode
         self.description = ()
     def execute(self, *args, **kwargs):
-        if self.explode[0]:
-            raise MockDisconnect("Lost the DB connection")
+        if self.explode == 'execute':
+            raise MockDisconnect("Lost the DB connection on execute")
+        elif self.explode in ('execute_no_disconnect', ):
+            raise MockError(
+                "something broke on execute but we didn't lose the connection")
+        elif self.explode in ('rollback', 'rollback_no_disconnect'):
+            raise MockError(
+                "something broke on execute but we didn't lose the connection")
         else:
             return
     def close(self):
 
         dbapi.shutdown()
 
-        # raises error
-        try:
-            conn.execute(select([1]))
-            assert False
-        except tsa.exc.DBAPIError:
-            pass
+        assert_raises(
+            tsa.exc.DBAPIError,
+            conn.execute, select([1])
+        )
 
         assert not conn.closed
         assert conn.invalidated
         assert not conn.invalidated
         assert len(dbapi.connections) == 1
 
+    def test_invalidated_close(self):
+        conn = db.connect()
+
+        dbapi.shutdown()
+
+        assert_raises(
+            tsa.exc.DBAPIError,
+            conn.execute, select([1])
+        )
+
+        conn.close()
+        assert conn.closed
+        assert conn.invalidated
+        assert_raises_message(
+            tsa.exc.StatementError,
+            "This Connection is closed",
+            conn.execute, select([1])
+        )
+
+    def test_noreconnect_execute_plus_closewresult(self):
+        conn = db.connect(close_with_result=True)
+
+        dbapi.shutdown("execute_no_disconnect")
+
+        # raises error
+        assert_raises_message(
+            tsa.exc.DBAPIError,
+            "something broke on execute but we didn't lose the connection",
+            conn.execute, select([1])
+        )
+
+        assert conn.closed
+        assert not conn.invalidated
+
+    def test_noreconnect_rollback_plus_closewresult(self):
+        conn = db.connect(close_with_result=True)
+
+        dbapi.shutdown("rollback_no_disconnect")
+
+        # raises error
+        assert_raises_message(
+            tsa.exc.DBAPIError,
+            "something broke on rollback but we didn't lose the connection",
+            conn.execute, select([1])
+        )
+
+        assert conn.closed
+        assert not conn.invalidated
+
+        assert_raises_message(
+            tsa.exc.StatementError,
+            "This Connection is closed",
+            conn.execute, select([1])
+        )
+
+    def test_reconnect_on_reentrant(self):
+        conn = db.connect()
+
+        conn.execute(select([1]))
+
+        assert len(dbapi.connections) == 1
+
+        dbapi.shutdown("rollback")
+
+        # raises error
+        assert_raises_message(
+            tsa.exc.DBAPIError,
+            "Lost the DB connection on rollback",
+            conn.execute, select([1])
+        )
+
+        assert not conn.closed
+        assert conn.invalidated
+
+    def test_reconnect_on_reentrant_plus_closewresult(self):
+        conn = db.connect(close_with_result=True)
+
+        dbapi.shutdown("rollback")
+
+        # raises error
+        assert_raises_message(
+            tsa.exc.DBAPIError,
+            "Lost the DB connection on rollback",
+            conn.execute, select([1])
+        )
+
+        assert conn.closed
+        assert conn.invalidated
+
+        assert_raises_message(
+            tsa.exc.StatementError,
+            "This Connection is closed",
+            conn.execute, select([1])
+        )
+
 class CursorErrTest(fixtures.TestBase):
 
     def setup(self):
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.