Commits

Anonymous committed 5c0af7e

Made the allowed imports list (set) global to each serializer class (mainly because of pyamf), changed tests accordingly

Comments (0)

Files changed (4)

pyhttprpc/serializer/base.py

 
 
 class BaseSerializer(object):
-    def __init__(self, allowed_imports):
-        self._allowed_imports = []
-        self.add_allowed_import(r'exceptions\..*')
-        self.add_allowed_import(r'pyhttprpc\..*')
-
-        for expr in allowed_imports:
-            self.add_allowed_import(expr)
-            if isinstance(expr, basestring):
-                expr = re.compile(expr)
-            self._allowed_imports.append(expr)
-
-    def add_allowed_import(self, expr):
+    @classmethod
+    def add_allowed_import(cls, expr):
         """
         Adds an expression to the list that defines the allowed imports
         when deserializing object instances. Each expression is a regular
         """
         if isinstance(expr, basestring):
             expr = re.compile(expr)
-        self._allowed_imports.append(expr)
+        cls._allowed_imports.add(expr)
 
-    def _check_import_name(self, name):
+    @classmethod
+    def _check_import_name(cls, name):
         from pyhttprpc import DeserializationError
 
-        for expr in self._allowed_imports:
+        for expr in cls._allowed_imports:
             if expr.match(name):
                 return True
         raise DeserializationError('Not allowed to import %s' % name)
 
-    def find_global(self, modulename, classname):
+    @classmethod
+    def find_global(cls, modulename, classname):
         """
         Dynamically imports the given class from the given module,
         but only if the module name passes the permission check.
         :return: the imported class
         """
         name = modulename + '.' + classname
-        self._check_import_name(name)
+        cls._check_import_name(name)
 
         module = __import__(modulename)
         components = modulename.split('.')

pyhttprpc/serializer/pickle_serializer.py

 
 
 class PickleSerializer(BaseSerializer):
-    def __init__(self, protocol=2, allowed_imports=()):
+    _allowed_imports = set()
+
+    def __init__(self, protocol=2):
         """
         Initializes the pickle serializer.
 
         :param find_global: callable that, if defined, is called to load a
             class for deserialization
         """
-        BaseSerializer.__init__(self, allowed_imports)
-
         if protocol > pickle.HIGHEST_PROTOCOL:
             raise ValueError('protocol cannot be higher than %d' %
                              pickle.HIGHEST_PROTOCOL)
         unpickler = pickle.Unpickler(StringIO(data))
         unpickler.find_global = self.find_global
         return unpickler.load()
+
+PickleSerializer.add_allowed_import(r'exceptions\..*')
+PickleSerializer.add_allowed_import(r'pyhttprpc\..*')

pyhttprpc/serializer/pyamf_serializer.py

 from pyhttprpc.serializer.base import BaseSerializer
 
 
-_old_get_class_alias = None
+_original_get_class_alias = pyamf.get_class_alias
 
 def _new_get_class_alias(cls):
     try:
-        return _old_get_class_alias(cls)
+        return _original_get_class_alias(cls)
     except pyamf.UnknownClassAlias:
         alias = '%s.%s' % (cls.__module__, cls.__name__)
         return pyamf.register_class(cls, alias)
 
 
 class PyAMFSerializer(BaseSerializer):
-    def __init__(self, allowed_imports=(), monkey_patch=True):
+    _allowed_imports = set()
+
+    def __init__(self, monkey_patch=True):
         """
         Initializes the PyAMF serializer.
 
             with a new version that auto-registers classes of outgoing
             instances so that they are properly deserialized on the other end. 
         """
-        BaseSerializer.__init__(self, allowed_imports)
-
-        global _old_get_class_alias
-        if monkey_patch and not _old_get_class_alias:
-            _old_get_class_alias = pyamf.get_class_alias
+        if monkey_patch and pyamf.get_class_alias == _original_get_class_alias:
             pyamf.get_class_alias = _new_get_class_alias
 
         try:
-            pyamf.register_class_loader(self._find_class)
+            pyamf.register_class_loader(PyAMFSerializer._find_class)
         except ValueError:
             pass
 
-    def _find_class(self, classname):
+    @classmethod
+    def _find_class(cls, classname):
         lastdot = classname.rfind('.')
         if lastdot >= 0:
             modulename, classname = classname[:lastdot], classname[lastdot + 1:]
-            return self.find_global(modulename, classname)
+            return cls.find_global(modulename, classname)
 
     def serialize(self, obj):
         return pyamf.encode(obj, encoding=pyamf.AMF3)
 
     def deserialize(self, data):
         return pyamf.decode(data, encoding=pyamf.AMF3).next()
+
+PyAMFSerializer.add_allowed_import(r'exceptions\..*')
+PyAMFSerializer.add_allowed_import(r'pyhttprpc\..*')

tests/test_serializer.py

+# coding: utf-8
 """Tests serializers for correct operation."""
-from nose.tools import eq_, raises
+
+from nose.tools import eq_, assert_raises
 
 from pyhttprpc.serializer.pickle_serializer import PickleSerializer
 from pyhttprpc.serializer.pyamf_serializer import PyAMFSerializer
 
 class AllowedClass(object):
     def __init__(self):
-        self.id = id(self)
+        self.id = 345082340955324
+        self.strdata = 'Test String Data'
+        self.ucdata = u'Unicode Data ÅÄÖ'
 
-    def __eq__(self, obj):
-        return type(obj) is AllowedClass and obj.id == self.id
 
+class DisallowedClass(object):
+    def __init__(self):
+        self.id = 993052791289095
+        self.strdata = 'Test 2 String Data'
+        self.ucdata = u'Unicode 2 Data ÅÄÖ'
 
-class DisallowedClass(AllowedClass):
-    pass
 
+def test_pickle_serialize():
+    obj = AllowedClass()
+    serializer = PickleSerializer()
+    serializer.add_allowed_import(__name__ + '\.AllowedClass')
+    data = serializer.serialize(obj)
 
-def find_global(modulename, classname):
-    if classname == 'AllowedClass':
-        return AllowedClass
+    obj = serializer.deserialize(data)
+    assert isinstance(obj, AllowedClass)
+    eq_(obj.id, 345082340955324)
+    eq_(obj.strdata, 'Test String Data')
+    eq_(obj.ucdata, u'Unicode Data ÅÄÖ')
 
 
-def check_serializer(serializer):
+def test_pyamf_serialize():
     obj = AllowedClass()
+    serializer = PyAMFSerializer()
+    serializer.add_allowed_import(__name__ + '\.AllowedClass')
     data = serializer.serialize(obj)
-    obj2 = serializer.deserialize(data)
-    eq_(obj, obj2)
 
+    obj = serializer.deserialize(data)
+    assert isinstance(obj, AllowedClass)
+    eq_(obj.id, 345082340955324)
+    eq_(obj.strdata, 'Test String Data')
+    eq_(obj.ucdata, u'Unicode Data ÅÄÖ')
 
-@raises(DeserializationError)
-def check_serializer_disallowed(serializer):
-    obj = DisallowedClass()
-    data = serializer.serialize(obj)
-    serializer.deserialize(data)
 
+def test_pickle_deserialize():
+    data = '\x80\x02ctest_serializer\nAllowedClass\nq\x01)\x81q\x02}q\x03'\
+           '(U\x06ucdataq\x04X\x13\x00\x00\x00Unicode Data \xc3\x85\xc3\x84'\
+           '\xc3\x96q\x05U\x07strdataq\x06U\x10Test String Dataq\x07U\x02idq'\
+           '\x08I345082340955324\nub.'
+    serializer = PickleSerializer()
+    serializer.add_allowed_import(__name__ + '\.AllowedClass')
+    obj = serializer.deserialize(data)
+    assert isinstance(obj, AllowedClass)
+    eq_(obj.id, 345082340955324)
+    eq_(obj.strdata, 'Test String Data')
+    eq_(obj.ucdata, u'Unicode Data ÅÄÖ')
 
-def test_serializers():
-    for cls in PickleSerializer, PyAMFSerializer:
-        serializer = cls(allowed_imports=[__name__ + '\.AllowedClass'])
-        yield check_serializer, serializer
-        yield check_serializer_disallowed, serializer
+
+def test_pyamf_deserialize():
+    data = "\n\x0b9test_serializer.AllowedClass\rucdata\x06'Unicode Data \xc3"\
+           "\x85\xc3\x84\xc3\x96\x0fstrdata\x06!Test String Data\x05id\x05B"\
+           "\xf3\x9d\x9b\xe5\x9dK\xc0\x01"
+    serializer = PyAMFSerializer()
+    serializer.add_allowed_import(__name__ + '\.AllowedClass')
+    obj = serializer.deserialize(data)
+    assert isinstance(obj, AllowedClass)
+    eq_(obj.id, 345082340955324)
+    eq_(obj.strdata, 'Test String Data')
+    eq_(obj.ucdata, u'Unicode Data ÅÄÖ')
+
+
+def test_pickle_deserialize_disallowed():
+    data = '\x80\x02ctest_serializer\nDisallowedClass\nq\x01)\x81q\x02}q\x03'\
+           '(U\x06ucdataq\x04X\x15\x00\x00\x00Unicode 2 Data \xc3\x85\xc3\x84'\
+           '\xc3\x96q\x05U\x07strdataq\x06U\x12Test 2 String Dataq\x07U\x02'\
+           'idq\x08I993052791289095\nub.'
+    serializer = PickleSerializer()
+    assert_raises(DeserializationError, serializer.deserialize, data)
+
+    serializer.add_allowed_import(__name__ + '\.DisallowedClass')
+    obj = serializer.deserialize(data)
+    assert isinstance(obj, DisallowedClass)
+    eq_(obj.id, 993052791289095)
+    eq_(obj.strdata, 'Test 2 String Data')
+    eq_(obj.ucdata, u'Unicode 2 Data ÅÄÖ')
+
+
+def test_pyamf_deserialize_disallowed():
+    data = '\n\x0b?test_serializer.DisallowedClass\rucdata\x06+Unicode 2 Data'\
+           ' \xc3\x85\xc3\x84\xc3\x96\x0fstrdata\x06%Test 2 String Data\x05id'\
+           '\x05C\x0c9h\xf6\xf1\xa88\x01'
+    serializer = PyAMFSerializer()
+    assert_raises(DeserializationError, serializer.deserialize, data)
+
+    serializer.add_allowed_import(__name__ + '\.DisallowedClass')
+    obj = serializer.deserialize(data)
+    assert isinstance(obj, DisallowedClass)
+    eq_(obj.id, 993052791289095)
+    eq_(obj.strdata, 'Test 2 String Data')
+    eq_(obj.ucdata, u'Unicode 2 Data ÅÄÖ')