Commits

Anonymous committed 0fe22f6

Refactored the green version of zmq to use a factory function for Context

  • Participants
  • Parent commits 7bcb557

Comments (0)

Files changed (4)

File eventlet/green/zmq.py

 __zmq__ = __import__('zmq')
 from eventlet import sleep
-from eventlet.hubs import trampoline
+from eventlet.hubs import trampoline, get_hub
 
 __patched__ = ['Context', 'Socket']
 globals().update(dict([(var, getattr(__zmq__, var))
                               var in __patched__)
                        ]))
 
-class Context(__zmq__.Context):
+
+def get_hub_name_from_instance(hub):
+    return hub.__class__.__module__.rsplit('.',1)[-1]
+
+def Context(io_threads=1):
+    hub = get_hub()
+    hub_name = get_hub_name_from_instance(hub)
+    if hub_name != 'zeromq':
+        raise RuntimeError("Hub must be 'zeromq', got '%s'" % hub_name)
+    return hub.get_context(io_threads)
+
+class _Context(__zmq__.Context):
 
     def socket(self, socket_type):
         return Socket(self, socket_type)

File eventlet/hubs/zeromq.py

 class Hub(poll.Hub):
 
 
-
     def __init__(self, clock=time.time):
         BaseHub.__init__(self, clock)
         self.poll = zmq.Poller()
 
-    def get_context(self):
+    def get_context(self, io_threads=1):
         """zmq's Context must be unique within a hub
 
         The zeromq API documentation states:
         try:
             return _threadlocal.context
         except AttributeError:
-            _threadlocal.context = zmq.Context()
+            _threadlocal.context = zmq._Context(io_threads)
             return _threadlocal.context
 
     def register(self, fileno, new=False):

File examples/distributed_websocket_chat.py

 from uuid import uuid1
 
 use_hub('zeromq')
-hub = get_hub()
-ctx = hub.get_context()
+ctx = zmq.Context()
 
 class IDName(object):
 

File tests/zmq_test.py

 from eventlet import event, spawn, sleep, patcher
-from eventlet.hubs import use_hub, get_hub, _threadlocal
-from eventlet.hubs.hub import READ, WRITE
+from eventlet.hubs import get_hub, _threadlocal, use_hub
 from eventlet.green import zmq
 from nose.tools import *
 from tests import mock, LimitedTestCase, skip_unless
 from unittest import TestCase
 
 from threading import Thread
+from eventlet.hubs.zeromq import Hub
 
 def using_zmq(_f):
     return 'zeromq' in type(get_hub()).__module__
 
 class TestUpstreamDownStream(LimitedTestCase):
 
+    sockets = []
+
     def tearDown(self):
         self.clear_up_sockets()
         super(TestUpstreamDownStream, self).tearDown()
 
     def create_bound_pair(self, type1, type2, interface='tcp://127.0.0.1'):
         """Create a bound socket pair using a random port."""
-        self.context = context = get_hub().get_context()
+        self.context = context = zmq.Context()
         s1 = context.socket(type1)
         port = s1.bind_to_random_port(interface)
         s2 = context.socket(type2)
     """
 
     @skip_unless_zmq
+    @mock.patch('eventlet.green.zmq.get_hub_name_from_instance')
+    @mock.patch('eventlet.green.zmq.get_hub', spec=Hub)
+    def test_context_factory_funtion(self, get_hub_mock, hub_name_mock):
+        hub_name_mock.return_value = 'zeromq'
+        ctx = zmq.Context()
+        self.assertTrue(get_hub_mock().get_context.called)
+
+    @skip_unless_zmq
     def test_threadlocal_context(self):
         hub = get_hub()
-        context = hub.get_context()
+        context = zmq.Context()
         self.assertEqual(context, _threadlocal.context)
         next_context = hub.get_context()
         self.assertTrue(context is next_context)
 
     @skip_unless_zmq
     def test_different_context_in_different_thread(self):
-        context = get_hub().get_context()
+        context = zmq.Context()
         test_result = []
         def assert_different(ctx):
 #            assert not hasattr(_threadlocal, 'hub')
 #            os.environ['EVENTLET_HUB'] = 'zeromq'
             hub = get_hub()
             try:
-                this_thread_context = hub.get_context()
+                this_thread_context = zmq.Context()
             except:
                 test_result.append('fail')
                 raise
             sleep(0.1)
         self.assertFalse(test_result[0])
 
+class TestCheckingForZMQHub(TestCase):
 
+    def setUp(self):
+        self.orig_hub = zmq.get_hub_name_from_instance(get_hub())
+        use_hub('poll')
 
+    def tearDown(self):
+        use_hub(self.orig_hub)
 
+    def test_assertionerror_raise_by_context(self):
+        self.assertRaises(RuntimeError, zmq.Context)
 
+
+
+
+