Tetsuya Morimoto avatar Tetsuya Morimoto committed 2ef35c3

changed to ignore parameters included in "redirect_uri" for validation
changed to be able to use "next" parameter from Consumer login

Comments (0)

Files changed (3)

src/raido/consumer.py

     data = {"access_token": resp["access_token"]}
     me = raido_provider.get("/info", data=data)
     return "Logged in as id={id} name={name} birth date={birth_date} "\
-           "redirect={0}".format(request.args.get("next"), **me.data)
+           "redirect={0}".format(resp.get("next"), **me.data)
 
 @raido_provider.tokengetter
 def get_raido_provider_oauth_token():

src/raido/provider.py

 from raido.utils.misc import (
         convert_datetime_to_str, convert_str_to_datetime,
         get_log_file_handler)
-from raido.utils.reqtools import get_parameter
+from raido.utils.reqtools import get_parameter, get_parameter_with_uri_parsed
 from raido.views.helper import *
 
 app.add_url_rule("/favicon.ico", "favicon",
 @app.route("/consumer/list", methods=["GET", "POST"])
 def consumer_list():
     if request.method == "POST":
-        param = get_parameter(request.form, ("client_id",))
-        delete_data((Consumer,), **param)
+        delete_data((Consumer,), client_id=request.form.get("client_id"))
         return redirect(url_for("consumer_list"))
     # GET or others
     consumers = Consumer.query.all()
         app.logger.debug("need login before getting authentication code")
         return redirect(url_for("login", next=request.url))
 
-    param = get_parameter(request.args, ("client_id", "redirect_uri"))
-    c = Consumer.query.filter_by(**param).first()
+    fields = ("client_id", "redirect_uri")
+    params, _ = get_parameter_with_uri_parsed(request.args, fields)
+    c = Consumer.query.filter_by(**params).first()
     if not c:
-        app.logger.debug("wrong consumer parameter: {0}".format(str(param)))
+        app.logger.debug("wrong consumer parameter: {0}".format(str(params)))
         return redirect(url_for("consumer_list"))
 
     code = generate_auth_code()
 @app.route("/oauth/2/access_token", methods=["GET", "POST"])
 def oauth2_access_token():
     def is_verified(c, code, request_dict):
-        valid, result = True, None
+        valid, verified, rd_params = True, None, None
         msg = "authentication code is verified: {0}".format(code)
         if not c:
             valid = False
-            result = msg = "invalid authentication code: {0}".format(code)
+            verified = msg = "invalid authentication code: {0}".format(code)
         elif c.expires_in <= datetime.now():
             valid = False
-            result = msg = "expired authentication code: {0}".format(code)
+            verified = msg = "expired authentication code: {0}".format(code)
         elif not c.user.logged_in:
             valid = False
             msg = "need login before getting access token"
-            result = redirect(url_for("login", next=request.url))
+            verified = redirect(url_for("login", next=request.url))
 
         fields = ("client_id", "client_secret", "redirect_uri")
-        param = get_parameter(request_dict, fields)
-        if valid and not Consumer.query.filter_by(**param).first():
+        params, rd_params = get_parameter_with_uri_parsed(request.args, fields)
+        if valid and not Consumer.query.filter_by(**params).first():
             valid = False
-            msg = "wrong consumer parameter: {0}".format(str(param))
-            result = redirect(url_for("consumer_list"))
+            msg = "wrong consumer parameter: {0}".format(str(params))
+            verified = redirect(url_for("consumer_list"))
         app.logger.debug(msg)
-        return valid, result
+        return valid, verified, rd_params
 
     if request.method == "POST":
         code = request.form.get("code")
                                  auth_code=code) as q:
         c = q.first()
         # verify authentication code and parameter
-        valid, result = is_verified(c, code, request_dict)
+        valid, verified, rd_params = is_verified(c, code, request_dict)
         if not valid:
-            return result
+            return verified
 
         # The authorization code MUST expire shortly
         # after it is issued to mitigate the risk of leaks.
         t = AccessToken(token, c.consumer, c.user)
         db.session.add(t)
         db.session.commit()
-        app.logger.debug("access token is created: {0}".format(token))
-    return "access_token={0}".format(token)
+
+        result = {"access_token": token}
+        if "next" in rd_params:
+            result.update(next=rd_params["next"])
+        app.logger.debug("access token is created: {0}".format(str(result)))
+    return jsonify(**result)
 
 @app.route("/info", methods=["GET"])
 def info():

src/raido/utils/reqtools.py

         {'id': u'1', 'name': u'Peter'}
     """
     return dict((p, request_dict.get(p)) for p in fields)
+
+def get_parameter_with_uri_parsed(request_dict, fields, key="redirect_uri"):
+    """ call :func:`get_parameter` except for parsing "redirect_uri"
+
+    :param request_dict: :data:`flask.request.args` or
+                         :data:`flask.request.form`
+    :param fields: key items for getting
+    :param key: uri parameter name (default="redirect_uri")
+
+    Example usage::
+
+        >>> from werkzeug import ImmutableMultiDict
+        >>> args = ImmutableMultiDict([
+        ...     ("client_id", 100),
+        ...     ("redirect_uri", "http://loc/auth?next=http%3A%2F%2Fwww"),
+        ...     ("other_parameter", "detarame"),
+        ...     ])
+        >>> fields = ("client_id", "redirect_uri")
+        >>> params, rd_params = get_parameter_with_uri_parsed(args, fields)
+        >>> params
+        {'redirect_uri': 'http://loc/auth', 'client_id': 100}
+        >>> rd_params
+        {'next': 'http://www'}
+    """
+    params = get_parameter(request_dict, fields)
+    rd_uri, rd_params = parse_uri_and_params(params[key])
+    params.update(redirect_uri=rd_uri)
+    return params, rd_params
+
+def parse_uri_and_params(uri):
+    """ parse :data:`uri`, then return uri and parameters
+
+    :param uri: full url string
+
+    Example usage::
+
+        >>> parse_uri_and_params("http://loc/authorized?query=text&key=100")
+        ('http://loc/authorized', {'query': 'text', 'key': '100'})
+    """
+    from urlparse import parse_qsl, urlparse
+    pr = urlparse(uri)
+    uri_only = "{0}://{1}{2}".format(pr.scheme, pr.netloc, pr.path)
+    params = dict(parse_qsl(pr.query))
+    return uri_only, params
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.