1. Andriy Kornatskyy
  2. wheezy.caching

Commits

Andriy Kornatskyy  committed ef25578

Added ability to override make_key for wraps_get_or_*.

  • Participants
  • Parent commits 610472e
  • Branches default

Comments (0)

Files changed (2)

File src/wheezy/caching/patterns.py

View file
  • Ignore whitespace
                 self.dependency.add(dependency_key_factory(), key)
         return result
 
-    def wraps_get_or_add(self, wrapped):
+    def wraps_get_or_add(self, wrapped=None, make_key=None):
         """ Returns specialized decorator for `get_or_add` cache
             pattern.
 
                 def list_items(self, locale):
                     pass
         """
-        make_key = self.key_builder(wrapped)
 
-        def get_or_add_wrapper(*args, **kwargs):
-            key = make_key(*args, **kwargs)
-            result = self.cache.get(key, self.namespace)
-            if result is not None:
+        def decorate(func):
+            mk = self.adapt(func, make_key)
+
+            def get_or_add_wrapper(*args, **kwargs):
+                key = mk(*args, **kwargs)
+                result = self.cache.get(key, self.namespace)
+                if result is not None:
+                    return result
+                result = func(*args, **kwargs)
+                if result is not None:
+                    self.cache.add(key, result, self.time, self.namespace)
                 return result
-            result = wrapped(*args, **kwargs)
-            if result is not None:
-                self.cache.add(key, result, self.time, self.namespace)
-            return result
-        return get_or_add_wrapper
+            return get_or_add_wrapper
+        if wrapped is None:
+            return decorate
+        else:
+            return decorate(wrapped)
 
     def get_or_set(self, key, create_factory, dependency_key_factory=None):
         """ Cache Pattern: get an item by *key* from *cache* and
                 self.dependency.add(dependency_key_factory(), key)
         return result
 
-    def __call__(self, wrapped):
-        return self.wraps_get_or_set(wrapped)
+    def __call__(self, wrapped=None, make_key=None):
+        return self.wraps_get_or_set(wrapped, make_key)
 
-    def wraps_get_or_set(self, wrapped):
+    def wraps_get_or_set(self, wrapped=None, make_key=None):
         """ Returns specialized decorator for `get_or_set` cache
             pattern.
 
                 def list_items(self, locale):
                     pass
         """
-        make_key = self.key_builder(wrapped)
 
-        def get_or_set_wrapper(*args, **kwargs):
-            key = make_key(*args, **kwargs)
-            result = self.cache.get(key, self.namespace)
-            if result is not None:
+        def decorate(func):
+            mk = self.adapt(func, make_key)
+
+            def get_or_set_wrapper(*args, **kwargs):
+                key = mk(*args, **kwargs)
+                result = self.cache.get(key, self.namespace)
+                if result is not None:
+                    return result
+                result = func(*args, **kwargs)
+                if result is not None:
+                    self.cache.set(key, result, self.time, self.namespace)
                 return result
-            result = wrapped(*args, **kwargs)
-            if result is not None:
-                self.cache.set(key, result, self.time, self.namespace)
-            return result
-        return get_or_set_wrapper
+            return get_or_set_wrapper
+        if wrapped is None:
+            return decorate
+        else:
+            return decorate(wrapped)
 
     def one_pass_create(self, key, create_factory,
                         dependency_key_factory=None):
         return self.one_pass_create(key, create_factory,
                                     dependency_key_factory)
 
-    def wraps_get_or_create(self, wrapped):
+    def wraps_get_or_create(self, wrapped=None, make_key=None):
         """ Returns specialized decorator for `get_or_create` cache
             pattern.
 
                 def list_items(self, locale):
                     pass
         """
-        make_key = self.key_builder(wrapped)
+        def decorate(func):
+            mk = self.adapt(func, make_key)
 
-        def get_or_create_wrapper(*args, **kwargs):
-            key = make_key(*args, **kwargs)
-            result = self.cache.get(key, self.namespace)
-            if result is not None:
-                return result
-            return self.one_pass_create(key, lambda: wrapped(*args, **kwargs))
-        return get_or_create_wrapper
+            def get_or_create_wrapper(*args, **kwargs):
+                key = mk(*args, **kwargs)
+                result = self.cache.get(key, self.namespace)
+                if result is not None:
+                    return result
+                return self.one_pass_create(
+                    key,
+                    lambda: func(*args, **kwargs))
+            return get_or_create_wrapper
+        if wrapped is None:
+            return decorate
+        else:
+            return decorate(wrapped)
+
+    # region: internal details
+
+    def adapt(self, func, make_key=None):
+        if make_key:
+            argnames = getargspec(func)[0]
+            if argnames and argnames[0] in ('self', 'cls', 'klass'):
+                return lambda ignore, *args, **kwargs: make_key(
+                    *args, **kwargs)
+            else:
+                return make_key
+        else:
+            return self.key_builder(func)
 
 
 class OnePass(object):

File src/wheezy/caching/tests/test_patterns.py

View file
  • Ignore whitespace
         assert cached.time == d.time
         assert cached.namespace == d.namespace
 
+    def test_adapt_make_key(self):
+        """ Adapts make_key to function args.
+        """
+
+        make_key = lambda: 'key'
+
+        def my_func():
+            pass
+        mk = self.cached.adapt(my_func, make_key)
+        assert 'key' == mk()
+
+    def test_adapt_make_key_cls(self):
+        """ Ignore 'cls' argument.
+        """
+
+        make_key = lambda: 'key'
+
+        def my_func(cls):
+            pass
+        mk = self.cached.adapt(my_func, make_key)
+        assert 'key' == mk('cls')
+
 
 class OnePassTestCase(unittest.TestCase):
 
 
 class WrapsGetOrAddTestCase(GetOrAddTestCase):
 
-    def setUp(self):
-        self.mock_cache = Mock()
-        self.mock_create_factory = Mock()
+    def test_has_dependency(self):
+        pass
 
     def get_or_add(self, dependency_factory=None):
         from wheezy.caching.patterns import Cached
         cached = Cached(self.mock_cache, kb, time=10, namespace='ns')
         return cached.wraps_get_or_add(self.mock_create_factory)()
 
-    def test_has_dependency(self):
-        """ Not supported.
-        """
-        pass
+
+class WrapsGetOrAddMakeKeyTestCase(WrapsGetOrAddTestCase):
+
+    def get_or_add(self, dependency_factory=None):
+        from wheezy.caching.patterns import Cached
+        cached = Cached(self.mock_cache, time=10, namespace='ns')
+
+        @cached.wraps_get_or_add(make_key=lambda: 'key')
+        def create_factory():
+            return self.mock_create_factory()
+        return create_factory()
 
 
 class GetOrSetTestCase(unittest.TestCase):
 
 class WrapsGetOrSetTestCase(GetOrSetTestCase):
 
-    def setUp(self):
-        self.mock_cache = Mock()
-        self.mock_create_factory = Mock()
+    def test_has_dependency(self):
+        pass
 
     def get_or_set(self, dependency_factory=None):
         from wheezy.caching.patterns import Cached
         cached = Cached(self.mock_cache, kb, time=10, namespace='ns')
         return cached.wraps_get_or_set(self.mock_create_factory)()
 
-    def test_has_dependency(self):
-        """ Not supported.
-        """
-        pass
+
+class WrapsGetOrSetMakeKeyTestCase(WrapsGetOrSetTestCase):
+
+    def get_or_set(self, dependency_factory=None):
+        from wheezy.caching.patterns import Cached
+        cached = Cached(self.mock_cache, time=10, namespace='ns')
+
+        @cached.wraps_get_or_set(make_key=lambda: 'key')
+        def create_factory():
+            return self.mock_create_factory()
+        return create_factory()
 
 
 class CachedCallTestCase(GetOrSetTestCase):
 
 class WrapsGetOrCreateTestCase(GetOrCreateTestCase):
 
-    def setUp(self):
-        self.mock_cache = Mock()
-        self.mock_create_factory = Mock()
-
     def get_or_create(self, dependency_factory=None):
         from wheezy.caching.patterns import Cached
         kb = lambda f: lambda *args, **kwargs: 'key'
         return cached.wraps_get_or_create(self.mock_create_factory)()
 
 
+class WrapsGetOrCreateMakeKeyTestCase(WrapsGetOrCreateTestCase):
+
+    def get_or_create(self, dependency_factory=None):
+        from wheezy.caching.patterns import Cached
+        cached = Cached(self.mock_cache, time=10, namespace='ns')
+
+        @cached.wraps_get_or_create(make_key=lambda: 'key')
+        def create_factory():
+            return self.mock_create_factory()
+        return create_factory()
+
+
 class KeyBuilderTestCase(unittest.TestCase):
 
     def setUp(self):