1. Ben Trofatter
  2. sqlalchemy-2663

Commits

Mike Bayer  committed 5befdf2

- broke pool tests out into QueuePoolTest/SingletonThreadPoolTest
- added test for r5061/r5062 [ticket:1157]

  • Participants
  • Parent commits 51c32cf
  • Branches default

Comments (0)

Files changed (1)

File test/engine/pool.py

View file
  • Ignore whitespace
 class MockDBAPI(object):
     def __init__(self):
         self.throw_error = False
-    def connect(self, argument, delay=0):
+    def connect(self, *args, **kwargs):
         if self.throw_error:
             raise Exception("couldnt connect !")
+        delay = kwargs.pop('delay', 0)
         if delay:
             time.sleep(delay)
         return MockConnection()
         pass
 mock_dbapi = MockDBAPI()
 
-class PoolTest(TestBase):
 
+class PoolTestBase(TestBase):    
     def setUp(self):
         pool.clear_managers()
+        
+    def tearDownAll(self):
+       pool.clear_managers()
 
+class PoolTest(PoolTestBase):
     def testmanager(self):
         manager = pool.manage(mock_dbapi, use_threadlocal=True)
 
         self.assert_(connection.cursor() is not None)
         self.assert_(connection is not connection2)
 
-    def testqueuepool_del(self):
-        self._do_testqueuepool(useclose=False)
-
-    def testqueuepool_close(self):
-        self._do_testqueuepool(useclose=True)
-
-    def _do_testqueuepool(self, useclose=False):
-        p = pool.QueuePool(creator = lambda: mock_dbapi.connect('foo.db'), pool_size = 3, max_overflow = -1, use_threadlocal = False)
-
-        def status(pool):
-            tup = (pool.size(), pool.checkedin(), pool.overflow(), pool.checkedout())
-            print "Pool size: %d  Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup
-            return tup
-
-        c1 = p.connect()
-        self.assert_(status(p) == (3,0,-2,1))
-        c2 = p.connect()
-        self.assert_(status(p) == (3,0,-1,2))
-        c3 = p.connect()
-        self.assert_(status(p) == (3,0,0,3))
-        c4 = p.connect()
-        self.assert_(status(p) == (3,0,1,4))
-        c5 = p.connect()
-        self.assert_(status(p) == (3,0,2,5))
-        c6 = p.connect()
-        self.assert_(status(p) == (3,0,3,6))
-        if useclose:
-            c4.close()
-            c3.close()
-            c2.close()
-        else:
-            c4 = c3 = c2 = None
-        self.assert_(status(p) == (3,3,3,3))
-        if useclose:
-            c1.close()
-            c5.close()
-            c6.close()
-        else:
-            c1 = c5 = c6 = None
-        self.assert_(status(p) == (3,3,0,0))
-        c1 = p.connect()
-        c2 = p.connect()
-        self.assert_(status(p) == (3, 1, 0, 2), status(p))
-        if useclose:
-            c2.close()
-        else:
-            c2 = None
-        self.assert_(status(p) == (3, 2, 0, 1))
-
-    def test_timeout(self):
-        p = pool.QueuePool(creator = lambda: mock_dbapi.connect('foo.db'), pool_size = 3, max_overflow = 0, use_threadlocal = False, timeout=2)
-        c1 = p.connect()
-        c2 = p.connect()
-        c3 = p.connect()
-        now = time.time()
-        try:
-            c4 = p.connect()
-            assert False
-        except tsa.exc.TimeoutError, e:
-            assert int(time.time() - now) == 2
-
-    def test_timeout_race(self):
-        # test a race condition where the initial connecting threads all race
-        # to queue.Empty, then block on the mutex.  each thread consumes a
-        # connection as they go in.  when the limit is reached, the remaining
-        # threads go in, and get TimeoutError; even though they never got to
-        # wait for the timeout on queue.get().  the fix involves checking the
-        # timeout again within the mutex, and if so, unlocking and throwing
-        # them back to the start of do_get()
-        p = pool.QueuePool(creator = lambda: mock_dbapi.connect('foo.db', delay=.05), pool_size = 2, max_overflow = 1, use_threadlocal = False, timeout=3)
-        timeouts = []
-        def checkout():
-            for x in xrange(1):
-                now = time.time()
-                try:
-                    c1 = p.connect()
-                except tsa.exc.TimeoutError, e:
-                    timeouts.append(int(time.time()) - now)
-                    continue
-                time.sleep(4)
-                c1.close()
-
-        threads = []
-        for i in xrange(10):
-            th = threading.Thread(target=checkout)
-            th.start()
-            threads.append(th)
-        for th in threads:
-            th.join()
-
-        print timeouts
-        assert len(timeouts) > 0
-        for t in timeouts:
-            assert abs(t - 3) < 1, "Not all timeouts were 3 seconds: " + repr(timeouts)
-
-    def _test_overflow(self, thread_count, max_overflow):
-        def creator():
-            time.sleep(.05)
-            return mock_dbapi.connect('foo.db')
-
-        p = pool.QueuePool(creator=creator,
-                           pool_size=3, timeout=2,
-                           max_overflow=max_overflow)
-        peaks = []
-        def whammy():
-            for i in range(10):
-                try:
-                    con = p.connect()
-                    time.sleep(.005)
-                    peaks.append(p.overflow())
-                    con.close()
-                    del con
-                except tsa.exc.TimeoutError:
-                    pass
-        threads = []
-        for i in xrange(thread_count):
-            th = threading.Thread(target=whammy)
-            th.start()
-            threads.append(th)
-        for th in threads:
-            th.join()
-
-        self.assert_(max(peaks) <= max_overflow)
-
-    def test_no_overflow(self):
-        self._test_overflow(40, 0)
-
-    def test_max_overflow(self):
-        self._test_overflow(40, 5)
-
-    def test_mixed_close(self):
-        p = pool.QueuePool(creator = lambda: mock_dbapi.connect('foo.db'), pool_size = 3, max_overflow = -1, use_threadlocal = True)
-        c1 = p.connect()
-        c2 = p.connect()
-        assert c1 is c2
-        c1.close()
-        c2 = None
-        assert p.checkedout() == 1
-        c1 = None
-        assert p.checkedout() == 0
-
-    def test_weakref_kaboom(self):
-        p = pool.QueuePool(creator = lambda: mock_dbapi.connect('foo.db'), pool_size = 3, max_overflow = -1, use_threadlocal = True)
-        c1 = p.connect()
-        c2 = p.connect()
-        c1.close()
-        c2 = None
-        del c1
-        del c2
-        gc.collect()
-        assert p.checkedout() == 0
-        c3 = p.connect()
-        assert c3 is not None
-
-    def test_trick_the_counter(self):
-        """this is a "flaw" in the connection pool; since threadlocal uses a single ConnectionFairy per thread
-        with an open/close counter, you can fool the counter into giving you a ConnectionFairy with an
-        ambiguous counter.  i.e. its not true reference counting."""
-        p = pool.QueuePool(creator = lambda: mock_dbapi.connect('foo.db'), pool_size = 3, max_overflow = -1, use_threadlocal = True)
-        c1 = p.connect()
-        c2 = p.connect()
-        assert c1 is c2
-        c1.close()
-        c2 = p.connect()
-        c2.close()
-        self.assert_(p.checkedout() != 0)
-
-        c2.close()
-        self.assert_(p.checkedout() == 0)
-
-    def test_recycle(self):
-        p = pool.QueuePool(creator = lambda: mock_dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False, recycle=3)
-
-        c1 = p.connect()
-        c_id = id(c1.connection)
-        c1.close()
-        c2 = p.connect()
-        assert id(c2.connection) == c_id
-        c2.close()
-        time.sleep(4)
-        c3= p.connect()
-        assert id(c3.connection) != c_id
-
-    def test_invalidate(self):
-        dbapi = MockDBAPI()
-        p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
-        c1 = p.connect()
-        c_id = c1.connection.id
-        c1.close(); c1=None
-        c1 = p.connect()
-        assert c1.connection.id == c_id
-        c1.invalidate()
-        c1 = None
-
-        c1 = p.connect()
-        assert c1.connection.id != c_id
-
-    def test_recreate(self):
-        dbapi = MockDBAPI()
-        p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
-        p2 = p.recreate()
-        assert p2.size() == 1
-        assert p2._use_threadlocal is False
-        assert p2._max_overflow == 0
-
-    def test_reconnect(self):
-        """tests reconnect operations at the pool level.  SA's engine/dialect includes another
-        layer of reconnect support for 'database was lost' errors."""
-        dbapi = MockDBAPI()
-        p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
-        c1 = p.connect()
-        c_id = c1.connection.id
-        c1.close(); c1=None
-
-        c1 = p.connect()
-        assert c1.connection.id == c_id
-        dbapi.raise_error = True
-        c1.invalidate()
-        c1 = None
-
-        c1 = p.connect()
-        assert c1.connection.id != c_id
-
-    def test_detach(self):
-        dbapi = MockDBAPI()
-        p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
-
-        c1 = p.connect()
-        c1.detach()
-        c_id = c1.connection.id
-
-        c2 = p.connect()
-        assert c2.connection.id != c1.connection.id
-        dbapi.raise_error = True
-
-        c2.invalidate()
-        c2 = None
-
-        c2 = p.connect()
-        assert c2.connection.id != c1.connection.id
-
-        con = c1.connection
-
-        assert not con.closed
-        c1.close()
-        assert con.closed
-
-    def test_threadfairy(self):
-        p = pool.QueuePool(creator = lambda: mock_dbapi.connect('foo.db'), pool_size = 3, max_overflow = -1, use_threadlocal = True)
-        c1 = p.connect()
-        c1.close()
-        c2 = p.connect()
-        assert c2.connection is not None
 
     def testthreadlocal_del(self):
         self._do_testthreadlocal(useclose=False)
 
     def _do_testthreadlocal(self, useclose=False):
         for p in (
-            pool.QueuePool(creator = lambda: mock_dbapi.connect('foo.db'), pool_size = 3, max_overflow = -1, use_threadlocal = True),
-            pool.SingletonThreadPool(creator = lambda: mock_dbapi.connect('foo.db'), use_threadlocal = True)
+            pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True),
+            pool.SingletonThreadPool(creator = mock_dbapi.connect, use_threadlocal = True)
         ):
             c1 = p.connect()
             c2 = p.connect()
         c.close()
         assert counts == [1, 2, 3]
 
+class QueuePoolTest(PoolTestBase):
 
+   def testqueuepool_del(self):
+       self._do_testqueuepool(useclose=False)
 
-    def tearDown(self):
-       pool.clear_managers()
+   def testqueuepool_close(self):
+       self._do_testqueuepool(useclose=True)
 
+   def _do_testqueuepool(self, useclose=False):
+       p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = False)
 
+       def status(pool):
+           tup = (pool.size(), pool.checkedin(), pool.overflow(), pool.checkedout())
+           print "Pool size: %d  Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup
+           return tup
+
+       c1 = p.connect()
+       self.assert_(status(p) == (3,0,-2,1))
+       c2 = p.connect()
+       self.assert_(status(p) == (3,0,-1,2))
+       c3 = p.connect()
+       self.assert_(status(p) == (3,0,0,3))
+       c4 = p.connect()
+       self.assert_(status(p) == (3,0,1,4))
+       c5 = p.connect()
+       self.assert_(status(p) == (3,0,2,5))
+       c6 = p.connect()
+       self.assert_(status(p) == (3,0,3,6))
+       if useclose:
+           c4.close()
+           c3.close()
+           c2.close()
+       else:
+           c4 = c3 = c2 = None
+       self.assert_(status(p) == (3,3,3,3))
+       if useclose:
+           c1.close()
+           c5.close()
+           c6.close()
+       else:
+           c1 = c5 = c6 = None
+       self.assert_(status(p) == (3,3,0,0))
+       c1 = p.connect()
+       c2 = p.connect()
+       self.assert_(status(p) == (3, 1, 0, 2), status(p))
+       if useclose:
+           c2.close()
+       else:
+           c2 = None
+       self.assert_(status(p) == (3, 2, 0, 1))
+
+   def test_timeout(self):
+       p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = 0, use_threadlocal = False, timeout=2)
+       c1 = p.connect()
+       c2 = p.connect()
+       c3 = p.connect()
+       now = time.time()
+       try:
+           c4 = p.connect()
+           assert False
+       except tsa.exc.TimeoutError, e:
+           assert int(time.time() - now) == 2
+
+   def test_timeout_race(self):
+       # test a race condition where the initial connecting threads all race
+       # to queue.Empty, then block on the mutex.  each thread consumes a
+       # connection as they go in.  when the limit is reached, the remaining
+       # threads go in, and get TimeoutError; even though they never got to
+       # wait for the timeout on queue.get().  the fix involves checking the
+       # timeout again within the mutex, and if so, unlocking and throwing
+       # them back to the start of do_get()
+       p = pool.QueuePool(creator = lambda: mock_dbapi.connect(delay=.05), pool_size = 2, max_overflow = 1, use_threadlocal = False, timeout=3)
+       timeouts = []
+       def checkout():
+           for x in xrange(1):
+               now = time.time()
+               try:
+                   c1 = p.connect()
+               except tsa.exc.TimeoutError, e:
+                   timeouts.append(int(time.time()) - now)
+                   continue
+               time.sleep(4)
+               c1.close()
+
+       threads = []
+       for i in xrange(10):
+           th = threading.Thread(target=checkout)
+           th.start()
+           threads.append(th)
+       for th in threads:
+           th.join()
+
+       print timeouts
+       assert len(timeouts) > 0
+       for t in timeouts:
+           assert abs(t - 3) < 1, "Not all timeouts were 3 seconds: " + repr(timeouts)
+
+   def _test_overflow(self, thread_count, max_overflow):
+       def creator():
+           time.sleep(.05)
+           return mock_dbapi.connect()
+
+       p = pool.QueuePool(creator=creator,
+                          pool_size=3, timeout=2,
+                          max_overflow=max_overflow)
+       peaks = []
+       def whammy():
+           for i in range(10):
+               try:
+                   con = p.connect()
+                   time.sleep(.005)
+                   peaks.append(p.overflow())
+                   con.close()
+                   del con
+               except tsa.exc.TimeoutError:
+                   pass
+       threads = []
+       for i in xrange(thread_count):
+           th = threading.Thread(target=whammy)
+           th.start()
+           threads.append(th)
+       for th in threads:
+           th.join()
+
+       self.assert_(max(peaks) <= max_overflow)
+
+   def test_no_overflow(self):
+       self._test_overflow(40, 0)
+
+   def test_max_overflow(self):
+       self._test_overflow(40, 5)
+
+   def test_mixed_close(self):
+       p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
+       c1 = p.connect()
+       c2 = p.connect()
+       assert c1 is c2
+       c1.close()
+       c2 = None
+       assert p.checkedout() == 1
+       c1 = None
+       assert p.checkedout() == 0
+
+   def test_weakref_kaboom(self):
+       p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
+       c1 = p.connect()
+       c2 = p.connect()
+       c1.close()
+       c2 = None
+       del c1
+       del c2
+       gc.collect()
+       assert p.checkedout() == 0
+       c3 = p.connect()
+       assert c3 is not None
+
+   def test_trick_the_counter(self):
+       """this is a "flaw" in the connection pool; since threadlocal uses a single ConnectionFairy per thread
+       with an open/close counter, you can fool the counter into giving you a ConnectionFairy with an
+       ambiguous counter.  i.e. its not true reference counting."""
+       p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
+       c1 = p.connect()
+       c2 = p.connect()
+       assert c1 is c2
+       c1.close()
+       c2 = p.connect()
+       c2.close()
+       self.assert_(p.checkedout() != 0)
+
+       c2.close()
+       self.assert_(p.checkedout() == 0)
+
+   def test_recycle(self):
+       p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 1, max_overflow = 0, use_threadlocal = False, recycle=3)
+
+       c1 = p.connect()
+       c_id = id(c1.connection)
+       c1.close()
+       c2 = p.connect()
+       assert id(c2.connection) == c_id
+       c2.close()
+       time.sleep(4)
+       c3= p.connect()
+       assert id(c3.connection) != c_id
+
+   def test_invalidate(self):
+       dbapi = MockDBAPI()
+       p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
+       c1 = p.connect()
+       c_id = c1.connection.id
+       c1.close(); c1=None
+       c1 = p.connect()
+       assert c1.connection.id == c_id
+       c1.invalidate()
+       c1 = None
+
+       c1 = p.connect()
+       assert c1.connection.id != c_id
+
+   def test_recreate(self):
+       dbapi = MockDBAPI()
+       p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
+       p2 = p.recreate()
+       assert p2.size() == 1
+       assert p2._use_threadlocal is False
+       assert p2._max_overflow == 0
+
+   def test_reconnect(self):
+       """tests reconnect operations at the pool level.  SA's engine/dialect includes another
+       layer of reconnect support for 'database was lost' errors."""
+       dbapi = MockDBAPI()
+       p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
+       c1 = p.connect()
+       c_id = c1.connection.id
+       c1.close(); c1=None
+
+       c1 = p.connect()
+       assert c1.connection.id == c_id
+       dbapi.raise_error = True
+       c1.invalidate()
+       c1 = None
+
+       c1 = p.connect()
+       assert c1.connection.id != c_id
+
+   def test_detach(self):
+       dbapi = MockDBAPI()
+       p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
+
+       c1 = p.connect()
+       c1.detach()
+       c_id = c1.connection.id
+
+       c2 = p.connect()
+       assert c2.connection.id != c1.connection.id
+       dbapi.raise_error = True
+
+       c2.invalidate()
+       c2 = None
+
+       c2 = p.connect()
+       assert c2.connection.id != c1.connection.id
+
+       con = c1.connection
+
+       assert not con.closed
+       c1.close()
+       assert con.closed
+
+   def test_threadfairy(self):
+       p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
+       c1 = p.connect()
+       c1.close()
+       c2 = p.connect()
+       assert c2.connection is not None
+
+class SingletonThreadPoolTest(PoolTestBase):
+    def test_cleanup(self):
+        """test that the pool's connections are OK after cleanup() has been called."""
+        
+        p = pool.SingletonThreadPool(creator = mock_dbapi.connect, pool_size=3)
+        
+        def checkout():
+            for x in xrange(10):
+                c = p.connect()
+                assert c
+                c.cursor()
+                c.close()
+                
+                time.sleep(.1)
+                
+        threads = []
+        for i in xrange(10):
+            th = threading.Thread(target=checkout)
+            th.start()
+            threads.append(th)
+        for th in threads:
+            th.join()
+            
+        assert len(p._all_conns) == 3
+    
 if __name__ == "__main__":
     testenv.main()