Commits

coady committed 992be14

Release 0.1.

Comments (0)

Files changed (4)

+include test.py
+"""
+Multiple argument dispacthing.
+
+Call multimethod on a variable number of types.
+It returns a decorator which finds the multimethod of the same name,
+creating it if necessary, and adds that function to it.  For example:
+
+    @multimethod(*types)
+    def func(*args):
+        ...
+
+'func' is now a multimethod which will delegate to the above function,
+when called with arguments of the specified types.  If an exact match
+can't be found, the next closest method will be called (and cached).
+A function can have more than one multimethod decorator.
+
+See tests for more example usage.
+"""
+
+import sys
+from itertools import imap as map, izip as zip
+
+class signature(tuple):
+    "A tuple of types that supports partial ordering."
+    __slots__ = ()
+    def __le__(self, other):
+        return len(self) <= len(other) and all(map(issubclass, other, self))
+    def __lt__(self, other):
+        return self != other and self <= other
+    def __sub__(self, other):
+        "Return relative distances, assuming self >= other."
+        return [list(left.__mro__).index(right) for left, right in zip(self, other)]
+
+class multimethod(dict):
+    "A callable directed acyclic graph of methods."
+    def __new__(cls, *types):
+        "Return a decorator which will add the function."
+        namespace = sys._getframe(1).f_locals
+        def decorator(func):
+            if isinstance(func, cls):
+                self, func = func, func.last
+            elif func.__name__ in namespace:
+                self = namespace[func.__name__]
+            else:
+                self = dict.__new__(cls)
+                self.__name__, self.cache = func.__name__, {}
+            self[types] = self.last = func
+            return self
+        return decorator
+    def parents(self, types):
+        "Find immediate parents of potential key."
+        parents, ancestors = set(), set()
+        for key, (value, superkeys) in self.items():
+            if key < types:
+                parents.add(key)
+                ancestors |= superkeys
+        return parents - ancestors
+    def __getitem__(self, types):
+        return dict.__getitem__(self, types)[0]
+    def __setitem__(self, types, func):
+        self.cache.clear()
+        parents = self.parents(types)
+        for key, (value, superkeys) in self.items():
+            if types < key and (not parents or parents & superkeys):
+                superkeys -= parents
+                superkeys.add(types)
+        dict.__setitem__(self, signature(types), (func, parents))
+    def __delitem__(self, types):
+        self.cache.clear()
+        dict.__delitem__(self, types)
+        for key, (value, superkeys) in self.items():
+            if types in superkeys:
+                dict.__setitem__(self, key, (value, self.parents(key)))
+    def super(self, *types):
+        "Return the next applicable method of given types."
+        keys = self.parents(types)
+        if keys:
+            return self[min(keys, key=signature(types).__sub__)]
+        raise TypeError("%s%s: no methods found" % (self.__name__, types))
+    def __call__(self, *args, **kwargs):
+        "Resolve and dispatch to best method."
+        types = tuple(map(type, args))
+        if types in self:
+            func = self[types]
+        elif types in self.cache:
+            func = self.cache[types]
+        else:
+            func = self.cache[types] = self.super(*types)
+        return func(*args, **kwargs)
+from distutils.core import setup
+
+setup(
+    name='multimethod',
+    version='0.1',
+    description='Multiple argument dispacthing.',
+    long_description='''
+    Multimethod is a simple pure python 2.5 module for dispatching functions on the types of multiple arguments.
+    It supports resolving to the next applicable method (super) and caching for fast dispatch.
+    It has more features than simplegeneric, but is lighter weight than PEAK.
+    ''',
+    author='Aric Coady',
+    author_email='aric.coady@gmail.com',
+    py_modules=['multimethod'],
+    classifiers=[
+        'Development Status :: 3 - Alpha',
+        'License :: OSI Approved :: Python Software Foundation License',
+    ],
+)
+import unittest
+from multimethod import multimethod
+
+# roshambo
+class rock(object):
+    pass
+
+class paper(object):
+    pass
+
+class scissors(object):
+    pass
+
+@multimethod(object, object)
+def roshambo(left, right):
+    return 'tie'
+
+@multimethod(scissors, rock)
+@multimethod(rock, scissors)
+def roshambo(left, right):
+    return 'rock smashes scissors'
+
+@multimethod(paper, scissors)
+@multimethod(scissors, paper)
+def roshambo(left, right):
+    return 'scissors cut paper'
+
+@multimethod(rock, paper)
+@multimethod(paper, rock)
+def roshambo(left, right):
+    return 'paper covers rock'
+
+# string join
+class tree(list):
+    def walk(self):
+        for value in self:
+            if isinstance(value, type(self)):
+                for subvalue in value.walk():
+                    yield subvalue
+            else:
+                yield value
+
+class bracket(tuple):
+    def __new__(cls, left, right):
+        return tuple.__new__(cls, (left, right))
+
+@multimethod(object, str)
+def join(seq, sep):
+    return sep.join(map(str, seq))
+
+@multimethod(object, bracket)
+def join(seq, sep):
+    return sep[0] + join(seq, sep[1]+sep[0]) + sep[1]
+
+@multimethod(tree, object)
+def join(seq, sep):
+    return join(seq.walk(), sep)
+
+class TestCase(unittest.TestCase):
+    def testRoshambo(self):
+        r, p, s = rock(), paper(), scissors()
+        assert roshambo(r, p) == 'paper covers rock'
+        assert roshambo(p, r) == 'paper covers rock'
+        assert roshambo(r, s) == 'rock smashes scissors'
+        assert roshambo(p, s) == 'scissors cut paper'
+        assert len(roshambo) == 7 and not roshambo.cache
+        assert roshambo(r, r) == 'tie'
+        assert roshambo.cache
+        del roshambo[object, object]
+        assert not roshambo.cache
+        self.assertRaises(TypeError, roshambo, r, r)
+    def testJoin(self):
+        sep = '<>'
+        seq = [0, tree([1]), 2]
+        assert list(tree(seq).walk()) == range(3)
+        assert join(seq, sep) == '0<>[1]<>2'
+        assert join(tree(seq), sep) == '0<>1<>2'
+        assert join(seq, bracket(*sep)) == '<0><[1]><2>'
+        assert join(tree(seq), bracket(*sep)) == '<0><1><2>'
+
+if __name__ == '__main__':
+    unittest.main()