Denis Bilenko avatar Denis Bilenko committed 76b0f9f

multiple fixes and cleanups

- chrome 16 now works
- HTTP proxies, like HaProxy, now work with hixie protocol
- sending messages is now safe from multiple greenlets (added write lock)
- in hixie, WebSocket-Location/Sec-WebSocket-Location is set properly, including the right scheme query string; fixes Safari
- geventwebsocket/__init__.py now contains WebSocketHandler; websockets are not imported there, since they are not usable without handler anyway
- errors at handshake are responded with "400 Bad Request" now
- handler: removed 'websocket_connection' attribute
- renamed WebSocketLegacy to WebSocketHixie
- renamed WebSocketVersion7 to WebSocketHybi
- remove 'compatibility_mode' flag
- when the client closes the connection using "close" frame, Closed() object is now returned
- close() method now has defaults arguments, so websocket.close() now works
- added 'wsgi.websocket_version' to environ (text string)
- all exception related to protocol are subclasses of WebSocketError which is a subclass of socket.error

Comments (0)

Files changed (3)

geventwebsocket/__init__.py

 version_info = (0, 3, 0, 'dev')
 __version__ =  ".".join(map(str, version_info))
 
-try:
-    from geventwebsocket.websocket import WebSocketVersion7, WebSocketLegacy
-except ImportError:
-    import traceback
-    traceback.print_exc()
+__all__ = ['WebSocketHandler']
+
+from geventwebsocket.handler import WebSocketHandler

geventwebsocket/handler.py

 import re
 import struct
 from hashlib import md5, sha1
+from socket import error as socket_error
+from urllib import quote
 
 from gevent.pywsgi import WSGIHandler
-from geventwebsocket import WebSocketVersion7, WebSocketLegacy
-
-
-class HandShakeError(ValueError):
-    """ Hand shake challenge can't be parsed """
-    pass
-
+from geventwebsocket.websocket import WebSocketHybi, WebSocketHixie
 
 
 class WebSocketHandler(WSGIHandler):
     """ Automatically upgrades the connection to websockets. """
 
     GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
-    SUPPORTED_VERSIONS = (7,8)
+    SUPPORTED_VERSIONS = ('13', '8', '7')
 
-    def __init__(self, *args, **kwargs):
-        self.websocket_connection = False
-        self.allowed_paths = []
+    def handle_one_response(self):
+        self.pre_start()
+        environ = self.environ
+        upgrade = environ.get('HTTP_UPGRADE', '').lower()
+        if upgrade == 'websocket':
+            connection = environ.get('HTTP_CONNECTION', '').lower()
+            if 'upgrade' in connection:
+                return self._handle_websocket()
+        return super(WebSocketHandler, self).handle_one_response()
 
-        for expression in kwargs.pop('allowed_paths', []):
-            if isinstance(expression, basestring):
-                self.allowed_paths.append(re.compile(expression))
-            else:
-                self.allowed_paths.append(expression)
+    def pre_start(self):
+        pass
 
-        super(WebSocketHandler, self).__init__(*args, **kwargs)
+    def _handle_websocket(self):
+        environ = self.environ
+        try:
+            try:
+                if environ.get("HTTP_SEC_WEBSOCKET_VERSION"):
+                    result = self._handle_hybi()
+                elif environ.get("HTTP_ORIGIN"):
+                    result = self._handle_hixie()
+            except:
+                self.close_connection = True
+                raise
+            self.result = []
+            if not result:
+                return
+            self.application(environ, None)
+            return []
+        finally:
+            self.log_request()
 
-    def handle_one_response(self, call_wsgi_app=True):
-        if self.environ.get("HTTP_ORIGIN"):
-            self._handle_one_legacy_response()
-        elif self.environ.get("HTTP_SEC_WEBSOCKET_VERSION"):
-            version = int(self.environ.get("HTTP_SEC_WEBSOCKET_VERSION"))
-            if version in self.SUPPORTED_VERSIONS:
-                if not self._handle_one_version7_response():
-                    return
-            else:
-                return
-        else:
-            # not a valid websocket request
-            return super(WebSocketHandler, self).handle_one_response()
+    def _handle_hybi(self):
+        environ = self.environ
+        version = environ.get("HTTP_SEC_WEBSOCKET_VERSION")
 
-        if call_wsgi_app:
-            return self.application(self.environ, self.start_response)
-        else:
+        environ['wsgi.websocket_version'] = 'hybi-%s' % version
+
+        if version not in self.SUPPORTED_VERSIONS:
+            self.log_error('400: Unsupported Version: %r', version)
+            self.respond('400 Unsupported Version', [('Sec-WebSocket-Version', '13, 8, 7')])
             return
 
-    def _close_connection(self, reason=None):
-        # based on gevent/pywsgi.py
-        # see http://pypi.python.org/pypi/gevent#downloads
-
-        if reason:
-            print "Closing the connection because %s!" % reason
-        if self.socket is not None:
-            try:
-                self.socket._sock.close()
-                self.socket.close()
-            except socket.error:
-                pass
-
-    def _handle_one_version7_response(self):
-        environ = self.environ
-
         protocol, version = self.request_version.split("/")
         key = environ.get("HTTP_SEC_WEBSOCKET_KEY")
 
         # check client handshake for validity
         if not environ.get("REQUEST_METHOD") == "GET":
             # 5.2.1 (1)
-            self._close_connection()
-            return False
+            self.respond('400 Bad Request')
+            return
         elif not protocol == "HTTP":
             # 5.2.1 (1)
-            self._close_connection()
-            return False
+            self.respond('400 Bad Request')
+            return
         elif float(version) < 1.1:
             # 5.2.1 (1)
-            self._close_connection()
-            return False
+            self.respond('400 Bad Request')
+            return
         # XXX: nobody seems to set SERVER_NAME correctly. check the spec
         #elif not environ.get("HTTP_HOST") == environ.get("SERVER_NAME"):
             # 5.2.1 (2)
-            #self._close_connection()
-            #return False
+            #self.respond('400 Bad Request')
+            #return
         elif not key:
             # 5.2.1 (3)
-            self._close_connection()
-            return False
+            self.log_error('400: HTTP_SEC_WEBSOCKET_KEY is missing from request')
+            self.respond('400 Bad Request')
+            return
         elif len(base64.b64decode(key)) != 16:
             # 5.2.1 (3)
-            self._close_connection()
-            return False
+            self.log_error('400: Invalid key: %r', key)
+            self.respond('400 Bad Request')
+            return
 
-        #TODO: compare Sec-WebSocket-Origin against self.allowed_paths
-
-        self.websocket_connection = True
-        self.websocket = WebSocketVersion7(self.socket, self.rfile, self.environ)
-        self.environ['wsgi.websocket'] = self.websocket
+        self.websocket = WebSocketHybi(self.rfile, environ)
+        environ['wsgi.websocket'] = self.websocket
 
         headers = [
             ("Upgrade", "websocket"),
             ("Connection", "Upgrade"),
             ("Sec-WebSocket-Accept", base64.b64encode(sha1(key + self.GUID).digest())),
         ]
-        self.start_response("101 Switching Protocols", headers)
+        self._send_reply("101 Switching Protocols", headers)
         return True
 
-    def _handle_one_legacy_response(self):
-        # In case the client doesn't want to initialize a WebSocket connection
-        # we will proceed with the default PyWSGI functionality.
+    def _handle_hixie(self):
+        environ = self.environ
+        assert "upgrade" in self.environ.get("HTTP_CONNECTION", "").lower()
 
-        if "upgrade" in self.environ.get("HTTP_CONNECTION", "").lower(). \
-             replace(" ", "").split(",") and \
-             "websocket" in self.environ.get("HTTP_UPGRADE").lower() and \
-             self.accept_upgrade():
-            self.websocket_connection = True
-        else:
-            print "NORMAL"
-            from pprint import pprint
-            pprint(self.environ)
-            return super(WebSocketHandler, self).handle_one_response()
+        self.websocket = WebSocketHixie(self.rfile, environ)
+        environ['wsgi.websocket'] = self.websocket
 
-        self.websocket = WebSocketLegacy(self.socket, self.rfile, self.environ)
-        self.environ['wsgi.websocket'] = self.websocket
+        key1 = self.environ.get('HTTP_SEC_WEBSOCKET_KEY1')
+        key2 = self.environ.get('HTTP_SEC_WEBSOCKET_KEY2')
 
-        # Detect the Websocket protocol
-        if "HTTP_SEC_WEBSOCKET_KEY1" in self.environ:
-            version = 76
-        else:
-            version = 75
+        if key1 is not None:
+            environ['wsgi.websocket_version'] = 'hixie-76'
+            if not key1:
+                self.log_error("400: SEC-WEBSOCKET-KEY1 header is empty")
+                self.respond('400 Bad Request')
+                return
+            if not key2:
+                self.log_error("400: SEC-WEBSOCKET-KEY2 header is missing or empty")
+                self.respond('400 Bad Request')
+                return
 
-        if version == 75:
+            part1 = self._get_key_value(key1)
+            part2 = self._get_key_value(key2)
+            if part1 is None or part2 is None:
+                self.respond('400 Bad Request')
+                return
+
             headers = [
                 ("Upgrade", "WebSocket"),
                 ("Connection", "Upgrade"),
-                ("WebSocket-Origin", self.websocket.origin),
-                ("WebSocket-Protocol", self.websocket.protocol),
-                ("WebSocket-Location", "ws://" + self.environ.get('HTTP_HOST') + self.websocket.path),
+                ("Sec-WebSocket-Location", reconstruct_url(environ)),
             ]
-            self.start_response("101 Web Socket Protocol Handshake", headers)
-        elif version == 76:
-            challenge = self._get_challenge()
+            if self.websocket.protocol is not None:
+                headers.append(("Sec-WebSocket-Protocol", self.websocket.protocol))
+            if self.websocket.origin:
+                headers.append(("Sec-WebSocket-Origin", self.websocket.origin))
+
+            self._send_reply("101 Web Socket Protocol Handshake", headers)
+
+            # This request should have 8 bytes of data in the body
+            key3 = self.rfile.read(8)
+
+            challenge = md5(struct.pack("!II", part1, part2) + key3).digest()
+
+            self.socket.sendall(challenge)
+            return True
+        else:
+            environ['wsgi.websocket_version'] = 'hixie-75'
             headers = [
                 ("Upgrade", "WebSocket"),
                 ("Connection", "Upgrade"),
-                ("Sec-WebSocket-Origin", self.websocket.origin),
-                ("Sec-WebSocket-Protocol", self.websocket.protocol),
-                ("Sec-WebSocket-Location", "ws://" + self.environ.get('HTTP_HOST') + self.websocket.path),
+                ("WebSocket-Location", reconstruct_url(environ)),
             ]
+            if self.websocket.protocol is not None:
+                headers.append(("WebSocket-Protocol", self.websocket.protocol))
+            if self.websocket.origin:
+                headers.append(("WebSocket-Origin", self.websocket.origin))
 
-            self.start_response("101 Web Socket Protocol Handshake", headers)
-            self.write(challenge)
-        else:
-            raise Exception("Version not supported")
+            self._send_reply("101 Web Socket Protocol Handshake", headers)
 
-    def accept_upgrade(self):
-        """
-        Returns True if request is allowed to be upgraded.
-        If self.allowed_paths is non-empty, self.environ['PATH_INFO'] will
-        be matched against each of the regular expressions.
-        """
+    def _send_reply(self, status, headers):
+        self.status = status
 
-        if self.allowed_paths:
-            path_info = self.environ.get('PATH_INFO', '')
+        towrite = []
+        towrite.append('%s %s\r\n' % (self.request_version, self.status))
 
-            for regexps in self.allowed_paths:
-                return regexps.match(path_info)
-        else:
-            return True
+        for header in headers:
+            towrite.append("%s: %s\r\n" % header)
 
-    def write(self, data):
-        if self.websocket_connection:
-            if data:
-                self.socket.sendall(data)
-            else:
-                raise Exception("No data to send")
-        else:
-            super(WebSocketHandler, self).write(data)
+        towrite.append("\r\n")
+        msg = ''.join(towrite)
+        self.socket.sendall(msg)
+        self.headers_sent = True
 
-    def start_response(self, status, headers, exc_info=None):
-        if self.websocket_connection:
-            self.status = status
-
-            towrite = []
-            towrite.append('%s %s\r\n' % (self.request_version, self.status))
-
-            for header in headers:
-                towrite.append("%s: %s\r\n" % header)
-
-            towrite.append("\r\n")
-            msg = ''.join(towrite)
-            self.socket.sendall(msg)
-            self.headers_sent = True
-        else:
-            super(WebSocketHandler, self).start_response(status, headers, exc_info)
+    def respond(self, status, headers=[]):
+        self.close_connection = True
+        self._send_reply(status, headers)
+        if self.socket is not None:
+            try:
+                self.socket._sock.close()
+                self.socket.close()
+            except socket_error:
+                pass
 
     def _get_key_value(self, key_value):
         key_number = int(re.sub("\\D", "", key_value))
         spaces = re.subn(" ", "", key_value)[1]
 
         if key_number % spaces != 0:
-            raise HandShakeError("key_number %d is not an intergral multiple of"
-                                 " spaces %d" % (key_number, spaces))
+            self.log_error("key_number %d is not an intergral multiple of spaces %d", key_number, spaces)
+        else:
+            return key_number / spaces
 
-        return key_number / spaces
 
-    def _get_challenge(self):
-        key1 = self.environ.get('HTTP_SEC_WEBSOCKET_KEY1')
-        key2 = self.environ.get('HTTP_SEC_WEBSOCKET_KEY2')
+def reconstruct_url(environ):
+    secure = environ['wsgi.url_scheme'] == 'https'
+    if secure:
+        url = 'wss://'
+    else:
+        url = 'ws://'
 
-        if not key1:
-            raise BadRequest("SEC-WEBSOCKET-KEY1 header is missing")
-        if not key2:
-            raise BadRequest("SEC-WEBSOCKET-KEY2 header is missing")
+    if environ.get('HTTP_HOST'):
+        url += environ['HTTP_HOST']
+    else:
+        url += environ['SERVER_NAME']
 
-        part1 = self._get_key_value(self.environ['HTTP_SEC_WEBSOCKET_KEY1'])
-        part2 = self._get_key_value(self.environ['HTTP_SEC_WEBSOCKET_KEY2'])
+        if secure:
+            if environ['SERVER_PORT'] != '443':
+               url += ':' + environ['SERVER_PORT']
+        else:
+            if environ['SERVER_PORT'] != '80':
+               url += ':' + environ['SERVER_PORT']
 
-        # This request should have 8 bytes of data in the body
-        key3 = self.rfile.read(8)
-
-        return md5(struct.pack("!II", part1, part2) + key3).digest()
-
-
-    def wait(self):
-        return self.websocket.wait()
-
-    def send(self, message):
-        return self.websocket.send(message)
+    url += quote(environ.get('SCRIPT_NAME', ''))
+    url += quote(environ.get('PATH_INFO', ''))
+    if environ.get('QUERY_STRING'):
+        url += '?' + environ['QUERY_STRING']
+    return url

geventwebsocket/websocket.py

+from socket import error as socket_error
 import struct
+from gevent.coros import Semaphore
 
-class WebSocket(object):
+
+class Closed(object):
+
+    def __init__(self, reason, message):
+        self.reason = reason
+        self.message = message
+
+    def __nonzero__(self):
+        return False
+
+    def __repr__(self):
+        return '%s(%r, %r)' % (self.__class__.__name__, self.reason, self.message)
+
+
+class WebSocketError(socket_error):
     pass
 
 
-class ProtocolException(Exception):
+class FrameTooLargeException(WebSocketError):
     pass
 
 
-class FrameTooLargeException(Exception):
-    pass
+class WebSocketHixie(object):
 
-
-class WebSocketLegacy(object):
-    def __init__(self, sock, rfile, environ):
-        self.rfile = rfile
-        self.socket = sock
+    def __init__(self, fobj, environ):
         self.origin = environ.get('HTTP_ORIGIN')
         self.protocol = environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL')
         self.path = environ.get('PATH_INFO')
-        self.websocket_closed = False
+        self._writelock = Semaphore(1)
+        self.fobj = fobj
+        self._write = _get_write(fobj)
 
     def send(self, message):
-        if self.websocket_closed:
-            raise Exception("Connection was terminated")
-
         if isinstance(message, unicode):
             message = message.encode('utf-8')
         elif isinstance(message, str):
         else:
             raise TypeError("Invalid message encoding")
 
-        self.socket.sendall("\x00" + message + "\xFF")
+        with self._writelock:
+            self._write("\x00" + message + "\xFF")
 
-    def close_connection(self):
-        if not self.websocket_closed:
-            self.websocket_closed = True
-            self.socket.shutdown(True)
-            self.socket.close()
-        else:
-            return
+    def close(self):
+        if self.fobj is not None:
+            self.fobj.close()
+            self.fobj = None
 
     def _message_length(self):
         length = 0
 
         while True:
-            byte_str = self.rfile.read(1)
+            if self.fobj is None:
+                raise WebSocketError('Connenction closed unexpectedly while reading message length')
+            byte_str = self.fobj.read(1)
 
             if not byte_str:
                 return 0
     def _read_until(self):
         bytes = []
 
+        read = self.fobj.read
+
         while True:
-            byte = self.rfile.read(1)
+            if self.fobj is None:
+                msg = ''.join(bytes)
+                raise WebSocketError('Connenction closed unexpectedly while reading message: %r' % msg)
+            byte = read(1)
             if ord(byte) != 0xff:
                 bytes.append(byte)
             else:
         return ''.join(bytes)
 
     def receive(self):
-        while True:
-            if self.websocket_closed:
-                return None
-
-            frame_str = self.rfile.read(1)
+        read = self.fobj.read
+        while self.fobj is not None:
+            frame_str = read(1)
             if not frame_str:
-                # Connection lost?
-                self.websocket_closed = True
-                continue
+                self.close()
+                return
             else:
                 frame_type = ord(frame_str)
 
+            if frame_type == 0x00:
+                bytes = self._read_until()
+                return bytes.decode("utf-8", "replace")
+            else:
+                raise WebSocketError("Received an invalid frame_type=%r" % frame_type)
 
-            if (frame_type & 0x80) == 0x00: # most significant byte is not set
 
-                if frame_type == 0x00:
-                    bytes = self._read_until()
-                    return bytes.decode("utf-8", "replace")
-                else:
-                    self.websocket_closed = True
-
-            elif (frame_type & 0x80) == 0x80: # most significant byte is set
-                # Read binary data (forward-compatibility)
-                if frame_type != 0xff:
-                    self.websocket_closed = True
-                else:
-                    length = self._message_length()
-                    if length == 0:
-                        self.websocket_closed = True
-                    else:
-                        self.rfile.read(length) # discard the bytes
-            else:
-                raise IOError("Reveiced an invalid message")
-
-
-class WebSocketVersion7(WebSocketLegacy):
+class WebSocketHybi(object):
     FIN = int("10000000", 2)
     RSV = int("01110000", 2)
     OPCODE = int("00001111", 2)
     LEN_16 = 126
     LEN_64 = 127
 
-    def __init__(self, sock, rfile, environ, compatibility_mode=True):
-        self.rfile = rfile
-        self.socket = sock
+    def __init__(self, fobj, environ):
         self.origin = environ.get('HTTP_SEC_WEBSOCKET_ORIGIN')
         self.protocol = environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL', 'unknown')
         self.path = environ.get('PATH_INFO')
-        self.websocket_closed = False
         self._chunks = bytearray()
         self._first_opcode = None
+        self._writelock = Semaphore(1)
+        self.fobj = fobj
+        self._write = _get_write(fobj)
 
-    def _read_from_socket(self, count):
-        return self.rfile.read(count)
+    def _parse_header(self, data):
+        if len(data) != 2:
+            self.close()
+            raise WebSocketError('Incomplete read while reading header: %r' % data)
+        first_byte, second_byte = struct.unpack('!BB', data)
+
+        fin = (first_byte >> 7) & 1
+        rsv1 = (first_byte >> 6) & 1
+        rsv2 = (first_byte >> 5) & 1
+        rsv3 = (first_byte >> 4) & 1
+        opcode = first_byte & 0xf
+
+        # frame-fin = %x0 ; more frames of this message follow
+        #           / %x1 ; final frame of this message
+
+        # frame-rsv1 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise
+        # frame-rsv2 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise
+        # frame-rsv3 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise
+        if rsv1 or rsv2 or rsv3:
+            self.close()
+            raise WebSocketError('Reserved bits cannot be set: %r' % data)
+
+        #if self._is_invalid_opcode(opcode):
+        #    raise WebSocketError('Invalid opcode %x' % opcode)
+
+        # control frames cannot be fragmented
+        if opcode > 0x7 and fin == 0:
+            self.close()
+            raise WebSocketError('Control frames cannot be fragmented: %r' % data)
+
+        if len(self._chunks) > 0 and fin == 0 and opcode != self.OPCODE_CONTINUATION:
+            self.close(self.REASON_PROTOCOL_ERROR, 'Received new fragment frame with non-zero opcode')
+            raise WebSocketError('Received new fragment frame with non-zero opcode: %r' % data)
+
+        if len(self._chunks) > 0 and fin == 1 and (self.OPCODE_TEXT <= opcode <= self.OPCODE_BINARY):
+            self.close(self.REASON_PROTOCOL_ERROR, 'Received new unfragmented data frame during fragmented message')
+            raise WebSocketError('Received new unfragmented data frame during fragmented message: %r' % data)
+
+        mask = (second_byte >> 7) & 1
+        length = (second_byte) & 0x7f
+
+        #if not self.MASK & length_octet: # TODO: where is this in the docs?
+        #    self.close(self.REASON_PROTOCOL_ERROR, 'MASK must be set')
+
+        # Control frames MUST have a payload length of 125 bytes or less
+        if opcode > 0x7 and length > 125:
+            self.close()
+            raise FrameTooLargeException("Control frame payload cannot be larger than 125 bytes: %r" % data)
+
+        return fin, opcode, mask, length
 
     def receive(self):
-        """Return the next frame from the socket
+        """Return the next frame from the socket."""
+        if self.fobj is None:
+            return
 
-        If the next frame is invalid, wait closes the socket and returns None.
-
-        If the next frame is valid and the websocket instance's
-        compatibility_mode attribute is True, then wait ignores PING and PONG
-        frames, returns None when sent a CLOSE frame and returns the payload
-        for data frames.
-
-        If the next frame is valid and the websocket instance's
-        compatibility_mode attribute is False, it returns a tuple of the form
-        (opcode, payload).
-        """
+        read = self.fobj.read
 
         while True:
-            if self.websocket_closed:
-                return None
-
-            payload = ""
-            first_byte, second_byte = struct.unpack('!BB', self._read_from_socket(2))
-
-            fin = (first_byte >> 7) & 1
-            rsv1 = (first_byte >> 6) & 1
-            rsv2 = (first_byte >> 5) & 1
-            rsv3 = (first_byte >> 4) & 1
-            opcode = first_byte & 0xf
-
-            # frame-fin = %x0 ; more frames of this message follow
-            #           / %x1 ; final frame of this message
-            if fin not in (0, 1):
-                raise ProtocolException("")
-
-            # frame-rsv1 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise
-            # frame-rsv2 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise
-            # frame-rsv3 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise
-            if rsv1 or rsv2 or rsv3:
-                raise ProtocolException('Reserved bits cannot be set')
-
-            #if self._is_invalid_opcode(opcode):
-            #    raise ProtocolException('Invalid opcode %x' % opcode)
-
-            # control frames cannot be fragmented
-            if opcode > 0x7 and fin == 0:
-                raise ProtocolException('Control frames cannot be fragmented')
-
-            if len(self._chunks) > 0 and \
-                    fin == 0 and opcode != self.OPCODE_CONTINUATION:
-                self.close(self.REASON_PROTOCOL_ERROR,
-                        'Received new fragment frame with non-zero opcode')
+            data0 = read(2)
+            if not data0:
+                self._close()
                 return
 
-            if len(self._chunks) > 0 and \
-                    fin == 1 and (self.OPCODE_TEXT <= opcode <= self.OPCODE_BINARY):
-                self.close(self.REASON_PROTOCOL_ERROR,
-                        'Received new unfragmented data frame during fragmented message')
+            fin, opcode, mask, length = self._parse_header(data0)
 
-            mask = (second_byte >> 7) & 1
-            payload_length = (second_byte) & 0x7f
-
-            #if not self.MASK & length_octet: # TODO: where is this in the docs?
-            #    self.close(self.REASON_PROTOCOL_ERROR, 'MASK must be set')
-
-            # Control frames MUST have a payload length of 125 bytes or less
-            if opcode > 0x7 and payload_length > 125:
-                raise FrameTooLargeException("Control frame payload cannot be larger than 125 bytes")
-
-            if payload_length < 126:
-                length = payload_length
-            elif payload_length == 126:
-                length = struct.unpack('!H', self._read_from_socket(2))[0]
-            elif payload_length == 127:
-                length = struct.unpack('!Q', self._read_from_socket(8))[0]
+            if length < 126:
+                data1 = ''
+            elif length == 126:
+                data1 = read(2)
+                if len(data1) != 2:
+                    self.close()
+                    raise WebSocketError('Incomplete read while reading 2-byte length: %r' % (data0 + data1))
+                length = struct.unpack('!H', data1)[0]
+            elif length == 127:
+                data1 = read(8)
+                if len(data1) != 8:
+                    self.close()
+                    raise WebSocketError('Incomplete read while reading 8-byte length: %r' % (data0 + data1))
+                length = struct.unpack('!Q', data1)[0]
             else:
-                raise ProtocolException('Calculated invalid length')
-
-            payload = ""
+                self.close()
+                raise WebSocketError('Invalid length: %r' % data0)
 
             # Unmask the payload if necessary
+            if mask and length:
+                data2 = read(4)
+                if len(data2) != 4:
+                    self.close()
+                    raise WebSocketError('Incomplete read while reading mask: %r' % (data0 + data1 + data2))
+                masking_key = struct.unpack('!BBBB', data2)
+            else:
+                data2 = ''
+
+            if length:
+                payload = read(length)
+                if len(payload) != length:
+                    self.close()
+                    args = (length, data0 + data1 + data2, payload)
+                    raise WebSocketError('Incomplete read (expected message of %s bytes): %r %r' % args)
+            else:
+                payload = ''
+
             if mask:
-                masking_key = struct.unpack('!BBBB', self._read_from_socket(4))
-                masked_payload = self._read_from_socket(length)
-
-                masked_payload = bytearray(masked_payload)
+                # XXX message from client actually should always be masked
+                masked_payload = bytearray(payload)
 
                 for i in range(len(masked_payload)):
                     masked_payload[i] = masked_payload[i] ^ masking_key[i%4]
 
                 payload = masked_payload
 
-            # Read application data
             if opcode == self.OPCODE_TEXT:
                 self._first_opcode = opcode
-                self._chunks.extend(payload)
-
+                if payload:
+                    # XXX given that we have OPCODE_CONTINUATION, shouldn't we just reset _chunks here?
+                    self._chunks.extend(payload)
             elif opcode == self.OPCODE_BINARY:
                 self._first_opcode = opcode
+                if payload:
+                    self._chunks.extend(payload)
+            elif opcode == self.OPCODE_CONTINUATION:
                 self._chunks.extend(payload)
-
-            elif opcode == self.OPCODE_CONTINUATION:
-                if len(self._chunks) != 0:
-                    raise ProtocolException("Cannot continue a non started message")
-
-                self._chunks.extend(payload)
-
             elif opcode == self.OPCODE_CLOSE:
                 if length >= 2:
-                    reason, message = struct.unpack('!H%ds' % (length - 2), payload)
+                    reason, message = struct.unpack('!H%ds' % (length - 2), buffer(payload))
                 else:
                     reason = message = None
-
                 self.close(self.REASON_NORMAL, '')
-                if not self.compatibility_mode:
-                    return (self.OPCODE_CLOSE, (reason, message))
-                else:
-                    return None
-
+                return Closed(reason, message)
             elif opcode == self.OPCODE_PING:
                 self.send(payload, opcode=self.OPCODE_PONG)
-
-                if not self.compatibility_mode:
-                    return (self.OPCODE_PING, payload)
-                else:
-                    continue
-
+                continue
             elif opcode == self.OPCODE_PONG:
-                if not self.compatibility_mode:
-                    return (self.OPCODE_PONG, payload)
-                else:
-                    continue
+                continue
             else:
-                raise Exception("Shouldn't happen")
+                self.close()
+                raise WebSocketError("Unexpected opcode=%r" % (opcode, ))
 
             if fin == 1:
                 if self._first_opcode == self.OPCODE_TEXT:
 
                 return msg
 
-
     def _encode_text(self, s):
         if isinstance(s, unicode):
             return s.encode('utf-8')
         return opcode in (self.OPCODE_CONTINUATION, self.OPCODE_TEXT, self.OPCODE_BINARY,
             self.OPCODE_CLOSE, self.OPCODE_PING, self.OPCODE_PONG)
 
-
     def send(self, message, opcode=OPCODE_TEXT):
         """Send a frame over the websocket with message as its payload
 
         opcode -- the opcode to use (default OPCODE_TEXT)
         """
 
-        if self.websocket_closed:
-            raise Exception('Connection was terminated')
-
         if not self._is_valid_opcode(opcode):
-            raise Exception('Invalid opcode %d' % opcode)
+            raise ValueError('Invalid opcode %d' % opcode)
 
         if opcode == self.OPCODE_TEXT:
             message = self._encode_text(message)
 
-        # TODO: implement masking
         # TODO: implement fragmented messages
         mask_bit = 0
         fin = 1
-        masking_key = None
 
         ## +-+-+-+-+-------+
         ## |F|R|R|R| opcode|
         else:
             raise FrameTooLargeException()
 
-        if masking_key:
-            self.socket.sendall(str(header + masking_key + mask(message))) # TODO: implement
-        else:
-            self.socket.sendall(header + message)
+        with self._writelock:
+            self._write(header + message)
 
+    def close(self, reason=1000, message=''):
+        """Close the websocket, sending the specified reason and message"""
+        if self.fobj is not None:
+            message = self._encode_text(message)
+            self.send(struct.pack('!H%ds' % len(message), reason, message), opcode=self.OPCODE_CLOSE)
+            self.fobj.close()
+            self.fobj = None
 
-    def close(self, reason, message):
-        """Close the websocket, sending the specified reason and message"""
+    def _close(self):
+        if self.fobj is not None:
+            self.fobj.close()
+            self.fobj = None
 
-        message = self._encode_text(message)
-        self.send(struct.pack('!H%ds' % len(message), reason, message), opcode=self.OPCODE_CLOSE)
-        self.websocket_closed = True
 
-        # based on gevent/pywsgi.py
-        # see http://pypi.python.org/pypi/gevent#downloads
-        if self.socket is not None:
-            try:
-                self.socket._sock.close()
-                self.socket.close()
-            except socket.error:
-                pass
+class write_method(object):
+
+    def __init__(self, fobj):
+        self.fobj = fobj
+
+    def __call__(self, data):
+        return self.fobj.write(data)
+
+
+def _get_write(fobj):
+    flush = getattr(fobj, 'flush', None)
+    if flush is not None:
+        flush()
+    sock = getattr(fobj, '_sock', None)
+    if sock is not None:
+        sendall = getattr(sock, 'sendall', None)
+        if sendall is not None:
+            return sendall
+    write = getattr(fobj, 'write', None)
+    if write is not None:
+        return write
+    return write_method(fobj)
+
+
+# XXX avoid small recv()s ?
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.