Commits

Mike Steder committed 04da8f5

Slowly putting together all the pieces of the connection pool

  • Participants
  • Parent commits 1c52405

Comments (0)

Files changed (3)

 
 # 3rd party
 from twisted.python import filepath
+from twisted.python import log
 from twisted.internet import reactor
 
 # 1st party
+from txmysql import adbapi
 from txmysql import protocol
 from txmysql import settings
 
 def shutdown(_):
     reactor.stop()
 
-def runQuery(proto):
-    d = proto.selectdb("test")
-    #proto.query("select * from counter")
-    #proto.quit()
+def run(result):
+    return pool.runQuery("select 1")
 
-factory = protocol.MysqlProtocolFactory(
-    settings.USERNAME,
-    settings.PASSWORD
-)
-d = factory.getDeferred()
-d.addErrback(shutdown)
-d.addCallback(runQuery)
+#factory = protocol.MysqlProtocolFactory(
+#    settings.USERNAME,
+#    settings.PASSWORD
+#)
+#d = factory.getDeferred()
+#d.addErrback(shutdown)
+#d.addCallback(runQuery)   
+#reactor.connectTCP(settings.HOSTNAME, settings.PORT, factory)
+#reactor.run()
 
-reactor.connectTCP(settings.HOSTNAME, settings.PORT, factory)
+pool = adbapi.ConnectionPool(settings.HOSTNAME, settings.PORT,
+                   settings.USERNAME, settings.PASSWORD,
+                   size=2)
+dl = pool.start()
+dl.addCallback(run)
+dl.addErrback(log.err)
+dl.addCallback(shutdown)
 reactor.run()
-

txmysql/adbapi.py

+"""
+Asynchronous DB-API compatible interface for txMysql
+
+"""
+import collections
+import Queue
+
+from twisted.internet import defer
+from twisted.internet import reactor
+from twisted.enterprise import adbapi
+from twisted.python import log
+from zope import interface as zinterface
+
+from txmysql import imysql
+from txmysql import protocol
+
+ConnectionLost = adbapi.ConnectionLost
+
+
+class Connection(object):
+    def __init__(self, proto):
+        self.proto = proto
+
+    def isAvailable(self):
+        return (self.proto.state == self.proto.AVAILABLE)
+
+    def _cbRun(self, results, query):
+        query.deferred.callback(results)
+
+    def _ebRun(self, failure, query):
+        query.deferred.errback(failure)
+
+    def run(self, query):
+        #import pdb; pdb.set_trace()
+        # maybe escape the sql? ;-)
+        d = self.proto.query(query.sql)
+        d.addCallback(self._cbRun, query)
+        d.addErrback(self._ebRun, query)
+        return d
+
+
+class Pool(object):
+    """
+    Pool is responsible for maintaining a pool of connections
+
+    Attributes:
+     - started:
+        * process failures during startup cause us to give up and shutdown,
+        * process failrues after startup cause us to simply restart
+          those processes
+     - shuttingDown:
+        * is set only when shutdown is called and then is true
+          until the service exits.
+        * disables process restarting during shutdown
+     - connections: dict of information describing all processes
+       key is process name.
+
+    Args:
+     - size
+
+    Example:
+    >>> Pool(1)
+    """
+    connectionFactory = protocol.MysqlProtocolFactory
+
+    system = "pool" # for debugging
+
+    def __init__(self, host, port,
+                 username, password,
+                 database=None,
+                 size=1):
+        self.started = False
+        self.shuttingDown = False
+
+        # operating information
+        self.host = host
+        self.port = port
+        self.username = username
+        self.password = password
+        self.size = size
+        self.connections = []
+        self.restartCallbacks = []
+        
+    def _cbConnectionStarted(self, proto):
+        c = Connection(proto)
+        self.connections.append(c)
+
+    def start(self):
+        """Start up all the pools connections
+        """
+        dl = []
+        for i in xrange(self.size):
+            f = self.connectionFactory(self.username,
+                                       self.password)
+            d = f.getDeferred()
+            d.addCallback(self._cbConnectionStarted)
+            dl.append(d)
+            reactor.connectTCP(self.host, self.port, f)
+        dl = defer.DeferredList(dl, fireOnOneErrback=True)
+        dl.addCallback(
+            self._cbStart
+        )
+        dl.addErrback(
+            self._ebStart
+        )
+        return dl
+
+    def _cbStart(self, _):
+        log.msg("connection pool ready!", system=self.system)
+        self.started = True
+
+    def _ebStart(self, failure):
+        log.err("failed to start connection pool...", system=self.system)
+        return failure
+
+    def anyFree(self):
+        for conn in self.connections:
+            if conn.isAvailable():
+                return True
+        return False
+    
+    def process(self, query):
+        for connection in self.connections:
+            if connection.isAvailable():
+                connection.run(query)
+                pd = defer.Deferred()
+                pd.addCallback(self._cbProcess)
+                pd.addErrback(self._ebProcess)
+                return pd
+
+    def _cbProcess(self, success):
+        return success
+
+    def _ebProcess(self, failure):
+        return failure
+
+
+class Query(object):
+    def __init__(self, sql):
+        self.sql = sql
+        self.deferred = defer.Deferred()
+
+
+class Dispatcher(object):
+    """Handles interaction with db by transparently queuing
+    requests and dispatching as connections become available
+    in the pool.
+    """
+    poolFactory = Pool
+    queueFactory = Queue.Queue
+
+    def __init__(self, host, port,
+                 username, password,
+                 database=None,
+                 size=1):
+        self.pool = self.poolFactory(host, port,
+                                     username, password,
+                                     database=database,
+                                     size=size)
+        self.queue = self.queueFactory()
+
+    def start(self):
+        return self.pool.start()
+
+    def runQuery(self, sql):
+        q = Query(sql)
+        self.accept(q)
+
+    def accept(self, query):
+        d = query.deferred
+        self.queue.put_nowait(query)
+        if self.pool.anyFree():
+            self.dispatch()
+        return d
+
+    def drain(self):
+        while self.pool.anyFree() and (self.queue.qsize() > 0):
+            self.dispatch()
+
+    def _cbDispatch(self, result, job):
+         self.drain()
+
+    def _ebDispatch(self, failure, job):
+        failure.trap(Exception)
+        log.err("Dispatched job failed with: %s"%(failure), system="dispatcher")
+        self.drain()
+
+    def dispatch(self):
+        query = self.queue.get_nowait()
+        poolDeferred = self.pool.process(query)
+        poolDeferred.addCallback(self._cbDispatch, query)
+        poolDeferred.addErrback(self._ebDispatch, query)
+
+
+        
+
+
+# use the db-api name
+ConnectionPool = Dispatcher
+    

txmysql/protocol.py

 COM_QUERY = 0x03 #(mysql_real_query)
 
 # PROTOCOLS:
-protocolStates = enum.Enum("AWAITING_HANDSHAKE", "AUTHENTICATING", "CONNECTED")
+
 
 def is_ascii(data):
     if data.isalnum():
 
 
 class MysqlProtocol(protocol.Protocol):
+    CONNECTING = "CONNECTING"
+    AUTHENTICATING = "AUTHENTICATING"
+    AVAILABLE = "AVAILABLE"
+    RUNNING = "RUNNING"
+
     def __init__(self):
         self.buffer = ""
-        self.state = protocolStates.AWAITING_HANDSHAKE
+        self.state = self.CONNECTING
 
     def msg(self, *msg):
         sys.stdout.write("%s\n"%(" ".join([str(x) for x in msg]),))
     def dataReceived(self, data):
         self.buffer += data
 
-        if self.state == protocolStates.AWAITING_HANDSHAKE:
+        if self.state == self.CONNECTING:
             pac = self.getHandshakeInitialization(self.buffer)
             if pac is not None:
-                self.state = protocolStates.AUTHENTICATING
+                self.state = self.AUTHENTICATING
                 self.sendAuthentication(pac)
-        elif self.state == protocolStates.AUTHENTICATING:
+        elif self.state == self.AUTHENTICATING:
             p = self.getPacket(self.buffer)
             if p:
                 field_count = struct.unpack("B", p.bytes[0])[0]
                 print "FIELD_COUNT:", field_count
                 if field_count == 0:
                     print "OK!"
-                    self.state = protocolStates.CONNECTED
-                    # fire some callback to let users know the protocol is connected
+                    self.state = self.AVAILABLE
+                    # fire callback to let others know this protocol
+                    # is connected
                     self.factory.deferred.callback(self)
                 else:
                     print "ERROR!"
-        elif self.state == protocolStates.CONNECTED:
+        elif self.state == self.AVAILABLE:
             p = self.getPacket(self.buffer)
-        elif self.state == protocolStates.RUNNING_COMMAND:
+        elif self.state == self.RUNNING:
             p = self.getPacket(self.buffer)
             if p:
-                d = self.deferreds.popleft()
-                self.state = protocolStates.CONNECTED
-                d.callback(self, p)
+                self.state = self.AVAILABLE
+                self.deferred.callback(p)
 
     def quit(self):
         quit_packet = struct.pack("B4s", COM_QUIT, "quit")
         self.write(select_db)
 
     def query(self, sql):
-        select_packet = struct.pack("<i", len(sql)+1) + struct.pack("B", COM_QUERY) + sql
+        select_packet = (struct.pack("<i", len(sql)+1) +
+                         struct.pack("B", COM_QUERY) + sql)
         self.write(select_packet)
+        self.deferred = defer.Deferred()
+        return self.deferred
         
     def sendAuthentication(self, greetingPacket):
         salt = (greetingPacket.scrambleBuffer +