Commits

rafek  committed cf41dac

Implement a local transport

  • Participants
  • Parent commits d9b360d

Comments (0)

Files changed (3)

File python/protorpc/remote.py

 
   __metaclass__ = _ServiceClass
 
+  __request_state = None
+
   @classmethod
   def all_remote_methods(cls):
     """Get all remote methods for service class.

File python/protorpc/transport.py

 
 import httplib
 import logging
+import os
 import sys
 import urllib2
 
   'RpcStateError',
 
   'HttpTransport',
+  'LocalTransport',
   'Rpc',
   'Transport',
 ]
     rpc._wait_impl = wait_impl
 
     return rpc
+
+
+class LocalTransport(Transport):
+  """Local transport that sends messages directly to services.
+
+  Useful in tests or creating code that can work with either local or remote
+  services.  Using LocalTransport is preferrable to simply instantiating a
+  single instance of a service and reusing it.  The entire request process
+  involves instantiating a new instance of a service, initializing it with
+  request state and then invoking the remote method for every request.
+  """
+
+  def __init__(self, service_factory):
+    """Constructor.
+
+    Args:
+      service_factory: Service factory or class.
+    """
+    super(LocalTransport, self).__init__()
+    self.__service_class = getattr(service_factory,
+                                   'service_class',
+                                   service_factory)
+    self.__service_factory = service_factory
+
+  @property
+  def service_class(self):
+    return self.__service_class
+
+  @property
+  def service_factory(self):
+    return self.__service_factory
+
+  def _start_rpc(self, remote_info, request):
+    """Start a remote procedure call.
+
+    Args:
+      remote_info: RemoteInfo instance describing remote method.
+      request: Request message to send to service.
+
+    Returns:
+      An Rpc instance initialized with the request.
+    """
+    rpc = Rpc(request)
+    def wait_impl():
+      instance = self.__service_factory()
+      try:
+        initalize_request_state = instance.initialize_request_state
+      except AttributeError:
+        pass
+      else:
+        host = unicode(os.uname()[1])
+        initalize_request_state(remote.RequestState(remote_host=host,
+                                                    remote_address=u'127.0.0.1',
+                                                    server_host=host,
+                                                    server_port=-1))
+      try:
+        response = remote_info.method(instance, request)
+        assert isinstance(response, remote_info.response_type)
+      except remote.ApplicationError:
+        raise
+      except:
+        exc_type, exc_value, traceback = sys.exc_info()
+        message = 'Unexpected error %s: %s' % (exc_type.__name__, exc_value)
+        raise remote.ServerError, message, traceback
+      rpc.set_response(response)
+    rpc._wait_impl = wait_impl
+    return rpc

File python/protorpc/transport_test.py

 # limitations under the License.
 #
 
+import os
 import StringIO
 import types
 import unittest
 
 import mox
 
+package = 'transport_test'
+
 
 def reset_urlfetch():
   """Configure urlfetch library on transport module."""
     self.assertEquals(None, rpc.error_name)
 
 
+class SimpleRequest(messages.Message):
+
+  content = messages.StringField(1)
+
+
+class SimpleResponse(messages.Message):
+
+  content = messages.StringField(1)
+  factory_value = messages.StringField(2)
+  remote_host = messages.StringField(3)
+  remote_address = messages.StringField(4)
+  server_host = messages.StringField(5)
+  server_port = messages.IntegerField(6)
+
+
+class LocalService(remote.Service):
+
+  def __init__(self, factory_value='default'):
+    self.factory_value = factory_value
+
+  @remote.method(SimpleRequest, SimpleResponse)
+  def call_method(self, request):
+    return SimpleResponse(content=request.content,
+                          factory_value=self.factory_value,
+                          remote_host=self.request_state.remote_host,
+                          remote_address=self.request_state.remote_address,
+                          server_host=self.request_state.server_host,
+                          server_port=self.request_state.server_port)
+
+  @remote.method()
+  def raise_totally_unexpected(self, request):
+    raise TypeError('Kablam')
+
+  @remote.method()
+  def raise_unexpected(self, request):
+    raise remote.RequestError('Huh?')
+
+  @remote.method()
+  def raise_application_error(self, request):
+    raise remote.ApplicationError('App error', 10)
+
+
+class LocalTransportTest(test_util.TestCase):
+
+  def CreateService(self, factory_value='default'):
+    return 
+
+  def testBasicCallWithClass(self):
+    stub = LocalService.Stub(transport.LocalTransport(LocalService))
+    response = stub.call_method(content='Hello')
+    self.assertEquals(SimpleResponse(content='Hello',
+                                     factory_value='default',
+                                     remote_host=os.uname()[1],
+                                     remote_address='127.0.0.1',
+                                     server_host=os.uname()[1],
+                                     server_port=-1),
+                      response)
+
+  def testBasicCallWithFactory(self):
+    stub = LocalService.Stub(
+      transport.LocalTransport(LocalService.new_factory('assigned')))
+    response = stub.call_method(content='Hello')
+    self.assertEquals(SimpleResponse(content='Hello',
+                                     factory_value='assigned',
+                                     remote_host=os.uname()[1],
+                                     remote_address='127.0.0.1',
+                                     server_host=os.uname()[1],
+                                     server_port=-1),
+                      response)
+
+  def testTotallyUnexpectedError(self):
+    stub = LocalService.Stub(transport.LocalTransport(LocalService))
+    self.assertRaisesWithRegexpMatch(
+      remote.ServerError,
+      'Unexpected error TypeError: Kablam',
+      stub.raise_totally_unexpected)
+
+  def testUnexpectedError(self):
+    stub = LocalService.Stub(transport.LocalTransport(LocalService))
+    self.assertRaisesWithRegexpMatch(
+      remote.ServerError,
+      'Unexpected error RequestError: Huh?',
+      stub.raise_unexpected)
+
+  def testApplicationError(self):
+    stub = LocalService.Stub(transport.LocalTransport(LocalService))
+    self.assertRaisesWithRegexpMatch(
+      remote.ApplicationError,
+      'App error',
+      stub.raise_application_error)
+
+
 def main():
   unittest.main()