diff --git a/src/gened/auth.py b/src/gened/auth.py index 8f98f19..12e672b 100644 --- a/src/gened/auth.py +++ b/src/gened/auth.py @@ -42,10 +42,10 @@ class AuthDict(TypedDict, total=False): display_name: str is_admin: bool is_tester: bool - role_id: int | None # current role - role: RoleType | None # current role name (e.g., 'instructor') class_id: int | None # current class ID class_name: str | None # current class name + role_id: int | None # current role + role: RoleType | None # current role name (e.g., 'instructor') class_experiments: list[str] # any experiments the current class is registered in other_classes: list[ClassDict] # for storing active classes that are not the user's current class @@ -62,20 +62,22 @@ def set_session_auth_user(user_id: int) -> None: """ Set the current session's user (on login, after authentication). Clears all other auth data in the session. """ - auth: AuthDict = { + auth = { 'user_id': user_id, } session[AUTH_SESSION_KEY] = auth _invalidate_g_auth() -def set_session_auth_role(role_id: int | None) -> None: - """ Set the current session's active role (on login or class switch). +def set_session_auth_class(class_id: int | None) -> None: + """ Set the current session's active class (on login or class switch). Adds to any existing auth data in the session. """ - auth: AuthDict = session[AUTH_SESSION_KEY] - auth['role_id'] = role_id - session[AUTH_SESSION_KEY] = auth + sess_auth = session.get(AUTH_SESSION_KEY, {}) + assert 'user_id' in sess_auth + assert sess_auth['user_id'] is not None # must be logged in already for this function to be valid + sess_auth['class_id'] = class_id + session[AUTH_SESSION_KEY] = sess_auth _invalidate_g_auth() @@ -87,14 +89,13 @@ def _get_auth_from_session() -> AuthDict: 'user_id': None, 'is_admin': False, 'is_tester': False, - 'role_id': None, 'role': None, } # Get the session auth dict, or an empty dict if it's not there, to find # current user_id and role_id (if any). sess_auth = session.get(AUTH_SESSION_KEY, {}) sess_user = sess_auth.get('user_id', None) - sess_role = sess_auth.get('role_id', None) + sess_class = sess_auth.get('class_id', None) if not sess_user: # No logged in user; return the base/empty auth data @@ -122,16 +123,16 @@ def _get_auth_from_session() -> AuthDict: auth_dict: AuthDict = { # from session 'user_id': sess_user, - 'role_id': sess_role, + 'class_id': sess_class, # from DB 'display_name': user_row['display_name'], 'is_admin': user_row['is_admin'], 'is_tester': user_row['is_tester'], 'auth_provider': user_row['auth_provider'], # to be filled - 'class_id': None, 'class_name': None, 'class_experiments': [], + 'role_id': None, 'role': None, 'other_classes': [], } @@ -141,11 +142,11 @@ def _get_auth_from_session() -> AuthDict: # Uses WHERE active=1 to only allow active roles. role_rows = db.execute(""" SELECT - roles.id, + roles.id AS role_id, roles.class_id, + roles.role, classes.name, - classes.enabled, - roles.role + classes.enabled FROM roles JOIN classes ON classes.id=roles.class_id WHERE roles.user_id=? AND roles.active=1 @@ -155,25 +156,27 @@ def _get_auth_from_session() -> AuthDict: found_role = False # track whether the current role from auth is actually found as an active role if role_rows: for row in role_rows: - class_dict: ClassDict = { - 'class_id': row['class_id'], - 'class_name': row['name'], - 'role': row['role'], - } - if row['id'] == auth_dict['role_id']: + if row['class_id'] == auth_dict['class_id']: found_role = True - # merge class info into auth_dict - auth_dict |= class_dict # type: ignore[typeddict-item] + # add class/role info to auth_dict + auth_dict['class_name'] = row['name'] + auth_dict['role_id'] = row['role_id'] + auth_dict['role'] = row['role'] # check for any registered experiments in the current class - experiment_class_rows = db.execute("SELECT experiments.name FROM experiments JOIN experiment_class ON experiment_class.experiment_id=experiments.id WHERE experiment_class.class_id=?", [class_dict['class_id']]).fetchall() + experiment_class_rows = db.execute("SELECT experiments.name FROM experiments JOIN experiment_class ON experiment_class.experiment_id=experiments.id WHERE experiment_class.class_id=?", [auth_dict['class_id']]).fetchall() auth_dict['class_experiments'] = [row['name'] for row in experiment_class_rows] elif row['enabled']: # store a list of any other classes that are enabled (for switching UI) + class_dict: ClassDict = { + 'class_id': row['class_id'], + 'class_name': row['name'], + 'role': row['role'], + } auth_dict['other_classes'].append(class_dict) if not found_role: - # ensure we don't keep a role_id in auth if it's not a valid/active one - auth_dict['role_id'] = None + # ensure we don't keep a class_id in auth if it's not a valid/active one + auth_dict['class_id'] = None return auth_dict @@ -185,29 +188,29 @@ def get_auth() -> AuthDict: return g.auth # type: ignore[no-any-return] -def get_last_role(user_id: int) -> int | None: - """ Find and return the last role (as a role ID) for the given user, - as long as that role still exists and is currently active. +def get_last_class(user_id: int) -> int | None: + """ Find and return the last class (as a class ID) for the given user, + as long as the user still has an active role in that class. - Returns the role_id or None if nothing is found / matches. + Returns the class_id or None if nothing is found / matches. """ db = get_db() - role_row = db.execute(""" - SELECT roles.id AS role_id - FROM roles - JOIN users ON roles.user_id=users.id + class_row = db.execute(""" + SELECT users.last_class_id AS class_id + FROM users + JOIN roles ON roles.user_id=users.id WHERE users.id=? - AND users.last_role_id=roles.id + AND roles.class_id=users.last_class_id AND roles.active=1 """, [user_id]).fetchone() - if not role_row: + if not class_row: return None - role_id = role_row['role_id'] - assert isinstance(role_id, int) - return role_id + class_id = class_row['class_id'] + assert isinstance(class_id, int) + return class_id def ext_login_update_or_create(provider_name: str, user_normed: dict[str, str | None], query_tokens: int=0) -> Row: @@ -280,9 +283,9 @@ def login() -> str | Response: flash("Invalid username or password.", "warning") else: # Success! - last_role_id = get_last_role(auth_row['id']) + last_class_id = get_last_class(auth_row['id']) set_session_auth_user(auth_row['id']) - set_session_auth_role(last_role_id) + set_session_auth_class(last_class_id) return safe_redirect_next(default_endpoint="helper.help_form") # we either have a GET request or we fell through the POST login attempt with a failure @@ -318,7 +321,7 @@ def instructor_required(f: Callable[P, R]) -> Callable[P, Response | R]: @wraps(f) def decorated_function(*args: P.args, **kwargs: P.kwargs) -> Response | R: auth = get_auth() - if auth['role'] != "instructor": + if auth['role'] != "instructor" and not auth['is_admin']: flash("Instructor login required.", "warning") return redirect(url_for('auth.login', next=request.full_path)) return f(*args, **kwargs) diff --git a/src/gened/classes.py b/src/gened/classes.py index cfd9f05..1c52fcd 100644 --- a/src/gened/classes.py +++ b/src/gened/classes.py @@ -17,7 +17,7 @@ ) from werkzeug.wrappers.response import Response -from .auth import get_auth, login_required, set_session_auth_role +from .auth import get_auth, login_required, set_session_auth_class from .db import get_db from .redir import safe_redirect_next from .tz import date_is_past @@ -141,7 +141,6 @@ def switch_class(class_id: int | None) -> bool: user_id = auth['user_id'] db = get_db() - role_id = None # will be used if class_id is None if class_id: # check for a valid role in the new class @@ -160,12 +159,11 @@ def switch_class(class_id: int | None) -> bool: # no valid row found; change nothing and return failure return False - # otherwise, here's our new role ID - role_id = row['role_id'] + # otherwise, we can continue with this class_id - set_session_auth_role(role_id) - # record as user's latest active role - db.execute("UPDATE users SET last_role_id=? WHERE users.id=?", [role_id, user_id]) + set_session_auth_class(class_id) + # record as user's latest active class + db.execute("UPDATE users SET last_class_id=? WHERE users.id=?", [class_id, user_id]) db.commit() return True diff --git a/src/gened/lti.py b/src/gened/lti.py index 9a31c1b..a7fe351 100644 --- a/src/gened/lti.py +++ b/src/gened/lti.py @@ -20,7 +20,7 @@ from .auth import ( ext_login_update_or_create, - set_session_auth_role, + set_session_auth_class, set_session_auth_user, ) from .classes import get_or_create_lti_class @@ -106,19 +106,15 @@ def lti_login(lti: LTI) -> Response | tuple[str, int]: # noqa: ARG001 (unused a if not role_row: # Register this user - cur = db.execute("INSERT INTO roles(user_id, class_id, role) VALUES(?, ?, ?)", [user_id, class_id, role]) + db.execute("INSERT INTO roles(user_id, class_id, role) VALUES(?, ?, ?)", [user_id, class_id, role]) db.commit() - role_id = cur.lastrowid - else: - role_id = role_row['id'] - - if not role_row['active']: - session.clear() - return abort(403) + elif not role_row['active']: + session.clear() + return abort(403) # Record them as logged in in the session set_session_auth_user(user_id) - set_session_auth_role(role_id) + set_session_auth_class(class_id) # Redirect to the app if role == "instructor": diff --git a/src/gened/migrations/20240901--last_role_to_last_class.sql b/src/gened/migrations/20240901--last_role_to_last_class.sql new file mode 100644 index 0000000..2f7ce47 --- /dev/null +++ b/src/gened/migrations/20240901--last_role_to_last_class.sql @@ -0,0 +1,22 @@ +-- SPDX-FileCopyrightText: 2024 Mark Liffiton +-- +-- SPDX-License-Identifier: AGPL-3.0-only + +BEGIN; + +CREATE UNIQUE INDEX roles_user_class_unique ON roles(user_id, class_id); + +ALTER TABLE users ADD COLUMN + last_class_id INTEGER; -- most recently active class, used to re-activate on login (note: user may no longer have active role in this class) + +UPDATE users +SET last_class_id = ( + SELECT class_id + FROM roles + WHERE roles.id=users.last_role_id +) +WHERE last_role_id IS NOT NULL; + +ALTER TABLE users DROP COLUMN last_role_id; + +COMMIT; diff --git a/src/gened/oauth.py b/src/gened/oauth.py index 4c03a39..bad5f48 100644 --- a/src/gened/oauth.py +++ b/src/gened/oauth.py @@ -2,15 +2,18 @@ # # SPDX-License-Identifier: AGPL-3.0-only -from authlib.integrations.flask_client import OAuth, OAuthError # type: ignore [import-untyped] +from authlib.integrations.flask_client import ( # type: ignore [import-untyped] + OAuth, + OAuthError, +) from flask import Blueprint, abort, current_app, redirect, request, session, url_for from flask.app import Flask from werkzeug.wrappers.response import Response from .auth import ( ext_login_update_or_create, - get_last_role, - set_session_auth_role, + get_last_class, + set_session_auth_class, set_session_auth_user, ) @@ -114,12 +117,12 @@ def auth(provider_name: str) -> Response: # Given 10 tokens by default if creating an account on first login. user_row = ext_login_update_or_create(provider_name, user_normed, query_tokens=20) - # Get their last active role, if there is one (and it still exists and is active) - last_role_id = get_last_role(user_row['id']) + # Get their last active class, if there is one (and it still exists and user has active role in it) + last_class_id = get_last_class(user_row['id']) # Now, either the user existed or has been created. Log them in! set_session_auth_user(user_row['id']) - set_session_auth_role(last_role_id) + set_session_auth_class(last_class_id) # Redirect to stored next_url (and reset) if one has been stored, else root path next_url = session.get(NEXT_URL_SESSION_KEY) or "/" diff --git a/src/gened/schema_common.sql b/src/gened/schema_common.sql index 2b0e231..8a5344d 100644 --- a/src/gened/schema_common.sql +++ b/src/gened/schema_common.sql @@ -53,7 +53,7 @@ INSERT INTO auth_providers(name) VALUES CREATE TABLE users ( id INTEGER PRIMARY KEY AUTOINCREMENT, auth_provider INTEGER NOT NULL, - last_role_id INTEGER, -- most recently activated role (note: may no longer exist if deleted) used to re-activate on login + last_class_id INTEGER, -- most recently active class, used to re-activate on login (note: user may no longer have active role in this class) full_name TEXT, email TEXT, auth_name TEXT, @@ -133,6 +133,8 @@ CREATE TABLE roles ( FOREIGN KEY(user_id) REFERENCES users(id), FOREIGN KEY(class_id) REFERENCES classes(id) ); +DROP INDEX IF EXISTS roles_user_class_unique; +CREATE UNIQUE INDEX roles_user_class_unique ON roles(user_id, class_id); -- Store/manage demonstration links CREATE TABLE demo_links ( diff --git a/tests/test_auth.py b/tests/test_auth.py index edc6ec4..038e501 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -46,6 +46,7 @@ def check_login( assert sessauth['display_name'] == username assert sessauth['is_admin'] == is_admin assert sessauth['class_id'] is None + assert sessauth['role_id'] is None assert 'auth_provider' in sessauth assert sessauth['auth_provider'] == 'local' @@ -53,10 +54,11 @@ def check_login( # Verify session auth contains correct values for non-logged-in user sessauth = get_auth() assert sessauth['user_id'] is None - assert sessauth['role_id'] is None assert sessauth['is_admin'] is False + assert sessauth['role'] is None assert 'display_name' not in sessauth assert 'class_id' not in sessauth + assert 'role_id' not in sessauth assert 'auth_provider' not in sessauth assert message in response.text @@ -118,10 +120,10 @@ def test_logout(client, auth): sessauth = get_auth() assert sessauth['user_id'] is None - assert sessauth['role_id'] is None assert sessauth['is_admin'] is False assert 'display_name' not in sessauth assert 'class_id' not in sessauth + assert 'role_id' not in sessauth assert 'auth_provider' not in sessauth # Check if the user can access the login page and see the flashed message after logout