Commits

Marcin Kuzminski committed 18ae6ad

Don't monkeypatch importend json modules, use imp() to make a copy
- added test of custom encoder

  • Participants
  • Parent commits fa3d620

Comments (0)

Files changed (2)

 import datetime
 import functools
 import decimal
+import imp
 
-__all__ = ['json', 'simplejson', 'stdjson']
+__all__ = ['json', 'simplejson', 'stdlibjson']
 
 
 def _is_aware(value):
 # Import simplejson
 try:
     # import simplejson initially
-    import simplejson as _sj
+    _sj = imp.load_module('_sj', *imp.find_module('simplejson'))
 
     def extended_encode(obj):
         try:
         except NotImplementedError:
             pass
         raise TypeError("%r is not JSON serializable" % (obj,))
-    # we handle decimals our own it makes unified behavior of json vs 
+    # we handle decimals our own it makes unified behavior of json vs
     # simplejson
-    _sj.dumps = functools.partial(_sj.dumps, default=extended_encode,
-                                  use_decimal=False)
-    _sj.dump = functools.partial(_sj.dump, default=extended_encode,
-                                 use_decimal=False)
+    sj_version = [int(x) for x in _sj.__version__.split('.')]
+    major, minor = sj_version[0], sj_version[1]
+    if major < 2 or (major == 2 and minor < 1):
+        # simplejson < 2.1 doesnt support use_decimal
+        _sj.dumps = functools.partial(_sj.dumps,
+                                             default=extended_encode)
+        _sj.dump = functools.partial(_sj.dump,
+                                            default=extended_encode)
+    else:
+        _sj.dumps = functools.partial(_sj.dumps,
+                                             default=extended_encode,
+                                             use_decimal=False)
+        _sj.dump = functools.partial(_sj.dump,
+                                            default=extended_encode,
+                                            use_decimal=False)
     simplejson = _sj
 
 except ImportError:
     # no simplejson set it to None
-    _sj = None
+    simplejson = None
 
 
-# simplejson not found try out regular json module
-import json as _json
+try:
+    # simplejson not found try out regular json module
+    _json = imp.load_module('_json', *imp.find_module('json'))
 
+    # extended JSON encoder for json
+    class ExtendedEncoder(_json.JSONEncoder):
+        def default(self, obj):
+            try:
+                return _obj_dump(obj)
+            except NotImplementedError:
+                pass
+            raise TypeError("%r is not JSON serializable" % (obj,))
+    # monkey-patch JSON encoder to use extended version
+    _json.dumps = functools.partial(_json.dumps, cls=ExtendedEncoder)
+    _json.dump = functools.partial(_json.dump, cls=ExtendedEncoder)
 
-# extended JSON encoder for json
-class ExtendedEncoder(_json.JSONEncoder):
-    def default(self, obj):
-        try:
-            return _obj_dump(obj)
-        except NotImplementedError:
-            pass
-        return _json.JSONEncoder.default(self, obj)
-# monkey-patch JSON encoder to use extended version
-_json.dumps = functools.partial(_json.dumps, cls=ExtendedEncoder)
-_json.dump = functools.partial(_json.dump, cls=ExtendedEncoder)
-stdlib = _json
+except ImportError:
+    json = None
+
+stdlibjson = _json
 
 # set all available json modules
-simplejson = _sj
-stdjson = _json
-json = _sj if _sj else _json
+if _sj:
+    json = _sj
+elif json:
+    json = _json
+else:
+    raise ImportError('Could not find any json modules')
 class TestJSONEncoder(unittest.TestCase):
 
     def setUp(self):
-        from ext_json import stdjson
-        self.json = stdjson
+        from ext_json import stdlibjson
+        self.json = stdlibjson
 
 
 class TestSIMPLEJSONEncoder(unittest.TestCase):
         from ext_json import simplejson
         self.json = simplejson
 
+
+class Promise(object):
+    def __init__(self, title):
+        self.title = title
+
+    def __repr__(self):
+        return 'ImPromise'
+
+
+class TestOverrideEncoder(unittest.TestCase):
+
+    def test_override_json(self):
+        import simplejson
+
+        class LazyEncoder(simplejson.JSONEncoder):
+            """Encodes django's lazy i18n strings.
+            """
+            def default(self, obj):
+                if isinstance(obj, Promise):
+                    return unicode(obj)
+                return obj
+
+        result = simplejson.dumps({
+            "html": '<span></span>',
+            "message": Promise(u"Data has been saved."),
+        }, cls=LazyEncoder)
+
+        assert result == '{"message": "ImPromise", "html": "<span></span>"}'
+
+    def test_override_simplejson(self):
+        import json
+
+        def default_enc(obj):
+            if isinstance(obj, Promise):
+                return unicode(obj)
+            return obj
+
+        result = json.dumps({
+            "html": '<span></span>',
+            "message": Promise(u"Data has been saved."),
+        }, default=default_enc)
+
+        assert result == '{"message": "ImPromise", "html": "<span></span>"}'
+
 for name, orgval, jsonval in test_cases:
     setattr(TestJSONEncoder, "test_JSON_%s" % name,
             _make_dump_test(name, orgval, jsonval))