Commits

Kevin Wetzels committed 79bd881

Add Enumeration, EnumCharField and EnumIntegerField to make working with choices easier

Comments (0)

Files changed (8)

 ================
 
 Fields I wish were standard in Django. At the moment this is limited to the
-``MinutesField``.
+``MinutesField``, ``EnumIntegerField`` and ``EnumCharField``.
 
 Contents
 ========
 
+* ``stdfields.forms.MinutesField``: use an integer to represent a duration of 
+  minutes and hours
+* ``stdfields.fields.EnumIntegerField``: makes working with ``choices`` a bit 
+  easier
+* ``stdfields.fields.EnumCharField``: the same, but for ``choices`` with a char 
+  key
+
 MinutesField
 ------------
 Is an extension of Django's standard ``django.forms.IntegerField``.
 
 This field will accept values for a duration in minutes in the formats 
 ``hh:mm`` or ``h.fraction``, similar to the way BaseCamp allows you to specify 
-your time spent on a task as either '8:30' or ``8.5``. In the latter case only 
+your time spent on a task as either ``8:30`` or ``8.5``. In the latter case only 
 ``8.25``, ``8.5``, ``8.50`` and ``8.75`` are considered valid inputs.
 
 Example
 such a field in the format ``8:30``::
 
     {% load stdfieldstags %}
-    It took me {{ task.time_spent|minutes }} to complete this task.
+    It took me {{ task.time_spent|minutes }} to complete this task.
+
+
+Enumeration
+-----------
+I always end up with ugly code when using Django's ``choices`` argument for 
+fields. With the ``stdfields.models.Enumeration`` class, I've got a handy base 
+class that allows me to keep things tidy::
+
+    # models.py
+    class Color(Enumeration):
+        RED = 'R'
+        GREEN = 'G'
+        BLUE = 'B'
+    
+        @classmethod
+        def all(cls):
+            return [
+                (cls.RED, _(u'Red')),
+                (cls.GREEN, _(u'Green')),
+                (cls.BLUE, _(u'Blue'))
+            ]
+            
+    class Pencil(models.Model):
+        color = models.CharField(choices=Color.all(), max_length=Color.max_length())
+        
+    # views.py
+    def red_pencils(request):
+        pencils = Pencil.objects.filter(color=Color.RED)
+        ...
+        # Prints 'Showing the Red pencils'
+        logging.info('Showing the %s pencils' % (Color.as_display(Color.RED)))
+
+
+EnumCharField and EnumIntegerField
+----------------------------------
+And now we can make working with an ``Enumeration`` easier with the 
+``EnumCharField`` and ``EnumIntegerField`` models fields::
+
+    # models.py
+    class Color(Enumeration):
+        # same as above
+        
+    class Pencil(models.Model):
+        color = models.EnumCharField(enum=Color, max_length=Color.max_length())
+        
+This example is basically the same as the above since ``EnumCharField`` is a 
+subclass of the regular Django ``CharField``. By using the ``enum`` keyword 
+argument of the enum field, the choices will be automatically updated when you
+update the enumeration object. And since you're using the provided 
+``max_length`` method of ``Enumeration``, the ``max_length`` will be updated
+when needed. Just like in the previous example. The enum fields simply offer 
+some more clarity when reading the code.
+
+``EnumIntegerField`` works exactly the same, but for enumerations with integer
+keys. Both fields can be used with South.
+# -*- coding: utf-8 -*-
+from django.db import models
+
+
+class EnumIntegerField(models.PositiveIntegerField):
+    """
+    Extension of a standard Django ``PositiveIntegerField`` that takes an
+    optional ``enum`` argument which should point to an implementation of
+    ``stdfields.models.Enumeration``.
+
+    The results of the implementation's ``all`` method will be used as the
+    possible choices.
+    """
+
+    def __init__(self, *args, **kwargs):
+        if 'enum' in kwargs:
+            self.enum = kwargs.pop('enum')
+            kwargs['choices'] = self.enum.all()
+        super(EnumIntegerField, self).__init__(*args, **kwargs)
+
+
+class EnumCharField(models.CharField):
+    """
+    Extension of a standard Django ``CharField`` that takes an optional
+    ``enum`` argument which should point to an implementation of
+    ``stdfields.models.Enumeration``.
+
+    The results of the implementation's ``all`` method will be used as the
+    possible choices.
+    """
+
+    def __init__(self, *args, **kwargs):
+        if 'enum' in kwargs:
+            self.enum = kwargs.pop('enum')
+            choices = self.enum.all()
+            kwargs['choices'] = choices
+        else:
+            choices = kwargs.get('choices', [])
+        super(EnumCharField, self).__init__(*args, **kwargs)
+
+
+try:
+    # Let South know it should be able to handle these fields
+    from south.modelsinspector import add_introspection_rules
+    add_introspection_rules([], ["^stdfields\.fields\.EnumIntegerField"])
+    add_introspection_rules([], ["^stdfields\.fields\.EnumCharField"])
+except ImportError:
+    # You're not using South?!
+    pass
 
 from django import forms
 from django.utils.encoding import smart_str
+from django.utils.translation import ugettext_lazy as _
 
 from widgets import MinutesWidget
 
+
 class MinutesField(forms.IntegerField):
-    
+    """
+    A form field representing a duration in minutes.
+
+    Accepts formats ``hh:mm`` and ``hh.fraction``, meaning ``8:30`` and ``8.5``
+    are equivalent, meaning 8 hours and 30 minutes.
+    """
     widget = MinutesWidget
-    
+
+    def __init__(self, *args, **kwargs):
+        # Override the default 'invalid' error message from IntegerField
+        if 'error_messages' in kwargs:
+            error_messages = kwargs.pop('error_messages')
+        else:
+            error_messages = {}
+        if not 'invalid' in error_messages:
+            error_messages['invalid'] = _(u'Enter a valid value.')
+        kwargs['error_messages'] = error_messages
+        super(MinutesField, self).__init__(*args, **kwargs)
+
     def clean(self, value):
         value = smart_str(value).strip()
         match = re.search(r'^(\d+):(\d{1,2})$', value)
             hours = int(groups[0])
             minutes = int(groups[1])
             if minutes > 59:
-                raise forms.ValidationError(self.error_messages['invalid'])
+                msg = self.error_messages['invalid']
+                raise forms.ValidationError(msg)
             value = (hours * 60) + minutes
         else:
             value = value.replace(',', '.')
                     hours = int(parts[0])
                     fraction = int(parts[1])
                     if not fraction in (5, 25, 50, 75):
-                        raise forms.ValidationError(self.error_messages['invalid'])
+                        msg = self.error_messages['invalid']
+                        raise forms.ValidationError(msg)
                     if fraction == 5:
                         fraction = 50
                     value = int((hours * 60) + (60 / 100 * fraction))
                 except (ValueError, TypeError):
-                    raise forms.ValidationError(self.error_messages['invalid'])
+                    msg = self.error_messages['invalid']
+                    raise forms.ValidationError(msg)
         return super(MinutesField, self).clean(value)
-from django.db import models
+# -*- coding: utf-8 -*-
+from django.utils.encoding import smart_str
 
-# Create your models here.
+
+class Enumeration(object):
+    """
+    Simple enumeration object - subclasses should implement ``all``,
+    mapping keys to values.
+    """
+
+    @classmethod
+    def as_dict(cls):
+        """
+        Returns the key-label pairs of the enumeration.
+        """
+        return dict((k, v) for (k, v) in cls.all())
+
+    @classmethod
+    def as_display(cls, key):
+        """Returns the label or display value of the key."""
+        return cls.as_dict().get(key, None)
+
+    @classmethod
+    def all(cls):
+        """
+        Returns all key-label pairs as a list of tuples.
+
+        Useful to pass on to a field as the possible choices. Will be used
+        internally by the customer enum fields. Should be *implemented* by 
+        subclasses.
+        """
+        return []
+
+    @classmethod
+    def max_length(cls):
+        """
+        Calculates the maximum length of the key.
+
+        You can set the ``max_length`` value of a ``CharField`` or
+        ``EnumCharField`` to the result of this method. That way South will be
+        able to pick up any changes to the maximum key length automatically.
+        """
+        keys = [smart_str(x[0]) for x in cls.all()]
+        value = max(keys, key=lambda x: len(x))
+        return 1 if not value else len(value)

stdfields/templatetags/stdfieldstags.py

 
 register = template.Library()
 
-@register.filter 
+
+@register.filter
 def minutes(value):
-    return format_minutes(value)
+    return format_minutes(value)
 """
 from django.test import TestCase
 from django.forms import ValidationError
+from django.db import models
 
+from stdfields.models import Enumeration
+from stdfields.fields import EnumIntegerField, EnumCharField
 from stdfields.forms import MinutesField
 from stdfields.widgets import MinutesWidget
 
+# -- --------------------------------------------------------------------------
+# -- MinutesField + MinutesWidget test
+
 class MinutesFieldTest(TestCase):
-    
+
     def test_minutes_field(self):
         f = MinutesField()
         self.assertEqual(121, f.clean(121))
         minutes = 480
         for i in range(60):
             self.assertEqual(minutes + i, f.clean('8:%d' % (i)))
-        
+
     def test_minutes_field_invalid(self):
         f = MinutesField()
         self._should_raise_validation_error(f, '2:60')
         self._should_raise_validation_error(f, 'x:y')
         self._should_raise_validation_error(f, '2;30')
         self._should_raise_validation_error(f, '2;30')
-        
+
     def _should_raise_validation_error(self, f, value):
         try:
             f.clean(value)
             self.fail('%s should raise a ValidationError' % (value))
-        except ValidationError:
+        except ValidationError, e:
             pass
 
 
 class MinutesWidgetTest(TestCase):
-    
+
     def test_minutes_widget(self):
         w = MinutesWidget()
-        self.assertEqual(w.render('hi', '121'), '<input type="text" name="hi" value="2:01" />')
-        self.assertEqual(w.render('hi', '2:1'), '<input type="text" name="hi" value="2:1" />')
-        self.assertEqual(w.render('hi', '2:60'), '<input type="text" name="hi" value="2:60" />')
-        self.assertEqual(w.render('hi', '493'), '<input type="text" name="hi" value="8:13" />')
+        tpl = '<input type="text" name="hi" value="%s" />'
+        self.assertEqual(w.render('hi', '121'), tpl % '2:01')
+        self.assertEqual(w.render('hi', '2:1'), tpl % '2:1')
+        self.assertEqual(w.render('hi', '2:60'), tpl % '2:60')
+        self.assertEqual(w.render('hi', '493'), tpl % '8:13')
+
+
+# -- --------------------------------------------------------------------------
+# -- EnumIntegerField tests
+
+class ExampleIntegerEnum(Enumeration):
+    FIRST = 1
+    SECOND = 2
+    THIRD = 3
+
+    @classmethod
+    def all(cls):
+        return [
+            (cls.FIRST, 'First'),
+            (cls.SECOND, 'Second'),
+            (cls.THIRD, 'Third'),
+        ]
+
+
+class EnumIntegerModel(models.Model):
+    c = EnumIntegerField(enum=ExampleIntegerEnum)
+
+
+class EnumIntegerFieldTest(TestCase):
+
+    def test_enum_integer_field(self):
+        f = EnumIntegerField(enum=ExampleIntegerEnum,
+                            default=ExampleIntegerEnum.FIRST, blank=False)
+        self.assertEqual(ExampleIntegerEnum.all(), f.formfield().choices)
+        f = EnumIntegerField(enum=ExampleIntegerEnum)
+        expected = [('', '---------')] + ExampleIntegerEnum.all()
+        self.assertEqual(expected, f.formfield().choices)
+
+    def test_display(self):
+        first = ExampleIntegerEnum.FIRST
+        label = ExampleIntegerEnum.as_display(first)
+        self.assertEqual(label, EnumIntegerModel(c=first).get_c_display())
+        self.assertEqual(5, EnumIntegerModel(c=5).get_c_display())
+        self.assertTrue(EnumIntegerModel(c=None).get_c_display() is None)
+        self.assertEqual('', EnumIntegerModel(c='').get_c_display())
+
+
+# -- --------------------------------------------------------------------------
+# -- EnumCharField tests
+
+class ExampleCharEnum(Enumeration):
+    FIRST = 'A'
+    SECOND = 'Boo'
+    THIRD = 'Circus'
+
+    @classmethod
+    def all(cls):
+        return [
+            (cls.FIRST, 'First'),
+            (cls.SECOND, 'Second'),
+            (cls.THIRD, 'Third'),
+        ]
+
+
+class EnumCharModel(models.Model):
+    c = EnumCharField(enum=ExampleCharEnum,
+                    max_length=ExampleCharEnum.max_length())
+
+
+class EnumCharFieldTest(TestCase):
+
+    def test_enum_char_field(self):
+        f = EnumCharField(enum=ExampleCharEnum, default=ExampleCharEnum.FIRST,
+                        blank=False, max_length=ExampleCharEnum.max_length())
+        # 'Circus' is the longest key at 6 characters
+        self.assertEqual(6, f.max_length)
+        self.assertEqual(ExampleCharEnum.all(), f.formfield().choices)
+        f = EnumCharField(enum=ExampleCharEnum)
+        expected = [('', '---------')] + ExampleCharEnum.all()
+        self.assertEqual(expected, f.formfield().choices)
+
+    def test_display(self):
+        first = ExampleCharEnum.FIRST
+        label = ExampleCharEnum.as_display(first)
+        self.assertEqual(label, EnumCharModel(c=first).get_c_display())
+        self.assertEqual('E', EnumCharModel(c='E').get_c_display())
+        self.assertTrue(EnumCharModel(c=None).get_c_display() is None)
+        self.assertEqual('', EnumCharModel(c='').get_c_display())

stdfields/views.py

-# Create your views here.
 # -*- coding: utf-8 -*-
 from django.forms.widgets import TextInput
 
+
 def format_minutes(value):
+    """Formats an integer as hours and minutes: ``hh:mm``."""
     if value is None or value == '':
         return u''
     try:
     divided = MinutesWidget.divide(value)
     return u'%d:%02d' % divided
 
+
 class MinutesWidget(TextInput):
-    
+    """Widgets for minute fields."""
+
     def _format_value(self, value):
         return format_minutes(value)
-        
+
     @classmethod
     def divide(cls, total):
         if total <= 0:
             return (0, 0)
         hours = 0 if total <= 60 else total / 60
-        return (hours, total - (hours * 60))
+        return (hours, total - (hours * 60))