Mike Bayer avatar Mike Bayer committed 3102da6

test cases were not fully testing contains_eager() with regards to [ticket:777], fixed contains_eager() for more than one level deep

Comments (0)

Files changed (2)

lib/sqlalchemy/orm/strategies.py

 
     def process_query_property(self, query, paths):
         if self.alias is not None and self.decorator is None:
+            (mapper, propname) = paths[-1][-2:]
+
+            prop = mapper.get_property(propname, resolve_synonyms=True)
             if isinstance(self.alias, basestring):
-                (mapper, propname) = paths[-1]
-                prop = mapper.get_property(propname, resolve_synonyms=True)
                 self.alias = prop.target.alias(self.alias)
-            def decorate(row):
-                d = {}
-                for c in prop.target.columns:
-                    d[c] = row[self.alias.corresponding_column(c)]
-                return d
-            self.decorator = decorate
+
+            self.decorator = mapperutil.create_row_adapter(self.alias, prop.target)
         query._attributes[("eager_row_processor", paths[-1])] = self.decorator
 
 RowDecorateOption.logger = logging.class_logger(RowDecorateOption)

test/orm/query.py

     def test_from_alias(self):
 
         query = users.select(users.c.id==7).union(users.select(users.c.id>7)).alias('ulist').outerjoin(addresses).select(use_labels=True,order_by=['ulist.id', addresses.c.id])
-        q = create_session().query(User)
+        sess =create_session()
+        q = sess.query(User)
 
         def go():
             l = q.options(contains_alias('ulist'), contains_eager('addresses')).instances(query.execute())
             assert fixtures.user_address_result == l
         self.assert_sql_count(testbase.db, go, 1)
 
-
+        sess.clear()
+        
         def go():
             l = q.options(contains_alias('ulist'), contains_eager('addresses')).from_statement(query).all()
             assert fixtures.user_address_result == l
     def test_contains_eager(self):
 
         selectquery = users.outerjoin(addresses).select(users.c.id<10, use_labels=True, order_by=[users.c.id, addresses.c.id])
-        q = create_session().query(User)
+        sess = create_session()
+        q = sess.query(User)
 
         def go():
             l = q.options(contains_eager('addresses')).instances(selectquery.execute())
             assert fixtures.user_address_result[0:3] == l
         self.assert_sql_count(testbase.db, go, 1)
 
+        sess.clear()
+        
         def go():
             l = q.options(contains_eager('addresses')).from_statement(selectquery).all()
             assert fixtures.user_address_result[0:3] == l
     def test_contains_eager_alias(self):
         adalias = addresses.alias('adalias')
         selectquery = users.outerjoin(adalias).select(use_labels=True, order_by=[users.c.id, adalias.c.id])
-        q = create_session().query(User)
+        sess = create_session()
+        q = sess.query(User)
 
         def go():
             # test using a string alias name
             l = q.options(contains_eager('addresses', alias="adalias")).instances(selectquery.execute())
             assert fixtures.user_address_result == l
         self.assert_sql_count(testbase.db, go, 1)
-
+        sess.clear()
+        
         def go():
             # test using the Alias object itself
             l = q.options(contains_eager('addresses', alias=adalias)).instances(selectquery.execute())
             assert fixtures.user_address_result == l
         self.assert_sql_count(testbase.db, go, 1)
-
+        
+        sess.clear()
+        
         def decorate(row):
             d = {}
             for c in addresses.columns:
             l = q.options(contains_eager('addresses', decorator=decorate)).instances(selectquery.execute())
             assert fixtures.user_address_result == l
         self.assert_sql_count(testbase.db, go, 1)
+        sess.clear()
+        
+        oalias = orders.alias('o1')
+        ialias = items.alias('i1')
+        query = users.outerjoin(oalias).outerjoin(order_items).outerjoin(ialias).select(use_labels=True)
+        q = create_session().query(User)
+        # test using string alias with more than one level deep
+        def go():
+            l = q.options(contains_eager('orders', alias='o1'), contains_eager('orders.items', alias='i1')).instances(query.execute())
+            assert fixtures.user_order_result == l
+        self.assert_sql_count(testbase.db, go, 1)
+
+        sess.clear()
+        
+        # test using Alias with more than one level deep
+        def go():
+            l = q.options(contains_eager('orders', alias=oalias), contains_eager('orders.items', alias=ialias)).instances(query.execute())
+            assert fixtures.user_order_result == l
+        self.assert_sql_count(testbase.db, go, 1)
+        sess.clear()
+
 
     def test_multi_mappers(self):
-        sess = create_session()
 
-        (user7, user8, user9, user10) = sess.query(User).all()
-        (address1, address2, address3, address4, address5) = sess.query(Address).all()
+        test_session = create_session()
+        (user7, user8, user9, user10) = test_session.query(User).all()
+        (address1, address2, address3, address4, address5) = test_session.query(Address).all()
 
         # note the result is a cartesian product
         expected = [(user7, address1),
             (user9, address5),
             (user10, None)]
 
+        sess = create_session()
+
         selectquery = users.outerjoin(addresses).select(use_labels=True, order_by=[users.c.id, addresses.c.id])
         q = sess.query(User)
         l = q.instances(selectquery.execute(), Address)
         assert l == expected
-
+        
+        sess.clear()
+        
         for aliased in (False, True):
             q = sess.query(User)
+
             q = q.add_entity(Address).outerjoin('addresses', aliased=aliased)
             l = q.all()
             assert l == expected
+            sess.clear()
 
             q = sess.query(User).add_entity(Address)
             l = q.join('addresses', aliased=aliased).filter_by(email_address='ed@bettyboop.com').all()
             assert l == [(user8, address3)]
+            sess.clear()
 
             q = sess.query(User, Address).join('addresses', aliased=aliased).filter_by(email_address='ed@bettyboop.com')
             assert q.all() == [(user8, address3)]
+            sess.clear()
 
             q = sess.query(User, Address).join('addresses', aliased=aliased).options(eagerload('addresses')).filter_by(email_address='ed@bettyboop.com')
             assert q.all() == [(user8, address3)]
-
+            sess.clear()
+            
     def test_aliased_multi_mappers(self):
         sess = create_session()
 
         l = q.all()
         assert l == expected
 
+        sess.clear()
+        
         q = sess.query(User).add_entity(Address, alias=adalias)
         l = q.select_from(users.outerjoin(adalias)).filter(adalias.c.email_address=='ed@bettyboop.com').all()
         assert l == [(user8, address3)]
         
         for add_col in (User.name, users.c.name, User.c.name):
             assert sess.query(User).add_column(add_col).all() == expected
-
+            sess.clear()
+            
         try:
             sess.query(User).add_column(object()).all()
             assert False
             q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin('addresses', aliased=aliased).add_column(func.count(Address.id).label('count'))
             l = q.all()
             assert l == expected
-
+            sess.clear()
+            
         s = select([users, func.count(addresses.c.id).label('count')]).select_from(users.outerjoin(addresses)).group_by(*[c for c in users.c]).order_by(User.id)
         q = sess.query(User)
         l = q.add_column("count").from_statement(s).all()
         q = create_session().query(User)
         l = q.add_column("count").add_column("concat").from_statement(s).all()
         assert l == expected
-
+        
+        sess.clear()
+        
         # test with select_from()
         q = create_session().query(User).add_column(func.count(addresses.c.id))\
             .add_column(("Name:" + users.c.name)).select_from(users.outerjoin(addresses))\
             .group_by([c for c in users.c]).order_by(users.c.id)
 
         assert q.all() == expected
-
+        sess.clear()
+        
         # test with outerjoin() both aliased and non
         for aliased in (False, True):
             q = create_session().query(User).add_column(func.count(addresses.c.id))\
                 .group_by([c for c in users.c]).order_by(users.c.id)
 
             assert q.all() == expected
-
+            sess.clear()
+            
 class CustomJoinTest(QueryTest):
     keep_mappers = False
 
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.