Commits

Luke Plant  committed c34ff02

Changed the way that DateTimeFilter handled remove links to be more robust

  • Participants
  • Parent commits 495bd73

Comments (0)

Files changed (2)

File django_easyfilters/filters.py

 class DrillDownMixin(object):
 
     def get_choices_remove(self, qs):
-        # Due to drill down, if an earlier param is removed,
-        # the later params must be removed too.
+        # Due to drill down, if a broader param is removed, the more specific
+        # params must be removed too. We assume we can do an ordering on
+        # whatever 'choice' objects are in chosen, and 'greater' means 'more
+        # specific'.
         chosen = list(self.chosen)
         out = []
         for i, choice in enumerate(chosen):
+            to_remove = [c for c in chosen if c >= choice]
             out.append(FilterChoice(self.display_choice(choice),
                                     None,
-                                    self.build_params(remove=chosen[i:]),
+                                    self.build_params(remove=to_remove),
                                     FILTER_REMOVE))
         return out
 
 month_match = re.compile(r'^\d{4}-\d{2}$')
 day_match = re.compile(r'^\d{4}-\d{2}-\d{2}$')
 
+DateRangeType = namedtuple('DateRangeType', 'order label')
+YEAR  = DateRangeType(1, 'year')
+MONTH = DateRangeType(2, 'month')
+DAY   = DateRangeType(3, 'day')
+
 class DateChoice(object):
     """
     Represents a choice of date. Params are converted to this, and this is used
         # This is called when converting to URL
         return '..'.join(self.values)
 
+    def __repr__(self):
+        return '<DateChoice %s %s>' % (self.range_type, self.__unicode__())
+
+    def __cmp__(self, other):
+        return cmp((self.range_type, self.values),
+                   (other.range_type, other.values))
+
     def display(self):
         # Called for user presentable string
         if len(self.values) == 1:
             value = self.values[0]
             parts = value.split('-')
-            if self.range_type == 'year':
+            if self.range_type == YEAR:
                 return value
-            elif self.range_type == 'month':
+            elif self.range_type == MONTH:
                 month = date(int(parts[0]), int(parts[1]), 1)
                 return capfirst(formats.date_format(month, 'YEAR_MONTH_FORMAT'))
-            elif self.range_type == 'day':
+            elif self.range_type == DAY:
                 return str(int(parts[-1]))
         else:
             return u'-'.join([DateChoice(self.range_type,
 
     @staticmethod
     def datetime_to_value(range_type, dt):
-        if range_type == 'year':
+        if range_type == YEAR:
             return '%04d' % dt.year
-        elif range_type == 'month':
+        elif range_type == MONTH:
             return '%04d-%02d' % (dt.year, dt.month)
         else:
             return '%04d-%02d-%02d' % (dt.year, dt.month, dt.day)
     @staticmethod
     def range_type_from_param(param):
         if year_match.match(param):
-            return 'year'
+            return YEAR
         elif month_match.match(param):
-            return 'month'
+            return MONTH
         elif day_match.match(param):
-            return 'day'
+            return DAY
 
     @staticmethod
     def from_param(param):
             # yyyy-mm-dd
             # Need to look up last part, converted to int
             parts = val.split('-')
-            return {field_name + '__' + self.range_type: int(parts[-1])}
+            return {field_name + '__' + self.range_type.label : int(parts[-1])}
         else:
             # Should be just two values. First is lower bound, second is upper
             # bound. Need to convert to datetime objects.
             start_parts = map(int, self.values[0].split('-'))
             end_parts = map(int, self.values[1].split('-'))
-            if self.range_type == 'year':
+            if self.range_type == YEAR:
                 return {field_name + '__gte': date(start_parts[0], 1, 1),
                         field_name + '__lt': date(end_parts[0] + 1, 1, 1)}
             else:
                 return {}
 
-    def __eq__(self, other):
-        return (other is not None and
-                self.range_type == other.range_type and
-                self.values == other.values)
-
 
 class DateTimeFilter(MultiValueFilterMixin, DrillDownMixin, Filter):
 
             raise ValueError()
         return choice
 
+    def choices_from_params(self):
+        choices = super(DateTimeFilter, self).choices_from_params()
+        choices.sort()
+        return choices
+
     def lookup_from_choice(self, choice):
         return choice.make_lookup(self.field)
 
 
         if len(chosen) > 0:
             last = chosen[-1]
-            if last.range_type == 'year':
+            if last.range_type == YEAR:
                 if len(last.values) == 1:
                     # One year, drill down
-                    range_type = 'month'
+                    range_type = MONTH
                 else:
                     # Range, stay on year
-                    range_type = 'year'
-            elif last.range_type == 'month':
+                    range_type = YEAR
+            elif last.range_type == MONTH:
                 if len(last.values) == 1:
-                    range_type = 'day'
+                    range_type = DAY
                 else:
-                    range_type = 'month'
-            elif last.range_type == 'day':
+                    range_type = MONTH
+            elif last.range_type == DAY:
                 if len(last.values) == 1:
                     # Already down to one day, can't drill any further.
                     return []
                 else:
-                    range_type = 'day'
+                    range_type = DAY
 
         if range_type is None:
             # Get some initial idea of range
                 return []
             if first.year == last.year:
                 if first.month == last.month:
-                    range_type = 'day'
+                    range_type = DAY
                 else:
-                    range_type = 'month'
+                    range_type = MONTH
             else:
-                range_type = 'year'
+                range_type = YEAR
 
-        date_qs = qs.dates(self.field, range_type)
+        date_qs = qs.dates(self.field, range_type.label)
         results = date_aggregation(date_qs)
 
         if len(results) > self.max_links:

File django_easyfilters/tests/filterset.py

         self.assertEqual(len(choices), 0)
         self.assertEqual(len(qs_filtered), 0)
 
+    def test_datetime_filter_remove_broad(self):
+        """
+        If we remove a broader choice (e.g. year), the more specific choices
+        (e.g. day) should be removed too.
+        """
+        # This should hold whichever order the params are defined:
+        params1 = MultiValueDict({'date_published': ['1818',
+                                                     '1818-08',
+                                                     '1818-08-24']})
+        params2 = MultiValueDict({'date_published': ['1818-08-24',
+                                                     '1818-08',
+                                                     '1818']})
+        for p in [params1, params2]:
+            f = DateTimeFilter('date_published', Book, p)
+            qs = Book.objects.all()
+            qs_filtered = f.apply_filter(qs)
+            choices = f.get_choices(qs_filtered)
+            # First choice should be for '1818' and remove all 'date_published'
+            self.assertEqual(choices[0].label, '1818')
+            self.assertEqual(choices[0].link_type, FILTER_REMOVE)
+            self.assertEqual(choices[0].params.getlist('date_published'), [])
+            # Second choice should remove all but one 'date_published'
+            self.assertEqual(choices[1].link_type, FILTER_REMOVE)
+            self.assertEqual(choices[1].params.getlist('date_published'), ['1818'])
+
+            self.assertEqual(choices[2].link_type, FILTER_REMOVE)
+            self.assertEqual(choices[2].params.getlist('date_published'), ['1818',
+                                                                           '1818-08'])
+
     def test_order_by_count(self):
         """
         Tests the 'order_by_count' option.