1. Nam Nguyen
  2. intfprgm

Commits

Nam Nguyen  committed 62a0b54

Add support for disabling argspec check.

  • Participants
  • Parent commits 9c5740b
  • Branches default

Comments (0)

Files changed (2)

File intfprgm/__init__.py

View file
     return orig_class
 
 
-def check_override(clz, func):
+def check_override(clz, func, check_argspec=True):
     '''Walk the base classes of ``clz`` and check if ``func`` is declared
     with the same signature.
 
         func (function object): A function object to check for. This function
             must have been defined somewhere up in the class hierachy with
             the exact signature.
+        check_argspec (boolean): ``True`` if a function must be matched
+            exactly. Default to ``False``.
 
     Raises:
 
         try:
             argspec = inspect.getargspec(f)
             # same signature? good match
-            if orig_argspec == argspec:
+            if (not check_argspec) or (orig_argspec == argspec):
                 return
             # not? continue up the chain
             else:
             func.func_name, func.im_class.__name__))
 
 
+class overrides_impl(object):
+    '''Actual implementation of ``overrides``.
+
+    We need this to be a class so that we can create instances from it, with
+    different ``check_argspec`` values.
+
+    '''
+
+    def __init__(self, check_argspec=True):
+        self.check_argspec = check_argspec
+
+    def __call__(self, orig):
+        if type(orig) in (types.TypeType, types.ClassType):
+            for f in dir(orig):
+                f = getattr(orig, f)
+                if type(f) in (types.FunctionType, types.MethodType):
+                    if not hasattr(f, '__intfprgm_overrides__'):
+                        continue
+                    check_override(orig, f, f.__intfprgm_overrides__)
+        else:
+            orig.__intfprgm_overrides__ = self.check_argspec
+        return orig
+
+
 def overrides(orig):
     '''A decorator for both class and function that marks and checks if a
     function is defined in any of the base classes.
     yet. We apply ``overrides`` to the function so that it can mark that
     function to be checked later. And we apply ``overrides`` to the class to
     walk its members after the class has been fully defined.
+
+    If a function must be matched only by its name, we can set
+    ``check_argspec`` flag to ``False``::
+
+        @overrides(False)
+        def method(signature):
+            pass
     
+    By the fault, argspec check is set to ``True``.
+
+    (For those who care about the code, this function is basically two
+    overloaded functions::
+
+        def overrides(function_or_class):
+            return overrides_impl(False)(function_or_class)
+
+        def overrides(boolean_value):
+            return overrides_impl(boolean_value)
+
+    When we use ``@overrides``, the ``overrides`` after ``@`` is evaluated  to
+    the ``overrides`` function, and that function is invoked on the passed-in
+    function or class.
+
+    When we use ``@overrides(True)`` (or ``False``), the part after ``@`` is
+    evalulated as a function invocation on ``overrides`` with a boolean
+    argument. The returned value of that invocation is then used to decorate
+    the original function or class.)
+
     '''
 
-    if type(orig) in (types.TypeType, types.ClassType):
-        for f in dir(orig):
-            f = getattr(orig, f)
-            if type(f) in (types.FunctionType, types.MethodType):
-                if not hasattr(f, '__intfprgm_overrides__'):
-                    continue
-                check_override(orig, f)
+    if type(orig) is types.BooleanType:
+        return overrides_impl(orig)
     else:
-        orig.__intfprgm_overrides__ = True
-    return orig
+        return overrides_impl(True)(orig)

File intfprgm/tests.py

View file
             @overrides
             def method_1(self):
                 pass
-        try:
-            overrides(Derived)
-        except RuntimeError:
-            self.fail()
+        overrides(Derived)
 
     def test_not_found(self):
         class Base(object):
             @overrides
             def method_1(self):
                 pass
-        try:
-            overrides(Derived)
-            self.fail()
-        except RuntimeError:
-            pass
+        self.assertRaises(RuntimeError, overrides, Derived)
 
-    def test_not_matched_signature(self):
+    def test_check_argspec_true(self):
         class Base(object):
             def method_1(self, arg1):
                 pass
-        class Derived(Base):
-            @overrides
+        class Derived_1(Base):
+            @overrides(True)
             def method_1(self):
                 pass
-        try:
-            overrides(Derived)
-            self.fail()
-        except RuntimeError:
-            pass
+        self.assertRaises(RuntimeError, overrides, Derived_1)
+
+    def test_check_argspec_false(self):
+        class Base(object):
+            def method_1(self, arg1):
+                pass
+        class Derived_2(Base):
+            @overrides(False)
+            def method_1(self):
+                pass
+        overrides(Derived_2)
 
     def test_matched_signature_up_there_1(self):
         class Base(object):
             @overrides
             def method_1(self, arg1):
                 pass
-        try:
-            overrides(Derived_2)
-        except RuntimeError:
-            self.fail()
+        overrides(Derived_2)
 
 
 if __name__ == '__main__':