Commits

Mike Bayer  committed a3e1eeb

- added has(), like any() but for scalars
- added **kwargs to has() and any(), criterion is optional; generate equality criterion
against the related table (since we know the related property when has() and any() are used),
i.e. filter(Address.user.has(name='jack')) equivalent to filter(Address.user.has(User.name=='jack'))
- added "from_joinpoint=False" arg to join()/outerjoin(). yes, I know join() is getting a little
crazy, but this flag is needed when you want to keep building along a line of aliased joins,
adding query criterion for each alias in the chain. self-referential unit test added.
- fixed basic_tree example a little bit

  • Participants
  • Parent commits ae2340a
  • Branches rel_0_4

Comments (0)

Files changed (7)

File examples/adjacencytree/basic_tree.py

 """a basic Adjacency List model tree."""
 
 from sqlalchemy import *
+from sqlalchemy.orm import *
 from sqlalchemy.util import OrderedDict
+from sqlalchemy.orm.collections import attribute_mapped_collection
 
 metadata = MetaData('sqlite:///', echo=True)
 
     Column('node_name', String(50), nullable=False),
     )
 
-class NodeList(OrderedDict):
-    """subclasses OrderedDict to allow usage as a list-based property."""
-    def append(self, node):
-        self[node.name] = node
-    def __iter__(self):
-        return iter(self.values())
 
 class TreeNode(object):
     """a rich Tree class which includes path-based operations"""
     def __init__(self, name):
-        self.children = NodeList()
         self.name = name
         self.parent = None
         self.id = None
         if isinstance(node, str):
             node = TreeNode(node)
         node.parent = self
-        self.children.append(node)
+        self.children[node.name] = node
     def __repr__(self):
         return self._getstring(0, False)
     def __str__(self):
     id=trees.c.node_id,
     name=trees.c.node_name,
     parent_id=trees.c.parent_node_id,
-    children=relation(TreeNode, cascade="all", backref=backref("parent", remote_side=[trees.c.node_id]), collection_class=NodeList),
+    children=relation(TreeNode, cascade="all", backref=backref("parent", remote_side=[trees.c.node_id]), collection_class=attribute_mapped_collection('name')),
 ))
 
 print "\n\n\n----------------------------"

File lib/sqlalchemy/orm/attributes.py

     def clause_element(self):
         return self.comparator.clause_element()
         
-    def operate(self, op, other):
-        return op(self.comparator, other)
+    def operate(self, op, other, **kwargs):
+        return op(self.comparator, other, **kwargs)
 
-    def reverse_operate(self, op, other):
-        return op(other, self.comparator)
+    def reverse_operate(self, op, other, **kwargs):
+        return op(other, self.comparator, **kwargs)
         
     def hasparent(self, item, optimistic=False):
         """Return the boolean value of a `hasparent` flag attached to the given item.

File lib/sqlalchemy/orm/interfaces.py

         The return value of this method is used as the result of
         ``query.get_by()`` if the value is anything other than
         EXT_PASS.
+        
+        DEPRECATED.
         """
 
         return EXT_PASS
         The return value of this method is used as the result of
         ``query.select_by()`` if the value is anything other than
         EXT_PASS.
+        
+        DEPRECATED.
         """
 
         return EXT_PASS
         The return value of this method is used as the result of
         ``query.select()`` if the value is anything other than
         EXT_PASS.
+        
+        DEPRECATED.
         """
 
         return EXT_PASS
         return a.contains(b)
     contains_op = staticmethod(contains_op)
     
-    def any_op(a, b):
-        return a.any(b)
+    def any_op(a, b, **kwargs):
+        return a.any(b, **kwargs)
     any_op = staticmethod(any_op)
     
+    def has_op(a, b, **kwargs):
+        return a.has(b, **kwargs)
+    has_op = staticmethod(has_op)
+    
     def __init__(self, prop):
         self.prop = prop
 
         """return true if this collection contains other"""
         return self.operate(PropComparator.contains_op, other)
 
-    def any(self, criterion):
-        """return true if this collection contains any member that meets the given criterion"""
-        return self.operate(PropComparator.any_op, criterion)
+    def any(self, criterion=None, **kwargs):
+        """return true if this collection contains any member that meets the given criterion.
+        
+            criterion
+                an optional ClauseElement formulated against the member class' table or attributes.
+                
+            \**kwargs
+                key/value pairs corresponding to member class attribute names which will be compared
+                via equality to the corresponding values.
+        """
+
+        return self.operate(PropComparator.any_op, criterion, **kwargs)
+    
+    def has(self, criterion=None, **kwargs):
+        """return true if this element references a member which meets the given criterion.
+        
+    
+        criterion
+            an optional ClauseElement formulated against the member class' table or attributes.
+            
+        \**kwargs
+            key/value pairs corresponding to member class attribute names which will be compared
+            via equality to the corresponding values.
+        """
+
+        return self.operate(PropComparator.has_op, criterion, **kwargs)
         
 class StrategizedProperty(MapperProperty):
     """A MapperProperty which uses selectable strategies to affect

File lib/sqlalchemy/orm/properties.py

                 return ~sql.exists([1], self.prop.primaryjoin)
             elif self.prop.uselist:
                 if not hasattr(other, '__iter__'):
-                    raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object")
+                    raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object.")
                 else:
                     j = self.prop.primaryjoin
                     if self.prop.secondaryjoin:
             else:  
                 return self.prop._optimized_compare(other)
         
-        def any(self, criterion):
+        def any(self, criterion=None, **kwargs):
             if not self.prop.uselist:
-                raise exceptions.InvalidRequestError("'any' not implemented for scalar attributes")
+                raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().")
             j = self.prop.primaryjoin
             if self.prop.secondaryjoin:
                 j = j & self.prop.secondaryjoin
+            for k in kwargs:
+                crit = (getattr(self.prop.mapper.class_, k) == kwargs[k])
+                if criterion is None:
+                    criterion = crit
+                else:
+                    criterion = criterion & crit
+            return sql.exists([1], j & criterion)
+        
+        def has(self, criterion=None, **kwargs):
+            if self.prop.uselist:
+                raise exceptions.InvalidRequestError("'has()' not implemented for collections.  Use any().")
+            j = self.prop.primaryjoin
+            if self.prop.secondaryjoin:
+                j = j & self.prop.secondaryjoin
+            for k in kwargs:
+                crit = (getattr(self.prop.mapper.class_, k) == kwargs[k])
+                if criterion is None:
+                    criterion = crit
+                else:
+                    criterion = criterion & crit
             return sql.exists([1], j & criterion)
                 
         def contains(self, other):
             if not self.prop.uselist:
-                raise exceptions.InvalidRequestError("'contains' not implemented for scalar attributes")
+                raise exceptions.InvalidRequestError("'contains' not implemented for scalar attributes.  Use ==")
             clause = self.prop._optimized_compare(other)
 
             j = self.prop.primaryjoin

File lib/sqlalchemy/orm/query.py

     def _join_to(self, keys, outerjoin=False, start=None, create_aliases=True):
         if start is None:
             start = self._joinpoint
-        
+            
         clause = self._from_obj[-1]
 
         currenttables = [clause]
             
         
         mapper = start
-        alias = None
+        alias = self._aliases
         for key in util.to_list(keys):
             prop = mapper.get_property(key, resolve_synonyms=True)
             if prop._is_self_referential() and not create_aliases:
             q._group_by = q._group_by + util.to_list(criterion)
         return q
 
-    def join(self, prop, id=None, aliased=False):
+    def join(self, prop, id=None, aliased=False, from_joinpoint=False):
         """create a join of this ``Query`` object's criterion
         to a relationship and return the newly resulting ``Query``.
 
         property names.
         """
 
-        return self._join(prop, id=id, outerjoin=False, aliased=aliased)
+        return self._join(prop, id=id, outerjoin=False, aliased=aliased, from_joinpoint=from_joinpoint)
         
-    def outerjoin(self, prop, id=None, aliased=False):
+    def outerjoin(self, prop, id=None, aliased=False, from_joinpoint=False):
         """create a left outer join of this ``Query`` object's criterion
         to a relationship and return the newly resulting ``Query``.
         
         property names.
         """
 
-        return self._join(prop, id=id, outerjoin=True, aliased=aliased)
+        return self._join(prop, id=id, outerjoin=True, aliased=aliased, from_joinpoint=from_joinpoint)
 
-    def _join(self, prop, id, outerjoin, aliased):
-        (clause, mapper, aliases) = self._join_to(prop, outerjoin=outerjoin, start=self.mapper, create_aliases=aliased)
+    def _join(self, prop, id, outerjoin, aliased, from_joinpoint):
+        (clause, mapper, aliases) = self._join_to(prop, outerjoin=outerjoin, start=from_joinpoint and self._joinpoint or self.mapper, create_aliases=aliased)
         q = self._clone()
         q._from_obj = [clause]
         q._joinpoint = mapper

File lib/sqlalchemy/sql.py

     def clause_element(self):
         raise NotImplementedError()
 
-    def operate(self, op, *other):
+    def operate(self, op, *other, **kwargs):
         raise NotImplementedError()
 
-    def reverse_operate(self, op, *other):
+    def reverse_operate(self, op, *other, **kwargs):
         raise NotImplementedError()
 
 class ColumnOperators(Operators):

File test/orm/query.py

 from testlib import *
 from fixtures import *
 
-class Base(object):
-    def __init__(self, **kwargs):
-        for k in kwargs:
-            setattr(self, k, kwargs[k])
-            
-    def __ne__(self, other):
-        return not self.__eq__(other)
-        
-    def __eq__(self, other):
-        """'passively' compare this object to another.
-        
-        only look at attributes that are present on the source object.
-        
-        """
-        # use __dict__ to avoid instrumented properties
-        for attr in self.__dict__.keys():
-            if attr[0] == '_':
-                continue
-            value = getattr(self, attr)
-            if hasattr(value, '__iter__') and not isinstance(value, basestring):
-                if len(value) == 0:
-                    continue
-                for (us, them) in zip(value, getattr(other, attr)):
-                    if us != them:
-                        return False
-                else:
-                    continue
-            else:
-                if value is not None:
-                    if value != getattr(other, attr):
-                        return False
-        else:
-            return True
-
 class QueryTest(ORMTest):
     keep_mappers = True
     keep_data = True
     
     def test_any(self):
         sess = create_session()
-        address = sess.query(Address).get(3)
+
         assert [User(id=8), User(id=9)] == sess.query(User).filter(User.addresses.any(Address.email_address.like('%ed%'))).all()
+
+        assert [User(id=8)] == sess.query(User).filter(User.addresses.any(Address.email_address.like('%ed%'), id=4)).all()
+
+        assert [User(id=9)] == sess.query(User).filter(User.addresses.any(email_address='fred@fred.com')).all()
+    
+    def test_has(self):
+        sess = create_session()
+        assert [Address(id=5)] == sess.query(Address).filter(Address.user.has(name='fred')).all()
         
+        assert [Address(id=2), Address(id=3), Address(id=4), Address(id=5)] == sess.query(Address).filter(Address.user.has(User.name.like('%ed%'))).all()
+        
+        assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter(Address.user.has(User.name.like('%ed%'), id=8)).all()
+            
     def test_contains_m2m(self):
         sess = create_session()
         item = sess.query(Item).get(3)
 
         assert [Order(id=4), Order(id=5)] == sess.query(Order).filter(~Order.items.contains(item)).all()
 
-    def test_has(self):
+    def test_comparison(self):
         """test scalar comparison to an object instance"""
         
         sess = create_session()
                 self.children.append(node)
 
         mapper(Node, nodes, properties={
-            'children':relation(Node, lazy=True, join_depth=3)
+            'children':relation(Node, lazy=True, join_depth=3, 
+                backref=backref('parent', remote_side=[nodes.c.id])
+            )
         })
         sess = create_session()
         n1 = Node(data='n1')
 
         node = sess.query(Node).join(['children', 'children'], aliased=True).filter_by(data='n122').first()
         assert node.data=='n1'
+        
+        node = sess.query(Node).filter_by(data='n122').join('parent', aliased=True).filter_by(data='n12').\
+            join('parent', aliased=True, from_joinpoint=True).filter_by(data='n1').first()
+        assert node.data == 'n122'
 
 class ExternalColumnsTest(QueryTest):
     keep_mappers = False