Commits

Mike Steder  committed 64ebbf0

so close

  • Participants
  • Parent commits 2337f7b

Comments (0)

Files changed (6)

File txmysql/enum.py

+def Enum(*names):
+   ##assert names, "Empty enums are not supported" # <- Don't like empty enums? Uncomment!
+
+   class EnumClass(object):
+      __slots__ = names
+      def __iter__(self):        return iter(constants)
+      def __len__(self):         return len(constants)
+      def __getitem__(self, i):  return constants[i]
+      def __repr__(self):        return 'Enum' + str(names)
+      def __str__(self):         return 'enum ' + str(constants)
+
+   class EnumValue(object):
+      __slots__ = ('__value')
+      def __init__(self, value): self.__value = value
+      Value = property(lambda self: self.__value)
+      EnumType = property(lambda self: EnumType)
+      def __hash__(self):        return hash(self.__value)
+      def __cmp__(self, other):
+         # C fans might want to remove the following assertion
+         # to make all enums comparable by ordinal value {;))
+         assert self.EnumType is other.EnumType, "Only values from the same enum are comparable"
+         return cmp(self.__value, other.__value)
+      def __invert__(self):      return constants[maximum - self.__value]
+      def __nonzero__(self):     return bool(self.__value)
+      def __repr__(self):        return str(names[self.__value])
+
+   maximum = len(names) - 1
+   constants = [None] * len(names)
+   for i, each in enumerate(names):
+      val = EnumValue(i)
+      setattr(EnumClass, each, val)
+      constants[i] = val
+   constants = tuple(constants)
+   EnumType = EnumClass()
+   return EnumType
+
+
+if __name__ == '__main__':
+   print '\n*** Enum Demo ***'
+   print '--- Days of week ---'
+   Days = Enum('Mo', 'Tu', 'We', 'Th', 'Fr', 'Sa', 'Su')
+   print Days
+   print Days.Mo
+   print Days.Fr
+   print Days.Mo < Days.Fr
+   print list(Days)
+   for each in Days:
+      print 'Day:', each
+   print '--- Yes/No ---'
+   Confirmation = Enum('No', 'Yes')
+   answer = Confirmation.No
+   print 'Your answer is not', ~answer

File txmysql/imysql.py

+"""
+Interfaces for TxMySQL
+
+"""
+
+from zope import component
+from zope.interface import Attribute, Interface
+
+registry = component.getGlobalSiteManager()
+
+
+class IPacket(Interface):
+    length = Attribute("24bit unsigned integer representing the"
+                       "length of this packet")
+    order = Attribute("8bit unsigned integer sequence number of this packet")
+
+    def fromBytes(bytes):
+        """This method returns an object implementing `IPacket`
+        """
+
+    def toBytes():
+        """This method returns a stream of bytes representing
+        this packet.
+        """
+
+
+class IHandshakeInitializationPacket(IPacket):
+    protocolVersion = Attribute("1 Byte version number of the protocol"
+                                "(e.g.: 10)")
+    serverVersion = Attribute("Null terminated string representing the server"
+                              "version.  For example: 5.1beta16")
+    threadId = Attribute("4 Byte ID of server thread for this connection")
+    scrambleBuffer = Attribute("8 byte scramble buffer, used by password"
+                               "authentication mechanism")
+    serverCapabilities = Attribute("2 byte integer of flags of capabilities"
+                                   "supported by this server.")
+    serverLanguage = Attribute("1 byte language code for this server")
+    serverStatus = Attribute("2 byte server status codes")
+    restOfScrambleBuffer = Attribute("13 byte remainder of scramble buffer")
+    
+
+class IAuthenticationPacket(IPacket):
+    clientFlags = Attribute("")
+    maxPacketSize = Attribute("")
+    charsetNumber = Attribute("")
+    user = Attribute("")
+    scrambleBuffer = Attribute("")
+    databaseName = Attribute("")

File txmysql/packet.py

 
 """
 
-
 import struct
 
+from zope import interface as zinterface
+
+from txmysql import imysql
+
+
 # packers
+packetHeaderLength = 4
+
 def pack_uint8(x):
     bytes = struct.pack("B", x)
     return bytes
 
 # Packet Class
 class Packet(object):
-    def __init__(self, bytes):
-        self.length, self.order = unpackHeader(bytes[0:4])
-        self.bytes = bytes[4:]
+    zinterface.implements(imysql.IPacket)
 
+    def __init__(self, length, order, remainingBytes=""):
+        self.length = length
+        self.order = order
+        self.bytes = remainingBytes
 
-handshakeInitializationFormat = "B%ssL8sxHBHx13s"
+    @classmethod
+    def fromBytes(cls, bytes):
+        length, order = unpackHeader(bytes[0:4])
+        remainingBytes = bytes[4:]
+        return Packet(length, order, remainingBytes)
+
+    def toBytes(self):
+        return packHeader(self.length, self.order)
+
+
+handshakeInitializationFormat = "=B%ssxL8sxHBHxxxxxxxxxxxxx13s"
 
 
 def packHandshakeInitialization(protocolVersion,
 
 
 def unpackHandshakeInitialization(bytes):
-    index = bytes.find("\0")
+    index = bytes.find("\x00")
+    print "bytes:", bytes
+    #import pdb; pdb.set_trace()
     values = struct.unpack(handshakeInitializationFormat % (index-1,), bytes)
     return values
 
 
-class HandshakeInitializationPacket(Packet):
-    def __init__(self, bytes):
-        super(HandshakeInitializationPacket, self).__init__(self)
+# class HandshakeInitializationPacket(Packet):
+#     def __init__(self, bytes):
+#         super(HandshakeInitializationPacket, self).__init__(bytes)
+#         handshakeValues = unpackHandshakeInitialization(self.bytes)
+#         self.serverVersion = handshakeValues[1]
+
+class HandshakeInitializationAdapter(Packet):
+    zinterface.implements(imysql.IHandshakeInitializationPacket)
+
+    __used_for__ = imysql.IPacket
+
+    def __init__(self, packet):
+        Packet.__init__(self, packet.length, packet.order, packet.bytes)
+        self.context = packet
         handshakeValues = unpackHandshakeInitialization(self.bytes)
+        self.protocolVersion = handshakeValues[0]
+        self.serverVersion = handshakeValues[1]
+        self.threadId = handshakeValues[2]
+        self.scrambleBuffer = handshakeValues[3]
+        self.serverCapabilities = handshakeValues[4]
+        self.serverLanguage = handshakeValues[5]
+        self.serverStatus = handshakeValues[6]
+        self.restOfScrambleBuffer = handshakeValues[7]
+
+imysql.registry.registerAdapter(HandshakeInitializationAdapter,
+                                (imysql.IPacket,),
+                                imysql.IHandshakeInitializationPacket)
+
+
+class AuthenticationAdapter(Packet):
+    zinterface.implements(imysql.IAuthenticationPacket)
+
+    __used_for__ = imysql.IHandshakeInitializationPacket
+
+    def __init__(self, packet):
+        Packet.__init__(self, packet.length, packet.order, packet.bytes)
+        self.context = packet
+
+    def toBytes(self):
+        pass
+
+imysql.registry.registerAdapter(AuthenticationAdapter,
+                                (imysql.IHandshakeInitializationPacket,),
+                                imysql.IAuthenticationPacket)

File txmysql/protocol.py

 """ protocol.py
 """
-import sys
+import sha
 try:
     import cStringIO as stringio
 except ImportError:
     import StringIO as stringio
-    
+import struct
+import sys    
 
 from twisted.internet import defer
 from twisted.internet import protocol
 from twisted.python import log
 
+from txmysql import enum
+from txmysql import imysql
+from txmysql import packet
+
 # CONSTANTS:
+# * CLIENT CAPABILITIES:
+CLIENT_LONG_PASSWORD = 1
+CLIENT_FOUND_ROWS = 1 << 1
+CLIENT_LONG_FLAG = 1 << 2
+CLIENT_CONNECT_WITH_DB = 1 << 3
+CLIENT_NO_SCHEMA = 1 << 4
+CLIENT_COMPRESS = 1 << 5
+CLIENT_ODBC = 1 << 6
+CLIENT_LOCAL_FILES = 1 << 7
+CLIENT_IGNORE_SPACE = 1 << 8
+CLIENT_PROTOCOL_41 = 1 << 9
+CLIENT_INTERACTIVE = 1 << 10
+CLIENT_SSL = 1 << 11
+CLIENT_IGNORE_SIGPIPE = 1 << 12
+CLIENT_TRANSACTIONS  = 1 << 13
+CLIENT_SECURE_CONNECTION = 1 << 15
+CLIENT_MULTI_STATEMENTS = 1 << 16
+CLIENT_MULTI_RESULTS = 1 << 17
+CLIENT_CAPABILITIES = (CLIENT_LONG_PASSWORD|CLIENT_LONG_FLAG|CLIENT_TRANSACTIONS| 
+                        CLIENT_PROTOCOL_41|CLIENT_SECURE_CONNECTION)
+
 # * SERVER CAPABILITIES:
 CLIENT_LONG_PASSWORD      = 1 # /* new more secure passwords */
 CLIENT_FOUND_ROWS         = 2 # /* Found instead of affected rows */
 CLIENT_MULTI_STATEMENTS   = 65536    # /* Enable/disable multi-stmt support */
 CLIENT_MULTI_RESULTS      = 131072   # /* Enable/disable multi-results */
 
-# PACKETS:
-packetHeaderLength = 4 # bytes
-
-def unpackPacketHeader(bytes):
-    return 0, 0
-
 # PROTOCOLS:
+protocolStates = enum.Enum("AWAITING_HANDSHAKE", "AUTHENTICATING")
 
 class MysqlProtocol(protocol.Protocol):
     def __init__(self):
         self.buffer = ""
+        self.state = protocolStates.AWAITING_HANDSHAKE
 
     def msg(self, *msg):
-        sys.stdout.write("%s"%(" ".join(msg),))
+        sys.stdout.write("%s\n"%(" ".join([str(x) for x in msg]),))
         sys.stdout.flush()
 
     def dataReceived(self, data):
         self.buffer += data
         self.msg(data)
-        
+
+        if self.state == protocolStates.AWAITING_HANDSHAKE:
+            packet = self.getHandshakeInitialization(self.buffer)
+            if packet is not None:
+                self.state = protocolStates.AUTHENTICATING
+                self.sendAuthentication(packet)
+        elif self.state == protocolStates.AUTHENTICATING:
+            self.msg("GOT RESPONSE AFTER AUTH:", str(len(data)), data)
+            field_count = struct.unpack("B", data[0])
+            if field_count == 0:
+                self.msg("OK:")
+            else:
+                field_count, errno, marker, sqlstate = struct.unpack("BHB5s", data[:10])
+                self.msg("ERROR:", field_count, errno, marker, sqlstate, data[11:])
+
+    def sendAuthentication(self, greetingPacket):
+        salt = (greetingPacket.scrambleBuffer +
+                greetingPacket.restOfScrambleBuffer)
+        self.factory.clientFlags |= CLIENT_CAPABILITIES
+        if greetingPacket.serverVersion.startswith('5'):
+            self.factory.clientFlags |= CLIENT_MULTI_RESULTS
+        data = (struct.pack('=i', self.factory.clientFlags) +
+                "\x00\x00\x00\x01" +  '\x08' + '\x00'*23 + 
+                self.factory.username + "\x00" +
+                self._scramble(self.factory.password, salt))
+        self.msg("authData:", "\"", data, "\"", str(len(data)))
+        if self.factory.database:
+            data += self.factory.database + "\x00"
+        self.msg("authData:", "\"", data, "\"", str(len(data)))
+        data = packet.pack_uint24(len(data)) + "\x01" + data
+        self.msg("authData:", "\"", data, "\"", str(len(data)))
+        self.transport.write(data)
+
+    def _scramble(self, password, message):
+        if password == None or len(password) == 0:
+            return '\0'
+        stage1 = sha.new(password).digest()
+        stage2 = sha.new(stage1).digest()
+        s = sha.new()
+        s.update(message)
+        s.update(stage2)
+        result = s.digest()
+        return self._my_crypt(result, stage1)
+
+    def _my_crypt(self, message1, message2):
+        length = len(message1)
+        result = struct.pack('B', length)
+        for i in xrange(length):
+            x = (struct.unpack('B', message1[i:i+1])[0] ^ struct.unpack('B', message2[i:i+1])[0])
+            result += struct.pack('B', x)
+        return result
+
+    def getHandshakeInitialization(self, buff):
+        p = None
+        if len(buff) > packet.packetHeaderLength:
+            length, order = packet.unpackHeader(buff[:4])
+            if len(buff[4:]) >= length:
+                # create the packet
+                p = imysql.IHandshakeInitializationPacket(
+                    packet.Packet.fromBytes(buff[:length+4]))
+
+                # remove packet bytes from buffer
+                self.buffer = buff[length+4:]
+        return p
+
     def connectionMade(self):
         self.msg("Connection made")
         #self.transport.write("CONNECT")
                  password,
                  database=None,
                  character_set="utf-8",
-                 flags=None,
+                 flags=CLIENT_CAPABILITIES,
                  max_packet_size=None):
         self.username = username
         self.password = password
         self.database = database
         self.characterSet = character_set
-        self.flags = flags
+        self.clientFlags = flags
         self.maxPacketSize = max_packet_size
         self.deferred = defer.Deferred()
 
     def msg(self, *msg):
-        sys.stdout.write("%s"%(" ".join(msg),))
+        sys.stdout.write("%s\n"%(" ".join(msg),))
         sys.stdout.flush()
 
     def getDeferred(self):
         self.deferred.errback(reason)
 
     def clientConnectionLost(self, connector, reason):
-        self.msg("connection log.")
+        self.msg("connection lost.")
         self.deferred.callback(reason)

File txmysql/test/test_packet.py

 from twisted.trial import unittest
 
 from txmysql import packet
+from txmysql import imysql
 
 
 class TestUint8(unittest.TestCase):
                                                    serverLanguage,
                                                    serverStatus,
                                                    restOfScramble)
-        self.assertEqual(bytes, "\x085.1b1\x00\x00\xe8\x03\x00\x00\x00\x00\x00\x00abcdefgh\x00\x00\x00\x00!\x00\x00\x00\x00abcdefghijklm")
+        self.assertEqual(bytes, "\x085.1b1\x00\xe8\x03\x00\x00abcdefgh\x00\x00\x00!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00abcdefghijklm")
 
     def test_unpack(self):
-        bytes = "\x085.1b1\x00\x00\xe8\x03\x00\x00\x00\x00\x00\x00abcdefgh\x00\x00\x00\x00!\x00\x00\x00\x00abcdefghijklm"
+        bytes = "\x085.1b1\x00\xe8\x03\x00\x00abcdefgh\x00\x00\x00!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00abcdefghijklm"
         values = packet.unpackHandshakeInitialization(bytes)
         self.assertEqual(len(values), 8)
         self.assertEqual(values, (8, "5.1b1",
         """See if we can parse a message that just consists of the packet header
         """
         bytes = "\x11\x00\x00\x01"
-        p = packet.Packet(bytes)
+        p = packet.Packet.fromBytes(bytes)
         self.assertEqual(p.length, 17)
         self.assertEqual(p.order, 1)
 
+    def test_handshakeInitializationPacket(self):
+        """See if we can parse a message that consists of the handshake initialization packet
+        """
+        handshakeMessage = "\x085.1b1\x00\xe8\x03\x00\x00abcdefgh\x00\x00\x00!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00abcdefghijklm"
+        bytes = packet.packHeader(len(handshakeMessage), 0)
+        bytes += handshakeMessage
+        p = imysql.IHandshakeInitializationPacket(packet.Packet.fromBytes(bytes))
+        self.assertEqual(p.serverVersion, "5.1b1")
+        self.assertEqual(p.scrambleBuffer, "")
+        self.assertEqual(p.restOfScrambleBuffer, "")
+        

File txmysql/test/test_protocol.py

 from twisted.internet import reactor
 from twisted.trial import unittest
 
+from txmysql import packet
 from txmysql import protocol as myprotocols
 
 
-
-
 class TestMysqlProtocolConnection(unittest.TestCase):
     def setUp(self):
         self.buff = StringIO.StringIO()
         self.p.connectionMade()
         self.assertEqual(self.buff.getvalue(), "")
 
-    # def test_handshakeInitializationPacket(self):
-    #     self.p.dataReceived()
-    #     self.assertEqual(len(self.p.packets), 1)
-    #     packet = self.p.packets[0]
-    #     self.assertEqual(packet, None)
+    def _receiveHandshakeInitializationPacket(self):
+        handshakeMessage = "\x085.1b1\x00\xe8\x03\x00\x00abcdefgh\x00\x00\x00!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00abcdefghijklm"
+        bytes = packet.packHeader(len(handshakeMessage), 0)
+        bytes += handshakeMessage
+        self.p.dataReceived(bytes)
+
+    def test_awaitingToAuthenticating(self):
+        self.assertEqual(self.p.state, myprotocols.protocolStates.AWAITING_HANDSHAKE)
+        self._receiveHandshakeInitializationPacket()
+        self.assertEqual(self.p.state, myprotocols.protocolStates.AUTHENTICATING)
+
+    def test_sendAuthenticationPacket(self):
+        self._receiveHandshakeInitializationPacket()
+        self.assertEqual(self.buff.getvalue(), ".\x00\x00\x00")