Commits

Luke Plant committed f0f72eb

Beginning of DateTimeFilter

Involved a fair amount of API change and code movement in the rest of Filter

  • Participants
  • Parent commits 72f74df

Comments (0)

Files changed (6)

django_easyfilters/filters.py

 from collections import namedtuple
+from datetime import date
 import operator
+import re
 
+from django.core.exceptions import ValidationError
 from django.db import models
+from django.utils import formats
 from django.utils.datastructures import SortedDict
+from django.utils.text import capfirst
+from django_easyfilters.queries import date_aggregation
 
 FILTER_ADD = 'add'
 FILTER_REMOVE = 'remove'
 FilterChoice = namedtuple('FilterChoice', 'label count params link_type')
 
 
+# TODO:
+#
+# - change API of Filter, so that it has an explicit, required 'set_params'
+#   method. This will eliminate the overhead of calling 'choices_from_params'
+#   multiple times per request, because after params have been set, we can cache
+#   the choices. It will also mean we don't pass params around everywhere.
+
+
 class FilterOptions(object):
     """
     Defines some common options for all Filters.
     a FilterSet. The actual choice of Filter subclass will be done by the
     FilterSet in this case.
     """
-    def __init__(self, query_param=None, order_by_count=False):
+    def __init__(self, query_param=None, order_by_count=False,
+                 max_links=10):
         self.query_param = query_param
         self.order_by_count = order_by_count
+        self.max_links = max_links
 
 
 class Filter(FilterOptions):
         Apply the filtering defined in params (request.GET) to the queryset qs,
         returning the new QuerySet.
         """
-        p_val = self.choices_from_params(params)
-        while len(p_val) > 0:
-            qs = qs.filter(**{self.field: p_val.pop()})
+        p_vals = self.choices_from_params(params)
+        while len(p_vals) > 0:
+            lookup = self.lookup_from_choice(p_vals.pop())
+            qs = qs.filter(**lookup)
         return qs
 
     def get_choices(self, qs, params):
 
     ### Methods that are used by base implementation above ###
 
-    def to_python(self, param):
-        return self.field_obj.to_python(param)
+    def choices_from_params(self, params):
+        out = []
+        for p in params.getlist(self.query_param):
+            try:
+                choice = self.choice_from_param(p)
+                out.append(choice)
+            except ValueError:
+                pass
+        return out
 
-    def choices_from_params(self, params):
+    def choice_from_param(self, param):
         """
-        For the params passed in (i.e. from query string), retrive a list of
-        already 'chosen' options.
+        Returns a native Python object representing something that has been
+        chosen for a filter, converted from the string value in param.
         """
-        return [self.to_python(i) for i in params.getlist(self.query_param)]
+        try:
+            return self.field_obj.to_python(param)
+        except ValidationError:
+            raise ValueError()
+
+    def lookup_from_choice(self, choice):
+        """
+        Converts a choice value to a lookup dictionary that can be passed to
+        QuerySet.filter() to do the filtering for that choice.
+        """
+        return {self.field: choice}
 
     ### Utility methods needed by most/all subclasses ###
 
         return map(unicode, choices)
 
     def build_params(self, params, add=None, remove=None):
+        """
+        Builds a new parameter MultiDict.
+        add is an optional item to add,
+        remove is an option list of items to remove.
+        """
         params = params.copy()
         chosen = self.choices_from_params(params)
         if remove is not None:
-            chosen.remove(remove)
+            for r in remove:
+                chosen.remove(r)
         else:
             if add not in chosen:
                 chosen.append(add)
         choices = self.choices_from_params(params)
         return [FilterChoice(self.display_choice(choice),
                              None, # Don't need count for removing
-                             self.build_params(params, remove=choice),
+                             self.build_params(params, remove=[choice]),
                              FILTER_REMOVE)
                 for choice in choices]
 
         super(ManyToManyFilter, self).__init__(*args, **kwargs)
         self.rel_model = self.field_obj.rel.to
 
-    def to_python(self, param):
-        return self.field_obj.rel.get_related_field().to_python(param)
+    def choice_from_param(self, param):
+        try:
+            return self.field_obj.rel.get_related_field().to_python(param)
+        except ValidationError:
+            raise ValueError()
 
     def get_choices_add(self, qs, params):
         # It is easiest to base queries around the intermediate table, in order
         obj_dict = dict([(obj.pk, obj) for obj in objs])
         return [FilterChoice(unicode(obj_dict[choice]),
                              None, # Don't need count for removing
-                             self.build_params(params, remove=choice),
+                             self.build_params(params, remove=[choice]),
                              FILTER_REMOVE)
                 for choice in choices]
+
+
+class DrillDownMixin(object):
+
+    def get_choices_remove(self, qs, params):
+        # Due to drill down, if an earlier param is removed,
+        # the later params must be removed too.
+        chosen = self.choices_from_params(params)
+        out = []
+        for i, choice in enumerate(chosen):
+            out.append(FilterChoice(self.display_choice(choice),
+                                    None,
+                                    self.build_params(params, remove=chosen[i:]),
+                                    FILTER_REMOVE))
+        return out
+
+
+year_match = re.compile(r'^\d{4}$')
+month_match = re.compile(r'^\d{4}-\d{2}$')
+day_match = re.compile(r'^\d{4}-\d{2}-\d{2}$')
+
+class DateChoice(object):
+    """
+    Represents a choice of date. Params are converted to this, and this is used
+    to build new params and format links.
+
+    It can represent a year, month or day choice, or a range (start, end, both
+    inclusive) of any of these choice.
+    """
+
+    def __init__(self, range_type, values):
+        self.range_type = range_type
+        self.values = values
+
+    def __unicode__(self):
+        # This is called when converting to URL
+        return '..'.join(self.values)
+
+    def display(self):
+        # Called for user presentable string
+        if len(self.values) == 1:
+            value = self.values[0]
+            if self.range_type == 'year':
+                return value
+            elif self.range_type == 'month':
+                parts = value.split('-')
+                month = date(int(parts[0]), int(parts[1]), 1)
+                return capfirst(formats.date_format(month, 'YEAR_MONTH_FORMAT'))
+        else:
+            return u'-'.join([DateChoice(self.range_type,
+                                         [val]).display()
+                              for val in self.values])
+
+
+    @staticmethod
+    def datetime_to_value(range_type, dt):
+        if range_type == 'year':
+            return '%04d' % dt.year
+        elif range_type == 'month':
+            return '%04d-%02d' % (dt.year, dt.month)
+        else:
+            return '%04d-%02d-%02' % (dt.year, dt.month, dt.day)
+
+    @staticmethod
+    def from_datetime(range_type, dt):
+        return DateChoice(range_type, [DateChoice.datetime_to_value(range_type, dt)])
+
+    @staticmethod
+    def from_datetime_range(range_type, dt1, dt2):
+        return DateChoice(range_type,
+                          [DateChoice.datetime_to_value(range_type, dt1),
+                           DateChoice.datetime_to_value(range_type, dt2)])
+
+    @staticmethod
+    def range_type_from_param(param):
+        if year_match.match(param):
+            return 'year'
+        elif month_match.match(param):
+            return 'month'
+        elif day_match.match(param):
+            return 'day'
+
+    @staticmethod
+    def from_param(param):
+        vals = []
+        if '..' in param:
+            params = param.split('..', 1)
+            range_types = [DateChoice.range_type_from_param(p) for p in params]
+            if None in range_types or range_types[0] != range_types[1]:
+                return None
+            else:
+                return DateChoice(range_types[0], params)
+        else:
+            range_type = DateChoice.range_type_from_param(param)
+            if range_type is not None:
+                return DateChoice(range_type, [param])
+
+    def make_lookup(self, field_name):
+        if len(self.values) == 1:
+            val = self.values[0]
+            # val can contain:
+            # yyyy
+            # yyyy-mm
+            # yyyy-mm-dd
+            # Need to look up last part, converted to int
+            parts = val.split('-')
+            return {field_name + '__' + self.range_type: int(parts[-1])}
+        else:
+            # Should be just two values. First is lower bound, second is upper
+            # bound. Need to convert to datetime objects.
+            start_parts = map(int, self.values[0].split('-'))
+            end_parts = map(int, self.values[1].split('-'))
+            if self.range_type == 'year':
+                return {field_name + '__gte': date(start_parts[0], 1, 1),
+                        field_name + '__lt': date(end_parts[0] + 1, 1, 1)}
+            else:
+                return {}
+
+    def __eq__(self, other):
+        return (other is not None and
+                self.range_type == other.range_type and
+                self.values == other.values)
+
+
+class DateTimeFilter(MultiValueFilterMixin, DrillDownMixin, Filter):
+
+    def choice_from_param(self, param):
+        choice = DateChoice.from_param(param)
+        if choice is None:
+            raise ValueError()
+        return choice
+
+    def lookup_from_choice(self, choice):
+        return choice.make_lookup(self.field)
+
+    def display_choice(self, choice):
+        return choice.display()
+
+    def get_choices_add(self, qs, params):
+        choices = self.choices_from_params(params)
+        range_type = None
+
+        if len(choices) > 0:
+            if choices[-1].range_type == 'year':
+                if len(choices[-1].values) == 1:
+                    # One year, drill down
+                    range_type = 'month'
+                else:
+                    # Range, stay on year
+                    range_type = 'year'
+
+        if range_type is None:
+            # Get some initial idea of range
+            date_range = qs.aggregate(first=models.Min(self.field),
+                                      last=models.Max(self.field))
+            first = date_range['first']
+            last = date_range['last']
+            if first.year == last.year:
+                range_type = 'month'
+            else:
+                range_type = 'year'
+
+        date_qs = qs.dates(self.field, range_type)
+        results = date_aggregation(date_qs)
+
+        if len(results) > self.max_links:
+            # Fold results together
+            div, mod = divmod(len(results), self.max_links)
+            if mod != 0:
+                div += 1
+            date_choice_counts = []
+            i = 0
+            while i < len(results):
+                group = results[i:i+div]
+                count = sum(row[1] for row in group)
+                # build range:
+                choice = DateChoice.from_datetime_range(range_type,
+                                                        group[0][0],
+                                                        group[-1][0])
+                date_choice_counts.append((choice, count))
+                i += div
+        else:
+            date_choice_counts = [(DateChoice.from_datetime(range_type, dt), count)
+                                  for dt, count in results]
+
+        choices = []
+        for date_choice, count in date_choice_counts:
+            choices.append(FilterChoice(date_choice.display(),
+                                        count,
+                                        self.build_params(params, add=date_choice),
+                                        FILTER_ADD))
+        return choices

django_easyfilters/filterset.py

 from django.utils.text import capfirst
 
 from django_easyfilters.filters import FILTER_ADD, FILTER_REMOVE, FILTER_ONLY_CHOICE, \
-    ValuesFilter, ChoicesFilter, ForeignKeyFilter, ManyToManyFilter
+    ValuesFilter, ChoicesFilter, ForeignKeyFilter, ManyToManyFilter, DateTimeFilter
 
 
 def non_breaking_spaces(val):
         elif f.choices:
             klass = ChoicesFilter
         else:
-            klass = ValuesFilter
+            type_ = f.get_internal_type()
+            if type_ == 'DateField' or type_ == 'DateTimeField':
+                klass = DateTimeFilter
+            else:
+                klass = ValuesFilter
         return klass(field, self.model, **kwargs)
 
     def setup_filters(self):

django_easyfilters/queries.py

+from django.db.models.sql.datastructures import Date
+from django.db.models.sql.subqueries import AggregateQuery
+from django.db.models.sql.compiler import SQLCompiler
+from django.db.models.sql.constants import MULTI
+
+
+# Some fairly brittle, low level stuff, to get the aggregation
+# queries we need.
+
+
+class DateAggregateQuery(AggregateQuery):
+    # Need to override to return a compiler not in django.db.models.sql.compiler
+    def get_compiler(self, using=None, connection=None):
+        return DateAggregateCompiler(self, connection, using)
+
+    def get_counts(self, using):
+        from django.db import connections
+        connection = connections[using]
+        return list(self.get_compiler(using, connection).results_iter())
+
+
+class DateAggregateCompiler(SQLCompiler):
+    def results_iter(self):
+        resolve_columns = hasattr(self, 'resolve_columns')
+        if resolve_columns:
+            from django.db.models.fields import DateTimeField, IntegerField
+            fields = [DateTimeField(), IntegerField]
+        else:
+            from django.db.backends.util import typecast_timestamp
+            needs_string_cast = self.connection.features.needs_datetime_string_cast
+
+        offset = len(self.query.extra_select)
+        for rows in self.execute_sql(MULTI):
+            for row in rows:
+                if resolve_columns:
+                    vals = self.resolve_columns(row, fields)
+                elif needs_string_cast:
+                    vals = [typecast_timestamp(str(row[0])),
+                            row[1]]
+                yield vals
+
+    def as_sql(self, qn=None):
+        sql = ('SELECT %s, COUNT(%s) FROM (%s) GROUP BY (%s) ORDER BY (%s)' % (
+                DateWithAlias.alias, DateWithAlias.alias, self.query.subquery,
+                DateWithAlias.alias, DateWithAlias.alias)
+               )
+        params = self.query.sub_params
+        return (sql, params)
+
+
+class DateWithAlias(Date):
+    alias = 'easyfilter_date_alias'
+    def as_sql(self, qn, connection):
+        return super(DateWithAlias, self).as_sql(qn, connection) + ' as ' + self.alias
+
+
+def date_aggregation(date_qs):
+    """
+    Performs an aggregation for a supplied DateQuerySet
+    """
+    # The DateQuerySet gives us a query that we need to clone and hack
+    date_q = date_qs.query.clone()
+    date_q.distinct = False
+
+    # Replace 'select' to add an alias
+    date_obj = date_q.select[0]
+    date_q.select = [DateWithAlias(date_obj.col, date_obj.lookup_type)]
+
+    # Now use as a subquery to do aggregation
+    query = DateAggregateQuery(date_qs.model)
+    query.add_subquery(date_q, date_qs.db)
+    return query.get_counts(date_qs.db)

django_easyfilters/tests/filterset.py

 # -*- coding: utf-8; -*-
 
+from datetime import datetime, date
 import decimal
 import operator
 
 from django_easyfilters.filterset import FilterSet
 from django_easyfilters.filters import FilterOptions, \
     FILTER_ADD, FILTER_REMOVE, FILTER_ONLY_CHOICE, \
-    ForeignKeyFilter, ValuesFilter, ChoicesFilter, ManyToManyFilter
+    ForeignKeyFilter, ValuesFilter, ChoicesFilter, ManyToManyFilter, DateTimeFilter
 
 from models import Book, Genre, Author, BINDING_CHOICES
 
                 'edition',
                 'binding',
                 'authors',
+                'date_published',
                 ]
 
         fs = BookFilterSet(Book.objects.all(), QueryDict(''))
         self.assertEqual(ValuesFilter, type(fs.filters[1]))
         self.assertEqual(ChoicesFilter, type(fs.filters[2]))
         self.assertEqual(ManyToManyFilter, type(fs.filters[3]))
+        self.assertEqual(DateTimeFilter, type(fs.filters[4]))
 
 
 class TestFilters(TestCase):
                           (unicode(anne), FILTER_REMOVE),
                           (unicode(charlotte), FILTER_ADD)])
 
+    def test_datetime_filter_multiple_year_choices(self):
+        """
+        Tests that DateTimeFilter can produce choices spanning a set of years
+        (and limit to max_links)
+        """
+        # This does drill down, and has multiple values.
+        f = DateTimeFilter('date_published', Book, max_links=10)
+
+        qs = Book.objects.all()
+
+        # We have enough data that it will not show a simple list of years.
+        choices = f.get_choices(qs, MultiValueDict())
+        self.assertTrue(len(choices) <= 10)
+
+    def test_datetime_filter_single_year_selected(self):
+        f = DateTimeFilter('date_published', Book, max_links=10)
+        qs = Book.objects.all()
+        params = MultiValueDict({'date_published':['1818']})
+
+        # Should get a number of books in queryset.
+        qs_filtered = f.apply_filter(qs, params)
+
+        self.assertEqual(list(qs_filtered),
+                         list(qs.filter(date_published__year=1818)))
+        # We only need 1 query if we've already told it what year to look at.
+        with self.assertNumQueries(1):
+            choices = f.get_choices(qs_filtered, params)
+
+        self.assertTrue(len([c for c in choices if c.link_type == FILTER_ADD]) >= 2)
+        self.assertEqual(len([c for c in choices if c.link_type == FILTER_REMOVE]), 1)
+
+    def test_datetime_filter_year_range_selected(self):
+        f = DateTimeFilter('date_published', Book, max_links=10)
+        qs = Book.objects.all()
+        params = MultiValueDict({'date_published':['1813..1814']})
+
+        # Should get a number of books in queryset.
+        qs_filtered = f.apply_filter(qs, params)
+
+        start = date(1813, 1, 1)
+        end = date(1815, 1, 1)
+        self.assertEqual(list(qs_filtered),
+                         list(qs.filter(date_published__gte=start,
+                                        date_published__lt=end)))
+
+        # We only need 1 query if we've already told it what years to look at,
+        # and there is data for both years.
+        with self.assertNumQueries(1):
+            choices = f.get_choices(qs_filtered, params)
+
+        self.assertEqual(len([c for c in choices if c.link_type == FILTER_REMOVE]), 1)
+        self.assertEqual(len([c for c in choices if c.link_type == FILTER_ADD]), 2)
+        self.assertEqual([c.label for c in choices if c.link_type == FILTER_ADD],
+                         ['1813', '1814'])
+
+
+    def test_datetime_filter_invalid_query(self):
+        f = DateTimeFilter('date_published', Book, max_links=10)
+        qs = Book.objects.all()
+        params = MultiValueDict({'date_published':['1818xx']})
+
+        # invalid param should be ignored
+        qs_filtered = f.apply_filter(qs, params)
+        self.assertEqual(list(qs_filtered),
+                         list(qs))
+
+        self.assertEqual(list(f.get_choices(qs, params)),
+                         list(f.get_choices(qs, MultiValueDict({}))))
+
     def test_order_by_count(self):
         """
         Tests the 'order_by_count' option.

django_easyfilters/tests/fixtures/django_easyfilters_tests.json

       "edition": 1, 
       "price": "3.8", 
       "binding": "P", 
-      "date_published": "1818-08-25", 
+      "date_published": "1818-09-25", 
       "authors": [
         1
       ], 

django_easyfilters/tests/views.py

         'authors',
         'genre',
         'price',
+        'date_published',
         ]
 
 def books(request):