Commits

Luke Plant committed 117e5d0

Small code cleanups.

Comments (0)

Files changed (1)

django_easyfilters/filters.py

 
 class ManyToManyFilter(ChooseAgainMixin, RelatedObjectMixin, Filter):
 
-    def get_choices_add(self, qs):
+    def get_values_counts(self, qs):
         # It is easiest to base queries around the intermediate table, in order
         # to get counts.
         through = self.field_obj.rel.through
         rel_model = self.rel_model
 
         assert rel_model != self.model, "Can't cope with this yet..."
-
-        fkey_to_this_table = [f for f in through._meta.fields
-                              if f.rel is not None and f.rel.to is self.model][0]
-        fkey_to_other_table = [f for f in through._meta.fields
-                               if f.rel is not None and f.rel.to is rel_model][0]
+        fkey_this = [f for f in through._meta.fields
+                     if f.rel is not None and f.rel.to is self.model][0]
+        fkey_other = [f for f in through._meta.fields
+                      if f.rel is not None and f.rel.to is rel_model][0]
 
         # We need to limit items by what is in the main QuerySet (which might
         # already be filtered).
-        main_filter = {fkey_to_this_table.name + '__in':qs}
-        m2m_objs = through.objects.filter(**main_filter)
+        m2m_objs = through.objects.filter(**{fkey_this.name + '__in':qs})
 
         # We need to exclude items in other table that we have already filtered
         # on, because they are not interesting.
-        exclude_filter = {fkey_to_other_table.name + '__in': self.chosen}
-        m2m_objs = m2m_objs.exclude(**exclude_filter)
+        m2m_objs = m2m_objs.exclude(**{fkey_other.name + '__in': self.chosen})
 
         # Now get counts:
-        field_name = fkey_to_other_table.name
+        field_name = fkey_other.name
         values_counts = m2m_objs.values_list(field_name).order_by(field_name).annotate(models.Count(field_name))
 
         count_dict = SortedDict()
         for val, count in values_counts:
             count_dict[val] = count
 
+        return count_dict
+
+    def get_choices_add(self, qs):
+        count_dict = self.get_values_counts(qs)
         # Now, need to lookup objects on related table, to display them.
-        objs = rel_model.objects.filter(pk__in=count_dict.keys())
+        objs = self.rel_model.objects.filter(pk__in=count_dict.keys())
 
-        choices = []
-        for o in objs:
-            pk = o.pk
-            choices.append(FilterChoice(unicode(o),
-                                        count_dict[pk],
-                                        self.build_params(add=pk),
-                                        FILTER_ADD))
-        return choices
+        return [FilterChoice(unicode(o),
+                             count_dict[o.pk],
+                             self.build_params(add=o.pk),
+                             FILTER_ADD)
+                for o in objs]
 
     def get_choices_remove(self, qs):
         chosen = self.chosen
         # Do a query in bulk to get objs corresponding to choices.
         objs = self.rel_model.objects.filter(pk__in=chosen)
 
-        # We want to preserve order of items in params, so use a dict:
+        # We want to preserve order of items in params, so use the original
+        # 'chosen' list, rather than objs.
         obj_dict = dict([(obj.pk, obj) for obj in objs])
         return [FilterChoice(unicode(obj_dict[choice]),
                              None, # Don't need count for removing
             m = drt.regex.match(param)
             if m is not None:
                 return DateChoice(drt, list(m.groups()))
+        raise ValueError()
 
     def make_lookup(self, field_name):
         # It's easier to do this all using datetime comparisons than have a
         super(DateTimeFilter, self).__init__(*args, **kwargs)
 
     def choice_from_param(self, param):
-        choice = DateChoice.from_param(param)
-        if choice is None:
-            raise ValueError()
-        return choice
+        return DateChoice.from_param(param)
 
     def choices_from_params(self):
         choices = super(DateTimeFilter, self).choices_from_params()