Source

wsgi3k / byte_server.py

#!/usr/bin/env python3
"""
    WSGI for Python 3 Server
    ~~~~~~~~~~~~~~~~~~~~~~~~

    This module implements a byte-based server for Python 3.

    :copyright: (c) 2009 by Armin Ronacher.
    :license: MIT
"""
import os
import socket
import sys
import time
import traceback
from urllib.parse import urlparse, unquote_to_bytes
from itertools import chain
from socketserver import ThreadingMixIn, ForkingMixIn
from http.server import HTTPServer, BaseHTTPRequestHandler


def internal_server_error(environ, start_response):
    response = '<h1>Internal Server Error</h1>'.encode('utf-8')
    start_response(b'500 INTERNAL SERVER ERROR', [
        (b'Content-Type', b'text/html; charset=utf-8'),
        (b'Content-Length', str(len(response)).encode('ascii'))
    ])
    return [response]


class WSGIRequestHandler(BaseHTTPRequestHandler):
    """A request handler that implements WSGI dispatching."""

    def make_environ(self):
        path = self.path.encode('iso-8859-1')
        if b'?' in path:
            path_info, query = path.split(b'?', 1)
        else:
            path_info = path
            query = b''
        environ = {
            'wsgi.version':         (1, 1),
            'wsgi.url_scheme':      b'http', # i made that bytes for consistency
            'wsgi.input':           self.rfile,
            'wsgi.errors':          sys.stderr,
            'wsgi.multithread':     self.server.multithread,
            'wsgi.multiprocess':    self.server.multiprocess,
            'wsgi.run_once':        False,
            'SERVER_SOFTWARE':      self.server_version.encode('ascii'),
            'REQUEST_METHOD':       self.command.encode('iso-8859-1'),
            'SCRIPT_NAME':          b'',
            'PATH_INFO':            unquote_to_bytes(path_info),
            'QUERY_STRING':         query,
            'CONTENT_TYPE':         self.headers.get('Content-Type', '')
                                        .encode('iso-8859-1'),
            'CONTENT_LENGTH':       self.headers.get('Content-Length', '')
                                        .encode('iso-8859-1'),
            'REMOTE_ADDR':          self.client_address[0].encode('ascii'),
            'REMOTE_PORT':          str(self.client_address[1]).encode('ascii'),
            'SERVER_NAME':          self.server.server_address[0],
            'SERVER_PORT':          str(self.server.server_address[1]).encode('ascii'),
            'SERVER_PROTOCOL':      self.request_version
        }

        for key, value in self.headers.items():
            key = ('HTTP_' + key.upper().replace('-', '_'))
            if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'):
                environ[key] = value.encode('iso-8859-1')

        return environ

    def run_wsgi(self):
        app = self.server.app
        environ = self.make_environ()
        headers_set = []
        headers_sent = []

        def write(data):
            assert headers_set, 'write() before start_response'
            if not headers_sent:
                status, response_headers = headers_sent[:] = headers_set
                code, msg = status.split(None, 1)
                self.send_response(int(code), msg)
                header_keys = set()
                for key, value in response_headers:
                    self.send_header(key.decode('iso-8859-1'),
                                     value.decode('iso-8859-1'))
                    key = key.lower()
                    header_keys.add(key)
                if 'content-length' not in header_keys:
                    self.close_connection = True
                    self.send_header('Connection', 'close')
                if 'server' not in header_keys:
                    self.send_header('Server', self.version_string())
                if 'date' not in header_keys:
                    self.send_header('Date', self.date_time_string())
                self.end_headers()

            assert type(data) is bytes, 'applications must write bytes'
            self.wfile.write(data)
            self.wfile.flush()

        def start_response(status, response_headers, exc_info=None):
            if exc_info:
                try:
                    if headers_sent:
                        raise exc_info[1].with_tracback(exc_info[2])
                finally:
                    exc_info = None
            elif headers_set:
                raise AssertionError('Headers already set')
            headers_set[:] = [status, response_headers]
            return write

        def execute(app):
            application_iter = app(environ, start_response)
            try:
                for data in application_iter:
                    write(data)
                # make sure the headers are sent
                if not headers_sent:
                    write('')
            finally:
                if hasattr(application_iter, 'close'):
                    application_iter.close()
                application_iter = None

        try:
            execute(app)
        except (socket.error, socket.timeout) as e:
            self.connection_dropped(e, environ)
        except:
            if self.server.passthrough_errors:
                raise
            exc = traceback.format_exc()
            try:
                # if we haven't yet sent the headers but they are set
                # we roll back to be able to set them again.
                if not headers_sent:
                    del headers_set[:]
                execute(internal_server_error)
            except:
                pass
            print('Error on request:\n%s', exc, file=sys.stderr)

    def handle(self):
        """Handles a request ignoring dropped connections."""
        try:
            return BaseHTTPRequestHandler.handle(self)
        except (socket.error, socket.timeout) as e:
            self.connection_dropped(e)

    def connection_dropped(self, error, environ=None):
        """Called if the connection was closed by the client.  By default
        nothing happens.
        """

    def handle_one_request(self):
        """Handle a single HTTP request."""
        self.raw_requestline = self.rfile.readline()
        if not self.raw_requestline:
            self.close_connection = 1
        elif self.parse_request():
            return self.run_wsgi()

    def send_response(self, code, message=None):
        """Send the response header and log the response code."""
        self.log_request(code)
        if message is None:
            message = code in self.responses and self.responses[code][0] or ''
        if self.request_version != 'HTTP/0.9':
            self.wfile.write(self.protocol_version.encode('ascii'))
            self.wfile.write(str(code).encode('ascii') + b' ' + message + b'\r\n')

    def version_string(self):
        return BaseHTTPRequestHandler.version_string(self).strip()

    def address_string(self):
        return self.client_address[0]


#: backwards compatible name if someone is subclassing it
BaseRequestHandler = WSGIRequestHandler


class BaseWSGIServer(HTTPServer, object):
    """Simple single-threaded, single-process WSGI server."""
    multithread = False
    multiprocess = False

    def __init__(self, host, port, app, handler=None,
                 passthrough_errors=False):
        if handler is None:
            handler = WSGIRequestHandler
        HTTPServer.__init__(self, (host, int(port)), handler)
        self.app = app
        self.passthrough_errors = passthrough_errors

    def serve_forever(self):
        try:
            HTTPServer.serve_forever(self)
        except KeyboardInterrupt:
            pass

    def handle_error(self, request, client_address):
        if self.passthrough_errors:
            raise
        else:
            return HTTPServer.handle_error(self, request, client_address)


class ThreadedWSGIServer(ThreadingMixIn, BaseWSGIServer):
    """A WSGI server that does threading."""
    multithread = True


class ForkingWSGIServer(ForkingMixIn, BaseWSGIServer):
    """A WSGI server that does forking."""
    multiprocess = True

    def __init__(self, host, port, app, processes=40, handler=None,
                 passthrough_errors=False):
        BaseWSGIServer.__init__(self, host, port, app, handler,
                                passthrough_errors)
        self.max_children = processes


def make_server(host, port, app=None, threaded=False, processes=1,
                request_handler=None, passthrough_errors=False):
    """Create a new server instance that is either threaded, or forks
    or just processes one request after another.
    """
    if threaded and processes > 1:
        raise ValueError("cannot have a multithreaded and "
                         "multi process server.")
    elif threaded:
        return ThreadedWSGIServer(host, port, app, request_handler,
                                  passthrough_errors)
    elif processes > 1:
        return ForkingWSGIServer(host, port, app, processes, request_handler,
                                 passthrough_errors)
    else:
        return BaseWSGIServer(host, port, app, request_handler,
                              passthrough_errors)


def run_simple(hostname, port, application, threaded=False,
               processes=1, request_handler=None, passthrough_errors=False):
    srv = make_server(hostname, port, application, threaded,
                      processes, request_handler,
                      passthrough_errors)
    display_hostname = hostname or '127.0.0.1'
    print(' * Running on http://%s:%d/' % (display_hostname, port))
    srv.serve_forever()


if __name__ == '__main__':
    def application(environ, start_response):
        start_response(b'200 OK', [(b'Content-Type', b'text/plain')])
        return ['\n'.join('%s=%s' % item
                          for item in environ.items()).encode('utf-8')]
    run_simple('localhost', 3000, application)