Commits

Jesper Nøhr committed 7e0df24

factoring user and event out so we can add our own

Comments (0)

Files changed (3)

src/rewsfeed/__init__.py

-from rewsfeed.models import Network, User, Event
+from rewsfeed.models import Network, FollowAble, Event

src/rewsfeed/models.py

 
     def __repr__(self):
         return '<'+', '.join([ '%s=%s' % (k, v[0:255]) for k, v in self.kw.iteritems() ])+'>'
-    
-class User(DictAbstracted):    
-    def follow(self, other_user):
+
+class FollowAble(DictAbstracted):
+    def follow(self, other_thing):
         """
-        Follows another user. Adds both forward and reverse
+        Follows another thing. Adds both forward and reverse
         relationships in our set.
         """
         forward = Op.FOLLOWS_KEY+self.key()
-        reverse = Op.FOLLOWERS_KEY+other_user.key()
+        reverse = Op.FOLLOWERS_KEY+other_thing.key()
 
-        return self.network.sadd(forward, other_user.key()) and self.network.sadd(reverse, self.key())
+        return self.network.sadd(forward, other_thing.key()) and self.network.sadd(reverse, self.key())
 
-    def unfollow(self, other_user):
+    def unfollow(self, other_thing):
         """
-        Unfollows another user.
+        Unfollows another thing.
         """
         forward = Op.FOLLOWS_KEY+self.key()
-        reverse = Op.FOLLOWERS_KEY+other_user.key()
+        reverse = Op.FOLLOWERS_KEY+other_thing.key()
 
-        return self.network.srem(forward, other_user.key()) and self.network.srem(reverse, self.key())
+        return self.network.srem(forward, other_thing.key()) and self.network.srem(reverse, self.key())
     
     # -- 
     
         follows_key = Op.FOLLOWS_KEY+self.key()
         blocked_key = Op.BLOCKED_KEY+self.key()
 
-        for user in self.network.sdiff((follows_key, blocked_key)):
-            yield User(self.network, **self.network.hgetall(Op.DATA_KEY+user))
+        for thing in self.network.sdiff((follows_key, blocked_key)):
+            yield self.__class__(self.network, **self.network.hgetall(Op.DATA_KEY+thing))
 
     def followers(self):
         followers_key = Op.FOLLOWERS_KEY+self.key()
         blocks_key = Op.BLOCKS_KEY+self.key()
 
-        for user in self.network.sdiff((followers_key, blocks_key)):
-            yield User(self.network, **self.network.hgetall(Op.DATA_KEY+user))
+        for thing in self.network.sdiff((followers_key, blocks_key)):
+            yield self.__class__(self.network, **self.network.hgetall(Op.DATA_KEY+thing))
         
-    def key(self, extra=''):
-        return md5(self.kw['username']+extra).hexdigest()
-
-    def flush(self):
-        return True
-
     # -- Events
 
     def did(self, event, callback):
         
         return self.network.zadd(Op.EVENTS_KEY+self.key(), event.key(), event.timestamp)
 
-    def done(self, start=0, end=20):
+    def done(self, kls, start=0, end=20):
         for hsh in self.network.zrange(Op.EVENTS_KEY+self.key(), start, end):
-            yield Event(self.network, **self.network.hgetall(Op.DATA_KEY+hsh))
+            yield kls(self.network, **self.network.hgetall(Op.DATA_KEY+hsh))
 
     get_events = done
 
-    def newsfeed(self, start=0, end=20):
+    def newsfeed(self, kls, start=0, end=20):
         for hsh in self.network.zrange(Op.NEWSFEED_KEY+self.key(), start, end):
-            yield Event(self.network, **self.network.hgetall(Op.DATA_KEY+hsh))
-            
+            yield kls(self.network, **self.network.hgetall(Op.DATA_KEY+hsh))
+        
 class Event(DictAbstracted):
     """
     Funky events.
     """
-    def key(self, extra=''):
-        return md5(self.kw['pk']+extra).hexdigest()
-    
     @property
     def timestamp(self):
         return self.kw.get('timestamp')
     
     @classmethod
     def add_event(cls, network, who, callback=always, **kwargs):
-        if not isinstance(who, User):
-            raise ValueError("A User must've done it. You gave me '%r'" % who)
+        if not isinstance(who, FollowAble):
+            raise ValueError("A `FollowAble` must've done it. You gave me '%r'" % who)
 
         if not kwargs.has_key('timestamp'):
             raise ValueError("Events need to have timestamps to be ordered.")
 
 sys.path.insert(0, '../src')
 
-from rewsfeed import Network, User, Event
+try:
+    from hashlib import md5
+except ImportError:
+    from md5 import md5
 
+from rewsfeed import Network, FollowAble, Event
+
+# Set up a simple Event type.
+class BaseEvent(Event):
+    def key(self, extra=''):
+        return md5(self.kw['pk']+extra).hexdigest()
+
+# Set up a simple User.
+class User(FollowAble):
+    def key(self, extra=''):
+        return 'U+'+md5(self.kw['username']+extra).hexdigest()
+
+# And a simple Repository.
+class Repository(FollowAble):
+    def key(self, extra=''):
+        return 'R+'+md5(self.kw['username']+self.kw['slug']+extra).hexdigest()
+    
 class TestNetwork(unittest.TestCase):
     def test_network_init(self):
         n = Network('test_network', '127.0.0.1', '6379')
         self.assertEquals(self.u2.network, self.n)
 
     def test_user_keys(self):
-        self.assertEquals(self.u1.key(), '24c9e15e52afc47c225b757e7bee1f9d')
-        self.assertEquals(self.u2.key(), '7e58d63b60197ceb55a1c487989a3720')
-
-    def test_user_flush(self):
-        self.assertEquals(self.u1.flush(), True)
-        self.assertEquals(self.u2.flush(), True)
+        self.assertEquals(self.u1.key(), 'U+24c9e15e52afc47c225b757e7bee1f9d')
+        self.assertEquals(self.u2.key(), 'U+7e58d63b60197ceb55a1c487989a3720')
 
     def test_follow(self):
         self.u1.follow(self.u2)
         self.u2.save()
 
     def test_inject_event(self):
-        Event.add_event(self.n, self.u1, event="commit", hash="1a2a3a4a",
-                        pk="1", timestamp=12345678901)
+        BaseEvent.add_event(self.n, self.u1, event="commit", hash="1a2a3a4a",
+                            pk="1", timestamp=12345678901)
 
-        self.assertEquals(len(list(self.u1.get_events())), 1)
+        self.assertEquals(len(list(self.u1.get_events(BaseEvent))), 1)
 
     def test_inject_many_events_ordering(self):
         num_events = 200
 
         for i in xrange(0, num_events):
-            Event.add_event(self.n, self.u2, event="commit", hash="c0ff33",
-                            pk=str(i), timestamp=random.randint(0, sys.maxint))
+            BaseEvent.add_event(self.n, self.u2, event="commit", hash="c0ff33",
+                                pk=str(i), timestamp=random.randint(0, sys.maxint))
 
-        self.assertEquals(len(list(self.u2.get_events(0, num_events))), num_events)
+        self.assertEquals(len(list(self.u2.get_events(BaseEvent, 0, num_events))), num_events)
 
         prev_tz = 0
 
         num_events = 200
 
         for i in xrange(0, num_events):
-            Event.add_event(self.n, self.u1, event="commit", hash="c0ff33",
-                            pk=str(i), timestamp=random.randint(0, sys.maxint))
+            BaseEvent.add_event(self.n, self.u1, event="commit", hash="c0ff33",
+                                pk=str(i), timestamp=random.randint(0, sys.maxint))
 
-        self.assertEquals(len(list(self.u2.newsfeed(0, num_events))), num_events)
-        self.assertEquals(len(list(self.u1.newsfeed())), 0)
+        self.assertEquals(len(list(self.u2.newsfeed(BaseEvent, 0, num_events))), num_events)
+        self.assertEquals(len(list(self.u1.newsfeed(BaseEvent))), 0)
 
         self.u1.follow(self.u2)
         
         for i in xrange(0, num_events):
-            Event.add_event(self.n, self.u2, event="commit", hash="c0ff33",
-                            pk=str(i), timestamp=random.randint(0, sys.maxint))
+            BaseEvent.add_event(self.n, self.u2, event="commit", hash="c0ff33",
+                                pk=str(i), timestamp=random.randint(0, sys.maxint))
 
-        self.assertEquals(len(list(self.u1.newsfeed(0, num_events))), num_events)
-        self.assertEquals(len(list(self.u2.newsfeed(0, num_events))), num_events)
+        self.assertEquals(len(list(self.u1.newsfeed(BaseEvent, 0, num_events))), num_events)
+        self.assertEquals(len(list(self.u2.newsfeed(BaseEvent, 0, num_events))), num_events)
 
 class TestEventsFail(unittest.TestCase):
     def setUp(self):
         self.u2.save()
 
     def test_these_should_fail(self):
-        self.assertRaises(ValueError, Event.add_event, (self.n, 'failme'), { 'timestamp': 12345 })
-        self.assertRaises(ValueError, Event.add_event, (self.n, self.u1), { 'lol_timestamp': 12345 })
-        self.assertRaises(ValueError, Event.add_event, (), { 'lol_timestamp': 12345 })
+        self.assertRaises(ValueError, BaseEvent.add_event, (self.n, 'failme'), { 'timestamp': 12345 })
+        self.assertRaises(ValueError, BaseEvent.add_event, (self.n, self.u1), { 'lol_timestamp': 12345 })
+        self.assertRaises(ValueError, BaseEvent.add_event, (), { 'lol_timestamp': 12345 })
 
 class TestEventCallback(unittest.TestCase):
     def setUp(self):
             self.assertEquals(event.pk, "42")
             return True
 
-        Event.add_event(self.n, self.u1, callback=cb, timestamp=1234567, pk="42")
-        self.assertEquals(len(list(self.u2.newsfeed())), 1)
+        BaseEvent.add_event(self.n, self.u1, callback=cb, timestamp=1234567, pk="42")
+        self.assertEquals(len(list(self.u2.newsfeed(BaseEvent))), 1)
 
     def test_callback_refusing(self):
         def cb(network, who, event):
             return False
 
-        Event.add_event(self.n, self.u1, callback=cb, timestamp=1234567, pk="45")
-        self.assertEquals(len(list(self.u2.newsfeed())), 0)
+        BaseEvent.add_event(self.n, self.u1, callback=cb, timestamp=1234567, pk="45")
+        self.assertEquals(len(list(self.u2.newsfeed(BaseEvent))), 0)
 
 class TestComplicatedNetwork(unittest.TestCase):
     def setUp(self):
                 self.u1.follow(users[-1])
 
         for idx, user in enumerate(users):
-            Event.add_event(self.n, user, timestamp=random.randint(0, sys.maxint), pk=str(idx))
+            BaseEvent.add_event(self.n, user, timestamp=random.randint(0, sys.maxint), pk=str(idx))
 
         last_tz = 0
             
-        for event in self.u1.newsfeed(0, num_users):
+        for event in self.u1.newsfeed(BaseEvent, 0, num_users):
             tz = int(event.timestamp)
             self.assertTrue(int(event.pk) % 3 == 0)
             self.assertTrue(tz > last_tz)
                     self.u1.follow(users[-1])
             
                 for repo in repos:
-                    Event.add_event(self.n, users[-1],
-                                    callback=cb,
-                                    timestamp=random.randint(0, sys.maxint),
-                                    pk=str(i+plus),
-                                    repository=repo.slug)
+                    BaseEvent.add_event(self.n, users[-1],
+                                        callback=cb,
+                                        timestamp=random.randint(0, sys.maxint),
+                                        pk=str(i+plus),
+                                        repository=repo.slug)
 
         spamalot(num_users*0)
                     
-        self.assertEquals(len(list(self.u1.newsfeed(0, num_users))), 67)
+        self.assertEquals(len(list(self.u1.newsfeed(BaseEvent, 0, num_users))), 67)
 
         spamalot(num_users*1)
 
-        self.assertEquals(len(list(self.u1.newsfeed(0, num_users))), 67*2)
+        self.assertEquals(len(list(self.u1.newsfeed(BaseEvent, 0, num_users))), 67*2)
         
         can_access.append('repo1')
         can_access.append('repo2')
 
         spamalot(num_users*2)
         
-        self.assertEquals(len(list(self.u1.newsfeed(0, num_users*2))), 67*3)
+        self.assertEquals(len(list(self.u1.newsfeed(BaseEvent, 0, num_users*2))), 67*3)
 
         can_access.remove('repo1')
 
         spamalot(num_users*3)
 
-        self.assertEquals(len(list(self.u1.newsfeed(0, num_users*3))), 67*4)
+        self.assertEquals(len(list(self.u1.newsfeed(BaseEvent, 0, num_users*3))), 67*4)
 
         can_access.remove('repo2')
         can_access.append('bar')
 
         spamalot(num_users*4)
 
-        self.assertEquals(len(list(self.u1.newsfeed(0, num_users*4))), 67*5)
+        self.assertEquals(len(list(self.u1.newsfeed(BaseEvent, 0, num_users*4))), 67*5)
         
 if __name__ == "__main__":
     unittest.main()