Kirill Simonov avatar Kirill Simonov committed 77c9072

Add constructors for some simple python types.

Comments (0)

Files changed (3)

lib/yaml/constructor.py

 except NameError:
     from sets import Set as set
 
-import binascii, re
+import binascii, re, sys
 
 class ConstructorError(MarkedYAMLError):
     pass
                     tag_suffix = node.tag[len(tag_prefix):]
                     constructor = lambda node:  \
                             self.yaml_multi_constructors[tag_prefix](self, tag_suffix, node)
-                break
+                    break
             else:
                 if None in self.yaml_multi_constructors:
                     constructor = lambda node:  \
                     constructor = self.construct_sequence
                 elif isinstance(node, MappingNode):
                     constructor = self.construct_mapping
+                else:
+                    print node.tag
         data = constructor(node)
         self.constructed_objects[node] = data
         return data
         return self.construct_mapping(node)
 
     def construct_yaml_object(self, node, cls):
-        mapping = self.construct_mapping(node)
-        state = {}
-        for key in mapping:
-            state[key.replace('-', '_')] = mapping[key]
+        state = self.construct_mapping(node)
         data = cls.__new__(cls)
         if hasattr(data, '__setstate__'):
-            data.__setstate__(mapping)
+            data.__setstate__(state)
         else:
-            data.__dict__.update(mapping)
+            data.__dict__.update(state)
         return data
 
     def construct_undefined(self, node):
         SafeConstructor.construct_undefined)
 
 class Constructor(SafeConstructor):
-    pass
 
+    def construct_python_str(self, node):
+        return self.construct_scalar(node).encode('utf-8')
+
+    def construct_python_unicode(self, node):
+        return self.construct_scalar(node)
+
+    def construct_python_long(self, node):
+        return long(self.construct_yaml_int(node))
+
+    def construct_python_complex(self, node):
+       return complex(self.construct_scalar(node))
+
+    def construct_python_tuple(self, node):
+        return tuple(self.construct_yaml_seq(node))
+
+    def find_python_module(self, name, mark):
+        if not name:
+            raise ConstructorError("while constructing a Python module", mark,
+                    "expected non-empty name appended to the tag", mark)
+        try:
+            __import__(name)
+        except ImportError, exc:
+            raise ConstructorError("while constructing a Python module", mark,
+                    "cannot find module %r (%s)" % (name.encode('utf-8'), exc), mark)
+        return sys.modules[name]
+
+    def find_python_name(self, name, mark):
+        if not name:
+            raise ConstructorError("while constructing a Python object", mark,
+                    "expected non-empty name appended to the tag", mark)
+        if u'.' in name:
+            module_name, object_name = name.rsplit('.', 1)
+        else:
+            module_name = '__builtin__'
+            object_name = name
+        try:
+            __import__(module_name)
+        except ImportError, exc:
+            raise ConstructorError("while constructing a Python object", mark,
+                    "cannot find module %r (%s)" % (module_name.encode('utf-8'), exc), mark)
+        module = sys.modules[module_name]
+        if not hasattr(module, object_name):
+            raise ConstructorError("while constructing a Python object", mark,
+                    "cannot find %r in the module %r" % (object_name.encode('utf-8'),
+                        module.__name__), mark)
+        return getattr(module, object_name)
+
+    def construct_python_name(self, suffix, node):
+        value = self.construct_scalar(node)
+        if value:
+            raise ConstructorError("while constructing a Python name", node.start_mark,
+                    "expected the empty value, but found %r" % value.encode('utf-8'),
+                    node.start_mark)
+        return self.find_python_name(suffix, node.start_mark)
+
+    def construct_python_module(self, suffix, node):
+        value = self.construct_scalar(node)
+        if value:
+            raise ConstructorError("while constructing a Python module", node.start_mark,
+                    "expected the empty value, but found %r" % value.encode('utf-8'),
+                    node.start_mark)
+        return self.find_python_module(suffix, node.start_mark)
+
+Constructor.add_constructor(
+    u'tag:yaml.org,2002:python/none',
+    Constructor.construct_yaml_null)
+
+Constructor.add_constructor(
+    u'tag:yaml.org,2002:python/bool',
+    Constructor.construct_yaml_bool)
+
+Constructor.add_constructor(
+    u'tag:yaml.org,2002:python/str',
+    Constructor.construct_python_str)
+
+Constructor.add_constructor(
+    u'tag:yaml.org,2002:python/unicode',
+    Constructor.construct_python_unicode)
+
+Constructor.add_constructor(
+    u'tag:yaml.org,2002:python/int',
+    Constructor.construct_yaml_int)
+
+Constructor.add_constructor(
+    u'tag:yaml.org,2002:python/long',
+    Constructor.construct_python_long)
+
+Constructor.add_constructor(
+    u'tag:yaml.org,2002:python/float',
+    Constructor.construct_yaml_float)
+
+Constructor.add_constructor(
+    u'tag:yaml.org,2002:python/complex',
+    Constructor.construct_python_complex)
+
+Constructor.add_constructor(
+    u'tag:yaml.org,2002:python/list',
+    Constructor.construct_yaml_seq)
+
+Constructor.add_constructor(
+    u'tag:yaml.org,2002:python/tuple',
+    Constructor.construct_python_tuple)
+
+Constructor.add_constructor(
+    u'tag:yaml.org,2002:python/dict',
+    Constructor.construct_yaml_map)
+
+Constructor.add_multi_constructor(
+    u'tag:yaml.org,2002:python/name:',
+    Constructor.construct_python_name)
+
+Constructor.add_multi_constructor(
+    u'tag:yaml.org,2002:python/module:',
+    Constructor.construct_python_module)
+

lib/yaml/representer.py

 except NameError:
     from sets import Set as set
 
+import sys
+
 class RepresenterError(YAMLError):
     pass
 
         self.serialize(node)
         self.represented_objects = {}
 
+    class C: pass
+    c = C()
+    def f(): pass
+    classobj_type = type(C)
+    instance_type = type(c)
+    function_type = type(f)
+    builtin_function_type = type(abs)
+    module_type = type(sys)
+    del C, c, f
+
+    def get_classobj_bases(self, cls):
+        bases = [cls]
+        for base in cls.__bases__:
+            bases.extend(self.get_classobj_bases(base))
+        return bases
+
     def represent_object(self, data):
         if self.ignore_aliases(data):
             alias_key = None
                     raise RepresenterError("recursive objects are not allowed: %r" % data)
                 return node
             self.represented_objects[alias_key] = None
-        for data_type in type(data).__mro__:
+        data_types = type(data).__mro__
+        if type(data) is self.instance_type:
+            data_types = self.get_classobj_bases(data.__class__)+data_types
+        for data_type in data_types:
             if data_type in self.yaml_representers:
                 node = self.yaml_representers[data_type](self, data)
                 break
         return SequenceNode(tag, value, flow_style=flow_style)
 
     def represent_mapping(self, tag, mapping, flow_style=None):
-        value = {}
         if hasattr(mapping, 'keys'):
+            value = {}
             for item_key in mapping.keys():
                 item_value = mapping[item_key]
                 value[self.represent_object(item_key)] =    \
                         self.represent_object(item_value)
         else:
+            value = []
             for item_key, item_value in mapping:
-                value[self.represent_object(item_key)] =    \
-                        self.represent_object(item_value)
+                value.append((self.represent_object(item_key),
+                        self.represent_object(item_value)))
         return MappingNode(tag, value, flow_style=flow_style)
 
     def ignore_aliases(self, data):
                 u'null')
 
     def represent_str(self, data):
-        encoding = None
+        tag = None
+        style = None
         try:
-            unicode(data, 'ascii')
-            encoding = 'ascii'
+            data = unicode(data, 'ascii')
+            tag = u'tag:yaml.org,2002:str'
         except UnicodeDecodeError:
             try:
-                unicode(data, 'utf-8')
-                encoding = 'utf-8'
+                data = unicode(data, 'utf-8')
+                tag = u'tag:yaml.org,2002:str'
             except UnicodeDecodeError:
-                pass
-        if encoding:
-            return self.represent_scalar(u'tag:yaml.org,2002:str',
-                    unicode(data, encoding))
-        else:
-            return self.represent_scalar(u'tag:yaml.org,2002:binary',
-                    unicode(data.encode('base64')), style='|')
+                data = data.encode('base64')
+                tag = u'tag:yaml.org,2002:binary'
+                style = '|'
+        return self.represent_scalar(tag, data, style=style)
 
     def represent_unicode(self, data):
         return self.represent_scalar(u'tag:yaml.org,2002:str', data)
         elif data == self.nan_value or data != data:
             value = u'.nan'
         else:
-            value = unicode(data)
+            value = unicode(repr(data))
         return self.represent_scalar(u'tag:yaml.org,2002:float', value)
 
     def represent_list(self, data):
-        pairs = (len(data) > 0)
-        for item in data:
-            if not isinstance(item, tuple) or len(item) != 2:
-                pairs = False
-                break
+        pairs = (len(data) > 0 and isinstance(data, list))
+        if pairs:
+            for item in data:
+                if not isinstance(item, tuple) or len(item) != 2:
+                    pairs = False
+                    break
         if not pairs:
             return self.represent_sequence(u'tag:yaml.org,2002:seq', data)
         value = []
             state = data.__getstate__()
         else:
             state = data.__dict__.copy()
-        mapping = state
-        if hasattr(state, 'keys'):
-            mapping = []
-            keys = state.keys()
-            keys.sort()
-            for key in keys:
-                mapping.append((key.replace('_', '-'), state[key]))
-        return self.represent_mapping(tag, mapping, flow_style=flow_style)
+        return self.represent_mapping(tag, state, flow_style=flow_style)
 
     def represent_undefined(self, data):
         raise RepresenterError("cannot represent an object: %s" % data)
 SafeRepresenter.add_representer(list,
         SafeRepresenter.represent_list)
 
+SafeRepresenter.add_representer(tuple,
+        SafeRepresenter.represent_list)
+
 SafeRepresenter.add_representer(dict,
         SafeRepresenter.represent_dict)
 
         SafeRepresenter.represent_undefined)
 
 class Representer(SafeRepresenter):
-    pass
+    
+    def represent_str(self, data):
+        tag = None
+        style = None
+        try:
+            data = unicode(data, 'ascii')
+            tag = u'tag:yaml.org,2002:str'
+        except UnicodeDecodeError:
+            try:
+                data = unicode(data, 'utf-8')
+                tag = u'tag:yaml.org,2002:python/str'
+            except UnicodeDecodeError:
+                data = data.encode('base64')
+                tag = u'tag:yaml.org,2002:binary'
+                style = '|'
+        return self.represent_scalar(tag, data, style=style)
 
+    def represent_unicode(self, data):
+        tag = None
+        try:
+            data.encode('ascii')
+            tag = u'tag:yaml.org,2002:python/unicode'
+        except UnicodeEncodeError:
+            tag = u'tag:yaml.org,2002:str'
+        return self.represent_scalar(tag, data)
+
+    def represent_long(self, data):
+        tag = u'tag:yaml.org,2002:int'
+        if int(data) is not data:
+            tag = u'tag:yaml.org,2002:python/long'
+        return self.represent_scalar(tag, unicode(data))
+
+    def represent_complex(self, data):
+        if data.real != 0.0:
+            data = u'%r+%rj' % (data.real, data.imag)
+        else:
+            data = u'%rj' % data.imag
+        return self.represent_scalar(u'tag:yaml.org,2002:python/complex', data)
+
+    def represent_tuple(self, data):
+        return self.represent_sequence(u'tag:yaml.org,2002:python/tuple', data)
+
+    def represent_name(self, data):
+        name = u'%s.%s' % (data.__module__, data.__name__)
+        return self.represent_scalar(u'tag:yaml.org,2002:python/name:'+name, u'')
+
+    def represent_module(self, data):
+        return self.represent_scalar(
+                u'tag:yaml.org,2002:python/module:'+data.__name__, u'')
+
+Representer.add_representer(str,
+        Representer.represent_str)
+
+Representer.add_representer(unicode,
+        Representer.represent_unicode)
+
+Representer.add_representer(long,
+        Representer.represent_long)
+
+Representer.add_representer(complex,
+        Representer.represent_complex)
+
+Representer.add_representer(tuple,
+        Representer.represent_tuple)
+
+Representer.add_representer(type,
+        Representer.represent_name)
+
+Representer.add_representer(Representer.classobj_type,
+        Representer.represent_name)
+
+Representer.add_representer(Representer.function_type,
+        Representer.represent_name)
+
+Representer.add_representer(Representer.builtin_function_type,
+        Representer.represent_name)
+
+Representer.add_representer(Representer.module_type,
+        Representer.represent_module)
+

lib/yaml/serializer.py

                 for item in node.value:
                     self.anchor_node(item)
             elif isinstance(node, MappingNode):
-                for key in node.value:
-                    self.anchor_node(key)
-                    self.anchor_node(node.value[key])
+                if hasattr(node.value, 'keys'):
+                    for key in node.value.keys():
+                        self.anchor_node(key)
+                        self.anchor_node(node.value[key])
+                else:
+                    for key, value in node.value:
+                        self.anchor_node(key)
+                        self.anchor_node(value)
 
     def generate_anchor(self, node):
         self.last_anchor_id += 1
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.