Commits

Carl Meyer committed 4a5f4bc

Updates to multi-level support for InheritanceManager.

Comments (0)

Files changed (7)

 Jannis Leidel <jannis@leidel.info>
 Javier García Sogo <jgsogo@gmail.com>
 Jeff Elmore <jeffelmore.org>
+ivirabyan
 Paul McLanahan <paul@mclanahan.net>
 Rinat Shigapov <rinatshigapov@gmail.com>
 Ryan Kaskel <dev@ryankaskel.com>
 tip (unreleased)
 ----------------
 
+- Added support for arbitrary levels of model inheritance in
+  InheritanceManager. Thanks ivirabyan. (This feature only works in Django
+  1.6+ due to https://code.djangoproject.com/ticket/16572).
+
+
 1.2.0 (2013.01.27)
 ------------------
 
 it's safe to use as your default manager for the model.
 
 .. note::
-    ``InheritanceManager`` currently only supports a single level of model
-    inheritance; it won't work for grandchild models.
+
+    Due to `Django bug #16572`_, on Django versions prior to 1.6
+    ``InheritanceManager`` only supports a single level of model inheritance;
+    it won't work for grandchild models.
 
 .. note::
     The implementation of ``InheritanceManager`` uses ``select_related``
 
 .. _contributed by Jeff Elmore: http://jeffelmore.org/2010/11/11/automatic-downcasting-of-inherited-models-in-django/
 .. _Django bug #16855: https://code.djangoproject.com/ticket/16855
+.. _Django bug #16572: https://code.djangoproject.com/ticket/16572
 
 
 TimeStampedModel
 TODO
 ====
 
-* Add support for multiple levels of inheritance to ``InheritanceManager``.
+* Switch to proper test skips once Django 1.3 is minimum supported.

model_utils/managers.py

-import sys
-
-from django.db import IntegrityError, models, transaction
+import django
+from django.db import models
 from django.db.models.fields.related import OneToOneField
 from django.db.models.query import QuerySet
 from django.core.exceptions import ObjectDoesNotExist
 
 try:
     from django.db.models.constants import LOOKUP_SEP
-except ImportError: # Django <= 1.5
+except ImportError: # Django < 1.5
     from django.db.models.sql.constants import LOOKUP_SEP
 
+
+
 class InheritanceQuerySet(QuerySet):
     def select_subclasses(self, *subclasses):
         if not subclasses:
-            subclasses = self._get_subclasses_recurse(self.model)
+            # only recurse one level on Django < 1.6 to avoid triggering
+            # https://code.djangoproject.com/ticket/16572
+            levels = None
+            if django.VERSION < (1, 6, 0):
+                levels = 1
+            subclasses = self._get_subclasses_recurse(self.model, levels=levels)
         new_qs = self.select_related(*subclasses)
         new_qs.subclasses = subclasses
         return new_qs
 
+
     def _clone(self, klass=None, setup=False, **kwargs):
         for name in ['subclasses', '_annotated']:
             if hasattr(self, name):
                 kwargs[name] = getattr(self, name)
         return super(InheritanceQuerySet, self)._clone(klass, setup, **kwargs)
 
+
     def annotate(self, *args, **kwargs):
         qset = super(InheritanceQuerySet, self).annotate(*args, **kwargs)
         qset._annotated = [a.default_alias for a in args] + kwargs.keys()
         return qset
 
+
     def iterator(self):
         iter = super(InheritanceQuerySet, self).iterator()
         if getattr(self, 'subclasses', False):
             for obj in iter:
                 yield obj
 
-    def _get_subclasses_recurse(self, model):
+
+    def _get_subclasses_recurse(self, model, levels=None):
         rels = [rel for rel in model._meta.get_all_related_objects()
                       if isinstance(rel.field, OneToOneField)
                       and issubclass(rel.field.model, model)]
         subclasses = []
+        if levels:
+            levels -= 1
         for rel in rels:
-            for subclass in self._get_subclasses_recurse(rel.field.model):
-                subclasses.append(rel.var_name + LOOKUP_SEP + subclass)
+            if levels or levels is None:
+                for subclass in self._get_subclasses_recurse(
+                        rel.field.model, levels=levels):
+                    subclasses.append(rel.var_name + LOOKUP_SEP + subclass)
             subclasses.append(rel.var_name)
         return subclasses
 
+
     def _get_sub_obj_recurse(self, obj, s):
         rel, _, s = s.partition(LOOKUP_SEP)
         try:
             return node
 
 
+
 class InheritanceManager(models.Manager):
     use_for_related_fields = True
 

model_utils/tests/models.py

-import django
 from django.db import models
 from django.utils.translation import ugettext_lazy as _
 
     pass
 
 
-if django.VERSION >= (1, 6, 0):
-    class InheritanceManagerTestGrandChild1(InheritanceManagerTestChild1):
-        text_field = models.TextField()
+class InheritanceManagerTestGrandChild1(InheritanceManagerTestChild1):
+    text_field = models.TextField()
 
 
 class InheritanceManagerTestChild2(InheritanceManagerTestParent):

model_utils/tests/tests.py

 from __future__ import with_statement
-import unittest
 import pickle
 
 from datetime import datetime, timedelta
 from model_utils.managers import QueryManager
 from model_utils.models import StatusModel, TimeFramedModel
 from model_utils.tests.models import (
-    InheritanceManagerTestRelated,
+    InheritanceManagerTestRelated, InheritanceManagerTestGrandChild1,
     InheritanceManagerTestParent, InheritanceManagerTestChild1,
     InheritanceManagerTestChild2, TimeStamp, Post, Article, Status,
     StatusPlainTuple, TimeFrame, Monitored, StatusManagerAdded,
     TimeFrameManagerAdded, Dude, SplitFieldAbstractParent, Car, Spot)
 
-if django.VERSION >= (1, 6, 0):
-    from model_utils.tests.models import InheritanceManagerTestGrandChild1
 
 
 class GetExcerptTests(TestCase):
     def setUp(self):
         self.child1 = InheritanceManagerTestChild1.objects.create()
         self.child2 = InheritanceManagerTestChild2.objects.create()
-        if django.VERSION >= (1, 6, 0):
-            self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create()
+        self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create()
+
 
     def get_manager(self):
         return InheritanceManagerTestParent.objects
 
 
     def test_normal(self):
-        children = set([InheritanceManagerTestParent(pk=self.child1.pk),
-                        InheritanceManagerTestParent(pk=self.child2.pk)])
-        if django.VERSION >= (1, 6, 0):
-            children.add(InheritanceManagerTestParent(pk=self.grandchild1.pk))
+        children = set([
+                InheritanceManagerTestParent(pk=self.child1.pk),
+                InheritanceManagerTestParent(pk=self.child2.pk),
+                InheritanceManagerTestParent(pk=self.grandchild1.pk),
+                ])
         self.assertEquals(set(self.get_manager().all()), children)
 
 
         children = set([self.child1, self.child2])
         if django.VERSION >= (1, 6, 0):
             children.add(self.grandchild1)
+        else:
+            children.add(InheritanceManagerTestChild1(pk=self.grandchild1.pk))
         self.assertEquals(
             set(self.get_manager().select_subclasses()), children)
 
 
     def test_select_specific_subclasses(self):
-        children = set([self.child1, InheritanceManagerTestParent(pk=self.child2.pk)])
+        children = set([
+                self.child1,
+                InheritanceManagerTestParent(pk=self.child2.pk),
+                InheritanceManagerTestChild1(pk=self.grandchild1.pk),
+                ])
+        self.assertEquals(
+            set(
+                self.get_manager().select_subclasses(
+                    "inheritancemanagertestchild1")
+                ),
+            children,
+            )
+
+
+    def test_select_specific_grandchildren(self):
         if django.VERSION >= (1, 6, 0):
-            children.add(InheritanceManagerTestChild1(pk=self.grandchild1.pk))
-        self.assertEquals(
-            set(self.get_manager().select_subclasses(
-                    "inheritancemanagertestchild1")), children)
+            children = set([
+                    self.child1,
+                    InheritanceManagerTestParent(pk=self.child2.pk),
+                    self.grandchild1,
+                    ])
+            self.assertEquals(
+                set(
+                    self.get_manager().select_subclasses(
+                        "inheritancemanagertestchild1__"
+                        "inheritancemanagertestgrandchild1"
+                        )
+                    ),
+                children,
+                )
 
-    @unittest.skipIf(django.VERSION < (1, 6, 0), "not supported in this django version")
-    def test_select_specific_grandchildren(self):
-        children = set([self.child1, InheritanceManagerTestParent(pk=self.child2.pk)])
-        if django.VERSION >= (1, 6, 0):
-            children.add(InheritanceManagerTestGrandChild1(pk=self.grandchild1.pk))
-        self.assertEquals(
-            set(self.get_manager().select_subclasses(
-                    "inheritancemanagertestchild1__inheritancemanagertestgrandchild1")), children)
 
     def test_get_subclass(self):
         self.assertEquals(
             related=self.related)
         self.child2 = InheritanceManagerTestChild2.objects.create(
             related=self.related)
-        if django.VERSION >= (1, 6, 0):
-            self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create(related=self.related)
+        self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create(related=self.related)
 
 
     def get_manager(self):