Commits

Antoine Pitrou  committed f7a9a10

Issue #14166: Pickler objects now have an optional `dispatch_table` attribute which allows to set custom per-pickler reduction functions.
Patch by sbt.

  • Participants
  • Parent commits 7531b5b

Comments (0)

Files changed (7)

File Doc/library/copyreg.rst

    returned by *function* at pickling time.  :exc:`TypeError` will be raised if
    *object* is a class or *constructor* is not callable.
 
-   See the :mod:`pickle` module for more details on the interface expected of
-   *function* and *constructor*.
-
+   See the :mod:`pickle` module for more details on the interface
+   expected of *function* and *constructor*.  Note that the
+   :attr:`~pickle.Pickler.dispatch_table` attribute of a pickler
+   object or subclass of :class:`pickle.Pickler` can also be used for
+   declaring reduction functions.

File Doc/library/pickle.rst

 
       See :ref:`pickle-persistent` for details and examples of uses.
 
+   .. attribute:: dispatch_table
+
+      A pickler object's dispatch table is a registry of *reduction
+      functions* of the kind which can be declared using
+      :func:`copyreg.pickle`.  It is a mapping whose keys are classes
+      and whose values are reduction functions.  A reduction function
+      takes a single argument of the associated class and should
+      conform to the same interface as a :meth:`~object.__reduce__`
+      method.
+
+      By default, a pickler object will not have a
+      :attr:`dispatch_table` attribute, and it will instead use the
+      global dispatch table managed by the :mod:`copyreg` module.
+      However, to customize the pickling for a specific pickler object
+      one can set the :attr:`dispatch_table` attribute to a dict-like
+      object.  Alternatively, if a subclass of :class:`Pickler` has a
+      :attr:`dispatch_table` attribute then this will be used as the
+      default dispatch table for instances of that class.
+
+      See :ref:`pickle-dispatch` for usage examples.
+
+      .. versionadded:: 3.3
+
    .. attribute:: fast
 
       Deprecated. Enable fast mode if set to a true value.  The fast mode
 
 .. literalinclude:: ../includes/dbpickle.py
 
+.. _pickle-dispatch:
+
+Dispatch Tables
+^^^^^^^^^^^^^^^
+
+If one wants to customize pickling of some classes without disturbing
+any other code which depends on pickling, then one can create a
+pickler with a private dispatch table.
+
+The global dispatch table managed by the :mod:`copyreg` module is
+available as :data:`copyreg.dispatch_table`.  Therefore, one may
+choose to use a modified copy of :data:`copyreg.dispatch_table` as a
+private dispatch table.
+
+For example ::
+
+   f = io.BytesIO()
+   p = pickle.Pickler(f)
+   p.dispatch_table = copyreg.dispatch_table.copy()
+   p.dispatch_table[SomeClass] = reduce_SomeClass
+
+creates an instance of :class:`pickle.Pickler` with a private dispatch
+table which handles the ``SomeClass`` class specially.  Alternatively,
+the code ::
+
+   class MyPickler(pickle.Pickler):
+       dispatch_table = copyreg.dispatch_table.copy()
+       dispatch_table[SomeClass] = reduce_SomeClass
+   f = io.BytesIO()
+   p = MyPickler(f)
+
+does the same, but all instances of ``MyPickler`` will by default
+share the same dispatch table.  The equivalent code using the
+:mod:`copyreg` module is ::
+
+   copyreg.pickle(SomeClass, reduce_SomeClass)
+   f = io.BytesIO()
+   p = pickle.Pickler(f)
 
 .. _pickle-state:
 

File Lib/pickle.py

             f(self, obj) # Call unbound method with explicit self
             return
 
-        # Check copyreg.dispatch_table
-        reduce = dispatch_table.get(t)
+        # Check private dispatch table if any, or else copyreg.dispatch_table
+        reduce = getattr(self, 'dispatch_table', dispatch_table).get(t)
         if reduce:
             rv = reduce(obj)
         else:

File Lib/test/pickletester.py

                 self.assertEqual(unpickler.load(), data)
 
 
+# Tests for dispatch_table attribute
+
+REDUCE_A = 'reduce_A'
+
+class AAA(object):
+    def __reduce__(self):
+        return str, (REDUCE_A,)
+
+class BBB(object):
+    pass
+
+class AbstractDispatchTableTests(unittest.TestCase):
+
+    def test_default_dispatch_table(self):
+        # No dispatch_table attribute by default
+        f = io.BytesIO()
+        p = self.pickler_class(f, 0)
+        with self.assertRaises(AttributeError):
+            p.dispatch_table
+        self.assertFalse(hasattr(p, 'dispatch_table'))
+
+    def test_class_dispatch_table(self):
+        # A dispatch_table attribute can be specified class-wide
+        dt = self.get_dispatch_table()
+
+        class MyPickler(self.pickler_class):
+            dispatch_table = dt
+
+        def dumps(obj, protocol=None):
+            f = io.BytesIO()
+            p = MyPickler(f, protocol)
+            self.assertEqual(p.dispatch_table, dt)
+            p.dump(obj)
+            return f.getvalue()
+
+        self._test_dispatch_table(dumps, dt)
+
+    def test_instance_dispatch_table(self):
+        # A dispatch_table attribute can also be specified instance-wide
+        dt = self.get_dispatch_table()
+
+        def dumps(obj, protocol=None):
+            f = io.BytesIO()
+            p = self.pickler_class(f, protocol)
+            p.dispatch_table = dt
+            self.assertEqual(p.dispatch_table, dt)
+            p.dump(obj)
+            return f.getvalue()
+
+        self._test_dispatch_table(dumps, dt)
+
+    def _test_dispatch_table(self, dumps, dispatch_table):
+        def custom_load_dump(obj):
+            return pickle.loads(dumps(obj, 0))
+
+        def default_load_dump(obj):
+            return pickle.loads(pickle.dumps(obj, 0))
+
+        # pickling complex numbers using protocol 0 relies on copyreg
+        # so check pickling a complex number still works
+        z = 1 + 2j
+        self.assertEqual(custom_load_dump(z), z)
+        self.assertEqual(default_load_dump(z), z)
+
+        # modify pickling of complex
+        REDUCE_1 = 'reduce_1'
+        def reduce_1(obj):
+            return str, (REDUCE_1,)
+        dispatch_table[complex] = reduce_1
+        self.assertEqual(custom_load_dump(z), REDUCE_1)
+        self.assertEqual(default_load_dump(z), z)
+
+        # check picklability of AAA and BBB
+        a = AAA()
+        b = BBB()
+        self.assertEqual(custom_load_dump(a), REDUCE_A)
+        self.assertIsInstance(custom_load_dump(b), BBB)
+        self.assertEqual(default_load_dump(a), REDUCE_A)
+        self.assertIsInstance(default_load_dump(b), BBB)
+
+        # modify pickling of BBB
+        dispatch_table[BBB] = reduce_1
+        self.assertEqual(custom_load_dump(a), REDUCE_A)
+        self.assertEqual(custom_load_dump(b), REDUCE_1)
+        self.assertEqual(default_load_dump(a), REDUCE_A)
+        self.assertIsInstance(default_load_dump(b), BBB)
+
+        # revert pickling of BBB and modify pickling of AAA
+        REDUCE_2 = 'reduce_2'
+        def reduce_2(obj):
+            return str, (REDUCE_2,)
+        dispatch_table[AAA] = reduce_2
+        del dispatch_table[BBB]
+        self.assertEqual(custom_load_dump(a), REDUCE_2)
+        self.assertIsInstance(custom_load_dump(b), BBB)
+        self.assertEqual(default_load_dump(a), REDUCE_A)
+        self.assertIsInstance(default_load_dump(b), BBB)
+
+
 if __name__ == "__main__":
     # Print some stuff that can be used to rewrite DATA{0,1,2}
     from pickletools import dis

File Lib/test/test_pickle.py

 import pickle
 import io
+import collections
 
 from test import support
 
 from test.pickletester import AbstractPickleModuleTests
 from test.pickletester import AbstractPersistentPicklerTests
 from test.pickletester import AbstractPicklerUnpicklerObjectTests
+from test.pickletester import AbstractDispatchTableTests
 from test.pickletester import BigmemPickleTests
 
 try:
     unpickler_class = pickle._Unpickler
 
 
+class PyDispatchTableTests(AbstractDispatchTableTests):
+    pickler_class = pickle._Pickler
+    def get_dispatch_table(self):
+        return pickle.dispatch_table.copy()
+
+
+class PyChainDispatchTableTests(AbstractDispatchTableTests):
+    pickler_class = pickle._Pickler
+    def get_dispatch_table(self):
+        return collections.ChainMap({}, pickle.dispatch_table)
+
+
 if has_c_implementation:
     class CPicklerTests(PyPicklerTests):
         pickler = _pickle.Pickler
         pickler_class = _pickle.Pickler
         unpickler_class = _pickle.Unpickler
 
+    class CDispatchTableTests(AbstractDispatchTableTests):
+        pickler_class = pickle.Pickler
+        def get_dispatch_table(self):
+            return pickle.dispatch_table.copy()
+
+    class CChainDispatchTableTests(AbstractDispatchTableTests):
+        pickler_class = pickle.Pickler
+        def get_dispatch_table(self):
+            return collections.ChainMap({}, pickle.dispatch_table)
+
 
 def test_main():
-    tests = [PickleTests, PyPicklerTests, PyPersPicklerTests]
+    tests = [PickleTests, PyPicklerTests, PyPersPicklerTests,
+             PyDispatchTableTests, PyChainDispatchTableTests]
     if has_c_implementation:
         tests.extend([CPicklerTests, CPersPicklerTests,
                       CDumpPickle_LoadPickle, DumpPickle_CLoadPickle,
                       PyPicklerUnpicklerObjectTests,
                       CPicklerUnpicklerObjectTests,
+                      CDispatchTableTests, CChainDispatchTableTests,
                       InMemoryPickleTests])
     support.run_unittest(*tests)
     support.run_doctest(pickle)
 Library
 -------
 
+- Issue #14166: Pickler objects now have an optional ``dispatch_table``
+  attribute which allows to set custom per-pickler reduction functions.
+  Patch by sbt.
+
 - Issue #14177: marshal.loads() now raises TypeError when given an unicode
   string.  Patch by Guilherme Gonçalves.
 

File Modules/_pickle.c

                                    objects to support self-referential objects
                                    pickling. */
     PyObject *pers_func;        /* persistent_id() method, can be NULL */
+    PyObject *dispatch_table;   /* private dispatch_table, can be NULL */
     PyObject *arg;
 
     PyObject *write;            /* write() method of the output stream. */
         return NULL;
 
     self->pers_func = NULL;
+    self->dispatch_table = NULL;
     self->arg = NULL;
     self->write = NULL;
     self->proto = 0;
     /* XXX: This part needs some unit tests. */
 
     /* Get a reduction callable, and call it.  This may come from
-     * copyreg.dispatch_table, the object's __reduce_ex__ method,
-     * or the object's __reduce__ method.
+     * self.dispatch_table, copyreg.dispatch_table, the object's
+     * __reduce_ex__ method, or the object's __reduce__ method.
      */
-    reduce_func = PyDict_GetItem(dispatch_table, (PyObject *)type);
+    if (self->dispatch_table == NULL) {
+        reduce_func = PyDict_GetItem(dispatch_table, (PyObject *)type);
+        /* PyDict_GetItem() unlike PyObject_GetItem() and
+           PyObject_GetAttr() returns a borrowed ref */
+        Py_XINCREF(reduce_func);
+    } else {
+        reduce_func = PyObject_GetItem(self->dispatch_table, (PyObject *)type);
+        if (reduce_func == NULL) {
+            if (PyErr_ExceptionMatches(PyExc_KeyError))
+                PyErr_Clear();
+            else
+                goto error;
+        }
+    }
     if (reduce_func != NULL) {
-        /* Here, the reference count of the reduce_func object returned by
-           PyDict_GetItem needs to be increased to be consistent with the one
-           returned by PyObject_GetAttr. This is allow us to blindly DECREF
-           reduce_func at the end of the save() routine.
-        */
-        Py_INCREF(reduce_func);
         Py_INCREF(obj);
         reduce_value = _Pickler_FastCall(self, reduce_func, obj);
     }
     Py_XDECREF(self->output_buffer);
     Py_XDECREF(self->write);
     Py_XDECREF(self->pers_func);
+    Py_XDECREF(self->dispatch_table);
     Py_XDECREF(self->arg);
     Py_XDECREF(self->fast_memo);
 
 {
     Py_VISIT(self->write);
     Py_VISIT(self->pers_func);
+    Py_VISIT(self->dispatch_table);
     Py_VISIT(self->arg);
     Py_VISIT(self->fast_memo);
     return 0;
     Py_CLEAR(self->output_buffer);
     Py_CLEAR(self->write);
     Py_CLEAR(self->pers_func);
+    Py_CLEAR(self->dispatch_table);
     Py_CLEAR(self->arg);
     Py_CLEAR(self->fast_memo);
 
     PyObject *proto_obj = NULL;
     PyObject *fix_imports = Py_True;
     _Py_IDENTIFIER(persistent_id);
+    _Py_IDENTIFIER(dispatch_table);
 
     if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|OO:Pickler",
                                      kwlist, &file, &proto_obj, &fix_imports))
         if (self->pers_func == NULL)
             return -1;
     }
+    self->dispatch_table = NULL;
+    if (_PyObject_HasAttrId((PyObject *)self, &PyId_dispatch_table)) {
+        self->dispatch_table = _PyObject_GetAttrId((PyObject *)self,
+                                                   &PyId_dispatch_table);
+        if (self->dispatch_table == NULL)
+            return -1;
+    }
     return 0;
 }
 
 static PyMemberDef Pickler_members[] = {
     {"bin", T_INT, offsetof(PicklerObject, bin)},
     {"fast", T_INT, offsetof(PicklerObject, fast)},
+    {"dispatch_table", T_OBJECT_EX, offsetof(PicklerObject, dispatch_table)},
     {NULL}
 };