Source

CherryPy / cherrypy / wsgiserver / ssl_pyopenssl.py

Diff from to

cherrypy/wsgiserver/ssl_pyopenssl.py

 import threading
 import time
 
-from cherrypy import wsgiserver
+from cherrypy import wsgiserver, config
 
 try:
     from OpenSSL import SSL
     This is needed for cheaper "chained root" SSL certificates, and should be
     left as None if not required."""
 
-    def __init__(self, certificate, private_key, certificate_chain=None):
+    def __init__(self, certificate, private_key, certificate_chain=None,
+                 client_CA=None):
         if SSL is None:
             raise ImportError("You must install pyOpenSSL to use HTTPS.")
 
         self.certificate = certificate
         self.private_key = private_key
         self.certificate_chain = certificate_chain
+        self.client_CA = client_CA or config.get("server.ssl_client_CA")
         self._environ = None
 
+        self.check_host = config.get("server.ssl_client_check_host", False)
+        check = config.get("server.ssl_client_check", "ignore")
+        if check == "ignore":
+            self.check = SSL.VERIFY_NONE
+        elif check == "optional":
+            self.check = SSL.VERIFY_PEER
+        elif check == "required":
+            self.check = SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT
+        else:
+            raise ValueError("server.ssl_client_check must be one of 'ignore',"
+                             "'optional','required'")
+
     def bind(self, sock):
         """Wrap and return the given socket."""
         if self.context is None:
         if self.certificate_chain:
             c.load_verify_locations(self.certificate_chain)
         c.use_certificate_file(self.certificate)
+
+        if self.client_CA:
+            c.load_client_ca(self.client_CA)
+
+            c.set_verify_depth(2)
+            c.load_verify_locations(self.client_CA)
+
+            def callback(conn, cert, errno, depth, retcode):
+                if retcode and depth < 1 and self.check_host:
+                    try:
+                        assert self.address_matches(conn.getpeername(),
+                                                    cert.get_subject().commonName)
+                    except:
+                        return False
+                return retcode
+
+            c.set_verify(self.check, callback)
         return c
 
     def get_environ(self):