Commits

Andriy Kornatskyy committed 4373008

Added ability overried forbid action; on reset delete reset locks as well.

  • Participants
  • Parent commits e0121b4

Comments (0)

Files changed (2)

File src/wheezy/caching/lockout.py

         self.name = name
         self.counters = counters
         self.cache = cache
-        self.namespace = namespace,
+        self.namespace = namespace
         self.key_prefix = key_prefix
         self.forbid_action = forbid_action
 
     def guard(self, func):
         """ A guard decorator is applied to a `func` which returns a
             boolean indicating success or failure. Each failure is a
-            subject to increase counter. The counters that supports
-            `reset` are deleted on success.
+            subject to increase counter. The counters that support
+            `reset` (and related locks) are deleted on success.
         """
         def guard_wrapper(ctx, *args, **kwargs):
             succeed = func(ctx, *args, **kwargs)
+            key_prefix = self.key_prefix
             if succeed:
-                keys = [self.key_prefix + c.key_func(ctx)
+                keys = [key_prefix + c.key_func(ctx)
                         for c in self.counters if c.reset]
+                key_prefix = 'lock:' + key_prefix
+                keys.extend([key_prefix + c.key_func(ctx)
+                             for c in self.counters if c.reset])
                 keys and self.cache.delete_multi(keys, 0, '', self.namespace)
             else:
                 for c in self.counters:
-                    key = self.key_prefix + c.key_func(ctx)
+                    key = key_prefix + c.key_func(ctx)
                     max_try = self.cache.add(
                         key, 1, c.period, self.namespace
                     ) and 1 or self.cache.incr(key, 1, self.namespace)
             return succeed
         return guard_wrapper
 
-    def forbid_locked(self, func):
+    def forbid_locked(self, wrapped=None, action=None):
         """ A decorator that forbids access (by a call to `forbid_action`)
             to `func` once the counter threshold is reached (lock is set).
+
+            You can override default forbid action by `action`.
+
+            See `test_lockout.py` for an example.
         """
-        key_prefix = 'lock:' + self.key_prefix
+        action = action or self.forbid_action
+        assert action
 
-        def forbid_locked_wrapper(ctx, *args, **kwargs):
-            locks = self.cache.get_multi(
-                [key_prefix + c.key_func(ctx) for c in self.counters],
-                '', self.namespace)
-            if locks:
-                return self.forbid_action(ctx)
-            return func(ctx, *args, **kwargs)
-        return forbid_locked_wrapper
+        def decorate(func):
+            key_prefix = 'lock:' + self.key_prefix
+
+            def forbid_locked_wrapper(ctx, *args, **kwargs):
+                locks = self.cache.get_multi(
+                    [key_prefix + c.key_func(ctx) for c in self.counters],
+                    '', self.namespace)
+                if locks:
+                    return action(ctx)
+                return func(ctx, *args, **kwargs)
+            return forbid_locked_wrapper
+        if wrapped is None:
+            return decorate
+        else:
+            return decorate(wrapped)

File src/wheezy/caching/tests/test_lockout.py

         else:
             return 'show error'
 
+    @lockout.forbid_locked(action=lambda s: "show captcha")
+    def action2(self):
+        if self.do_action():
+            return 'show ok'
+        else:
+            return 'show error'
+
     @lockout.guard
     def do_action(self):
         return self.action_result
 
 class LockoutTestCase(unittest.TestCase):
 
+    def setUp(self):
+        del alerts[:]
+
     def test_forbidden(self):
         s = MyService()
         s.user_id = 'u1'
         for i in range(4):
             assert 'show error' == s.action()
         assert 'forbidden' == s.action()
+
+    def test_custom_forbid_action(self):
+        s = MyService()
+        s.user_id = 'cfa-u1'
+        s.user_ip = 'cfa-ip1'
+        for i in range(4):
+            assert 'show error' == s.action2()
+        assert 'show captcha' == s.action2()