Jeffrey Gelens avatar Jeffrey Gelens committed 6e716a4 Merge

Merged from stable

Comments (0)

Files changed (3)

geventwebsocket/handler.py

     def _handle_websocket(self):
         environ = self.environ
         try:
-            try:
-                if environ.get("HTTP_SEC_WEBSOCKET_VERSION"):
-                    result = self._handle_hybi()
-                elif environ.get("HTTP_ORIGIN"):
-                    result = self._handle_hixie()
-            except:
+            if environ.get("HTTP_SEC_WEBSOCKET_VERSION"):
                 self.close_connection = True
-                raise
+                result = self._handle_hybi()
+            elif environ.get("HTTP_ORIGIN"):
+                self.close_connection = True
+                result = self._handle_hixie()
             self.result = []
             if not result:
                 return
             self.respond('400 Bad Request')
             return
 
-        self.websocket = WebSocketHybi(self.rfile, environ)
+        self.websocket = WebSocketHybi(self.socket, environ)
         environ['wsgi.websocket'] = self.websocket
 
         headers = [
         environ = self.environ
         assert "upgrade" in self.environ.get("HTTP_CONNECTION", "").lower()
 
-        self.websocket = WebSocketHixie(self.rfile, environ)
+        self.websocket = WebSocketHixie(self.socket, environ)
         environ['wsgi.websocket'] = self.websocket
 
         key1 = self.environ.get('HTTP_SEC_WEBSOCKET_KEY1')

geventwebsocket/websocket.py

+import sys
 from socket import error as socket_error
+from errno import EINTR
 import struct
 from gevent.coros import Semaphore
 
 
+if sys.version_info[:2] == (2, 7):
+    # Python 2.7 has a working BufferedReader but socket.makefile() does not use it
+    # Python 2.6's BufferedReader is broken (TypeError: recv_into() argument 1 must be pinned buffer, not bytearray)
+    from io import BufferedReader, RawIOBase
+
+    class SocketIO(RawIOBase):
+
+        def __init__(self, sock):
+            RawIOBase.__init__(self)
+            self._sock = sock
+
+        def readinto(self, b):
+            self._checkClosed()
+            while True:
+                try:
+                    return self._sock.recv_into(b)
+                except socket_error as ex:
+                    if ex.args[0] == EINTR:
+                        continue
+                    raise
+
+        def readable(self):
+            return self._sock is not None
+
+        @property
+        def closed(self):
+            return self._sock is None
+
+        def fileno(self):
+            self._checkClosed()
+            return self._sock.fileno()
+
+        @property
+        def name(self):
+            if not self.closed:
+                return self.fileno()
+            else:
+                return -1
+
+        def close(self):
+            if self.closed:
+                return
+            RawIOBase.close(self)
+            self._sock = None
+
+    def makefile(socket):
+        return BufferedReader(SocketIO(socket))
+
+else:
+
+    def makefile(socket):
+        # XXX on python3 enable buffering
+        return socket.makefile()
+
+
+if sys.version_info[:2] < (2, 7):
+
+    def is_closed(fobj):
+        return fobj._sock is None
+
+else:
+
+    def is_closed(fobj):
+        return fobj.closed
+
+
 class WebSocketError(socket_error):
     pass
 
 
 class WebSocketHixie(WebSocket):
 
-    def __init__(self, fobj, environ):
+    def __init__(self, socket, environ):
         self.origin = environ.get('HTTP_ORIGIN')
         self.protocol = environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL')
         self.path = environ.get('PATH_INFO')
         self._writelock = Semaphore(1)
-        self.fobj = fobj
-        self._write = _get_write(fobj)
+        self.fobj = socket.makefile()
+        self._write = socket.sendall
 
     def send(self, message):
         message = self._encode_text(message)
         if self.fobj is not None:
             self.fobj.close()
             self.fobj = None
+            self._write = None
 
     def _message_length(self):
         length = 0
     OPCODE_PING = 0x9
     OPCODE_PONG = 0xA
 
-    def __init__(self, fobj, environ):
+    def __init__(self, socket, environ):
         self.origin = environ.get('HTTP_SEC_WEBSOCKET_ORIGIN')
         self.protocol = environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL', 'unknown')
         self.path = environ.get('PATH_INFO')
         self._chunks = bytearray()
         self._writelock = Semaphore(1)
-        self.fobj = fobj
-        self._write = _get_write(fobj)
+        self.socket = socket
+        self._write = socket.sendall
+        self.fobj = makefile(socket)
         self.close_code = None
         self.close_message = None
+        self._reading = False
 
     def _parse_header(self, data):
         if len(data) != 2:
-            self.close()
+            self._close()
             raise WebSocketError('Incomplete read while reading header: %r' % data)
 
         first_byte, second_byte = struct.unpack('!BB', data)
         # frame-rsv3 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise
         if rsv1 or rsv2 or rsv3:
             self.close(1002)
-            raise WebSocketError('Received frame with non-zero reserved bits: %r' % data)
+            raise WebSocketError('Received frame with non-zero reserved bits: %r' % str(data))
 
         if opcode > 0x7 and fin == 0:
             self.close(1002)
-            raise WebSocketError('Received fragmented control frame: %r' % data)
+            raise WebSocketError('Received fragmented control frame: %r' % str(data))
 
         if len(self._chunks) > 0 and fin == 0 and not opcode:
             self.close(1002)
-            raise WebSocketError('Received new fragment frame with non-zero opcode: %r' % data)
+            raise WebSocketError('Received new fragment frame with non-zero opcode: %r' % str(data))
 
         if len(self._chunks) > 0 and fin == 1 and (self.OPCODE_TEXT <= opcode <= self.OPCODE_BINARY):
             self.close(1002)
-            raise WebSocketError('Received new unfragmented data frame during fragmented message: %r' % data)
+            raise WebSocketError('Received new unfragmented data frame during fragmented message: %r' % str(data))
 
         has_mask = (second_byte >> 7) & 1
         length = (second_byte) & 0x7f
         # Control frames MUST have a payload length of 125 bytes or less
         if opcode > 0x7 and length > 125:
             self.close(1002)
-            raise FrameTooLargeException("Control frame payload cannot be larger than 125 bytes: %r" % data)
+            raise FrameTooLargeException("Control frame payload cannot be larger than 125 bytes: %r" % str(data))
 
         return fin, opcode, has_mask, length
 
     def receive_frame(self):
         """Return the next frame from the socket."""
-        if self.fobj is None:
+        fobj = self.fobj
+
+        if fobj is None:
+            return
+
+        if is_closed(fobj):
             return
 
         read = self.fobj.read
 
-        data0 = read(2)
-        if not data0:
-            self._close()
-            return
+        assert not self._reading, 'Reading is not possible from multiple greenlets'
+        self._reading = True
+        try:
+            data0 = read(2)
+            if not data0:
+                self._close()
+                return
 
-        fin, opcode, has_mask, length = self._parse_header(data0)
+            fin, opcode, has_mask, length = self._parse_header(data0)
 
-        if length < 126:
-            data1 = ''
-        elif length == 126:
-            data1 = read(2)
-            if len(data1) != 2:
-                self.close()
-                raise WebSocketError('Incomplete read while reading 2-byte length: %r' % (data0 + data1))
-            length = struct.unpack('!H', data1)[0]
-        elif length == 127:
-            data1 = read(8)
-            if len(data1) != 8:
-                self.close()
-                raise WebSocketError('Incomplete read while reading 8-byte length: %r' % (data0 + data1))
-            length = struct.unpack('!Q', data1)[0]
-        else:
-            self.close()
-            raise WebSocketError('Invalid length: %r' % data0)
+            if not has_mask and length:
+                self.close(1002)
+                raise WebSocketError('Message from client is not masked')
 
-        if has_mask:
-            data2 = read(4)
-            if len(data2) != 4:
-                self.close()
-                raise WebSocketError('Incomplete read while reading mask: %r' % (data0 + data1 + data2))
-            masking_key = struct.unpack('!BBBB', data2)
-        else:
-            data2 = ''
+            if length < 126:
+                data1 = ''
+            elif length == 126:
+                data1 = read(2)
+                if len(data1) != 2:
+                    self.close()
+                    raise WebSocketError('Incomplete read while reading 2-byte length: %r' % (data0 + data1))
+                length = struct.unpack('!H', data1)[0]
+            else:
+                assert length == 127, length
+                data1 = read(8)
+                if len(data1) != 8:
+                    self.close()
+                    raise WebSocketError('Incomplete read while reading 8-byte length: %r' % (data0 + data1))
+                length = struct.unpack('!Q', data1)[0]
 
-        if length:
-            payload = read(length)
-            if len(payload) != length:
-                self.close()
-                args = (length, len(payload))
-                raise WebSocketError('Incomplete read: expected message of %s bytes, got %s bytes' % args)
-        else:
-            payload = ''
+            mask = read(4)
+            if len(mask) != 4:
+                self._close()
+                raise WebSocketError('Incomplete read while reading mask: %r' % (data0 + data1 + mask))
 
-        if has_mask and payload:
-            # XXX message from client actually should always be masked
-            masked_payload = bytearray(payload)
+            mask = struct.unpack('!BBBB', mask)
 
-            for i in range(len(masked_payload)):
-                masked_payload[i] = masked_payload[i] ^ masking_key[i % 4]
+            if length:
+                payload = read(length)
+                if len(payload) != length:
+                    self._close()
+                    args = (length, len(payload))
+                    raise WebSocketError('Incomplete read: expected message of %s bytes, got %s bytes' % args)
+            else:
+                payload = ''
 
-            payload = masked_payload
+            if payload:
+                payload = bytearray(payload)
+                for i in xrange(len(payload)):
+                    payload[i] = payload[i] ^ mask[i % 4]
 
-        return fin, opcode, payload
+            return fin, opcode, payload
+        finally:
+            self._reading = False
+            if self.fobj is None:
+                fobj.close()
 
     def _receive(self):
         """Return the next text or binary message from the socket."""
         else:
             raise FrameTooLargeException()
 
-        with self._writelock:
-            self._write(header + message)
+        try:
+            combined = header + message
+        except TypeError:
+            with self._writelock:
+                self._write(header)
+                self._write(message)
+        else:
+            with self._writelock:
+                self._write(combined)
 
     def send(self, message, binary=None):
         """Send a frame over the websocket with message as its payload"""
 
     def close(self, code=1000, message=''):
         """Close the websocket, sending the specified code and message"""
-        if self.fobj is not None:
+        if self.socket is not None:
             message = self._encode_text(message)
             self.send_frame(struct.pack('!H%ds' % len(message), code, message), opcode=self.OPCODE_CLOSE)
-            self.fobj.close()
-            self.fobj = None
+            self._close()
 
     def _close(self):
-        if self.fobj is not None:
-            self.fobj.close()
+        if self.socket is not None:
+            self.socket._sock.close()
+            self.socket = None
+            self._write = None
+            fobj = self.fobj
             self.fobj = None
-
-
-class write_method(object):
-
-    def __init__(self, fobj):
-        self.fobj = fobj
-
-    def __call__(self, data):
-        return self.fobj.write(data)
-
-
-def _get_write(fobj):
-    flush = getattr(fobj, 'flush', None)
-    if flush is not None:
-        flush()
-    sock = getattr(fobj, '_sock', None)
-    if sock is not None:
-        sendall = getattr(sock, 'sendall', None)
-        if sendall is not None:
-            return sendall
-    write = getattr(fobj, 'write', None)
-    if write is not None:
-        return write
-    return write_method(fobj)
-
-
-# XXX avoid small recv()s ?
+            if not self._reading:
+                fobj.close()

run_autobahn_tests.py

 spec = {
    "options": {"failByDrop": False},
    "enable-ssl": False,
-   "servers": [],
-   "cases": ["*"],
-   "exclude-cases": ["7.5.1",
-                     "7.9.3",
-                     "7.9.4",
-                     "7.9.5",
-                     "7.9.6",
-                     "7.9.7",
-                     "7.9.8",
-                     "7.9.9",
-                     "7.9.10",
-                     "7.9.11",
-                     "7.9.12",
-                     "7.9.13"]
-}
+   "servers": []}
+
+
+default_args = ["*",
+         "x7.5.1",
+         "x7.9.3",
+         "x7.9.4",
+         "x7.9.5",
+         "x7.9.6",
+         "x7.9.7",
+         "x7.9.8",
+         "x7.9.9",
+         "x7.9.10",
+         "x7.9.11",
+         "x7.9.12",
+         "x7.9.13"]
 # We ignore 7.5.1 because it checks that close frame has valid utf-8 message
 # we do not validate utf-8.
 
     parser.add_option('--geventwebsocket', default='examples/echoserver.py')
     parser.add_option('--autobahn', default='../../src/Autobahn/testsuite/websockets/servers/test_autobahn.py')
     options, args = parser.parse_args()
-    assert not args, args
+
+    cases = []
+    exclude_cases = []
+
+    for arg in (args or default_args):
+        if arg.startswith('x'):
+            arg = arg[1:]
+            exclude_cases.append(arg)
+        else:
+            cases.append(arg)
+
+    spec['cases'] = cases
+    spec['exclude-cases'] = exclude_cases
+
     if options.autobahn and not os.path.exists(options.autobahn):
         print 'Ignoring %s (not found)' % options.autobahn
         options.autobahn = None
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.