Commits

Anonymous committed 881cfd6

Handling of extensions parameters; basic extension processing

Comments (0)

Files changed (3)

+class ExtensionRefused(Exception):
+    """
+    Returned by the server or client if the extension should not be accepted
+    """
+    pass
+
+
+class Extension:
+    """
+    Base class for extension
+
+    @ivar name: Extension name
+    @type name: C{str}
+
+    @ivar params: Token-value extension parameters, sent by the client in the HTTP headers
+                  It can be passed as class or constructor attribute
+    @type name: C{dict}
+    """
+
+    name = ""
+    params = {}
+
+    def __init__(self, params=None):
+        if params:
+            self.params = params
+        if self.params:
+            self.requestHeader = "%s;%s" % (self.name,
+                                            ";".join("%s=%s" % (p, v)
+                                                     for p, v in self.params.items()))
+        else:
+            self.requestHeader = self.name
+
+
+    def dataLength(self):
+        """
+        Return the extension data length (server only)
+        """
+        return 0
+
+
+    def negociate(self, params):
+        """
+        Return a token-value parameter dict to be sent back in the response headers
+        by the server, in the case the server proposes parameters for this extension.
+        Thus this should be used only on client.
+        Raises C{ExtensionRefused} if the extension is not accepted.
+
+        @param params: Parameters for the extension as sent in the request headers
+        @type params: C{dict}
+        """
+        return {}
+
+
+    def processIncomingFrame(self, frame):
+        """
+        Process the given incoming frame
+        before it is processed by the handler.
+        Return the processed frame.
+        """
+        return frame
+
+
+    def processIncomingFragment(self, fragment):
+        """
+        Process the given incoming fragment
+        before it is processed by the handler.
+        Return the processed fragment.
+        """
+        return fragment
+
+
+    def processOutgoingFrame(self, frame):
+        """
+        Process the given frame before it is send.
+        Return the processed frame.
+        """
+        return frame
+
+
+    def processOutgoingFragment(self, fragment):
+        """
+        Process the given fragment before it is send.
+        Return the processed fragment.
+        """
+        return fragment
+
+
+
+class ZipExtension(Extension):
+    """
+    Basic extension for zip compression
+    """
+    name = 'zip'
+    def processIncomingFrame(self, frame):
+        from zlib import decompress
+        return decompress(frame)
+    processBinaryIncomingFrame = processIncomingFrame
+
+    def processOutgoingFrame(self, frame):
+        from zlib import compress
+        return compress(frame)
+    processOutgoingBinaryFrame = processOutgoingFrame
+

test_websocket_client.py

 from websocket import OPCODE_PING
 
 from websocket_client import WebSocketClientHandler, WebSocketClient, WebSocketClientFactory, HandshakeResponseError
+from extension import Extension, ZipExtension
 
 #from twisted.internet.base import DelayedCall
 #DelayedCall.debug = True
 class TestHandlerWithSubProtocols(TestHandler):
     def beepbeep(self):
         pass
+
     subprotocolsAvailable = {'beepbeep.acme.com': beepbeep}
 
 
+class RoadRunnerExtension(Extension):
+    name = 'roadrunner.acme.com'
+    params = {'arg1': 'foo'}
+
+
 class TestHandlerWithExtensions(TestHandler):
-    def beepbeep(self):
-        pass
-    extensionsAvailable = {'beepbeep.acme.com': beepbeep}
+    extensionsAvailable = [ZipExtension]
 
 
 class TestClientFactory(WebSocketClientFactory):
 
 
 class WebSocketClientTestCase(unittest.TestCase):
+    simpleData = "bot"
     def __init__(self, *args, **kwargs):
         self.connectionEstablishedDeferred = defer.Deferred()
         self.frameReceivedDeferred = defer.Deferred()
     def clean_disconnection(self, reason):
         self.assertIsInstance(reason.value, ConnectionDone)
 
+    def send(self, client, method=None, data=None):
+        ## Client is got from the callback chain, actually ignored since it is
+        ## already set to self.client by got_tcp_connection
+        if method == None:
+            method = WebSocketClient.write
+        if data == None:
+            data = self.simpleData
+        self.sent_data = data
+        method(client, data)
+
+    def check_response(self, data):
+        self.assertEqual(data, self.sent_data)
+        self.assertEqual(self.client.factory.subprotocol, None)
+        self.client.close()
 
 
 class WebSocketSimpleClientTestCase(WebSocketClientTestCase):
     """
     Tests for L{WebSocketClient}.
     """
-    simpleData = "bot"
     def setUp(self):
         self.site = WebSocketSite(Resource())
         self.site.addHandler("/bar", TestHandler)
     def tearDown(self):
         self.p.stopListening()
 
-    def check_response(self, data):
-        self.assertEqual(data, self.sent_data)
-        self.assertEqual(self.client.factory.subprotocol, None)
-        self.client.close()
-
-    def send(self, client, method=None, data=None):
-        ## Client is got from the callback chain, actually ignored since it is
-        ## already set to self.client by got_tcp_connection
-        if method == None:
-            method = WebSocketClient.write
-        if data == None:
-            data = self.simpleData
-        self.sent_data = data
-        #self.client.sendFrame(OPCODE_PING, self.data)
-        method(client, data)
-
     def test_SimpleWrite(self):
         self.connection_deferred.addCallback(self.send)
         self.frameReceivedDeferred.addCallback(self.check_response)
         url = 'ws://localhost:%d/bogus' % self.p.getHost().port
         self.factory = TestClientFactory(self, url)
         self.handshakeErrorDeferred.addCallback(self.got_response)
+
         point = TCP4ClientEndpoint(reactor, "localhost", self.p.getHost().port)
         point.connect(self.factory).addCallback(self.got_tcp_connection)
         return self.handshakeErrorDeferred
             self.client.close()
 
         url = 'ws://localhost:%d/bar' % self.p.getHost().port
-        self.factory = TestClientFactory(self, url, subprotocolsAvailable={'roadrunner.acme.com': self.roadrunner,
-                                                                  'beepbeep.acme.com': self.beepbeep})
+        self.factory = TestClientFactory(self, url,
+                                         subprotocolsAvailable={'roadrunner.acme.com': self.roadrunner,
+                                                                'beepbeep.acme.com': self.beepbeep})
+        self.connectionEstablishedDeferred.addCallback(got_connection)
+
         point = TCP4ClientEndpoint(reactor, "localhost", self.p.getHost().port)
-        self.connectionEstablishedDeferred.addCallback(got_connection)
         point.connect(self.factory).addCallback(self.got_tcp_connection)
         return self.connectionLostDeferred.addErrback(self.clean_disconnection)
 
         url = 'ws://localhost:%d/bar' % self.p.getHost().port
         self.factory = TestClientFactory(self, url, subprotocolsAvailable={'bogus.acme.com': None,
                                                                   'bogos.acme.com': None})
+        self.handshakeErrorDeferred.addCallback(got_response)
 
-        self.handshakeErrorDeferred.addCallback(got_response)
         point = TCP4ClientEndpoint(reactor, "localhost", self.p.getHost().port)
         point.connect(self.factory).addCallback(self.got_tcp_connection)
         return self.handshakeErrorDeferred
     def beepbeep(self):
         pass
 
-    def test_AcceptExtensions(self):
+    def test_RefuseExtensions(self):
+        """
+        Test that the connection is accepted even when the extention proposed by the client is not
+        by the server, and that the extension is not selected.
+        """
         def got_connection(response):
-            self.assertEqual(self.client.factory.extensions, [self.beepbeep])
+            self.assertEqual(self.client.factory.extensions, [])
             self.client.close()
 
         url = 'ws://localhost:%d/bar' % self.p.getHost().port
-        self.factory = TestClientFactory(self, url, extensionsAvailable={'roadrunner.acme.com': self.roadrunner,
-                                                                'beepbeep.acme.com': self.beepbeep})
+        self.factory = TestClientFactory(self, url,
+                                         extensionsAvailable=[RoadRunnerExtension()])
+        self.connectionEstablishedDeferred.addCallback(got_connection)
+
         point = TCP4ClientEndpoint(reactor, "localhost", self.p.getHost().port)
-        self.connectionEstablishedDeferred.addCallback(got_connection)
         point.connect(self.factory).addCallback(self.got_tcp_connection)
         return self.connectionLostDeferred.addErrback(self.clean_disconnection)
 
-    def test_RefuseExtensions(self):
+
+    def test_AcceptExtensions(self):
+        """
+        Test that the connection is accepted when at least one extension is accepted
+        Also, test that the parameters are duely selected.
+        """
         def got_connection(response):
-            self.assertEqual(self.client.factory.extensions, [])
-            self.client.trace_lost_connection = lambda _: None
+            self.assertEqual(len(self.client.factory.extensions), 1)
+            self.assertIsInstance(self.client.factory.extensions[0], ZipExtension)
+            self.assertEqual(self.client.factory.extensions[0].params, {'arg1': 'foo', 'arg2': 'bar'})
             self.client.close()
 
         url = 'ws://localhost:%d/bar' % self.p.getHost().port
-        self.factory = TestClientFactory(self, url, extensionsAvailable={'bogus.acme.com': lambda _:None,
-                                                                  'bogos.acme.com': lambda _:None})
+        self.factory = TestClientFactory(self, url,
+                                         extensionsAvailable=[
+                                                              RoadRunnerExtension(),
+                                                              ZipExtension(params={'arg1':'foo', 'arg2':'bar'})
+                                                              ])
+        self.connectionEstablishedDeferred.addCallback(got_connection)
+
         point = TCP4ClientEndpoint(reactor, "localhost", self.p.getHost().port)
-        self.connectionEstablishedDeferred.addCallback(got_connection)
         point.connect(self.factory).addCallback(self.got_tcp_connection)
         return self.connectionLostDeferred.addErrback(self.clean_disconnection)
 
+
+    def test_ProcessExtension(self):
+        """
+        Test the processing of the selected extensions
+        """
+        def send(_):
+            self.sent_data = self.simpleData*1000
+            self.client.write(self.sent_data)
+        url = 'ws://localhost:%d/bar' % self.p.getHost().port
+        self.factory = TestClientFactory(self, url,
+                                         extensionsAvailable=[ZipExtension()])
+        self.connectionEstablishedDeferred.addCallback(send)
+        self.frameReceivedDeferred.addCallback(self.check_response)
+
+        point = TCP4ClientEndpoint(reactor, "localhost", self.p.getHost().port)
+        point.connect(self.factory).addCallback(self.got_tcp_connection)
+        return self.connectionLostDeferred.addErrback(self.clean_disconnection)
+
 from twisted.internet.protocol import ClientFactory, Protocol
 from twisted.protocols.policies import TimeoutMixin
 from twisted.protocols.basic import _PauseableMixin
+from extension import ExtensionRefused
 
 try:
     from twisted.internet import ssl
                   ]
 
         if self.factory.subprotocolsAvailable:
-            field_list.append("Sec-WebSocket-Protocol: %s\r\n" % ",".join(self.factory.subprotocolsAvailable))
+            field_list.append("Sec-WebSocket-Protocol: %s\r\n"
+                              % ",".join(self.factory.subprotocolsAvailable))
 
         if self.factory.extensionsAvailable:
-            field_list.append("Sec-WebSocket-Extensions: %s\r\n" % ",".join(self.factory.extensionsAvailable))
+            field_list.append("Sec-WebSocket-Extensions: %s\r\n"
+                              % ",".join(ext.requestHeader
+                                         for ext in self.factory.extensionsAvailable))
 
         if self.factory.extra_headers != None:
             field_list.extend(self.factory.extra_headers)
         if opcode not in ALL_OPCODES:
             raise ValueError("Invalid opcode 0x%X" % opcode)
 
+        for ext in self.factory.extensions:
+            payload = ext.processOutgoingFrame(payload)
+
         length = len(payload)
 
         # there's always the header and at least one length field
         self.factory.extensions = []
         if "Sec-WebSocket-Extensions" in headers:
             for ext in headers["Sec-WebSocket-Extensions"].split(','):
-                if ext in self.factory.extensionsAvailable:
-                    self.factory.extensions.append(self.factory.extensionsAvailable[ext])
-                else:
-                    raise HandshakeResponseError, ("Invalid extension: %s"
-                                                   % ext)
+                ext_params = ext.split(';')
+                ext_name = ext_params.pop(0)
+                try:
+                    if ext_name in self.factory._extensionsAvailableByName:
+                        self.factory._addExtension(ext_name, ext_params)
+                    else:
+                        raise HandshakeResponseError, (
+                                        "Invalid extension: %s" % ext_name)
+                except ExtensionRefused, error:
+                    raise HandshakeResponseError, ("Extension %s does not "
+                                        "accept such parameters" % ext_name)
 
         ## Check protocols
         if "Sec-WebSocket-Protocol" in headers:
     @type extra_headers: C{dict}
     """
     protocol = WebSocketClient
+    extensions = []
 
     def __init__(self,
                  url,
         """
         If not present, origin will be set to the host
         """
-        self.setURL(url)
+        self._setURL(url)
         if origin == None:
             self.origin = self.host
         else:
             self.extra_headers = None
         else:
             self.extra_headers = ["%s: %s\r\n" % (k, v) for k, v in extra_headers.items()]
+        self._extensionsAvailableByName = dict((e.name, e) for e in self.extensionsAvailable)
 
-    def setURL(self, url):
+    def _setURL(self, url):
         self.url = url
         parsed_url = urlparse(url)
         if parsed_url.scheme and parsed_url.hostname:
                     self.port = 80
 
 
+    def _addExtension(self, name, params):
+        """
+        Add the extension which name is in parameter name (must be in
+        extensionsAvailable)
+        """
+        self._extensionsAvailableByName[name].negociate(params=dict((p.split('=')) for p in params))
+        self.extensions.append(self._extensionsAvailableByName[name])
+
+
 __all__ = ["WebSocketClient", "WebSocketClientFactory", "WebSocketClientHandler"]