Commits

Alex Ulianytskyi  committed 1ebe749

Added RangeWidget, DecimalRangeField, FloatRangeField

  • Participants
  • Parent commits 8a36b64

Comments (0)

Files changed (5)

 ^build
 ^dist
 \.egg-info
+^fabfile\.py

File django_utils/forms/__init__.py

+from django_utils.forms.widgets import *
+from django_utils.forms.fields import *

File django_utils/forms/fields.py

+from django.forms.fields import DecimalField
+from django.forms.fields import EMPTY_VALUES
+from django.forms.fields import FloatField
+from django.utils.translation import ugettext_lazy as _
+from django_utils.forms.widgets import RangeWidget
+try:
+    from decimal import Decimal, DecimalException
+except ImportError:
+    from django.utils._decimal import Decimal, DecimalException
+from django.utils.encoding import smart_str
+from django.forms import ValidationError
+
+__all__ = ['FloatRangeField', 'DecimalRangeField']
+
+class FloatRangeField(FloatField):
+    widget = RangeWidget
+    default_error_messages = {
+        'invalid': _(u'Enter a numbers.'),
+        'max_value': _(u'Ensure this start number is less than or equal to %s.'),
+        'min_value': _(u'Ensure this end number is greater than or equal to %s.'),
+        'order': _(u'Ensure this value in right order of numbers (start <= end)'),
+        'range': _(u'Ensure this value is an range (start, end)'),
+    }
+    def clean(self, value):
+        try:
+            start = super(FloatField, self).clean(value[0])
+            end = super(FloatField, self).clean(value[1])
+        except IndexError:
+            raise ValidationError(self.error_messages['range'])
+        res = [0, 0]
+        if not self.required:
+            if start in EMPTY_VALUES and end in EMPTY_VALUES:
+                return (None, None)
+            else:
+                if start not in EMPTY_VALUES:
+                    res[0] = start
+                if end not in EMPTY_VALUES:
+                    res[1] = end
+        try:
+            start = float(start)
+            end = float(end)
+        except (ValueError, TypeError):
+            raise ValidationError(self.error_messages['invalid'])
+        if start > end:
+            raise ValidationError(self.error_messages['order'])
+        if self.max_value is not None and end > self.max_value:
+            raise ValidationError(self.error_messages['max_value'] % self.max_value)
+        if self.min_value is not None and start < self.min_value:
+            raise ValidationError(self.error_messages['min_value'] % self.min_value)
+        return (start, end)
+
+class DecimalRangeField(DecimalField):
+    widget = RangeWidget
+    default_error_messages = {
+        'invalid': _(u'Enter a number.'),
+        'max_value': _(u'Ensure this value is less than or equal to %s.'),
+        'min_value': _(u'Ensure this value is greater than or equal to %s.'),
+        'max_digits': _('Ensure that there are no more than %s digits in total.'),
+        'max_decimal_places': _('Ensure that there are no more than %s decimal places.'),
+        'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.'),
+        'order': _(u'Ensure this value in right order of numbers (start <= end)'),
+        'range': _(u'Ensure this value is an range (start, end)'),
+    }
+    def clean(self, value):
+        """
+        Validates that the input is a decimal number. Returns a Decimal
+        instance. Returns None for empty values. Ensures that there are no more
+        than max_digits in the number, and no more than decimal_places digits
+        after the decimal point.
+        """
+        try:
+            start = super(DecimalField, self).clean(value[0])
+            end = super(DecimalField, self).clean(value[1])
+        except IndexError:
+            raise ValidationError(self.error_messages['range'])
+        res = [u'0', u'0']
+        if not self.required:
+            if start in EMPTY_VALUES and end in EMPTY_VALUES:
+                return (None, None)
+            else:
+                if start not in EMPTY_VALUES:
+                    res[0] = smart_str(start).strip()
+                if end not in EMPTY_VALUES:
+                    res[1] = smart_str(end).strip()
+        try:
+            start = Decimal(start)
+            end = Decimal(end)
+        except DecimalException:
+            raise ValidationError(self.error_messages['invalid'])
+        if self.max_value is not None and end > self.max_value:
+            raise ValidationError(self.error_messages['max_value'] % self.max_value)
+        if self.min_value is not None and start < self.min_value:
+            raise ValidationError(self.error_messages['min_value'] % self.min_value)
+        if start > end:
+            raise ValidationError(self.error_messages['order'])
+        for i, value in enumerate([start, end]):
+            sign, digittuple, exponent = value.as_tuple()
+            decimals = abs(exponent)
+            # digittuple doesn't include any leading zeros.
+            digits = len(digittuple)
+            if decimals > digits:
+                # We have leading zeros up to or past the decimal point.  Count
+                # everything past the decimal point as a digit.  We do not count
+                # 0 before the decimal point as a digit since that would mean
+                # we would not allow max_digits = decimal_places.
+                digits = decimals
+            whole_digits = digits - decimals
+
+            if self.max_digits is not None and digits > self.max_digits:
+                raise ValidationError(self.error_messages['max_digits'] % self.max_digits)
+            if self.decimal_places is not None and decimals > self.decimal_places:
+                raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places)
+            if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
+                raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places))
+            res[i] = value
+        return res

File django_utils/forms/widgets.py

+import datetime
+from django.forms.widgets import MultiWidget
+from django.forms.widgets import Select
+from django.forms.widgets import TextInput
+from django.utils.safestring import mark_safe
+
+__all__ = ['SplitDateWidget', 'RangeWidget']
+
+class SplitDateWidget(MultiWidget):
+    reverse = False
+    
+    def __init__(self, attrs=None, reverse=True):
+        self.reverse = reverse
+        days = [(i, i) for i in xrange(1, 32)]
+        months = dates.MONTHS.items()
+        years = [(i, i) for i in xrange(1900, datetime.date.today().year)]
+        widgets = [
+            Select(attrs=attrs, choices=years),
+            Select(attrs=attrs, choices=months),
+            Select(attrs=attrs, choices=days)
+        ]
+        if self.reverse:
+            widgets.reverse()
+        super(SplitDate, self).__init__(widgets=widgets, attrs=attrs)
+
+    def decompress(self, value):
+        if value:
+            if isinstance(value, datetime.date):
+                res = [value.year, value.month, value.day]
+            if isinstance(value, basestring):
+                res = value.split('-')
+            if self.reverse:
+                res.reverse()
+            return res
+        else:
+            return [None, None, None]
+
+    def value_from_datadict(self, data, files, name):
+        parts = super(SplitDate, self).value_from_datadict(data, files, name)
+        if parts is not None:
+            if self.reverse:
+                parts.reverse()
+            return u'-'.join(parts)
+        return None
+
+class RangeWidget(MultiWidget):
+    suffix_list = ['start', 'end']
+    
+    def __init__(self, attrs=None):
+        super(RangeWidget, self).__init__(widgets=[TextInput, TextInput], attrs=attrs)
+
+    def decompress(self, value):
+        return value if value else [None, None]
+
+    def value_from_datadict(self, data, files, name):
+        return [widget.value_from_datadict(data, files, name + '_%s' % i) for i, widget in zip(self.suffix_list, self.widgets)]
+
+    def render(self, name, value, attrs=None):
+        # value is a list of values, each corresponding to a widget
+        # in self.widgets.
+        if not isinstance(value, list):
+            value = self.decompress(value)
+        value_dict = dict(start=value[0], end=value[1])
+        output = []
+        final_attrs = self.build_attrs(attrs)
+        id_ = final_attrs.get('id', None)
+        for i, widget in zip(self.suffix_list, self.widgets):
+            try:
+                widget_value = value_dict[i]
+            except IndexError:
+                widget_value = None
+            if id_:
+                final_attrs = dict(final_attrs, id='%s_%s' % (id_, i))
+            output.append(widget.render(name + '_%s' % i, widget_value, final_attrs))
+        output.insert(1, u'-')
+        return mark_safe(self.format_output(output))

File django_utils/tests.py

 import os
 os.environ['DJANGO_SETTINGS_MODULE'] = 'django_utils.settings'
 
+from django_utils.test import TestCase
+from django_utils.urls import rest_urlconf
+from django.conf.urls.defaults import url
+from django_utils.forms import RangeWidget
 import unittest
-from django_utils import rest_urlconf
-from django.conf.urls.defaults import url
 
-class  TestsTestCase(unittest.TestCase):
-    def test_rest_urlconf(self):
-        urlpatterns = rest_urlconf('object')
-        self.assertTrue(url(r'^object/$', 'object_index', name='index') in urlpatterns)
-        self.assertTrue(url(r'^object/new/$', 'object_new', name='new') in urlpatterns)
-        self.assertTrue(url(r'^object/(?P<id>\d+>/$', 'object_show', name='show') in urlpatterns)
-        self.assertTrue(url(r'^object/(?P<id>\d+>/edit/$', 'object_edit', name='edit') in urlpatterns)
-        self.assertTrue(url(r'^object/(?P<id>\d+>/delete/$', 'object_delete', name='delete') in urlpatterns)
+class RestUrlTest(TestCase):
+    def setUp(self):
+        self.urlpatterns = rest_urlconf('object')
+    def testIndex(self):
+        self.assertTrue(url(r'^object/$', 'object_index', name='index') in self.urlpatterns)
+    def testNew(self):
+        self.assertTrue(url(r'^object/new/$', 'object_new', name='new') in self.urlpatterns)
+    def testShow(self):
+        self.assertTrue(url(r'^object/(?P<id>\d+>/$', 'object_show', name='show') in self.urlpatterns)
+    def testEdit(self):
+        self.assertTrue(url(r'^object/(?P<id>\d+>/edit/$', 'object_edit', name='edit') in self.urlpatterns)
+    def testDelete(self):
+        self.assertTrue(url(r'^object/(?P<id>\d+>/delete/$', 'object_delete', name='delete') in self.urlpatterns)
 
+class TestRangeWidget(TestCase):
+    def setUp(self):
+        self.range_widget = RangeWidget()
+
+    def testRangeFieldData(self):
+        value = self.range_widget.value_from_datadict(data=dict(range_start=u'0', range_end=u'1'), name=u'range')
+        self.assertEqual([u'0', u'1'], value)
+#TODO: Errors on run test suite
 if __name__ == '__main__':
     unittest.main()