Source

txMysql / txmysql / protocol.py

""" protocol.py
"""
import sha
try:
    import cStringIO as stringio
except ImportError:
    import StringIO as stringio
import struct
import sys    

from twisted.internet import defer
from twisted.internet import protocol
from twisted.python import log

from txmysql import enum
from txmysql import imysql
from txmysql import packet

# CONSTANTS:
# * CLIENT CAPABILITIES:
CLIENT_LONG_PASSWORD = 1
CLIENT_FOUND_ROWS = 1 << 1
CLIENT_LONG_FLAG = 1 << 2
CLIENT_CONNECT_WITH_DB = 1 << 3
CLIENT_NO_SCHEMA = 1 << 4
CLIENT_COMPRESS = 1 << 5
CLIENT_ODBC = 1 << 6
CLIENT_LOCAL_FILES = 1 << 7
CLIENT_IGNORE_SPACE = 1 << 8
CLIENT_PROTOCOL_41 = 1 << 9
CLIENT_INTERACTIVE = 1 << 10
CLIENT_SSL = 1 << 11
CLIENT_IGNORE_SIGPIPE = 1 << 12
CLIENT_TRANSACTIONS  = 1 << 13
CLIENT_SECURE_CONNECTION = 1 << 15
CLIENT_MULTI_STATEMENTS = 1 << 16
CLIENT_MULTI_RESULTS = 1 << 17
CLIENT_CAPABILITIES = (CLIENT_LONG_PASSWORD|CLIENT_LONG_FLAG|CLIENT_TRANSACTIONS| 
                        CLIENT_PROTOCOL_41|CLIENT_SECURE_CONNECTION)

# * SERVER CAPABILITIES:
CLIENT_LONG_PASSWORD      = 1 # /* new more secure passwords */
CLIENT_FOUND_ROWS         = 2 # /* Found instead of affected rows */
CLIENT_LONG_FLAG          = 4 # /* Get all column flags */
CLIENT_CONNECT_WITH_DB    = 8 # /* One can specify db on connect */
CLIENT_NO_SCHEMA          = 16 # /* Don't allow database.table.column */
CLIENT_COMPRESS           = 32 # /* Can use compression protocol */
CLIENT_ODBC               = 64 # /* Odbc client */
CLIENT_LOCAL_FILES        = 128 # /* Can use LOAD DATA LOCAL */
CLIENT_IGNORE_SPACE       = 256 # /* Ignore spaces before '(' */
CLIENT_PROTOCOL_          = 41512 # /* New 4.1 protocol */
CLIENT_INTERACTIVE        = 1024 # /* This is an interactive client */
CLIENT_SSL                = 2048 # /* Switch to SSL after handshake */
CLIENT_IGNORE_SIGPIPE     = 4096     # /* IGNORE sigpipes */
CLIENT_TRANSACTIONS       = 8192 # /* Client knows about transactions */
CLIENT_RESERVED           = 16384    # /* Old flag for 4.1 protocol  */
CLIENT_SECURE_CONNECTION  = 32768   # /* New 4.1 authentication */
CLIENT_MULTI_STATEMENTS   = 65536    # /* Enable/disable multi-stmt support */
CLIENT_MULTI_RESULTS      = 131072   # /* Enable/disable multi-results */

# SERVER_STATUS
SERVER_STATUS_IN_TRANS = 1
SERVER_STATUS_AUTOCOMMIT = 2
SERVER_MORE_RESULTS_EXISTS = 8
SERVER_QUERY_NO_GOOD_INDEX_USED = 16
SERVER_QUERY_NO_INDEX_USED = 32
SERVER_STATUS_CURSOR_EXISTS = 64
SERVER_STATUS_LAST_ROW_SENT = 128
SERVER_STATUS_DB_DROPPED = 256
SERVER_STATUS_NO_BACKSLASH_ESCAPES = 512
SERVER_STATUS_METADATA_CHANGED = 1024

# COMMAND CODES:
COM_SLEEP = 0x00
COM_QUIT = 0x01 #(mysql_close)
COM_INIT_DB = 0x02 #(mysql_select_db)
COM_QUERY = 0x03 #(mysql_real_query)

# PROTOCOLS:
protocolStates = enum.Enum("AWAITING_HANDSHAKE", "AUTHENTICATING", "CONNECTED")

def is_ascii(data):
    if data.isalnum():
        return data
    return '.'


def dump_packet(data):    
    print "packet length %d" % len(data)
    print "method call[1]: %s" % sys._getframe(1).f_code.co_name
    print "method call[2]: %s" % sys._getframe(2).f_code.co_name
    print "method call[3]: %s" % sys._getframe(3).f_code.co_name
    print "method call[4]: %s" % sys._getframe(4).f_code.co_name
    print "method call[5]: %s" % sys._getframe(5).f_code.co_name
    print "-" * 88
    dump_data = [data[i:i+16] for i in xrange(len(data)) if i%16 == 0]
    for d in dump_data:
        print ' '.join(map(lambda x:"%02X" % ord(x), d)) + \
                '   ' * (16 - len(d)) + ' ' * 2 + ' '.join(map(lambda x:"%s" % is_ascii(x), d))
    print "-" * 88
    print ""


class MysqlProtocol(protocol.Protocol):
    def __init__(self):
        self.buffer = ""
        self.state = protocolStates.AWAITING_HANDSHAKE

    def msg(self, *msg):
        sys.stdout.write("%s\n"%(" ".join([str(x) for x in msg]),))
        sys.stdout.flush()

    def dataReceived(self, data):
        self.buffer += data

        if self.state == protocolStates.AWAITING_HANDSHAKE:
            pac = self.getHandshakeInitialization(self.buffer)
            if pac is not None:
                self.state = protocolStates.AUTHENTICATING
                self.sendAuthentication(pac)
        elif self.state == protocolStates.AUTHENTICATING:
            p = self.getPacket(self.buffer)
            if p:
                field_count = struct.unpack("B", p.bytes[0])[0]
                print "FIELD_COUNT:", field_count
                if field_count == 0:
                    print "OK!"
                    self.state = protocolStates.CONNECTED
                    # fire some callback to let users know the protocol is connected
                    self.factory.deferred.callback(self)
                else:
                    print "ERROR!"
        elif self.state == protocolStates.CONNECTED:
            p = self.getPacket(self.buffer)
        elif self.state == protocolStates.RUNNING_COMMAND:
            p = self.getPacket(self.buffer)
            if p:
                d = self.deferreds.popleft()
                self.state = protocolStates.CONNECTED
                d.callback(self, p)

    def quit(self):
        quit_packet = struct.pack("B4s", COM_QUIT, "quit")
        self.write(quit_packet)

    def selectdb(self, dbname):
        select_db = struct.pack("<i", len(dbname)+1) + struct.pack("B", COM_INIT_DB) + dbname
        self.write(select_db)

    def query(self, sql):
        select_packet = struct.pack("<i", len(sql)+1) + struct.pack("B", COM_QUERY) + sql
        self.write(select_packet)
        
    def sendAuthentication(self, greetingPacket):
        salt = (greetingPacket.scrambleBuffer +
                greetingPacket.restOfScrambleBuffer)
        self.factory.clientFlags |= CLIENT_CAPABILITIES
        if greetingPacket.serverVersion.startswith('5'):
            self.factory.clientFlags |= CLIENT_MULTI_RESULTS
        data = (struct.pack('=i', self.factory.clientFlags) +
                "\x00\x00\x00\x01" +  '\x21' + '\x00'*23 + 
                self.factory.username + "\x00" +
                self._scramble(self.factory.password, salt))
        if self.factory.database:
            data += self.factory.database + "\x00"
        data = packet.pack_uint24(len(data)) + "\x01" + data
        self.write(data)

    def write(self, p):
        dump_packet(p)
        self.transport.write(p)

    def _scramble(self, password, message):
        if password == None or len(password) == 0:
            return '\0'
        stage1 = sha.new(password).digest()
        stage2 = sha.new(stage1).digest()
        s = sha.new()
        s.update(message)
        s.update(stage2)
        result = s.digest()
        return self._my_crypt(result, stage1)

    def _my_crypt(self, message1, message2):
        length = len(message1)
        result = struct.pack('B', length)
        for i in xrange(length):
            x = (struct.unpack('B', message1[i:i+1])[0] ^ struct.unpack('B', message2[i:i+1])[0])
            result += struct.pack('B', x)
        return result

    def getPacket(self, buff):
        p = None
        if len(buff) >= packet.packetHeaderLength:
            length, order = packet.unpackHeader(buff[0:4])
            if (len(buff) >= length+packet.packetHeaderLength):
                # create the packet
                dump_packet(buff)
                p = packet.Packet.fromBytes(buff[0:(length+packet.packetHeaderLength)])
                # remove packet bytes from buffer
                self.buffer = buff[(length+packet.packetHeaderLength):]
        return p

    def getHandshakeInitialization(self, buff):
        p = self.getPacket(buff)
        if p:
            p = imysql.IHandshakeInitializationPacket(p)
        return p

    def connectionMade(self):
        self.msg("Connection made")
        #self.transport.write("CONNECT")

    
class MysqlProtocolFactory(protocol.ClientFactory):
    protocol = MysqlProtocol

    def __init__(self, username,
                 password,
                 database=None,
                 character_set="utf-8",
                 flags=CLIENT_CAPABILITIES,
                 max_packet_size=None):
        self.username = username
        self.password = password
        self.database = database
        self.characterSet = character_set
        self.clientFlags = flags
        self.maxPacketSize = max_packet_size
        self.deferred = defer.Deferred()

    def msg(self, *msg):
        sys.stdout.write("%s\n"%(" ".join(msg),))
        sys.stdout.flush()

    def getDeferred(self):
        return self.deferred
        
    def startedConnecting(self, connector):
        self.msg("started to connect...")

    def clientConnectionFailed(self, connector, reason):
        self.msg("connection failed...")
        self.deferred.errback(reason)

    def clientConnectionLost(self, connector, reason):
        self.msg("connection lost.")