Commits

Tetsuya Morimoto committed aacb96d

added AuthResult model
changed to trace debug flow by each ip address

  • Participants
  • Parent commits e2fb1a3

Comments (0)

Files changed (3)

src/raido/models.py

         self.consumer = consumer
         self.user = user
 
+class AuthResult(ModelBaseMixin, db.Model):
+    __tablename__ = "authresult"
+    id = db.Column(db.Integer, primary_key=True)
+    ip_address = _Column(db.String(32), index=True)
+    progress = _Column(db.String(2048), default="[]")
+
+    def __init__(self, ip_addr):
+        self.ip_address = ip_addr
+
 # create tables if not exist
 db.create_all(bind=[None])

src/raido/provider.py

 
 app.add_url_rule("/favicon.ico", "favicon",
         lambda: redirect(app.config["RAIDO_LOGO"]))
-app.debug_progress = []  # consider later for simultaneous access
 
 @app.route("/", methods=["GET", "POST"])
 def index():
     return redirect(url_for("login"))
 
 @app.route("/login", methods=["GET", "POST"])
+@trace_debug_progress
 def login():
     app.logger.debug("Session: {0}".format(session))
     form = LoginForm(request.form)
     if request.method == "POST" and form.validate():
-        app.debug_progress.append("req_login")
+        g.debug_progress.append("req_login")
         if session.get("username") != form.username.data:
             logout()
         with get_object_or_none((User,), name=form.username.data) as u:
         session["username"] = form.username.data
         app.logger.debug("{0} is logged in".format(u.name))
         if form.next_url.data == url_for("index"):
-            app.debug_progress.append("res_logged_in")
+            g.debug_progress.append("res_logged_in")
         else:
-            app.debug_progress.append("redirect")
+            g.debug_progress.append("redirect")
         return redirect(form.next_url.data)
     # GET or others
-    app.debug_progress.append("get_login_form")
+    g.debug_progress.append("get_login_form")
     form.next_url.data = request.args.get("next", url_for("index"))
-    app.debug_progress.append("res_login_form")
+    g.debug_progress.append("res_login_form")
     return render_template("login.html", user=g.user, form=form)
 
 @app.route("/logout")
     return render_template("consumer/register.html", form=form)
 
 @app.route("/oauth/2/auth_code", methods=["GET"])
+@trace_debug_progress
 def oauth2_auth_code():
-    if app.debug_progress and app.debug_progress[-1] != "redirect":
-        app.debug_progress = []
-    app.debug_progress.append("req_code")
-    app.debug_progress.append("chk_logged_in")
+    if g.debug_progress and g.debug_progress[-1] != "redirect":
+        g.debug_progress = []
+    g.debug_progress.append("req_code")
+    g.debug_progress.append("chk_logged_in")
     if not g.user:
         app.logger.debug("need login before getting authentication code")
-        app.debug_progress.append("redirect")
+        g.debug_progress.append("redirect")
         return redirect(url_for("login", next=request.url))
 
     fields = ("client_id", "redirect_uri")
     db.session.commit()
 
     app.logger.debug("authentication code is created: {0}".format(code))
-    app.debug_progress.append("res_code")
+    g.debug_progress.append("res_code")
     return redirect("{0}?code={1}".format(c.redirect_uri, code))
 
 @app.route("/oauth/2/access_token", methods=["GET", "POST"])
+@trace_debug_progress
 def oauth2_access_token():
     def is_verified(c, code, request_dict):
         valid, verified, rd_params = True, None, None
         app.logger.debug(msg)
         return valid, verified, rd_params
 
-    app.debug_progress.append("req_token")
+    g.debug_progress.append("req_token")
     if request.method == "POST":
         code = request.form.get("code")
         request_dict = request.form
         if "next" in rd_params:
             result.update(next=rd_params["next"])
         app.logger.debug("access token is created: {0}".format(str(result)))
-    app.debug_progress.append("res_token")
+    g.debug_progress.append("res_token")
     return jsonify(**result)
 
 @app.route("/info", methods=["GET"])
 
 @app.route("/debug", methods=["GET"])
 def debug():
+    get_debug_progress()
     cfg, root_path = app.config, app.root_path  # just for alias
     # make whole flow
     whole_path = cfg["OAUTH2_WHOLE_FLOW"] + "." + cfg["DIAG_FORMAT"]
 
     # get/analysis your oauth flow result
     path = cfg["DEBUG_OAUTH2_FLOW"] + "." + cfg["DIAG_FORMAT"]
-    source, result = generate_diag_source(app.debug_progress)
-    app.logger.debug("progress: {0}".format(app.debug_progress))
+    source, result = generate_diag_source(g.debug_progress)
+    app.logger.debug("progress: {0}".format(g.debug_progress))
     app.logger.debug("diag source:\n{0}".format(source.encode("utf-8")))
     make_flow_diagram(path, root_path, cfg, source, clean=True)
 

src/raido/views/helper.py

     View Helper Module
 """
 
+import json
 import os
 from operator import getitem
 from os.path import join as pathjoin
 from random import randint
-from flask import flash
+from flask import flash, g, request
+from functools import wraps
 
-from raido.models import db, AuthCode, AccessToken
+from raido.models import db, AuthCode, AccessToken, AuthResult
 from raido.utils.contextmanagers import get_object_or_none
 
 def delete_data(entities, **kwargs):
             db.session.delete(obj)
             db.session.commit()
 
+def get_debug_progress():
+    r = AuthResult.query.filter_by(ip_address=request.remote_addr).first()
+    if not r:
+        r = AuthResult(request.remote_addr)
+        db.session.add(r)
+        db.session.commit()
+    g.debug_progress = json.loads(r.progress)
+
+def set_debug_progress():
+    r = AuthResult.query.filter_by(ip_address=request.remote_addr).first()
+    if r:
+        r.progress = json.dumps(g.debug_progress)
+        db.session.add(r)
+        db.session.commit()
+
+def trace_debug_progress(func):
+    @wraps(func)
+    def _trace_debug_progress(*args, **kwargs):
+        get_debug_progress()
+        response = func(*args, **kwargs)
+        set_debug_progress()
+        return response
+    return _trace_debug_progress
+
 def get_font_path(app_fonts):
     from os.path import isfile
     from raido.consts import DIAG_DEFAULT_FONTS