Commits

Mike Bayer committed 98beeb2

- fix to using query.count() with distinct, **kwargs with SelectResults
count() [ticket:287]

Comments (0)

Files changed (4)

 so that control over timezone presence is more controllable (psycopg2
 returns datetimes with tzinfo's if available, which can create confusion
 against datetimes that dont).
+- fix to using query.count() with distinct, **kwargs with SelectResults
+count() [ticket:287]
 
 0.2.7
 - quoting facilities set up so that database-specific quoting can be

lib/sqlalchemy/ext/selectresults.py

 
     def count(self):
         """executes the SQL count() function against the SelectResults criterion."""
-        return self._query.count(self._clause)
+        return self._query.count(self._clause, **self._ops)
 
     def _col_aggregate(self, col, func):
         """executes func() function against the given column

lib/sqlalchemy/orm/query.py

         return self._select_statement(statement, params=params)
 
     def count(self, whereclause=None, params=None, **kwargs):
-        s = self.table.count(whereclause)
+        if self._nestable(**kwargs):
+            s = self.table.select(whereclause, **kwargs).alias('getcount').count()
+        else:
+            s = self.table.count(whereclause)
         return self.session.scalar(self.mapper, s, params=params)
 
     def select_statement(self, statement, **params):
         return self.instances(statement, params=params, **kwargs)
 
     def _should_nest(self, **kwargs):
-        """returns True if the given statement options indicate that we should "nest" the
+        """return True if the given statement options indicate that we should "nest" the
         generated query as a subquery inside of a larger eager-loading query.  this is used
         with keywords like distinct, limit and offset and the mapper defines eager loads."""
         return (
             self.mapper.has_eager()
-            and (kwargs.has_key('limit') or kwargs.has_key('offset') or kwargs.get('distinct', False))
+            and self._nestable(**kwargs)
         )
 
+    def _nestable(self, **kwargs):
+        """return true if the given statement options imply it should be nested."""
+        return (kwargs.has_key('limit') or kwargs.has_key('offset') or kwargs.get('distinct', False))
+        
     def compile(self, whereclause = None, **kwargs):
         order_by = kwargs.pop('order_by', False)
         from_obj = kwargs.pop('from_obj', [])

test/orm/selectresults.py

     def test_offset(self):
         assert len(list(self.res.limit(10))) == 10
 
+class Obj1(object):
+    pass
+class Obj2(object):
+    pass
+
+class SelectResultsTest2(PersistTest):
+    def setUpAll(self):
+        self.install_threadlocal()
+        global metadata, table1, table2
+        metadata = BoundMetaData(testbase.db)
+        table1 = Table('Table1', metadata,
+            Column('id', Integer, primary_key=True),
+            )
+        table2 = Table('Table2', metadata,
+            Column('t1id', Integer, ForeignKey("Table1.id"), primary_key=True),
+            Column('num', Integer, primary_key=True),
+            )
+        assign_mapper(Obj1, table1, extension=SelectResultsExt())
+        assign_mapper(Obj2, table2, extension=SelectResultsExt())
+        metadata.create_all()
+        table1.insert().execute({'id':1},{'id':2},{'id':3},{'id':4})
+        table2.insert().execute({'num':1,'t1id':1},{'num':2,'t1id':1},{'num':3,'t1id':1},\
+{'num':4,'t1id':2},{'num':5,'t1id':2},{'num':6,'t1id':3})
+
+    def setUp(self):
+        self.query = Obj1.mapper.query()
+        #self.orig = self.query.select_whereclause()
+        #self.res = self.query.select()
+
+    def tearDownAll(self):
+        metadata.drop_all()
+        self.uninstall_threadlocal()
+
+    def test_distinctcount(self):
+        res = self.query.select()
+        assert res.count() == 4
+        res = self.query.select(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1))
+        assert res.count() == 3
+        res = self.query.select(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1), distinct=True)
+        self.assertEqual(res.count(), 1)
+
+
 
 if __name__ == "__main__":
     testbase.main()