Commits

Robert Brewer  committed 798e0e0

Fix for #783 (File uploads corrupt when using built in SSL).

  • Participants
  • Parent commits 358bb51
  • Branches 598-sendall

Comments (0)

Files changed (2)

File cherrypy/test/test_http.py

         index.exposed = True
         
         def post_multipart(self, file):
-            """Return a summary ("a * 1000000\nb * 1000000") of the uploaded file."""
+            """Return a summary ("a * 65536\nb * 65536") of the uploaded file."""
             contents = file.file.read()
             summary = []
             curchar = ""
         # By not including a Content-Length header, cgi.FieldStorage
         # will hang. Verify that CP times out the socket and responds
         # with 411 Length Required.
-        c = httplib.HTTPConnection("127.0.0.1:%s" % self.PORT)
+        if self.scheme == "https":
+            c = httplib.HTTPSConnection("127.0.0.1:%s" % self.PORT)
+        else:
+            c = httplib.HTTPConnection("127.0.0.1:%s" % self.PORT)
         c.request("POST", "/")
         self.assertEqual(c.getresponse().status, 411)
     
     def test_post_multipart(self):
         alphabet = "abcdefghijklmnopqrstuvwxyz"
         # generate file contents for a large post
-        contents = "".join([c * 1000000 for c in alphabet])
+        contents = "".join([c * 65536 for c in alphabet])
         
         # encode as multipart form data
         files=[('file', 'file.txt', contents)]
         self.assertEqual(errcode, 200)
         
         response_body = c.file.read()
-        self.assertEquals(", ".join(["%s * 1000000" % c for c in alphabet]),
+        self.assertEquals(", ".join(["%s * 65536" % c for c in alphabet]),
                           response_body)
 
 

File cherrypy/wsgiserver/__init__.py

     
     def readline(self, size=None):
         if size is not None:
-            local_bytes_seen = 0
-            seen_data = []
-            while local_bytes_seen < size:
-                data = self.rfile.readline(size-local_bytes_seen)
-                if not data:
-                    break
-                seen_data.append(data)
-                local_bytes_seen += len(data)
-                if '\n' in data:
-                    break
-
-            self.bytes_read += local_bytes_seen
+            data = self.rfile.readline(size)
+            self.bytes_read += len(data)
             self._check_length()
-            return "".join(seen_data)
+            return data
         
         # User didn't specify a size ...
         # We read the line in chunks to make sure it's not a 100MB line !
         self._check_length()
         return data
 
-class HTTPRequestSocketWrapper(object):
-    """
-    A file like wrapper for HTTP on non-blocking sockets.
-
-    IOW this class provides as much as it can of the file like 
-    interface for sockets over which HTTP requests are made.
-    """
-    def __init__(self, sock):
-        self.sock = sock
-        self.incomplete_line_buffer = []
-        self.lines_buffer = []
-    
-    def read(self, size=None):
-        raise NotImplementedError
-    
-    def readline(self):
-        # This doesn't raise the appropriate exceptions
-        # as the result may result in an Index Error?
-
-        # if we can't return their data right away, let's try to read (blocking)
-        if not self.lines_buffer:
-            self._fill_lines_buffer()
-
-        if not self.lines_buffer:
-            return ""
-        
-        return self.lines_buffer.pop(0)
-
-            
-    def _fill_lines_buffer(self):
-        while True:
-            data = self.sock.recv(256)
-
-            lines = data.split("\n")
-            # We remove the last piece of the split to ensure
-            # that subsequent processing happens only on data
-            # Representing complete lines.
-            new_incomplete_line = lines.pop()
-            
-            # If we still have data in the lines list that means
-            # we have received some data that forms a complete line
-            if lines:
-                # Ensure that we take the data left over from previous reads
-                # and join it to our current reads before appending the latest
-                # line seen
-                self.incomplete_line_buffer.append(lines.pop(0))
-                self.lines_buffer.append("".join(self.incomplete_line_buffer))
-
-                self.incomplete_line_buffer = []
-
-                # remember all other complete lines seen in this read
-                self.lines_buffer.extend(lines)
-               
-            # Record the latest new incomplete line
-            self.incomplete_line_buffer.append(new_incomplete_line)
-
-            # If they didn't specify a size and we have a line to send them
-            # stop reading
-            if self.lines_buffer:
-                return
-    
-    def readlines(self, sizehint=0):
-        raise NotImplementedError
-    
-    def close(self):
-        self.sock.close()
-    
-    def __iter__(self):
-        return self
-    
-    def next(self):
-        data = self.sock.next()
-        self.bytes_read += len(data)
-        return data
-
-    def send(self, *args, **kwargs):
-        return self.sock.send(*args, **kwargs)
-
-    def readline(self, size=None):
-        raise NotImplementedError
-
 
 class HTTPRequest(object):
     """An HTTP Request (and response).
     pass
 
 
-def _ssl_wrap_method(method, is_reader=False):
-    """Wrap the given method with SSL error-trapping.
+class SSL_fileobject(socket._fileobject):
+    """Faux file object attached to a socket object."""
     
-    is_reader: if False (the default), EOF errors will be raised.
-        If True, EOF errors will return "" (to emulate normal sockets).
-    """
-    def ssl_method_wrapper(self, *args, **kwargs):
-##        print (id(self), method, args, kwargs)
+    ssl_timeout = 3
+    ssl_retry = .01
+    
+    def _safe_call(self, is_reader, call, *args, **kwargs):
+        """Wrap the given call with SSL error-trapping.
+        
+        is_reader: if False EOF errors will be raised. If True, EOF errors
+            will return "" (to emulate normal sockets).
+        """
         start = time.time()
         while True:
             try:
-                return method(self, *args, **kwargs)
+                return call(*args, **kwargs)
             except (SSL.WantReadError, SSL.WantWriteError):
                 # Sleep and try again. This is dangerous, because it means
                 # the rest of the stack has no way of differentiating
                 raise
             if time.time() - start > self.ssl_timeout:
                 raise socket.timeout("timed out")
-    return ssl_method_wrapper
+    
+    def flush(self):
+        if self._wbuf:
+            buffer = "".join(self._wbuf)
+            self._wbuf = []
+            self._safe_call(False, self._sock.sendall, buffer)
+    
+    def read(self, size=-1):
+        data = self._rbuf
+        if size < 0:
+            # Read until EOF
+            buffers = []
+            if data:
+                buffers.append(data)
+            self._rbuf = ""
+            if self._rbufsize <= 1:
+                recv_size = self.default_bufsize
+            else:
+                recv_size = self._rbufsize
+            
+            while True:
+                data = self._safe_call(True, self._sock.recv, recv_size)
+                if not data:
+                    break
+                buffers.append(data)
+            return "".join(buffers)
+        else:
+            # Read until size bytes or EOF seen, whichever comes first
+            buf_len = len(data)
+            if buf_len >= size:
+                self._rbuf = data[size:]
+                return data[:size]
+            buffers = []
+            if data:
+                buffers.append(data)
+            self._rbuf = ""
+            while True:
+                left = size - buf_len
+                recv_size = max(self._rbufsize, left)
+                data = self._safe_call(True, self._sock.recv, recv_size)
+                if not data:
+                    break
+                buffers.append(data)
+                n = len(data)
+                if n >= left:
+                    self._rbuf = data[left:]
+                    buffers[-1] = data[:left]
+                    break
+                buf_len += n
+            return "".join(buffers)
 
-class SSL_fileobject(socket._fileobject):
-    """Faux file object attached to a socket object."""
-    
-    ssl_timeout = 3
-    ssl_retry = .01
-    
-    close = _ssl_wrap_method(socket._fileobject.close)
-    flush = _ssl_wrap_method(socket._fileobject.flush)
-    write = _ssl_wrap_method(socket._fileobject.write)
-    writelines = _ssl_wrap_method(socket._fileobject.writelines)
-    read = _ssl_wrap_method(socket._fileobject.read, is_reader=True)
-    readline = _ssl_wrap_method(socket._fileobject.readline, is_reader=True)
-    readlines = _ssl_wrap_method(socket._fileobject.readlines, is_reader=True)
+    def readline(self, size=-1):
+        data = self._rbuf
+        if size < 0:
+            # Read until \n or EOF, whichever comes first
+            if self._rbufsize <= 1:
+                # Speed up unbuffered case
+                assert data == ""
+                buffers = []
+                while data != "\n":
+                    data = self._safe_call(True, self._sock.recv, 1)
+                    if not data:
+                        break
+                    buffers.append(data)
+                return "".join(buffers)
+            nl = data.find('\n')
+            if nl >= 0:
+                nl += 1
+                self._rbuf = data[nl:]
+                return data[:nl]
+            buffers = []
+            if data:
+                buffers.append(data)
+            self._rbuf = ""
+            while True:
+                data = self._safe_call(True, self._sock.recv, self._rbufsize)
+                if not data:
+                    break
+                buffers.append(data)
+                nl = data.find('\n')
+                if nl >= 0:
+                    nl += 1
+                    self._rbuf = data[nl:]
+                    buffers[-1] = data[:nl]
+                    break
+            return "".join(buffers)
+        else:
+            # Read until size bytes or \n or EOF seen, whichever comes first
+            nl = data.find('\n', 0, size)
+            if nl >= 0:
+                nl += 1
+                self._rbuf = data[nl:]
+                return data[:nl]
+            buf_len = len(data)
+            if buf_len >= size:
+                self._rbuf = data[size:]
+                return data[:size]
+            buffers = []
+            if data:
+                buffers.append(data)
+            self._rbuf = ""
+            while True:
+                data = self._safe_call(True, self._sock.recv, self._rbufsize)
+                if not data:
+                    break
+                buffers.append(data)
+                left = size - buf_len
+                nl = data.find('\n', 0, left)
+                if nl >= 0:
+                    nl += 1
+                    self._rbuf = data[nl:]
+                    buffers[-1] = data[:nl]
+                    break
+                n = len(data)
+                if n >= left:
+                    self._rbuf = data[left:]
+                    buffers[-1] = data[:left]
+                    break
+                buf_len += n
+            return "".join(buffers)
     
     def send(self, *args, **kwargs):
-        return self._sock.send(*args, **kwargs)
-    send = _ssl_wrap_method(send)
+        return self._safe_call(False, self._sock.send, *args, **kwargs)
+
 
 class HTTPConnection(object):
     """An HTTP connection (active socket).
             errnum = e.args[0]
             if errnum not in socket_errors_to_ignore:
                 if req:
-                    fd = open ("ssl_errors.txt", "a")
+                    fd = open("ssl_errors.txt", "a")
                     fd.write("1" * 80)
                     fd.write("\n")
                     fd.write(str(type(e)))
-                    fd.write( format_exc())
+                    fd.write(format_exc())
                     req.simple_response("500 Internal Server Error",
                                         format_exc())
             return
                                 "this server only speaks HTTPS on this port.")
         except Exception, e:
             if req:
-                fd = open ("ssl_errors.txt", "a")
-                fd.write("2" * 80 )
+                fd = open("ssl_errors.txt", "a")
+                fd.write("2" * 80)
                 fd.write("\n")
-                fd.write(str(type(e)))
-                fd.write( format_exc())
+                fd.write(format_exc())
                 req.simple_response("500 Internal Server Error", format_exc())
     
     def close(self):