Source

starry / starry / server.py

Full commit
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
import time
from threading import Thread
import signal
import os

import zmq
from protorpc import protobuf
from protorpc import remote

from .rpc_pb import Request
from .rpc_pb import Response

logger = logging.getLogger(__file__)


class Server(object):

    channel = None

    def __init__(self, protocol=protobuf):
        self.protocol = protocol
        self.service_dict = {}

    def register_service(self, service, name=None):
        '''
        Class-instances which should be available for calling, are registered
        with this function.

        No methods for this instance are stored because they could be reached
        with the function getattr for any given class.

        :param instance: class-instance which should be registered
        :param name: with this name the instance is being registered
        '''
        if not name:
            definition_name_function = getattr(service, \
                    'definition_name', None)
            if definition_name_function:
                name = definition_name_function()
            else:
                name = '{0}.{1}'.format(service.__module__,
                                        service.__class__.__name__,
                                        )
            name = service.definition_name()
        logger.debug('register service with name {0}'.format(name))
        self.service_dict[name] = service

    def get_service(self, service_name):
        return self.service_dict[service_name]

    def get_method(self, service, method_name):
        return getattr(service, method_name)

    def start(self, *args, **kwargs):
        """
        Start this server.You need override this method.
        """
        raise NotImplementedError()

    def stop(self, *args, **kwargs):
        """
        Stop this server.You need override this method.
        """
        raise NotImplementedError()


class RPCServer(Server):

    def __init__(self, name, listen, worker_num=5,
            protocol=protobuf):

        super(RPCServer, self).__init__(protocol)
        self.server_id = "{0}@{1}_{2}".format(name, os.uname()[1], os.getpid())
        self.listen = listen
        self.worker_num = worker_num
        self.context = zmq.Context(1)
        self.threads = []
        self.__stop = False

    def stop(self):
        logger.info('Stopping server.')
        self.__stop = True

    def register_signal_handlers(self):
        signal.signal(signal.SIGTERM, lambda signum, frame: self.stop())
        signal.signal(signal.SIGINT, lambda signum, frame: self.stop())
        signal.signal(signal.SIGQUIT, lambda signum, frame: self.stop())
        signal.signal(signal.SIGUSR1, lambda signum, frame: self.stop())

    def start(self):
        # Socket to talk to workers
        worker_url = "inproc://{0}_workers".format(self.server_id)
        logger.debug('bind worker on: {0}'.format(worker_url))
        workers = self.context.socket(zmq.XREP)
        workers.bind(worker_url)

        # Socket to talk to clients
        logger.debug('start listen on: {0}'.format(self.listen))
        clients = self.context.socket(zmq.XREP)
        clients.bind(self.listen)

        for i in range(0, self.worker_num):
            execute_method_thread = ExecuteMethodThread(self, worker_url)
            self.threads.append(execute_method_thread)
            execute_method_thread.start()

        self.register_signal_handlers()

        logger.info('start lru device')
        self.lru_device(clients, workers)
        self.stop_threads()
        logger.info('Shut down.')
        clients.close()
        workers.close()
        self.context.term()

    def stop_threads(self):
        for thread in self.threads:
            thread.stop()

    def lru_device(self, xrep_clients, xrep_workers):
        workers_list = []

        poller = zmq.Poller()
        poller.register(xrep_clients, zmq.POLLIN)
        poller.register(xrep_workers, zmq.POLLIN)

        logger.info('Starting LRU loop')
        while not self.__stop:

            try:
                while not self.__stop:
                    socks = dict(poller.poll(1000))
                    if len(socks.keys()) > 0:
                        break
                    else:
                        continue
            except zmq.ZMQError:
                self.stop()

            if self.__stop:
                break

            # handle worker activity on the backend
            if xrep_workers in socks \
                    and socks[xrep_workers] == zmq.POLLIN:
                msg = xrep_workers.recv_multipart()
                # add the worker address to the queue
                worker_addr = msg[0]
                workers_list.append(worker_addr)
                # is the worker delivering or registering
                if msg[2] != 'READY':
                    # delivering, so forward to the original client
                    xrep_clients.send_multipart(msg[2:])
                continue

            #  Dequeue and drop the next worker address
            if workers_list \
               and xrep_clients in socks \
               and socks[xrep_clients] == zmq.POLLIN:

                worker_addr = workers_list.pop()
                msg_in = xrep_clients.recv_multipart()
                msg_out = [worker_addr, ''] + msg_in
                xrep_workers.send_multipart(msg_out)


class ExecuteMethodThread(Thread):

    def __init__(self, server, worker_url):
        Thread.__init__(self)
        self.server = server
        self.context = self.server.context
        self.server_id = self.server.server_id
        self.protocol = self.server.protocol
        self.worker_url = worker_url
        self.__stop = False

    def stop(self):
        self.__stop = True

    def reply(self, response):
        self.__reply(self.protocol.encode_message(response))

    def __reply(self, msg):
        self.socket.send_multipart(self.return_address + [msg])

    def __send_error(self,
                    response,
                    status_state,
                    error_message,
                    error_name=None):

        status = remote.RpcStatus(state=status_state,
                error_message='{0} - '
                       .format(self.server.server_id + error_message),
                error_name=error_name)
        response.status = status
        return response

    def handle_message(self, msg):
        with self.server.app.request_context():
            resp = Response()
            try:
                req = self.protocol.decode_message(Request, msg)
                client_id = req.client_id
                service_name = req.service_name
                method_name = req.method_name
                if method_name == 'ping':
                    resp.response = "I'm still alive"
                    return resp
                service = self.server.get_service(service_name)
                method = self.server.get_method(service, method_name)
                method_info = method.remote
                try:
                    request_msg = self.protocol.decode_message(method_info.request_type, req.request)
                except Exception as e:
                    raise remote.ServiceDefinitionError(e)
            except KeyError as e:
                logger.error("Unrecognized Service: {0}, called from: {1}".format(e, client_id), exc_info=1)
                return self.__send_error(resp,
                        remote.RpcState.METHOD_NOT_FOUND_ERROR,
                        'Unrecognized RPC Service: %s' % service_name,
                         )
            except AttributeError as e:
                logger.error("Unrecognized Method: {0}, called from: {1}".format(e, client_id), exc_info=1)
                return self.__send_error(resp,
                        remote.RpcState.METHOD_NOT_FOUND_ERROR,
                        'Unrecognized RPC method: %s' % method_name,
                         )

            except remote.ServiceDefinitionError as err:
                err_msg = 'Invalid Service define: {0}, called from: {1}'.format(err, client_id)
                logger.error(err_msg, exc_info=1)
                return self.__send_error(resp,
                                remote.RpcState.Request_Error,
                                'Invalid Service define: {0}'.format(err)
                                )
            try:
                start = time.time()
                response = method(request_msg)
                end = time.time()
                finished = end - start
                logger.info("<method {method_name}( {request} ) > from {client_id} to {service_name} ,finished in {finished}s".format(client_id=client_id,
                                                                                                                            request=request_msg,
                                                                                                                            service_name=service_name,
                                                                                                                            method_name=method_name,
                                                                                                                            finished=finished,
                                                                                                                            ))
                resp.response = self.protocol.encode_message(response)
                return resp
            except remote.ApplicationError, err:
                logger.error(err.message)
                return self.__send_error(resp,
                                remote.RpcState.APPLICATION_ERROR,
                                err.message,
                                err.error_name)
            except Exception, err:
                logging.error('An unexpected error occured when handling RPC: {0}, called from: {1}'.format(err,
                              client_id), exc_info=1)

                return self.__send_error(resp,
                                    remote.RpcState.SERVER_ERROR,
                                    'Internal Server Error: {0}'.format(err),
                                    )

    def run(self):
        # Socket to talk to dispatcher
        while not self.__stop:
            try:
                self.socket = self.context.socket(zmq.REQ)
                self.socket.connect(self.worker_url)
                self.socket.send('READY')

                self.poller = zmq.Poller()
                self.poller.register(self.socket, zmq.POLLIN)
                while True:
                    if self.__stop:
                        return

                    socks = dict(self.poller.poll(1000))
                    if len(socks.keys()) > 0:
                        pass
                    else:
                        continue

                    assert(socks[self.socket] == zmq.POLLIN)
                    msg = self.socket.recv_multipart()
                    self.return_address = msg[:-1]
                    msg_received = msg[-1]
                    result_message = self.handle_message(msg_received)
                    #send reply back to client
                    self.reply(result_message)
            except AssertionError:
                logger.error('Could not assert right socket state, creating new socket.', exc_info=1)

            except zmq.ZMQError:
                if not self.__stop:
                    logger.error('Could not send or receive message, creating new socket.', exc_info=1)
                else:
                    logging.info("worker:(%s) will now shutdown" % self.worker_url)
            finally:
                self.poller.unregister(self.socket)
                self.socket.close()