Mike Bayer avatar Mike Bayer committed d4205cc

- changed "invalidate" semantics with pooled connection; will
instruct the underlying connection record to reconnect the next
time its called. "invalidate" will also automatically be called
if any error is thrown in the underlying call to connection.cursor().
this will hopefully allow the connection pool to reconnect to a
database that had been stopped and started without restarting
the connecting application [ticket:121]

Comments (0)

Files changed (4)

 defaults to 3600 seconds; connections after this age will be closed and
 replaced with a new one, to handle db's that automatically close 
 stale connections [ticket:274]
+- changed "invalidate" semantics with pooled connection; will 
+instruct the underlying connection record to reconnect the next 
+time its called.  "invalidate" will also automatically be called
+if any error is thrown in the underlying call to connection.cursor().
+this will hopefully allow the connection pool to reconnect to a
+database that had been stopped and started without restarting
+the connecting application [ticket:121] 
 - eesh !  the tutorial doctest was broken for quite some time.
 - add_property() method on mapper does a "compile all mappers"
 step in case the given property references a non-compiled mapper

lib/sqlalchemy/engine/base.py

         try:
             self.__engine.dialect.do_executemany(c, statement, parameters, context=context)
         except Exception, e:
-            self._rollback_impl()
+            self._autorollback()
+            #self._rollback_impl()
             if self.__close_with_result:
                 self.close()
             raise exceptions.SQLError(statement, parameters, e)

lib/sqlalchemy/pool.py

     def return_conn(self, agent):
         self.do_return_conn(agent._connection_record)
 
-    def return_invalid(self, agent):
-        self.do_return_invalid(agent._connection_record)
-        
     def get(self):
         return self.do_get()
     
     def do_return_conn(self, conn):
         raise NotImplementedError()
         
-    def do_return_invalid(self, conn):
-        raise NotImplementedError()
-        
     def status(self):
         raise NotImplementedError()
 
 
 class _ConnectionRecord(object):
     def __init__(self, pool):
-        self.pool = pool
+        self.__pool = pool
         self.connection = self.__connect()
     def close(self):
         self.connection.close()
+    def invalidate(self):
+        self.__pool.log("Invalidate connection %s" % repr(self.connection))
+        self.__close()
+        self.connection = None
     def get_connection(self):
-        if self.pool._recycle > -1 and time.time() - self.starttime > self.pool._recycle:
-            self.pool.log("Connection %s exceeded timeout; recycling" % repr(self.connection))
-            try:
-                self.connection.close()
-            except Exception, e:
-                self.pool.log("Connection %s threw an error: %s" % (repr(self.connection), str(e)))
+        if self.connection is None:
+            self.connection = self.__connect()
+        elif (self.__pool._recycle > -1 and time.time() - self.starttime > self.__pool._recycle):
+            self.__pool.log("Connection %s exceeded timeout; recycling" % repr(self.connection))
+            self.__close()
             self.connection = self.__connect()
         return self.connection
+    def __close(self):
+        try:
+            self.__pool.log("Closing connection %s" % (repr(self.connection)))
+            self.connection.close()
+        except Exception, e:
+            self.__pool.log("Connection %s threw an error on close: %s" % (repr(self.connection), str(e)))
     def __connect(self):
         try:
             self.starttime = time.time()
-            return self.pool._creator()
-        except:
+            return self.__pool._creator()
+        except Exception, e:
+            self.__pool.log("Error on connect(): %s" % (str(e)))
             raise
-            # TODO: reconnect support here ?
 
 class _ThreadFairy(object):
     """marks a thread identifier as owning a connection, for a thread local pool."""
         except:
             self.connection = None # helps with endless __getattr__ loops later on
             self._connection_record = None
-            self.__pool.return_invalid(self)
             raise
         if self.__pool.echo:
             self.__pool.log("Connection %s checked out from pool" % repr(self.connection))
     def invalidate(self):
-        if self.__pool.echo:
-            self.__pool.log("Invalidate connection %s" % repr(self.connection))
+        self._connection_record.invalidate()
         self.connection = None
-        self._connection_record = None
-        self._threadfairy = None
-        self.__pool.return_invalid(self)
+        self._close()
     def cursor(self, *args, **kwargs):
-        return _CursorFairy(self, self.connection.cursor(*args, **kwargs))
+        try:
+            return _CursorFairy(self, self.connection.cursor(*args, **kwargs))
+        except Exception, e:
+            self.invalidate()
+            raise
     def __getattr__(self, key):
         return getattr(self.connection, key)
     def checkout(self):
         self._close()
     def _close(self):
         if self.connection is not None:
-            if self.__pool.echo:
-                self.__pool.log("Connection %s being returned to pool" % repr(self.connection))
             try:
                 self.connection.rollback()
             except:
                 # damn mysql -- (todo look for NotSupportedError)
                 pass
+        if self._connection_record is not None:
+            if self.__pool.echo:
+                self.__pool.log("Connection %s being returned to pool" % repr(self.connection))
             self.__pool.return_conn(self)
-        self.__pool = None
-        self.connection = None
         self._connection_record = None
         self._threadfairy = None
             
     def do_return_conn(self, conn):
         pass
         
-    def do_return_invalid(self, conn):
-        try:
-            del self._conns[thread.get_ident()]
-        except KeyError:
-            pass
-            
     def do_get(self):
         try:
             return self._conns[thread.get_ident()]
         except Queue.Full:
             self._overflow -= 1
 
-    def do_return_invalid(self, conn):
-        if conn is not None:
-            self._overflow -= 1
-        
     def do_get(self):
         try:
             return self._pool.get(self._max_overflow > -1 and self._overflow >= self._max_overflow, self._timeout)

test/engine/pool.py

 import sqlalchemy.exceptions as exceptions
 
 class MockDBAPI(object):
+    def __init__(self):
+        self.throw_error = False
     def connect(self, argument):
+        if self.throw_error:
+            raise Exception("couldnt connect !")
         return MockConnection()
 class MockConnection(object):
     def close(self):
         time.sleep(3)
         c3= p.connect()
         assert id(c3.connection) != c_id
+    
+    def test_invalidate(self):
+        dbapi = MockDBAPI()
+        p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False, echo=True)
+        c1 = p.connect()
+        c_id = id(c1.connection)
+        c1.close(); c1=None
+
+        c1 = p.connect()
+        assert id(c1.connection) == c_id
+        c1.invalidate()
+        c1 = None
+        
+        c1 = p.connect()
+        assert id(c1.connection) != c_id
+
+    def test_reconnect(self):
+        dbapi = MockDBAPI()
+        p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False, echo=True)
+        c1 = p.connect()
+        c_id = id(c1.connection)
+        c1.close(); c1=None
+
+        c1 = p.connect()
+        assert id(c1.connection) == c_id
+        dbapi.raise_error = True
+        c1.invalidate()
+        c1 = None
+
+        c1 = p.connect()
+        assert id(c1.connection) != c_id
         
     def testthreadlocal_del(self):
         self._do_testthreadlocal(useclose=False)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.