Commits

Carl Meyer committed 1c9195f

Added use_for_related_fields=True to InheritanceManager. Fixes #8. Thanks munhitsu for the report.

Note that the tests added for this feature pass even without the change to
InheritanceManager, because use_for_related_fields is broken in Django and
always acts as if True for reverse FKs and M2Ms.
(http://code.djangoproject.com/ticket/14891)

Comments (0)

Files changed (3)

model_utils/managers.py

                 yield obj
 
 class InheritanceManager(models.Manager):
+    use_for_related_fields = True
+
     def get_query_set(self):
         return InheritanceQuerySet(self.model)
 

model_utils/tests/models.py

 from model_utils.fields import SplitField, MonitorField
 from model_utils import Choices
 
+
+
 class InheritParent(InheritanceCastModel):
-    non_related_field_using_descriptor = models.FileField(upload_to='test')
+    non_related_field_using_descriptor = models.FileField(upload_to="test")
     normal_field = models.TextField()
     pass
 
+
+
 class InheritChild(InheritParent):
-    non_related_field_using_descriptor_2 = models.FileField(upload_to='test')
+    non_related_field_using_descriptor_2 = models.FileField(upload_to="test")
     normal_field_2 = models.TextField()
     pass
 
+
+
 class InheritChild2(InheritParent):
-    non_related_field_using_descriptor_3 = models.FileField(upload_to='test')
+    non_related_field_using_descriptor_3 = models.FileField(upload_to="test")
     normal_field_3 = models.TextField()
     pass
 
+
+
+class InheritanceManagerTestRelated(models.Model):
+    pass
+
+
+
 class InheritanceManagerTestParent(models.Model):
-    # test for #6
-    # I'm using FileField, because it will always use descriptor
-    non_related_field_using_descriptor = models.FileField(upload_to='test')
+    # FileField is just a handy descriptor-using field. Refs #6.
+    non_related_field_using_descriptor = models.FileField(upload_to="test")
+    related = models.ForeignKey(
+        InheritanceManagerTestRelated, related_name="imtests", null=True)
     normal_field = models.TextField()
     objects = InheritanceManager()
 
+
+
 class InheritanceManagerTestChild1(InheritanceManagerTestParent):
-    non_related_field_using_descriptor_2 = models.FileField(upload_to='test')
+    non_related_field_using_descriptor_2 = models.FileField(upload_to="test")
     normal_field_2 = models.TextField()
     pass
 
+
+
 class InheritanceManagerTestChild2(InheritanceManagerTestParent):
-    non_related_field_using_descriptor_2 = models.FileField(upload_to='test')
+    non_related_field_using_descriptor_2 = models.FileField(upload_to="test")
     normal_field_2 = models.TextField()
     pass
 
+
+
 class TimeStamp(TimeStampedModel):
     pass
 
+
+
 class TimeFrame(TimeFramedModel):
     pass
 
+
+
 class TimeFrameManagerAdded(TimeFramedModel):
     pass
 
+
+
 class Monitored(models.Model):
     name = models.CharField(max_length=25)
-    name_changed = MonitorField(monitor='name')
+    name_changed = MonitorField(monitor="name")
+
+
 
 class Status(StatusModel):
     STATUS = Choices(
-        ('active', _('active')),
-        ('deleted', _('deleted')),
-        ('on_hold', _('on hold')),
+        ("active", _("active")),
+        ("deleted", _("deleted")),
+        ("on_hold", _("on hold")),
     )
 
+
+
 class StatusPlainTuple(StatusModel):
     STATUS = (
-        ('active', _('active')),
-        ('deleted', _('deleted')),
-        ('on_hold', _('on hold')),
+        ("active", _("active")),
+        ("deleted", _("deleted")),
+        ("on_hold", _("on hold")),
     )
 
+
+
 class StatusManagerAdded(StatusModel):
     STATUS = (
-        ('active', _('active')),
-        ('deleted', _('deleted')),
-        ('on_hold', _('on hold')),
+        ("active", _("active")),
+        ("deleted", _("deleted")),
+        ("on_hold", _("on hold")),
     )
 
+
+
 class Post(models.Model):
     published = models.BooleanField()
     confirmed = models.BooleanField()
     public = QueryManager(published=True)
     public_confirmed = QueryManager(models.Q(published=True) &
                                     models.Q(confirmed=True))
-    public_reversed = QueryManager(published=True).order_by('-order')
+    public_reversed = QueryManager(published=True).order_by("-order")
 
     class Meta:
-        ordering = ('order',)
+        ordering = ("order",)
+
+
 
 class Article(models.Model):
     title = models.CharField(max_length=50)
     body = SplitField()
 
+
+
 class NoRendered(models.Model):
     """
     Test that the no_excerpt_field keyword arg works. This arg should
     """
     body = SplitField(no_excerpt_field=True)
 
+
+
 class AuthorMixin(object):
     def by_author(self, name):
         return self.filter(author=name)
 
+
+
 class PublishedMixin(object):
     def published(self):
         return self.filter(published=True)
 
+
+
 def unpublished(self):
     return self.filter(published=False)
 
+
+
 class ByAuthorQuerySet(models.query.QuerySet, AuthorMixin):
     pass
 
+
+
 class FeaturedManager(models.Manager):
     def get_query_set(self):
         kwargs = {}
-        if hasattr(self, '_db'):
-            kwargs['using'] = self._db
+        if hasattr(self, "_db"):
+            kwargs["using"] = self._db
         return ByAuthorQuerySet(self.model, **kwargs).filter(feature=True)
 
+
+
 class Entry(models.Model):
     author = models.CharField(max_length=20)
     published = models.BooleanField()
                             manager_cls=FeaturedManager,
                             queryset_cls=ByAuthorQuerySet)
 
+
+
 class DudeQuerySet(models.query.QuerySet):
     def abiding(self):
         return self.filter(abides=True)
     def by_name(self, name):
         return self.filter(name__iexact=name)
 
+
+
 class AbidingManager(PassThroughManager):
     def get_query_set(self):
         return DudeQuerySet(self.model).abiding()
 
     def get_stats(self):
         return {
-            'abiding_count': self.count(),
-            'rug_count': self.rug_positive().count(),
+            "abiding_count": self.count(),
+            "rug_count": self.rug_positive().count(),
         }
 
+
+
 class Dude(models.Model):
     abides = models.BooleanField(default=True)
     name = models.CharField(max_length=20)

model_utils/tests/tests.py

+from __future__ import with_statement
+
 import pickle, sys, warnings
 
 from datetime import datetime, timedelta
 
 import django
 from django.test import TestCase
-from django.conf import settings
 from django.db import models
 from django.db.models.fields import FieldDoesNotExist
 from django.core.exceptions import ImproperlyConfigured
 from model_utils.managers import QueryManager, manager_from
 from model_utils.models import StatusModel, TimeFramedModel
 from model_utils.tests.models import (
-    InheritParent, InheritChild, InheritChild2, InheritanceManagerTestParent,
-    InheritanceManagerTestChild1, InheritanceManagerTestChild2,
-    TimeStamp, Post, Article, Status, StatusPlainTuple, TimeFrame, Monitored,
-    StatusManagerAdded, TimeFrameManagerAdded, Entry, Dude)
+    InheritParent, InheritChild, InheritChild2, InheritanceManagerTestRelated,
+    InheritanceManagerTestParent, InheritanceManagerTestChild1,
+    InheritanceManagerTestChild2, TimeStamp, Post, Article, Status,
+    StatusPlainTuple, TimeFrame, Monitored, StatusManagerAdded,
+    TimeFrameManagerAdded, Entry, Dude)
 
 
 
             self.child2 = InheritanceManagerTestChild2.objects.create()
 
 
+        def get_manager(self):
+            return InheritanceManagerTestParent.objects
+
+
         def test_normal(self):
-            self.assertEquals(set(InheritanceManagerTestParent.objects.all()),
+            self.assertEquals(set(self.get_manager().all()),
                               set([
                         InheritanceManagerTestParent(pk=self.child1.pk),
                         InheritanceManagerTestParent(pk=self.child2.pk),
 
         def test_select_all_subclasses(self):
             self.assertEquals(
-                set(InheritanceManagerTestParent.objects.select_subclasses()),
+                set(self.get_manager().select_subclasses()),
                 set([self.child1, self.child2]))
 
 
         def test_select_specific_subclasses(self):
             self.assertEquals(
-                set(InheritanceManagerTestParent.objects.select_subclasses(
+                set(self.get_manager().select_subclasses(
                         "inheritancemanagertestchild1")),
                 set([self.child1,
                      InheritanceManagerTestParent(pk=self.child2.pk)]))
 
 
+    class InheritanceManagerRelatedTests(InheritanceManagerTests):
+        def setUp(self):
+            self.related = InheritanceManagerTestRelated.objects.create()
+            self.child1 = InheritanceManagerTestChild1.objects.create(
+                related=self.related)
+            self.child2 = InheritanceManagerTestChild2.objects.create(
+                related=self.related)
+
+
+        def get_manager(self):
+            return self.related.imtests
+
+
 
 class TimeStampedModelTests(TestCase):
     def test_created(self):
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.