Commits

Michał Jaworski committed c052177

now oauth decorator can be used with optional resource_name operator, but still works without it, fixes #11

Comments (0)

Files changed (1)

oauth_provider/decorators.py

 from utils import initialize_server_request, send_oauth_error, get_oauth_request
 from consts import OAUTH_PARAMETERS_NAMES
 from store import store, InvalidTokenError
+from functools import wraps
 
-def oauth_required(view_func=None, resource_name=None):
-    return CheckOAuth(view_func, resource_name)
 
-class CheckOAuth(object):
+class CheckOauth(object):
     """
-    Class that checks that the OAuth parameters passes the given test, raising
+    Decorator that checks that the OAuth parameters passes the given test, raising
     an OAuth error otherwise. If the test is passed, the view function
     is invoked.
 
     CheckOAuth object is used as a method decorator, the view function
     is properly bound to its instance.
     """
-    def __init__(self, view_func, resource_name):
-        self.view_func = view_func
+    def __init__(self, resource_name=None):
         self.resource_name = resource_name
-        update_wrapper(self, view_func)
-        
-    def __get__(self, obj, cls=None):
-        view_func = self.view_func.__get__(obj, cls)
-        return CheckOAuth(view_func, self.resource_name)
-    
-    def __call__(self, request, *args, **kwargs):
-        if self.is_valid_request(request):
-            oauth_request = get_oauth_request(request)
-            consumer = store.get_consumer(request, oauth_request, 
-                            oauth_request.get_parameter('oauth_consumer_key'))
-            try:
-                token = store.get_access_token(request, oauth_request, 
-                                consumer, oauth_request.get_parameter('oauth_token'))
-            except InvalidTokenError:
-                return send_oauth_error(oauth2.Error(_('Invalid access token: %s') % oauth_request.get_parameter('oauth_token')))
-            try:
-                self.validate_token(request, consumer, token)
-            except oauth2.Error, e:
-                return send_oauth_error(e)
-            
-            if self.resource_name and token.resource.name != self.resource_name:
-                return send_oauth_error(oauth2.Error(_('You are not allowed to access this resource.')))
-            elif consumer and token:
-                return self.view_func(request, *args, **kwargs)
-        
-        return send_oauth_error(oauth2.Error(_('Invalid request parameters.')))
+
+    def __new__(cls, arg=None):
+        if not callable(arg):
+            return super(CheckOauth, cls).__new__(cls)
+        else:
+            obj =  super(CheckOauth, cls).__new__(cls)
+            obj.__init__()
+            return obj(arg)
+
+    def __call__(self, view_func):
+
+        @wraps(view_func)
+        def wrapped_view(request, *args, **kwargs):
+            if self.is_valid_request(request):
+                oauth_request = get_oauth_request(request)
+                consumer = store.get_consumer(request, oauth_request,
+                                oauth_request.get_parameter('oauth_consumer_key'))
+                try:
+                    token = store.get_access_token(request, oauth_request,
+                                    consumer, oauth_request.get_parameter('oauth_token'))
+                except InvalidTokenError:
+                    return send_oauth_error(oauth2.Error(_('Invalid access token: %s') % oauth_request.get_parameter('oauth_token')))
+                try:
+                    self.validate_token(request, consumer, token)
+                except oauth2.Error, e:
+                    return send_oauth_error(e)
+
+                if self.resource_name and token.resource.name != self.resource_name:
+                    return send_oauth_error(oauth2.Error(_('You are not allowed to access this resource.')))
+                elif consumer and token:
+                    return view_func(request, *args, **kwargs)
+
+            return send_oauth_error(oauth2.Error(_('Invalid request parameters.')))
+
+        return wrapped_view
 
     @staticmethod
     def is_valid_request(request):
     def validate_token(request, consumer, token):
         oauth_server, oauth_request = initialize_server_request(request)
         return oauth_server.verify_request(oauth_request, consumer, token)
+
+oauth_required = CheckOauth