Jeffrey Gelens avatar Jeffrey Gelens committed aa2b374

Refactoring/merging code with Denis Bilenko's websocket repos

Comments (0)

Files changed (2)

geventwebsocket/handler.py

 import re
 import struct
 from hashlib import md5
+from socket import error
 
 from gevent.pywsgi import WSGIHandler
 from geventwebsocket import WebSocket
 
 
-class HandShakeError(ValueError):
-    """ Hand shake challenge can't be parsed """
+class WebSocketError(error):
     pass
 
 
+class BadRequest(WebSocketError):
+    """
+    This error will be raised by meth:`do_handshake` when encountering an invalid request.
+    If left unhandled, it will cause :class:`WebSocketHandler` to log the error and to issue 400 reply.
+    It will also be raised by :meth:`connect` if remote server has replied with 4xx error.
+    """
+
+
 class WebSocketHandler(WSGIHandler):
     """ Automatically upgrades the connection to websockets. """
     def __init__(self, *args, **kwargs):
-        self.websocket_connection = False
         self.allowed_paths = []
 
         for expression in kwargs.pop('allowed_paths', []):
 
         super(WebSocketHandler, self).__init__(*args, **kwargs)
 
-    def handle_one_response(self, call_wsgi_app=True):
+    def run_application(self):
+        if self.websocket:
+            return self.application(self.environ, self.start_response)
+        else:
+            return super(WebSocketHandler, self).run_application()
+
+    def handle_one_response(self):
+        # TODO: refactor to run under run_application
         # In case the client doesn't want to initialize a WebSocket connection
         # we will proceed with the default PyWSGI functionality.
         if self.environ.get("HTTP_CONNECTION") != "Upgrade" or \
            not self.environ.get("HTTP_ORIGIN") or \
            not self.accept_upgrade():
             return super(WebSocketHandler, self).handle_one_response()
-        else:
-            self.websocket_connection = True
 
         self.websocket = WebSocket(self.socket, self.rfile, self.environ)
         self.environ['wsgi.websocket'] = self.websocket
 
+        headers = [
+            ("Upgrade", "WebSocket"),
+            ("Connection", "Upgrade"),
+        ]
+
         # Detect the Websocket protocol
         if "HTTP_SEC_WEBSOCKET_KEY1" in self.environ:
             version = 76
             version = 75
 
         if version == 75:
-            headers = [
-                ("Upgrade", "WebSocket"),
-                ("Connection", "Upgrade"),
+            headers.extend([
                 ("WebSocket-Origin", self.websocket.origin),
                 ("WebSocket-Protocol", self.websocket.protocol),
                 ("WebSocket-Location", "ws://" + self.environ.get('HTTP_HOST') + self.websocket.path),
-            ]
+            ])
             self.start_response("101 Web Socket Protocol Handshake", headers)
         elif version == 76:
             challenge = self._get_challenge()
-            headers = [
-                ("Upgrade", "WebSocket"),
-                ("Connection", "Upgrade"),
+            headers.extend([
                 ("Sec-WebSocket-Origin", self.websocket.origin),
                 ("Sec-WebSocket-Protocol", self.websocket.protocol),
                 ("Sec-WebSocket-Location", "ws://" + self.environ.get('HTTP_HOST') + self.websocket.path),
-            ]
+            ])
 
             self.start_response("101 Web Socket Protocol Handshake", headers)
-            self.write([challenge])
+            self.write(challenge)
         else:
-            raise Exception("Version not supported")
+            raise Exception("WebSocket version not supported")
 
-        if call_wsgi_app:
-            return self.application(self.environ, self.start_response)
-        else:
-            return
+        return self.run_application()
 
     def accept_upgrade(self):
         """
             return True
 
     def write(self, data):
-        if self.websocket_connection:
-            self.wfile.writelines(data)
+        if self.websocket:
+            self.socket.sendall(data)
         else:
             super(WebSocketHandler, self).write(data)
 
                 towrite.append("%s: %s\r\n" % header)
 
             towrite.append("\r\n")
-            self.wfile.writelines(towrite)
+            self.socket.sendall(towrite)
             self.headers_sent = True
         else:
             super(WebSocketHandler, self).start_response(status, headers, exc_info)
         spaces = re.subn(" ", "", key_value)[1]
 
         if key_number % spaces != 0:
-            raise HandShakeError("key_number %d is not an intergral multiple of"
+            raise WebSocketHandler("key_number %d is not an intergral multiple of"
                                  " spaces %d" % (key_number, spaces))
 
         return key_number / spaces
         key1 = self.environ.get('HTTP_SEC_WEBSOCKET_KEY1')
         key2 = self.environ.get('HTTP_SEC_WEBSOCKET_KEY2')
 
-        if not (key1 and key2):
-            message = "Client using old/invalid protocol implementation"
-            headers = [("Content-Length", str(len(message))),]
-            self.start_response("400 Bad Request", headers)
-            self.write([message])
-            self.close_connection = True
-            return
+        if not key1:
+            raise BadRequest("SEC-WEBSOCKET-KEY1 header is missing")
+        if not key2:
+            raise BadRequest("SEC-WEBSOCKET-KEY2 header is missing")
 
         part1 = self._get_key_value(self.environ['HTTP_SEC_WEBSOCKET_KEY1'])
         part2 = self._get_key_value(self.environ['HTTP_SEC_WEBSOCKET_KEY2'])
         # This request should have 8 bytes of data in the body
         key3 = self.rfile.read(8)
 
-        challenge = ""
-        challenge += struct.pack("!I", part1)
-        challenge += struct.pack("!I", part2)
-        challenge += key3
-
-        return md5(challenge).digest()
+        return md5(struct.pack("!II", part1, part2) + key3).digest()
 
     def wait(self):
         return self.websocket.wait()

geventwebsocket/websocket.py

+from gevent.coros import Semaphore
+
 # This class implements the Websocket protocol draft version as of May 23, 2010
 # The version as of August 6, 2010 will be implementend once Firefox or
 # Webkit-trunk support this version.
         self.origin = environ.get('HTTP_ORIGIN')
         self.protocol = environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL', 'unknown')
         self.path = environ.get('PATH_INFO')
-        self.websocket_closed = False
+        self._writelock = Semaphore(1)
 
     def send(self, message):
-        if self.websocket_closed:
-            raise Exception("Connection was terminated")
-
         if isinstance(message, unicode):
             message = message.encode('utf-8')
         elif isinstance(message, str):
         else:
             raise Exception("Invalid message encoding")
 
-        self.socket.sendall("\x00" + message + "\xFF")
+        with self._writelock:
+            self.socket.sendall("\x00" + message + "\xFF")
 
-    def close_connection(self):
-        if not self.websocket_closed:
-            self.websocket_closed = True
-            self.socket.shutdown(True)
-            self.socket.close()
-        else:
-            return
+    def detach(self):
+        self.socket = None
+        self.rfile = None
+        self.handler = None
+
+    def close(self):
+        # TODO implement graceful close with 0xFF frame
+        if self.socket is not None:
+            try:
+                self.socket.close()
+            except Exception:
+                pass
+            self.detach()
+
 
     def _message_length(self):
         # TODO: buildin security agains lengths greater than 2**31 or 2**32
 
         return ''.join(bytes)
 
-    def wait(self):
-        while True:
-            if self.websocket_closed:
-                return None
-
+    def receive(self):
+        while self.socket is not None:
             frame_str = self.rfile.read(1)
             if not frame_str:
                 # Connection lost?
-                self.websocket_closed = True
-                continue
+                self.close()
+                break
             else:
                 frame_type = ord(frame_str)
 
 
             if (frame_type & 0x80) == 0x00: # most significant byte is not set
-
                 if frame_type == 0x00:
                     bytes = self._read_until()
                     return bytes.decode("utf-8", "replace")
                 else:
-                    self.websocket_closed = True
-
+                    self.close()
             elif (frame_type & 0x80) == 0x80: # most significant byte is set
                 # Read binary data (forward-compatibility)
                 if frame_type != 0xff:
-                    self.websocket_closed = True
+                    self.close()
+                    break
                 else:
                     length = self._message_length()
                     if length == 0:
-                        self.websocket_closed = True
+                        self.close()
+                        break
                     else:
                         self.rfile.read(length) # discard the bytes
             else:
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.