Source

green380 / green380 / _core.py

__all__ = 'schedule', 'spawn', 'run', 'Channel',

import collections
import errno
import fcntl
import heapq
import logging
import os
import pdb
import select
import signal
import socket as _socket
import time

from .fileno import fileno

logging.root.setLevel(logging.NOTSET)


def timeout_multiplex(event, timeout):
    @spawn
    def timeout_fiber():
        yield 'timeout', timeout
        import pdb
        pdb.set_trace()
        main_fiber.send(False)
    main_fiber = current
    ready = yield
    assert isinstance(ready, bool), ready
    if ready:
        timeout_fiber.close()
    schedule(main_fiber)
    yield
    return ready


class Channel:

    def __init__(self):
        # serious black magic; deque is fastest
        self._receivers = collections.deque()
        self._senders = collections.deque()

    def get(self):
        if self._senders:
            sender = self._senders.popleft()
            item = next(sender)
            schedule(sender)
            return item
        else:
            self._receivers.append(current)
            item = yield
            yield
            return item

    def put(self, item):
        if self._receivers:
            receiver = self._receivers.popleft()
            receiver.send(item)
            schedule(receiver)
        else:
            self._senders.append(current)
            yield
            yield item


class _Scheduler:

    _close_mask = select.EPOLLHUP|select.EPOLLERR
    _read_mask = select.EPOLLIN|select.EPOLLPRI|_close_mask
    _read_mask |= getattr(select, 'EPOLLRDHUP', 0)
    _write_mask = select.EPOLLOUT|_close_mask

    def __init__(self):
        self._poll_obj = select.epoll()
        open_max = os.sysconf(os.sysconf_names['SC_OPEN_MAX'])
        self._readers = [set() for _ in range(open_max)]
        self._writers = [set() for _ in range(open_max)]
        self._signals = collections.defaultdict(set)
        self._registered = set()
        self._deadlines = []
        self._signalfd, self._wakeup_fd = os.pipe()
        fcntl.fcntl(self._wakeup_fd, fcntl.F_SETFL, os.O_NONBLOCK)
        signal.set_wakeup_fd(self._wakeup_fd)
        self._ready = []
        self._original_signal_handlers = {}
        self.spawn(self._signalfd_reader)
        self._hook_default_signals()

    def _hook_default_signals(self):
        if signal.getsignal(signal.SIGCHLD) == signal.SIG_DFL:
            logging.debug('Hooked signal SIGCHLD')
            if signal.signal(signal.SIGCHLD, self._handle_signal) != signal.SIG_DFL:
                raise RuntimeError

    def run_fiber(self, fiber, arg=None):
        global current
        current = fiber
        try:
            if __debug__:
                logging.debug('running %s', fiber)
            return fiber.send(arg)
        except StopIteration:
            pass
        finally:
            current = None

    def handle_event(self, fiber, event):
        new_event = self.run_fiber(fiber)
        if new_event != event:
            if event is not None:
                self.remove_event(fiber, *event)
            if new_event is not None:
                self.add_event(fiber, *new_event)

    def _signalfd_reader(self):
        while True:
            yield 'read', self._signalfd
            buf = os.read(self._signalfd, 0x100)
            if not buf:
                break
            for signum in buf:
                logging.debug('Got signal: %s', signum)
                for fiber in self._signals[signum].copy():
                    self.handle_event(fiber, ('signal', signum))

    def run(self):
        while any([self._registered - {self._signalfd}, self._ready, self._deadlines, any(self._signals.values())]):
            if self._registered:
                if self._ready:
                    timeout = 0
                elif self._deadlines:
                    timeout = max(0, self._deadlines[0][0] - time.time())
                else:
                    timeout = -1
                while True:
                    try:
                        fd_masks = self._poll_obj.poll(timeout)
                    except InterruptedError:
                        pass
                    else:
                        break
                for fd, mask in fd_masks:
                    #~ if mask & self._close_mask:
                        #~ self._registered.remove(fd)
                    if mask & self._read_mask:
                        readers = self._readers[fd]
                        for fiber in readers.copy():
                            if fiber in readers:
                                self.handle_event(fiber, ('read', fd))
                    if mask & self._write_mask:
                        writers = self._writers[fd]
                        for fiber in writers.copy():
                            if fiber in writers:
                                self.handle_event(fiber, ('write', fd))
            while self._deadlines and self._deadlines[0][0] < time.time():
                deadline, fiber = heapq.heappop(self._deadlines)
                self.handle_event(fiber, None)
            ready = self._ready
            self._ready = []
            for fiber in ready:
                self.handle_event(fiber, None)

    def spawn(self, func, *args, **kwargs):
        return self.schedule(func(*args, **kwargs))

    def schedule(self, fiber):
        if fiber is None:
            assert False
        self._ready.append(fiber)
        return fiber

    def _update_poll(self, fd):
        mask = select.EPOLLIN if self._readers[fd] else 0
        mask |= select.EPOLLOUT if self._writers[fd] else 0
        if fd in self._registered:
            if mask:
                self._poll_obj.modify(fd, mask)
            else:
                self._poll_obj.unregister(fd)
                self._registered.remove(fd)
        else:
            if mask:
                self._poll_obj.register(fd, mask)
                self._registered.add(fd)

    def _handle_signal(self, signum, frame):
        #~ logging.critical('signal handler: %s', signum)
        pass

    def add_event(self, fiber, filter, data):
        if filter in {'read', 'write'}:
            fd = fileno(data)
            watchers = {'read': self._readers, 'write': self._writers}[filter][fd]
            assert fiber not in watchers
            watchers.add(fiber)
            self._update_poll(fd)
        elif filter == 'signal':
            self._signals[data].add(fiber)
            assert signal.getsignal(data) not in {signal.SIG_IGN, signal.SIG_DFL, None}
        elif filter == 'timeout':
            if data is not None:
                assert data >= 0, data
                heapq.heappush(self._deadlines, (time.time() + data, fiber))
        else:
            raise ValueError('Unknown filter', event.filter)

    def remove_event(self, fiber, filter, data):
        if filter in {'read', 'write'}:
            fd = fileno(data)
            watchers = {'read': self._readers, 'write': self._writers}[filter][fd]
            watchers.remove(fiber)
            self._update_poll(fd)
        elif filter == 'signal':
            self._signals[data].remove(fiber)
        elif filter == 'timeout':
            self._deadlines.remove((data, fiber))
        else:
            raise ValueError('unknown filter', event.filter)

scheduler = _Scheduler()
schedule = scheduler.schedule
spawn = scheduler.spawn
run = scheduler.run

def current_fiber():
    return current