1. Simon Law
  2. python-goroutine

Commits

Simon Law  committed 41737ac

Initial implementation.

go() is a higher-order function that applies parameters to a function,
creating a Goroutine. This is analogous to the "go" keyword in Go.

Channel is a data-structure that represents channels. This is similar
to the "channel" data-structure in Go.

  • Participants
  • Branches default

Comments (0)

Files changed (3)

File examples/print_numbers.py

View file
  • Ignore whitespace
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from __future__ import absolute_import, print_function
+
+from goroutine import Channel, go
+
+
+def gorange(start, count, output):
+    for i in range(count):
+        output.send(start + i)
+    output.close()
+
+
+def goprint(input, done):
+    for num in input:
+        print(num)
+    done.send(True)
+
+
+if __name__ == '__main__':
+    numbers = Channel()
+    done = Channel()
+    go(gorange, 1, 10, numbers)
+    go(goprint, numbers, done)
+
+    done.recv()

File goroutine/__init__.py

View file
  • Ignore whitespace
+# -*- coding: utf-8 -*-
+
+from __future__ import absolute_import, print_function
+
+from itertools import chain
+import multiprocessing
+import multiprocessing.queues
+try:
+    import queue
+except ImportError:
+    import Queue as queue
+from threading import current_thread, Thread
+
+
+class Channel(object):
+    from multiprocessing.queues import Empty, Full
+
+    def __init__(self, bufsize=0):
+        self.bufsize = bufsize
+        if bufsize == 0:
+            # Control pipe for blocking sends
+            self._block = queue.Queue(1)
+            bufsize = 1
+        if bufsize == 1:
+            # Single thread, so we use queues
+            self._queue = queue.Queue(bufsize)
+        else:
+            # Multiple processes, so we use multiprocessing
+            self._queue = multiprocessing.Queue(bufsize)
+
+        self._closed = False
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        try:
+            return self.recv()
+        except self.Empty:
+            raise StopIteration()
+    next = __next__
+
+    def close(self):
+        self._closed = True
+
+    @property
+    def closed(self):
+        return self._closed
+
+    def send(self, obj):
+        if self._closed:
+            raise self.Full()
+
+        self._queue.put((current_thread().ident, obj))
+        if self.bufsize == 0:
+            # Wait until the goroutine picks it up
+            self._block.get()
+
+    def recv(self):
+        block = not self._closed
+        while True:
+            try:
+                thread, result = self._queue.get(block=block)
+                if self.bufsize == 0:
+                    # Unblock the sender
+                    self._block.put(None)
+                elif thread == current_thread().ident:
+                    # Ignore items in the queue that were placed by this thread
+                    self._queue.put((thread, result))
+                    continue
+                return result
+            except (multiprocessing.queues.Empty, queue.Empty):
+                # Non-blocking Queues sometimes raise Empty when they are not,
+                # so ensure that the queue is actually empty.
+                if self._queue.qsize() == 0:
+                    raise self.Empty()
+
+
+class Goroutine(object):
+    def __init__(self, func, args, kwargs, channels=None):
+        self.func = func
+        self.args = args
+        self.kwargs = kwargs
+        self._channels = channels
+
+        self._pool = None
+
+    def __del__(self):
+        pass
+
+    def _bufsizes(self):
+        if self._channels is None:
+            # Attempt to discover channels
+            self._channels = tuple(chain(
+                (c for c in self.args
+                 if isinstance(c, Channel)),
+                (c for c in self.kwargs.itervalues()
+                 if isinstance(c, Channel))
+            ))
+
+        return (c.bufsize for c in self._channels)
+
+    def _make_pool(self):
+        def wrapped(*args, **kwargs):
+            for channel in self._channels:
+                channel._isparent = False
+            return self.func(*args, **kwargs)
+
+        try:
+            processes = max(min(self._bufsizes()), 1)
+        except ValueError:
+            processes = 1       # No channels were specified
+
+        pool = [Thread(target=wrapped,
+                       args=self.args, kwargs=self.kwargs)
+                for i in range(processes)]
+        for thread in pool:
+            thread.daemon = True
+        return pool
+
+    def _start(self):
+        # Make the pool if it doesn't already exist
+        if self._pool is None:
+            self._pool = self._make_pool()
+
+        for thread in self._pool:
+            thread.start()
+
+
+def go(f, *args, **kwargs):
+    goroutine = Goroutine(func=f, args=args, kwargs=kwargs)
+    goroutine._start()
+    return goroutine

File test.py

View file
  • Ignore whitespace
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from __future__ import absolute_import, print_function
+
+from time import sleep, time
+import unittest
+
+from goroutine import Channel, go
+
+
+class TestCoroutine(unittest.TestCase):
+    def test_create(self):
+        def func(*args, **kwargs):
+            pass
+
+        # No arguments
+        goroutine = go(func)
+        self.assertEqual(goroutine.func, func)
+        self.assertEqual(goroutine.args, ())
+        self.assertEqual(goroutine.kwargs, {})
+
+        # Arguments
+        goroutine = go(func, 'Hello', 'world', lang='en')
+        self.assertEqual(goroutine.func, func)
+        self.assertEqual(goroutine.args, ('Hello', 'world'))
+        self.assertEqual(goroutine.kwargs, {'lang': 'en'})
+
+    def test_async(self):
+        def func(*args, **kwargs):
+            sleep(1)
+
+        start_time = time()
+        go(func, Channel())
+        # Did not sleep
+        self.assertAlmostEqual(start_time, time(), places=2)
+
+    def test_blocking(self):
+        channel = Channel()
+
+        def slow_add(channel):
+            sleep(0.01)
+            x, y = channel.recv()
+            sleep(0.01)
+            channel.send(x + y)
+
+        start_time = time()
+        goroutine = go(slow_add, channel)
+        self.assertAlmostEqual(start_time - time(), 0, places=2)
+        self.assertEqual(len(goroutine._pool), 1)
+
+        channel.send([1, 2])
+        self.assertAlmostEqual(start_time - time(), -0.01, places=2)
+
+        result = channel.recv()
+        self.assertEqual(result, 1 + 2)
+        self.assertAlmostEqual(start_time - time(), -0.02, places=2)
+
+    def test_single_thread(self):
+        channel = Channel(bufsize=1)
+
+        def slow_add(channel):
+            while True:
+                x, y = channel.recv()
+                channel.send(x + y)
+
+        goroutine = go(slow_add, channel)
+
+        channel.send([1, 2])
+        self.assertEqual(len(goroutine._pool), 1)
+
+        self.assertEqual(channel.recv(), 1 + 2)
+
+
+class TestChannel(unittest.TestCase):
+    def test_create(self):
+        channel = Channel()
+        self.assertEqual(channel.bufsize, 0)
+
+    def test_close(self):
+        channel = Channel(bufsize=1)
+        channel.send(1)
+        self.assertFalse(channel.closed)
+        channel.close()
+        self.assertTrue(channel.closed)
+
+        # Can't send() any more
+        self.assertRaises(Channel.Full, channel.send, 2)
+
+        # recv() empties out
+        def receiver(channel):
+            self.assertEqual(channel.recv(), 1)
+            self.assertRaises(Channel.Empty, channel.recv)
+        go(receiver, channel)
+
+        # close() is idempotent
+        channel.close()
+        self.assertTrue(channel.closed)
+
+    def test_communicate(self):
+        channel = Channel(bufsize=1)
+        channel.send(1)
+        go((lambda channel: self.assertEqual(channel.recv(), 1)),
+           channel)
+
+    def test_iter(self):
+        def sum(channel, done):
+            result = 0
+            for i in channel:
+                result += i
+            done.send(result)
+
+        channel = Channel(bufsize=1)
+        done = Channel()
+        go(sum, channel, done)
+
+        channel.send(1)
+        channel.send(2)
+        channel.send(3)
+        channel.close()
+        self.assertEqual(done.recv(), 1 + 2 + 3)
+
+
+if __name__ == '__main__':
+    unittest.main(module='test', exit=False)