1. Nam Nguyen
  2. intfprgm

Commits

Nam Nguyen  committed 361877a

Add ``overrides`` decorator.

  • Participants
  • Parent commits 7c313ae
  • Branches default

Comments (0)

Files changed (2)

File intfprgm/__init__.py

View file
  • Ignore whitespace
 
 '''
 
+import collections
 import dis
+import inspect
 import types
 
 
                 orig_class.__name__, ', '.join(f.func_name for f in funcs)))
 
     return orig_class
+
+
+def check_override(clz, func):
+    '''Walk the base classes of ``clz`` and check if ``func`` is declared
+    with the same signature.
+
+    Raise ``RuntimeError`` if ``func`` is not found in any of the base class.
+
+    Args:
+
+        clz (class object): A class object where ``func`` should be defined.
+        func (function object): A function object to check for. This function
+            must have been defined somewhere up in the class hierachy with
+            the exact signature.
+
+    Raises:
+
+        ``RuntimeError`` if ``func`` is not found in any of the base class.
+
+    '''
+
+    queue = collections.deque()
+    queue.extend(clz.__bases__)
+    orig_argspec = inspect.getargspec(func)
+    while queue:
+        base = queue.popleft()
+        # get the function
+        f = getattr(base, func.func_name, None)
+        try:
+            argspec = inspect.getargspec(f)
+            # same signature? good match
+            if orig_argspec == argspec:
+                return
+            # not? continue up the chain
+            else:
+                queue.extend(base.__bases__)
+        except Exception:
+            # f may not be a function, that's okay, carry on
+            queue.extend(base.__bases__)
+    else:
+        raise RuntimeError('%r is not found in any base class of %r.' % (
+            func.func_name, func.im_class.__name__))
+
+
+def overrides(orig):
+    '''A decorator for both class and function that marks and checks if a
+    function is defined in any of the base classes.
+
+    For example::
+
+        @overrides
+        class Derived(Base)
+
+            @overrides
+            def method(signature):
+                pass
+
+    This decorator MUST be applied on both the class and the function. The
+    reason is that during definition, the function is not assigned to a class
+    yet. We apply ``overrides`` to the function so that it can mark that
+    function to be checked later. And we apply ``overrides`` to the class to
+    walk its members after the class has been fully defined.
+    
+    '''
+
+    if type(orig) in (types.TypeType, types.ClassType):
+        for f in dir(orig):
+            f = getattr(orig, f)
+            if type(f) in (types.FunctionType, types.MethodType):
+                if not hasattr(f, '__intfprgm_overrides__'):
+                    continue
+                check_override(orig, f)
+    else:
+        orig.__intfprgm_overrides__ = True
+    return orig

File intfprgm/tests.py

View file
  • Ignore whitespace
 
 import unittest
 
-from intfprgm import interface, abstract, concrete
+from intfprgm import interface, abstract, concrete, overrides
 
 
 class InterfaceTest(unittest.TestCase):
         self.assertRaises(SyntaxError, concrete, test2)
 
 
+class OverridesTest(unittest.TestCase):
+
+    def test_found(self):
+        class Base(object):
+            def method_1(self):
+                pass
+        class Derived(Base):
+            @overrides
+            def method_1(self):
+                pass
+        try:
+            overrides(Derived)
+        except RuntimeError:
+            self.fail()
+
+    def test_not_found(self):
+        class Base(object):
+            pass
+        class Derived(Base):
+            @overrides
+            def method_1(self):
+                pass
+        try:
+            overrides(Derived)
+            self.fail()
+        except RuntimeError:
+            pass
+
+    def test_not_matched_signature(self):
+        class Base(object):
+            def method_1(self, arg1):
+                pass
+        class Derived(Base):
+            @overrides
+            def method_1(self):
+                pass
+        try:
+            overrides(Derived)
+            self.fail()
+        except RuntimeError:
+            pass
+
+    def test_matched_signature_up_there_1(self):
+        class Base(object):
+            def method_1(self, arg1):
+                pass
+        class Derived_1(Base):
+            @overrides
+            def method_1(self, *args):
+                pass
+        class Derived_2(Derived_1):
+            @overrides
+            def method_1(self, arg1):
+                pass
+        try:
+            overrides(Derived_2)
+        except RuntimeError:
+            self.fail()
+
+
 if __name__ == '__main__':
     unittest.main()