Commits

Kevin Wetzels  committed 50ce8df

Add models.Enum as a shortcut and improvement of models.Enumeration and provide
a standard enumeration field for Enumeration that can be initialized, rather
than having to override Enumeration.all (backwards compatible).

Enum uses the EnumMetaclass to provide an Enumeration implementation that's a
lot easier to use and more alignment with the way Django models and forms are
constructed. models.EnumValue is also introduced to enable us to use the
specified order (using the same technique Django uses with a creation_counter)
without messing with tuples.

  • Participants
  • Parent commits 133c0bd

Comments (0)

Files changed (4)

         logging.info('Showing the %s pencils' % (Color.as_display(Color.RED)))
 
 
+That could be shorter. Use ``Enum`` instead::
+
+    # models.py
+    class Color(Enum):
+        RED = EnumValue('R', 'Red')
+        GREEN = EnumValue('G', 'Green')
+        BLUE = EnumValue('B', 'Blue')
+            
+    class Pencil(models.Model):
+        color = models.CharField(choices=Color.all(), max_length=Color.max_length())
+        
+    # views.py
+    def red_pencils(request):
+        pencils = Pencil.objects.filter(color=Color.RED)
+        ...
+        # Prints 'Showing the Red pencils'
+        logging.info('Showing the %s pencils' % (Color.RED_display))
+
+
 EnumCharField and EnumIntegerField
 ----------------------------------
-And now we can make working with an ``Enumeration`` easier with the 
+And now we can make working with an ``Enum`` easier with the 
 ``EnumCharField`` and ``EnumIntegerField`` models fields::
 
     # models.py
 
 setup(
     name='django-stdfields',
-    version='0.0.3',
+    version='0.0.10',
     author=u'Kevin Wetzels',
     author_email=u'kevin@roam.be',
     url='https://bitbucket.org/roam/django-stdfields',

File stdfields/models.py

 
         Useful to pass on to a field as the possible choices. Will be used
         internally by the customer enum fields. Should be *implemented* by 
-        subclasses.
+        subclasses if they provide no value for the ``enumeration`` field.
         """
-        return []
+        return cls.enumeration
 
     @classmethod
     def max_length(cls):
         keys = [smart_str(x[0]) for x in cls.all()]
         value = max(keys, key=lambda x: len(x))
         return 1 if not value else len(value)
+
+
+class EnumValue(object):
+    """
+    An EnumValue is basically a wrapper for tuples, adding a counter that 
+    allows us to keep them in the order they were defined.
+    """
+    creation_counter = 0
+
+    def __init__(self, key, label):
+        self.key = key
+        self.label = label
+        # Use the same trick Django uses to keep the order
+        self.creation_counter = EnumValue.creation_counter
+        EnumValue.creation_counter += 1
+
+    def as_tuple(self):
+        return (self.key, self.label)
+
+
+class EnumMetaclass(type):
+    """
+    Metaclass for Enum that will turn an Enum with EnumValue fields into an 
+    Enumeration instance.
+    """
+
+    def __new__(cls, name, bases, attrs):
+        enum_values = [(name, value) for name, value in attrs.items()
+                            if isinstance(value, EnumValue)]
+        enum_values.sort(key=lambda x: x[1].creation_counter)
+        for (name, value) in enum_values:
+            attrs[name] = value.key
+        tuples = [e[1].as_tuple() for e in enum_values]
+        previous_enumerations = []
+        for base in bases[::1]:
+            if hasattr(base, 'enumeration'):
+                previous_enumerations = [x for x in base.enumeration]
+        attrs['enumeration'] = previous_enumerations + tuples
+        for (name, value) in enum_values:
+            attrs['%s_display' % (name)] = value.label
+        return super(EnumMetaclass, cls).__new__(cls, name, bases, attrs)
+
+
+class Enum(Enumeration):
+    """
+    An easier to use version of ``Enumeration``, which will grab the possible
+    values from the defined ``EnumValue`` fields::
+
+        class Color(Enum):
+            RED = EnumValue('R', 'Red')
+            GREEN = EnumValue('G', 'Green')
+            BLUE = EnumValue('B', 'Blue')
+
+        class Pencil(models.Model):
+            color = EnumCharField(choices=Color.all())
+
+        Pencil.objects.filter(color=Color.RED)
+
+    This is the equivalent of::
+
+        class Color(Enumeration):
+            enumeration = [
+                ('R', 'Red'),
+                ('G', 'Green'),
+                ('B', 'Blue')
+            ]
+
+        class Pencil(models.Model):
+            color = EnumCharField(choices=Color.all())
+
+        Pencil.objects.filter(color=Color.RED)
+
+    But ``Enum`` also provides a ``FIELD_display`` field for each value::
+
+        Color.RED_display == 'Red'
+        Color.BLUE_display == 'Blue'
+    """
+
+    __metaclass__ = EnumMetaclass

File stdfields/tests.py

 from django.forms import ValidationError, ModelForm
 from django.db import models
 
-from stdfields.models import Enumeration
+from stdfields.models import Enumeration, Enum, EnumValue
 from stdfields.fields import MinutesField, EnumIntegerField, EnumCharField
 from stdfields.forms import MinutesField as MinutesFormField
 from stdfields.widgets import MinutesWidget
         self.assertEqual(w.render('hi', '2:60'), tpl % '2:60')
         self.assertEqual(w.render('hi', '493'), tpl % '8:13')
 
+# -- --------------------------------------------------------------------------
+# -- Enum tests
+
+class SimpleExampleIntegerEnum(Enum):
+    FIRST = EnumValue(1, 'First')
+    SECOND = EnumValue(2, 'Second')
+    THIRD = EnumValue(3, 'Third')
+
+class ExtendedIntegerEnum(SimpleExampleIntegerEnum):
+    FOURTH = EnumValue(4, 'Fourth')
+    FIFTH = EnumValue(5, 'Fifth')
+    SIXTH = EnumValue(6, 'Sixth')
+    SEVENTH = EnumValue(7, 'Seventh')
+
+class EnumTest(TestCase):
+
+    def test_order(self):
+        values = SimpleExampleIntegerEnum.all()
+        self.assertEqual(SimpleExampleIntegerEnum.FIRST, values[0][0])
+        self.assertEqual(SimpleExampleIntegerEnum.SECOND, values[1][0])
+        self.assertEqual(SimpleExampleIntegerEnum.THIRD, values[2][0])
+
+    def test_order_extended(self):
+        values = ExtendedIntegerEnum.all()
+        self.assertEqual(ExtendedIntegerEnum.FIRST, values[0][0])
+        self.assertEqual(ExtendedIntegerEnum.SECOND, values[1][0])
+        self.assertEqual(ExtendedIntegerEnum.THIRD, values[2][0])
+        self.assertEqual(ExtendedIntegerEnum.FOURTH, values[3][0])
+        self.assertEqual(ExtendedIntegerEnum.FIFTH, values[4][0])
+        self.assertEqual(ExtendedIntegerEnum.SIXTH, values[5][0])
+        self.assertEqual(ExtendedIntegerEnum.SEVENTH, values[6][0])
+
+    def test_display(self):
+        self.assertEqual(SimpleExampleIntegerEnum.THIRD_display, 'Third')
+        self.assertEqual(ExtendedIntegerEnum.THIRD_display, 'Third')
+        self.assertEqual(ExtendedIntegerEnum.SIXTH_display, 'Sixth')
+
 
 # -- --------------------------------------------------------------------------
 # -- EnumIntegerField tests
     FIRST = 1
     SECOND = 2
     THIRD = 3
-
-    @classmethod
-    def all(cls):
-        return [
-            (cls.FIRST, 'First'),
-            (cls.SECOND, 'Second'),
-            (cls.THIRD, 'Third'),
-        ]
+    enumeration = [
+        (FIRST, 'First'),
+        (SECOND, 'Second'),
+        (THIRD, 'Third'),
+    ]
 
 
 class EnumIntegerModel(models.Model):
 
 
 class EnumIntegerFieldTest(TestCase):
+    enum = ExampleIntegerEnum
+    model = EnumIntegerModel
 
     def test_enum_integer_field(self):
-        f = EnumIntegerField(enum=ExampleIntegerEnum,
-                            default=ExampleIntegerEnum.FIRST, blank=False)
-        self.assertEqual(ExampleIntegerEnum.all(), f.formfield().choices)
-        f = EnumIntegerField(enum=ExampleIntegerEnum)
-        expected = [('', '---------')] + ExampleIntegerEnum.all()
+        f = EnumIntegerField(enum=self.enum,
+                            default=self.enum.FIRST, blank=False)
+        self.assertEqual(self.enum.all(), f.formfield().choices)
+        f = EnumIntegerField(enum=self.enum)
+        expected = [('', '---------')] + self.enum.all()
         self.assertEqual(expected, f.formfield().choices)
 
     def test_display(self):
-        first = ExampleIntegerEnum.FIRST
-        label = ExampleIntegerEnum.as_display(first)
-        self.assertEqual(label, EnumIntegerModel(c=first).get_c_display())
-        self.assertEqual(5, EnumIntegerModel(c=5).get_c_display())
-        self.assertTrue(EnumIntegerModel(c=None).get_c_display() is None)
-        self.assertEqual('', EnumIntegerModel(c='').get_c_display())
+        first = self.enum.FIRST
+        label = self.enum.as_display(first)
+        self.assertEqual(label, self.model(c=first).get_c_display())
+        self.assertEqual(5, self.model(c=5).get_c_display())
+        self.assertTrue(self.model(c=None).get_c_display() is None)
+        self.assertEqual('', self.model(c='').get_c_display())
 
 
+class PureIntegerEnum(Enum):
+    FIRST = EnumValue(1, 'First')
+    SECOND = EnumValue(2, 'Second')
+    THIRD = EnumValue(3, 'Third')
+
+
+class PureEnumIntegerModel(models.Model):
+    c = EnumIntegerField(enum=PureIntegerEnum)
+
+
+PureEnumIntegerFieldTest = EnumIntegerFieldTest
+PureEnumIntegerFieldTest.enum = PureIntegerEnum
+PureEnumIntegerFieldTest.model = PureEnumIntegerModel
+
 # -- --------------------------------------------------------------------------
 # -- EnumCharField tests
 
     FIRST = 'A'
     SECOND = 'Boo'
     THIRD = 'Circus'
-
-    @classmethod
-    def all(cls):
-        return [
-            (cls.FIRST, 'First'),
-            (cls.SECOND, 'Second'),
-            (cls.THIRD, 'Third'),
-        ]
+    enumeration = [
+        (FIRST, 'First'),
+        (SECOND, 'Second'),
+        (THIRD, 'Third'),
+    ]
 
 
 class EnumCharModel(models.Model):
 
 
 class EnumCharFieldTest(TestCase):
+    
+    def __init__(self, *args, **kwargs):
+        super(EnumCharFieldTest, self).__init__(*args, **kwargs)
+        self.enum_cls = ExampleCharEnum
+        self.model_cls = EnumCharModel
+        # 'Circus' is the longest key at 6 characters
+        self.max_length = 6
 
     def test_enum_char_field(self):
-        f = EnumCharField(enum=ExampleCharEnum, default=ExampleCharEnum.FIRST,
-                        blank=False, max_length=ExampleCharEnum.max_length())
-        # 'Circus' is the longest key at 6 characters
-        self.assertEqual(6, f.max_length)
-        self.assertEqual(ExampleCharEnum.all(), f.formfield().choices)
-        f = EnumCharField(enum=ExampleCharEnum)
-        expected = [('', '---------')] + ExampleCharEnum.all()
+        enum = self.enum_cls
+        f = EnumCharField(enum=enum, default=enum.FIRST,
+                        blank=False, max_length=enum.max_length())
+        self.assertEqual(self.max_length, f.max_length)
+        self.assertEqual(enum.all(), f.formfield().choices)
+        f = EnumCharField(enum=enum)
+        expected = [('', '---------')] + enum.all()
         self.assertEqual(expected, f.formfield().choices)
 
     def test_display(self):
-        first = ExampleCharEnum.FIRST
-        label = ExampleCharEnum.as_display(first)
-        self.assertEqual(label, EnumCharModel(c=first).get_c_display())
-        self.assertEqual('E', EnumCharModel(c='E').get_c_display())
-        self.assertTrue(EnumCharModel(c=None).get_c_display() is None)
-        self.assertEqual('', EnumCharModel(c='').get_c_display())
+        enum = self.enum_cls
+        model = self.model_cls
+        first = enum.FIRST
+        label = enum.as_display(first)
+        self.assertEqual(label, model(c=first).get_c_display())
+        self.assertEqual('E', model(c='E').get_c_display())
+        self.assertTrue(model(c=None).get_c_display() is None)
+        self.assertEqual('', model(c='').get_c_display())
+
+
+class PureCharEnum(Enum):
+    FIRST = EnumValue('A', 'First')
+    SECOND = EnumValue('Boo', 'Second')
+    THIRD = EnumValue('Circus', 'Third')
+
+
+class PureEnumCharModel(models.Model):
+    c = EnumCharField(enum=PureCharEnum,
+                    max_length=PureCharEnum.max_length())
+
+
+class PureEnumCharFieldTest(EnumCharFieldTest):
+
+    def __init__(self, *args, **kwargs):
+        super(PureEnumCharFieldTest, self).__init__(*args, **kwargs)
+        self.enum_cls = PureCharEnum
+        self.model_cls = PureEnumCharModel
+
+
+class PureCharEnumExtension(PureCharEnum):
+    FOURTH = EnumValue('FICTIONAL', 'Acme Inc')
+    FIFTH = EnumValue('Wtf?', 'Stop the train!')
+
+
+class PureEnumCharExtensionModel(models.Model):
+    c = EnumCharField(enum=PureCharEnumExtension,
+                    max_length=PureCharEnumExtension.max_length())
+
+
+class PureEnumCharExtensionFieldTest(EnumCharFieldTest):
+
+    def __init__(self, *args, **kwargs):
+        super(PureEnumCharExtensionFieldTest, self).__init__(*args, **kwargs)
+        self.enum_cls = PureCharEnumExtension
+        self.model_cls = PureEnumCharExtensionModel
+        self.max_length = 9
+
+    def test_display_extra_values(self):
+        enum = self.enum_cls
+        model = self.model_cls
+        first = enum.FIFTH
+        label = 'Stop the train!'
+        self.assertEqual(label, enum.as_display(first))
+        self.assertEqual(label, model(c=first).get_c_display())
+        self.assertEqual('E', model(c='E').get_c_display())
+        self.assertTrue(model(c=None).get_c_display() is None)
+        self.assertEqual('', model(c='').get_c_display())