Commits

Mike Steder committed 2337f7b

Adding earliest packet parsing code and tests.
Added simple script to test connecting to MySQL

Comments (0)

Files changed (7)

+#!/usr/bin/env python
+# -*- mode: python -*-
+"""tmysqlc - twisted mysql client
+
+"""
+
+# standard lib
+import sys
+
+# 3rd party
+from twisted.python import filepath
+from twisted.internet import reactor
+
+# 1st party
+from txmysql import protocol
+from txmysql import settings
+
+
+def shutdown(_):
+    reactor.stop()
+
+
+factory = protocol.MysqlProtocolFactory(
+    settings.USERNAME,
+    settings.PASSWORD
+)
+d = factory.getDeferred()
+d.addBoth(shutdown)
+
+
+reactor.connectTCP(settings.HOSTNAME, settings.PORT, factory)
+reactor.run()
+

txmysql/packet.py

+"""
+Includes format strings that describe packets
+and classes to read/write packets.
+
+"""
+
+
+import struct
+
+# packers
+def pack_uint8(x):
+    bytes = struct.pack("B", x)
+    return bytes
+
+def pack_uint24(x):
+    bytes = struct.pack("BBB", x&0xFF, (x>>8)&0xFF, (x>>16)&0xFF)
+    return bytes
+
+def packHeader(length, order):
+    return pack_uint24(length) + pack_uint8(order)
+
+# unpackers
+def unpack_uint8(byte):
+    return struct.unpack("B", byte)[0]
+
+def unpack_uint24(bytes):
+    return ((struct.unpack("B", bytes[0])[0]) +
+            (struct.unpack("B", bytes[1])[0]<<8) +
+            (struct.unpack("B", bytes[2])[0]<<16))
+
+def unpackHeader(bytes):
+    return (unpack_uint24(bytes[0:3]),
+            unpack_uint8(bytes[3]))
+
+
+# Packet Class
+class Packet(object):
+    def __init__(self, bytes):
+        self.length, self.order = unpackHeader(bytes[0:4])
+        self.bytes = bytes[4:]
+
+
+handshakeInitializationFormat = "B%ssL8sxHBHx13s"
+
+
+def packHandshakeInitialization(protocolVersion,
+                                serverVersion,
+                                threadId,
+                                scrambleBuffer,
+                                serverCapabilities,
+                                serverLanguage,
+                                serverStatus,
+                                restOfScramble):
+    bytes = ""
+    bytes += struct.pack(handshakeInitializationFormat % (
+            len(serverVersion),
+            ),
+                         protocolVersion,
+                         serverVersion,
+                         threadId,
+                         scrambleBuffer,
+                         serverCapabilities,
+                         serverLanguage,
+                         serverStatus,
+                         restOfScramble)
+    return bytes
+
+
+def unpackHandshakeInitialization(bytes):
+    index = bytes.find("\0")
+    values = struct.unpack(handshakeInitializationFormat % (index-1,), bytes)
+    return values
+
+
+class HandshakeInitializationPacket(Packet):
+    def __init__(self, bytes):
+        super(HandshakeInitializationPacket, self).__init__(self)
+        handshakeValues = unpackHandshakeInitialization(self.bytes)

txmysql/protocol.py

 """ protocol.py
+"""
+import sys
+try:
+    import cStringIO as stringio
+except ImportError:
+    import StringIO as stringio
+    
 
-------------------------------------------------------------------
-The MySQL Protocol consists of the following structure 'atoms':
-------------------------------------------------------------------
-==================================================================
-Elements (Data values sent and received)
-==================================================================
+from twisted.internet import defer
+from twisted.internet import protocol
+from twisted.python import log
 
-+++++++++++++++++++++++++++++++++++++++++++
-Null-Terminated String: "\0" 0x00 
-+++++++++++++++++++++++++++++++++++++++++++
-
-++++++++++++++++++++++++++++++
-Length Coded Binary
-++++++++++++++++++++++++++++++
-
-Value of      # Of Bytse   Description
-First Byte    Following    
------------   ----------   -----------
-0-250         0            = value of first byte
-251           0            column value = NULL (only appropriate in
-                                                column data packet)
-252           2            = value of following 16-bit word
-253           3            = value of following 24-bit word
-254           8            = value of following 64-bit word
-
-All numbers are stored with least significant bit first and are unsigned.
-
-+++++++++++++++++++++++++++++++++++++++++
-Length Coded String
-+++++++++++++++++++++++++++++++++++++++++
-
-A length coded string is sent as a length coded binary followed by
-the data for that string.
-
-=====================================================
-The Packet Header
-=====================================================
-
-Bytes      Name
------      ----
- 3         Packet Length
- 1         Packet Number
-
- Where packet length is simply the length in bytes.
-
- Maxium packet length is 16MB. (2**(3(bytes)*8(bits)))
-
- Where packet number is an ordering for the packets.  Each query
- will start with packet 0 and increment until the result set
- is received.
-
-Every packet will have one of these first.
-
------------------------------------------------------------------------
-A typical session:
------------------------------------------------------------------------
-
- 1. The Handshake (Client connects):
-   A. Server to Client: Sends Handshake Initialization Packet
-   B. Client to Server: Client Authentication Packet
-   C. Server to Client: OK Packet or Error Packet
- 2. Commands (every action the client wants the server to do):
-   A. Client Sends to Server: Command Packet
-   B. Server Sends to Client: OK Packet, Error Packet, or Result Set Packet
-
-==================================================
-Handshake Initialization Packet
-==================================================
-Bytes                        Name
- -----                        ----
- 1                            protocol_version
- n (Null-Terminated String)   server_version
- 4                            thread_id
- 8                            scramble_buff
- 1                            (filler) always 0x00
- 2                            server_capabilities
- 1                            server_language
- 2                            server_status
- 13                           (filler) always 0x00 ...
- 13                           rest of scramble_buff (4.1)
-
- protocol_version:    The server takes this from PROTOCOL_VERSION
- in /include/mysql_version.h. Example value = 10.
-
- server_version:      The server takes this from MYSQL_SERVER_VERSION
- in /include/mysql_version.h. Example value = "4.1.1-alpha".
-
- thread_number:       ID of the server thread for this connection.
- 
- scramble_buff:       The password mechanism uses this. The second part are the
- last 13 bytes.
- (See "Password functions" section elsewhere in this document.)
- 
- server_capabilities: CLIENT_XXX options. The possible flag values at time of
- writing (taken from  include/mysql_com.h):
-  (SEE SERVER CAPABILITIES BELOW) 
- server_language:     current server character set number
-
- server_status:       SERVER_STATUS_xxx flags: e.g. SERVER_STATUS_AUTOCOMMIT
-"""
-# SERVER CAPABILITIES:
+# CONSTANTS:
+# * SERVER CAPABILITIES:
 CLIENT_LONG_PASSWORD      = 1 # /* new more secure passwords */
 CLIENT_FOUND_ROWS         = 2 # /* Found instead of affected rows */
 CLIENT_LONG_FLAG          = 4 # /* Get all column flags */
 CLIENT_MULTI_STATEMENTS   = 65536    # /* Enable/disable multi-stmt support */
 CLIENT_MULTI_RESULTS      = 131072   # /* Enable/disable multi-results */
 
-"""
-============================================================
-Client Authentication Packet
-============================================================
+# PACKETS:
+packetHeaderLength = 4 # bytes
 
-From client to server during initial handshake
+def unpackPacketHeader(bytes):
+    return 0, 0
 
-Version 4.1
- Bytes                       Name
- -----                       -----
- 4                           client_flags
- 4                           max_packet_size
- 1                           charset_number
- 23                          (filler) always 0x00
- n (Null-terminated string)  user
- n (length coded binary)     scramble_buff (1+x bytes)
- n (Null-terminated string)  databasename (optional)
+# PROTOCOLS:
 
- client_flags:            CLIENT_xxx options. The list of possible flag
- values is in the description of the Handshake
- Initialisation Packet, for server_capabilities.
- For some of the bits, the server passed "what
- it's capable of". The client leaves some of the
- bits on, adds others, and passes back to the server.
- One important flag is: whether compression is desired.
- Another interesting one is CLIENT_CONNECT_WITH_DB,
- which shows the presence of the optional databasename.
- 
- max_packet_size:         the maximum number of bytes in a packet for the client
- 
- charset_number:          in the same domain as the server_language field that
- the server passes in the Handshake Initialization packet.
- 
- user:                    identification
- 
- scramble_buff:           the password, after encrypting using the scramble_buff
- contents passed by the server (see "Password functions"
- section elsewhere in this document)
- if length is zero, no password was given
- 
- databasename:            name of schema to use initially
+class MysqlProtocol(protocol.Protocol):
+    def __init__(self):
+        self.buffer = ""
 
-================================================================
-OK Packet
-================================================================
+    def msg(self, *msg):
+        sys.stdout.write("%s"%(" ".join(msg),))
+        sys.stdout.flush()
 
-Version 4.1+
-
- Bytes                     Name
- -----                     ----
- 1 (Length Coded Binary)   field_count, always = 0
- 1-9 (Length Coded Binary) affected_rows
- 1-9 (Length Coded Binary) insert_id
- 2                         server_status
- 2                         warning_count
- n (until end of packet)   message
-
- field_count:     always = 0
-
- affected_rows:   = number of rows affected by INSERT/UPDATE/DELETE
-
- insert_id:       If the statement generated any AUTO_INCREMENT number,
- it is returned here. Otherwise this field contains 0.
- Note: when using for example a multiple row INSERT the
- insert_id will be from the first row inserted, not from
- last.
- 
- server_status:   = The client can use this to check if the
- command was inside a transaction.
- 
- warning_count:   number of warnings
- 
- message:         For example, after a multi-line INSERT, message might be
- "Records: 3 Duplicates: 0 Warnings: 0"
- 
- The message field is optional.
- Alternative terms: OK Packet is also known as "okay packet" or "ok packet" or "OK-Packet". field_count is also known as "number of rows" or "marker for ok packet". message is also known as "Messagetext". OK Packets (and result set packets) are also called "Result packets".
-
-================================================================
-Error Packet
-================================================================
-"""
-
-
-class MysqlProxyProtocol(object):
-    """For proxying traffic to a real mysql server.
-
-    Allows transforms and modifications of result sets
-    """
-
-
-class MysqlProxyFactory(object):
-    protocol = MysqlProxyProtocol
-
-    def __init__(self, hostname, port):
-        self.hostname = hostname
-        self.port = port
-
-
-class MysqlProtocol(object):
-    def __init__(self):
-        self.connected = False
-
-    def connect(self):
-        pass
-
-    def _cbConnected(self, result):
-        pass
-
-    def _ebConnectionFailed(self, error):
-        pass
+    def dataReceived(self, data):
+        self.buffer += data
+        self.msg(data)
+        
+    def connectionMade(self):
+        self.msg("Connection made")
+        #self.transport.write("CONNECT")
 
     
-class MysqlProtocolFactory(object):
+class MysqlProtocolFactory(protocol.ClientFactory):
     protocol = MysqlProtocol
 
     def __init__(self, username,
         self.characterSet = character_set
         self.flags = flags
         self.maxPacketSize = max_packet_size
+        self.deferred = defer.Deferred()
+
+    def msg(self, *msg):
+        sys.stdout.write("%s"%(" ".join(msg),))
+        sys.stdout.flush()
+
+    def getDeferred(self):
+        return self.deferred
+        
+    def startedConnecting(self, connector):
+        self.msg("started to connect...")
+
+    def clientConnectionFailed(self, connector, reason):
+        self.msg("connection failed...")
+        self.deferred.errback(reason)
+
+    def clientConnectionLost(self, connector, reason):
+        self.msg("connection log.")
+        self.deferred.callback(reason)

txmysql/settings.py

 TXMYSQL_ROOT = filepath.FilePath(__file__).parent().parent()
     
 HOSTNAME = "localhost"
-PORT = 7777
+PORT = 3306
 
 # BENCHMARKS Settings:
-USERNAME = "test"
+USERNAME = "root"
+PASSWORD = ""
 DATABASE = "test"
 ITERATIONS = 1000
 USE_UNICODE = True

txmysql/test/test_factory.py

+""" tests for the mysql protocol factory
+
+"""
+
+from twisted.internet import protocol
+from twisted.internet import reactor
+from twisted.trial import unittest
+
+from txmysql import protocol as myprotocols
+
+
+class MysqlTestProtocol(protocol.Protocol):
+    def connectionMade(self):
+        self.transport.loseConnection()
+
+
+class MysqlTestFactory(protocol.ServerFactory):
+    protocol = MysqlTestProtocol
+
+
+class TestMysqlProtocolConnection(unittest.TestCase):
+    def setUp(self):
+        self.f = myprotocols.MysqlProtocolFactory("test",
+                                                  "test")
+
+    def _ebConnectionShouldFail(self, _):
+        self.assertTrue(True)
+
+    def _cbConnectionShouldFail(self, _):
+        self.fail()
+
+    def test_connectionFailure(self):
+        d = self.f.getDeferred()
+        d.addCallback(self._cbConnectionShouldFail)
+        d.addErrback(self._ebConnectionShouldFail)
+
+        sf = MysqlTestFactory()
+        port = reactor.listenTCP(0, sf)
+        self.addCleanup(port.stopListening)
+        PORT = port.getHost().port
+        reactor.connectTCP("127.0.0.1", PORT, self.f)
+        return d
+        

txmysql/test/test_packet.py

+"""tests for the mysql packets stuff
+
+"""
+import struct
+
+from twisted.trial import unittest
+
+from txmysql import packet
+
+
+class TestUint8(unittest.TestCase):
+    def test_pack(self):
+        n = 10
+        bytes = packet.pack_uint8(n)
+        self.assertEqual(bytes, "\x0a")
+
+    def test_unpack(self):
+        bytes = "\x0a"
+        n = packet.unpack_uint8(bytes)
+        self.assertEqual(n, 10)
+
+    def test_overflow(self):
+        n = 256
+        self.assertRaises(struct.error, packet.pack_uint8, n)
+
+    def test_negative(self):
+        n = -10
+        self.assertRaises(struct.error, packet.pack_uint8, n)
+
+
+class TestPackUint24(unittest.TestCase):
+    def test_pack(self):
+        n = 0xFFFFFF - 1 # 16777214 (biggest signed number in 24bits)
+        bytes = packet.pack_uint24(n)
+        self.assertEqual(bytes, "\xfe\xff\xff")
+
+    def test_unpack(self):
+        bytes = "\xfe\xff\xff"
+        r = packet.unpack_uint24(bytes)
+        self.assertEqual(r, 16777214)
+
+
+class TestPacketHeader(unittest.TestCase):
+    def test_pack(self):
+        """Convert (length, packetOrder) to
+        
+        BBBB
+
+        to packet header
+        """
+        packed = packet.packHeader(12, 0)
+        self.assertEqual(packed, "\x0c\x00\x00\x00")
+
+    def test_unpack(self):
+        bytes = "\x0c\x00\x00\x00"
+        length, order = packet.unpackHeader(bytes)
+        self.assertEqual(length, 12)
+        self.assertEqual(order, 0)
+                         
+
+class TestHandshakeInitialization(unittest.TestCase):
+    def test_pack(self):
+        protocolVersion = 8
+        serverVersion = "5.1b1"
+        threadId = 1000
+        scrambleBuffer = "abcdefgh"
+        serverCapabilities = 0x0000
+        serverLanguage = 33 # language code 33 is unicode
+        serverStatus = 0x0000
+        restOfScramble = "abcdefghijklm"
+        bytes = packet.packHandshakeInitialization(protocolVersion,
+                                                   serverVersion,
+                                                   threadId,
+                                                   scrambleBuffer,
+                                                   serverCapabilities,
+                                                   serverLanguage,
+                                                   serverStatus,
+                                                   restOfScramble)
+        self.assertEqual(bytes, "\x085.1b1\x00\x00\xe8\x03\x00\x00\x00\x00\x00\x00abcdefgh\x00\x00\x00\x00!\x00\x00\x00\x00abcdefghijklm")
+
+    def test_unpack(self):
+        bytes = "\x085.1b1\x00\x00\xe8\x03\x00\x00\x00\x00\x00\x00abcdefgh\x00\x00\x00\x00!\x00\x00\x00\x00abcdefghijklm"
+        values = packet.unpackHandshakeInitialization(bytes)
+        self.assertEqual(len(values), 8)
+        self.assertEqual(values, (8, "5.1b1",
+                                  1000, "abcdefgh",
+                                  0x0000, 33,
+                                  0x0000, "abcdefghijklm"))
+                                  
+
+class TestUnpackingPacket(unittest.TestCase):
+    def setUp(self):
+        pass
+
+    def test_packetHeader(self):
+        """See if we can parse a message that just consists of the packet header
+        """
+        bytes = "\x11\x00\x00\x01"
+        p = packet.Packet(bytes)
+        self.assertEqual(p.length, 17)
+        self.assertEqual(p.order, 1)
+

txmysql/test/test_protocol.py

+""" tests for the mysql protocol factory
+
+"""
+import StringIO
+
+from twisted.internet import protocol
+from twisted.internet import reactor
+from twisted.trial import unittest
+
+from txmysql import protocol as myprotocols
+
+
+
+
+class TestMysqlProtocolConnection(unittest.TestCase):
+    def setUp(self):
+        self.buff = StringIO.StringIO()
+        transport = protocol.FileWrapper(self.buff)
+        f = myprotocols.MysqlProtocolFactory("test",
+                                                  "test")
+        self.p = f.buildProtocol(None)
+        self.p.transport = transport
+        
+    def test_buffering(self):
+        self.p.dataReceived("Hello World")
+        self.assertEqual(self.p.buffer, "Hello World")
+
+    def test_authenticationMessage(self):
+        self.p.connectionMade()
+        self.assertEqual(self.buff.getvalue(), "")
+
+    # def test_handshakeInitializationPacket(self):
+    #     self.p.dataReceived()
+    #     self.assertEqual(len(self.p.packets), 1)
+    #     packet = self.p.packets[0]
+    #     self.assertEqual(packet, None)