Commits

Jesper Nøhr  committed 33b42c4

adding callbacks to Event.add_event

  • Participants
  • Parent commits 86556a0

Comments (0)

Files changed (2)

File src/rewsfeed/models.py

 except ImportError:
     from md5 import md5
 
+def always(*args, **kwargs):
+    return True
+    
 class Op(object):
     DATA_KEY = 'D+'
     FOLLOWS_KEY = 'F+'
         return self.kw.get('timestamp')
     
     @classmethod
-    def add_event(cls, network, who, **kwargs):
+    def add_event(cls, network, who, callback=always, **kwargs):
         if not isinstance(who, User):
             raise ValueError("A User must've done it. You gave me '%r'" % who)
 
         if not kwargs.has_key('timestamp'):
             raise ValueError("Events need to have timestamps to be ordered.")
-        
-        event = cls(network, **kwargs).save()
 
-        who.did(event)
+        if callback(network, who, **kwargs):
+            event = cls(network, **kwargs).save()
+            
+            who.did(event)

File tests/tests.py

         self.assertRaises(ValueError, Event.add_event, (self.n, 'failme'), { 'timestamp': 12345 })
         self.assertRaises(ValueError, Event.add_event, (self.n, self.u1), { 'lol_timestamp': 12345 })
         self.assertRaises(ValueError, Event.add_event, (), { 'lol_timestamp': 12345 })
+
+class TestEventCallback(unittest.TestCase):
+    def setUp(self):
+        self.n = Network('events')
+        self.n.conn.flushdb()
+
+        self.u1 = User(self.n, username='user1')
+        self.u1.save()
+
+    def test_callback_allowing(self):
+        def cb(network, who, **kwargs):
+            self.assertEquals(who, self.u1)
+            return True
+
+        Event.add_event(self.n, self.u1, callback=cb, timestamp=1234567, pk="42")
+        self.assertEquals(len(list(self.u1.get_events())), 1)
+
+    def test_callback_refusing(self):
+        def cb(network, who, **kwargs):
+            return False
+
+        Event.add_event(self.n, self.u1, callback=cb, timestamp=1234567, pk="45")
+        self.assertEquals(len(list(self.u1.get_events())), 0)
+
         
 if __name__ == "__main__":
     unittest.main()