Commits

gsalgado committed 576a5bf

Add a gevent ConnectionPool class, which prevents connections from being
used concurrently by multiple greenlets.

Comments (0)

Files changed (2)

psycogreen/gevent.py

 from __future__ import absolute_import
 
 import psycopg2
+from psycopg2.extras import RealDictConnection
 from psycopg2 import extensions
 
+from gevent.coros import Semaphore
+from gevent.local import local as gevent_local
 from gevent.socket import wait_read, wait_write
 
+
 def patch_psycopg():
     """Configure Psycopg to be used with gevent in non-blocking way."""
     if not hasattr(extensions, 'set_wait_callback'):
         else:
             raise psycopg2.OperationalError(
                 "Bad result from poll: %r" % state)
+
+
+class ConnectionPool(object):
+
+    def __init__(self, dsn, max_con=10, max_idle=3,
+                 connection_factory=RealDictConnection):
+        self.dsn = dsn
+        self.max_con = max_con
+        self.max_idle = max_idle
+        self.connection_factory = connection_factory
+        self._sem = Semaphore(max_con)
+        self._free = []
+        self._local = gevent_local()
+
+    def __enter__(self):
+        self._sem.acquire()
+        try:
+            if getattr(self._local, 'con', None) is not None:
+                raise RuntimeError("Attempting to re-enter connection pool?")
+            if self._free:
+                con = self._free.pop()
+            else:
+                con = psycopg2.connect(
+                    dsn=self.dsn, connection_factory=self.connection_factory)
+            self._local.con = con
+            return con
+        except StandardError:
+            self._sem.release()
+            raise
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        try:
+            if self._local.con is None:
+                raise RuntimeError("Exit connection pool with no connection?")
+            if exc_type is not None:
+                self.rollback()
+            else:
+                self.commit()
+            if len(self._free) < self.max_idle:
+                self._free.append(self._local.con)
+            self._local.con = None
+        finally:
+            self._sem.release()
+
+    def commit(self):
+        self._local.con.commit()
+
+    def rollback(self):
+        self._local.con.rollback()

tests/test_gevent_connection_pool.py

+import unittest
+
+import gevent
+import psycopg2
+
+from psycogreen.gevent import ConnectionPool
+
+
+class FakeConnection(object):
+
+    rollback_called = False
+    commit_called = False
+
+    def rollback(self):
+        self.rollback_called = True
+
+    def commit(self):
+        self.commit_called = True
+
+
+class TestConnectionPool(unittest.TestCase):
+
+    def setUp(self):
+        super(TestConnectionPool, self).setUp()
+        self._orig_connect = psycopg2.connect
+        self.conn = FakeConnection()
+        self.pool = ConnectionPool('bogus-dsn')
+        psycopg2.connect = lambda **kwargs: self.conn
+        self.addCleanup(self._restore_psycopg2_connect)
+
+    def _restore_psycopg2_connect(self):
+        psycopg2.connect = self._orig_connect
+
+    def test_commit_is_called_on_success(self):
+        with self.pool:
+            self.assertTrue(isinstance(self.pool._local, gevent.local.local))
+            self.assertEqual(self.conn, self.pool._local.con)
+        self.assertTrue(self.conn.commit_called)
+
+    def test_rollback_is_called_on_success(self):
+        try:
+            with self.pool:
+                raise ValueError("anything")
+        except ValueError:
+            pass
+        self.assertTrue(self.conn.rollback_called)
+
+    def test_cannot_reenter_connection_pool(self):
+        try:
+            with self.pool:
+                with self.pool:
+                    pass
+        except RuntimeError:
+            pass
+        else:
+            self.fail("Should not be allowed to re-enter pool")
+
+    def test_pool_locks_when_max_connections_reached(self):
+        pool = ConnectionPool('bogus-dsn', max_con=1)
+        def f():
+            with pool:
+                self.assertTrue(pool._sem.locked())
+
+        gevent.joinall([gevent.spawn(f)])
+
+
+if __name__ == "__main__":
+    unittest.main()