Szymon Wróblewski avatar Szymon Wróblewski committed e31a25c

reworked message module, added streaming feature to json decoder, removed Connection.state and State, added Connection.connected, expanded socket_adapter.Connection

Comments (0)

Files changed (14)

pygame_network/__init__.py

 import message
 import syncobject
 from handler import Handler
-from network.base_adapter import State
 
 _logger = logging.getLogger(__name__)
 _network_module = None
 _serialization_module = None
-register = message.MessageFactory.register
+register = message.message_factory.register
 
 
 class Client(object):

pygame_network/message.py

         player = received_msg.player
         msg = received_msg.msg
     """
-    _message_names = {}  # mapping name -> message
-    _message_types = WeakValueDictionary()  # mapping type_id -> message
-    _message_params = WeakKeyDictionary()  # mapping message -> type_id, send par
-    _type_id_cnt = 0
-    _frozen = False
-    _hash = None
-
     def __init__(self):
-        # override class variables with instance variables
-        cls = self.__class__
-        self._message_names = cls._message_names.copy()
-        self._message_types = cls._message_types.copy()
-        self._message_params = cls._message_params.copy()
-        self._type_id_cnt = cls._type_id_cnt
+        self._message_names = {}  # mapping name -> message
+        self._message_types = WeakValueDictionary()  # mapping type_id -> message
+        self._message_params = WeakKeyDictionary()  # mapping message -> type_id, send par
+        self._type_id_cnt = 0
         self._frozen = False
         self._hash = None
 
-    @classmethod
-    def register(cls, name, field_names=tuple(), **kwargs):
+    def register(self, name, field_names=tuple(), **kwargs):
         """Register new message type
 
         MessageFactory.register(name, field_names[, **kwargs]): return class
 
         Returns namedtuple class.
         """
-        if cls._frozen == True:
-            raise MessageError("Can't register new messages after "\
-                              "connection establishment")
-        type_id = cls._type_id_cnt = cls._type_id_cnt + 1
+        if self._frozen == True:
+            raise MessageError("Can't register new messages after "
+                               "connection establishment")
+        type_id = self._type_id_cnt = self._type_id_cnt + 1
         packet = namedtuple(name, field_names)
-        cls._message_names[name] = packet
-        cls._message_types[type_id] = packet
-        cls._message_params[packet] = (type_id, kwargs)
+        self._message_names[name] = packet
+        self._message_types[type_id] = packet
+        self._message_params[packet] = (type_id, kwargs)
         return packet
 
-    @classmethod
-    def pack(cls, message, sys_data=[]):
+    def pack(self, message):
         """Pack data to string
 
         MessageFactory.pack(packet_id, packet_obj): return string
 
         Returns message packed in string, ready to send.
         """
-        type_id = cls._message_params[message.__class__][0]
-        message = (type_id,) + tuple(sys_data) + message
+        type_id = self._message_params[message.__class__][0]
+        message = (type_id,) + message
         data = s_lib.pack(message)
         _logger.debug("Packing message (length: %d)", len(data))
         return data
 
-    @classmethod
-    def unpack(cls, data, sys_data=[]):
-        """Unpack data from string, return message_id and message
+    def set_frozen(self):
+        self._frozen = True
 
-        MessageFactory.unpack(data): return (message_id, message)
+    def reset_context(self, context):
+        context._unpacker = s_lib.unpacker()
+
+    def _process_message(self, message):
+        try:
+            type_id = message[0]
+            return self._message_types[type_id](*message[1:])
+        except KeyError:
+            _logger.error('Unknown message type_id: %s', type_id)
+        except:
+            _logger.error('Message unpacking error: %s', message)
+
+    def unpack(self, data):
+        """Unpack data from string and buffer, return message
+
+        MessageFactory.unpack(data): return message
 
         data - packed message data as a string
         """
         _logger.debug("Unpacking message (length: %d)", len(data))
         try:
             message = s_lib.unpack(data)
-        except StopIteration:  # end of stream
-            _logger.warning('Not enough data to unpack')
         except:
-            _logger.error('Data corrupted (No streaming support in library)')
+            _logger.error('Data corrupted')
+            self._reset_unpacker()  # prevent from corrupting next data
             return
-        sys_data_l = len(sys_data)
+        return self._process_message(message)
+
+    def unpack_all(self, data, context):
+        """Unpack all data from string and buffer, return message generator
+
+        MessageFactory.unpack(data): return message generator
+
+        data - packed data as a string
+        """
+        _logger.debug("Unpacking data (length: %d)", len(data))
+        context._unpacker.feed(data)
         try:
-            type_id = message[0]
-            sys_data[:] = message[1:1 + sys_data_l]
-            return cls._message_types[type_id](*message[1 + sys_data_l:])
-        except KeyError:
-            # should not happen
-            # prevented by hash control during connection
-            _logger.error('Unknown message type_id: %s', type_id)
+            for message in context._unpacker:
+                yield self._process_message(message)
         except:
-            _logger.error('Message unpacking error: %s', message)
-        return
+            _logger.error('Data corrupted')
+            self._reset_unpacker()  # prevent from corrupting next data
+            return
 
-    @classmethod
-    def get_by_name(cls, name):
+    def get_by_name(self, name):
         """Returns message class with given name
 
         MessageFactory.get_by_name(name): return class
 
         Returns namedtuple class.
         """
-        return cls._message_names[name]
+        return self._message_names[name]
 
-    @classmethod
-    def get_by_type(cls, type_id):
+    def get_by_type(self, type_id):
         """Returns message class with given type_id
 
         MessageFactory.get_by_type(name): return class
 
         Returns namedtuple class.
         """
-        return cls._message_types[type_id]
+        return self._message_types[type_id]
 
-    @classmethod
-    def get_params(cls, message):
+    def get_params(self, message):
         """Return tuple containing type_id, and sending keyword args
 
         MessageFactory.get_params(message): return (int, dict)
 
         message - message class created by register
         """
-        return cls._message_params[message]
+        return self._message_params[message]
 
-    @classmethod
-    def get_hash(cls):
-        if cls._frozen:
-            if cls._hash is None:
-                ids = cls._message_types.keys()
+    def get_hash(self):
+        """Calculate and return hash.
+
+        Hash depends on registered messages and used serializing library.
+        """
+        if self._frozen:
+            if self._hash is None:
+                ids = self._message_types.keys()
                 ids.sort()
                 l = list()
                 l.append(s_lib.__name__)
                 for i in ids:
-                    p = cls._message_types[i]
+                    p = self._message_types[i]
                     l.append((i, p.__name__, p._fields))
                 # should be the same on 32 & 64 platforms
-                cls._hash = hash(tuple(l)) & 0xffffffff
-            return cls._hash
+                self._hash = hash(tuple(l)) & 0xffffffff
+            return self._hash
         else:
             _logger.warning('Attempt to get hash of not frozen MessageFactory')
 
 
-update_remoteobject = MessageFactory.register('update_remoteobject', (
+message_factory = MessageFactory()
+update_remoteobject = message_factory.register('update_remoteobject', (
     'type_id',
     'obj_id',
     'variables'
 ), channel=1, flags=0)
-chat_msg = MessageFactory.register('chat_msg', (
+chat_msg = message_factory.register('chat_msg', (
     'player',
     'msg'
 ))

pygame_network/network/base_adapter.py

 import logging
 from weakref import proxy
 from functools import partial
-from ..message import MessageFactory
+from .. import message
 from .. import event
 
 __all__ = ('Connection', 'State')
 _logger = logging.getLogger(__name__)
 
 
-class State(object):
-    (CONNECTED,
-    CONNECTING,
-    DISCONNECTED,
-    DISCONNECTING) = range(4)
-
-
 class Connection(object):
     """Class allowing to send messages
 
     * using send method with message as argument
 
     """
-    message_factory = MessageFactory
+    message_factory = message.message_factory
 
     def __init__(self, parent, message_factory=None, *args, **kwargs):
         super(Connection, self).__init__(*args, **kwargs)
         self.parent = proxy(parent)
         if message_factory is not None:
             self.message_factory = message_factory
+        self.message_factory.reset_context(self)
         self._handlers = []
         self.data_sent = 0
         self.data_received = 0
 
     def __getattr__(self, name):
         parts = name.split('_', 1)
-        if len(parts) == 2 and parts[0] == 'net' and\
-                parts[1] in self._message_factory._message_names:
-            p = partial(self._send_message, self._message_factory.get_by_name(parts[1]))
+        if (len(parts) == 2 and parts[0] == 'net' and
+                parts[1] in self.message_factory._message_names):
+            p = partial(self._send_message, self.message_factory.get_by_name(parts[1]))
             p.__doc__ = "Send %s message to remote host\n\nHost.net_%s" % (
                 parts[1],
-                self._message_factory._message_names[parts[1]].__doc__
+                self.message_factory._message_names[parts[1]].__doc__
             )
             # add new method so __getattr__ is no longer needed
             setattr(self, name, p)
         Pygame event queue if sending was successful.
         """
         if isinstance(message, basestring):
-            message = self._message_factory.get_by_name(message)
+            message = self.message_factory.get_by_name(message)
         self._send_message(message, *args, **kwargs)
 
     def _send_message(self, message, *args, **kwargs):
         name = message.__name__
-        params = self._message_factory.get_params(message)[1]
+        params = self.message_factory.get_params(message)[1]
         try:
             message_ = message(*args, **kwargs)
         except TypeError, e:
             e, f = re.findall(r'\d', e.message)
             raise TypeError('%s takes exactly %d arguments (%d given)' %
                 (message.__doc__, int(e) - 1, int(f) - 1))
-        data = self._message_factory.pack(message_)
+        data = self.message_factory.pack(message_)
         _logger.info('Sent %s message to %s:%s', name, *self.address)
         self.data_sent += len(data)
         self.messages_sent += 1
         return self._send_data(data, **params)
 
     def _receive(self, data, channel=0):
-        message = self._message_factory.unpack(data)
-        if message is None:
-            return
-        name = message.__class__.__name__
-        _logger.info('Received %s message from %s:%s', name, *self.address)
-        event.received(self, message, channel)
-        for h in self._handlers:
-            getattr(h, 'net_' + name, h.on_recive)(message, channel)
+        for message in self.message_factory.unpack_all(data, self):
+            name = message.__class__.__name__
+            _logger.info('Received %s message from %s:%s', name, *self.address)
+            event.received(self, message, channel)
+            for h in self._handlers:
+                getattr(h, 'net_' + name, h.on_recive)(message, channel)
 
     def _connect(self):
         _logger.info('Connected to %s:%s', *self.address)
         handler.connection = proxy(self)
 
     @property
-    def state(self):
-        """Connection state."""
-        return State.DISCONNECTED
-
-    @property
     def address(self):
         """Connection address."""
         return None, None
 
 
 class Server(object):
-    message_factory = MessageFactory
+    message_factory = message.message_factory
     handler = None
 
     def __init__(self, address='', port=0, connections_limit=4, *args, **kwargs):
         _logger.debug('Server created %s, connections limit: %d', address, connections_limit)
-        self.message_factory._frozen = True
+        self.message_factory.set_frozen()
         _logger.debug('MessageFactory frozen')
         self.conn_map = {}
 

pygame_network/network/enet_adapter/client.py

 import logging
 import enet
-from ...message import MessageFactory
+from ... import message
 from connection import Connection
 
 _logger = logging.getLogger(__name__)
         while True:
             client.step()
     """
-    def __init__(self, connections_limit=1, channel_limit=0, in_bandwidth=0, out_bandwidth=0):
-        self.host = enet.Host(None, connections_limit, channel_limit, in_bandwidth, out_bandwidth)
+    def __init__(self, connections_limit=1, *args, **kwargs):
+        super(Client, self).__init__(*args, **kwargs)
+        self.host = enet.Host(None, connections_limit)
         self._peers = {}
         self._peer_cnt = 0
         _logger.debug('Client created, connections limit: %d', connections_limit)
 
-    def connect(self, address, port, channels=2, message_factory=MessageFactory):
+    def connect(self, address, port, channels=2, message_factory=message.message_factory):
         address = enet.Address(address, port)
         _logger.info('Connecting to %s', address)
         peer_id = self._peer_cnt = self._peer_cnt + 1
         peer_id = str(peer_id)
         # Can't register messages after connection
-        message_factory._frozen = True
+        message_factory.set_frozen()
         _logger.debug('MessageFactory frozen')
         peer = self.host.connect(address, channels, message_factory.get_hash())
         peer.data = peer_id

pygame_network/network/enet_adapter/connection.py

 import enet
 from .. import base_adapter
 
-_state_mapping = {
-    enet.PEER_STATE_ACKNOWLEDGING_CONNECT: base_adapter.State.CONNECTING,
-    enet.PEER_STATE_ACKNOWLEDGING_DISCONNECT: base_adapter.State.DISCONNECTING,
-    enet.PEER_STATE_CONNECTED: base_adapter.State.CONNECTED,
-    enet.PEER_STATE_CONNECTING: base_adapter.State.CONNECTING,
-    enet.PEER_STATE_CONNECTION_PENDING: base_adapter.State.CONNECTING,
-    enet.PEER_STATE_CONNECTION_SUCCEEDED: base_adapter.State.CONNECTING,
-    enet.PEER_STATE_DISCONNECTED: base_adapter.State.DISCONNECTED,
-    enet.PEER_STATE_DISCONNECTING: base_adapter.State.DISCONNECTING,
-    enet.PEER_STATE_DISCONNECT_LATER: base_adapter.State.DISCONNECTING,
-    enet.PEER_STATE_ZOMBIE: base_adapter.State.DISCONNECTING
-}
-
 
 class Connection(base_adapter.Connection):
     """Class allowing to send messages
             self.peer.disconnect_now()
 
     @property
-    def state(self):
+    def connected(self):
         """Connection state."""
-        return _state_mapping[self.peer.state]
+        return self.peer.state == enet.PEER_STATE_CONNECTED
 
     @property
     def address(self):

pygame_network/network/enet_adapter/server.py

 from weakref import proxy
 import enet
 from ...handler import Handler
-from ...message import MessageFactory
+from ... import message
 from connection import Connection
 
 _logger = logging.getLogger(__name__)
 
 
 class Server(object):
-    message_factory = MessageFactory
+    message_factory = message.message_factory
     handler = None
 
     def __init__(self, address='', port=0, con_limit=4, *args, **kwargs):
+        super(Server, self).__init__(*args, **kwargs)
         address = enet.Address(address, port)
         self.host = enet.Host(address, con_limit, *args, **kwargs)
         self.conn_map = {}
         self._peer_cnt = 0
         _logger.debug('Server created %s, connections limit: %d', address, con_limit)
-        self.message_factory._frozen = True
+        self.message_factory.set_frozen()
         _logger.debug('MessageFactory frozen')
 
 
                     peer_id = str(peer_id)
                     event.peer.data = peer_id
                     connection = Connection(self, event.peer, self.message_factory)
-                    if issubclass(self.handler, Handler):
+                    if self.handler is not None and issubclass(self.handler, Handler):
                         handler = self.handler()
                         handler.server = proxy(self)
                         connection.add_handler(handler)

pygame_network/network/socket_adapter/__init__.py

+from connection import Connection

pygame_network/network/socket_adapter/connection.py

 import logging
+import socket
+import asyncore
 from collections import deque
-from asyncore import dispatcher
 from .. import base_adapter
 from ...message import MessageFactory
 
 _logger = logging.getLogger(__name__)
 
 
-class Connection(base_adapter.Connection, dispatcher):
-    __send = dispatcher.send
-    rcvr_buffer_size = 2048
+class Connection(base_adapter.Connection, asyncore.dispatcher):
+    __send = asyncore.dispatcher.send
+    # maximum amount of data received / sent at once
+    recv_buffer_size = 4096
 
     def __init__(self, parent, socket, message_factory=MessageFactory):
         super(Connection, self).__init__(parent, message_factory)
-        self._queue = deque()
-        self.state = base_adapter.State.DISCONNECTED
+        #self.send_queue = deque()
+        self.send_buffer = bytearray()
+        self.recv_buffer = bytearray(b'\0' * self.recv_buffer_size)
 
     def _send_data(self, data, **kwargs):
-        self._message_queue.append(data)
+        self.send_buffer.extend(data)
+        self._send_part()
+
+#    def _send_data2(self, data, **kwargs):
+#        if len(self.send_buffer) == 0:
+#            self.send_buffer = data
+#        else:
+#            self.send_queue.append(data)
+#        self._send_part()
+
+    def handle_write(self):
+        self._send_part()
+
+    def _send_part(self):
+        try:
+            num_sent = self.__send(self.send_buffer)
+        except socket.error:
+            self.handle_error()
+            return
+        self.send_buffer = self.send_buffer[num_sent:]
+
+    def writable(self):
+        return (not self.connected) or len(self.send_buffer)
+
+    def handle_read(self):
+        # tinkering with dispatcher internal variables,
+        # because it doesn't support socket.recv_into
+        try:
+            data = self.socket.recv_into(self.recv_buffer)
+            if not data:
+                self.handle_close()
+            else:
+                self._receive(data)
+        except socket.error, why:
+            if why.args[0] in asyncore._DISCONNECTED:
+                self.handle_close()
+            else:
+                self.handle_error()
+            return
+        self._receive(self.recv(self.recv_buffer_size))
+
+    def handle_connect(self):
+        self._connect()
+
+    def handle_close(self):
+        self.log_info('unhandled close event', 'warning')
+        self.close()
+
+    def log_info(self, message, type='info'):
+        return getattr(_logger, type)(message)
 
     def disconnect(self, *args):
         """Request a disconnection."""
     @property
     def address(self):
         """Connection address."""
-        return None, None
-
-    def handle_read(self):
-        self._receive(self.recv(self.rcvr_buffer_size))
-
-    def handle_write(self):
-        #queue, self._queue = self._queue, deque()
-        for data in self._queue:
-            self.__send(data)
-        self._queue = deque()
-
-    def handle_connect(self):
-        self.state = base_adapter.State.CONNECTED
-        self._connect()
-
-    def handle_close(self):
-        self.log_info('unhandled close event', 'warning')
-        self.close()
-
-    def log_info(self, message, type='info'):
-        return getattr(_logger, type)(message)
+        return self.getpeername()

pygame_network/serialization/json_adapter.py

 import json
-_packer = json.JSONEncoder()
-_unpacker = json.JSONDecoder()
-pack = _packer.encode
-unpack = _unpacker.decode
+
+
+class JSONDecoder(json.JSONDecoder):
+    """JSONDecoder with streaming feature"""
+    def __init__(self):
+        super(JSONDecoder, self).__init__()
+        self.buffer = bytearray()
+
+    def feed(self, data):
+        self.buffer.extend(data)
+
+    def decode(self):
+        try:
+            obj, end = self.scan_once(self.buffer.decode(), 0)
+            del self.buffer[:end]
+            return obj
+        except:
+            raise StopIteration('No more data to decode.')
+
+    next = decode
+
+
+pack = json.dumps
+unpack = json.loads
+unpacker = JSONDecoder

pygame_network/serialization/msgpack_adapter.py

 import msgpack
-_packer = msgpack.Packer()
-_unpacker = msgpack.Unpacker()
-pack = _packer.pack
-unpack = lambda data: _unpacker.feed(data) or _unpacker.unpack()
+
+pack = msgpack.packb
+unpack = msgpack.unpackb
+unpacker = msgpack.Unpacker

test_client_1.py

-import random
-import logging
-import pygame_network as net
-
-
-def main():
-    net.init(logging_lvl=logging.DEBUG)
-    net.register('echo', ('msg', 'msg_id'))
-    client = net.Client()
-    connection = client.connect("localhost", 54301)
-    counter = 0
-    while connection.state != net.State.DISCONNECTED:
-        client.step()
-        if counter < 10 and connection.state == net.State.CONNECTED:
-            msg = ''.join(random.sample('abcdefghijklmnopqrstuvwxyz', 10))
-            logging.info('Sending: %s', msg)
-            connection.net_echo(msg, counter)
-            counter += 1
-            if counter == 10:
-                connection.disconnect()
-
-
-if __name__ == '__main__':
-    main()
 
 class EchoHandler(net.Handler):
     def __init__(self):
-        self.counter = 10
+        self.out_counter = 0
+        self.in_counter = 0
 
     def net_echo(self, message, channel):
         logging.info('Received message @ch%d: %s', channel, message)
+        self.in_counter += 1
 
     def step(self):
-        if self.counter > 0 and self.connection.state == net.State.CONNECTED:
+        if self.out_counter < 10 and self.connection.connected:
             msg = ''.join(random.sample('abcdefghijklmnopqrstuvwxyz', 10))
             logging.info('Sending: %s', msg)
-            self.connection.net_echo(msg, self.counter)
-            self.counter -= 1
-            if self.counter == 0:
-                self.connection.disconnect()
+            self.connection.net_echo(msg, self.out_counter)
+            self.out_counter += 1
+        if self.out_counter == 10 and self.in_counter == 10:
+            self.connection.disconnect()
 
 
 def main():
             if e.type == KEYDOWN:
                 if e.key == K_SPACE:
                     if connection is not None:
-                        if connection.state == net.State.CONNECTED:
+                        if connection.connected:
                             connection.disconnect()
                             connection_status(screen, (140, 38), False)
                     else:
                         message_status(screen, (110, 62), messages)
             if e.type == QUIT or e.type == KEYDOWN and e.key == K_ESCAPE:
                 run = False
-        if len(messages) < 10 and connection is not None and connection.state == net.State.CONNECTED:
+        if len(messages) < 10 and connection is not None and connection.connected:
             msg = ''.join(random.sample('abcdefghijklmnopqrstuvwxyz', 10))
 
             # Sending messages
 
 class EchoHandler(net.Handler):
     def net_echo(self, message, channel):
+        logging.info('Received message @ch%d: %s', channel, message)
         msg = message.msg.upper()
         self.connection.net_echo(msg, message.msg_id)
-        logging.info('message @ch%d: %s', channel, message)
 
 
 class Server(net.Server):
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.