Tomasz Rybarczyk avatar Tomasz Rybarczyk committed bdd207a

Allow `execute_update` retries

Comments (0)

Files changed (2)

     import cPickle as pickle
 except ImportError:
     import pickle
+import redis
+import time
 
 class RedisDict(object):
 
     def items(self):
         return [(k, self._deserialize(v)) for k,v in self._redis.hgetall(self.name).items()]
 
-    def execute_update(self, key, updater):
+    def execute_update(self, key, updater, retry=0, interval=100):
         with self._redis.pipeline() as pipe:
-            pipe.watch(self.name)
-            old_value = self._deserialize(pipe.hget(self.name, key))
-            new_value = updater(old_value)
-            pipe.multi()
-            pipe.hset(self.name, key, self._serialize(new_value))
-            pipe.execute()
+            while retry+1:
+                try:
+                    pipe.watch(self.name)
+                    old_value = self._deserialize(pipe.hget(self.name, key))
+                    new_value = updater(old_value)
+                    pipe.multi()
+                    pipe.hset(self.name, key, self._serialize(new_value))
+                    pipe.execute()
+                except redis.WatchError:
+                    if not retry:
+                        raise
+                    retry -= 1
+                    time.sleep(interval)
+                else:
+                    break
 
     def get(self, key, default=None):
         return self[key] if key in self else default
 
 
 class RedisMapTestCase(unittest.TestCase):
-    """In order to run this test you have to start redis server locally"""
+    """In order to run this test case you have to start redis server locally"""
 
     def setUp(self):
         self.connection = redis.Redis()
         intersect = lambda c: self.redis_dict.__setitem__('two', 'two')
         self.assertRaises(redis.WatchError, lambda: self.redis_dict.execute_update('one', intersect))
 
+    def test_execute_update_retries(self):
+        self.redis_dict['one'] = 1
+        self.redis_dict['two'] = 2
+        class call_counter: count = 0
+        def intersect(c):
+            self.redis_dict.__setitem__('two', 'two')
+            call_counter.count += 1
+        self.assertRaises(redis.WatchError, lambda: self.redis_dict.execute_update('one', intersect, retry=3, interval=0))
+        self.assertEqual(call_counter.count, 4)
+
     def test_execute_update(self):
         self.redis_dict['one'] = 1
         self.redis_dict['two'] = 2
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.