Donald Stufft avatar Donald Stufft committed b5352ce

Port the OAuth handler to use Werkzeug

Comments (0)

Files changed (1)

             # OpenID
             routing.Rule("/id/", endpoint="openid_discovery"),
             routing.Rule("/id/<username>/", endpoint="openid_user"),
+
+            # OAuth
+            routing.Submount("/oauth", [
+                routing.Rule("/request_token/",
+                    endpoint="oauth_request_token"),
+                routing.Rule("/access_token/", endpoint="oauth_access_token"),
+                routing.Rule("/authorise", endpoint="oauth_authorise"),
+                routing.Rule("/add_release", endpoint="oauth_add_release"),
+                routing.Rule("/upload", endpoint="oauth_upload"),
+                routing.Rule("/docupload", endpoint="oauth_docupload"),
+                routing.Rule("/test", endpoint="oauth_test_access"),
+            ]),
         ])
 
     @property
             # Dispatch to the endpoint
             return getattr(self, "do_%s" % endpoint)(request, database, **args)
 
-            # Now we have a username try running OAuth if necessary
-            if script_name == '/oauth':
-                return self.run_oauth(request)
-
             if request.content_type == "text/xml":
                 return self.xmlrpc(request)
 
         self.form['version'] = data['version']
         self.display(ok_message=message)
 
-    def submit(self, parameters=None, response=True):
+    def do_submit(self, request, database, parameters=None):
         ''' Handle the submission of distro metadata.
         '''
         # make sure the user is identified
         else:
             return None
 
-    #
-    # OAuth
-    #
-    def run_oauth(self):
-        if self.env.get('HTTPS') != 'on':
-            raise NotFound('HTTPS must be used to access this URL')
-
-        path = self.env.get('PATH_INFO')
-
-        if path == '/request_token':
-            self.oauth_request_token()
-        elif path == '/access_token':
-            self.oauth_access_token()
-        elif path == '/authorise':
-            self.oauth_authorise()
-        elif path == '/add_release':
-            self.oauth_add_release()
-        elif path == '/upload':
-            self.oauth_upload()
-        elif path == '/docupload':
-            self.oauth_docupload()
-        elif path == '/test':
-            self.oauth_test_access()
-        else:
-            raise NotFound()
-
-    def _oauth_request(self):
-        uri = self.url_machine + self.env['REQUEST_URI']
-        if not self.env.get('HTTP_AUTHORIZATION'):
+    def _oauth_request(self, request):
+        if not request.authorization:
             raise OAuthError('PyPI OAuth requires header authorization')
-        params = dict(self.form)
+
+        params = dict(request.form)
         # don't use file upload in signature
         if 'content' in params:
             del params['content']
-        return oauth.OAuthRequest.from_request(self.env['REQUEST_METHOD'],
-            uri, dict(Authorization=self.env['HTTP_AUTHORIZATION']), params)
-
-    def _oauth_server(self):
-        data_store = store.OAuthDataStore(self.store)
+
+        return oauth.OAuthRequest.from_request(
+            request.method,
+            request.url,
+            dict(Authorization=request.environ['HTTP_AUTHORIZATION']),
+            params,
+        )
+
+    def _oauth_server(self, database):
+        data_store = store.OAuthDataStore(database)
         o = oauth.OAuthSignatureMethod_HMAC_SHA1()
         signature_methods = {o.get_name(): o}
         return oauth.OAuthServer(data_store, signature_methods)
 
-    def oauth_request_token(self):
-        s = self._oauth_server()
-        r = self._oauth_request()
+    def oauth_request_token(self, request, database):
+        s = self._oauth_server(database)
+        r = self._oauth_request(request)
         token = s.fetch_request_token(r)
-        self.store.commit()
-        self.write_plain(str(token))
-
-    def oauth_access_token(self):
-        s = self._oauth_server()
-        r = self._oauth_request()
+        return Response(str(token), content_type="text/plain")
+
+    def oauth_access_token(self, request, database):
+        s = self._oauth_server(database)
+        r = self._oauth_request(request)
         token = s.fetch_access_token(r)
         if token is None:
             raise OAuthError('Request Token not authorised')
-        self.store.commit()
-        self.write_plain(str(token))
-
-    def oauth_authorise(self):
-        if 'oauth_token' not in self.form:
+        return Response(str(token), content_type="text/plain")
+
+    def oauth_authorise(self, request, database):
+        if "oauth_token" not in request.form:
             raise FormError('oauth_token and oauth_callback are required')
-        if not self.authenticated:
-            self.write_template('oauth_notloggedin.pt',
-                title="OAuth authorisation attempt")
-            return
-
-        oauth_token = self.form['oauth_token']
-        oauth_callback = self.form['oauth_callback']
-
-        ok = self.form.get('ok')
-        cancel = self.form.get('cancel')
-
-        s = self._oauth_server()
+
+        if not request.authenticated:
+            return self.render_template(
+                request,
+                database,
+                "oauth_notloggedin.pt",
+                title="OAuth authorization attempt",
+            )
+
+        oauth_token = request.form['oauth_token']
+        oauth_callback = request.form['oauth_callback']
+
+        ok = request.form.get('ok')
+        cancel = request.form.get('cancel')
+
+        s = self._oauth_server(database)
 
         if not ok and not cancel:
-            description = s.data_store._get_consumer_description(request_token=oauth_token)
-            action_url = self.url_machine + '/oauth/authorise'
-            return self.write_template('oauth_authorise.pt',
-                title='PyPI - the Python Package Index',
+            description = s.data_store._get_consumer_description(
+                request_token=oauth_token,
+            )
+            action_url = urlparse.urljoin(request.host_url, "/oauth/authorise")
+
+            return self.render_template(
+                request,
+                database,
+                "oauth_authorise.pt",
+                title="PyPI - the Python Package Index",
                 action_url=action_url,
                 oauth_token=oauth_token,
                 oauth_callback=oauth_callback,
-                description=description)
+                description=description,
+            )
 
         if '%3A' in oauth_callback:
             oauth_callback = urllib.unquote(oauth_callback)
             raise RedirectTemporary(oauth_callback)
 
         # register the user against the request token
-        s.authorize_token(oauth_token, self.username)
-
-        # commit all changes now
-        self.store.commit()
-
-        url = oauth_callback + '?oauth_token=%s'%oauth_token
+        s.authorize_token(oauth_token, request.user.username)
+
+        url = oauth_callback + '?oauth_token=%s' % oauth_token
         raise RedirectTemporary(url)
 
-    def _parse_request(self):
+    def _parse_request(self, request, database):
         '''Read OAuth access request information from the request.
 
         Return the consumer (OAuthConsumer instance), the access token
         accompanying the request) and the user account number authorized by the
         access token.
         '''
-        s = self._oauth_server()
-        r = self._oauth_request()
+        s = self._oauth_server(database)
+        r = self._oauth_request(request)
         consumer, token, params = s.verify_request(r)
         user = s.data_store._get_user(token)
         # recognise the user as accessing during this request
-        self.username = user
-        self.store.set_user(user, self.remote_addr, False)
-        self.authenticated = True
+        request.user = User(
+            authenticated=True,
+            username=user,
+            last_login=None,
+        )
+        database.set_user(user, request.remote_addr, False)
         return consumer, token, params, user
 
-    def oauth_test_access(self):
+    def oauth_test_access(self, request, database):
         '''A resource that is protected so access without an access token is
         disallowed.
         '''
-        consumer, token, params, user = self._parse_request()
-        message = 'Access allowed for %s (ps. I got params=%r)'%(user, params)
-        self.write_plain(message)
-
-    def oauth_add_release(self):
+        consumer, token, params, user = self._parse_request(reqest, database)
+        message = 'Access allowed for %s (ps. I got params=%r)' % (
+            user, params,
+        )
+        return Response(message, content_type="text/plain")
+
+    def oauth_add_release(self, request, database):
         '''Add a new release.
 
         Returns "OK" if all is well otherwise .. who knows (TODO this needs to
         be clarified and cleaned up).
         '''
-        consumer, token, params, user = self._parse_request()
-        self.submit(params, False)
-        self.write_plain('OK')
-
-    def oauth_upload(self):
+        consumer, token, params, user = self._parse_request(request, database)
+        resp = self.do_submit(request, database, parameters=params)
+        return Response("OK", content_type="text/plain")
+
+    def oauth_upload(self, request, database):
         '''Upload a file for a package release.
         '''
-        consumer, token, params, user = self._parse_request()
-        self.file_upload(False)
-        self.write_plain('OK')
-
-    def oauth_docupload(self):
+        consumer, token, params, user = self._parse_request(request, database)
+        self.do_file_upload(request, database)
+        return Response("OK", content_type="text/plain")
+
+    def oauth_docupload(self, request, database):
         '''Upload a documentation bundle.
         '''
-        consumer, token, params, user = self._parse_request()
-        message = 'Access allowed for %s (ps. I got params=%r)'%(user, params)
-        self.write_plain(message)
-
+        consumer, token, params, user = self._parse_request(request, database)
+        message = 'Access allowed for %s (ps. I got params=%r)' % (
+            user, params,
+        )
+        return Response(message, content_type="text/plain")
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.