Jeffrey Gelens avatar Jeffrey Gelens committed f91cba9

Initial commit

Comments (0)

Files changed (14)

+syntax: glob
+
+.pyc
+tests/testresults.sqlite3
+Gevent is written and maintained by
+
+  Denis Bilenko
+
+and the contributors (ordered alphabetically):
+
+  Daniele Varrazzo
+  Jason Toffaletti
+  Jeffrey Gelens
+  Ludvig Ericson
+  Marcus Cavanaugh
+  Matt Goodall
+  Mike Barton
+  Nicholas Piël
+  Örjan Persson
+  Ralf Schmitt
+  Randall Leeds
+  Ted Suzman
+  Uriel Katz
+
+Gevent is inspired by and uses some code from eventlet which was written by
+
+  Bob Ipollito
+  Donovan Preston
+
+The libevent wrappers are based on pyevent by
+
+  Dug Song
+  Martin Murray
+
+The win32util module is taken from Twisted.
+
+Some modules (local and ssl) are adaptations of the modules from the Python standard library.
+
+If your code is used in gevent and you are not mentioned above, please contact the maintainer.
+Copyright (c) 2010, Noppo (Jeffrey Gelens) <http://www.noppo.pro/>
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification,
+are permitted provided that the following conditions are met:
+
+Redistributions of source code must retain the above copyright notice, this list
+of conditions and the following disclaimer.
+Redistributions in binary form must reproduce the above copyright notice, this
+list of conditions and the following disclaimer in the documentation and/or
+other materials provided with the distribution.
+Neither the name of the Noppo nor the names of its contributors may be
+used to endorse or promote products derived from this software without specific
+prior written permission.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
+ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+recursive-include test *
+recursive-include geventwebsocket *
+include LICENSE
+include README.rst
+include MANIFEST.in

Empty file added.

geventwebsocket/__init__.py

+
+version_info = (0, 1, 0)
+__version__ =  ".".join(map(str, version_info))
+
+try:
+    from geventwebsocket.websocket import WebSocket
+except ImportError:
+    import traceback
+    traceback.print_exc()

geventwebsocket/handler.py

+import re
+import struct
+import time
+import traceback
+import sys
+from hashlib import md5
+from gevent.pywsgi import WSGIHandler
+from geventwebsocket import WebSocket
+
+
+class WebSocketHandler(WSGIHandler):
+    def handle_one_response(self):
+        self.time_start = time.time()
+        self.status = None
+        self.response_length = 0
+
+        if self.environ.get("HTTP_CONNECTION") != "Upgrade" or \
+           self.environ.get("HTTP_UPGRADE") != "WebSocket" or \
+           not self.environ.get("HTTP_ORIGIN"):
+            message = "Websocket connection expected"
+            headers = [("Content-Length", str(len(message))),]
+            self.start_response("HTTP/1.1 400 Bad Request", headers, message)
+            self.close_connection = True
+            return
+
+        ws = WebSocket(self.rfile, self.wfile, self.socket, self.environ)
+        challenge = self._get_challenge()
+
+        headers = [
+            ("Upgrade", "WebSocket"),
+            ("Connection", "Upgrade"),
+            ("Sec-WebSocket-Origin", ws.origin),
+            ("Sec-WebSocket-Protocol", ws.protocol),
+            ("Sec-WebSocket-Location", "ws://" + self.environ.get('HTTP_HOST') + ws.path),
+        ]
+
+        self.start_response(
+            "HTTP/1.1 101 Web Socket Protocol Handshake", headers, challenge
+        )
+
+        try:
+            self.application(self.environ, self.start_response, ws)
+        except Exception:
+            traceback.print_exc()
+            sys.exc_clear()
+            try:
+                args = (getattr(self, 'server', ''),
+                        getattr(self, 'requestline', ''),
+                        getattr(self, 'client_address', ''),
+                        getattr(self, 'application', ''))
+                msg = '%s: Failed to handle request:\n  request = %s from %s\n  application = %s\n\n' % args
+                sys.stderr.write(msg)
+            except Exception:
+                sys.exc_clear()
+        finally:
+            self.wsgi_input._discard()
+            self.time_finish = time.time()
+            self.log_request()
+
+    def start_response(self, status, headers, body=None):
+        towrite = [status]
+        for header in headers:
+            towrite.append(": ".join(header))
+
+        if body is not None:
+            towrite.append("")
+            towrite.append(body)
+
+        self.wfile.write("\r\n".join(towrite))
+
+    def _get_key_value(self, key_value):
+        key_number = int(re.sub("\\D", "", key_value))
+        spaces = re.subn(" ", "", key_value)[1]
+
+        if key_number % spaces != 0:
+            raise HandShakeError("key_number %d is not an intergral multiple of"
+                                 " spaces %d" % (key_number, spaces))
+
+        return key_number / spaces
+
+    def _get_challenge(self):
+        key1 = self.environ.get('HTTP_SEC_WEBSOCKET_KEY1')
+        key2 = self.environ.get('HTTP_SEC_WEBSOCKET_KEY2')
+        if not (key1 and key2):
+            message = "Client using old protocol implementation"
+            headers = [("Content-Length", str(len(message))),]
+            self.start_response("HTTP/1.1 400 Bad Request", headers, message)
+            self.close_connection = True
+            return
+
+        part1 = self._get_key_value(self.environ['HTTP_SEC_WEBSOCKET_KEY1'])
+        part2 = self._get_key_value(self.environ['HTTP_SEC_WEBSOCKET_KEY2'])
+
+        # This request should have 8 bytes of data in the body
+        key3 = self.rfile.read(8)
+
+        challenge = ""
+        challenge += struct.pack("!I", part1)
+        challenge += struct.pack("!I", part2)
+        challenge += key3
+
+        return md5(challenge).digest()

geventwebsocket/websocket.py

+class WebSocket(object):
+    def __init__(self, rfile, wfile, sock, environ):
+        self.rfile = rfile
+        self.wfile = wfile
+        self.socket = sock
+        self.origin = environ.get('HTTP_ORIGIN')
+        self.protocol = environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL', 'unknown')
+        self.path = environ.get('PATH_INFO')
+        self.websocket_closed = False
+
+    def send(self, message):
+        if self.websocket_closed:
+            raise Exception("Connection was terminated")
+
+        if isinstance(message, unicode):
+            message = message.encode('utf-8')
+        elif isinstance(message, str):
+            message = unicode(message).encode('utf-8')
+        else:
+            raise Exception("Invalid message encoding")
+
+        self.wfile.write("\x00" + message + "\xFF")
+
+    def close_connection(self):
+        if not self.websocket_closed:
+            self.websocket_closed = True
+            self.socket.shutdown(True)
+            self.socket.close()
+        else:
+            return
+
+    def _message_length(self):
+        # TODO: buildin security agains lengths greater than 2**31 or 2**32
+        length = 0
+
+        while True:
+            byte_str = self.rfile.read(1)
+
+            if not byte_str:
+                return 0
+            else:
+                byte = ord(byte_str)
+
+            if byte != 0x00:
+                length = length * 128 + (byte & 0x7f)
+                if (byte & 0x80) != 0x80:
+                    break
+
+        return length
+
+    def _read_until(self):
+        bytes = []
+
+        while True:
+            byte = self.rfile.read(1)
+            if ord(byte) != 0xff:
+                bytes.append(byte)
+            else:
+                break
+
+        return ''.join(bytes)
+
+    def wait(self):
+        while True:
+            if self.websocket_closed:
+                return None
+
+            frame_str = self.rfile.read(1)
+            frame_type = ord(frame_str)
+            if (frame_type & 0x80) == 0x00: # most significant byte is not set
+
+                if frame_type == 0x00:
+                    bytes = self._read_until()
+                    return bytes.decode("utf-8", "replace")
+                else:
+                    self.websocket_closed = True
+
+            elif (frame_type & 0x80) == 0x80: # most significant byte is set
+                # Read binary data (forward-compatibility)
+                if frame_type != 0xff:
+                    self.websocket_closed = True
+                else:
+                    length = self._message_length()
+                    if length == 0:
+                        self.websocket_closed = True
+                    else:
+                        self.rfile.read(length) # discard the bytes
+            else:
+                raise IOError("Reveiced an invalid message")
+from setuptools import setup, find_packages
+
+setup(
+    name="gevent-websocket",
+    version="0.1.0",
+    description="Websocket handler for the gevent pywsgi server, a Python network library",
+    long_description=open("README.rst").read(),
+    author="Jeffrey Gelens",
+    author_email="jeffrey@noppo.org",
+    url="",
+    install_requires=("gevent", "greenlet"),
+    packages=find_packages(exclude=["example","test"]),
+    classifiers=[
+        "Development Status :: 3 - Alpha",
+        "License :: OSI Approved :: BSD License",
+        "Programming Language :: Python",
+        "Operating System :: MacOS :: MacOS X",
+        "Operating System :: POSIX",
+        "Operating System :: Microsoft :: Windows",
+        "Topic :: Internet",
+        "Topic :: Software Development :: Libraries :: Python Modules",
+        "Intended Audience :: Developers",
+    ],
+)

tests/greentest.py

+# Copyright (c) 2008-2009 AG Projects
+# Author: Denis Bilenko
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+
+# package is named greentest, not test, so it won't be confused with test in stdlib
+import sys
+import unittest
+import time
+import traceback
+import re
+
+import gevent
+
+disabled_marker = '-*-*-*-*-*- disabled -*-*-*-*-*-'
+def exit_disabled():
+    sys.exit(disabled_marker)
+
+def exit_unless_25():
+    if sys.version_info[:2] < (2, 5):
+        exit_disabled()
+
+VERBOSE = sys.argv.count('-v') > 1
+
+if '--debug-greentest' in sys.argv:
+    sys.argv.remove('--debug-greentest')
+    DEBUG = True
+else:
+    DEBUG = False
+
+
+class TestCase(unittest.TestCase):
+
+    __timeout__ = 1
+    switch_expected = True
+    _switch_count = None
+
+    def setUp(self):
+        gevent.sleep(0) # switch at least once to setup signal handlers
+        if hasattr(gevent.core, '_event_count'):
+            self._event_count = (gevent.core._event_count(), gevent.core._event_count_active())
+        hub = gevent.hub.get_hub()
+        if hasattr(hub, 'switch_count'):
+            self._switch_count = hub.switch_count
+        self._timer = gevent.Timeout.start_new(self.__timeout__, RuntimeError('test is taking too long'))
+
+    def tearDown(self):
+        try:
+            if not hasattr(self, 'stderr'):
+                self.unhook_stderr()
+            if hasattr(self, 'stderr'):
+                sys.__stderr__.write(self.stderr)
+        except:
+            traceback.print_exc()
+        if hasattr(self, '_timer'):
+            self._timer.cancel()
+            hub = gevent.hub.get_hub()
+            if self._switch_count is not None and hasattr(hub, 'switch_count'):
+                msg = ''
+                if hub.switch_count < self._switch_count:
+                    msg = 'hub.switch_count decreased?\n'
+                elif hub.switch_count == self._switch_count:
+                    if self.switch_expected:
+                        msg = '%s.%s did not switch\n' % (type(self).__name__, self.testname)
+                elif hub.switch_count > self._switch_count:
+                    if not self.switch_expected:
+                        msg = '%s.%s switched but expected not to\n' % (type(self).__name__, self.testname)
+                if msg:
+                    print >> sys.stderr, 'WARNING: ' + msg
+
+            if hasattr(gevent.core, '_event_count'):
+                event_count = (gevent.core._event_count(), gevent.core._event_count_active())
+                if event_count > self._event_count:
+                    args = (type(self).__name__, self.testname, self._event_count, event_count)
+                    sys.stderr.write('WARNING: %s.%s event count was %s, now %s\n' % args)
+                    gevent.sleep(0.1)
+        else:
+            sys.stderr.write('WARNING: %s.setUp does not call base class setUp\n' % (type(self).__name__, ))
+
+    @property
+    def testname(self):
+        return getattr(self, '_testMethodName', '') or getattr(self, '_TestCase__testMethodName')
+
+    @property
+    def testcasename(self):
+        return self.__class__.__name__ + '.' + self.testname
+
+    def hook_stderr(self):
+        if VERBOSE:
+            return
+        from cStringIO import StringIO
+        self.new_stderr = StringIO()
+        self.old_stderr = sys.stderr
+        sys.stderr = self.new_stderr
+
+    def unhook_stderr(self):
+        if VERBOSE:
+            return
+        try:
+            value = self.new_stderr.getvalue()
+        except AttributeError:
+            return None
+        sys.stderr = self.old_stderr
+        self.stderr = value
+        return value
+
+    def assert_no_stderr(self):
+        stderr = self.unhook_stderr()
+        assert not stderr, 'Expected no stderr, got:\n__________\n%s\n^^^^^^^^^^\n\n' % (stderr, )
+
+    def assert_stderr_traceback(self, typ, value=None):
+        if VERBOSE:
+            return
+        if isinstance(typ, Exception):
+            if value is None:
+                value = str(typ)
+            typ = typ.__class__.__name__
+        else:
+            typ = getattr(typ, '__name__', typ)
+        stderr = self.unhook_stderr()
+        assert stderr is not None, repr(stderr)
+        traceback_re = '^Traceback \\(most recent call last\\):\n( +.*?\n)+^(?P<type>\w+): (?P<value>.*?)$'
+        self.extract_re(traceback_re, type=typ, value=value)
+
+    def assert_stderr(self, message):
+        if VERBOSE:
+            return
+        exact_re = '^' + message + '.*?\n$.*'
+        if re.match(exact_re, self.stderr):
+            self.extract_re(exact_re)
+        else:
+            words_re = '^' + '.*?'.join(message.split()) + '.*?\n$'
+            if re.match(words_re, self.stderr):
+                self.extract_re(words_re)
+            else:
+                if message.endswith('...'):
+                    another_re = '^' + '.*?'.join(message.split()) + '.*?(\n +.*?$){2,5}\n\n'
+                    self.extract_re(another_re)
+                else:
+                    raise AssertionError('%r did not match:\n%r' % (message, self.stderr))
+
+    def assert_mainloop_assertion(self, message=None):
+        self.assert_stderr_traceback('AssertionError', 'Cannot switch to MAINLOOP from MAINLOOP')
+        if message is not None:
+            self.assert_stderr(message)
+
+    def extract_re(self, regex, **kwargs):
+        assert self.stderr is not None
+        m = re.search(regex, self.stderr, re.DOTALL|re.M)
+        if m is None:
+            raise AssertionError('%r did not match:\n%r' % (regex, self.stderr))
+        for key, expected_value in kwargs.items():
+            real_value = m.group(key)
+            if expected_value is not None:
+                try:
+                    self.assertEqual(real_value, expected_value)
+                except AssertionError:
+                    print 'failed to process: %s' % self.stderr
+                    raise
+        if DEBUG:
+            ate = '\n#ATE#: ' + self.stderr[m.start(0):m.end(0)].replace('\n', '\n#ATE#: ') + '\n'
+            sys.__stderr__.write(ate)
+        self.stderr = self.stderr[:m.start(0)] + self.stderr[m.end(0)+1:]
+
+
+main = unittest.main
+
+_original_Hub = gevent.hub.Hub
+
+class CountingHub(_original_Hub):
+
+    switch_count = 0
+
+    def switch(self):
+        self.switch_count += 1
+        return _original_Hub.switch(self)
+
+gevent.hub.Hub = CountingHub
+
+
+def test_outer_timeout_is_not_lost(self):
+    timeout = gevent.Timeout.start_new(0.01)
+    try:
+        self.wait(timeout=0.02)
+    except gevent.Timeout, ex:
+        assert ex is timeout, (ex, timeout)
+    else:
+        raise AssertionError('must raise Timeout')
+    gevent.sleep(0.02)
+
+
+class GenericWaitTestCase(TestCase):
+
+    def wait(self, timeout):
+        raise NotImplementedError('override me in subclass')
+
+    test_outer_timeout_is_not_lost = test_outer_timeout_is_not_lost
+
+    def test_returns_none_after_timeout(self):
+        start = time.time()
+        result = self.wait(timeout=0.01)
+        # join and wait simply returns after timeout expires
+        delay = time.time() - start
+        assert 0.01 - 0.001 <= delay < 0.01 + 0.01 + 0.1, delay
+        assert result is None, repr(result)
+
+
+class GenericGetTestCase(TestCase):
+
+    def wait(self, timeout):
+        raise NotImplementedError('override me in subclass')
+
+    test_outer_timeout_is_not_lost = test_outer_timeout_is_not_lost
+
+    def test_raises_timeout_number(self):
+        start = time.time()
+        self.assertRaises(gevent.Timeout, self.wait, timeout=0.01)
+        # get raises Timeout after timeout expired
+        delay = time.time() - start
+        assert 0.01 - 0.001 <= delay < 0.01 + 0.01 + 0.1, delay
+
+    def test_raises_timeout_Timeout(self):
+        start = time.time()
+        timeout = gevent.Timeout(0.01)
+        try:
+            self.wait(timeout=timeout)
+        except gevent.Timeout, ex:
+            assert ex is timeout, (ex, timeout)
+        delay = time.time() - start
+        assert 0.01 - 0.001 <= delay < 0.01 + 0.01 + 0.1, delay
+
+    def test_raises_timeout_Timeout_exc_customized(self):
+        start = time.time()
+        error = RuntimeError('expected error')
+        timeout = gevent.Timeout(0.01, exception=error)
+        try:
+            self.wait(timeout=timeout)
+        except RuntimeError, ex:
+            assert ex is error, (ex, error)
+        delay = time.time() - start
+        assert 0.01 - 0.001 <= delay < 0.01 + 0.01 + 0.1, delay
+
+
+class ExpectedException(Exception):
+    """An exception whose traceback should be ignored"""

tests/mysubprocess.py

+import sys
+import os
+import subprocess
+import signal
+from subprocess import *
+
+class Popen(subprocess.Popen):
+
+    def send_signal(self, sig):
+        if sys.platform == 'win32':
+            sig = signal.SIGTERM
+        if hasattr(subprocess.Popen, 'send_signal'):
+            try:
+                return subprocess.Popen.send_signal(self, sig)
+            except Exception:
+                sys.stderr.write('send_signal(%s, %s) failed: %s\n' % (self.pid, sig, ex))
+                self.external_kill()
+        else:
+            if hasattr(os, 'kill'):
+                sys.stderr.write('Sending signal %s to %s\n' % (sig, self.pid))
+                try:
+                    os.kill(self.pid, sig)
+                except Exception, ex:
+                    sys.stderr.write('Error while killing %s: %s\n' % (self.pid, ex))
+                    self.external_kill()
+            else:
+                self.external_kill()
+
+    if not hasattr(subprocess.Popen, 'kill'):
+
+        def kill(self):
+            return self.send_signal(getattr(signal, 'SIGTERM', 15))
+
+    if not hasattr(subprocess.Popen, 'terminate'):
+
+        def terminate(self):
+            return self.send_signal(getattr(signal, 'SIGTERM', 9))
+
+    def interrupt(self):
+        sig = getattr(signal, 'SIGINT', 2)
+        return self.send_signal(sig)
+
+    def external_kill(self):
+        if sys.platform == 'win32':
+            sys.stderr.write('Killing %s: %s\n' % (self.pid, ex))
+            os.system('taskkill /f /pid %s' % self.pid)
+        else:
+            sys.stderr.write('Cannot kill on this platform. Please kill %s\n' % self.pid)
+

tests/test__websocket.py

+# Websocket tests by Jeffrey Gelens, Copyright 2010, Noppo.pro
+# Socket related functions by:
+#
+# @author Donovan Preston
+#
+# Copyright (c) 2007, Linden Research, Inc.
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+from gevent import monkey
+monkey.patch_all(thread=False)
+
+import sys
+import greentest
+import gevent
+from gevent import socket
+from geventwebsocket.handler import WebSocketHandler
+
+
+CONTENT_LENGTH = 'Content-Length'
+CONN_ABORTED_ERRORS = []
+DEBUG = '-v' in sys.argv
+
+try:
+    from errno import WSAECONNABORTED
+    CONN_ABORTED_ERRORS.append(WSAECONNABORTED)
+except ImportError:
+    pass
+
+
+class ConnectionClosed(Exception):
+    pass
+
+
+def read_headers(fd):
+    response_line = fd.readline()
+    if not response_line:
+        raise ConnectionClosed
+    headers = {}
+    while True:
+        line = fd.readline().strip()
+        if not line:
+            break
+        try:
+            key, value = line.split(': ', 1)
+        except:
+            print 'Failed to split: %r' % (line, )
+            raise
+        assert key.lower() not in [x.lower() for x in headers.keys()], 'Header %r:%r sent more than once: %r' % (key, value, headers)
+        headers[key] = value
+    return response_line, headers
+
+
+def iread_chunks(fd):
+    while True:
+        line = fd.readline()
+        chunk_size = line.strip()
+        try:
+            chunk_size = int(chunk_size, 16)
+        except:
+            print 'Failed to parse chunk size: %r' % line
+            raise
+        if chunk_size == 0:
+            crlf = fd.read(2)
+            assert crlf == '\r\n', repr(crlf)
+            break
+        data = fd.read(chunk_size)
+        yield data
+        crlf = fd.read(2)
+        assert crlf == '\r\n', repr(crlf)
+
+
+class Response(object):
+
+    def __init__(self, status_line, headers, body=None, chunks=None):
+        self.status_line = status_line
+        self.headers = headers
+        self.body = body
+        self.chunks = chunks
+        try:
+            version, code, self.reason = status_line[:-2].split(' ', 2)
+        except Exception:
+            print 'Error: %r' % status_line
+            raise
+        self.code = int(code)
+        HTTP, self.version = version.split('/')
+        assert HTTP == 'HTTP', repr(HTTP)
+        assert self.version in ('1.0', '1.1'), repr(self.version)
+
+    def __iter__(self):
+        yield self.status_line
+        yield self.headers
+        yield self.body
+
+    def __str__(self):
+        args = (self.__class__.__name__, self.status_line, self.headers, self.body, self.chunks)
+        return '<%s status_line=%r headers=%r body=%r chunks=%r>' % args
+
+    def assertCode(self, code):
+        if hasattr(code, '__contains__'):
+            assert self.code in code, 'Unexpected code: %r (expected %r)\n%s' % (self.code, code, self)
+        else:
+            assert self.code == code, 'Unexpected code: %r (expected %r)\n%s' % (self.code, code, self)
+
+    def assertReason(self, reason):
+        assert self.reason == reason, 'Unexpected reason: %r (expected %r)\n%s' % (self.reason, reason, self)
+
+    def assertVersion(self, version):
+        assert self.version == version, 'Unexpected version: %r (expected %r)\n%s' % (self.version, version, self)
+
+    def assertHeader(self, header, value):
+        real_value = self.headers.get(header)
+        assert real_value == value, \
+               'Unexpected header %r: %r (expected %r)\n%s' % (header, real_value, value, self)
+
+    def assertBody(self, body):
+        assert self.body == body, \
+               'Unexpected body: %r (expected %r)\n%s' % (self.body, body, self)
+
+    @classmethod
+    def read(cls, fd, code=200, reason='default', version='1.1', body=None):
+        _status_line, headers = read_headers(fd)
+        self = cls(_status_line, headers)
+        if code is not None:
+            self.assertCode(code)
+        if reason == 'default':
+            reason = {200: 'OK'}.get(code)
+        if reason is not None:
+            self.assertReason(reason)
+        if version is not None:
+            self.assertVersion(version)
+        if self.code == 100:
+            return self
+        try:
+            if 'chunked' in headers.get('Transfer-Encoding', ''):
+                if CONTENT_LENGTH in headers:
+                    print "WARNING: server used chunked transfer-encoding despite having Content-Length header (libevent 1.x's bug)"
+                self.chunks = list(iread_chunks(fd))
+                self.body = ''.join(self.chunks)
+            elif CONTENT_LENGTH in headers:
+                num = int(headers[CONTENT_LENGTH])
+                self.body = fd.read(num)
+            else:
+                self.body = fd.read(16)
+        except:
+            print 'Response.read failed to read the body:\n%s' % self
+            raise
+        if body is not None:
+            self.assertBody(body)
+        return self
+
+read_http = Response.read
+
+
+class DebugFileObject(object):
+
+    def __init__(self, obj):
+        self.obj = obj
+
+    def read(self, *args):
+        result = self.obj.read(*args)
+        if DEBUG:
+            print repr(result)
+        return result
+
+    def readline(self, *args):
+        result = self.obj.readline(*args)
+        if DEBUG:
+            print repr(result)
+        return result
+
+    def __getattr__(self, item):
+        assert item != 'obj'
+        return getattr(self.obj, item)
+
+
+def makefile(self, mode='r', bufsize=-1):
+    return DebugFileObject(socket._fileobject(self.dup(), mode, bufsize))
+
+socket.socket.makefile = makefile
+
+class TestCase(greentest.TestCase):
+    __timeout__ = 5
+
+    def get_wsgi_module(self):
+        from gevent import pywsgi
+        return pywsgi
+
+    def init_server(self, application):
+        self.server = self.get_wsgi_module().WSGIServer(('127.0.0.1', 0),
+            application, handler_class=WebSocketHandler)
+
+    def setUp(self):
+        application = self.application
+        self.init_server(application)
+        self.server.start()
+        self.port = self.server.server_port
+        greentest.TestCase.setUp(self)
+
+
+    def tearDown(self):
+        greentest.TestCase.tearDown(self)
+        timeout = gevent.Timeout.start_new(0.5)
+        try:
+            self.server.stop()
+        finally:
+            timeout.cancel()
+
+    def connect(self):
+        return socket.create_connection(('127.0.0.1', self.port))
+
+
+class TestWebSocket(TestCase):
+    message = "\x00Hello world\xff"
+
+    def application(self, environ, start_response, ws):
+        if environ['PATH_INFO'] == "/echo":
+            while True:
+                message = ws.wait()
+                if message is None:
+                    break
+                ws.send(message)
+                ws.close_connection()
+            return []
+
+    def test_basic(self):
+        fd = self.connect().makefile(bufsize=1)
+        headers = "" \
+        "GET /echo HTTP/1.1\r\n" \
+        "Host: localhost\r\n" \
+        "Connection: Upgrade\r\n" \
+        "Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n" \
+        "Sec-WebSocket-Protocol: test\r\n" \
+        "Upgrade: WebSocket\r\n" \
+        "Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n" \
+        "Origin: http://localhost\r\n\r\n" \
+        "^n:ds[4U"
+
+        fd.write(headers)
+        fd.write(self.message)
+
+        response = read_http(fd, code=101, body="8jKS'y:G*Co,Wxa-", reason="Web Socket Protocol Handshake")
+        response.assertHeader("Upgrade", "WebSocket")
+        response.assertHeader("Connection", "Upgrade")
+        response.assertHeader("Sec-WebSocket-Origin", "http://localhost")
+        response.assertHeader("Sec-WebSocket-Location", "ws://localhost/echo")
+        response.assertHeader("Sec-WebSocket-Protocol", "test")
+
+        message = fd.read()
+        assert message == self.message, \
+               'Unexpected message: %r (expected %r)\n%s' % (message, self.message, self)
+
+        fd.close()
+
+    def test_badrequest(self):
+        fd = self.connect().makefile(bufsize=1)
+        fd.write('GET / HTTP/1.1\r\nHost: localhost\r\n\r\n')
+        read_http(fd, code=400, reason='Bad Request', body='Websocket connection expected')
+        fd.close()
+
+    def test_oldprotocol_version(self):
+        fd = self.connect().makefile(bufsize=1)
+        headers = "" \
+        "GET /echo HTTP/1.1\r\n" \
+        "Host: localhost\r\n" \
+        "Connection: Upgrade\r\n" \
+        "WebSocket-Protocol: sample\r\n" \
+        "Upgrade: WebSocket\r\n" \
+        "Origin: http://example.com\r\n\r\n" \
+        "^n:ds[4U"
+
+        fd.write(headers)
+        read_http(fd, code=400, reason='Bad Request', body='Client using old protocol implementation')
+
+        fd.close()
+
+if __name__ == '__main__':
+    greentest.main()

tests/test_support.py

+"""Supporting definitions for the Python regression tests."""
+
+import sys
+
+HOST = 'localhost'
+
+class Error(Exception):
+    """Base class for regression test exceptions."""
+
+class TestFailed(Error):
+    """Test failed."""
+
+class TestSkipped(Error):
+    """Test skipped.
+
+    This can be raised to indicate that a test was deliberatly
+    skipped, but not because a feature wasn't available.  For
+    example, if some resource can't be used, such as the network
+    appears to be unavailable, this should be raised instead of
+    TestFailed.
+    """
+
+class ResourceDenied(TestSkipped):
+    """Test skipped because it requested a disallowed resource.
+
+    This is raised when a test calls requires() for a resource that
+    has not be enabled.  It is used to distinguish between expected
+    and unexpected skips.
+    """
+
+verbose = 1              # Flag set to 0 by regrtest.py
+use_resources = None     # Flag set to [] by regrtest.py
+max_memuse = 0           # Disable bigmem tests (they will still be run with
+                         # small sizes, to make sure they work.)
+
+# _original_stdout is meant to hold stdout at the time regrtest began.
+# This may be "the real" stdout, or IDLE's emulation of stdout, or whatever.
+# The point is to have some flavor of stdout the user can actually see.
+_original_stdout = None
+def record_original_stdout(stdout):
+    global _original_stdout
+    _original_stdout = stdout
+
+def get_original_stdout():
+    return _original_stdout or sys.stdout
+
+def unload(name):
+    try:
+        del sys.modules[name]
+    except KeyError:
+        pass
+
+def unlink(filename):
+    import os
+    try:
+        os.unlink(filename)
+    except OSError:
+        pass
+
+def forget(modname):
+    '''"Forget" a module was ever imported by removing it from sys.modules and
+    deleting any .pyc and .pyo files.'''
+    unload(modname)
+    import os
+    for dirname in sys.path:
+        unlink(os.path.join(dirname, modname + os.extsep + 'pyc'))
+        # Deleting the .pyo file cannot be within the 'try' for the .pyc since
+        # the chance exists that there is no .pyc (and thus the 'try' statement
+        # is exited) but there is a .pyo file.
+        unlink(os.path.join(dirname, modname + os.extsep + 'pyo'))
+
+def is_resource_enabled(resource):
+    """Test whether a resource is enabled.  Known resources are set by
+    regrtest.py."""
+    return use_resources is not None and resource in use_resources
+
+def requires(resource, msg=None):
+    """Raise ResourceDenied if the specified resource is not available.
+
+    If the caller's module is __main__ then automatically return True.  The
+    possibility of False being returned occurs when regrtest.py is executing."""
+    # see if the caller's module is __main__ - if so, treat as if
+    # the resource was set
+    return
+    if sys._getframe().f_back.f_globals.get("__name__") == "__main__":
+        return
+    if not is_resource_enabled(resource):
+        if msg is None:
+            msg = "Use of the `%s' resource not enabled" % resource
+        raise ResourceDenied(msg)
+
+import socket
+
+def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM):
+    """Returns an unused port that should be suitable for binding.  This is
+    achieved by creating a temporary socket with the same family and type as
+    the 'sock' parameter (default is AF_INET, SOCK_STREAM), and binding it to
+    the specified host address (defaults to 0.0.0.0) with the port set to 0,
+    eliciting an unused ephemeral port from the OS.  The temporary socket is
+    then closed and deleted, and the ephemeral port is returned.
+
+    Either this method or bind_port() should be used for any tests where a
+    server socket needs to be bound to a particular port for the duration of
+    the test.  Which one to use depends on whether the calling code is creating
+    a python socket, or if an unused port needs to be provided in a constructor
+    or passed to an external program (i.e. the -accept argument to openssl's
+    s_server mode).  Always prefer bind_port() over find_unused_port() where
+    possible.  Hard coded ports should *NEVER* be used.  As soon as a server
+    socket is bound to a hard coded port, the ability to run multiple instances
+    of the test simultaneously on the same host is compromised, which makes the
+    test a ticking time bomb in a buildbot environment. On Unix buildbots, this
+    may simply manifest as a failed test, which can be recovered from without
+    intervention in most cases, but on Windows, the entire python process can
+    completely and utterly wedge, requiring someone to log in to the buildbot
+    and manually kill the affected process.
+
+    (This is easy to reproduce on Windows, unfortunately, and can be traced to
+    the SO_REUSEADDR socket option having different semantics on Windows versus
+    Unix/Linux.  On Unix, you can't have two AF_INET SOCK_STREAM sockets bind,
+    listen and then accept connections on identical host/ports.  An EADDRINUSE
+    socket.error will be raised at some point (depending on the platform and
+    the order bind and listen were called on each socket).
+
+    However, on Windows, if SO_REUSEADDR is set on the sockets, no EADDRINUSE
+    will ever be raised when attempting to bind two identical host/ports. When
+    accept() is called on each socket, the second caller's process will steal
+    the port from the first caller, leaving them both in an awkwardly wedged
+    state where they'll no longer respond to any signals or graceful kills, and
+    must be forcibly killed via OpenProcess()/TerminateProcess().
+
+    The solution on Windows is to use the SO_EXCLUSIVEADDRUSE socket option
+    instead of SO_REUSEADDR, which effectively affords the same semantics as
+    SO_REUSEADDR on Unix.  Given the propensity of Unix developers in the Open
+    Source world compared to Windows ones, this is a common mistake.  A quick
+    look over OpenSSL's 0.9.8g source shows that they use SO_REUSEADDR when
+    openssl.exe is called with the 's_server' option, for example. See
+    http://bugs.python.org/issue2550 for more info.  The following site also
+    has a very thorough description about the implications of both REUSEADDR
+    and EXCLUSIVEADDRUSE on Windows:
+    http://msdn2.microsoft.com/en-us/library/ms740621(VS.85).aspx)
+
+    XXX: although this approach is a vast improvement on previous attempts to
+    elicit unused ports, it rests heavily on the assumption that the ephemeral
+    port returned to us by the OS won't immediately be dished back out to some
+    other process when we close and delete our temporary socket but before our
+    calling code has a chance to bind the returned port.  We can deal with this
+    issue if/when we come across it."""
+    tempsock = socket.socket(family, socktype)
+    port = bind_port(tempsock)
+    tempsock.close()
+    del tempsock
+    return port
+
+def bind_port(sock, host='', preferred_port=54321):
+    """Try to bind the sock to a port.  If we are running multiple
+    tests and we don't try multiple ports, the test can fails.  This
+    makes the test more robust."""
+
+    import socket, errno
+
+    # Find some random ports that hopefully no one is listening on.
+    # Ideally each test would clean up after itself and not continue listening
+    # on any ports.  However, this isn't the case.  The last port (0) is
+    # a stop-gap that asks the O/S to assign a port.  Whenever the warning
+    # message below is printed, the test that is listening on the port should
+    # be fixed to close the socket at the end of the test.
+    # Another reason why we can't use a port is another process (possibly
+    # another instance of the test suite) is using the same port.
+    for port in [preferred_port, 9907, 10243, 32999, 0]:
+        try:
+            sock.bind((host, port))
+            if port == 0:
+                port = sock.getsockname()[1]
+            return port
+        except socket.error, (err, msg):
+            if err != errno.EADDRINUSE:
+                raise
+            print >>sys.__stderr__, \
+                '  WARNING: failed to listen on port %d, trying another' % port
+    raise TestFailed, 'unable to find port to listen on'
+
+FUZZ = 1e-6
+
+def fcmp(x, y): # fuzzy comparison function
+    if type(x) == type(0.0) or type(y) == type(0.0):
+        try:
+            x, y = coerce(x, y)
+            fuzz = (abs(x) + abs(y)) * FUZZ
+            if abs(x-y) <= fuzz:
+                return 0
+        except:
+            pass
+    elif type(x) == type(y) and type(x) in (type(()), type([])):
+        for i in range(min(len(x), len(y))):
+            outcome = fcmp(x[i], y[i])
+            if outcome != 0:
+                return outcome
+        return cmp(len(x), len(y))
+    return cmp(x, y)
+
+try:
+    unicode
+    have_unicode = 1
+except NameError:
+    have_unicode = 0
+
+is_jython = sys.platform.startswith('java')
+
+import os
+# Filename used for testing
+if os.name == 'java':
+    # Jython disallows @ in module names
+    TESTFN = '$test'
+elif os.name == 'riscos':
+    TESTFN = 'testfile'
+else:
+    TESTFN = '@test'
+    # Unicode name only used if TEST_FN_ENCODING exists for the platform.
+    if have_unicode:
+        # Assuming sys.getfilesystemencoding()!=sys.getdefaultencoding()
+        # TESTFN_UNICODE is a filename that can be encoded using the
+        # file system encoding, but *not* with the default (ascii) encoding
+        if isinstance('', unicode):
+            # python -U
+            # XXX perhaps unicode() should accept Unicode strings?
+            TESTFN_UNICODE = "@test-\xe0\xf2"
+        else:
+            # 2 latin characters.
+            TESTFN_UNICODE = unicode("@test-\xe0\xf2", "latin-1")
+        TESTFN_ENCODING = sys.getfilesystemencoding()
+        # TESTFN_UNICODE_UNENCODEABLE is a filename that should *not* be
+        # able to be encoded by *either* the default or filesystem encoding.
+        # This test really only makes sense on Windows NT platforms
+        # which have special Unicode support in posixmodule.
+        if (not hasattr(sys, "getwindowsversion") or
+                sys.getwindowsversion()[3] < 2): #  0=win32s or 1=9x/ME
+            TESTFN_UNICODE_UNENCODEABLE = None
+        else:
+            # Japanese characters (I think - from bug 846133)
+            TESTFN_UNICODE_UNENCODEABLE = eval('u"@test-\u5171\u6709\u3055\u308c\u308b"')
+            try:
+                # XXX - Note - should be using TESTFN_ENCODING here - but for
+                # Windows, "mbcs" currently always operates as if in
+                # errors=ignore' mode - hence we get '?' characters rather than
+                # the exception.  'Latin1' operates as we expect - ie, fails.
+                # See [ 850997 ] mbcs encoding ignores errors
+                TESTFN_UNICODE_UNENCODEABLE.encode("Latin1")
+            except UnicodeEncodeError:
+                pass
+            else:
+                print \
+                'WARNING: The filename %r CAN be encoded by the filesystem.  ' \
+                'Unicode filename tests may not be effective' \
+                % TESTFN_UNICODE_UNENCODEABLE
+
+# Make sure we can write to TESTFN, try in /tmp if we can't
+fp = None
+try:
+    fp = open(TESTFN, 'w+')
+except IOError:
+    TMP_TESTFN = os.path.join('/tmp', TESTFN)
+    try:
+        fp = open(TMP_TESTFN, 'w+')
+        TESTFN = TMP_TESTFN
+        del TMP_TESTFN
+    except IOError:
+        print ('WARNING: tests will fail, unable to write to: %s or %s' %
+                (TESTFN, TMP_TESTFN))
+if fp is not None:
+    fp.close()
+    unlink(TESTFN)
+del os, fp
+
+def findfile(file, here=__file__):
+    """Try to find a file on sys.path and the working directory.  If it is not
+    found the argument passed to the function is returned (this does not
+    necessarily signal failure; could still be the legitimate path)."""
+    import os
+    if os.path.isabs(file):
+        return file
+    path = sys.path
+    path = [os.path.dirname(here)] + path
+    for dn in path:
+        fn = os.path.join(dn, file)
+        if os.path.exists(fn): return fn
+    return file
+
+def verify(condition, reason='test failed'):
+    """Verify that condition is true. If not, raise TestFailed.
+
+       The optional argument reason can be given to provide
+       a better error text.
+    """
+
+    if not condition:
+        raise TestFailed(reason)
+
+def vereq(a, b):
+    """Raise TestFailed if a == b is false.
+
+    This is better than verify(a == b) because, in case of failure, the
+    error message incorporates repr(a) and repr(b) so you can see the
+    inputs.
+
+    Note that "not (a == b)" isn't necessarily the same as "a != b"; the
+    former is tested.
+    """
+
+    if not (a == b):
+        raise TestFailed, "%r == %r" % (a, b)
+
+def sortdict(dict):
+    "Like repr(dict), but in sorted order."
+    items = dict.items()
+    items.sort()
+    reprpairs = ["%r: %r" % pair for pair in items]
+    withcommas = ", ".join(reprpairs)
+    return "{%s}" % withcommas
+
+def check_syntax(statement):
+    try:
+        compile(statement, '<string>', 'exec')
+    except SyntaxError:
+        pass
+    else:
+        print 'Missing SyntaxError: "%s"' % statement
+
+def open_urlresource(url):
+    import urllib, urlparse
+    import os.path
+
+    filename = urlparse.urlparse(url)[2].split('/')[-1] # '/': it's URL!
+
+    for path in [os.path.curdir, os.path.pardir]:
+        fn = os.path.join(path, filename)
+        if os.path.exists(fn):
+            return open(fn)
+
+    requires('urlfetch')
+    print >> get_original_stdout(), '\tfetching %s ...' % url
+    fn, _ = urllib.urlretrieve(url, filename)
+    return open(fn)
+
+#=======================================================================
+# Decorator for running a function in a different locale, correctly resetting
+# it afterwards.
+
+def run_with_locale(catstr, *locales):
+    def decorator(func):
+        def inner(*args, **kwds):
+            try:
+                import locale
+                category = getattr(locale, catstr)
+                orig_locale = locale.setlocale(category)
+            except AttributeError:
+                # if the test author gives us an invalid category string
+                raise
+            except:
+                # cannot retrieve original locale, so do nothing
+                locale = orig_locale = None
+            else:
+                for loc in locales:
+                    try:
+                        locale.setlocale(category, loc)
+                        break
+                    except:
+                        pass
+
+            # now run the function, resetting the locale on exceptions
+            try:
+                return func(*args, **kwds)
+            finally:
+                if locale and orig_locale:
+                    locale.setlocale(category, orig_locale)
+        inner.func_name = func.func_name
+        inner.__doc__ = func.__doc__
+        return inner
+    return decorator
+
+#=======================================================================
+# Big-memory-test support. Separate from 'resources' because memory use should be configurable.
+
+# Some handy shorthands. Note that these are used for byte-limits as well
+# as size-limits, in the various bigmem tests
+_1M = 1024*1024
+_1G = 1024 * _1M
+_2G = 2 * _1G
+
+# Hack to get at the maximum value an internal index can take.
+class _Dummy:
+    def __getslice__(self, i, j):
+        return j
+MAX_Py_ssize_t = _Dummy()[:]
+
+def set_memlimit(limit):
+    import re
+    global max_memuse
+    sizes = {
+        'k': 1024,
+        'm': _1M,
+        'g': _1G,
+        't': 1024*_1G,
+    }
+    m = re.match(r'(\d+(\.\d+)?) (K|M|G|T)b?$', limit,
+                 re.IGNORECASE | re.VERBOSE)
+    if m is None:
+        raise ValueError('Invalid memory limit %r' % (limit,))
+    memlimit = int(float(m.group(1)) * sizes[m.group(3).lower()])
+    if memlimit > MAX_Py_ssize_t:
+        memlimit = MAX_Py_ssize_t
+    if memlimit < _2G - 1:
+        raise ValueError('Memory limit %r too low to be useful' % (limit,))
+    max_memuse = memlimit
+
+def bigmemtest(minsize, memuse, overhead=5*_1M):
+    """Decorator for bigmem tests.
+
+    'minsize' is the minimum useful size for the test (in arbitrary,
+    test-interpreted units.) 'memuse' is the number of 'bytes per size' for
+    the test, or a good estimate of it. 'overhead' specifies fixed overhead,
+    independant of the testsize, and defaults to 5Mb.
+
+    The decorator tries to guess a good value for 'size' and passes it to
+    the decorated test function. If minsize * memuse is more than the
+    allowed memory use (as defined by max_memuse), the test is skipped.
+    Otherwise, minsize is adjusted upward to use up to max_memuse.
+    """
+    def decorator(f):
+        def wrapper(self):
+            if not max_memuse:
+                # If max_memuse is 0 (the default),
+                # we still want to run the tests with size set to a few kb,
+                # to make sure they work. We still want to avoid using
+                # too much memory, though, but we do that noisily.
+                maxsize = 5147
+                self.failIf(maxsize * memuse + overhead > 20 * _1M)
+            else:
+                maxsize = int((max_memuse - overhead) / memuse)
+                if maxsize < minsize:
+                    # Really ought to print 'test skipped' or something
+                    if verbose:
+                        sys.stderr.write("Skipping %s because of memory "
+                                         "constraint\n" % (f.__name__,))
+                    return
+                # Try to keep some breathing room in memory use
+                maxsize = max(maxsize - 50 * _1M, minsize)
+            return f(self, maxsize)
+        wrapper.minsize = minsize
+        wrapper.memuse = memuse
+        wrapper.overhead = overhead
+        return wrapper
+    return decorator
+
+def bigaddrspacetest(f):
+    """Decorator for tests that fill the address space."""
+    def wrapper(self):
+        if max_memuse < MAX_Py_ssize_t:
+            if verbose:
+                sys.stderr.write("Skipping %s because of memory "
+                                 "constraint\n" % (f.__name__,))
+        else:
+            return f(self)
+    return wrapper
+
+#=======================================================================
+# Preliminary PyUNIT integration.
+
+import unittest
+
+
+class BasicTestRunner:
+    def run(self, test):
+        result = unittest.TestResult()
+        test(result)
+        return result
+
+
+def run_suite(suite, testclass=None):
+    """Run tests from a unittest.TestSuite-derived class."""
+    if verbose:
+        runner = unittest.TextTestRunner(sys.stdout, verbosity=2)
+    else:
+        runner = BasicTestRunner()
+
+    result = runner.run(suite)
+    if not result.wasSuccessful():
+        if len(result.errors) == 1 and not result.failures:
+            err = result.errors[0][1]
+        elif len(result.failures) == 1 and not result.errors:
+            err = result.failures[0][1]
+        else:
+            if testclass is None:
+                msg = "errors occurred; run in verbose mode for details"
+            else:
+                msg = "errors occurred in %s.%s" \
+                      % (testclass.__module__, testclass.__name__)
+            raise TestFailed(msg)
+        raise TestFailed(err)
+
+
+def run_unittest(*classes):
+    """Run tests from unittest.TestCase-derived classes."""
+    suite = unittest.TestSuite()
+    for cls in classes:
+        if isinstance(cls, (unittest.TestSuite, unittest.TestCase)):
+            suite.addTest(cls)
+        else:
+            suite.addTest(unittest.makeSuite(cls))
+    if len(classes)==1:
+        testclass = classes[0]
+    else:
+        testclass = None
+    run_suite(suite, testclass)
+
+
+#=======================================================================
+# doctest driver.
+
+def run_doctest(module, verbosity=None):
+    """Run doctest on the given module.  Return (#failures, #tests).
+
+    If optional argument verbosity is not specified (or is None), pass
+    test_support's belief about verbosity on to doctest.  Else doctest's
+    usual behavior is used (it searches sys.argv for -v).
+    """
+
+    import doctest
+
+    if verbosity is None:
+        verbosity = verbose
+    else:
+        verbosity = None
+
+    # Direct doctest output (normally just errors) to real stdout; doctest
+    # output shouldn't be compared by regrtest.
+    save_stdout = sys.stdout
+    sys.stdout = get_original_stdout()
+    try:
+        f, t = doctest.testmod(module, verbose=verbosity)
+        if f:
+            raise TestFailed("%d of %d doctests failed" % (f, t))
+    finally:
+        sys.stdout = save_stdout
+    if verbose:
+        print 'doctest (%s) ... %d tests with zero failures' % (module.__name__, t)
+    return f, t
+
+#=======================================================================
+# Threading support to prevent reporting refleaks when running regrtest.py -R
+
+def threading_setup():
+    import threading
+    return len(threading._active), len(threading._limbo)
+
+def threading_cleanup(num_active, num_limbo):
+    import threading
+    import time
+
+    _MAX_COUNT = 10
+    count = 0
+    while len(threading._active) != num_active and count < _MAX_COUNT:
+        print threading._active
+        count += 1
+        time.sleep(0.1)
+
+    count = 0
+    while len(threading._limbo) != num_limbo and count < _MAX_COUNT:
+        print threading._limbo
+        count += 1
+        time.sleep(0.1)
+
+def reap_children():
+    """Use this function at the end of test_main() whenever sub-processes
+    are started.  This will help ensure that no extra children (zombies)
+    stick around to hog resources and create problems when looking
+    for refleaks.
+    """
+
+    # Reap all our dead child processes so we don't leave zombies around.
+    # These hog resources and might be causing some of the buildbots to die.
+    import os
+    if hasattr(os, 'waitpid'):
+        any_process = -1
+        while True:
+            try:
+                # This will raise an exception on Windows.  That's ok.
+                pid, status = os.waitpid(any_process, os.WNOHANG)
+                if pid == 0:
+                    break
+            except:
+                break

tests/testrunner.py

+#!/usr/bin/env python
+"""Unit test runner.
+
+This test runner runs each test module isolated in a subprocess, thus allowing them to
+mangle globals freely (i.e. do monkey patching).
+
+To report the results and generate statistics sqlite3 database is used.
+
+Additionally, the subprocess is killed after a timeout has passed. The test case remains
+in the database logged with the result 'TIMEOUT'.
+
+The --db option, when provided, specifies sqlite3 database that holds the test results.
+By default 'testresults.sqlite3' is used in the current directory.
+If the a mercurial repository is detected and the current working copy is "dirty", that is,
+has uncommited changes, then '/tmp/testresults.sqlite3' is used.
+
+The results are stored in the following 2 tables:
+
+testcase:
+
+  runid   | test   | testcase        | result                 | time |
+  --------+--------+-----------------+------------------------+------+
+  abc123  | module | class.function  | PASSED|FAILED|TIMEOUT  | 0.01 |
+
+test:
+
+  runid   | test    | python | output | retcode | changeset   | uname | started_at |
+  --------+---------+--------+--------+---------+-------------+-------+------------+
+  abc123  | module  | 2.6.4  | ...    |       1 | 123_fe43ca+ | Linux |            |
+
+Set runid with --runid option. It must not exists in the database. The random
+one will be selected if not provided.
+"""
+
+# Known issues:
+# - screws up warnings location, causing them to appear as originated from testrunner.py
+
+# the number of seconds each test script is allowed to run
+DEFAULT_TIMEOUT = 60
+
+# the number of bytes of output that is recorded; the rest is thrown away
+OUTPUT_LIMIT = 15*1024
+
+ignore_tracebacks = ['ExpectedException', 'test_support.TestSkipped', 'test.test_support.TestSkipped']
+
+import sys
+import os
+import glob
+import re
+import traceback
+from unittest import _TextTestResult, defaultTestLoader, TextTestRunner
+import platform
+
+try:
+    import sqlite3
+except ImportError:
+    try:
+        import pysqlite2.dbapi2 as sqlite3
+    except ImportError:
+        sqlite3 = None
+
+_column_types = {'time': 'real'}
+
+
+def store_record(database_path, table, dictionary, _added_colums_per_db={}):
+    if sqlite3 is None:
+        return
+    conn = sqlite3.connect(database_path)
+    _added_columns = _added_colums_per_db.setdefault(database_path, set())
+    keys = dictionary.keys()
+    for key in keys:
+        if key not in _added_columns:
+            try:
+                sql = '''alter table %s add column %s %s''' % (table, key, _column_types.get(key))
+                conn.execute(sql)
+                conn.commit()
+                _added_columns.add(key)
+            except sqlite3.OperationalError, ex:
+                if 'duplicate column' not in str(ex).lower():
+                    raise
+    sql = 'insert or replace into %s (%s) values (%s)' % (table, ', '.join(keys), ', '.join(':%s' % key for key in keys))
+    cursor = conn.cursor()
+    try:
+        cursor.execute(sql, dictionary)
+    except sqlite3.Error:
+        print 'sql=%r\ndictionary=%r' % (sql, dictionary)
+        raise
+    conn.commit()
+    return cursor.lastrowid
+
+
+class DatabaseTestResult(_TextTestResult):
+    separator1 = '=' * 70
+    separator2 = '-' * 70
+
+    def __init__(self, database_path, runid, module_name, stream, descriptions, verbosity):
+        _TextTestResult.__init__(self, stream, descriptions, verbosity)
+        self.database_path = database_path
+        self.params = {'runid': runid,
+                       'test': module_name}
+
+    def startTest(self, test):
+        _TextTestResult.startTest(self, test)
+        self.params['testcase'] = test.id().replace('__main__.', '')
+        self.params['result'] = 'TIMEOUT'
+        row_id = store_record(self.database_path, 'testcase', self.params)
+        self.params['id'] = row_id
+        from time import time
+        self.time = time()
+
+    def _store_result(self, test, result):
+        self.params['result'] = result
+        from time import time
+        self.params['time'] = time() - self.time
+        store_record(self.database_path, 'testcase', self.params)
+        self.params.pop('id', None)
+
+    def addSuccess(self, test):
+        _TextTestResult.addSuccess(self, test)
+        self._store_result(test, 'PASSED')
+
+    def addError(self, test, err):
+        _TextTestResult.addError(self, test, err)
+        self._store_result(test, format_exc_info(err))
+
+    def addFailure(self, test, err):
+        _TextTestResult.addFailure(self, test, err)
+        self._store_result(test, format_exc_info(err))
+
+
+def format_exc_info(exc_info):
+    try:
+        return '%s: %s' % (exc_info[0].__name__, exc_info[1])
+    except Exception:
+        return str(exc_info[1]) or str(exc_info[0]) or 'FAILED'
+
+
+class DatabaseTestRunner(TextTestRunner):
+
+    def __init__(self, database_path, runid, module_name, stream=sys.stderr, descriptions=1, verbosity=1):
+        self.database_path = database_path
+        self.runid = runid
+        self.module_name = module_name
+        TextTestRunner.__init__(self, stream=stream, descriptions=descriptions, verbosity=verbosity)
+
+    def _makeResult(self):
+        return DatabaseTestResult(self.database_path, self.runid, self.module_name, self.stream, self.descriptions, self.verbosity)
+
+
+def get_changeset():
+    try:
+        diff = os.popen(r"hg diff 2> /dev/null").read().strip()
+    except Exception:
+        diff = None
+    try:
+        changeset = os.popen(r"hg log -r tip 2> /dev/null | grep changeset").readlines()[0]
+        changeset = changeset.replace('changeset:', '').strip().replace(':', '_')
+        if diff:
+            changeset += '+'
+    except Exception:
+        changeset = ''
+    return changeset
+
+
+def get_libevent_version():
+    from gevent import core
+    libevent_version = core.get_version()
+    if core.get_header_version() != core.get_version() and core.get_header_version() is not None:
+        libevent_version += '/headers=%s' % core.get_header_version()
+    return libevent_version
+
+
+def get_libevent_method():
+    from gevent import core
+    return core.get_method()
+
+
+def get_tempnam():
+    import warnings
+    warnings.filterwarnings('ignore', 'tempnam is a potential security risk to your program')
+    try:
+        tempnam = os.tempnam()
+    finally:
+        del warnings.filters[0]
+    return os.path.join(os.path.dirname(tempnam), 'testresults.sqlite3')
+
+
+def run_tests(options, args):
+    if len(args) != 1:
+        sys.exit('--record requires exactly one test module to run')
+    arg = args[0]
+    module_name = arg
+    if module_name.endswith('.py'):
+        module_name = module_name[:-3]
+    class _runner(object):
+        def __new__(cls, *args, **kawrgs):
+            return DatabaseTestRunner(database_path=options.db, runid=options.runid, module_name=module_name, verbosity=options.verbosity)
+    if options.db:
+        import unittest
+        unittest.TextTestRunner = _runner
+        import test_support
+        test_support.BasicTestRunner = _runner
+    if os.path.exists(arg):
+        sys.argv = args
+        saved_globals = {'__file__': __file__}
+        try:
+            globals()['__file__'] = arg
+            # QQQ this makes tests reported as if they are from __main__ and screws up warnings location
+            execfile(arg, globals())
+        finally:
+            globals().update(saved_globals)
+    else:
+        test = defaultTestLoader.loadTestsFromName(arg)
+        result = _runner().run(test)
+        sys.exit(not result.wasSuccessful())
+
+
+def run_subprocess(arg, options):
+    from threading import Timer
+    from mysubprocess import Popen, PIPE, STDOUT
+
+    popen_args = [sys.executable, sys.argv[0], '--record',
+                  '--runid', options.runid,
+                  '--verbosity', options.verbosity]
+    if options.db:
+        popen_args += ['--db', options.db]
+    popen_args += [arg]
+    popen_args = [str(x) for x in popen_args]
+    if options.capture:
+        popen = Popen(popen_args, stdout=PIPE, stderr=STDOUT, shell=False)
+    else:
+        popen = Popen(popen_args, shell=False)
+
+    retcode = []
+
+    def killer():
+        retcode.append('TIMEOUT')
+        print >> sys.stderr, 'Killing %s (%s) because of timeout' % (popen.pid, arg)
+        popen.kill()
+
+    timeout = Timer(options.timeout, killer)
+    timeout.start()
+    output = ''
+    output_printed = False
+    try:
+        try:
+            if options.capture:
+                while True:
+                    data = popen.stdout.read(1)
+                    if not data:
+                        break
+                    output += data
+                    if options.verbosity >= 2:
+                        sys.stdout.write(data)
+                        output_printed = True
+            retcode.append(popen.wait())
+        except Exception:
+            popen.kill()
+            raise
+    finally:
+        timeout.cancel()
+    # QQQ compensating for run_tests' screw up
+    module_name = arg
+    if module_name.endswith('.py'):
+        module_name = module_name[:-3]
+    output = output.replace(' (__main__.', ' (' + module_name + '.')
+    return retcode[0], output, output_printed
+
+
+def spawn_subprocess(arg, options, base_params):
+    success = False
+    if options.db:
+        module_name = arg
+        if module_name.endswith('.py'):
+            module_name = module_name[:-3]
+        from datetime import datetime
+        params = base_params.copy()
+        params.update({'started_at': datetime.now(),
+                       'test': module_name})
+        row_id = store_record(options.db, 'test', params)
+        params['id'] = row_id
+    retcode, output, output_printed = run_subprocess(arg, options)
+    if len(output) > OUTPUT_LIMIT:
+        output = output[:OUTPUT_LIMIT] + '<AbridgedOutputWarning>'
+    if retcode:
+        if retcode == 1 and 'test_support.TestSkipped' in output:
+            pass
+        else:
+            if not output_printed and options.verbosity >= -1:
+                sys.stdout.write(output)
+            print '%s failed with code %s' % (arg, retcode)
+    elif retcode == 0:
+        if not output_printed and options.verbosity >= 1:
+            sys.stdout.write(output)
+        if options.verbosity >= 0:
+            print '%s passed' % arg
+        success = True
+    else:
+        print '%s timed out' % arg
+    if options.db:
+        params['output'] = output
+        params['retcode'] = retcode
+        store_record(options.db, 'test', params)
+    return success
+
+
+def spawn_subprocesses(options, args):
+    params = {'runid': options.runid,
+              'python': '%s.%s.%s' % sys.version_info[:3],
+              'changeset': get_changeset(),
+              'libevent_version': get_libevent_version(),
+              'libevent_method': get_libevent_method(),
+              'uname': platform.uname()[0],
+              'retcode': 'TIMEOUT'}
+    success = True
+    if not args:
+        args = glob.glob('test_*.py')
+        args.remove('test_support.py')
+    for arg in args:
+        try:
+            success = spawn_subprocess(arg, options, params) and success
+        except Exception:
+            traceback.print_exc()
+    if options.db:
+        try:
+            print '-' * 80
+            if print_stats(options):
+                success = False
+        except sqlite3.OperationalError:
+            traceback.print_exc()
+        print 'To view stats again for this run, use %s --stats --runid %s --db %s' % (sys.argv[0], options.runid, options.db)
+    if not success:
+        sys.exit(1)
+
+
+def get_testcases(cursor, runid, result=None):
+    sql = 'select test, testcase from testcase where runid=?'
+    args = (runid, )
+    if result is not None:
+        sql += ' and result=?'
+        args += (result, )
+    return ['.'.join(x) for x in cursor.execute(sql, args).fetchall()]
+
+
+def get_failed_testcases(cursor, runid):
+    sql = 'select test, testcase, result from testcase where runid=?'
+    args = (runid, )
+    sql += ' and result!="PASSED" and result!="TIMEOUT"'
+    names = []
+    errors = {}
+    for test, testcase, result in cursor.execute(sql, args).fetchall():
+        name = '%s.%s' % (test, testcase)
+        names.append(name)
+        errors[name] = result
+    return names, errors
+
+
+_warning_re = re.compile('\w*warning', re.I)
+_error_re = re.compile(r'(?P<prefix>\s*)Traceback \(most recent call last\):' +
+                       r'(\n(?P=prefix)[ \t]+[^\n]*)+\n(?P=prefix)(?P<error>[\w\.]+)')
+
+
+def get_warnings(output):
+    """
+    >>> get_warnings('hello DeprecationWarning warning: bla DeprecationWarning')
+    ['DeprecationWarning', 'warning', 'DeprecationWarning']
+    """
+    if len(output) <= OUTPUT_LIMIT:
+        return _warning_re.findall(output)
+    else:
+        return _warning_re.findall(output[:OUTPUT_LIMIT]) + ['AbridgedOutputWarning']
+
+
+def get_exceptions(output):
+    """
+    >>> get_exceptions('''test$ python -c "1/0"
+    ... Traceback (most recent call last):
+    ...   File "<string>", line 1, in <module>
+    ... ZeroDivisionError: integer division or modulo by zero''')
+    ['ZeroDivisionError']
+    """
+    return [x.group('error') for x in _error_re.finditer(output)]
+
+
+def get_warning_stats(output):
+    counter = {}
+    for warning in get_warnings(output):
+        counter.setdefault(warning, 0)
+        counter[warning] += 1
+    items = counter.items()
+    items.sort(key=lambda (a, b): -b)
+    result = []
+    for name, count in items:
+        if count == 1:
+            result.append(name)
+        else:
+            result.append('%s %ss' % (count, name))
+    return result
+
+
+def get_ignored_tracebacks(test):
+    if os.path.exists(test + '.py'):
+        data = open(test + '.py').read()
+        m = re.search('Ignore tracebacks: (.*)', data)
+        if m is not None:
+            return m.group(1).split()
+    return []
+
+
+def get_traceback_stats(output, test):
+    ignored = get_ignored_tracebacks(test) or ignore_tracebacks
+    counter = {}
+    traceback_count = output.lower().count('Traceback (most recent call last)')
+    ignored_list = []
+    for error in get_exceptions(output):
+        if error in ignored:
+            ignored_list.append(error)
+        else:
+            counter.setdefault(error, 0)
+            counter[error] += 1
+        traceback_count -= 1
+    items = counter.items()
+    items.sort(key=lambda (a, b): -b)
+    if traceback_count>0:
+        items.append(('other traceback', traceback_count))
+    result = []
+    for name, count in items:
+        if count == 1:
+            result.append('1 %s' % name)
+        else:
+            result.append('%s %ss' % (count, name))
+    return result, ignored_list
+
+
+def get_info(output, test):
+    output = output[:OUTPUT_LIMIT*2]
+    traceback_stats, ignored_list = get_traceback_stats(output, test)
+    warning_stats = get_warning_stats(output)
+    result = traceback_stats + warning_stats
+    skipped = not warning_stats and not traceback_stats and ignored_list in [['test_support.TestSkipped'], ['test.test_support.TestSkipped']]
+    return ', '.join(result), skipped
+
+
+def print_stats(options):
+    db = sqlite3.connect(options.db)
+    cursor = db.cursor()
+    if options.runid is None:
+        options.runid = cursor.execute('select runid from test order by started_at desc limit 1').fetchall()[0][0]
+        print 'Using the latest runid: %s' % options.runid
+    total = len(get_testcases(cursor, options.runid))
+    failed, errors = get_failed_testcases(cursor, options.runid)
+    timedout = get_testcases(cursor, options.runid, 'TIMEOUT')
+    for test, output, retcode in cursor.execute('select test, output, retcode from test where runid=?', (options.runid, )):
+        info, skipped = get_info(output or '', test)
+        if info:
+            print '%s: %s' % (test, info)
+        if retcode == 'TIMEOUT':
+            for testcase in timedout:
+                if testcase.startswith(test + '.'):
+                    break
+            else:
+                timedout.append(test)
+                total += 1
+        elif retcode != 0:
+            for testcase in failed:
+                if testcase.startswith(test + '.'):
+                    break
+            else:
+                if not skipped:
+                    failed.append(test)
+                    total += 1
+    if failed:
+        failed.sort()
+        print 'FAILURES: '
+        for testcase in failed:
+            error = errors.get(testcase)
+            if error:
+                error = repr(error)[1:-1][:100]
+                print ' - %s: %s' % (testcase, error)
+            else:
+                print ' - %s' % (testcase, )
+    if timedout:
+        print 'TIMEOUTS: '
+        print ' - ' + '\n - '.join(timedout)
+    print '%s testcases passed; %s failed; %s timed out' % (total, len(failed), len(timedout))
+    if failed or timedout:
+        return True
+    return False
+
+
+def main():
+    import optparse
+    parser = optparse.OptionParser()
+    parser.add_option('-v', '--verbose', default=0, action='count')
+    parser.add_option('-q', '--quiet', default=0, action='count')
+    parser.add_option('--verbosity', default=0, type='int', help=optparse.SUPPRESS_HELP)
+    parser.add_option('--db')
+    parser.add_option('--runid')
+    parser.add_option('--record', default=False, action='store_true')
+    parser.add_option('--no-capture', dest='capture', default=True, action='store_false')
+    parser.add_option('--stats', default=False, action='store_true')
+    parser.add_option('--timeout', default=DEFAULT_TIMEOUT, type=float, metavar='SECONDS')
+
+    options, args = parser.parse_args()
+    options.verbosity += options.verbose - options.quiet
+
+    if not options.db and sqlite3:
+        if get_changeset().endswith('+'):
+            options.db = get_tempnam()
+        else:
+            options.db = 'testresults.sqlite3'
+        print 'Using the database: %s' % options.db
+    elif options.db and not sqlite3:
+        sys.exit('Cannot access the database %r: no sqlite3 module found.' % (options.db, ))
+
+    if options.db:
+        db = sqlite3.connect(options.db)
+        db.execute('create table if not exists test (id integer primary key autoincrement, runid text)')
+        db.execute('create table if not exists testcase (id integer primary key autoincrement, runid text)')
+        db.commit()
+
+    if options.stats:
+        print_stats(options)
+    else:
+        if not options.runid:
+            try:
+                import uuid
+                options.runid = str(uuid.uuid4())
+            except ImportError:
+                import random
+                options.runid = str(random.random())[2:]
+            print 'Generated runid: %s' % (options.runid, )
+        if options.record:
+            run_tests(options, args)
+        else:
+            spawn_subprocesses(options, args)
+
+
+if __name__ == '__main__':
+    main()
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.