Commits

Carl Meyer committed 781afbc

Fixed annotation of InheritanceQuerySets. Thanks Jeff Elmore.

  • Participants
  • Parent commits ab7312d

Comments (0)

Files changed (3)

 tip (unreleased)
 ----------------
 
+- Fixed annotation of InheritanceQuerysets. Thanks Jeff Elmore.
+
 - Dropped support for Python 2.5 and Django 1.1. Both are no longer supported
   even for security fixes, and should not be used.
 

File model_utils/managers.py

 from django.db.models.manager import Manager
 from django.db.models.query import QuerySet
 
+
 class InheritanceQuerySet(QuerySet):
+    def __init__(self, *args, **kwargs):
+        self._annotated = None
+        super(InheritanceQuerySet, self).__init__(*args, **kwargs)
+
     def select_subclasses(self, *subclasses):
         if not subclasses:
             subclasses = [rel.var_name for rel in self.model._meta.get_all_related_objects()
 
     def _clone(self, klass=None, setup=False, **kwargs):
         try:
-            kwargs.update({'subclasses': self.subclasses})
+            kwargs.update({'subclasses': self.subclasses,
+                           '_annotated': self._annotated})
         except AttributeError:
             pass
         return super(InheritanceQuerySet, self)._clone(klass, setup, **kwargs)
 
+    def annotate(self, *args, **kwargs):
+        qset = super(InheritanceQuerySet, self).annotate(*args, **kwargs)
+        qset._annotated = [a.default_alias for a in args] + kwargs.keys()
+        return qset
+
     def iterator(self):
         iter = super(InheritanceQuerySet, self).iterator()
         if getattr(self, 'subclasses', False):
             for obj in iter:
-                obj = [getattr(obj, s) for s in self.subclasses if getattr(obj, s)] or [obj]
-                yield obj[0]
+                sub_obj = [getattr(obj, s) for s in self.subclasses if getattr(obj, s)] or [obj]
+                sub_obj = sub_obj[0]
+                if self._annotated:
+                    for k in self._annotated:
+                        setattr(sub_obj, k, getattr(obj, k))
+
+                yield sub_obj
         else:
             for obj in iter:
                 yield obj
 
+
 class InheritanceManager(models.Manager):
     use_for_related_fields = True
 

File model_utils/tests/tests.py

             self.child1)
 
 
+    def test_annotate_with_select_subclasses(self):
+        qs = InheritanceManagerTestParent.objects.select_subclasses().annotate(
+            models.Count('id'))
+        self.assertEqual(qs.get(id=self.child1.id).id__count, 1)
+
+
+    def test_annotate_with_named_arguments_with_select_subclasses(self):
+        qs = InheritanceManagerTestParent.objects.select_subclasses().annotate(
+            test_count=models.Count('id'))
+        self.assertEqual(qs.get(id=self.child1.id).test_count, 1)
+
+
+
 class TimeStampedModelTests(TestCase):
     def test_created(self):
         t1 = TimeStamp.objects.create()