Commits

Tomasz Rybarczyk  committed bdd207a

Allow `execute_update` retries

  • Participants
  • Parent commits 9b3b420

Comments (0)

Files changed (2)

File redis_dict.py

     import cPickle as pickle
 except ImportError:
     import pickle
+import redis
+import time
 
 class RedisDict(object):
 
     def items(self):
         return [(k, self._deserialize(v)) for k,v in self._redis.hgetall(self.name).items()]
 
-    def execute_update(self, key, updater):
+    def execute_update(self, key, updater, retry=0, interval=100):
         with self._redis.pipeline() as pipe:
-            pipe.watch(self.name)
-            old_value = self._deserialize(pipe.hget(self.name, key))
-            new_value = updater(old_value)
-            pipe.multi()
-            pipe.hset(self.name, key, self._serialize(new_value))
-            pipe.execute()
+            while retry+1:
+                try:
+                    pipe.watch(self.name)
+                    old_value = self._deserialize(pipe.hget(self.name, key))
+                    new_value = updater(old_value)
+                    pipe.multi()
+                    pipe.hset(self.name, key, self._serialize(new_value))
+                    pipe.execute()
+                except redis.WatchError:
+                    if not retry:
+                        raise
+                    retry -= 1
+                    time.sleep(interval)
+                else:
+                    break
 
     def get(self, key, default=None):
         return self[key] if key in self else default
 
 
 class RedisMapTestCase(unittest.TestCase):
-    """In order to run this test you have to start redis server locally"""
+    """In order to run this test case you have to start redis server locally"""
 
     def setUp(self):
         self.connection = redis.Redis()
         intersect = lambda c: self.redis_dict.__setitem__('two', 'two')
         self.assertRaises(redis.WatchError, lambda: self.redis_dict.execute_update('one', intersect))
 
+    def test_execute_update_retries(self):
+        self.redis_dict['one'] = 1
+        self.redis_dict['two'] = 2
+        class call_counter: count = 0
+        def intersect(c):
+            self.redis_dict.__setitem__('two', 'two')
+            call_counter.count += 1
+        self.assertRaises(redis.WatchError, lambda: self.redis_dict.execute_update('one', intersect, retry=3, interval=0))
+        self.assertEqual(call_counter.count, 4)
+
     def test_execute_update(self):
         self.redis_dict['one'] = 1
         self.redis_dict['two'] = 2