Source

starry / starry / transport.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
import threading
import uuid
import zmq

from protorpc.transport import Transport
from protorpc.transport import Rpc
from protorpc import protobuf
from protorpc import remote

from .rpc_pb import Request
from .rpc_pb import Response
from .rpc_pb import _Req
from .rpc_pb import _Resp

logger = logging.getLogger('starry.transport')


class TcpTransport(Transport):

    def __init__(self, client_id, address, service_name,\
            timeout=1, protocol=protobuf):
        super(TcpTransport, self).__init__(protocol=protocol)
        self.client_id = client_id
        self.address = address
        self.__service_name = service_name
        self.timeout = timeout

        self._request_table = {}
        self._generator = lambda: uuid.uuid4().bytes
        self.ctx = zmq.Context()
        self.poller = zmq.Poller()
        self._reconnect()

    def _reconnect(self):
        self.sock = self.ctx.socket(zmq.XREQ)
        self.sock.connect(self.address)
        self.poller.register(self.sock, zmq.POLLIN)

    def _start_rpc(self, remote_info, request):
        """Start a remote procedure call.

        Args:
            remote_info: A RemoteInfo instance for this RPC.
            request: The request message for this RPC.

        Returns:
            An Rpc instance initialized with a Request.
        """
        method_name = remote_info.method.func_name
        response_type = remote_info.response_type
        return self.call_remote(method_name, request, response_type)

    def close(self):
        self.sock.setsockopt(zmq.LINGER, 0)
        self.poller.unregister(self.sock)
        self.sock.close()
        self._request_table = {}

    def ping(self):
        return self._send('ping')

    def _send(self, msg):
        request = _Req(msg=msg)
        result = self.call_remote('_send', request, _Resp)
        resp = result.response
        return resp.msg


    def call_remote(self, method_name, request, response_type):
        """Start a remote procedure call.

        Args:
            method_name: called method for this RPC.
            request: The request message for this RPC.

        Returns:
            An Rpc instance initialized with a Request.
        """
        req = Request()
        req.client_id = self.client_id
        req.service_name = self.__service_name
        req.method_name = method_name
        req.request = self.protocol.encode_message(request)
        msgid = self._generator()
        req.msgid = msgid

        msg = self.protocol.encode_message(req)
        self.sock.send(msg, flags=zmq.NOBLOCK)
        future = Future(self.protocol, response_type)
        self._request_table[msgid] = future
        rpc = Rpc(request)

        def wait_impl():
            future = self._request_table.pop(msgid)
            try:
                while True:
                    if future.finished:
                        future.set_rpc(rpc)
                        return
                    evt = dict(self.poller.poll(timeout=self.timeout * 1000))[self.sock]
                    if evt & zmq.POLLIN:
                        reply = self.sock.recv()
                        rsp = self.protocol.decode_message(Response, reply)
                        _msg_id = rsp.msgid
                        if _msg_id == msgid:
                            future.set_response(rsp)
                        else:
                            self._request_table[_msg_id].set_response(rsp)
            except KeyError:
                logging.error('',exc_info=1)
                self.close()
                rpc.set_status(remote.RpcStatus(state=remote.RpcState.NETWORK_ERROR,
                    error_message='Network Error : {0} timeout'.format(self.__service_name)))

        rpc._wait_impl = wait_impl
        return rpc


class Future(object):

    def __init__(self, protocol, response_type):
        self.protocol = protocol
        self.response_type = response_type
        self._finished = False

    @property
    def finished(self):
        return self._finished

    def set_response(self, rsp):
        self.rsp = rsp
        self._finished = True

    def set_rpc(self, rpc):
        if self.rsp.status:
            rpc.set_status(self.rsp.status)
        else:
            result = self.protocol.decode_message(self.response_type, self.rsp.response)
            rpc.set_response(result)