Commits

maluke  committed f96403d

rewrite using classes

  • Participants
  • Parent commits 7ef5d1e

Comments (0)

Files changed (1)

File contrib/decorators.py

 
 import new
 
-class webob_wrap(object):
+class WrapperBase(object):
+    _min_args = 1
     def __new__(cls, *args, **kw):
-        if not args:
-            return lambda func: webob_wrap(func, **kw)
+        if len(args) < cls._min_args:
+            return lambda *newargs: cls(*(args+newargs), **kw)
         else:
-            return super(webob_wrap, cls).__new__(cls, *args, **kw)
+            inst = super(WrapperBase, cls).__new__(cls, *args, **kw)
+            inst._key = str(id(inst))
+            inst._original_args = list(args)
+            inst._original_kw = kw.copy()
+            return inst
 
-    def __init__(self, __func=None, **kw):
-        self.func = __func
-        self._original_kw = kw.copy()
-        self.default_request_charset = kw.pop('default_request_charset', 'UTF-8')
+    def __repr__(self):
+        args = []
+        for arg in self._original_args:
+            args.append(repr(arg))
+        for name, val in self._original_kw.iteritems():
+            args.append('%s=%r' % (name, val))
+        return '%s(%s)' % (self.__class__.__name__, ', '.join(args))
+
+    def __get__(self, owner, instance):
+        wrapped = getattr(instance, self._key, None)
+        if wrapped is None:
+            bound = new.instancemethod(self.func, owner, instance)
+            wrapped = self.__class__(bound, **self._original_kw)
+            setattr(instance, self._key, wrapped)
+        return wrapped
+
+
+class webob_wrap(WrapperBase):
+    def __init__(self, func, default_request_charset='UTF-8', **kw):
+        self.func = func
+        self.default_request_charset = default_request_charset
         self.kw = kw
-        self._key = str(id(self))
 
     def __call__(self, environ, start_response):
         req = Request(environ)
             app = exc
         return app(environ, start_response)
 
-    def __get__(self, owner, instance):
-        wrapped = getattr(instance, self._key, None)
-        if wrapped is None:
-            bound = new.instancemethod(self.func, owner, instance)
-            wrapped = webob_wrap(bound, **self._original_kw)
-            setattr(instance, self._key, wrapped)
-        return wrapped
 
-    def __repr__(self):
-        kwstr = ', '.join('%s=%r' % (name, val) for (name, val) in self._original_kw.iteritems())
-        return '%s(%r, %s)' % (self.__class__.__name__, self.func, kwstr)
+class webob_middleware(WrapperBase):
+    _min_args = 2
+    def __init__(self, mwfunc, next_app, **kw):
+        self.mwfunc = mwfunc
+        self.next_app = next_app
+        self.kw = kw
 
+    @webob_wrap
+    def __call__(self, req):
+        return self.mwfunc(req, self.next_app, **self.kw)
 
-def webob_middleware(middleware):
-    def wrapper(app, **kw):
-        @webob_wrap
-        def middleware_app(req):
-            return middleware(req, app, **kw)
-        return middleware_app
-    return wrapper
 
-def webob_postprocessor(processor):
-    @webob_middleware
-    def postprocessor_middleware(req, app, no_range=True, decode_content=True, **kw):
-        if no_range:
+class webob_postprocessor(WrapperBase):
+    _min_args = 2
+    def __init__(self, postprocessor, next_app, no_range=True, decode_content=True, **kw):
+        self.postprocessor = postprocessor
+        self.next_app = next_app
+        self.no_range = no_range
+        self.decode_content = decode_content
+        self.kw = kw
+
+    @webob_wrap
+    def __call__(self, req):
+        if self.no_range:
             req.range = req.if_range = None
-        resp = req.get_response(app)
+        resp = req.get_response(self.next_app)
         if resp._app_iter or resp._body:
-            if decode_content:
+            if self.decode_content:
                 resp.decode_content()
-            processor(resp, **kw)
+            self.postprocessor(resp, **self.kw)
         return resp
-    return postprocessor_middleware
+
 
 if __name__ == '__main__':
     def test(app, url='/'):
             return Response(self.val)
 
     assert test(App('123')).body == '123'
+
+    @webob_middleware
+    def mw(req, app):
+        r = req.get_response(app)
+        r.md5_etag()
+        return r
+
+    mwa = mw(app)
+    mwr = test(mwa)
+    print mw, mwa
+    assert mwr.body == '1'
+    assert mwr.etag is not None
+
+    @webob_postprocessor
+    def double(r):
+        r.body += r.body
+
+    #print double
+    #print double.__call__
+    dapp = double(app)
+    print dapp
+    assert test(dapp).body == '11'