Commits

Luke Plant committed f465f36

Added initial implementation of NumericRangeFilter

Comments (0)

Files changed (4)

django_easyfilters/filters.py

 from django.core.exceptions import ValidationError
 from django.db import models
 from django.utils import formats
-from django.utils.datastructures import SortedDict
 from django.utils.dates import MONTHS
 from django.utils.text import capfirst
-from django_easyfilters.queries import date_aggregation, value_counts
+from django_easyfilters.queries import date_aggregation, value_counts, numeric_range_counts
 
 try:
     from collections import namedtuple
         return '<DateChoice %s %s>' % (self.range_type, self.__unicode__())
 
     def __cmp__(self, other):
+        # 'greater' means more specific.
         return cmp((self.range_type, self.values),
                    (other.range_type, other.values))
 
                                        FILTER_DISPLAY))
         return retval
 
+
+def make_numeric_range_choice(to_python, to_str):
+    """
+    Returns a Choice class that represents a numeric choice range,
+    using the passed in 'to_python' and 'to_str' callables to do
+    conversion to/from native data types.
+    """
+
+    class NumericRangeChoice(object):
+        def __init__(self, values):
+            self.values = values
+
+        def display(self):
+            return '-'.join(map(str, self.values))
+
+        @classmethod
+        def from_param(cls, param):
+            vals = []
+            for p in param.split('..', 1):
+                try:
+                    val = to_python(p)
+                    vals.append(val)
+                except ValidationError:
+                    raise ValueError()
+            return cls(vals)
+
+        def make_lookup(self, field_name):
+            if len(self.values) == 1:
+                return {field_name: self.values[0]}
+            else:
+                return {field_name + '__gte': self.values[0],
+                        field_name + '__lte': self.values[1]}
+
+        def __unicode__(self):
+            return '..'.join(map(to_str, self.values))
+
+        def __repr__(self):
+            return '<NumericChoice %s>' % self.__unicode__()
+
+        def __cmp__(self, other):
+            # 'greater' means more specific.
+            if other is None:
+                return cmp(self.values, None)
+            else:
+                if len(self.values) != len(other.values):
+                    # one value is more specific than two
+                    return -cmp(len(self.values), len(other.values))
+                elif len(self.values) == 1:
+                    return 0
+                else:
+                    # Larger difference means less specific
+                    return -cmp(self.values[1] - self.values[0],
+                                other.values[1] - other.values[0])
+
+    return NumericRangeChoice
+
+class NumericRangeFilter(RangeFilterMixin, Filter):
+
+    def __init__(self, field, model, params, **kwargs):
+        self.max_links = kwargs.pop('max_links', 5)
+        field_obj = model._meta.get_field(field)
+        self.choice_type = make_numeric_range_choice(field_obj.to_python, str)
+        super(NumericRangeFilter, self).__init__(field, model, params, **kwargs)
+
+    def get_choices_add(self, qs):
+        chosen = list(self.chosen)
+        range_type = None
+
+        all_vals = qs.values_list(self.field).distinct()
+
+        num = all_vals.count()
+
+        choices = []
+        if num <= self.max_links:
+            val_counts = value_counts(qs, self.field)
+            for v, count in val_counts.items():
+                choice = self.choice_type([v])
+                choices.append(FilterChoice(choice.display(),
+                                            count,
+                                            self.build_params(add=choice),
+                                            FILTER_ADD))
+        else:
+            val_range = qs.aggregate(lower=models.Min(self.field),
+                                     upper=models.Max(self.field))
+            lower = val_range['lower']
+            upper = val_range['upper']
+
+            # TODO - round to produce nice looking ranges.
+            step = (upper - lower)/self.max_links
+            ranges = [(lower + step * i, lower + step * (i+1)) for i in xrange(self.max_links)]
+
+            val_counts = numeric_range_counts(qs, self.field, ranges)
+            for vals, count in val_counts.items():
+                choice = self.choice_type(vals)
+                choices.append(FilterChoice(choice.display(),
+                                            count,
+                                            self.build_params(add=choice),
+                                            FILTER_ADD))
+
+        return choices

django_easyfilters/filterset.py

 from django.utils.text import capfirst
 
 from django_easyfilters.filters import FILTER_ADD, FILTER_REMOVE, FILTER_DISPLAY, \
-    ValuesFilter, ChoicesFilter, ForeignKeyFilter, ManyToManyFilter, DateTimeFilter
+    ValuesFilter, ChoicesFilter, ForeignKeyFilter, ManyToManyFilter, DateTimeFilter, NumericRangeFilter
 
 
 def non_breaking_spaces(val):
             type_ = f.get_internal_type()
             if type_ == 'DateField' or type_ == 'DateTimeField':
                 return DateTimeFilter
+            elif type_ == 'DecimalField':
+                return NumericRangeFilter
             else:
                 return ValuesFilter
 

django_easyfilters/queries.py

     for val, count in values_counts:
         count_dict[val] = count
     return count_dict
+
+
+class NumericAggregateQuery(AggregateQuery):
+    # Need to override to return a compiler not in django.db.models.sql.compiler
+    def get_compiler(self, using=None, connection=None):
+        return  NumericAggregateCompiler(self, connection, using)
+
+    def get_counts(self, using):
+        from django.db import connections
+        connection = connections[using]
+        return list(self.get_compiler(using, connection).results_iter())
+
+
+class NumericAggregateCompiler(SQLCompiler):
+    def results_iter(self):
+        for rows in self.execute_sql(MULTI):
+            for row in rows:
+                yield row
+
+    def as_sql(self, qn=None):
+        sql = ('SELECT %s, COUNT(%s) FROM (%s) subquery GROUP BY (%s) ORDER BY (%s)' % (
+                NumericValueRange.alias, NumericValueRange.alias, self.query.subquery,
+                NumericValueRange.alias, NumericValueRange.alias)
+               )
+        params = self.query.sub_params
+        return (sql, params)
+
+
+class NumericValueRange(object):
+    alias = 'easyfilter_number_range_alias'
+    def __init__(self, col, ranges):
+        # ranges is list of (lower, upper) bounds we want to find, where 'lower'
+        # is inclusive and upper is exclusive.
+        self.col = col
+        self.ranges = ranges
+
+    # TODO - do we need 'relabel_aliases', like 'Date'?
+
+    def as_sql(self, qn, connection):
+        if isinstance(self.col, (list, tuple)):
+            col = '%s.%s' % tuple([qn(c) for c in self.col])
+        else:
+            col = self.col
+
+        # Build up case expression.
+        clause = (['CASE '] +
+                  ['WHEN %s >= %s AND %s < %s THEN %s ' % (col, val[0], col, val[1], i)
+                   for i, val in enumerate(self.ranges)] +
+                  ['ELSE %s END ' % len(self.ranges)] +
+                  ['as %s' % self.alias])
+        return ''.join(clause)
+
+
+def numeric_range_counts(qs, fieldname, ranges):
+
+    # Build the query:
+    query = qs.values_list(fieldname).query.clone()
+    query.select[0] = NumericValueRange(query.select[0], ranges)
+
+    agg_query = NumericAggregateQuery(qs.model)
+    agg_query.add_subquery(query, qs.db)
+    results = agg_query.get_counts(qs.db)
+
+    count_dict = SortedDict()
+    for val, count in results:
+        try:
+            r = ranges[val]
+        except IndexError:
+            # Include in the top range - this could be a rounding error
+            r = ranges[-1]
+        count_dict[r] = count
+    return count_dict
+

django_easyfilters/tests/filterset.py

 # -*- coding: utf-8; -*-
 
 from datetime import datetime, date
-import decimal
+from decimal import Decimal
 import operator
 
 from django.http import QueryDict
 from django_easyfilters.filterset import FilterSet
 from django_easyfilters.filters import \
     FILTER_ADD, FILTER_REMOVE, FILTER_DISPLAY, \
-    ForeignKeyFilter, ValuesFilter, ChoicesFilter, ManyToManyFilter, DateTimeFilter
+    ForeignKeyFilter, ValuesFilter, ChoicesFilter, ManyToManyFilter, DateTimeFilter, NumericRangeFilter
 
 from models import Book, Genre, Author, BINDING_CHOICES
 
                 'binding',
                 'authors',
                 'date_published',
+                'price',
                 ]
 
         fs = BookFilterSet(Book.objects.all(), QueryDict(''))
         self.assertEqual(ChoicesFilter, type(fs.filters[2]))
         self.assertEqual(ManyToManyFilter, type(fs.filters[3]))
         self.assertEqual(DateTimeFilter, type(fs.filters[4]))
+        self.assertEqual(NumericRangeFilter, type(fs.filters[5]))
 
 
 class TestFilters(TestCase):
                               '1818-08-24..1818-08-30',
                               ])
 
+
+    def test_numericrange_filter_simple_vals(self):
+        # If data is less than max_links, we should get a simple list of values.
+        filter1 = NumericRangeFilter('price', Book, MultiValueDict(), max_links=20)
+
+        # Limit to single value to force the case
+        qs = Book.objects.filter(price=Decimal('3.50'))
+
+        # Should only take 2 queries - one to find out how many distinct values,
+        # one to get the counts.
+        with self.assertNumQueries(2):
+            choices = filter1.get_choices(qs)
+
+        self.assertEqual(len(choices), 1)
+        self.assertTrue('3.5' in choices[0].label)
+
+    def test_numericrange_filter_range_choices(self):
+        # If data is more than max_links, we should get a range
+        filter1 = NumericRangeFilter('price', Book, MultiValueDict(), max_links=8)
+
+        qs = Book.objects.all()
+        # Should take 3 queries - one to find out how many distinct values,
+        # one to find a range, one to get the counts.
+        with self.assertNumQueries(3):
+            choices = filter1.get_choices(qs)
+
+        self.assertTrue(len(choices) <= 8)
+        total_count = sum(c.count for c in choices)
+        self.assertEqual(total_count, qs.count())
+
+    def test_numericrange_filter_apply_filter(self):
+        params = MultiValueDict({'price': ['3.50..4.00']})
+        filter1 = NumericRangeFilter('price', Book, params)
+        qs = Book.objects.all()
+
+        qs_filtered = filter1.apply_filter(qs)
+        self.assertEqual(list(qs_filtered),
+                         list(qs.filter(price__gte=Decimal('3.50'),
+                                        price__lte=Decimal('4.00'))))
+
+
     def test_order_by_count(self):
         """
         Tests the 'order_by_count' option.