Commits

Carl Meyer committed 561f79e

Improved fix for annotations and InheritanceQuerySet. Thanks Facundo Gaich.

Comments (0)

Files changed (4)

 Carl Meyer <carl@dirtcircle.com>
 Jannis Leidel <jannis@leidel.info>
+Facundo Gaich <facugaich@gmail.com>
 Gregor Müllegger <gregor@muellegger.de>
 Jeff Elmore <jeffelmore.org>
 Paul McLanahan <paul@mclanahan.net>
 tip (unreleased)
 ----------------
 
-- Fixed annotation of InheritanceQuerysets. Thanks Jeff Elmore.
+- Fixed annotation of InheritanceQuerysets. Thanks Jeff Elmore and Facundo
+  Gaich.
 
 - Dropped support for Python 2.5 and Django 1.1. Both are no longer supported
   even for security fixes, and should not be used.

model_utils/managers.py

 
 
 class InheritanceQuerySet(QuerySet):
-    def __init__(self, *args, **kwargs):
-        self._annotated = None
-        super(InheritanceQuerySet, self).__init__(*args, **kwargs)
-
     def select_subclasses(self, *subclasses):
         if not subclasses:
             subclasses = [rel.var_name for rel in self.model._meta.get_all_related_objects()
         return new_qs
 
     def _clone(self, klass=None, setup=False, **kwargs):
-        try:
-            kwargs.update({'subclasses': self.subclasses,
-                           '_annotated': self._annotated})
-        except AttributeError:
-            pass
+        for name in ['subclasses', '_annotated']:
+            if hasattr(self, name):
+                kwargs[name] = getattr(self, name)
         return super(InheritanceQuerySet, self)._clone(klass, setup, **kwargs)
 
     def annotate(self, *args, **kwargs):
             for obj in iter:
                 sub_obj = [getattr(obj, s) for s in self.subclasses if getattr(obj, s)] or [obj]
                 sub_obj = sub_obj[0]
-                if self._annotated:
+                if getattr(self, '_annotated', False):
                     for k in self._annotated:
                         setattr(sub_obj, k, getattr(obj, k))
 

model_utils/tests/tests.py

         self.assertEqual(qs.get(id=self.child1.id).test_count, 1)
 
 
+    def test_annotate_before_select_subclasses(self):
+        qs = InheritanceManagerTestParent.objects.annotate(
+            models.Count('id')).select_subclasses()
+        self.assertEqual(qs.get(id=self.child1.id).id__count, 1)
+
+
+    def test_annotate_with_named_arguments_before_select_subclasses(self):
+        qs = InheritanceManagerTestParent.objects.annotate(
+            test_count=models.Count('id')).select_subclasses()
+        self.assertEqual(qs.get(id=self.child1.id).test_count, 1)
+
+
 
 class TimeStampedModelTests(TestCase):
     def test_created(self):
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.