Matt Joiner avatar Matt Joiner committed b0cc020

New, much simpler model

Comments (0)

Files changed (4)

examples/child_proc.py

 from fcntl import fcntl, F_SETFL
-from gthread import Event, spawn, run
+from gthread import spawn, run
 from os import O_NONBLOCK
+from pdb import set_trace
 from signal import SIGCHLD
 from subprocess import Popen, PIPE
+import sys
 
 def find_var():
-    sigchld = Event('signal', SIGCHLD)
     proc = Popen(['find', '/var'], stdout=PIPE, stderr=PIPE)
-    stdout = Event('read', proc.stdout)
-    stderr = Event('read', proc.stderr)
-    fcntl(proc.stdout, F_SETFL, O_NONBLOCK)
     lines = 0
-    stderr_buf = b''
-    while any([sigchld, stdout, stderr]):
-        event = yield
-        if event is sigchld:
-            rc = proc.poll()
-            if rc is not None:
-                print('process terminated with returncode', rc)
-                sigchld.remove()
-        elif event is stdout:
-            buf = proc.stdout.read(100)
-            #~ print(buf)
-            if buf:
-                lines += buf.count(b'\n')
-            else:
-                stdout.remove()
-                print('got', lines, 'lines')
-        elif event is stderr:
-            buf = proc.stderr.read(1)
+    @spawn
+    def sigchld():
+        while proc.poll() is None:
+            yield 'signal', SIGCHLD
+        print('process returncode:', proc.returncode)
+    @spawn
+    def read_stdout():
+        nonlocal lines
+        while True:
+            yield 'read', proc.stdout
+            buf = proc.stdout.read(0x1000)
             if not buf:
-                stderr.remove()
-            stderr_buf += buf
+                break
+            lines += buf.count(b'\n')
+        print(lines, 'lines')
+    @spawn
+    def read_stderr():
+        while True:
+            yield 'read', proc.stderr
+            line = proc.stderr.readline()
+            if not line:
+                break
+            sys.stderr.buffer.raw.write(line)
 
-spawn(find_var)
+find_var()
 run()

examples/http_server.py

+import socket
 import gthread
-import socket
 
 def handler(sock, addr):
     print('Handling connection from', addr)
     yield from sock.until_sendall('Hello {}!\n'.format(addr).encode())
     sock.close()
 
-def serve():
+@gthread.spawn
+def server():
     sock = gthread.socket()
     sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
     sock.bind(('', 8080))
         new_sock, addr = yield from sock.until_accept()
         gthread.spawn(handler, new_sock, addr)
 
-gthread.spawn(serve)
-
 gthread.run()

examples/prime_sieve.py

         yield from ch.put(i)
 
 def filter(p, in_, out):
-    while True:
+    while 1:
         n = yield from in_.get()
         if n % p:
             yield from out.put(n)
         spawn(filter, p, ch, ch1)
         ch = ch1
 
-spawn(sieve, 1000)
+spawn(sieve, 10000)
 run()
         self.setblocking(False)
 
     def until_recv_term(self, term):
-        with scheduler.event('read', self):
-            buf = b''
-            while not buf.endswith(term):
-                yield
-                buf1 = self.recv(0x10000, _socket.MSG_PEEK)
-                if not buf1:
-                    break
-                index = (buf + buf1).find(term)
-                buf2 = self.recv(len(buf1) if index == -1 else index - len(buf) + len(term))
-                buf += buf2
-            return buf
+        buf = b''
+        while not buf.endswith(term):
+            yield 'read', self
+            buf1 = self.recv(0x10000, _socket.MSG_PEEK)
+            if not buf1:
+                break
+            index = (buf + buf1).find(term)
+            buf2 = self.recv(len(buf1) if index == -1 else index - len(buf) + len(term))
+            buf += buf2
+        return buf
 
     def until_recv(self, count):
         yield
             buf = buf[sent:]
 
     def until_send(self, buf):
-        with Event('write', self) as can_send:
-            event = yield
-            if event is can_send:
-                return self.send(buf)
-            yield event
+        yield 'write', self
+        return self.send(buf)
 
     def until_accept(self):
-        with event('read', self):
-            yield
-            return self.accept()
+        yield 'read', self
+        return self.accept()
 
     def accept(self):
         sock, addr = super().accept()
         self._senders = set()
 
     def get(self):
-        #~ if self._items:
-            #~ return self._items.popleft()
         if self._senders:
-            item, sender = self._senders.pop()
+            sender = self._senders.pop()
+            item = next(sender)
             schedule(sender)
             return item
-        self._receivers.add(current)
-        return (yield)
+        else:
+            self._receivers.add(current)
+            item = yield
+            yield
+            return item
 
     def put(self, item):
         if self._receivers:
-            schedule(self._receivers.pop(), item)
+            receiver = self._receivers.pop()
+            receiver.send(item)
+            schedule(receiver)
         else:
-            self._senders.add((item, current))
+            self._senders.add(current)
             yield
+            yield item
 
 def spawn(*args, **kwargs):
-    scheduler.spawn(*args, **kwargs)
+    return scheduler.spawn(*args, **kwargs)
 
+import fcntl
+import logging
+import os
+import pdb
+import signal
 
-class timeout:
-
-    def __init__(self, timeout):
-        self.deadline = timeout + time()
-
-    def __enter__(self):
-        scheduler().add_deadline(self.deadline)
-
-    def __exit__(self):
-        scheduler().remove_deadline(self.deadline)
-
-import os as _os
+#~ logging.root.setLevel(logging.NOTSET)
 
 from fileno import fileno
 
-
-class _Event:
-
-    def __init__(self, scheduler, filter, data):
-        self.scheduler = scheduler
-        self.filter = filter
-        self.data = data
-        self.fiber = current
-        self.add()
-
-    def __enter__(self):
-        self.add()
-        return self
-
-    def __exit__(self, *args):
-        self.remove()
-
-    def add(self):
-        self.scheduler.add_event(self)
-        self.active = True
-
-    def remove(self):
-        self.scheduler.remove_event(self)
-        self.active = False
-
-    def __bool__(self):
-        return self.active
-
-import signal
-import pdb
-import os
-
 class _Scheduler:
 
     _close_mask = _select.EPOLLHUP|_select.EPOLLERR
 
     def __init__(self):
         self._poll_obj = _select.epoll()
-        open_max = _os.sysconf(_os.sysconf_names['SC_OPEN_MAX'])
+        open_max = os.sysconf(os.sysconf_names['SC_OPEN_MAX'])
         self._readers = [set() for _ in range(open_max)]
         self._writers = [set() for _ in range(open_max)]
-        from collections import defaultdict
-        self._signals = defaultdict(set)
+        self._signals = collections.defaultdict(set)
         self._registered = set()
         self._deadlines = []
-        import os
         self._signalfd, self._wakeup_fd = os.pipe()
-        from fcntl import fcntl, F_SETFL
-        from os import O_NONBLOCK
-        fcntl(self._wakeup_fd, F_SETFL, O_NONBLOCK)
+        fcntl.fcntl(self._wakeup_fd, fcntl.F_SETFL, os.O_NONBLOCK)
         signal.set_wakeup_fd(self._wakeup_fd)
         self._ready = []
         self._original_signal_handlers = {}
         self.spawn(self._signalfd_reader)
+        self._hook_default_signals()
+
+    def _hook_default_signals(self):
+        if signal.getsignal(signal.SIGCHLD) == signal.SIG_DFL:
+            logging.debug('Hooked signal SIGCHLD')
+            if signal.signal(signal.SIGCHLD, self._handle_signal) != signal.SIG_DFL:
+                raise RuntimeError
 
     def run_fiber(self, fiber, arg=None):
         global current
         current = fiber
         try:
+            if __debug__:
+                logging.debug('running %s', fiber)
             return fiber.send(arg)
         except StopIteration:
-            return arg
+            pass
         finally:
             current = None
 
-    def handle_event(self, event):
-        event = self.run_fiber(event.fiber, event)
-        if event:
-            event.remove()
+    def handle_event(self, fiber, event):
+        new_event = self.run_fiber(fiber)
+        if new_event != event:
+            if event is not None:
+                self.remove_event(fiber, *event)
+            if new_event is not None:
+                self.add_event(fiber, *new_event)
 
     def _signalfd_reader(self):
-        self.event('read', self._signalfd)
         while True:
-            yield
-            for signo in os.read(self._signalfd, 0x100):
-                for event in self._signals[signo]:
-                    self.handle_event(event)
+            yield 'read', self._signalfd
+            buf = os.read(self._signalfd, 0x100)
+            if not buf:
+                break
+            for signum in buf:
+                logging.debug('Got signal: %s', signum)
+                for fiber in self._signals[signum].copy():
+                    self.handle_event(fiber, ('signal', signum))
 
     def run(self):
-        while any([self._registered - {self._signalfd}, self._ready, self._deadlines]):
-            if self._registered - {self._signalfd}:
+        while any([self._registered - {self._signalfd}, self._ready, self._deadlines, any(self._signals.values())]):
+            if self._registered:
                 if self._ready:
                     timeout = 0
                 elif self._deadlines:
                     timeout = max(0, self._deadlines[0][0])
                 else:
                     timeout = -1
-                for fd, mask in self._poll_obj.poll(timeout):
+                while True:
+                    try:
+                        fd_masks = self._poll_obj.poll(timeout)
+                    except InterruptedError:
+                        pass
+                    else:
+                        break
+                for fd, mask in fd_masks:
                     #~ if mask & self._close_mask:
                         #~ self._registered.remove(fd)
                     if mask & self._read_mask:
                         readers = self._readers[fd]
-                        for event in readers.copy():
-                            if event in readers:
-                                self.handle_event(event)
+                        for fiber in readers.copy():
+                            if fiber in readers:
+                                self.handle_event(fiber, ('read', fd))
                     if mask & self._write_mask:
                         writers = self._writers[fd]
-                        for event in writers.copy():
-                            if event in writers:
-                                self.handle_event(event)
+                        for fiber in writers.copy():
+                            if fiber in writers:
+                                self.handle_event(fiber, ('write', fd))
             ready = self._ready
             self._ready = []
-            for gen, arg in ready:
-                self.run_fiber(gen, arg)
+            for fiber in ready:
+                self.handle_event(fiber, None)
 
     def spawn(self, func, *args, **kwargs):
-        self.schedule(func(*args, **kwargs))
+        return self.schedule(func(*args, **kwargs))
 
-    def schedule(self, fiber, arg=None):
-        self._ready.append((fiber, arg))
-
-    def event(self, filter, data):
-        return _Event(self, filter, data)
+    def schedule(self, fiber):
+        if fiber is None:
+            assert False
+        self._ready.append(fiber)
+        return fiber
 
     def _update_poll(self, fd):
         mask = _select.EPOLLIN if self._readers[fd] else 0
                 self._poll_obj.register(fd, mask)
                 self._registered.add(fd)
 
-    def _signal_handler(self, signum, frame):
+    def _handle_signal(self, signum, frame):
+        #~ logging.critical('signal handler: %s', signum)
         pass
 
-    def add_event(self, event):
-        if event.filter == 'read':
-            fd = fileno(event.data)
-            self._readers[fd].add(event)
+    def add_event(self, fiber, filter, data):
+        if filter in {'read', 'write'}:
+            fd = fileno(data)
+            watchers = {'read': self._readers, 'write': self._writers}[filter][fd]
+            assert fiber not in watchers
+            watchers.add(fiber)
             self._update_poll(fd)
-        elif event.filter == 'write':
-            fd = fileno(event.data)
-            self._writers[fd].add(event)
-            self._update_poll(fd)
-        elif event.filter == 'signal':
-            signum = event.data
-            self._signals[event.data].add(event)
-            handler = signal.getsignal(signum)
-            def meh(signum, frame):
-                pass
-            if handler in (signal.SIG_IGN, signal.SIG_DFL, meh):
-                signal.signal(signum, meh)
-            else:
-                assert False, (signum, handler)
+        elif filter == 'signal':
+            self._signals[data].add(fiber)
+            assert signal.getsignal(data) not in {signal.SIG_IGN, signal.SIG_DFL, None}
         else:
             raise ValueError('Unknown filter', event.filter)
 
-    def remove_event(self, event):
-        if event.filter == 'read':
-            fd = fileno(event.data)
-            self._readers[fd].remove(event)
+    def remove_event(self, fiber, filter, data):
+        if filter in {'read', 'write'}:
+            fd = fileno(data)
+            watchers = {'read': self._readers, 'write': self._writers}[filter][fd]
+            watchers.remove(fiber)
             self._update_poll(fd)
-        elif event.filter == 'write':
-            fd = fileno(event.data)
-            self._writers[fd].remove(event)
-            self._update_poll(fd)
-        elif event.filter == 'signal':
-            signum = event.data
-            if signal.getsignal(signum) == self._signal_handler:
-                if signal.signal(signum, self._original_signal_handlers.pop(signum)) != self._signal_handler:
-                    raise RuntimeError('signal handler changed unexpectedly')
+        elif filter == 'signal':
+            self._signals[data].remove(fiber)
         else:
             raise ValueError('unknown filter', event.filter)
 
 schedule = scheduler.schedule
 
 run = scheduler.run
-
-event = scheduler.event
-Event = event
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.