Source

amqpev / protocol.py

Full commit
import itertools
import logging
from eventlet.coros import semaphore, queue, event, api as ev_api
from spec.spec import spec_0_8 as spec

from frame import MethodFrame, HeaderFrame, BodyFrame
import channel_fsm
import errors

log = logging.getLogger("amqpev.protocol")


class BaseConnection(object):
    def __init__(self, transport, vhost="/", insist=False):
        self._transport = transport
        self._channel_queues = {}
        self._mux = Multiplexer(self._channel_queues)
        self._chan_zero = self._create_channel(0)
        self.runner_exception = None

        self._transport.open()
        ev_api.spawn(self._demux_runner)
        ev_api.spawn(self._mux_runner)

        self._state = 'open'
        (result, params) = channel_fsm.connection_init(self._chan_zero,
                vhost=vhost, insist=insist)

        self.tune_params = params

        if result == 'open':
            self._state = 'open'
        elif result == 'redirected':
            self._state = 'broker-closed'

    def close(self):
        self._chan_zero.sync_method(spec.connection.close,
                (spec.connection.close_ok,), reply_code=0,
                reply_text="Closing", class_id=0, method_id=0)
        self._transport.close()
        self._state = 'client-closed'
        self.runner_exception = errors.ClientClosedConnection()

    def _demux_runner(self):
        while self._state == 'open':
            try:
                frame = self._transport.recv_frame()
                self._mux.demux_frame(frame)
            except errors.ConnectionClosedError, exc:
                self.runner_exception = exc
                log.exception("Connection down in demux runner, exiting.")
                self._state = 'closed'
            except Exception, exc:
                self.runner_exception = exc
                log.exception("Caught exception in demux runner.")

            if self.runner_exception is not None:
                self._mux.all_channel_exception(self.runner_exception)
                self.runner_exception = None

    def _mux_runner(self):
        while self._state == 'open':
            try:
                frame = self._mux.get_next_muxed_frame()
                self._transport.send_frame(frame)
            except errors.ConnectionClosedError, exc:
                self.runner_exception = exc
                log.exception("Connection down in mux runner, exiting.")
                self._state = 'closed'
            except Exception, exc:
                self.runner_exception = exc
                log.exception("Caught exception in mux runner.")

            if self.runner_exception is not None:
                self._mux.all_channel_exception(self.runner_exception)
                self.runner_exception = None

    def _create_channel(self, id=None):
        if id is None:
            free_chan_nrs = (set(range(max(self._channel_queues) + 2)) -
                    set(self._channel_queues))
            id = list(free_chan_nrs)[0]

        elif id in self._channel_queues:
            raise ChannelCreateError("Channel already exists")

        cfq = ChannelFrameQueue(self._mux)
        self._channel_queues[id] = cfq
        return Channel(id, self, cfq)

    def _drop_channel(self, id):
        self._channel_queues.pop(id)


class Connection(BaseConnection):
    def channel(self):
        new_chan = self._create_channel()
        try:
            new_chan.sync_method(spec.channel.open, (spec.channel.open_ok,),
                    out_of_band="")
            return new_chan
        except:
            self._drop_channel(new_chan.id)
            raise

    def close_channel(self, ch):
        ch.sync_method(spec.channel.close, (spec.channel.close_ok,),
                reply_code=0, reply_text="Closing", class_id=0, method_id=0)
        self._drop_channel(ch.id)


class BaseChannel(object):
    def __init__(self, id, connection, channel_queue):
        self._id = id
        self._frame_queue = channel_queue
        self._open = True
        self._connection = connection

    def _recv_frame(self):
        if not self._open:
            raise errors.ChannelClosedError()

        frame = self._frame_queue.recv_frame()
        assert frame.channel == self._id, ("Frame delivered to"
                " incorrect ChannelFrameQueue")

        return frame

    def _peek_frames(self):
        if not self._open:
            raise errors.ChannelClosedError()
        return self._frame_queue.peek_frames()

    def _drop_frame(self, index):
        if not self._open:
            raise errors.ChannelClosedError()
        return self._frame_queue.drop_frame(index)

    def _send_frame(self, frame):
        if not self._open:
            raise errors.ChannelClosedError()

        frame.channel = self._id
        self._frame_queue.send_frame(frame)

    def send_method(self, method, **kwargs):
        self._send_frame(MethodFrame(method, **kwargs))

    def recv_method(self, discard=False):
        frame = self._recv_frame()
        while True:
            if frame.type == 'method':
                return frame
            elif discard:
                continue
            else:
                raise errors.FrameTypeUnexpectedError()

    def wait_method(self, *methods):
        for index, frame in self._peek_frames():
            if frame.type == 'method' and frame.method in methods:
                self._drop_frame(index)
                return frame

    def sync_method(self, method, reply_methods, **kwargs):
        self.send_method(method, **kwargs)
        return self.wait_method(*reply_methods)

    def send_content(self, class_, headers, body, body_len=None):
        if isinstance(body, str):
            body = [body]
            body_len = len(body[0])
        elif body_len is None:
            body = [''.join(body)]
            body_len = len(body[0])

        self._send_frame(HeaderFrame(class_, body_len, headers))

        mtu = self._connection.tune_params['frame_max']
        segment_buf = ""

        for hunk in body:
            segment_buf += hunk

            while len(segment_buf) >= mtu:
                (segment, rest) = (segment_buf[:mtu], segment_buf[mtu:])
                self._send_frame(BodyFrame(segment))
                segment_buf = rest

        if len(segment_buf):
            self._send_frame(BodyFrame(segment_buf))

    def recv_content(self):
        def next_frame_is(f_type):
            (_, next_frame) = self._peek_frames().next()
            return next_frame.type == f_type

        if not next_frame_is('header'):
            raise errors.FrameTypeUnexpectedError("recv_content expecting a"
                    " header frame, got %s instead." % next_frame.type)

        header_frame = self._recv_frame()
        body = ""

        while len(body) < header_frame.body_len and next_frame_is('body'):
            body += self._recv_frame().payload

        return (header_frame.headers, header_frame.body_len, body)

    @property
    def id(self):
        return self._id


class Channel(BaseChannel):
    def close(self):
        self._connection.close_channel(self)
        self._open = False


class Multiplexer(object):
    """
    Implements channel frame multiplexing for a connection.
    """
    def __init__(self, channel_queues):
        self._channel_queues = channel_queues
        self._last_channel_idx = 0

        self.tx_sem = semaphore()
        self.tx_enqueue_notify = self.tx_sem.release
        self.tx_wait_if_empty = self.tx_sem.acquire

    def get_next_muxed_frame(self):
        """
        Get the next frame to transmit in the multiplex sequence. Blocks if
        there are no frames queued for sending.
        """
        self.tx_wait_if_empty()
        chan_nrs = sorted(self._channel_queues.keys())
        assert len(chan_nrs) > 0, "We should have at least one channel by now."
        result = None
        chan_idx = self._last_channel_idx

        while result is None:
            chan_idx = (chan_idx + 1) % len(chan_nrs)
            chan_nr = chan_nrs[chan_idx]
            queue = self._channel_queues[chan_nr]

            result = queue.get_next_tx_frame()
            if result is not None:
                result.channel = chan_nr

        self._last_channel_idx = chan_idx
        return result

    def demux_frame(self, frame):
        """
        Demultiplex a received frame.
        """
        try:
            self._channel_queues[frame.channel].put_next_rx_frame(frame)
        except KeyError:
            raise errors.UnknownChannelReceivedError("Recieved on unknown"
                    " channel %i" % frame.channel)

    def all_channel_exception(self, exc):
        for cq in self._channel_queues.itervalues():
            cq.raise_(exc)


class ChannelFrameQueue(object):
    """
    Implements queueing behavior for a single channel. Contains the send and
    receive frame queues for a single channel.
    """
    def __init__(self, multiplexer, send_queue_max=10):
        self.mux = multiplexer
        self.send_queue_sem = semaphore(limit=send_queue_max)
        self._send_frame_queue = []
        self._recv_frame_queue = queue()
        self._recv_pushback = []

        self._send_queue_wait = self.send_queue_sem.release
        self._send_dequeue_notify = self.send_queue_sem.acquire

    def send_frame(self, frame):
        """
        Queues a frame to be sent on this channel. This call will wait if the
        send queue length exceeds `send_queue_max` frames.
        """
        # We append the frame to our send queue, and then notify the mux.
        self._send_queue_wait()
        self._send_frame_queue.append(frame)
        self.mux.tx_enqueue_notify()

    def get_next_tx_frame(self):
        """
        Gets the next frame in the transmit queue, or None if it is empty.
        """
        if len(self._send_frame_queue):
            self._send_dequeue_notify()
            return self._send_frame_queue.pop(0)

    def put_next_rx_frame(self, frame):
        """
        Puts the next frame in the receive queue.
        """
        self._recv_frame_queue.send(frame)

    def raise_(self, exc):
        """
        Raise an exception on the channel.
        """
        self._recv_frame_queue.send(exc=exc)

    def recv_frame(self):
        """
        Receive a frame from the channel. This dequeues the frame, which will
        wait if the queue is empty.
        """
        # If there are any frames in our pushback buffer, return those first.
        if self._recv_pushback:
            return self._recv_pushback.pop(0)
        else:
            return self._recv_frame_queue.wait()

    def peek_frames(self):
        """
        Return an iterator over the frames in our queue so that higher level
        APIs may selectively accept frames. The iterator yields tuples of
        (index, frame) such that drop_frame(index) will delete frame from the
        receive queue.

        Calls to drop_frame() or recv_frame() will invalidate the indices
        returned by this method.
        """
        while self._recv_frame_queue.ready():
            self._recv_pushback.append(self._recv_frame_queue.wait())

        index = 0
        for frame in self._recv_pushback:
            yield index, frame
            index += 1

        while True:
            frame = self._recv_frame_queue.wait()
            self._recv_pushback.append(frame)
            yield index, frame
            index += 1

    def drop_frame(self, index):
        """
        Drop a frame from the receive queue. The index is obtained from the
        peek_frames method. Calling it with an invalidated index is undefined
        behavior (see peek_frames doc).
        """
        self._recv_pushback.pop(index)