Commits

stoneleaf  committed d822e94

convenience function added
iteration order tested

  • Participants
  • Parent commits 541950f

Comments (0)

Files changed (2)

 
 import sys
 
-__all__ = ('Enum', 'IntEnum', 'InvalidEnum')
+__all__ = ('Enum', 'IntEnum', 'EnumError')
 
-class InvalidEnum(Exception):
+class EnumError(ValueError):
     'general purpose exception'
 
 class EnumDict(dict):
         self._enums.append(key)
         dict.__setitem__(self, key, something)
 
+Enum = None  # dummy value until replaced
+
 class EnumType(type):
 
     @classmethod
     def __prepare__(metacls, cls, bases):
+        if Enum in bases and bases[-1] is not Enum:
+            raise EnumError('Enum must be last in mro when subclassing')
         return EnumDict()
 
     def __new__(metacls, cls, bases, classdict):
                 else:
                     enum_item = obj_type.__new__(result, value)
             except Exception as exc:
-                raise InvalidEnum(*exc.args) from None
+                raise EnumError(*exc.args) from None
             enum_item.value = value
             enum_item.name = e
             for name, stored_value in enum_map.items():
                 setattr(result, name, getattr(Enum, name))
         return result
 
-    def __call__(cls, enum_lookup):
-        for enum_name in cls._enums:
-            enum = cls._enum_map[enum_name]
-            if enum.value == enum_lookup:
-                return enum
-        raise ValueError("%s is not a valid %s" % (enum_lookup, cls.__name__))
+    def __call__(cls, value, names=None):
+        if names is None:  # simple value lookup
+            return cls.__new__(cls, value)
+        # create a new Enum type
+        class_name = value  # better name for a name than value ;)
+        metacls = cls.__class__
+        classdict = metacls.__prepare__(class_name, (cls, ))
+        if isinstance(names, str):
+            names = [(e, i) for (i, e) in enumerate(names.replace(',',' ').split(), 1)]
+        for e, v in names:
+            classdict[e] = v
+        result = metacls.__new__(metacls, class_name, (cls, ), classdict)
+        return result
 
     def __contains__(cls, enum_item):
         if type(enum_item) is cls:
             enum = cls._enum_map[name]
             if enum.value == value:
                 return enum
-        raise InvalidEnum("nothing matches %r" % value)
+        raise EnumError("nothing matches %r" % value)
 
     def __iter__(cls):
         enums = []
             enums.append(cls._enum_map[name])
         return iter(enums)
 
-    #def __init__(cls, *args, **kws):
-    #    super().__init__(cls)
-
     def __len__(cls):
         return len(cls._enums)
 
                 ', '.join('%s:%r' % (k, v) for k, v in enums)
                 )
 
-    def create(cls, name, enums, type=None, register=None, export=None):
-        """
-        creates a class of enumerations, optionally registering the class
-        itself, or the enumerations, or both
-        """
-        metacls = cls.__class__
-        bases = (type, ) if type is not None else (cls, )
-        classdict = metacls.__prepare__(name, bases)
-        if isinstance(enums, str):
-            enums = [(e, i) for (i, e) in enumerate(enums.replace(',',' ').split(), 1)]
-        for e, v in enums:
-            classdict[e] = v
-        result = metacls.__new__(metacls, name, bases, classdict)
-        return result
-
 class Enum(metaclass=EnumType):
     "valueless, unordered Enum class"
 
             enum = cls._enum_map[enum_name]
             if enum.value == value:
                 return enum
-        raise ValueError("%s is not a valid %s" % (value, cls.__name__))
+        raise EnumError("%s is not a valid %s" % (value, cls.__name__))
 
     def __repr__(self):
         return "%s.%s [value=%s]" % (self.__class__.__name__, self.name, self.value)

File test_ref435.py

 #!/usr/bin/python3.3
 from pickle import dumps, loads
 import unittest
-from ref435 import Enum, IntEnum, InvalidEnum
+from ref435 import Enum, IntEnum, EnumError
 
 # for pickle tests
 try:
         self.assertTrue(Period[2] is Period.noon)
         self.assertTrue(Period['night'] is Period.night)
 
+    def test_iteration_order(self):
+        class Season(Enum):
+            SUMMER = 2
+            WINTER = 4
+            AUTUMN = 3
+            SPRING = 1
+        self.assertEqual(
+                list(Season),
+                [Season.SUMMER, Season.WINTER, Season.AUTUMN, Season.SPRING],
+                )
+
+    def test_convenience_function(self):
+        SummerMonth = Enum('SummerMonth', 'june july august')
+        lst = list(SummerMonth)
+        self.assertEqual(len(lst), len(SummerMonth))
+        self.assertEqual(len(SummerMonth), 3, SummerMonth)
+        self.assertEqual([SummerMonth.june, SummerMonth.july, SummerMonth.august], lst)
+        for i, month in enumerate('june july august'.split(), 1):
+            e = SummerMonth(i)
+            self.assertEqual(int(e), i)
+            self.assertNotEqual(e, i)
+            self.assertEqual(e.name, month)
+            self.assertIn(e, SummerMonth)
+            self.assertTrue(type(e) is SummerMonth)
+
     def test_subclassing(self):
         self.assertEqual(Name.BDFL, 'Guido van Rossum')
         self.assertTrue(Name.BDFL is Name['BDFL'])
         self.assertTrue(Name.BDFL is loads(dumps(Name.BDFL)))
 
+    def test_wrong_mro(self):
+        with self.assertRaises(EnumError):
+            class Wrong(Enum, str):
+                NotHere = 'error before this point'
+
+    def test_no_such_enum_item(self):
+        class Color(Enum):
+            red = 1
+            green = 2
+            blue = 3
+        with self.assertRaises(EnumError):
+            Color(4)
+
 if __name__ == '__main__':
     unittest.main()