Mike Bayer avatar Mike Bayer committed 4e9c059

- [bug] Fixed bug whereby objects using
attribute_mapped_collection or
column_mapped_collection could not be
pickled. [ticket:2409]

Comments (0)

Files changed (3)

     invokes common table expression support
     from the Core (see below). [ticket:1859]
 
+  - [bug] Fixed bug whereby objects using
+    attribute_mapped_collection or 
+    column_mapped_collection could not be
+    pickled.  [ticket:2409]
+
   - [bug] Fixed bug whereby MappedCollection
     would not get the appropriate collection
     instrumentation if it were only used

lib/sqlalchemy/orm/collections.py

 from sqlalchemy import schema, util, exc as sa_exc
 
 
+
 __all__ = ['collection', 'collection_adapter',
            'mapped_collection', 'column_mapped_collection',
            'attribute_mapped_collection']
 
 __instrumentation_mutex = util.threading.Lock()
 
+class _SerializableColumnGetter(object):
+    def __init__(self, colkeys):
+        self.colkeys = colkeys
+        self.composite = len(colkeys) > 1
+
+    def __reduce__(self):
+        return _SerializableColumnGetter, (self.colkeys,)
+
+    def __call__(self, value):
+        state = instance_state(value)
+        m = _state_mapper(state)
+        key = [m._get_state_attr_by_column(
+                        state, state.dict, 
+                        m.mapped_table.columns[k])
+                     for k in self.colkeys]
+        if self.composite:
+            return tuple(key)
+        else:
+            return key[0]
 
 def column_mapped_collection(mapping_spec):
     """A dictionary-based collection type with column-based keying.
     after a session flush.
 
     """
+    global _state_mapper, instance_state
     from sqlalchemy.orm.util import _state_mapper
     from sqlalchemy.orm.attributes import instance_state
 
-    cols = [expression._only_column_elements(q, "mapping_spec") 
-                for q in util.to_list(mapping_spec)]
-    if len(cols) == 1:
-        def keyfunc(value):
-            state = instance_state(value)
-            m = _state_mapper(state)
-            return m._get_state_attr_by_column(state, state.dict, cols[0])
-    else:
-        mapping_spec = tuple(cols)
-        def keyfunc(value):
-            state = instance_state(value)
-            m = _state_mapper(state)
-            return tuple(m._get_state_attr_by_column(state, state.dict, c)
-                         for c in mapping_spec)
+    cols = [c.key for c in [
+                expression._only_column_elements(q, "mapping_spec") 
+                for q in util.to_list(mapping_spec)]]
+    keyfunc = _SerializableColumnGetter(cols)
     return lambda: MappedCollection(keyfunc)
 
+class _SerializableAttrGetter(object):
+    def __init__(self, name):
+        self.name = name
+        self.getter = operator.attrgetter(name)
+
+    def __call__(self, target):
+        return self.getter(target)
+
+    def __reduce__(self):
+        return _SerializableAttrGetter, (self.name, )
+
 def attribute_mapped_collection(attr_name):
     """A dictionary-based collection type with attribute-based keying.
 
     after a session flush.
 
     """
-    return lambda: MappedCollection(operator.attrgetter(attr_name))
+    getter = _SerializableAttrGetter(attr_name)
+    return lambda: MappedCollection(getter)
 
 
 def mapped_collection(keyfunc):

test/orm/test_pickled.py

                             clear_mappers, exc as orm_exc,\
                             configure_mappers, Session, lazyload_all,\
                             lazyload, aliased
+from sqlalchemy.orm.collections import attribute_mapped_collection, \
+    column_mapped_collection
 from test.lib import fixtures
 from test.orm import _fixtures
 from test.lib.pickleable import User, Address, Dingaling, Order, \
                 repickled = loads(dumps(sa_exc))
                 eq_(repickled.args[0], sa_exc.args[0])
 
+    def test_attribute_mapped_collection(self):
+        users, addresses = self.tables.users, self.tables.addresses
+
+        mapper(User, users, properties={
+            'addresses':relationship(
+                            Address, 
+                            collection_class=
+                            attribute_mapped_collection('email_address')
+                        )
+        })
+        mapper(Address, addresses)
+        u1 = User()
+        u1.addresses = {"email1":Address(email_address="email1")}
+        for loads, dumps in picklers():
+            repickled = loads(dumps(u1))
+            eq_(u1.addresses, repickled.addresses)
+            eq_(repickled.addresses['email1'], 
+                    Address(email_address="email1"))
+
+    def test_column_mapped_collection(self):
+        users, addresses = self.tables.users, self.tables.addresses
+
+        mapper(User, users, properties={
+            'addresses':relationship(
+                            Address, 
+                            collection_class=
+                            column_mapped_collection(
+                                addresses.c.email_address)
+                        )
+        })
+        mapper(Address, addresses)
+        u1 = User()
+        u1.addresses = {
+            "email1":Address(email_address="email1"),
+            "email2":Address(email_address="email2")
+        }
+        for loads, dumps in picklers():
+            repickled = loads(dumps(u1))
+            eq_(u1.addresses, repickled.addresses)
+            eq_(repickled.addresses['email1'], 
+                    Address(email_address="email1"))
+
+    def test_composite_column_mapped_collection(self):
+        users, addresses = self.tables.users, self.tables.addresses
+
+        mapper(User, users, properties={
+            'addresses':relationship(
+                            Address, 
+                            collection_class=
+                            column_mapped_collection([
+                                addresses.c.id,
+                                addresses.c.email_address])
+                        )
+        })
+        mapper(Address, addresses)
+        u1 = User()
+        u1.addresses = {
+            (1, "email1"):Address(id=1, email_address="email1"),
+            (2, "email2"):Address(id=2, email_address="email2")
+        }
+        for loads, dumps in picklers():
+            repickled = loads(dumps(u1))
+            eq_(u1.addresses, repickled.addresses)
+            eq_(repickled.addresses[(1, 'email1')], 
+                    Address(id=1, email_address="email1"))
+
 class PolymorphicDeferredTest(fixtures.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.