Jeffrey Gelens avatar Jeffrey Gelens committed 3e0ecdf

Integrated Lon Ingram's version 7 of the protocol. Working on version 8.
TODO:
refactor loading if the websocket classes

Comments (0)

Files changed (4)

examples/websocket.html

 var iets = "";
 window.onload = function() {
     var data = {};
-    var s = new WebSocket("ws://localhost:8000/data");
+    var s = new MozWebSocket("ws://localhost:8000/data");
     s.onopen = function() {
         //alert('open');
         s.send('hi');

geventwebsocket/__init__.py

 __version__ =  ".".join(map(str, version_info))
 
 try:
-    from geventwebsocket.websocket import WebSocket, WebSocketLegacy
+    from geventwebsocket.websocket import WebSocketVersion7, WebSocketLegacy
 except ImportError:
     import traceback
     traceback.print_exc()

geventwebsocket/handler.py

 import re
 import struct
 from hashlib import md5, sha1
-from base64 import b64encode
+from base64 import b64encode, b64decode
 
 from gevent.pywsgi import WSGIHandler
-from geventwebsocket import WebSocket, WebSocketLegacy
+from geventwebsocket import WebSocketVersion7, WebSocketLegacy
 
 
 PROTOCOL_VERSIONS = (
     "0",
     "6",
 )
-MAGIC_STRING = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
 
 class HandShakeError(ValueError):
     """ Hand shake challenge can't be parsed """
 
 class WebSocketHandler(WSGIHandler):
     """ Automatically upgrades the connection to websockets. """
+
+    GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
+
     def __init__(self, *args, **kwargs):
         self.websocket_connection = False
         self.allowed_paths = []
     def handle_one_response(self, call_wsgi_app=True):
         # In case the client doesn't want to initialize a WebSocket connection
         # we will proceed with the default PyWSGI functionality.
+        print self.environ.get("HTTP_CONNECTION", "").lower().split(",")
 
-        if "Upgrade" in self.environ.get("HTTP_CONNECTION", "").split(",") and \
-             "WebSocket" in self.environ.get("HTTP_UPGRADE") and \
+        if "upgrade" in self.environ.get("HTTP_CONNECTION", "").lower(). \
+             replace(" ", "").split(",") and \
+             "websocket" in self.environ.get("HTTP_UPGRADE").lower() and \
              self.upgrade_allowed():
             self.websocket_connection = True
         else:
+            print "NORMAL"
+            from pprint import pprint
+            pprint(self.environ)
             return super(WebSocketHandler, self).handle_one_response()
 
         self.init_websocket()
 
     def init_websocket(self):
         version = self.environ.get("HTTP_SEC_WEBSOCKET_VERSION")
+        print "VERSION", version
 
         if self.environ.get("HTTP_ORIGIN"):
+            print "OLD ", version
             self.websocket = WebSocketLegacy(self.socket, self.rfile, self.environ)
 
             if "HTTP_SEC_WEBSOCKET_KEY1" in self.environ:
             else:
                 self._handshake_hixie75()
         else:
-            self.websocket = WebSocket(self.socket, self.rfile, self.environ)
+            print "NEW ", version
+            self.websocket = WebSocketVersion7(self.socket, self.rfile, self.environ)
 
             if version and int(version) in PROTOCOL_VERSIONS:
                 pass
         self.start_response("101 Web Socket Protocol Handshake", headers)
         self.write(challenge)
 
-    def handshake_hybi06(self):
+    def _handshake_version7(self):
+        environ = self.environ
+
+        protocol, version = self.request_version.split("/")
+        key = environ.get("HTTP_SEC_WEBSOCKET_KEY")
+
+        # check client handshake for validity
+        if not environ.get("REQUEST_METHOD") == "GET":
+            # 5.2.1 (1)
+            self._close_connection()
+            return False
+        elif not protocol == "HTTP":
+            # 5.2.1 (1)
+            self._close_connection()
+            return False
+        elif float(version) < 1.1:
+            # 5.2.1 (1)
+            self._close_connection()
+            return False
+        # XXX: nobody seems to set SERVER_NAME correctly. check the spec
+        #elif not environ.get("HTTP_HOST") == environ.get("SERVER_NAME"):
+             # 5.2.1 (2)
+             #self._close_connection()
+             #return False
+        elif not key:
+            # 5.2.1 (3)
+            self._close_connection()
+            return False
+        elif len(b64decode(key)) != 16:
+            # 5.2.1 (3)
+            self._close_connection()
+            return False
+
+
+    def _handshake_hybi06(self):
         raise Exception("Version not yet supported")
         challenge = self._get_challange_hybi06()
         headers = [
         self.start_response("101 Switching Protocols", headers)
         self.write(challenge)
 
+    def _close_connection(self, reason=None):
+        # based on gevent/pywsgi.py
+        # see http://pypi.python.org/pypi/gevent#downloads
+
+        if reason:
+            print "Closing the connection because %s!" % reason
+        if self.socket is not None:
+            try:
+                self.socket._sock.close()
+                self.socket.close()
+            except socket.error:
+                pass
+
 
     def upgrade_allowed(self):
         """
             return True
 
     def write(self, data):
-        if data:
-            if self.websocket_connection:
+        if self.websocket_connection:
+            if data:
                 self.socket.sendall(data)
             else:
-                super(WebSocketHandler, self).write(data)
+                raise Exception("No data to send")
         else:
-            raise Exception("No data to send")
+            super(WebSocketHandler, self).write(data)
 
     def start_response(self, status, headers, exc_info=None):
         if self.websocket_connection:

geventwebsocket/websocket.py

+import struct
+
+
 class WebSocket(object):
     pass
 
                         self.rfile.read(length) # discard the bytes
             else:
                 raise IOError("Reveiced an invalid message")
+
+
+class WebSocketVersion7(WebSocketLegacy):
+    FIN = int("10000000", 2)
+    RSV = int("01110000", 2)
+    OPCODE = int("00001111", 2)
+    MASK = int("10000000", 2)
+    PAYLOAD = int("01111111", 2)
+
+    OPCODE_FRAG = 0x0
+    OPCODE_TEXT = 0x1
+    OPCODE_BINARY = 0x2
+    OPCODE_CLOSE = 0x8
+    OPCODE_PING = 0x9
+    OPCODE_PONG = 0xA
+
+    REASON_NORMAL = 1000
+    REASON_GOING_AWAY = 1001
+    REASON_PROTOCOL_ERROR = 1002
+    REASON_UNSUPPORTED_DATA_TYPE = 1003
+    REASON_TOO_LARGE = 1004
+
+    LEN_16 = 126
+    LEN_64 = 127
+
+    def __init__(self, sock, rfile, environ, compatibility_mode=True):
+        self.rfile = rfile
+        self.socket = sock
+        self.origin = environ.get('HTTP_SEC_WEBSOCKET_ORIGIN')
+        self.protocol = environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL', 'unknown')
+        self.path = environ.get('PATH_INFO')
+        self.websocket_closed = False
+        self.compatibility_mode = compatibility_mode
+        self._fragments = []
+        self._original_opcode = -1
+
+    def _read_from_socket(self, count):
+        return self.rfile.read(count)
+
+    def wait(self):
+        """Return the next frame from the socket
+
+        If the next frame is invalid, wait closes the socket and returns None.
+
+        If the next frame is valid and the websocket instance's
+        compatibility_mode attribute is True, then wait ignores PING and PONG
+        frames, returns None when sent a CLOSE frame and returns the payload
+        for data frames.
+
+        If the next frame is valid and the websocket instance's
+        compatibility_mode attribute is False, it returns a tuple of the form
+        (opcode, payload).
+        """
+
+        while True:
+            payload = ''
+            if self.websocket_closed:
+                return None
+
+            opcode_octet, length_octet = struct.unpack('!BB', self._read_from_socket(2))
+
+            if self.RSV & opcode_octet:
+                self.close(self.REASON_PROTOCOL_ERROR, 'Reserved bits cannot be set')
+                return None
+
+            opcode = opcode_octet & self.OPCODE
+            is_final_frag = (self.FIN & opcode_octet) != 0
+
+            if self._is_opcode_invalid(opcode):
+                self.close(self.REASON_PROTOCOL_ERROR, 'Invalid opcode %x' % opcode)
+                return None
+
+            if not is_final_frag and self.OPCODE_CLOSE <= opcode <= self.OPCODE_PONG:
+                self.close(self.REASON_PROTOCOL_ERROR, 'Control frames cannot be fragmented')
+                return None
+
+            if len(self._fragments) > 0 and not is_final_frag and opcode != self.OPCODE_FRAG:
+                self.close(self.REASON_PROTOCOL_ERROR,
+                        'Received new fragment frame with non-zero opcode')
+                return None
+
+            if len(self._fragments) > 0 and is_final_frag and (
+                    self.OPCODE_TEXT <= opcode <= self.OPCODE_BINARY):
+                self.close(self.REASON_PROTOCOL_ERROR,
+                        'Received new unfragmented data frame during fragmented message')
+                return None
+
+            if not self.MASK & length_octet:
+                self.close(self.REASON_PROTOCOL_ERROR, 'MASK must be set')
+                return None
+
+            length_code = length_octet & self.PAYLOAD
+
+            if length_code >= self.LEN_16 and (self.OPCODE_CLOSE <= opcode <= self.OPCODE_PONG):
+                self.close(self.REASON_PROTOCOL_ERROR,
+                        'Control frame payload cannot be larger than 125 bytes')
+                return None
+
+            if length_code < self.LEN_16:
+                length = length_code
+            elif length_code == self.LEN_16:
+                length = struct.unpack('!H', self._read_from_socket(2))[0]
+            elif length_code == self.LEN_64:
+                length = struct.unpack('!Q', self._read_from_socket(8))[0]
+            else:
+                raise Exception('Calculated invalid length')
+
+            mask_octets = struct.unpack('!BBBB', self._read_from_socket(4))
+            masked_payload = self._read_from_socket(length)
+
+            payload = ''
+
+            j = 0
+            for c in masked_payload:
+                # TODO: optimize me? http://www.skymind.com/~ocrow/python_string/
+                payload += chr(ord(c) ^ mask_octets[j])
+                j = (j + 1) % 4
+
+            if opcode == self.OPCODE_TEXT:
+                payload = payload.decode('utf-8')
+            elif opcode == self.OPCODE_CLOSE:
+                if length >= 2:
+                    reason, message = struct.unpack('!H%ds' % (length - 2), payload)
+                else:
+                    reason = message = None
+
+                self.close(self.REASON_NORMAL, '')
+                if not self.compatibility_mode:
+                    return (self.OPCODE_CLOSE, (reason, message))
+                else:
+                    return None
+            elif opcode == self.OPCODE_PING:
+                self.send(payload, opcode=self.OPCODE_PONG)
+                if not self.compatibility_mode:
+                    return (self.OPCODE_PING, payload)
+                else:
+                    continue
+            elif opcode == self.OPCODE_PONG:
+                if not self.compatibility_mode:
+                    return (self.OPCODE_PONG, payload)
+                else:
+                    continue
+
+            if is_final_frag:
+                if len(self._fragments) > 0:
+                    opcode = self._original_opcode
+                    self._original_opcode = -1
+                    payload = ''.join(self._fragments) + payload
+                    self._fragments = []
+                if not self.compatibility_mode:
+                    return (opcode, payload)
+                else:
+                    return payload
+            else:
+                if len(self._fragments) == 0:
+                    self._original_opcode = opcode
+                self._fragments.append(payload)
+
+    def _encode_text(self, s):
+        if isinstance(s, unicode):
+            return s.encode('utf-8')
+        elif isinstance(s, str):
+            return unicode(s).encode('utf-8')
+        else:
+            raise Exception('Invalid encoding')
+
+    def _is_opcode_invalid(self, opcode):
+        return opcode < self.OPCODE_FRAG or (opcode > self.OPCODE_BINARY and
+                opcode < self.OPCODE_CLOSE) or opcode > self.OPCODE_PONG
+
+    def send(self, message, opcode=OPCODE_TEXT):
+        """Send a frame over the websocket with message as its payload
+
+        Keyword args:
+        opcode -- the opcode to use (default OPCODE_TEXT)
+        """
+
+        if self.websocket_closed:
+            raise Exception('Connection was terminated')
+
+        if self._is_opcode_invalid(opcode):
+            raise Exception('Invalid opcode %d' % opcode)
+
+        if opcode == self.OPCODE_TEXT:
+            message = self._encode_text(message)
+
+        length = len(message)
+
+        if opcode == self.OPCODE_TEXT:
+            message = struct.pack('!%ds' % length, message)
+
+        if length < self.LEN_16:
+            preamble = struct.pack('!BB', self.FIN | opcode, length)
+        elif length < 2 ** 16:
+            preamble = struct.pack('!BBH', self.FIN | opcode, self.LEN_16, length)
+        elif length < 2 ** 64:
+            preamble = struct.pack('!BBQ', self.FIN | opcode, self.LEN_64, length)
+        else:
+            # this can't really happen, but for correctness sake...
+            raise Exception('Message is too long')
+
+        self.socket.sendall(preamble + message)
+
+    def close(self, reason, message):
+        """Close the websocket, sending the specified reason and message
+        """
+
+        message = self._encode_text(message)
+        self.send(struct.pack('!H%ds' % len(message), reason, message), opcode=self.OPCODE_CLOSE)
+        self.websocket_closed = True
+
+        # based on gevent/pywsgi.py
+        # see http://pypi.python.org/pypi/gevent#downloads
+        if self.socket is not None:
+            try:
+                self.socket._sock.close()
+                self.socket.close()
+            except socket.error:
+                pass
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.