Michael Foord  committed 939422c

Making the moduleloading plugin more modular

  • Participants
  • Parent commits 25a4ec0
  • Branches plugins

Comments (0)

Files changed (1)

File unittest2/plugins/

     parametersEnabled = False
     configSection = 'functions'
     commandLineSwitch = (None, 'functions', 'Load tests from functions')
+    unpack = enumerate
     def loadTestsFromName(self, event):
         name =
             args['tearDown'] = tearDown
         paramList = getattr(obj, 'paramList', None)
-        isGenerator = getattr(obj, 'testGenerator', False)
+        isGenerator = self.isGenerator(obj)
         if self.parametersEnabled and paramList is not None:
             for index, argSet in enumerate(paramList):
                 def func(argSet=argSet, obj=obj):
             name = '%s.%s' % (obj.__module__, obj.__name__)
             def createTest(name):
                 return GeneratorFunctionCase(name, **args)
-            tests.extend(testsFromGenerator(name, extras, createTest))
+            tests.extend(testsFromGenerator(name, extras, createTest,
+                                            self.unpack))
             case = FunctionTestCase(obj, **args)
             return [tests[testIndex-1]]
         return tests
+    def isGenerator(self, obj):
+        return getattr(obj, 'testGenerator', None) is not None
 class GeneratorFunctionCase(FunctionTestCase):
     def __init__(self, name, **args):
     configSection = 'generators'
     commandLineSwitch = (None, 'generators', 'Load tests from generators')
+    unpack = enumerate
     def pluginsLoaded(self, event):
         Functions.generatorsEnabled = True
         testCaseClass = event.testCase
         for name in dir(testCaseClass):
             method = getattr(testCaseClass, name)
-            if getattr(method, 'testGenerator', None) is not None:
+            if self.isGenerator(method):
                 instance = testCaseClass(name)
-                    testsFromGenerator(name, method(instance), testCaseClass)
+                    testsFromGenerator(name, method(instance), testCaseClass,
+                                       self.unpack)
+    def isGenerator(self, obj):
+        return getattr(obj, 'testGenerator', None) is not None
     def getTestCaseNames(self, event):
         names = filter(event.isTestMethod, dir(event.testCase))
         klass = event.testCase
         for name in names:
             method = getattr(klass, name)
-            if getattr(method, 'testGenerator', None) is not None:
+            if self.isGenerator(method):
     def loadTestsFromName(self, event):
         parent, obj, name, index = result
         if (index is None or not isinstance(parent, type) or 
             not issubclass(parent, TestCase) or 
-            not getattr(obj, 'testGenerator', False)):
+            not self.isGenerator(obj)):
             # we're only handling TestCase generator methods here
         instance = parent(obj.__name__)
-            test = list(testsFromGenerator(name, obj(instance), parent))[index-1]
+            test = list(testsFromGenerator(name, obj(instance), parent, 
+                                           self.unpack))[index-1]
         except IndexError:
             raise TestNotFoundError(original_name)
         return suite
-def testsFromGenerator(name, generator, testCaseClass):
+def testsFromGenerator(name, generator, testCaseClass, unpack):
-        for index, (func, args) in enumerate(generator):
+        for index, (func, args) in unpack(generator):
             method_name = name_from_args(name, index, args)
             setattr(testCaseClass, method_name, None)
             instance = testCaseClass(method_name)