From 586fbcfca017ac93af9b2d3c0cb5d3132c3e0859 Mon Sep 17 00:00:00 2001 From: Franziska Kunsmann Date: Mon, 23 Dec 2024 16:15:26 +0100 Subject: [PATCH 1/4] rework login logic to support generic oauth2 providers --- frontend.py | 179 +++++++++++++++++++++++++++-------------- requirements.txt | 1 - templates/layout.jinja | 8 +- util/__init__.py | 34 ++++---- util/sso/__init__.py | 22 +++++ util/sso/github.py | 32 ++++++++ 6 files changed, 194 insertions(+), 82 deletions(-) create mode 100644 util/sso/__init__.py create mode 100644 util/sso/github.py diff --git a/frontend.py b/frontend.py index bfc54d5..a00d404 100644 --- a/frontend.py +++ b/frontend.py @@ -4,8 +4,10 @@ from datetime import datetime, timezone from secrets import token_hex from typing import Iterable +from urllib.parse import urlencode import iso8601 +import requests from flask import ( Flask, abort, @@ -17,7 +19,6 @@ session, url_for, ) -from flask_github import GitHub from prometheus_client import generate_latest from prometheus_client.core import REGISTRY, GaugeMetricFamily from prometheus_client.metrics_core import Metric @@ -38,11 +39,10 @@ get_assets_awaiting_moderation, get_random, get_user_assets, - login_disabled_for_user, + is_within_timeframe, login_required, - user_is_admin, - user_without_limits, ) +from util.sso import SSO_CONFIG app = Flask( __name__, @@ -52,8 +52,6 @@ app.wsgi_app = ProxyFix(app.wsgi_app) for copy_key in ( - "GITHUB_CLIENT_ID", - "GITHUB_CLIENT_SECRET", "MAX_UPLOADS", "ROOMS", "TIME_MAX", @@ -112,30 +110,49 @@ def collect(self) -> Iterable[Metric]: REGISTRY.register(SubmissionsCollector()) REGISTRY.register(InfobeamerCollector()) -github = GitHub(app) - app.session_interface = RedisSessionStore() @app.before_request def before_request(): - user = session.get("gh_login") - g.user_is_admin = user_is_admin(user) - g.user_without_limits = user_without_limits(user) + provider = session.get("oauth2_provider") + userinfo = session.get("oauth2_userinfo") + + g.user_is_admin = False + g.user_without_limits = False + g.userid = "" + g.username = "" + + if not provider or not userinfo: + return + + username = SSO_CONFIG[provider]["functions"]["username"](userinfo) + user_is_admin = SSO_CONFIG[provider]["functions"]["is_admin"](userinfo) + user_without_limits = SSO_CONFIG[provider]["functions"]["no_limit"](userinfo) - if login_disabled_for_user(user): - g.user = None - g.avatar = None + if not (user_is_admin or user_without_limits or is_within_timeframe()): return - g.user = user - g.avatar = session.get("gh_avatar") + g.user_is_admin = user_is_admin + g.user_without_limits = user_without_limits + g.userid = f"{provider}:{username}" + g.username = username + + +@app.context_processor +def login_providers(): + result = {} + + for provider, config in CONFIG["oauth2_providers"].items(): + result[provider] = config.get("display_name", provider.capitalize()) + + return {"login_providers": result} @app.context_processor def start_time_alert(): # if g.user is set, the user was successfully logged in (see above) - if g.user: + if g.userid: return {"start_time": None} start_time = datetime.fromtimestamp(CONFIG["TIME_MIN"], timezone.utc) @@ -146,47 +163,88 @@ def start_time_alert(): return {"start_time": start_time.strftime("%F %T")} -@app.route("/github-callback") -@github.authorized_handler -def authorized(access_token): - if access_token is None: - return redirect(url_for("index")) +@app.route("/login/") +def login(provider): + if g.userid: + return redirect(url_for("dashboard")) - state = request.args.get("state") - if state is None or state != session.get("state"): - return redirect(url_for("index")) - session.pop("state") + provider_config = CONFIG["oauth2_providers"].get(provider, {}) + if not provider_config or provider not in SSO_CONFIG: + abort(404) - github_user = github.get("user", access_token=access_token) - if github_user["type"] != "User": - return redirect(url_for("faq", _anchor="signup")) + session["oauth2_state"] = state = get_random() - if login_disabled_for_user(github_user["login"]): - return render_template("time_error.jinja") + qs = urlencode( + { + "client_id": provider_config["client_id"], + "redirect_uri": url_for( + "oauth2_callback", provider=provider, _external=True + ), + "response_type": "code", + "scope": " ".join(SSO_CONFIG[provider]["scopes"]), + "state": state, + } + ) + return redirect("{}?{}".format(SSO_CONFIG[provider]["authorize_url"], qs)) + + +@app.route("/login/callback/") +def oauth2_callback(provider): + if g.userid: + return redirect(url_for("dashboard")) + + provider_config = CONFIG["oauth2_providers"].get(provider, {}) + if not provider_config or provider not in SSO_CONFIG: + abort(404) + + if "error" in request.args: + for k, v in request.args.items(): + if k.startswith("error"): + flash(f"{k}: {v}") + return redirect(url_for("index")) - age = datetime.utcnow() - iso8601.parse_date(github_user["created_at"]).replace( - tzinfo=None + if request.args["state"] != session.get("oauth2_state"): + abort(401) + + if "code" not in request.args: + abort(400) + + r = requests.post( + SSO_CONFIG[provider]["token_url"], + data={ + "client_id": provider_config["client_id"], + "client_secret": provider_config["client_secret"], + "code": request.args["code"], + "grant_type": "authorization_code", + "redirect_uri": url_for( + "oauth2_callback", provider=provider, _external=True + ), + }, + headers={"Accept": "application/json"}, + ) + if r.status_code != 200: + abort(400) + oauth2_token = r.json().get("access_token") + + r = requests.get( + SSO_CONFIG[provider]["userinfo_url"], + headers={ + "Authorization": f"Bearer {oauth2_token}", + "Accept": "application/json", + }, ) + userinfo_json = r.json() - app.logger.info(f"user is {age.days} days old") - app.logger.info("user has {} followers".format(github_user["followers"])) - if age.days < 31 and github_user["followers"] < 10: + if not SSO_CONFIG[provider]["functions"]["login_allowed"](userinfo_json): return redirect(url_for("faq", _anchor="signup")) - session["gh_login"] = github_user["login"] + session["oauth2_provider"] = provider + session["oauth2_userinfo"] = userinfo_json if "redirect_after_login" in session: return redirect(session["redirect_after_login"]) return redirect(url_for("dashboard")) -@app.route("/login") -def login(): - if g.user: - return redirect(url_for("dashboard")) - session["state"] = state = get_random() - return github.authorize(state=state) - - @app.route("/logout") def logout(): session.clear() @@ -213,7 +271,7 @@ def saal(): auth = CONFIG.get("INTERRUPT_KEY") if not auth: abort(404) - if not user_is_admin(g.user) and request.args.get("auth") != auth: + if not g.user_is_admin and request.args.get("auth") != auth: abort(401) interrupt_key = get_scoped_api_key( @@ -271,13 +329,16 @@ def content_upload(): extension = "jpg" if filetype == "image" else "mp4" filename = "user/{}/{}_{}.{}".format( - g.user, datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S"), token_hex(8), extension + g.userid, + datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), + token_hex(8), + extension, ) condition = { "StringEquals": { "asset:filename": filename, "asset:filetype": filetype, - "userdata:user": g.user, + "userdata:user": g.userid, }, "NotExists": { "userdata:state": True, @@ -311,7 +372,7 @@ def content_upload(): ) return jsonify( filename=filename, - user=g.user, + user=g.userid, upload_key=get_scoped_api_key( [{"Action": "asset:upload", "Condition": condition, "Effect": "allow"}], uses=1, @@ -327,30 +388,30 @@ def content_request_review(asset_id): except Exception: abort(404) - if asset["userdata"].get("user") != g.user: + if asset["userdata"].get("user") != g.userid: return error("Cannot review") if "state" in asset["userdata"]: # not in new state? return error("Cannot review") moderation_message = "{asset} uploaded by {user}. ".format( - user=g.user, + user=g.userid, asset=asset["filetype"].capitalize(), ) if g.user_is_admin: - update_asset_userdata(asset, state=State.CONFIRMED, moderated_by=g.user) + update_asset_userdata(asset, state=State.CONFIRMED, moderated_by=g.userid) app.logger.warn( "auto-confirming {} because it was uploaded by admin {}".format( - asset["id"], g.user + asset["id"], g.userid ) ) moderation_message += "It was automatically confirmed because user is an admin." elif g.user_without_limits: - update_asset_userdata(asset, state=State.CONFIRMED, moderated_by=g.user) + update_asset_userdata(asset, state=State.CONFIRMED, moderated_by=g.userid) app.logger.warn( "auto-confirming {} because it was uploaded by no-limits user {}".format( - asset["id"], g.user + asset["id"], g.userid ) ) moderation_message += ( @@ -413,10 +474,10 @@ def content_moderate_result(asset_id, result): if result == "confirm": app.logger.info("Asset {} was confirmed".format(asset["id"])) - update_asset_userdata(asset, state=State.CONFIRMED, moderated_by=g.user) + update_asset_userdata(asset, state=State.CONFIRMED, moderated_by=g.userid) else: app.logger.info("Asset {} was rejected".format(asset["id"])) - update_asset_userdata(asset, state=State.REJECTED, moderated_by=g.user) + update_asset_userdata(asset, state=State.REJECTED, moderated_by=g.userid) return jsonify(ok=True) @@ -432,7 +493,7 @@ def content_update(asset_id): starts = request.values.get("starts", type=int) ends = request.values.get("ends", type=int) - if asset["userdata"].get("user") != g.user: + if asset["userdata"].get("user") != g.userid: return error("Cannot update") try: @@ -452,7 +513,7 @@ def content_delete(asset_id): except Exception: abort(404) - if asset["userdata"].get("user") != g.user: + if asset["userdata"].get("user") != g.userid: return error("Cannot delete") try: diff --git a/requirements.txt b/requirements.txt index dfa3cb7..99c0699 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,6 @@ click==8.1.7 Deprecated==1.2.14 Flask==3.0.3 gevent==24.2.1 -GitHub-Flask==3.2.0 greenlet==3.0.3 gunicorn==22.0.0 httplib2==0.22.0 diff --git a/templates/layout.jinja b/templates/layout.jinja index 93a59ca..b839572 100644 --- a/templates/layout.jinja +++ b/templates/layout.jinja @@ -32,13 +32,15 @@
  • Slideshow
  • diff --git a/util/__init__.py b/util/__init__.py index 35a40b9..cdfde7d 100644 --- a/util/__init__.py +++ b/util/__init__.py @@ -3,7 +3,7 @@ import random import shutil import tempfile -from datetime import datetime +from datetime import datetime, timezone from functools import wraps from typing import NamedTuple, Optional @@ -14,25 +14,24 @@ from .ib_hosted import ib +DEFAULT_SSO_PROVIDER = CONFIG.get( + "DEFAULT_SSO_PROVIDER", list(CONFIG["oauth2_providers"].keys())[0] +) +DEFAULT_ADMIN_SSO_PROVIDER = CONFIG.get( + "DEFAULT_ADMIN_SSO_PROVIDER", list(CONFIG["oauth2_providers"].keys())[0] +) + def error(msg): return jsonify(error=msg), 400 -def user_is_admin(user) -> bool: - return user is not None and user.lower() in CONFIG.get("ADMIN_USERS", set()) - - -def user_without_limits(user) -> bool: - return user is not None and user.lower() in CONFIG.get("NO_LIMIT_USERS", set()) - - def login_required(f): @wraps(f) def decorated_function(*args, **kwargs): - if not g.user: + if not g.userid: session["redirect_after_login"] = request.url - return redirect(url_for("login")) + return redirect(url_for("login", provider=DEFAULT_SSO_PROVIDER)) return f(*args, **kwargs) return decorated_function @@ -41,9 +40,9 @@ def decorated_function(*args, **kwargs): def admin_required(f): @wraps(f) def decorated_function(*args, **kwargs): - if not g.user: + if not g.userid: session["redirect_after_login"] = request.url - return redirect(url_for("login")) + return redirect(url_for("login", provider=DEFAULT_ADMIN_SSO_PROVIDER)) if not g.user_is_admin: abort(401) return f(*args, **kwargs) @@ -136,7 +135,7 @@ def get_assets(cached=False): def get_user_assets(): - return [a for a in get_assets() if a.user == g.user and a.state != State.DELETED] + return [a for a in get_assets() if a.user == g.userid and a.state != State.DELETED] def get_assets_awaiting_moderation(): @@ -156,11 +155,8 @@ def get_all_live_assets(no_time_filter=False): ] -def login_disabled_for_user(user=None): - if user_is_admin(user) or user_without_limits(user): - return False - - now = datetime.now().timestamp() +def is_within_timeframe(): + now = datetime.now(timezone.utc).timestamp() return not (CONFIG["TIME_MIN"] < now < CONFIG["TIME_MAX"]) diff --git a/util/sso/__init__.py b/util/sso/__init__.py new file mode 100644 index 0000000..6b98c83 --- /dev/null +++ b/util/sso/__init__.py @@ -0,0 +1,22 @@ +from util.sso.github import ( + check_github_allowed_login, + check_github_is_admin, + check_github_no_limit, + get_github_username, +) + +SSO_CONFIG = { + "github": { + "display_name": "GitHub", + "authorize_url": "https://github.com/login/oauth/authorize", + "token_url": "https://github.com/login/oauth/access_token", + "scopes": ["user:email"], + "userinfo_url": "https://api.github.com/user", + "functions": { + "is_admin": check_github_is_admin, + "login_allowed": check_github_allowed_login, + "no_limit": check_github_no_limit, + "username": get_github_username, + }, + }, +} diff --git a/util/sso/github.py b/util/sso/github.py new file mode 100644 index 0000000..4c624ce --- /dev/null +++ b/util/sso/github.py @@ -0,0 +1,32 @@ +from datetime import datetime, timezone +from logging import getLogger + +from conf import CONFIG + +LOG = getLogger("SSO-Github") + + +def get_github_username(userinfo_json): + return "{} (GitHub)".format(userinfo_json["login"]) + + +def check_github_allowed_login(userinfo_json): + if userinfo_json["type"] != "User": + return False + + age = datetime.now(timezone.utc) - datetime.fromisoformat( + userinfo_json["created_at"] + ) + LOG.info(f"user is {age.days} days old") + LOG.info("user has {} followers".format(userinfo_json["followers"])) + if age.days < 31 and userinfo_json["followers"] < 10: + return False + return True + + +def check_github_is_admin(userinfo_json): + return f"github:{userinfo_json['login'].lower()}" in CONFIG["ADMIN_USERS"] + + +def check_github_no_limit(userinfo_json): + return f"github:{userinfo_json['login'].lower()}" in CONFIG["NO_LIMIT_USERS"] From 9352162c3cf9ad59adfd15781ab30505fda5d80a Mon Sep 17 00:00:00 2001 From: Franziska Kunsmann Date: Mon, 23 Dec 2024 16:39:35 +0100 Subject: [PATCH 2/4] ensure flashed messages get shown --- frontend.py | 6 ++++-- templates/layout.jinja | 5 +++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/frontend.py b/frontend.py index a00d404..03ed7f6 100644 --- a/frontend.py +++ b/frontend.py @@ -6,11 +6,11 @@ from typing import Iterable from urllib.parse import urlencode -import iso8601 import requests from flask import ( Flask, abort, + flash, g, jsonify, redirect, @@ -200,7 +200,7 @@ def oauth2_callback(provider): if "error" in request.args: for k, v in request.args.items(): if k.startswith("error"): - flash(f"{k}: {v}") + flash(f"{k}: {v}", "danger") return redirect(url_for("index")) if request.args["state"] != session.get("oauth2_state"): @@ -236,6 +236,7 @@ def oauth2_callback(provider): userinfo_json = r.json() if not SSO_CONFIG[provider]["functions"]["login_allowed"](userinfo_json): + flash("You are not allowed to log in at this time.", "warning") return redirect(url_for("faq", _anchor="signup")) session["oauth2_provider"] = provider @@ -248,6 +249,7 @@ def oauth2_callback(provider): @app.route("/logout") def logout(): session.clear() + flash("You have been logged out", "info") return redirect(url_for("index")) diff --git a/templates/layout.jinja b/templates/layout.jinja index b839572..62e5a2f 100644 --- a/templates/layout.jinja +++ b/templates/layout.jinja @@ -47,6 +47,11 @@
    + {% for messages in get_flashed_messages(with_categories=True) %} + {% for category, message in messages %} + + {% endfor %} + {% endfor %} {% if start_time %}