Kirill Simonov avatar Kirill Simonov committed a188d5e

Add support for pickling/unpickling python objects.

Comments (0)

Files changed (7)

lib/yaml/__init__.py

     """
     Dumper.add_representer(data_type, representer)
 
+def add_multi_representer(data_type, multi_representer, Dumper=Dumper):
+    """
+    Add a representer for the given type.
+    Multi-representer is a function accepting a Dumper instance
+    and an instance of the given data type or subtype
+    and producing the corresponding representation node.
+    """
+    Dumper.add_multi_representer(data_type, multi_representer)
+
 class YAMLObjectMetaclass(type):
     """
     The metaclass for YAMLObject.

lib/yaml/constructor.py

                     node.start_mark)
         return self.find_python_module(suffix, node.start_mark)
 
+    class classobj: pass
+
+    def make_python_instance(self, suffix, node,
+            args=None, kwds=None, newobj=False):
+        if not args:
+            args = []
+        if not kwds:
+            kwds = {}
+        cls = self.find_python_name(suffix, node.start_mark)
+        if newobj and isinstance(cls, type(self.classobj))  \
+                and not args and not kwds:
+            instance = self.classobj()
+            instance.__class__ = cls
+            return instance
+        elif newobj and isinstance(cls, type):
+            return cls.__new__(cls, *args, **kwds)
+        else:
+            return cls(*args, **kwds)
+
+    def set_python_instance_state(self, instance, state):
+        if hasattr(instance, '__setstate__'):
+            instance.__setstate__(state)
+        else:
+            slotstate = {}
+            if isinstance(state, tuple) and len(state) == 2:
+                state, slotstate = state
+            if hasattr(instance, '__dict__'):
+                instance.__dict__.update(state)
+            elif state:
+                slotstate.update(state)
+            for key, value in slotstate.items():
+                setattr(object, key, value)
+
+    def construct_python_object(self, suffix, node):
+        # Format:
+        #   !!python/object:module.name { ... state ... }
+        instance = self.make_python_instance(suffix, node, newobj=True)
+        state = self.construct_mapping(node)
+        self.set_python_instance_state(instance, state)
+        return instance
+
+    def construct_python_object_apply(self, suffix, node, newobj=False):
+        # Format:
+        #   !!python/object/apply       # (or !!python/object/new)
+        #   args: [ ... arguments ... ]
+        #   kwds: { ... keywords ... }
+        #   state: ... state ...
+        #   listitems: [ ... listitems ... ]
+        #   dictitems: { ... dictitems ... }
+        # or short format:
+        #   !!python/object/apply [ ... arguments ... ]
+        # The difference between !!python/object/apply and !!python/object/new
+        # is how an object is created, check make_python_instance for details.
+        if isinstance(node, SequenceNode):
+            args = self.construct_sequence(node)
+            kwds = {}
+            state = {}
+            listitems = []
+            dictitems = {}
+        else:
+            value = self.construct_mapping(node)
+            args = value.get('args', [])
+            kwds = value.get('kwds', {})
+            state = value.get('state', {})
+            listitems = value.get('listitems', [])
+            dictitems = value.get('dictitems', {})
+        instance = self.make_python_instance(suffix, node, args, kwds, newobj)
+        if state:
+            self.set_python_instance_state(instance, state)
+        if listitems:
+            instance.extend(listitems)
+        if dictitems:
+            for key in dictitems:
+                instance[key] = dictitems[key]
+        return instance
+
+    def construct_python_object_new(self, suffix, node):
+        return self.construct_python_object_apply(suffix, node, newobj=True)
+
+
 Constructor.add_constructor(
     u'tag:yaml.org,2002:python/none',
     Constructor.construct_yaml_null)
     u'tag:yaml.org,2002:python/module:',
     Constructor.construct_python_module)
 
+Constructor.add_multi_constructor(
+    u'tag:yaml.org,2002:python/object:',
+    Constructor.construct_python_object)
+
+Constructor.add_multi_constructor(
+    u'tag:yaml.org,2002:python/object/apply:',
+    Constructor.construct_python_object_apply)
+
+Constructor.add_multi_constructor(
+    u'tag:yaml.org,2002:python/object/new:',
+    Constructor.construct_python_object_new)
+

lib/yaml/representer.py

 except NameError:
     from sets import Set as set
 
-import sys
+import sys, copy_reg
 
 class RepresenterError(YAMLError):
     pass
 class BaseRepresenter:
 
     yaml_representers = {}
+    yaml_multi_representers = {}
 
     def __init__(self):
         self.represented_objects = {}
 
     def represent(self, data):
-        node = self.represent_object(data)
+        node = self.represent_data(data)
         self.serialize(node)
         self.represented_objects = {}
 
             bases.extend(self.get_classobj_bases(base))
         return bases
 
-    def represent_object(self, data):
+    def represent_data(self, data):
         if self.ignore_aliases(data):
             alias_key = None
         else:
         data_types = type(data).__mro__
         if type(data) is self.instance_type:
             data_types = self.get_classobj_bases(data.__class__)+list(data_types)
-        for data_type in data_types:
-            if data_type in self.yaml_representers:
-                node = self.yaml_representers[data_type](self, data)
-                break
+        if data_types[0] in self.yaml_representers:
+            node = self.yaml_representers[data_types[0]](self, data)
         else:
-            if None in self.yaml_representers:
-                node = self.yaml_representers[None](self, data)
+            for data_type in data_types:
+                if data_type in self.yaml_multi_representers:
+                    node = self.yaml_multi_representers[data_type](self, data)
+                    break
             else:
-                node = ScalarNode(None, unicode(data))
+                if None in self.yaml_multi_representers:
+                    node = self.yaml_multi_representers[None](self, data)
+                elif None in self.yaml_representers:
+                    node = self.yaml_representers[None](self, data)
+                else:
+                    node = ScalarNode(None, unicode(data))
         if alias_key is not None:
             self.represented_objects[alias_key] = node
         return node
         cls.yaml_representers[data_type] = representer
     add_representer = classmethod(add_representer)
 
+    def add_multi_representer(cls, data_type, representer):
+        if not 'yaml_multi_representers' in cls.__dict__:
+            cls.yaml_multi_representers = cls.yaml_multi_representers.copy()
+        cls.yaml_multi_representers[data_type] = representer
+    add_multi_representer = classmethod(add_multi_representer)
+
     def represent_scalar(self, tag, value, style=None):
         return ScalarNode(tag, value, style=style)
 
     def represent_sequence(self, tag, sequence, flow_style=None):
+        best_style = True
         value = []
         for item in sequence:
-            value.append(self.represent_object(item))
+            node_item = self.represent_data(item)
+            if not (isinstance(node_item, ScalarNode) and not node_item.style):
+                best_style = False
+            value.append(self.represent_data(item))
+        if flow_style is None:
+            flow_style = best_style
         return SequenceNode(tag, value, flow_style=flow_style)
 
     def represent_mapping(self, tag, mapping, flow_style=None):
+        best_style = True
         if hasattr(mapping, 'keys'):
             value = {}
             for item_key in mapping.keys():
                 item_value = mapping[item_key]
-                value[self.represent_object(item_key)] =    \
-                        self.represent_object(item_value)
+                node_key = self.represent_data(item_key)
+                node_value = self.represent_data(item_value)
+                if not (isinstance(node_key, ScalarNode) and not node_key.style):
+                    best_style = False
+                if not (isinstance(node_value, ScalarNode) and not node_value.style):
+                    best_style = False
+                value[node_key] = node_value
         else:
             value = []
             for item_key, item_value in mapping:
-                value.append((self.represent_object(item_key),
-                        self.represent_object(item_value)))
+                node_key = self.represent_data(item_key)
+                node_value = self.represent_data(item_value)
+                if not (isinstance(node_key, ScalarNode) and not node_key.style):
+                    best_style = False
+                if not (isinstance(node_value, ScalarNode) and not node_value.style):
+                    best_style = False
+                value.append((node_key, node_value))
+        if flow_style is None:
+            flow_style = best_style
         return MappingNode(tag, value, flow_style=flow_style)
 
     def ignore_aliases(self, data):
         SafeRepresenter.represent_undefined)
 
 class Representer(SafeRepresenter):
-    
+
     def represent_str(self, data):
         tag = None
         style = None
         return self.represent_scalar(
                 u'tag:yaml.org,2002:python/module:'+data.__name__, u'')
 
+    def represent_instance(self, data):
+        # For instances of classic classes, we use __getinitargs__ and
+        # __getstate__ to serialize the data.
+
+        # If data.__getinitargs__ exists, the object must be reconstructed by
+        # calling cls(**args), where args is a tuple returned by
+        # __getinitargs__. Otherwise, the cls.__init__ method should never be
+        # called and the class instance is created by instantiating a trivial
+        # class and assigning to the instance's __class__ variable.
+
+        # If data.__getstate__ exists, it returns the state of the object.
+        # Otherwise, the state of the object is data.__dict__.
+
+        # We produce either a !!python/object or !!python/object/new node.
+        # If data.__getinitargs__ does not exist and state is a dictionary, we
+        # produce a !!python/object node . Otherwise we produce a
+        # !!python/object/new node.
+
+        cls = data.__class__
+        class_name = u'%s.%s' % (cls.__module__, cls.__name__)
+        args = None
+        state = None
+        if hasattr(data, '__getinitargs__'):
+            args = list(data.__getinitargs__())
+        if hasattr(data, '__getstate__'):
+            state = data.__getstate__()
+        else:
+            state = data.__dict__
+        if args is None and isinstance(state, dict):
+            return self.represent_mapping(
+                    u'tag:yaml.org,2002:python/object:'+class_name, state)
+        if isinstance(state, dict) and not state:
+            return self.represent_sequence(
+                    u'tag:yaml.org,2002:python/object/new:'+class_name, args)
+        value = {}
+        if args:
+            value['args'] = args
+        value['state'] = state
+        return self.represent_mapping(
+                u'tag:yaml.org,2002:python/object/new:'+class_name, value)
+
+    def represent_object(self, data):
+        # We use __reduce__ API to save the data. data.__reduce__ returns
+        # a tuple of length 2-5:
+        #   (function, args, state, listitems, dictitems)
+
+        # For reconstructing, we calls function(*args), then set its state,
+        # listitems, and dictitems if they are not None.
+
+        # A special case is when function.__name__ == '__newobj__'. In this
+        # case we create the object with args[0].__new__(*args).
+
+        # Another special case is when __reduce__ returns a string - we don't
+        # support it.
+
+        # We produce a !!python/object, !!python/object/new or
+        # !!python/object/apply node.
+
+        cls = type(data)
+        if cls in copy_reg.dispatch_table:
+            reduce = copy_reg.dispatch_table[cls]
+        elif hasattr(data, '__reduce_ex__'):
+            reduce = data.__reduce_ex__(2)
+        elif hasattr(data, '__reduce__'):
+            reduce = data.__reduce__()
+        else:
+            raise RepresenterError("cannot represent object: %r" % data)
+        reduce = (list(reduce)+[None]*5)[:5]
+        function, args, state, listitems, dictitems = reduce
+        args = list(args)
+        if state is None:
+            state = {}
+        if listitems is not None:
+            listitems = list(listitems)
+        if dictitems is not None:
+            dictitems = dict(dictitems)
+        if function.__name__ == '__newobj__':
+            function = args[0]
+            args = args[1:]
+            tag = u'tag:yaml.org,2002:python/object/new:'
+            newobj = True
+        else:
+            tag = u'tag:yaml.org,2002:python/object/apply:'
+            newobj = False
+        function_name = u'%s.%s' % (function.__module__, function.__name__)
+        if not args and not listitems and not dictitems \
+                and isinstance(state, dict) and newobj:
+            return self.represent_mapping(
+                    u'tag:yaml.org,2002:python/object:'+function_name, state)
+        if not listitems and not dictitems  \
+                and isinstance(state, dict) and not state:
+            return self.represent_sequence(tag+function_name, args)
+        value = {}
+        if args:
+            value['args'] = args
+        if state or not isinstance(state, dict):
+            value['state'] = state
+        if listitems:
+            value['listitems'] = listitems
+        if dictitems:
+            value['dictitems'] = dictitems
+        return self.represent_mapping(tag+function_name, value)
+
 Representer.add_representer(str,
         Representer.represent_str)
 
 Representer.add_representer(Representer.module_type,
         Representer.represent_module)
 
+Representer.add_multi_representer(Representer.instance_type,
+        Representer.represent_instance)
+
+Representer.add_multi_representer(object,
+        Representer.represent_object)
+

tests/data/construct-python-object.code

+[
+AnObject(1, 'two', [3,3,3]),
+AnInstance(1, 'two', [3,3,3]),
+
+AnObject(1, 'two', [3,3,3]),
+AnInstance(1, 'two', [3,3,3]),
+
+AState(1, 'two', [3,3,3]),
+ACustomState(1, 'two', [3,3,3]),
+
+InitArgs(1, 'two', [3,3,3]),
+InitArgsWithState(1, 'two', [3,3,3]),
+
+NewArgs(1, 'two', [3,3,3]),
+NewArgsWithState(1, 'two', [3,3,3]),
+
+Reduce(1, 'two', [3,3,3]),
+ReduceWithState(1, 'two', [3,3,3]),
+
+MyInt(3),
+MyList(3),
+MyDict(3),
+]

tests/data/construct-python-object.data

+- !!python/object:test_constructor.AnObject { foo: 1, bar: two, baz: [3,3,3] }
+- !!python/object:test_constructor.AnInstance { foo: 1, bar: two, baz: [3,3,3] }
+
+- !!python/object/new:test_constructor.AnObject { args: [1, two], kwds: {baz: [3,3,3]} }
+- !!python/object/apply:test_constructor.AnInstance { args: [1, two], kwds: {baz: [3,3,3]} }
+
+- !!python/object:test_constructor.AState { _foo: 1, _bar: two, _baz: [3,3,3] }
+- !!python/object/new:test_constructor.ACustomState { state: !!python/tuple [1, two, [3,3,3]] }
+
+- !!python/object/new:test_constructor.InitArgs [1, two, [3,3,3]]
+- !!python/object/new:test_constructor.InitArgsWithState { args: [1, two], state: [3,3,3] }
+
+- !!python/object/new:test_constructor.NewArgs [1, two, [3,3,3]]
+- !!python/object/new:test_constructor.NewArgsWithState { args: [1, two], state: [3,3,3] }
+
+- !!python/object/apply:test_constructor.Reduce [1, two, [3,3,3]]
+- !!python/object/apply:test_constructor.ReduceWithState { args: [1, two], state: [3,3,3] }
+
+- !!python/object/new:test_constructor.MyInt [3]
+- !!python/object/new:test_constructor.MyList { listitems: [~, ~, ~] }
+- !!python/object/new:test_constructor.MyDict { dictitems: {0, 1, 2} }

tests/test_constructor.py

         else:
             return False
 
+class AnObject(object):
+
+    def __new__(cls, foo=None, bar=None, baz=None):
+        self = object.__new__(cls)
+        self.foo = foo
+        self.bar = bar
+        self.baz = baz
+        return self
+
+    def __cmp__(self, other):
+        return cmp((type(self), self.foo, self.bar, self.baz),
+                (type(other), other.foo, other.bar, other.baz))
+
+    def __eq__(self, other):
+        return type(self) is type(other) and    \
+                (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz)
+
+class AnInstance:
+
+    def __init__(self, foo=None, bar=None, baz=None):
+        self.foo = foo
+        self.bar = bar
+        self.baz = baz
+
+    def __cmp__(self, other):
+        return cmp((type(self), self.foo, self.bar, self.baz),
+                (type(other), other.foo, other.bar, other.baz))
+
+    def __eq__(self, other):
+        return type(self) is type(other) and    \
+                (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz)
+
+class AState(AnInstance):
+
+    def __getstate__(self):
+        return {
+            '_foo': self.foo,
+            '_bar': self.bar,
+            '_baz': self.baz,
+        }
+
+    def __setstate__(self, state):
+        self.foo = state['_foo']
+        self.bar = state['_bar']
+        self.baz = state['_baz']
+
+class ACustomState(AnInstance):
+
+    def __getstate__(self):
+        return (self.foo, self.bar, self.baz)
+
+    def __setstate__(self, state):
+        self.foo, self.bar, self.baz = state
+
+class InitArgs(AnInstance):
+
+    def __getinitargs__(self):
+        return (self.foo, self.bar, self.baz)
+
+    def __getstate__(self):
+        return {}
+
+class InitArgsWithState(AnInstance):
+
+    def __getinitargs__(self):
+        return (self.foo, self.bar)
+
+    def __getstate__(self):
+        return self.baz
+
+    def __setstate__(self, state):
+        self.baz = state
+
+class NewArgs(AnObject):
+
+    def __getnewargs__(self):
+        return (self.foo, self.bar, self.baz)
+
+    def __getstate__(self):
+        return {}
+
+class NewArgsWithState(AnObject):
+
+    def __getnewargs__(self):
+        return (self.foo, self.bar)
+
+    def __getstate__(self):
+        return self.baz
+
+    def __setstate__(self, state):
+        self.baz = state
+
+class Reduce(AnObject):
+
+    def __reduce__(self):
+        return self.__class__, (self.foo, self.bar, self.baz)
+
+class ReduceWithState(AnObject):
+
+    def __reduce__(self):
+        return self.__class__, (self.foo, self.bar), self.baz
+
+    def __setstate__(self, state):
+        self.baz = state
+
+class MyInt(int):
+
+    def __eq__(self, other):
+        return type(self) is type(other) and int(self) == int(other)
+
+class MyList(list):
+
+    def __init__(self, n=1):
+        self.extend([None]*n)
+
+    def __eq__(self, other):
+        return type(self) is type(other) and list(self) == list(other)
+
+class MyDict(dict):
+
+    def __init__(self, n=1):
+        for k in range(n):
+            self[k] = None
+
+    def __eq__(self, other):
+        return type(self) is type(other) and dict(self) == dict(other)
+
 class TestConstructorTypes(test_appliance.TestAppliance):
 
     def _testTypes(self, test_name, data_filename, code_filename):

tests/test_representer.py

                     data2 = data2.items()
                     data2.sort()
                     data2 = repr(data2)
-                if data1 != data2:
+                    if data1 != data2:
+                        raise
+                elif isinstance(data1, list):
+                    self.failUnlessEqual(type(data1), type(data2))
+                    self.failUnlessEqual(len(data1), len(data2))
+                    for item1, item2 in zip(data1, data2):
+                        self.failUnlessEqual(item1, item2)
+                else:
                     raise
         except:
             print
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.