Commits

Mike Bayer committed ca1159d Merge

Merge branch 'master' into rel_0_9

  • Participants
  • Parent commits d4ee945, fec03c8

Comments (0)

Files changed (4)

File lib/sqlalchemy/dialects/postgresql/base.py

 
     _backslash_escapes = True
 
-    def __init__(self, isolation_level=None, **kwargs):
+    def __init__(self, isolation_level=None, json_serializer=None,
+                    json_deserializer=None, **kwargs):
         default.DefaultDialect.__init__(self, **kwargs)
         self.isolation_level = isolation_level
+        self._json_deserializer = json_deserializer
+        self._json_serializer = json_serializer
 
     def initialize(self, connection):
         super(PGDialect, self).initialize(connection)

File lib/sqlalchemy/dialects/postgresql/json.py

     will be detected by the unit of work.  See the example at :class:`.HSTORE`
     for a simple example involving a dictionary.
 
+    Custom serializers and deserializers are specified at the dialect level,
+    that is using :func:`.create_engine`.  The reason for this is that when
+    using psycopg2, the DBAPI only allows serializers at the per-cursor
+    or per-connection level.   E.g.::
+
+        engine = create_engine("postgresql://scott:tiger@localhost/test",
+                                json_serializer=my_serialize_fn,
+                                json_deserializer=my_deserialize_fn
+                        )
+
+    When using the psycopg2 dialect, the json_deserializer is registered
+    against the database using ``psycopg2.extras.register_default_json``.
+
     .. versionadded:: 0.9
 
     """
 
     __visit_name__ = 'JSON'
 
-    def __init__(self, json_serializer=None, json_deserializer=None):
-        if json_serializer:
-            self.json_serializer = json_serializer
-        else:
-            self.json_serializer = json.dumps
-        if json_deserializer:
-            self.json_deserializer = json_deserializer
-        else:
-            self.json_deserializer = json.loads
-
     class comparator_factory(sqltypes.Concatenable.Comparator):
         """Define comparison operations for :class:`.JSON`."""
 
                 _adapt_expression(self, op, other_comparator)
 
     def bind_processor(self, dialect):
+        json_serializer = dialect._json_serializer or json.dumps
         if util.py2k:
             encoding = dialect.encoding
             def process(value):
-                return self.json_serializer(value).encode(encoding)
+                return json_serializer(value).encode(encoding)
         else:
             def process(value):
-                return self.json_serializer(value)
+                return json_serializer(value)
         return process
 
     def result_processor(self, dialect, coltype):
+        json_deserializer = dialect._json_deserializer or json.loads
         if util.py2k:
             encoding = dialect.encoding
             def process(value):
-                return self.json_deserializer(value.decode(encoding))
+                return json_deserializer(value.decode(encoding))
         else:
             def process(value):
-                return self.json_deserializer(value)
+                return json_deserializer(value)
         return process
 
 

File lib/sqlalchemy/dialects/postgresql/psycopg2.py

                                         array_oid=array_oid)
             fns.append(on_connect)
 
+        if self.dbapi and self._json_deserializer:
+            def on_connect(conn):
+                extras.register_default_json(conn, loads=self._json_deserializer)
+            fns.append(on_connect)
+
         if fns:
             def on_connect(conn):
                 for fn in fns:

File test/dialect/postgresql/test_types.py

         )
 
     def test_bind_serialize_default(self):
-        from sqlalchemy.engine import default
 
-        dialect = default.DefaultDialect()
+        dialect = postgresql.dialect()
         proc = self.test_table.c.hash.type._cached_bind_processor(dialect)
         eq_(
             proc(util.OrderedDict([("key1", "value1"), ("key2", "value2")])),
         )
 
     def test_bind_serialize_with_slashes_and_quotes(self):
-        from sqlalchemy.engine import default
-
-        dialect = default.DefaultDialect()
+        dialect = postgresql.dialect()
         proc = self.test_table.c.hash.type._cached_bind_processor(dialect)
         eq_(
             proc({'\\"a': '\\"1'}),
         )
 
     def test_parse_error(self):
-        from sqlalchemy.engine import default
-
-        dialect = default.DefaultDialect()
+        dialect = postgresql.dialect()
         proc = self.test_table.c.hash.type._cached_result_processor(
                     dialect, None)
         assert_raises_message(
         )
 
     def test_result_deserialize_default(self):
-        from sqlalchemy.engine import default
-
-        dialect = default.DefaultDialect()
+        dialect = postgresql.dialect()
         proc = self.test_table.c.hash.type._cached_result_processor(
                     dialect, None)
         eq_(
         )
 
     def test_result_deserialize_with_slashes_and_quotes(self):
-        from sqlalchemy.engine import default
-
-        dialect = default.DefaultDialect()
+        dialect = postgresql.dialect()
         proc = self.test_table.c.hash.type._cached_result_processor(
                     dialect, None)
         eq_(
         )
 
     def test_bind_serialize_default(self):
-        from sqlalchemy.engine import default
-
-        dialect = default.DefaultDialect()
+        dialect = postgresql.dialect()
         proc = self.test_table.c.test_column.type._cached_bind_processor(dialect)
         eq_(
             proc({"A": [1, 2, 3, True, False]}),
         )
 
     def test_result_deserialize_default(self):
-        from sqlalchemy.engine import default
-
-        dialect = default.DefaultDialect()
+        dialect = postgresql.dialect()
         proc = self.test_table.c.test_column.type._cached_result_processor(
                     dialect, None)
         eq_(
         )
         self._assert_data([{"k1": "r1v1", "k2": "r1v2"}])
 
-    def _non_native_engine(self):
+    def _non_native_engine(self, json_serializer=None, json_deserializer=None):
+        if json_serializer is not None or json_deserializer is not None:
+            options = {
+                "json_serializer": json_serializer,
+                "json_deserializer": json_deserializer
+            }
+        else:
+            options = {}
+
         if testing.against("postgresql+psycopg2"):
             from psycopg2.extras import register_default_json
-            engine = engines.testing_engine()
+            engine = engines.testing_engine(options=options)
             @event.listens_for(engine, "connect")
             def connect(dbapi_connection, connection_record):
                 engine.dialect._has_native_json = False
                 def pass_(value):
                     return value
                 register_default_json(dbapi_connection, loads=pass_)
+        elif options:
+            engine = engines.testing_engine(options=options)
         else:
             engine = testing.db
         engine.connect()
         engine = self._non_native_engine()
         self._test_insert(engine)
 
+
+    def _test_custom_serialize_deserialize(self, native):
+        import json
+        def loads(value):
+            value = json.loads(value)
+            value['x'] = value['x'] + '_loads'
+            return value
+
+        def dumps(value):
+            value = dict(value)
+            value['x'] = 'dumps_y'
+            return json.dumps(value)
+
+        if native:
+            engine = engines.testing_engine(options=dict(
+                            json_serializer=dumps,
+                            json_deserializer=loads
+                        ))
+        else:
+            engine = self._non_native_engine(
+                            json_serializer=dumps,
+                            json_deserializer=loads
+                        )
+
+        s = select([
+                cast(
+                    {
+                        "key": "value",
+                        "x": "q"
+                    },
+                    JSON
+                )
+            ])
+        eq_(
+            engine.scalar(s),
+            {
+                "key": "value",
+                "x": "dumps_y_loads"
+            },
+        )
+
+    @testing.only_on("postgresql+psycopg2")
+    def test_custom_native(self):
+        self._test_custom_serialize_deserialize(True)
+
+    @testing.only_on("postgresql+psycopg2")
+    def test_custom_python(self):
+        self._test_custom_serialize_deserialize(False)
+
+
     @testing.only_on("postgresql+psycopg2")
     def test_criterion_native(self):
         engine = testing.db