Commits

Andriy Kornatskyy committed e0121b4

Introduced lockout module.

Comments (0)

Files changed (2)

src/wheezy/caching/lockout.py

+
+""" ``lockout`` module.
+"""
+
+from warnings import warn
+
+from wheezy.caching.utils import total_seconds
+
+
+class Locker(object):
+    """ Used to define lockout terms.
+    """
+
+    def __init__(self, cache, forbid_action, namespace=None,
+                 key_prefix='c', **terms):
+        self.cache = cache
+        self.forbid_action = forbid_action
+        self.namespace = namespace
+        self.key_prefix = key_prefix
+        self.terms = terms
+
+    def define(self, name, **terms):
+        """ Defines a new lockout with given `name` and `terms`.
+            The `terms` keys must correspond to `known terms` of locker.
+        """
+        if not terms:  # pragma: nocover
+            warn('Locker: no terms', stacklevel=2)
+        key_prefix = '%s:%s:' % (self.key_prefix, name.replace(' ', '_'))
+        counters = [self.terms[t](**terms[t]) for t in terms]
+        return Lockout(name, counters, self.forbid_action,
+                       self.cache, self.namespace, key_prefix)
+
+
+class Counter(object):
+    """ A container of various attributes used by lockout.
+    """
+
+    def __init__(self, key_func, count, period, duration,
+                 reset=True, alert=None):
+        self.key_func = key_func
+        self.count = count
+        self.period = total_seconds(period)
+        self.duration = total_seconds(duration)
+        self.reset = reset
+        self.alert = alert
+
+
+class Lockout(object):
+    """ A lockout is used to enforce terms of use policy.
+    """
+
+    def __init__(self, name, counters, forbid_action,
+                 cache, namespace, key_prefix):
+        self.name = name
+        self.counters = counters
+        self.cache = cache
+        self.namespace = namespace,
+        self.key_prefix = key_prefix
+        self.forbid_action = forbid_action
+
+    def guard(self, func):
+        """ A guard decorator is applied to a `func` which returns a
+            boolean indicating success or failure. Each failure is a
+            subject to increase counter. The counters that supports
+            `reset` are deleted on success.
+        """
+        def guard_wrapper(ctx, *args, **kwargs):
+            succeed = func(ctx, *args, **kwargs)
+            if succeed:
+                keys = [self.key_prefix + c.key_func(ctx)
+                        for c in self.counters if c.reset]
+                keys and self.cache.delete_multi(keys, 0, '', self.namespace)
+            else:
+                for c in self.counters:
+                    key = self.key_prefix + c.key_func(ctx)
+                    max_try = self.cache.add(
+                        key, 1, c.period, self.namespace
+                    ) and 1 or self.cache.incr(key, 1, self.namespace)
+                    #print("%s ~ %d" % (key, max_try))
+                    if max_try >= c.count:
+                        self.cache.delete(key, 0, self.namespace)
+                        self.cache.add('lock:' + key, 1,
+                                       c.duration, self.namespace)
+                        c.alert and c.alert(ctx, self.name, c)
+            return succeed
+        return guard_wrapper
+
+    def forbid_locked(self, func):
+        """ A decorator that forbids access (by a call to `forbid_action`)
+            to `func` once the counter threshold is reached (lock is set).
+        """
+        key_prefix = 'lock:' + self.key_prefix
+
+        def forbid_locked_wrapper(ctx, *args, **kwargs):
+            locks = self.cache.get_multi(
+                [key_prefix + c.key_func(ctx) for c in self.counters],
+                '', self.namespace)
+            if locks:
+                return self.forbid_action(ctx)
+            return func(ctx, *args, **kwargs)
+        return forbid_locked_wrapper

src/wheezy/caching/tests/test_lockout.py

+
+""" Unit tests for ``wheezy.caching.lockout``.
+"""
+
+import unittest
+
+from datetime import timedelta
+
+from wheezy.caching.memory import MemoryCache
+from wheezy.caching.lockout import Counter
+from wheezy.caching.lockout import Locker
+
+
+# region: alerts
+
+alerts = []
+
+
+def send_mail(s, name, counter):
+    assert isinstance(s, MyService)
+    alerts.append('send mail: %s' % name)
+
+
+def send_sms(s, name, counter):
+    assert isinstance(s, MyService)
+    alerts.append('send sms: %s' % name)
+
+
+def ignore_alert(s, name, counter):
+    assert isinstance(s, MyService)
+    alerts.append('ignore: %s' % name)
+
+
+# region: lockouts and defaults
+
+def lockout_by_id(count=10,
+                  period=timedelta(minutes=1),
+                  duration=timedelta(hours=2),
+                  reset=False,
+                  alert=send_mail):
+    key_func = lambda s: 'by_id:%s' % s.user_id
+    return Counter(key_func=key_func, count=count,
+                   period=period, duration=duration,
+                   reset=reset, alert=alert)
+
+
+def lockout_by_ip(count=10,
+                  period=timedelta(minutes=1),
+                  duration=timedelta(hours=2),
+                  reset=True,
+                  alert=send_sms):
+    key_func = lambda s: 'by_ip:%s' % s.user_ip
+    return Counter(key_func=key_func, count=count,
+                   period=period, duration=duration,
+                   reset=reset, alert=alert)
+
+
+def lockout_by_id_ip(count=10,
+                     period=timedelta(minutes=1),
+                     duration=timedelta(hours=2),
+                     reset=True,
+                     alert=ignore_alert):
+    key_func = lambda s: 'by_id_ip:%s:%s' % (s.user_id, s.user_ip)
+    return Counter(key_func=key_func, count=count,
+                   period=period, duration=duration,
+                   reset=reset, alert=alert)
+
+
+# region: config
+
+cache = MemoryCache()
+locker = Locker(cache, key_prefix='my_app',
+                forbid_action=lambda s: 'forbidden',
+                by_id=lockout_by_id,
+                by_ip=lockout_by_ip,
+                by_id_ip=lockout_by_id_ip)
+
+
+# region: service/handler
+
+class MyService(object):
+
+    lockout = locker.define(
+        name='action',
+        by_id_ip=dict(count=4, duration=60),
+        by_id=dict(count=6, duration=timedelta(minutes=2)),
+        by_ip=dict(count=8, duration=timedelta(minutes=5))
+    )
+
+    action_result = False
+    user_id = None
+    user_ip = None
+
+    @lockout.forbid_locked
+    def action(self):
+        if self.do_action():
+            return 'show ok'
+        else:
+            return 'show error'
+
+    @lockout.guard
+    def do_action(self):
+        return self.action_result
+
+
+# region: test case
+
+class LockoutTestCase(unittest.TestCase):
+
+    def test_forbidden(self):
+        s = MyService()
+        s.user_id = 'u1'
+        s.user_ip = 'ip1'
+        for i in range(4):
+            assert 'show error' == s.action()
+        assert ['ignore: action'] == alerts
+        del alerts[:]
+        assert 'forbidden' == s.action(), 'lock by id/ip'
+
+        s.user_ip = 'ip2'
+        for i in range(2):
+            assert 'show error' == s.action()
+        assert ['send mail: action'] == alerts
+        del alerts[:]
+        assert 'forbidden' == s.action(), 'lock by id'
+
+        s.user_id = 'u3'
+        for i in range(3):
+            assert 'show error' == s.action()
+        s.user_id = 'u4'
+        for i in range(3):
+            assert 'show error' == s.action()
+        assert ['send sms: action'] == alerts
+        assert 'forbidden' == s.action(), 'lock by ip'
+
+    def test_reset_on_success(self):
+        s = MyService()
+        s.user_id = 'u0'
+        s.user_ip = 'ip0'
+        for i in range(2):
+            assert 'show error' == s.action()
+
+        s.action_result = True
+        assert 'show ok' == s.action()
+
+        s.action_result = False
+        for i in range(4):
+            assert 'show error' == s.action()
+        assert 'forbidden' == s.action()