Commits

Robert Brewer committed 0d062a1

Fix and tests for http.urljoin. Did not correctly handle blank PATH_INFO's.

Comments (0)

Files changed (2)

cherrypy/lib/http.py

 
 
 def urljoin(*atoms):
-    """Return the given path *atoms, joined into a single URL."""
-    url = "/".join(atoms)
+    """Return the given path *atoms, joined into a single URL.
+    
+    This will correctly join a SCRIPT_NAME and PATH_INFO into the
+    original URL, even if either atom is blank.
+    """
+    url = "/".join([x for x in atoms if x])
     while "//" in url:
         url = url.replace("//", "/")
-    return url
+    # Special-case the final url of "", and return "/" instead.
+    return url or "/"
 
 def protocol_from_http(protocol_str):
     """Return a protocol tuple from the given 'HTTP/x.y' string."""

cherrypy/test/test_httplib.py

+"""Tests for cherrypy/lib/http.py."""
+
+from cherrypy.test import test
+test.prefer_parent_path()
+
+import unittest
+from cherrypy.lib import http
+
+
+class UtilityTests(unittest.TestCase):
+    
+    def test_urljoin(self):
+        # Test all slash+atom combinations for SCRIPT_NAME and PATH_INFO
+        self.assertEqual(http.urljoin("/sn/", "/pi/"), "/sn/pi/")
+        self.assertEqual(http.urljoin("/sn/", "/pi"), "/sn/pi")
+        self.assertEqual(http.urljoin("/sn/", "/"), "/sn/")
+        self.assertEqual(http.urljoin("/sn/", ""), "/sn/")
+        self.assertEqual(http.urljoin("/sn", "/pi/"), "/sn/pi/")
+        self.assertEqual(http.urljoin("/sn", "/pi"), "/sn/pi")
+        self.assertEqual(http.urljoin("/sn", "/"), "/sn/")
+        self.assertEqual(http.urljoin("/sn", ""), "/sn")
+        self.assertEqual(http.urljoin("/", "/pi/"), "/pi/")
+        self.assertEqual(http.urljoin("/", "/pi"), "/pi")
+        self.assertEqual(http.urljoin("/", "/"), "/")
+        self.assertEqual(http.urljoin("/", ""), "/")
+        self.assertEqual(http.urljoin("", "/pi/"), "/pi/")
+        self.assertEqual(http.urljoin("", "/pi"), "/pi")
+        self.assertEqual(http.urljoin("", "/"), "/")
+        self.assertEqual(http.urljoin("", ""), "/")
+
+if __name__ == '__main__':
+    unittest.main()