Michael Trier avatar Michael Trier committed f16ec79

Modified query_cls on DynamicAttribteImpl to accept a full mixin version of the AppenderQuery.

Comments (0)

Files changed (3)

 =====
 
 - orm
+    - Modified query_cls on DynamicAttributeImpl to accept a full
+      mixin version of the AppenderQuery, which allows subclassing
+      the AppenderMixin.
+
     - Fixed the evaluator not being able to evaluate IS NULL clauses.
 
     - Fixed the "set collection" function on "dynamic" relations to

lib/sqlalchemy/orm/dynamic.py

         strategies._register_attribute(self,
             mapper,
             useobject=True,
-            impl_class=DynamicAttributeImpl, 
-            target_mapper=self.parent_property.mapper, 
-            order_by=self.parent_property.order_by, 
+            impl_class=DynamicAttributeImpl,
+            target_mapper=self.parent_property.mapper,
+            order_by=self.parent_property.order_by,
             query_class=self.parent_property.query_class
         )
 
     uses_objects = True
     accepts_scalar_loader = False
 
-    def __init__(self, class_, key, typecallable, 
+    def __init__(self, class_, key, typecallable,
                      target_mapper, order_by, query_class=None, **kwargs):
         super(DynamicAttributeImpl, self).__init__(class_, key, typecallable, **kwargs)
         self.target_mapper = target_mapper
         self.order_by = order_by
         if not query_class:
             self.query_class = AppenderQuery
+        elif AppenderMixin in query_class.mro():
+            self.query_class = query_class
         else:
             self.query_class = mixin_user_query(query_class)
 
             ext.remove(state, value, initiator or self)
 
     def _modified_event(self, state):
-        
+
         if self.key not in state.committed_state:
             state.committed_state[self.key] = CollectionHistory(self, state)
 
 
         collection_history = self._modified_event(state)
         new_values = list(iterable)
-        
+
         if _state_has_identity(state):
             old_collection = list(self.get(state))
         else:
             c = state.committed_state[self.key]
         else:
             c = CollectionHistory(self, state)
-            
+
         if not passive:
             return CollectionHistory(self, state, apply_to=c)
         else:
             return c
-        
+
     def append(self, state, value, initiator, passive=False):
         if initiator is not self:
             self.fire_append_event(state, value, initiator)
-    
+
     def remove(self, state, value, initiator, passive=False):
         if initiator is not self:
             self.fire_remove_event(state, value, initiator)
 
 class DynCollectionAdapter(object):
     """the dynamic analogue to orm.collections.CollectionAdapter"""
-    
+
     def __init__(self, attr, owner_state, data):
         self.attr = attr
         self.state = owner_state
         self.data = data
-    
+
     def __iter__(self):
         return iter(self.data)
-        
+
     def append_with_event(self, item, initiator=None):
         self.attr.append(self.state, item, initiator)
 
 
     def append_without_event(self, item):
         pass
-    
+
     def remove_without_event(self, item):
         pass
-        
+
 class AppenderMixin(object):
     query_class = None
 
         Query.__init__(self, attr.target_mapper, None)
         self.instance = state.obj()
         self.attr = attr
-    
+
     def __session(self):
         sess = object_session(self.instance)
         if sess is not None and self.autoflush and sess.autoflush and self.instance in sess:
             return None
         else:
             return sess
-    
+
     def session(self):
         return self.__session()
     session = property(session, lambda s, x:None)
-    
+
     def __iter__(self):
         sess = self.__session()
         if sess is None:
                 passive=True).added_items.__getitem__(index)
         else:
             return self._clone(sess).__getitem__(index)
-    
+
     def count(self):
         sess = self.__session()
         if sess is None:
     name = 'Appender' + cls.__name__
     return type(name, (AppenderMixin, cls), {'query_class': cls})
 
-class CollectionHistory(object): 
+class CollectionHistory(object):
     """Overrides AttributeHistory to receive append/remove events directly."""
 
     def __init__(self, attr, state, apply_to=None):
             self.deleted_items = []
             self.added_items = []
             self.unchanged_items = []
-        
+

test/orm/dynamic.py

 from testlib import testing
 from testlib.sa import Table, Column, Integer, String, ForeignKey, desc, select, func
 from testlib.sa.orm import mapper, relation, create_session, Query, attributes
+from sqlalchemy.orm.dynamic import AppenderMixin
 from testlib.testing import eq_
 from testlib.compat import _function_named
 from orm import _base, _fixtures
         assert not hasattr(q, 'append')
         assert type(q).__name__ == 'MyQuery'
 
+    @testing.resolve_artifact_names
+    def test_custom_query_with_custom_mixin(self):
+        class MyAppenderMixin(AppenderMixin):
+            def add(self, items):
+                if isinstance(items, list):
+                    for item in items:
+                        self.append(item)
+                else:
+                    self.append(items)
+
+        class MyQuery(Query):
+            pass
+
+        class MyAppenderQuery(MyAppenderMixin, MyQuery):
+            query_class = MyQuery
+
+        mapper(User, users, properties={
+            'addresses':dynamic_loader(mapper(Address, addresses),
+                                       query_class=MyAppenderQuery)
+        })
+        sess = create_session()
+        u = User()
+        sess.add(u)
+
+        col = u.addresses
+        assert isinstance(col, Query)
+        assert isinstance(col, MyQuery)
+        assert hasattr(col, 'append')
+        assert hasattr(col, 'add')
+        assert type(col).__name__ == 'MyAppenderQuery'
+
+        q = col.limit(1)
+        assert isinstance(q, Query)
+        assert isinstance(q, MyQuery)
+        assert not hasattr(q, 'append')
+        assert not hasattr(q, 'add')
+        assert type(q).__name__ == 'MyQuery'
+
 
 class SessionTest(_fixtures.FixtureTest):
     run_inserts = None
         a1 = Address(email_address='foo')
         sess.add_all([u1, a1])
         sess.flush()
-        
+
         assert testing.db.scalar(select([func.count(1)]).where(addresses.c.user_id!=None)) == 0
         u1 = sess.query(User).get(u1.id)
         u1.addresses.append(a1)
         assert testing.db.execute(select([addresses]).where(addresses.c.user_id!=None)).fetchall() == [
             (a1.id, u1.id, 'foo')
         ]
-        
+
         u1.addresses.remove(a1)
         sess.flush()
         assert testing.db.scalar(select([func.count(1)]).where(addresses.c.user_id!=None)) == 0
-        
+
         u1.addresses.append(a1)
         sess.flush()
         assert testing.db.execute(select([addresses]).where(addresses.c.user_id!=None)).fetchall() == [
         assert testing.db.execute(select([addresses]).where(addresses.c.user_id!=None)).fetchall() == [
             (a2.id, u1.id, 'bar')
         ]
-        
+
 
     @testing.resolve_artifact_names
     def test_merge(self):
         a1 = Address(email_address='a1')
         a2 = Address(email_address='a2')
         a3 = Address(email_address='a3')
-        
+
         u1.addresses.append(a2)
         u1.addresses.append(a3)
-        
+
         sess.add_all([u1, a1])
         sess.flush()
-        
+
         u1 = User(id=u1.id, name='jack')
         u1.addresses.append(a1)
         u1.addresses.append(a3)
         u1 = sess.merge(u1)
         assert attributes.get_history(u1, 'addresses') == (
-            [a1], 
-            [a3], 
+            [a1],
+            [a3],
             [a2]
         )
 
         sess.flush()
-        
+
         eq_(
             list(u1.addresses),
             [a1, a3]
         )
-        
+
     @testing.resolve_artifact_names
     def test_flush(self):
         mapper(User, users, properties={
         u1.addresses.append(Address(email_address='lala@hoho.com'))
         sess.add_all((u1, u2))
         sess.flush()
-        
+
         from sqlalchemy.orm import attributes
         self.assertEquals(attributes.get_history(attributes.instance_state(u1), 'addresses'), ([], [Address(email_address='lala@hoho.com')], []))
-        
+
         sess.expunge_all()
 
         # test the test fixture a little bit
             User(name='jack', addresses=[Address(email_address='lala@hoho.com')]),
             User(name='ed', addresses=[Address(email_address='foo@bar.com')])
         ] == sess.query(User).all()
-    
+
     @testing.resolve_artifact_names
     def test_hasattr(self):
         mapper(User, users, properties={
             'addresses':dynamic_loader(mapper(Address, addresses))
         })
         u1 = User(name='jack')
-        
+
         assert 'addresses' not in u1.__dict__.keys()
         u1.addresses = [Address(email_address='test')]
         assert 'addresses' in dir(u1)
-    
+
     @testing.resolve_artifact_names
     def test_collection_set(self):
         mapper(User, users, properties={
         a2 = Address(email_address='a2')
         a3 = Address(email_address='a3')
         a4 = Address(email_address='a4')
-        
+
         sess.add(u1)
         u1.addresses = [a1, a3]
         assert list(u1.addresses) == [a1, a3]
         assert list(u1.addresses) == [a2, a3]
         u1.addresses = []
         assert list(u1.addresses) == []
-        
-        
 
-        
+
+
+
     @testing.resolve_artifact_names
     def test_rollback(self):
         mapper(User, users, properties={
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.