Anonymous avatar Anonymous committed d5efd17

Implement most of wait; Fix race in test application

Can now receive unfragmented frames. Added a loop in the test case application
that sleeps while the websocket is open. This fixes a race where the
application and the test case are both in wait() at the same time.

Comments (0)

Files changed (2)

geventwebsocket/websocket.py

         self.path = environ.get('PATH_INFO')
         self.websocket_closed = False
 
+    def _read_from_socket(self, count):
+        return self.rfile.read(count)
+
     def wait(self):
-        msg = ""
+        msg = ''
         while True:
             if self.websocket_closed:
                 return None
 
-            opcode, length = struct.unpack('!BB', self.rfile.read(2))
+            opcode_octet, length_octet = struct.unpack('!BB', self._read_from_socket(2))
 
-            if self.RSV & opcode:
+            if self.RSV & opcode_octet:
                 self.close(1002, 'Reserved bits cannot be set')
                 return None
 
+            opcode = opcode_octet & self.OPCODE
+            if self._is_opcode_invalid(opcode):
+                self.close(1002, 'Invalid opcode %x' % opcode)
+                return None
+
+            if not self.MASK & length_octet:
+                self.close(1002, 'MASK must be set')
+                return None
+
+            length_code = length_octet & self.PAYLOAD
             is_final_frag = (self.FIN & opcode) != 0
 
-    def _encodeText(self, s):
+            if length_code < 126:
+                length = length_code
+            elif length_code == 126:
+                length = struct.unpack('!H', self._read_from_socket(2))[0]
+            elif length_code == 127:
+                length = struct.unpack('!Q', self._read_from_socket(8))[0]
+            else:
+                raise Exception('Calculated invalid length')
+
+            mask_octets = struct.unpack('!BBBB', self._read_from_socket(4))
+            masked_payload = self._read_from_socket(length)
+            payload = ''
+
+            # TODO: optimize me
+            j = 0
+            for c in masked_payload:
+                payload += chr(ord(c) ^ mask_octets[j])
+                j = (j + 1) % 4
+
+            if opcode == self.OPCODE_TEXT:
+                payload = payload.decode('utf-8')
+
+            return payload
+
+    def _encode_text(self, s):
         if isinstance(s, unicode):
             return s.encode('utf-8')
         elif isinstance(s, str):
         else:
             raise Exception('Invalid encoding')
 
+    def _is_opcode_invalid(self, opcode):
+        return opcode < self.OPCODE_TEXT or (opcode > self.OPCODE_BINARY and 
+                opcode < self.OPCODE_CLOSE) or opcode > self.OPCODE_PONG
+
     def send(self, opcode, message):
         if self.websocket_closed:
             raise Exception('Connection was terminated')
 
-        if opcode < self.OPCODE_TEXT or (opcode > self.OPCODE_BINARY and 
-                opcode < self.OPCODE_CLOSE) or opcode > self.OPCODE_PONG:
+        if self._is_opcode_invalid(opcode):
             raise Exception('Invalid opcode %d' % opcode)
 
         if opcode == self.OPCODE_TEXT:
-            message = self._encodeText(message)
+            message = self._encode_text(message)
 
         length = len(message)
 
             preamble = struct.pack('!BB', self.FIN | opcode, length)
         elif length < 2 ** 16:
             preamble = struct.pack('!BBH', self.FIN | opcode, self.LEN_16, length)
+        elif length < 2 ** 64:
+            preamble = struct.pack('!BBQ', self.FIN | opcode, self.LEN_64, length)
         else:
-            preamble = struct.pack('!BBQ', self.FIN | opcode, self.LEN_64, length)
+            # this can't really happen, but for correctness sake...
+            raise Exception('Message is too long')
 
         self.socket.sendall(preamble + message)
 
     def close(self, reason, message):
-        message = self._encodeText(message)
+        message = self._encode_text(message)
         self.send(self.OPCODE_CLOSE, struct.pack('!H%ds' % len(message), reason, message))
         self.websocket_closed = True
 

tests/test__websocket.py

 import sys
 import greentest
 import gevent
+import gevent.local
 from gevent import socket
 from geventwebsocket.handler import WebSocketHandler
 from geventwebsocket.websocket import WebSocketVersion7
 
-
 CONTENT_LENGTH = 'Content-Length'
 CONN_ABORTED_ERRORS = []
 DEBUG = '-v' in sys.argv
 
 socket.socket.makefile = makefile
 
+import gevent.coros
 class TestCase(greentest.TestCase):
     __timeout__ = 5
+    _testlock = gevent.coros.Semaphore(1)
 
     def get_wsgi_module(self):
         from gevent import pywsgi
         return pywsgi
 
     def init_server(self, application):
-        self.server = self.get_wsgi_module().WSGIServer(('127.0.0.1', 0),
+        self.local = gevent.local.local()
+        self.local.server = self.get_wsgi_module().WSGIServer(('127.0.0.1', 0),
             application, handler_class=WebSocketHandler)
 
     def setUp(self):
         application = self.application
         self.init_server(application)
-        self.server.start()
-        self.port = self.server.server_port
+        self.local.server.start()
+        self.port = self.local.server.server_port
         greentest.TestCase.setUp(self)
 
-
     def tearDown(self):
         greentest.TestCase.tearDown(self)
         timeout = gevent.Timeout.start_new(0.5)
         try:
-            self.server.stop()
+            self.local.server.stop()
         finally:
             timeout.cancel()
 
     def connect(self):
         return socket.create_connection(('127.0.0.1', self.port))
 
-
 class TestWebSocket(TestCase):
     message = "\x00Hello world\xff"
 
                 start_response("400 Bad Request", [])
                 return []
 
-            """
-            while True:
-                message = ws.wait()
-                if message is None:
-                    break
-                ws.send(message)
-            """
-
-            return []
+            while not ws.websocket_closed:
+                gevent.sleep()
+            self.close_connection = True
+            return None
 
     def test_bad_handshake_method(self):
         fd = self.connect().makefile(bufsize=1)
         fd.write(self.GOOD_HEADERS)
         read_http(fd, code=101, reason="Switching Protocols")
 
-        msg = 'Hello, websocket' * 4097
+        msg = 'Hello, websocket' * 4098
         self.ws.send(1, msg)
 
         preamble = fd.read(10)
         fd.close()
 
     def test_wait_bad_framing_reserved_bits(self):
+        for reserved_bits in xrange(1, 8):
+            fd = self.connect().makefile(bufsize=1)
+
+            fd.write(self.GOOD_HEADERS)
+            read_http(fd, code=101, reason='Switching Protocols')
+
+            expected_msg = 'Reserved bits cannot be set'
+
+            bad_opcode = WebSocketVersion7.FIN | (reserved_bits << 4) | WebSocketVersion7.OPCODE_TEXT
+            fd.write(struct.pack('!BB', bad_opcode, int('10000000', 2)))
+
+            frame = self.ws.wait()
+            assert self.ws.websocket_closed, \
+                    'Failed to close connection when sent a frame with reserved bits set'
+
+            preamble = fd.read(2)
+
+            opcode, length = struct.unpack('!BB', preamble)
+            assert opcode & WebSocketVersion7.FIN, 'FIN must be set'
+            assert (opcode & WebSocketVersion7.OPCODE) == 8, 'Opcode must be 0x8'
+            assert (length & WebSocketVersion7.MASK) == 0, 'MASK must not be set'
+
+            reason = fd.read(2)
+            reason = struct.unpack('!H', reason)[0]
+            assert reason == 1002, 'Expected reason to be 1002, but got %d' % reason
+
+            rxd_msg = fd.read(length - 2).decode('utf-8', 'replace')
+            assert rxd_msg == expected_msg, 'Wrong message "%s"' % rxd_msg
+
+            fd.close();
+
+    def test_wait_bad_opcode(self):
+        bad_opcodes = range(WebSocketVersion7.OPCODE_BINARY + 1, WebSocketVersion7.OPCODE_CLOSE)
+        bad_opcodes += range(WebSocketVersion7.OPCODE_PONG + 1, 2**4)
+        for bad_opcode in bad_opcodes:
+            fd = self.connect().makefile(bufsize=1)
+
+            fd.write(self.GOOD_HEADERS)
+            read_http(fd, code=101, reason='Switching Protocols')
+
+            expected_msg = 'Invalid opcode %x' % bad_opcode
+
+            bad_opcode = WebSocketVersion7.FIN | bad_opcode
+            fd.write(struct.pack('!BB', bad_opcode, int('10000000', 2)))
+
+            frame = self.ws.wait()
+            assert self.ws.websocket_closed, \
+                    'Failed to close connection when sent a frame with unsupported opcode'
+
+            preamble = fd.read(2)
+
+            opcode, length = struct.unpack('!BB', preamble)
+            assert opcode & WebSocketVersion7.FIN, 'FIN must be set'
+            assert (opcode & WebSocketVersion7.OPCODE) == 8, 'Opcode must be 0x8'
+            assert (length & WebSocketVersion7.MASK) == 0, 'MASK must not be set'
+
+            reason = fd.read(2)
+            reason = struct.unpack('!H', reason)[0]
+            assert reason == 1002, 'Expected reason to be 1002, but got %d' % reason
+
+            rxd_msg = fd.read(length - 2).decode('utf-8', 'replace')
+            assert rxd_msg == expected_msg, 'Wrong message "%s"' % rxd_msg
+
+            fd.close();
+
+    def test_wait_no_mask(self):
         fd = self.connect().makefile(bufsize=1)
 
         fd.write(self.GOOD_HEADERS)
-        read_http(fd, code=101, reason="Switching Protocols")
+        read_http(fd, code=101, reason='Switching Protocols')
 
-        expected_msg = 'Reserved bits cannot be set'
+        expected_msg = 'MASK must be set'
 
-        fd.write(struct.pack("!BB", int("11000001", 2), int("10000000", 2)))
+        fd.write(struct.pack('!BB',
+            WebSocketVersion7.FIN | WebSocketVersion7.OPCODE_TEXT, int('00000000', 2)))
 
         frame = self.ws.wait()
-        assert self.ws.websocket_closed, "Failed to close connection when sent a frame with RSV1 set"
+        assert self.ws.websocket_closed, \
+                'Failed to close connection when sent a frame with MASK not set'
 
         preamble = fd.read(2)
 
 
         fd.close();
 
+    def _get_payload(self, mask, msg):
+        mask_octets = struct.unpack('!BBBB', struct.pack('!L', mask))
+        msg = unicode(msg).encode('utf-8')
+        result = ''
+        j = 0
+        for c in msg:
+            result += chr(ord(c) ^ mask_octets[j])
+            j = (j + 1) % 4
+        return result
+
+    def test_wait_short_frame(self):
+        fd = self.connect().makefile(bufsize=1)
+
+        fd.write(self.GOOD_HEADERS)
+        read_http(fd, code=101, reason='Switching Protocols')
+
+        msg = 'Hello, websocket'
+        mask = 42
+        encoded_msg = self._get_payload(mask, msg)
+        length = len(encoded_msg)
+        fd.write(struct.pack('!BBL%ds' % length,
+            WebSocketVersion7.FIN | WebSocketVersion7.OPCODE_TEXT, 
+            WebSocketVersion7.MASK | length, mask, encoded_msg))
+
+        rxd_msg = self.ws.wait()
+        assert not self.ws.websocket_closed, 'Closed connection when sent a good frame'
+        assert rxd_msg == msg, 'Wrong message "%s"' % rxd_msg
+
+        fd.close();
+
+    def test_wait_med_frame(self):
+        fd = self.connect().makefile(bufsize=1)
+
+        fd.write(self.GOOD_HEADERS)
+        read_http(fd, code=101, reason='Switching Protocols')
+
+        msg = 'Hello, websocket' * 8
+        mask = 42
+        encoded_msg = self._get_payload(mask, msg)
+        length = len(encoded_msg)
+        fd.write(struct.pack('!BBHL%ds' % length,
+            WebSocketVersion7.FIN | WebSocketVersion7.OPCODE_TEXT, 
+            WebSocketVersion7.MASK | 126, length, mask, encoded_msg))
+
+        rxd_msg = self.ws.wait()
+        assert not self.ws.websocket_closed, 'Closed connection when sent a good frame'
+        assert rxd_msg == msg, 'Wrong message "%s"' % rxd_msg
+
+        fd.close();
+
+    def test_wait_long_frame(self):
+        fd = self.connect().makefile(bufsize=1)
+
+        fd.write(self.GOOD_HEADERS)
+        read_http(fd, code=101, reason='Switching Protocols')
+
+        msg = 'Hello, websocket' * 8 #4098
+        mask = 42
+        encoded_msg = self._get_payload(42, msg)
+        length = len(encoded_msg)
+        payload = struct.pack('!BBQL%ds' % length,
+            WebSocketVersion7.FIN | WebSocketVersion7.OPCODE_TEXT, 
+            WebSocketVersion7.MASK | 127, length, mask, encoded_msg)
+        fd.write(payload)
+
+        self.ws._waiter = 'test'
+        rxd_msg = self.ws.wait()
+        assert not self.ws.websocket_closed, 'Closed connection when sent a good frame'
+        assert rxd_msg == msg, 'Wrong message "%s"' % rxd_msg
+
+        fd.close();
+
 if __name__ == '__main__':
     greentest.main()
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.