Michael Foord avatar Michael Foord committed e9e099b

getTestCaseNames event has access to isTestMethod
moduleloading plugin also provides test generators
config section and command line switch for loading test functions changed

Comments (0)

Files changed (6)

 * ``testMethodPrefix`` - set to None, modify this attribute to *change* the prefix being used for this class
 * ``extraNames`` - a list of extra names to use for this test case as well as the default ones
 * ``excludedNames`` - a list of names to exclude from loading from this class
+* ``isTestMethod`` - the default filter for telling if a name is a valid test method name
 
 This event can be handled. If it is handled it should return a list of strings. Note that if this event returns an empty list (or None which will be replaced with an empty list then ``loadTestsFromTestCase`` will still check to see if the TestCase has a ``runTest`` method.
 
 Even if the event is handled ``extraNames`` will still be added to the list, however *excludedNames`` won't be removed as they are filtered out by the default implementation which looks for all attributes that are methods (or callable) whose name begins with ``loader.testMethodPrefix`` (or ``event.testMethodPrefix`` if that is set) and aren't in the list of excluded names (converted to a set first for efficient lookup).
 
+Note that modifying ``isTestMethod`` has no effect. It is there as a convenience for plugins wanting to be able to use the default check.
+
 The list of names will also be sorted using ``loader.sortTestMethodsUsing``.
 
 
 [doctest]
 always-on = False
 
-[module-loading]
+[functions]
+always-on = False
+
+[generators]
 always-on = False
 
 [checker]

unittest2/case.py

         """
         self._testMethodName = methodName
         self._resultForDoCleanups = None
+
         try:
             testMethod = getattr(self, methodName)
         except AttributeError:
             raise ValueError("no such test method in %s: %s" % \
                   (self.__class__, methodName))
+    
         self._testMethodDoc = testMethod.__doc__
+            
         self._cleanups = []
 
         # Map types to custom assertEqual functions that will compare

unittest2/events.py

         self.extraTests = []
 
 class GetTestCaseNamesEvent(_Event):
-    def __init__(self, loader, testCase):
+    def __init__(self, loader, testCase, isTestMethod):
         _Event.__init__(self)
         self.loader = loader
         self.testCase = testCase
         self.testMethodPrefix = None
         self.extraNames = []
         self.excludedNames = []
+        self.isTestMethod = isTestMethod
 
 class RunnerCreatedEvent(_Event):
     def __init__(self, runner):

unittest2/loader.py

     def getTestCaseNames(self, testCaseClass):
         """Return a sorted sequence of method names found within testCaseClass
         """
-        event = GetTestCaseNamesEvent(self, testCaseClass)
+        excluded = set()
+        def isTestMethod(attrname, testCaseClass=testCaseClass,
+                         excluded=excluded):
+            prefix = event.testMethodPrefix or self.testMethodPrefix
+            return (
+                attrname.startswith(prefix) and 
+                hasattr(getattr(testCaseClass, attrname), '__call__') and
+                attrname not in excluded
+            )
+        event = GetTestCaseNamesEvent(self, testCaseClass, isTestMethod)
+
         result = hooks.getTestCaseNames(event)
         if event.handled:
             testFnNames = result or []
         else:
-            prefix = event.testMethodPrefix or self.testMethodPrefix 
-            excluded = set(event.excludedNames)
-            def isTestMethod(attrname, testCaseClass=testCaseClass,
-                                 prefix=prefix, excluded=excluded):
-                return (
-                    attrname.startswith(prefix) and 
-                    hasattr(getattr(testCaseClass, attrname), '__call__') and
-                    attrname not in excluded
-                )
+            excluded = excluded.update(event.excludedNames)
             testFnNames = filter(isTestMethod, dir(testCaseClass))
         if event.extraNames:
             testFnNames.extend(event.extraNames)

unittest2/plugins/moduleloading.py

-from unittest2 import Plugin, FunctionTestCase
+from unittest2 import Plugin, FunctionTestCase, TestCase
 
 import types
 
-help_text = 'Load test functions from test modules'
-class TestLoading(Plugin):
-    
-    configSection = 'module-loading'
-    commandLineSwitch = (None, 'test-functions', help_text)
-
-
-    def loadTestsFromModule(self, event):
-        loader = event.loader
-        module = event.module
-        
-        def is_test(obj):
-            return obj.__name__.startswith(loader.testMethodPrefix)
-        
-        tests = []
-        for name in dir(module):
-            obj = getattr(module, name)
-            if isinstance(obj, types.FunctionType) and is_test(obj):
-                args = {}
-                setUp = getattr(obj, 'setUp', None)
-                tearDown = getattr(obj, 'tearDown', None)
-                if setUp is not None:
-                    args['setUp'] = setUp
-                if tearDown is not None:
-                    args['tearDown'] = tearDown
-                case = FunctionTestCase(obj, **args)
-                tests.append(case)
-                
-        event.extraTests.extend(tests)
 
 def setUp(setupFunction):
     def decorator(func):
 
 def testGenerator(func):
     func.testGenerator = True
-    return func
+    return func
+
+class Functions(Plugin):
+    
+    generatorsEnabled = False
+    configSection = 'functions'
+    commandLineSwitch = (None, 'functions', 'Load tests from functions')
+
+    def loadTestsFromModule(self, event):
+        loader = event.loader
+        module = event.module
+        
+        def is_test(obj):
+            if obj is testGenerator:
+                return False
+            return obj.__name__.startswith(loader.testMethodPrefix)
+        
+        tests = []
+        for name in dir(module):
+            obj = getattr(module, name)
+            if isinstance(obj, types.FunctionType) and is_test(obj):
+                args = {}
+                setUp = getattr(obj, 'setUp', None)
+                tearDown = getattr(obj, 'tearDown', None)
+                if setUp is not None:
+                    args['setUp'] = setUp
+                if tearDown is not None:
+                    args['tearDown'] = tearDown
+                
+                if (not self.generatorsEnabled or 
+                    getattr(obj, 'testGenerator', None) is None):
+                    case = FunctionTestCase(obj, **args)
+                    tests.append(case)
+                else:
+                    extras = list(obj())
+                    name = '%s.%s' % (obj.__module__, obj.__name__)
+                    def createTest(name):
+                        return GeneratorFunctionCase(name, **args)
+                    tests.extend(testsFromGenerator(name, extras, createTest))
+                
+        event.extraTests.extend(tests)
+
+
+class GeneratorFunctionCase(FunctionTestCase):
+
+    def __init__(self, name, **args):
+        self._name = name
+        FunctionTestCase.__init__(self, None, **args)
+
+    _testFunc = property(lambda self: getattr(self, self._name),
+                         lambda self, func: None)
+
+    def __repr__(self):
+        return self._name
+
+    id = __str__ = __repr__
+
+
+class Generators(Plugin):
+
+    configSection = 'generators'
+    commandLineSwitch = (None, 'generators', 'Load tests from generators')
+
+    def pluginsLoaded(self, event):
+        Functions.generatorsEnabled = True
+        
+    def loadTestsFromTestCase(self, event):
+        testCaseClass = event.testCase
+        for name in dir(testCaseClass):
+            method = getattr(testCaseClass, name)
+            if getattr(method, 'testGenerator', None) is not None:
+                instance = testCaseClass(name)
+                tests = list(method(instance))
+                event.extraTests.extend(
+                    testsFromGenerator(name, tests, testCaseClass)
+                )
+
+    def getTestCaseNames(self, event):
+        names = filter(event.isTestMethod, dir(event.testCase))
+        klass = event.testCase
+        for name in names:
+            method = getattr(klass, name)
+            if getattr(method, 'testGenerator', None) is not None:
+                event.excludedNames.append(name)
+
+def testsFromGenerator(name, tests, testCaseClass):
+    for index, (func, args) in enumerate(tests):
+        summary = ', '.join(repr(arg) for arg in args)
+
+        method_name = '%s_%s\n%s' % (name, index + 1, summary[:79])
+        setattr(testCaseClass, method_name, None)
+        instance = testCaseClass(method_name)
+        delattr(testCaseClass, method_name)
+        def method(func=func, args=args):
+            return func(*args)
+        setattr(instance, method_name, method)
+        yield instance
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.