Commits

Carl Meyer committed dccfb7b

Added InheritanceManager, contributed by Jeff Elmore.

Comments (0)

Files changed (7)

 Carl Meyer <carl@dirtcircle.com>
 Jannis Leidel <jannis@leidel.info>
 Gregor Müllegger <gregor@muellegger.de>
+Jeff Elmore <jeffelmore.org>
 tip (unreleased)
 ----------------
 
+- added InheritanceManager, a better approach to selecting subclass instances
+  for Django 1.2+. Thanks Jeff Elmore.
+
 - added InheritanceCastManager and InheritanceCastQuerySet, to allow bulk
   casting of a queryset to child types.  Thanks Gregor Müllegger.
 
 Dependencies
 ------------
 
-``django-model-utils`` requires `Django`_ 1.0 or later.
+Most of ``django-model-utils`` works with `Django`_ 1.0 or later.
+`InheritanceManager`_ requires Django 1.2 or later.
 
 .. _Django: http://www.djangoproject.com/
 
     # this query will only return published articles:
     Article.published.all()
 
-InheritanceCastModel
-====================
+InheritanceManager
+==================
 
-This abstract base class can be inherited by the root (parent) model
-in a model-inheritance tree.  It allows each model in the tree to
-"know" what type it is (via an automatically-set foreign key to
-``ContentType``), allowing for automatic casting of a parent instance
-to its proper leaf (child) type.
+This manager (`contributed by Jeff Elmore`_) should be attached to a base model
+class in a model-inheritance tree.  It allows queries on that base model to
+return heterogenous results of the actual proper subtypes, without any
+additional queries.
 
-For instance, if you have a ``Place`` model with subclasses
-``Restaurant`` and ``Bar``, you may want to query all Places::
+For instance, if you have a ``Place`` model with subclasses ``Restaurant`` and
+``Bar``, you may want to query all Places::
 
     nearby_places = Place.objects.filter(location='here')
 
 But when you iterate over ``nearby_places``, you'll get only ``Place``
-instances back, even for objects that are "really" ``Restaurant`` or
-``Bar``.  If you have ``Place`` inherit from ``InheritanceCastModel``,
-you can just call the ``cast()`` method on each ``Place`` and it will
-return an instance of the proper subtype, ``Restaurant`` or ``Bar``::
+instances back, even for objects that are "really" ``Restaurant`` or ``Bar``.
+If you attach an ``InheritanceManager`` to ``Place``, you can just call the
+``select_subclasses()`` method on the ``InheritanceManager`` or any
+``QuerySet`` from it, and the resulting objects will be instances of
+``Restaurant`` or ``Bar``::
+
+    from model_utils.managers import InheritanceManager
+
+    class Place(models.Model):
+        # ...
+        objects = InheritanceManager()
+
+    class Restaurant(Place):
+        # ...
+
+    class Bar(Place):
+        # ...
+
+    nearby_places = Place.objects.filter(location='here').select_subclasses()
+    for place in nearby_places:
+        # "place" will automatically be an instance of Place, Restaurant, or Bar
+
+The database query performed will have an extra join for each subclass; if you
+want to reduce the number of joins and you only need particular subclasses to
+be returned as their actual type, you can pass subclass names to
+``select_subclasses()``, much like the built-in ``select_related()`` method::
+
+    nearby_places = Place.objects.select_subclasses("restaurant")
+    # restaurants will be Restaurant instances, bars will still be Place instances
+
+If you don't explicitly call ``select_subclasses()``, an ``InheritanceManager``
+behaves identically to a normal ``Manager``; so it's safe to use as your
+default manager for the model.
+
+.. note::
+    ``InheritanceManager`` currently only supports a single level of model
+    inheritance; it won't work for grandchild models.
+
+.. note::
+    ``InheritanceManager`` requires Django 1.2 or later.
+
+.. _contributed by Jeff Elmore: http://jeffelmore.org/2010/11/11/automatic-downcasting-of-inherited-models-in-django/
+
+
+InheritanceCastModel
+====================
+
+This abstract base class can be inherited by the root (parent) model in a
+model-inheritance tree. It solves the same problem as `InheritanceManager`_ in
+a way that requires more database queries and is less convenient to use, but is
+compatible with Django versions prior to 1.2. Whenever possible,
+`InheritanceManager`_ should be used instead.
+
+Usage::
 
     from model_utils.models import InheritanceCastModel
 
     class Place(InheritanceCastModel):
         # ...
-    
+
     class Restaurant(Place):
         # ...
 
+    class Bar(Place):
+        # ...
+
     nearby_places = Place.objects.filter(location='here')
     for place in nearby_places:
-        restaurant_or_bar = place.cast()
-        # ...
+        restaurant_or_bar = place.cast() # ...
 
 This is inefficient for large querysets, as it results in a new query for every
 individual returned object.  You can use the ``cast()`` method on a queryset to
     for place in nearby_places.cast():
         # ...
 
-.. note:: The ``cast()`` queryset method does *not* return another
-    queryset but an already evaluated result of the database query.  This means
-    that you cannot chain additional queryset methods after ``cast()``.
+.. note::
+    The ``cast()`` queryset method does *not* return another queryset but an
+    already evaluated result of the database query.  This means that you cannot
+    chain additional queryset methods after ``cast()``.
 
 TimeStampedModel
 ================
 
 This abstract base class just provides self-updating ``created`` and
-``modified`` fields on any model that inherits from it.        
-  
+``modified`` fields on any model that inherits from it.
 
 QueryManager
 ============
 TODO
 ====
 
-* A version of InheritanceCastModel for 1.2+ (with reverse OneToOne
-  select_related now available) that doesn't require the added real_type
-  field.
+* Add support for multiple levels of inheritance to ``InheritanceManager``.

model_utils/managers.py

 
 from django.contrib.contenttypes.models import ContentType
 from django.db import models
+from django.db.models.fields.related import SingleRelatedObjectDescriptor
 from django.db.models.manager import Manager
 from django.db.models.query import QuerySet
 
+class InheritanceQuerySet(QuerySet):
+    def select_subclasses(self, *subclasses):
+        if not subclasses:
+            subclasses = [o for o in dir(self.model)
+                          if isinstance(getattr(self.model, o), SingleRelatedObjectDescriptor)
+                          and issubclass(getattr(self.model,o).related.model, self.model)]
+        new_qs = self.select_related(*subclasses)
+        new_qs.subclasses = subclasses
+        return new_qs
+
+    def _clone(self, klass=None, setup=False, **kwargs):
+        try:
+            kwargs.update({'subclasses': self.subclasses})
+        except AttributeError:
+            pass
+        return super(InheritanceQuerySet, self)._clone(klass, setup, **kwargs)
+
+    def iterator(self):
+        iter = super(InheritanceQuerySet, self).iterator()
+        if getattr(self, 'subclasses', False):
+            for obj in iter:
+                obj = [getattr(obj, s) for s in self.subclasses if getattr(obj, s)] or [obj]
+                yield obj[0]
+        else:
+            for obj in iter:
+                yield obj
+
+class InheritanceManager(models.Manager):
+    def get_query_set(self):
+        return InheritanceQuerySet(self.model)
+
+    def select_subclasses(self, *subclasses):
+        return self.get_query_set().select_subclasses(*subclasses)
+
 
 class InheritanceCastMixin(object):
     def cast(self):

model_utils/tests/models.py

 from django.utils.translation import ugettext_lazy as _
 
 from model_utils.models import InheritanceCastModel, TimeStampedModel, StatusModel, TimeFramedModel
-from model_utils.managers import QueryManager, manager_from
+from model_utils.managers import QueryManager, manager_from, InheritanceManager
 from model_utils.fields import SplitField, MonitorField
 from model_utils import Choices
 
 class InheritChild2(InheritParent):
     pass
 
+class InheritanceManagerTestParent(models.Model):
+    objects = InheritanceManager()
+
+class InheritanceManagerTestChild1(InheritanceManagerTestParent):
+    pass
+
+class InheritanceManagerTestChild2(InheritanceManagerTestParent):
+    pass
+
 class TimeStamp(TimeStampedModel):
     pass
 

model_utils/tests/tests.py

 from model_utils.fields import get_excerpt, MonitorField
 from model_utils.managers import QueryManager, manager_from
 from model_utils.models import StatusModel, TimeFramedModel
-from model_utils.tests.models import (InheritParent, InheritChild, InheritChild2,
-                                      TimeStamp, Post, Article, Status,
-                                      StatusPlainTuple, TimeFrame, Monitored,
-                                      StatusManagerAdded, TimeFrameManagerAdded,
-                                      Entry)
+from model_utils.tests.models import (
+    InheritParent, InheritChild, InheritChild2, InheritanceManagerTestParent,
+    InheritanceManagerTestChild1, InheritanceManagerTestChild2,
+    TimeStamp, Post, Article, Status, StatusPlainTuple, TimeFrame, Monitored,
+    StatusManagerAdded, TimeFrameManagerAdded,  Entry)
 
 
 class GetExcerptTests(TestCase):
     def test_middle_of_line(self):
         e = get_excerpt("some text <!-- split --> more text")
         self.assertEquals(e, "some text <!-- split --> more text")
-    
+
 class SplitFieldTests(TestCase):
     full_text = u'summary\n\n<!-- split -->\n\nmore'
     excerpt = u'summary\n'
-    
+
     def setUp(self):
         self.post = Article.objects.create(
             title='example post', body=self.full_text)
         post = Article.objects.create(title='example 2',
                                       body='some text\n\nsome more\n')
         self.failIf(post.body.has_more)
-        
+
     def test_load_back(self):
         post = Article.objects.get(pk=self.post.pk)
         self.assertEquals(post.body.content, self.post.body.content)
     def test_no_monitor_arg(self):
         self.assertRaises(TypeError, MonitorField)
 
-        
+
 class ChoicesTests(TestCase):
     def setUp(self):
         self.STATUS = Choices('DRAFT', 'PUBLISHED')
     def test_wrong_length_tuple(self):
         self.assertRaises(ValueError, Choices, ('a',))
 
-        
+
 class LabelChoicesTests(ChoicesTests):
     def setUp(self):
         self.STATUS = Choices(
                           "('PUBLISHED', 'PUBLISHED', 'is published'), "
                           "('DELETED', 'DELETED', 'DELETED'))")
 
-        
+
 class IdentifierChoicesTests(ChoicesTests):
     def setUp(self):
         self.STATUS = Choices(
 
     def test_getattr(self):
         self.assertEquals(self.STATUS.DRAFT, 0)
-        
+
     def test_repr(self):
         self.assertEquals(repr(self.STATUS),
                           "Choices("
                           "(1, 'PUBLISHED', 'is published'), "
                           "(2, 'DELETED', 'is deleted'))")
 
-        
+
 class InheritanceCastModelTests(TestCase):
     def setUp(self):
         self.parent = InheritParent.objects.create()
         self.child = InheritChild.objects.create()
-    
+
     def test_parent_real_type(self):
         self.assertEquals(self.parent.real_type,
                           ContentType.objects.get_for_model(InheritParent))
                           set([parent, self.child, self.child2]))
 
 
+class InheritanceManagerTests(TestCase):
+    def setUp(self):
+        self.child1 = InheritanceManagerTestChild1.objects.create()
+        self.child2 = InheritanceManagerTestChild2.objects.create()
+
+    def test_normal(self):
+        self.assertEquals(set(InheritanceManagerTestParent.objects.all()),
+                          set([
+                    InheritanceManagerTestParent(pk=self.child1.pk),
+                    InheritanceManagerTestParent(pk=self.child2.pk),
+                    ]))
+
+    def test_select_all_subclasses(self):
+        self.assertEquals(set(InheritanceManagerTestParent.objects.select_subclasses()),
+                          set([self.child1, self.child2]))
+
+    def test_select_specific_subclasses(self):
+        self.assertEquals(set(InheritanceManagerTestParent.objects.select_subclasses(
+                    "inheritancemanagertestchild1")),
+                          set([self.child1, InheritanceManagerTestParent(pk=self.child2.pk)]))
+
+
 class TimeStampedModelTests(TestCase):
     def test_created(self):
         t1 = TimeStamp.objects.create()
 
     def setUp(self):
         self.now = datetime.now()
-    
+
     def test_not_yet_begun(self):
         TimeFrame.objects.create(start=self.now+timedelta(days=2))
         self.assertEquals(TimeFrame.timeframed.count(), 0)
             class ErrorModel(TimeFramedModel):
                 timeframed = models.BooleanField()
         self.assertRaises(ImproperlyConfigured, _run)
-        
-                
+
+
 class StatusModelTests(TestCase):
     def setUp(self):
         self.model = Status
         t1.save()
         self.assert_(t1.status_changed > date_active_again)
 
-        
+
 class StatusModelPlainTupleTests(StatusModelTests):
     def setUp(self):
         self.model = StatusPlainTuple
                     )
                 active = models.BooleanField()
         self.assertRaises(ImproperlyConfigured, _run)
-                
+
 
 class QueryManagerTests(TestCase):
     def setUp(self):
             mf = Article._meta.get_field('body')
             args, kwargs = introspector(mf)
             self.assertEquals(kwargs['no_excerpt_field'], 'True')
-        
+
         def test_no_excerpt_field_works(self):
             from models import NoRendered
             self.assertRaises(FieldDoesNotExist,
         Entry.objects.create(author='George', published=True)
         Entry.objects.create(author='George', published=False)
         Entry.objects.create(author='Paul', published=True, feature=True)
-    
+
     def test_chaining(self):
         self.assertEqual(Entry.objects.by_author('George').published().count(),
                          1)
 
     def test_typecheck(self):
         self.assertRaises(TypeError, manager_from, 'somestring')
-        
+
     def test_custom_get_query_set(self):
         self.assertEqual(Entry.featured.published().count(), 1)