Commits

Hong Minhee committed 3ff6dd0

Added `vlastic.rom.decorator.NegotiativeView`, a higher-order view functor class.

Comments (0)

Files changed (5)

test/http_test.py

         msg = Message(self.start_line, dict(self.headers), self.body)
         self.assertEquals(self.headers, msg.headers)
 
+    def test_init_with_str_headers(self):
+        msg = Message(self.start_line, [("Accept", "text/plain")])
+        self.assertEquals([(b"Accept", b"text/plain")], msg.headers)
+
     def test_init_with_start_line_tuple(self):
         msg = Message(("GET", "/abc/def", "HTTP/1.1"), self.headers, self.body)
         self.assertEquals(self.start_line, msg.start_line)
         )
         self.assertEquals(b"123", b"".join(msg.body))
 
+    def test_contains_bytes(self):
+        self.assertTrue(b"Content-Length" in self.msg)
+        self.assertTrue(b"content-LENGTH" in self.msg)
+        self.assertFalse(b"Connection" in self.msg)
+
+    def test_contains_str(self):
+        self.assertTrue("Content-Length" in self.msg)
+        self.assertTrue("content-LENGTH" in self.msg)
+        self.assertFalse("Connection" in self.msg)
+
+    def test_contains_typeerror(self):
+        self.assertRaises(TypeError, lambda: 1 in self.msg)
+
     def test_getitem_bytes(self):
         self.assertEquals(b"128", self.msg[b"Content-Length"])
         self.assertEquals(b"128", self.msg[b"content-LENGTH"])
         self.assertEquals("/", str(req.path))
         self.assertEquals(self.headers, req.headers)
 
+    def test_accept(self):
+        req = Request(self.start_line, self.headers)
+        self.assertTrue(isinstance(req.accept, NegotiativeHeader))
+        self.assertTrue("image/gif" in req.accept)
+        self.assertTrue("image/jpeg" in req.accept)
+        self.assertTrue("image/pjpeg" in req.accept)
+
+    def test_accept_none(self):
+        req = Request(self.start_line, {})
+        self.assertEquals(id(None), id(req.accept))
+
 
 class ResponseTest(unittest.TestCase):
     """Unit tests for Response class."""
         self.assertEquals("Not Found", res.reason_phrase)
 
     def test_init_each(self):
-        res = Response((b"HTTP/1.1", 200, b"OK"), b"abcde")
+        res = Response((b"HTTP/1.1", 200, b"OK"), {}, b"abcde")
         self.assertEquals(200, res.status_code)
         self.assertEquals("OK", res.reason_phrase)
 
         self.res = MethodNotAllowedError()
 
 
+class NotAcceptableErrorTest(ErrorResponseTest):
+    """Unit tests for NotAcceptableError (406) class."""
+
+    def setUp(self):
+        self.status_code, self.reason_phrase = 406, "Not Acceptable"
+        self.res = NotAcceptableError()
+
 
 class NegotiativeHeaderTest(unittest.TestCase):
     """Unit tests for NegotiativeHeader class."""
         response = partial(self.root.__getitem__, 1)
         self.assertRaises(KeyError, response)
 
+
+class NegotiativeViewTest(unittest.TestCase):
+    """Unit tests for NegotiativeView class."""
+
+    @staticmethod
+    def _html_view(context):
+        return Response(
+            ("HTTP/1.1", 200, "OK"),
+            {"Content-Type": "text/html"},
+            "<pre>{0[message]}</pre>".format(context)
+        )
+
+    @staticmethod
+    def _json_view(context):
+        import json
+        return Response(
+            ("HTTP/1.1", 200, "OK"),
+            {"Content-Type": "application/json"},
+            json.dumps(context)
+        )
+
+    def setUp(self):
+        self.view = NegotiativeView({
+            "text/plain": test_view,
+            "text/html": self._html_view,
+            "application/json": self._json_view
+        })
+
+    def assert_view(self, expected_body, mime_type, expected_mime_type=None):
+        request = Request(
+            ("GET", "/", "HTTP/1.1"),
+            {"Accept": mime_type} if mime_type else {}
+        )
+        response = self.view({"message": "Content negotiation test"}, request)
+        self.assertEquals(
+            expected_mime_type or mime_type,
+            response["Content-Type"]
+        )
+        self.assertEquals(expected_body, response.body)
+
+    def test_negotiation(self):
+        self.assert_view("Content negotiation test", "text/plain")
+        self.assert_view("<pre>Content negotiation test</pre>", "text/html")
+        self.assert_view(
+            '{"message": "Content negotiation test"}',
+            "application/json"
+        )
+
+    def test_negotiation_best_match(self):
+        self.assert_view(
+            '{"message": "Content negotiation test"}',
+            "text/plain; q=0.5, application/json; q=1.0",
+            "application/json"
+        )
+
+    def test_default(self):
+        self.assert_view(
+            "<pre>Content negotiation test</pre>",
+            None, "text/html"
+        )
+
+    def test_not_acceptable(self):
+        error = self.view(
+            {},
+            Request(("GET", "/", "HTTP/1.1"), {"Accept": "image/png"})
+        )
+        self.assertEquals(406, error.status_code)
+

test/wsgi_test.py

     def test_start_response(self):
         self.assertEquals("200 OK", self.start_response.status)
         self.assertEquals(
-            [("Content-Type", "text/plain")],
+            [(b"Content-Type", b"text/plain")],
             self.start_response.response_headers
         )
 
 import re
 
 __all__ = ["Message", "RequestPath", "Request", "Response", "ErrorResponse",
-           "NotFoundError", "MethodNotAllowedError", "NegotiativeHeader",
-           "NegotiativeLanguageHeader"]
+           "NotFoundError", "MethodNotAllowedError", "NotAcceptableError",
+           "NegotiativeHeader", "NegotiativeLanguageHeader"]
 
 
 class Message:
                 for m in self.HEADER_PATTERN.finditer(match.group("headers"))
             ]
         def each(start_line, headers, body=b"", encoding=None):
+            self.encoding = encoding or getdefaultencoding()
             if not isinstance(start_line, bytes):
-                encoding = encoding or getdefaultencoding()
                 start_line = b" ".join(
-                    prt if isinstance(prt, bytes) else str(prt).encode(encoding)
+                    prt if isinstance(prt, bytes)
+                        else str(prt).encode(self.encoding)
                     for prt in start_line
                 )
-            self.encoding = encoding
             self.start_line = start_line
-            self.headers = list(headers.items() if isinstance(headers, dict)
-                                                else headers)
+            headers = (headers.items() if isinstance(headers, dict)
+                                       else headers)
+            def encode(s):
+                return s.encode(self.encoding) if isinstance(s, str) else s
+            self.headers = list((encode(name), encode(content))
+                                for name, content in headers)
             self.body = (body,) if isinstance(body, bytes) else body
         try:
             full_message(*args, **kwargs)
         self.encoding = self.encoding or getdefaultencoding()
 
     def _case_insensitive_re(self, name):
+        if not isinstance(name, (bytes, str)):
+            name_type = type(name).__name__
+            raise TypeError("expected bytes or a str, but %s given" % name_type)
         if isinstance(name, str):
             name = name.encode(self.encoding)
         return re.compile(b"^" + re.escape(name) + b"$", re.IGNORECASE)
 
+    def __contains__(self, name):
+        """Returns True if it contains the header. Case-insensitive."""
+        pattern = self._case_insensitive_re(name)
+        return any(pattern.match(key) for key, _ in self.headers)
+
     def __getitem__(self, name):
         """Get a header value by the name.
 
             >>> msg["content-length"]
             "123"
 
+        Raises KeyError when the header is not there.
+
+            >>> msg["Connection"]
+            Traceback (most recent call last):
+              File "<stdin>", line 1, in <module>
+            KeyError: there is no header 'Connection'
+
         """
-        if not isinstance(name, (bytes, str)):
-            name_type = type(name).__name__
-            raise TypeError("expected bytes or a str, but %s given" % name_type)
         pattern = self._case_insensitive_re(name)
         for field, value in self.headers:
             if pattern.match(field):
 
     def __setitem__(self, name, value):
         """Set a header. It accepts a str or bytes for a name and a value."""
+        pattern = self._case_insensitive_re(name)
         if isinstance(name, str):
             name = name.encode(self.encoding)
-        elif not isinstance(name, bytes):
-            raise TypeError("name must be a str or bytes")
         if isinstance(value, str):
             value = value.encode(self.encoding)
         elif not isinstance(value, bytes):
             raise TypeError("value must be a str or bytes")
-        pattern = self._case_insensitive_re(name)
         for i, (field, _) in enumerate(self.headers):
             if pattern.match(field):
                 self.headers[i] = (name, value)
             self.start_line.decode().split()
         self.path = RequestPath(full_path)
 
+    @property
+    def accept(self):
+        """Returns a NegotiativeHeader instance of 'Accept' header.
+
+        If 'Accept' header does not exist, it returns None.
+
+        """
+        try:
+            return NegotiativeHeader(self[b"Accept"], self.encoding)
+        except KeyError:
+            pass
+
 
 class Response(Message):
     """HTTP responses [1] class.
         ErrorResponse.__init__(self, 405, b"Method Not Allowed", headers)
 
 
-class NegotiativeHeader():
+class NotAcceptableError(ErrorResponse):
+    """406 Not Acceptable response class."""
+
+    def __init__(self, headers={}):
+        ErrorResponse.__init__(self, 406, b"Not Acceptable", headers)
+
+
+class NegotiativeHeader:
     """HTTP content negotiation [1] class. It used by Accept, Accept-Charset,
     Accept-Encoding, Accept-Ranges. [2]
 
     http://code.activestate.com/recipes/576653/
 
     """
-
     class K:
         def __init__(self, obj):
             self.obj = obj
 class LanguageComparator:
     @staticmethod
     def unpack_lang_country(language):
-        """ "en-US" -> ("en", "US"), "en" -> ("en", None) """
+        """'en-US' -> ('en', 'US'), 'en' -> ('en', '')"""
         lang_country = language.split("-")
         lang = lang_country[0]
         if len(lang_country) == 2:
             self.c_country_list.append(client[1])
 
     def __call__(self, a, b):
-        """server character language comparing function"""
+        """Compares server character languages."""
         a_lang, a_country = self.unpack_lang_country(a)
         b_lang, b_country = self.unpack_lang_country(b)
         a_lang_indexes = self.lang_indexes(self.c_lang_list, a_lang)
         return result
 
 
-class NegotiativeLanguageHeader():
+class NegotiativeLanguageHeader:
     """HTTP Accept-Language negotiation [1] class.
 
     [1] http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.4
         lang_cmp = LanguageComparator(sorted_lang)
         server_lang.sort(key=CmpToKey(lang_cmp))
         return server_lang
+

vlastic/rom/decorator.py

 """Some utility decorators for resource-object mapping."""
 
+from .. import http
 from . import resource
 
-__all__ = ["method", "get", "post", "put", "delete"]
+__all__ = ["NegotiativeView", "method", "get", "post", "put", "delete"]
+
+
+class NegotiativeView(dict):
+    """Accept-aware higher-order view functor. Its constructor takes a
+    {'mime/type': view} map for negotiation.
+
+    """
+
+    def default_view(self, *args):
+        """You can alter this method to set default view. Default view is
+        chosen when the request doesn't contain Accept header.
+
+        """
+        try:
+            view = self["text/html"]
+        except KeyError:
+            view = self["application/xhtml+xml"]
+        return view(*args)
+
+    def __call__(self, context, request):
+        accept = request.accept
+        if accept:
+            type = accept.best_match(self.keys())
+            try:
+                view = self[type]
+            except KeyError:
+                return http.NotAcceptableError()
+        else:
+            view = self.default_view
+        try:
+            return view(context, request)
+        except TypeError:
+            return view(context)
 
 
 class method: