Commits

Anonymous committed 2f974a8

decorators: refactored @cached to allow usages such as @cached(cacheattr='_cachename') while keeping bw compat

Comments (0)

Files changed (3)

 ============================
 
 	--
-* date: new datetime/delta <-> seconds/days conversion function
-	
+    * date: new datetime/delta <-> seconds/days conversion function
+
+    * decorators: refactored @cached to allow usages such as
+      @cached(cacheattr='_cachename') while keeping bw compat
+
 2011-04-01  --  0.55.2
     * new function for password generation in shellutils
 
 __docformat__ = "restructuredtext en"
 
 import types
+import sys, re
 from time import clock, time
-import sys, re
 
 # XXX rewrite so we can use the decorator syntax when keyarg has to be specified
 
 def _is_generator_function(callableobj):
     return callableobj.func_code.co_flags & 0x20
 
-def cached(callableobj, keyarg=None):
-    """Simple decorator to cache result of method call."""
-    assert not _is_generator_function(callableobj), 'cannot cache generator function: %s' % callableobj
-    if callableobj.func_code.co_argcount == 1 or keyarg == 0:
+class cached_decorator(object):
+    def __init__(self, cacheattr=None, keyarg=None):
+        self.cacheattr = cacheattr
+        self.keyarg = keyarg
+    def __call__(self, callableobj=None):
+        assert not _is_generator_function(callableobj), \
+               'cannot cache generator function: %s' % callableobj
+        if callableobj.func_code.co_argcount == 1 or self.keyarg == 0:
+            cache = _SingleValueCache(callableobj, self.cacheattr)
+        elif self.keyarg:
+            cache = _MultiValuesKeyArgCache(callableobj, self.keyarg, self.cacheattr)
+            print 'hop'
+        else:
+            cache = _MultiValuesCache(callableobj, self.cacheattr)
+        return cache.closure()
 
-        def cache_wrapper1(self, *args):
-            cache = '_%s_cache_' % callableobj.__name__
-            #print 'cache1?', cache
-            try:
-                return self.__dict__[cache]
-            except KeyError:
-                #print 'miss'
-                value = callableobj(self, *args)
-                setattr(self, cache, value)
-                return value
+class _SingleValueCache(object):
+    def __init__(self, callableobj, cacheattr=None):
+        self.callable = callableobj
+        if cacheattr is None:
+            self.cacheattr = '_%s_cache_' % callableobj.__name__
+        else:
+            assert cacheattr != callableobj.__name__
+            self.cacheattr = cacheattr
+
+    def __call__(__me, self, *args):
         try:
-            cache_wrapper1.__doc__ = callableobj.__doc__
-            cache_wrapper1.func_name = callableobj.func_name
+            return self.__dict__[__me.cacheattr]
+        except KeyError:
+            value = __me.callable(self, *args)
+            setattr(self, __me.cacheattr, value)
+            return value
+
+    def closure(self):
+        def wrapped(*args, **kwargs):
+            return self.__call__(*args, **kwargs)
+        wrapped.clear = self.clear
+        try:
+            wrapped.__doc__ = self.callable.__doc__
+            wrapped.__name__ = self.callable.__name__
+            wrapped.func_name = self.callable.func_name
         except:
             pass
-        return cache_wrapper1
+        return wrapped
 
-    elif keyarg:
+    def clear(self, holder):
+        holder.__dict__.pop(self.cacheattr, None)
 
-        def cache_wrapper2(self, *args, **kwargs):
-            cache = '_%s_cache_' % callableobj.__name__
-            key = args[keyarg-1]
-            #print 'cache2?', cache, self, key
-            try:
-                _cache = self.__dict__[cache]
-            except KeyError:
-                #print 'init'
-                _cache = {}
-                setattr(self, cache, _cache)
-            try:
-                return _cache[key]
-            except KeyError:
-                #print 'miss', self, cache, key
-                _cache[key] = callableobj(self, *args, **kwargs)
-            return _cache[key]
+
+class _MultiValuesCache(_SingleValueCache):
+    def _get_cache(self, holder):
         try:
-            cache_wrapper2.__doc__ = callableobj.__doc__
-            cache_wrapper2.func_name = callableobj.func_name
-        except:
-            pass
-        return cache_wrapper2
+            _cache = holder.__dict__[self.cacheattr]
+        except KeyError:
+            _cache = {}
+            setattr(holder, self.cacheattr, _cache)
+        return _cache
 
-    def cache_wrapper3(self, *args):
-        cache = '_%s_cache_' % callableobj.__name__
-        #print 'cache3?', cache, self, args
-        try:
-            _cache = self.__dict__[cache]
-        except KeyError:
-            #print 'init'
-            _cache = {}
-            setattr(self, cache, _cache)
+    def __call__(__me, self, *args, **kwargs):
+        _cache = __me._get_cache(self)
         try:
             return _cache[args]
         except KeyError:
-            #print 'miss'
-            _cache[args] = callableobj(self, *args)
-        return _cache[args]
-    try:
-        cache_wrapper3.__doc__ = callableobj.__doc__
-        cache_wrapper3.func_name = callableobj.func_name
-    except:
-        pass
-    return cache_wrapper3
+            _cache[args] = __me.callable(self, *args)
+            return _cache[args]
+
+class _MultiValuesKeyArgCache(_MultiValuesCache):
+    def __init__(self, callableobj, keyarg, cacheattr=None):
+        super(_MultiValuesKeyArgCache, self).__init__(callableobj, cacheattr)
+        self.keyarg = keyarg
+
+    def __call__(__me, self, *args, **kwargs):
+        _cache = __me._get_cache(self)
+        key = args[__me.keyarg-1]
+        try:
+            return _cache[key]
+        except KeyError:
+            _cache[key] = __me.callable(self, *args, **kwargs)
+            return _cache[key]
+
+
+def cached(callableobj=None, keyarg=None, **kwargs):
+    """Simple decorator to cache result of method call."""
+    kwargs['keyarg'] = keyarg
+    decorator = cached_decorator(**kwargs)
+    if callableobj is None:
+        return decorator
+    else:
+        return decorator(callableobj)
 
 def clear_cache(obj, funcname):
     """Function to clear a cache handled by the cached decorator."""
-    try:
-        del obj.__dict__['_%s_cache_' % funcname]
-    except KeyError:
-        pass
+    getattr(obj, funcname).clear(obj)
 
 def copy_cache(obj, funcname, cacheobj):
     """Copy cache for <funcname> from cacheobj to obj."""
-    cache = '_%s_cache_' % funcname
+    cache = getattr(obj, funcname).cacheattr
     try:
         setattr(obj, cache, cacheobj.__dict__[cache])
     except KeyError:
         pass
 
+
 class wproperty(object):
     """Simple descriptor expecting to take a modifier function as first argument
     and looking for a _<function name> to retrieve the attribute.

test/unittest_decorators.py

 """
 
 from logilab.common.testlib import TestCase, unittest_main
-from logilab.common.decorators import monkeypatch, cached
+from logilab.common.decorators import monkeypatch, cached, clear_cache
 
 class DecoratorsTC(TestCase):
 
             def quux(self, zogzog):
                 """ what's up doc ? """
         self.assertEqual(Foo.foo.__doc__, """ what's up doc ? """)
+        self.assertEqual(Foo.foo.__name__, 'foo')
         self.assertEqual(Foo.foo.func_name, 'foo')
         self.assertEqual(Foo.bar.__doc__, """ what's up doc ? """)
+        self.assertEqual(Foo.bar.__name__, 'bar')
         self.assertEqual(Foo.bar.func_name, 'bar')
         self.assertEqual(Foo.quux.__doc__, """ what's up doc ? """)
+        self.assertEqual(Foo.quux.__name__, 'quux')
         self.assertEqual(Foo.quux.func_name, 'quux')
 
+    def test_cached_single_cache(self):
+        class Foo(object):
+            @cached(cacheattr=u'_foo')
+            def foo(self):
+                """ what's up doc ? """
+        foo = Foo()
+        foo.foo()
+        self.assertTrue(hasattr(foo, '_foo'))
+        clear_cache(foo, 'foo')
+        self.assertFalse(hasattr(foo, '_foo'))
+
+    def test_cached_multi_cache(self):
+        class Foo(object):
+            @cached(cacheattr=u'_foo')
+            def foo(self, args):
+                """ what's up doc ? """
+        foo = Foo()
+        foo.foo(1)
+        self.assertEqual(foo._foo, {(1,): None})
+        clear_cache(foo, 'foo')
+        self.assertFalse(hasattr(foo, '_foo'))
+
+    def test_cached_keyarg_cache(self):
+        class Foo(object):
+            @cached(cacheattr=u'_foo', keyarg=1)
+            def foo(self, other, args):
+                """ what's up doc ? """
+        foo = Foo()
+        foo.foo(2, 1)
+        self.assertEqual(foo._foo, {2: None})
+        clear_cache(foo, 'foo')
+        self.assertFalse(hasattr(foo, '_foo'))
+
 if __name__ == '__main__':
     unittest_main()