Jeffrey Gelens avatar Jeffrey Gelens committed 6dac69f

Accepted Sardar Yumatov's patch to accept only predefined paths to upgrade the
socket

Comments (0)

Files changed (1)

geventwebsocket/handler.py

 import re
 import struct
-import time
-import traceback
-import sys
 from hashlib import md5
+
 from gevent.pywsgi import WSGIHandler
 from geventwebsocket import WebSocket
 
+
+class HandShakeError(ValueError):
+    """ Hand shake challenge can't be parsed """
+    pass
+
+
 class WebSocketHandler(WSGIHandler):
+    """ Automatically upgrades the connection to websockets. """
     def __init__(self, *args, **kwargs):
         self.websocket_connection = False
+        self.allowed_paths = []
+
+        for expression in kwargs.pop('allowed_paths', []):
+            if isinstance(expression, basestring):
+                self.allowed_paths.append(re.compile(expression))
+            else:
+                self.allowed_paths.append(expression)
+
         super(WebSocketHandler, self).__init__(*args, **kwargs)
 
     def handle_one_response(self, call_wsgi_app=True):
         # we will proceed with the default PyWSGI functionality.
         if self.environ.get("HTTP_CONNECTION") != "Upgrade" or \
            self.environ.get("HTTP_UPGRADE") != "WebSocket" or \
-           not self.environ.get("HTTP_ORIGIN"):
+           not self.environ.get("HTTP_ORIGIN") or \
+           not self.accept_upgrade():
             return super(WebSocketHandler, self).handle_one_response()
         else:
             self.websocket_connection = True
         else:
             return
 
+    def accept_upgrade(self):
+        """
+        Returns True if request is allowed to be upgraded.
+        If self.allowed_paths is non-empty, self.environ['PATH_INFO'] will
+        be matched against each of the regular expressions.
+        """
+
+        if self.allowed_paths:
+            path_info = self.environ.get('PATH_INFO', '')
+
+            for regexps in self.allowed_paths:
+                return regexps.match(path_info)
+        else:
+            return True
+
     def write(self, data):
         if self.websocket_connection:
             self.wfile.writelines(data)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.