Anonymous avatar Anonymous committed 5e83690

Add checks of extensions and subprotocols (actual implementation of them pending)

Comments (0)

Files changed (1)

websocket_client.py

     def __getitem__(self, key):
         return super(CaseInsensitiveDict, self).__getitem__(key.lower())
 
+    def __contains__(self, key):
+        return super(CaseInsensitiveDict, self).__contains__(key.lower())
 
 
 class HandshakeResponseError(Exception):
                   "Sec-WebSocket-Version: 13\r\n"
                   ]
 
-        if self.factory.subprotocols != []:
-            field_list.append("Sec-WebSocket-Protocol: %s\r\n" % ",".join(self.factory.subprotocols))
+        if self.factory.subprotocolsAvailable:
+            field_list.append("Sec-WebSocket-Protocol: %s\r\n" % ",".join(self.factory.subprotocolsAvailable))
 
-        if self.factory.extensions != []:
-            field_list.append("Sec-WebSocket-Extensions: %s\r\n" % ",".join(self.factory.extensions))
+        if self.factory.extensionsAvailable:
+            field_list.append("Sec-WebSocket-Extensions: %s\r\n" % ",".join(self.factory.extensionsAvailable))
 
         if self.factory.extra_headers != None:
             field_list.extend(self.factory.extra_headers)
             raise HandshakeResponseError, "No Sec-WebSocket-Accept header"
         expected_key = b64encode(sha1(self.nonce + GUID).digest())
         if sec != expected_key:
-            raise HandshakeResponseError, ("Invalid key: expected %s, received %s"
-                                           % (expected_key, sec))
-        self.factory.extensions_in_use = headers.get("Sec-WebSocket-Extensions", "").split(',')
-        self.factory.protocol_in_use = headers.get("Sec-WebSocket-Protocol", "").split(',')
-        if self.factory.protocol_in_use != [""] and self.factory.protocol_in_use not in self.factory.subprotocols:
-            self.factory.doStop()
-            raise HandshakeResponseError, ("Invalid protocol: %s"
-                                           % (self.factory.protocol_in_use))
+            raise HandshakeResponseError, "Invalid key"
+
+        ## Check extensions
+        self.factory.extensions = []
+        a
+        if "Sec-WebSocket-Extensions" in headers:
+            for ext in headers["Sec-WebSocket-Extensions"].split(','):
+                if ext in self.factory.extensionsAvailable:
+                    self.factory.extensions.append(self.factory.extensionsAvailable[ext])
+                else:
+                    self.factory.doStop()
+                    raise HandshakeResponseError, ("Invalid extension: %s"
+                                                   % ext)
+
+        ## Check protocols
+        if "Sec-WebSocket-Protocol" in headers:
+            subprotocol = headers["Sec-WebSocket-Protocol"]
+            if subprotocol in self.factory.subprotocolsAvailable:
+                self.factory.subprotocol = self.factory.subprotocolsAvailable[subprotocol]
+            else:
+                self.factory.doStop()
+                raise HandshakeResponseError, ("Invalid protocol: %s"
+                                               % (self.factory.subprotocol))
+        else:
+            self.factory.subprotocol = None
+
+        ## Return the remaining data after the handshake
         return handshake_as_list[-1]
 
 
     @type url: C{str}
     @param origin: Host name of the client. Defaults to host
     @type origin: C{str}
-    @param subprotocols: Subprotocols
-    @type subprotocols: C{list}
-    @param extensions: Extensions
-    @type extensions: C{list}
+    @param subprotocolsAvailable: Available subprotocols (TBD)
+    @type subprotocolsAvailable: C{list}
+    @param extensionsAvailable: Extensions
+    @type extensionsAvailable: C{list}
     @param extra_headers: Extra headers to send with the handshake
     @type extra_headers: C{dict}
     """
     def __init__(self,
                  url,
                  origin=None,
-                 subprotocols=[],
-                 extensions=[],
+                 subprotocolsAvailable={},
+                 extensionsAvailable={},
                  extra_headers=None):
         """
         If not present, origin will be set to the host
             self.origin = self.host
         else:
             self.origin = origin
-        self.subprotocols = subprotocols
-        self.extensions = extensions
+        self.subprotocolsAvailable = subprotocolsAvailable
+        self.extensionsAvailable = extensionsAvailable
         if extra_headers == None:
             self.extra_headers = None
         else:
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.