Commits

Luke Plant  committed c223189

Implemented ManyToManyFilter

This involved lots of moving stuff around, and changing API of methods in
Filter and subclasses.

  • Participants
  • Parent commits 60e0755

Comments (0)

Files changed (4)

File django_easyfilters/filterset.py

         super(Filter, self).__init__(**kwargs)
 
     def apply_filter(self, qs, params):
-        p_val = params.get(self.query_param, None)
-        if p_val is None:
-            return qs
+        p_val = self.choices_from_params(params)
+        while len(p_val) > 0:
+            qs = qs.filter(**{self.field: p_val.pop()})
+        return qs
+
+    def choices_from_params(self, params):
+        """
+        For the params passed in (i.e. from query string), retrive a list of
+        already 'chosen' options.
+        """
+        raise NotImplementedError()
+
+    def param_from_choices(self, choice):
+        """
+        For a list of choices, return the parameter that should be created.
+        """
+        raise NotImplementedError()
+
+    def build_params(self, params, add=None, remove=None):
+        params = params.copy()
+        chosen = self.choices_from_params(params)
+        if remove:
+            chosen.remove(remove)
         else:
-            return qs.filter(**{self.field: p_val})
-
-    def build_params(self, qs, params, add=None, remove=False):
-        params = params.copy()
-        if remove:
+            if add not in chosen:
+                chosen.append(add)
+        if chosen:
+            params[self.query_param] = self.param_from_choices(chosen)
+        else:
             del params[self.query_param]
-        else:
-            params[self.query_param] = add
         params.pop('page', None) # links should reset paging
         return params
 
-    def get_values_counts(self, qs, params):
-        """
-        Returns a SortedDict dictionary of {value: count}
-        """
-        values_counts = qs.values_list(self.field).order_by(self.field).annotate(models.Count(self.field))
-
-        count_dict = SortedDict()
-        for val, count in values_counts:
-            count_dict[val] = count
-        return count_dict
-
     def sort_choices(self, qs, params, choices):
         """
         Sorts the choices by applying order_by_count if applicable.
         """
         return choices
 
-    def display_choice(self, qs, params, choice):
+    def display_choice(self, choice):
         retval = unicode(choice)
         if retval == u'':
             return u'(empty)'
         else:
             return retval
 
+    def get_choices(self, qs, params):
+        """
+        Returns a list of namedtuples containing (label (as a string), count,
+        params)
+        """
+        raise NotImplementedError()
+
+
+class SingleValueFilterMixin(object):
+
+    def choices_from_params(self, params):
+        if self.query_param in params:
+            return [params[self.query_param]]
+        else:
+            return []
+
+    def param_from_choices(self, choices):
+        # There can be only one
+        return unicode(choices[0])
+
+    def get_values_counts(self, qs, params):
+        """
+        Returns a SortedDict dictionary of {value: count}.
+
+        The order is the underlying order produced by sorting ascending on the
+        DB field.
+        """
+        values_counts = qs.values_list(self.field).order_by(self.field).annotate(models.Count(self.field))
+
+        count_dict = SortedDict()
+        for val, count in values_counts:
+            count_dict[val] = count
+        return count_dict
+
+    def get_choices(self, qs, params):
+        choices_remove = self.get_choices_remove(qs, params)
+        if len(choices_remove) > 0:
+            return choices_remove
+        else:
+            choices_add = self.get_choices_add(qs, params)
+            if len(choices_add) == 1:
+                # No point giving people a choice of one
+                return []
+            else:
+                return self.sort_choices(qs, params, choices_add)
+
+    def get_choices_add(self, qs, params):
+        raise NotImplementedError()
+
+    def get_choices_remove(self, qs, params):
+        choices = self.choices_from_params(params)
+        return [FilterChoice(self.display_choice(choice),
+                             None, # Don't need count for removing
+                             self.build_params(params, remove=choice),
+                             FILTER_REMOVE)
+                for choice in choices]
+
+
+class ValuesFilter(SingleValueFilterMixin, Filter):
+    """
+    Fallback Filter for various kinds of simple values.
+    """
     def get_choices_add(self, qs, params):
         """
         Called by 'get_choices', this is usually the one to override.
         """
         count_dict = self.get_values_counts(qs, params)
-        return [FilterChoice(self.display_choice(qs, params, val),
+        return [FilterChoice(self.display_choice(val),
                              count,
-                             self.build_params(qs, params, add=val),
+                             self.build_params(params, add=val),
                              FILTER_ADD)
                 for val, count in count_dict.items()]
 
-    def get_choices(self, qs, params):
-        raise NotImplementedError()
-
-    def choice_from_params(self, qs, params):
-        return params[self.query_param]
-
-    def get_choice_remove(self, qs, params):
-        choice = self.choice_from_params(qs, params)
-        return FilterChoice(self.display_choice(qs, params, choice),
-                            None, # Don't need count for removing
-                            self.build_params(qs, params, remove=True),
-                            FILTER_REMOVE)
-
-
-class SingleValueFilterMixin(object):
-
-    def get_choices(self, qs, params):
-        """
-        Returns a list of namedtuples containing
-        (label (as a string), count, url)
-        """
-        if self.query_param in params:
-            # Already filtered on this, we just display a remove link.
-            return [self.get_choice_remove(qs, params)]
-        else:
-            choices = self.get_choices_add(qs, params)
-
-        return self.sort_choices(qs, params, choices)
-
-
-class ValuesFilter(SingleValueFilterMixin, Filter):
-    """
-    Fallback Filter for various kinds of simple values.
-    """
-    pass
-
 
 class ChoicesFilter(ValuesFilter):
     """
         # For performance we cache this rather than build in
         self.choices_dict = dict(self.field_obj.flatchoices)
 
-    def display_choice(self, qs, params, choice):
+    def display_choice(self, choice):
         # 3) above
         return self.choices_dict.get(choice, choice)
 
             if val in count_dict:
                 # We could use the value 'display' here, but for consistency
                 # call display_choice() in case it is overriden.
-                choices.append(FilterChoice(self.display_choice(qs, params, val),
+                choices.append(FilterChoice(self.display_choice(val),
                                             count_dict[val],
-                                            self.build_params(qs, params, add=val),
+                                            self.build_params(params, add=val),
                                             FILTER_ADD))
         return choices
 
         self.rel_model = self.field_obj.rel.to
         self.rel_field = self.field_obj.rel.get_related_field()
 
-    def choice_from_params(self, qs, params):
-        lookup = {self.rel_field.attname: params[self.query_param]}
+    def display_choice(self, choice):
+        lookup = {self.rel_field.name: choice}
         return self.rel_model.objects.get(**lookup)
 
     def get_choices_add(self, qs, params):
         count_dict = self.get_values_counts(qs, params)
-        lookup = {self.rel_field.attname + '__in': count_dict.keys()}
+        lookup = {self.rel_field.name + '__in': count_dict.keys()}
         objs = self.rel_model.objects.filter(**lookup)
         choices = []
 
             pk = getattr(o, self.rel_field.attname)
             choices.append(FilterChoice(unicode(o),
                                         count_dict[pk],
-                                        self.build_params(qs, params, add=pk),
+                                        self.build_params(params, add=pk),
                                         FILTER_ADD))
         return choices
 
 
+class MultiValueFilterMixin(object):
+
+    def choices_from_params(self, params):
+        if self.query_param in params:
+            return map(int, params[self.query_param].split(','))
+        else:
+            return []
+
+    def param_from_choices(self, choices):
+        return ','.join(map(unicode, choices))
+
+    def get_choices(self, qs, params):
+        # In general, can filter multiple times, so we can have multiple remove
+        # links, and multiple add links, at the same time.
+        choices_remove = self.get_choices_remove(qs, params)
+        choices_add = self.get_choices_add(qs, params)
+        if len(choices_add) == 1:
+            # No point adding a filter of nothing
+            choices_add = []
+        else:
+            choices_add = self.sort_choices(qs, params, choices_add)
+        return choices_remove + choices_add
+
+
+class ManyToManyFilter(MultiValueFilterMixin, Filter):
+    def __init__(self, *args, **kwargs):
+        super(ManyToManyFilter, self).__init__(*args, **kwargs)
+        self.rel_model = self.field_obj.rel.to
+
+    def get_choices_add(self, qs, params):
+        # It is easiest to base queries around the intermediate table, in order
+        # to get counts.
+        through = self.field_obj.rel.through
+        rel_model = self.rel_model
+
+        assert rel_model != self.model, "Can't cope with this yet..."
+
+        fkey_to_this_table = [f for f in through._meta.fields
+                              if f.rel is not None and f.rel.to is self.model][0]
+        fkey_to_other_table = [f for f in through._meta.fields
+                               if f.rel is not None and f.rel.to is rel_model][0]
+
+        # We need to limit items by what is in the main QuerySet (which might
+        # already be filtered).
+        main_filter = {fkey_to_this_table.name + '__in':qs}
+        m2m_objs = through.objects.filter(**main_filter)
+
+        # We need to exclude items in other table that we have already filtered
+        # on, because they are not interesting.
+        exclude_filter = {fkey_to_other_table.name + '__in': self.choices_from_params(params)}
+        m2m_objs = m2m_objs.exclude(**exclude_filter)
+
+        # Now get counts:
+        field_name = fkey_to_other_table.name
+        values_counts = m2m_objs.values_list(field_name).order_by(field_name).annotate(models.Count(field_name))
+
+        count_dict = SortedDict()
+        for val, count in values_counts:
+            count_dict[val] = count
+
+        # Now, need to lookup objects on related table, to display them.
+        objs = rel_model.objects.filter(pk__in=count_dict.keys())
+
+        choices = []
+        for o in objs:
+            pk = o.pk
+            choices.append(FilterChoice(unicode(o),
+                                        count_dict[pk],
+                                        self.build_params(params, add=pk),
+                                        FILTER_ADD))
+        return choices
+
+
+    def get_choices_remove(self, qs, params):
+        choices = self.choices_from_params(params)
+        # Do a query in bulk to get objs corresponding to choices.
+        objs = self.rel_model.objects.filter(pk__in=choices)
+
+        # We want to preserve order of items in params, so use a dict:
+        obj_dict = dict([(obj.pk, obj) for obj in objs])
+
+        return [FilterChoice(unicode(obj_dict[choice]),
+                             None, # Don't need count for removing
+                             self.build_params(params, remove=choice),
+                             FILTER_REMOVE)
+                for choice in choices]
+
+
 def non_breaking_spaces(val):
     return u' '.join(escape(part) for part in val.split(u' '))
 
 class FilterSet(object):
 
     def __init__(self, queryset, params, request=None):
-        self.params = params
+        self.params = dict(params.items())
         self.initial_queryset = queryset
         self.model = queryset.model
         self.filters = self.setup_filters()
         return self.fields
 
     def get_filter_for_field(self, field, **kwargs):
-        f = self.model._meta.get_field(field)
+        f, model, direct, m2m = self.model._meta.get_field_by_name(field)
         if f.rel is not None:
-            return RelatedFilter(field, self.model, **kwargs)
+            if m2m:
+                klass = ManyToManyFilter
+            else:
+                klass = RelatedFilter
         elif f.choices:
-            return ChoicesFilter(field, self.model, **kwargs)
+            klass = ChoicesFilter
         else:
-            return ValuesFilter(field, self.model, **kwargs)
+            klass = ValuesFilter
+        return klass(field, self.model, **kwargs)
 
     def setup_filters(self):
         filters = []

File django_easyfilters/tests/filterset.py

+# -*- coding: utf-8; -*-
+
 import decimal
 import operator
 
 from django.test import TestCase
 from django_easyfilters.filterset import FilterSet, FilterOptions, FILTER_ADD, FILTER_REMOVE, \
-    RelatedFilter, ValuesFilter, ChoicesFilter
+    RelatedFilter, ValuesFilter, ChoicesFilter, ManyToManyFilter
 
-from models import Book, Genre, BINDING_CHOICES
+from models import Book, Genre, Author, BINDING_CHOICES
 
 
 class TestFilterSet(TestCase):
                 'genre',
                 'edition',
                 'binding',
+                'authors',
                 ]
 
         fs = BookFilterSet(Book.objects.all(), {})
         self.assertEqual(RelatedFilter, type(fs.filters[0]))
         self.assertEqual(ValuesFilter, type(fs.filters[1]))
         self.assertEqual(ChoicesFilter, type(fs.filters[2]))
+        self.assertEqual(ManyToManyFilter, type(fs.filters[3]))
 
 
 class TestFilters(TestCase):
                 self.assertEqual(unicode(book.genre), choice.label)
         self.assertTrue(reached)
 
-
     def test_foreignkey_remove_link(self):
         """
         Ensure that a ForeignKey Filter will turn into a 'remove' link when an
         for c in choices:
             self.assertTrue(c.params.values()[0] in binding_choices_db)
 
+    def test_manytomany_filter(self):
+        """
+        Tests for ManyToManyFilter
+        """
+        filter_ = ManyToManyFilter('authors', Book)
+        qs = Book.objects.all()
+
+        # ManyToMany can have 'drill down', i.e. multiple levels of filtering,
+        # which can be removed individually.
+
+        # First level:
+        choices = filter_.get_choices(qs, {})
+
+        # Check list is full, and in right order
+        self.assertEqual([unicode(v) for v in Author.objects.all()],
+                         [choice.label for choice in choices])
+
+        param_to_list = lambda param: map(int, param.split(','))
+
+        for choice in choices:
+            # For single choice, param will be single integer:
+            param = int(choice.params[filter_.query_param])
+
+            # Check the count
+            count = Book.objects.filter(authors=int(param)).count()
+            self.assertEqual(choice.count, count)
+
+            author = Author.objects.get(id=param)
+
+            # Check the label
+            self.assertEqual(unicode(author),
+                             choice.label)
+
+            # Check the filtering
+            qs_filtered = filter_.apply_filter(qs, choice.params)
+            self.assertEqual(len(qs_filtered), choice.count)
+
+            for book in qs_filtered:
+                self.assertTrue(author in book.authors.all())
+
+            # Check we've got a 'remove link' on filtered.
+            choices_filtered = filter_.get_choices(qs, choice.params)
+            self.assertEqual(choices_filtered[0].link_type, FILTER_REMOVE)
+
+
+    def test_manytomany_filter_multiple(self):
+        filter_ = ManyToManyFilter('authors', Book)
+        qs = Book.objects.all()
+
+        # Specific example - multiple filtering
+        emily = Author.objects.get(name='Emily Brontë')
+        charlotte = Author.objects.get(name='Charlotte Brontë')
+        anne = Author.objects.get(name='Anne Brontë')
+
+        # If we select 'emily' as an author:
+
+        data =  {'authors':str(emily.pk)}
+        qs_emily = filter_.apply_filter(qs, data)
+
+        # ...we should get a qs that includes Poems and Wuthering Heights.
+        self.assertTrue(qs_emily.filter(name='Poems').exists())
+        self.assertTrue(qs_emily.filter(name='Wuthering Heights').exists())
+        # ...and excludes Jane Eyre
+        self.assertFalse(qs_emily.filter(name='Jane Eyre').exists())
+
+        # We should get a 'choices' that includes charlotte and anne
+        choices = filter_.get_choices(qs_emily, data)
+        self.assertTrue(unicode(anne) in [c.label for c in choices if c.link_type is FILTER_ADD])
+        self.assertTrue(unicode(charlotte) in [c.label for c in choices if c.link_type is FILTER_ADD])
+
+        # ... but not emily, because that is obvious and boring
+        self.assertTrue(unicode(emily) not in [c.label for c in choices if c.link_type is FILTER_ADD])
+        # emily should be in 'remove' links, however.
+        self.assertTrue(unicode(emily) in [c.label for c in choices if c.link_type is FILTER_REMOVE])
+
+        # If we select again:
+        data =  {'authors': ','.join([str(emily.pk), str(anne.pk)])}
+
+        qs_emily_anne = filter_.apply_filter(qs, data)
+
+        # ...we should get a qs that includes Poems
+        self.assertTrue(qs_emily_anne.filter(name='Poems').exists())
+        # ... but not Wuthering Heights
+        self.assertFalse(qs_emily_anne.filter(name='Wuthering Heights').exists())
+
+        # The choices should contain just emily and anne, to remove, but not
+        # charlotte to add, because there is no point adding a filter
+        # when it is the only choice.
+        choices = filter_.get_choices(qs_emily_anne, data)
+        self.assertEqual([(c.label, c.link_type) for c in choices],
+                         [(unicode(emily), FILTER_REMOVE),
+                          (unicode(anne), FILTER_REMOVE)])
+
     def test_order_by_count(self):
         """
         Tests the 'order_by_count' option.

File django_easyfilters/tests/fixtures/django_easyfilters_tests.json

 [
   {
+    "pk": 4, 
+    "model": "tests.author", 
+    "fields": {
+      "name": "A. A. Milne"
+    }
+  }, 
+  {
+    "pk": 6, 
+    "model": "tests.author", 
+    "fields": {
+      "name": "Anne Bront\u00eb"
+    }
+  }, 
+  {
+    "pk": 2, 
+    "model": "tests.author", 
+    "fields": {
+      "name": "Charlotte Bront\u00eb"
+    }
+  }, 
+  {
+    "pk": 5, 
+    "model": "tests.author", 
+    "fields": {
+      "name": "E. H. Shepard"
+    }
+  }, 
+  {
+    "pk": 7, 
+    "model": "tests.author", 
+    "fields": {
+      "name": "Emily Bront\u00eb"
+    }
+  }, 
+  {
     "pk": 1, 
     "model": "tests.author", 
     "fields": {
     }
   }, 
   {
-    "pk": 2, 
-    "model": "tests.author", 
-    "fields": {
-      "name": "Charlotte Bronte"
-    }
-  }, 
-  {
     "pk": 3, 
     "model": "tests.author", 
     "fields": {
     }
   }, 
   {
-    "pk": 4, 
-    "model": "tests.author", 
-    "fields": {
-      "name": "A. A. Milne"
-    }
-  }, 
-  {
-    "pk": 5, 
-    "model": "tests.author", 
-    "fields": {
-      "name": "E. H. Shepard"
-    }
-  }, 
-  {
     "pk": 3, 
     "model": "tests.genre", 
     "fields": {
     }
   }, 
   {
+    "pk": 4, 
+    "model": "tests.genre", 
+    "fields": {
+      "name": "Poetry"
+    }
+  }, 
+  {
     "pk": 1, 
     "model": "tests.genre", 
     "fields": {
       ], 
       "genre": 3
     }
+  }, 
+  {
+    "pk": 6, 
+    "model": "tests.book", 
+    "fields": {
+      "name": "Poems", 
+      "edition": 1, 
+      "price": "4.5", 
+      "binding": "H", 
+      "date_published": "1846-05-01", 
+      "authors": [
+        6, 
+        2, 
+        7
+      ], 
+      "genre": 4
+    }
+  }, 
+  {
+    "pk": 7, 
+    "model": "tests.book", 
+    "fields": {
+      "name": "Wuthering Heights", 
+      "edition": 1, 
+      "price": "44.99", 
+      "binding": "C", 
+      "date_published": "1847-06-22", 
+      "authors": [
+        7
+      ], 
+      "genre": 1
+    }
   }
 ]

File django_easyfilters/tests/models.py

     def __unicode__(self):
         return self.name
 
+    class Meta:
+        ordering = ['name']
 
 class Genre(models.Model):
     name = models.CharField(max_length=50)