Szymon Wróblewski avatar Szymon Wróblewski committed 7248eac

updated message module, added 2 basic unit tests

Comments (0)

Files changed (8)

doc/source/api/index.rst

       
       .. automethod:: get_params
       
+      .. automethod:: get_type_id
+      
       .. automethod:: pack
       
       .. automethod:: register(name[, field_names[, **kwargs]])

pygnetic/__init__.py

 """Network library for Pygame."""
 
 import logging
+import discovery
 import event
 import message
 import network

pygnetic/connection.py

 
     def _send_message(self, message, *args, **kwargs):
         name = message.__name__
-        params = self.message_factory.get_params(message)[1]
+        params = self.message_factory.get_params(message)
         try:
             message_ = message(*args, **kwargs)
         except TypeError, e:

pygnetic/discovery.py

     TIMEOUT = 1
 
 
-class ServerHandler(SocketServer.BaseRequestHandler):
+class ServerHandler(SocketServer.BaseRequestHandler, object):
     def handle(self):
         data, socket = self.request
         _logger.debug('Received data: %r', data)
         name = message.__class__.__name__
         _logger.info('Received %s message from %s', name, self.client_address)
         ret_val, err = getattr(self, 'net_' + name)(message)
-        mid = message_factory.get_params(message.__class__)[0]
+        mid = message_factory.get_type_id(message.__class__)
         ack_data = message_factory.pack(response(message.oid, mid, ret_val, err))
         cnt = socket.sendto(ack_data, self.client_address)
         _logger.info('Sent response to %s', self.client_address)

pygnetic/message.py

         if self._frozen == True:
             _logger.warning("Can't register new messages after connection "
                             "establishment")
+            return
         type_id = self._type_id_cnt = self._type_id_cnt + 1
         packet = namedtuple(name, field_names)
         self._message_names[name] = packet
         """Returns message class with given name.
 
         :param name: name of message
-        :return: message class (namedtuple) or None if not found
+        :return: message class (namedtuple)
         """
-        return self._message_names.get(name)
+        try:
+            return self._message_names[name]
+        except KeyError:
+            raise ValueError('Unknown message name')
+
 
     def get_by_type(self, type_id):
         """Returns message class with given type_id.
 
         :param type_id: type identifier of message
-        :return: message class (namedtuple) or None if not found
+        :return: message class (namedtuple)
         """
-        return self._message_types.get(type_id)
+        try:
+            return self._message_types[type_id]
+        except KeyError:
+            raise ValueError('Unknown message type_id')
 
     def get_params(self, message_cls):
-        """Return tuple containing type_id, and sending keyword arguments
+        """Return dict containing sending keyword arguments
 
         :param message_cls: message class created by register
-        :return: int, dict or None if not found
+        :return: dict
         """
-        return self._message_params.get(message_cls)
+        try:
+            return self._message_params[message_cls][1]
+        except KeyError:
+            raise ValueError('Unregistered message')
+
+    def get_type_id(self, message_cls):
+        """Return message class type_id
+
+        :param message_cls: message class created by register
+        :return: int
+        """
+        try:
+            return self._message_params[message_cls][1]
+        except KeyError:
+            raise ValueError('Unregistered message')
 
     def get_hash(self):
         """Calculate and return hash.

Empty file added.

tests/message_test.py

+if __name__ == '__main__':
+    import sys
+    import os
+    pkg_dir = os.path.dirname(os.path.abspath(__file__))
+    parent_dir, pkg_name = os.path.split(pkg_dir)
+    sys.path.insert(0, parent_dir)
+import unittest
+import pygnetic
+
+
+class MessageTests(unittest.TestCase):
+
+    def setUp(self):
+        pygnetic.serialization.select_adapter('msgpack')
+        self.message_factory = pygnetic.message.MessageFactory()
+
+    def test_register(self):
+        message_factory = pygnetic.message.MessageFactory()
+        name = 'test_01'
+        fields = ('name_01', 'name_02', 'name_03')
+        parameters = {'arg_1': 1, 'arg_2': 2, 'arg_3': 3}
+        test_01 = message_factory.register(name, fields, **parameters)
+        self.assertIsNotNone(test_01)
+        self.assertEqual(test_01.__name__, name)
+        self.assertTupleEqual(test_01._fields, fields)
+        self.assertDictEqual(message_factory.get_params(test_01), parameters)
+
+if __name__ == '__main__':
+    unittest.main(verbosity=2)

tests/serialization_test.py

+if __name__ == '__main__':
+    import sys
+    import os
+    pkg_dir = os.path.dirname(os.path.abspath(__file__))
+    parent_dir, pkg_name = os.path.split(pkg_dir)
+    sys.path.insert(0, parent_dir)
+import unittest
+import pygnetic
+
+
+class CommonTests(object):
+    def test_select_adapter(self):
+        pygnetic.serialization.select_adapter(self.adapter_lib_name)
+        self.assertEqual(pygnetic.serialization.selected_adapter,
+                         self.adapter,
+                         'incorrect selected adapter')
+        self.assertEqual(pygnetic.serialization.pack,
+                         self.adapter.pack,
+                         'incorrect selected adapter pack function')
+        self.assertEqual(pygnetic.serialization.unpack,
+                         self.adapter.unpack,
+                         'incorrect selected adapter unpack function')
+        self.assertEqual(pygnetic.serialization.unpacker,
+                         self.adapter.unpacker,
+                         'incorrect selected adapter unpacker class')
+
+    def test_get_adapter(self):
+        adapter = pygnetic.serialization.get_adapter(self.adapter_lib_name)
+        self.assertEqual(adapter, self.adapter,
+                         'incorrect selected adapter')
+
+
+class MsgpackAdapterTests(unittest.TestCase, CommonTests):
+    def setUp(self):
+        import pygnetic.serialization.msgpack_adapter as adapter
+        self.adapter = adapter
+        self.adapter_lib_name = 'msgpack'
+
+
+class JsonAdapterTests(unittest.TestCase, CommonTests):
+    def setUp(self):
+        import pygnetic.serialization.json_adapter as adapter
+        self.adapter = adapter
+        self.adapter_lib_name = 'json'
+
+
+if __name__ == '__main__':
+    unittest.main(verbosity=2)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.