green380 / green380 / _core.py

__all__ = 'schedule', 'spawn', 'run', 'Channel',

import collections
import errno
import fcntl
import heapq
import inspect
import logging
import os
import pdb
import select
import signal
import socket as _socket
import sys
import time
import traceback
#~ logging.basicConfig(level=logging.NOTSET, stream=sys.stderr)
#
from .fileno import fileno
from .time import sleep

logger = logging.getLogger('green380')
logger.setLevel(logging.DEBUG)


def alt(*gens):
    ch = Channel()
    def wrap(gen):
        ret = yield from gen
        ch.send((ret, gen))

    for g in gens:
        spawn(gen)


def timeout_multiplex(event, timeout):

    @spawn
    def timeout_fiber():
        yield 'timeout', timeout
        main_fiber.send(False)
    main_fiber = current
    ready = yield
    assert isinstance(ready, bool), ready
    if ready:
        timeout_fiber.close()
    schedule(main_fiber)
    main_fiber = None
    yield
    return ready


class Timeout(Exception):

    def __init__(self, timeout):
        spawn(self._timeout_routine, timeout)
        self._channels = set()
        self._expired = False

    def _timeout_routine(self, timeout):
        yield from sleep(timeout)
        self._expired = True
        for ch in self._channels:
            yield from ch.send()

    def multiplex(self, gen):
        if self._expired:
            raise self
        ch = Channel()
        self._channels.add(ch)
        @schedule
        def routine():
            result = yield from gen
            ch.send(result)
        value, sender = yield from ch.recv_from()
        self._channels.remove(ch)
        if sender is routine:
            return value
        assert sender is self._timeout_routine, sender
        raise self

#~ class Timeout:
#~
    #~ def __init__(self, timeout):
        #~ self._timeout = timeout
        #~ spawn(self._timeout_routine)
#~
    #~ def _timeout_routine(self):
        #~ yield from sleep(self._timeout)
        #~ self._expired = True
        #~ while self._channels:
            #~ yield from self._channels.pop().send(False)
#~
    #~ def __call__(self, task):
        #~ if self._expired:
            #~ return False
        #~ ch = Channel()
        #~ self._channels.append(ch)
        #~ @spawn
        #~ def event_fiber():
            #~ yield event
            #~ ch.send(True)
        #~ ready = yield from ch.recv()
        #~ assert self._channels[-1] is ch
        #~ self._channels.pop()
        #~ assert isinstance(ready, bool), ready
        #~ return ready


class Channel:

    def __init__(self):
        # serious black magic; deque is fastest
        self._receivers = collections.deque()
        self._senders = collections.deque()

    def recv_from(self):
        if self._senders:
            sender = self._senders.popleft()
            item = next(sender)
            schedule(sender)
            return item, sender
        else:
            self._receivers.append(current)
            item, sender = yield
            yield
            return item, sender

    def recv(self):
        return (yield from self.recv_from())[0]

    get = recv

    def send(self, item=None):
        if self._receivers:
            receiver = self._receivers.popleft()
            receiver.send((item, current))
            schedule(receiver)
        else:
            self._senders.append(current)
            yield
            yield item

    put = send


def fiber_bottom_location(gen):
    import pdb
    pdb.set_trace()
    while inspect.isgenerator(gen):
        frame = gen.gi_frame
        gen = frame.f_yieldfrom
    return frame.f_code, frame.f_lineno

class _Scheduler:

    _close_mask = select.EPOLLHUP|select.EPOLLERR
    _read_mask = select.EPOLLIN|select.EPOLLPRI|_close_mask
    _read_mask |= getattr(select, 'EPOLLRDHUP', 0)
    _write_mask = select.EPOLLOUT|_close_mask

    def __init__(self):
        self._poll_obj = select.epoll()
        open_max = os.sysconf(os.sysconf_names['SC_OPEN_MAX'])
        self._readers = [set() for _ in range(open_max)]
        self._writers = [set() for _ in range(open_max)]
        self._signals = collections.defaultdict(set)
        self._registered = set()
        self._deadlines = []
        self._signalfd, self._wakeup_fd = os.pipe()
        fcntl.fcntl(self._wakeup_fd, fcntl.F_SETFL, os.O_NONBLOCK)
        signal.set_wakeup_fd(self._wakeup_fd)
        self._ready = []
        self._original_signal_handlers = {}
        self.spawn(self._signalfd_reader)
        self._hook_default_signals()

    def _hook_default_signals(self):
        if signal.getsignal(signal.SIGCHLD) == signal.SIG_DFL:
            logger.debug('Hooked signal SIGCHLD')
            if signal.signal(signal.SIGCHLD, self._handle_signal) != signal.SIG_DFL:
                raise RuntimeError

    def run_fiber(self, fiber, arg=None):
        global current
        current = fiber
        try:
            if __debug__:
                logger.debug('running %s', fiber)
            return fiber.send(arg)
        except StopIteration:
            pass
        finally:
            current = None

    def handle_event(self, fiber, event):
        try:
            new_event = self.run_fiber(fiber)
            try:
                if new_event != event:
                    if event is not None:
                        self.remove_event(fiber, *event)
                    if new_event is not None:
                        self.add_event(fiber, *new_event)
            except:
                logging.exception('Error handling event for fiber %r: %s', fiber, fiber_bottom_location(fiber))
                raise
        except:
            logger.critical('Error handling event %r for fiber %r', event, fiber, exc_info=True)
            raise

    def _signalfd_reader(self):
        while True:
            yield 'read', self._signalfd
            buf = os.read(self._signalfd, 0x100)
            if not buf:
                break
            for signum in buf:
                logger.debug('Got signal: %s', signum)
                for fiber in self._signals[signum].copy():
                    self.handle_event(fiber, ('signal', signum))

    def run(self):
        while any([self._registered - {self._signalfd}, self._ready, self._deadlines, any(self._signals.values())]):
            if self._registered:
                if self._ready:
                    timeout = 0
                elif self._deadlines:
                    timeout = max(0, self._deadlines[0][0] - time.time())
                else:
                    timeout = -1
                while True:
                    try:
                        fd_masks = self._poll_obj.poll(timeout)
                    except InterruptedError:
                        pass
                    else:
                        break
                for fd, mask in fd_masks:
                    #~ if mask & self._close_mask:
                        #~ self._registered.remove(fd)
                    if mask & self._read_mask:
                        readers = self._readers[fd]
                        for fiber in readers.copy():
                            if fiber in readers:
                                self.handle_event(fiber, ('read', fd))
                    if mask & self._write_mask:
                        writers = self._writers[fd]
                        for fiber in writers.copy():
                            if fiber in writers:
                                self.handle_event(fiber, ('write', fd))
            while self._deadlines and self._deadlines[0][0] < time.time():
                deadline, _, fiber = heapq.heappop(self._deadlines)
                self.handle_event(fiber, None)
            ready = self._ready
            self._ready = []
            for fiber in ready:
                self.handle_event(fiber, None)

    def spawn(self, func, *args, **kwargs):
        return self.schedule(func(*args, **kwargs))
        #~ def wrapper():
            #~ try:
                #~ yield from func(*args, **kwargs)
            #~ except GeneratorExit:
                #~ assert False, func
        #~ return self.schedule(wrapper())

    def schedule(self, fiber):
        if fiber is None:
            assert False
        self._ready.append(fiber)
        return fiber

    def _update_poll(self, fd):
        mask = select.EPOLLIN if self._readers[fd] else 0
        mask |= select.EPOLLOUT if self._writers[fd] else 0
        if fd in self._registered:
            if mask:
                self._poll_obj.modify(fd, mask)
            else:
                self._poll_obj.unregister(fd)
                self._registered.remove(fd)
        else:
            if mask:
                self._poll_obj.register(fd, mask)
                self._registered.add(fd)

    def _handle_signal(self, signum, frame):
        #~ logger.critical('signal handler: %s', signum)
        pass

    def add_event(self, fiber, filter, data):
        if filter in {'read', 'write'}:
            fd = fileno(data)
            watchers = {'read': self._readers, 'write': self._writers}[filter][fd]
            assert fiber not in watchers
            watchers.add(fiber)
            self._update_poll(fd)
        elif filter == 'signal':
            self._signals[data].add(fiber)
            assert signal.getsignal(data) not in {signal.SIG_IGN, signal.SIG_DFL, None}
        elif filter == 'timeout':
            if data is not None:
                assert data >= 0, data
                heapq.heappush(self._deadlines, (time.time() + data, id(fiber), fiber))
        else:
            raise ValueError('Unknown filter', event.filter)

    def remove_event(self, fiber, filter, data):
        if filter in {'read', 'write'}:
            fd = fileno(data)
            watchers = {'read': self._readers, 'write': self._writers}[filter][fd]
            watchers.remove(fiber)
            self._update_poll(fd)
        elif filter == 'signal':
            self._signals[data].remove(fiber)
        elif filter == 'timeout':
            self._deadlines.remove((data, fiber))
        else:
            raise ValueError('unknown filter', event.filter)

scheduler = _Scheduler()
schedule = scheduler.schedule
spawn = scheduler.spawn
run = scheduler.run

def current_fiber():
    return current
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.