Skip to content

Commit

Permalink
Refactor session auth (preparing for admin access to class pages).
Browse files Browse the repository at this point in the history
  • Loading branch information
liffiton committed Sep 2, 2024
1 parent be3e4f4 commit a0fdddf
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 68 deletions.
87 changes: 45 additions & 42 deletions src/gened/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ class AuthDict(TypedDict, total=False):
display_name: str
is_admin: bool
is_tester: bool
role_id: int | None # current role
role: RoleType | None # current role name (e.g., 'instructor')
class_id: int | None # current class ID
class_name: str | None # current class name
role_id: int | None # current role
role: RoleType | None # current role name (e.g., 'instructor')
class_experiments: list[str] # any experiments the current class is registered in
other_classes: list[ClassDict] # for storing active classes that are not the user's current class

Expand All @@ -62,20 +62,22 @@ def set_session_auth_user(user_id: int) -> None:
""" Set the current session's user (on login, after authentication).
Clears all other auth data in the session.
"""
auth: AuthDict = {
auth = {
'user_id': user_id,
}
session[AUTH_SESSION_KEY] = auth
_invalidate_g_auth()


def set_session_auth_role(role_id: int | None) -> None:
""" Set the current session's active role (on login or class switch).
def set_session_auth_class(class_id: int | None) -> None:
""" Set the current session's active class (on login or class switch).
Adds to any existing auth data in the session.
"""
auth: AuthDict = session[AUTH_SESSION_KEY]
auth['role_id'] = role_id
session[AUTH_SESSION_KEY] = auth
sess_auth = session.get(AUTH_SESSION_KEY, {})
assert 'user_id' in sess_auth
assert sess_auth['user_id'] is not None # must be logged in already for this function to be valid
sess_auth['class_id'] = class_id
session[AUTH_SESSION_KEY] = sess_auth
_invalidate_g_auth()


Expand All @@ -87,14 +89,13 @@ def _get_auth_from_session() -> AuthDict:
'user_id': None,
'is_admin': False,
'is_tester': False,
'role_id': None,
'role': None,
}
# Get the session auth dict, or an empty dict if it's not there, to find
# current user_id and role_id (if any).
sess_auth = session.get(AUTH_SESSION_KEY, {})
sess_user = sess_auth.get('user_id', None)
sess_role = sess_auth.get('role_id', None)
sess_class = sess_auth.get('class_id', None)

if not sess_user:
# No logged in user; return the base/empty auth data
Expand Down Expand Up @@ -122,16 +123,16 @@ def _get_auth_from_session() -> AuthDict:
auth_dict: AuthDict = {
# from session
'user_id': sess_user,
'role_id': sess_role,
'class_id': sess_class,
# from DB
'display_name': user_row['display_name'],
'is_admin': user_row['is_admin'],
'is_tester': user_row['is_tester'],
'auth_provider': user_row['auth_provider'],
# to be filled
'class_id': None,
'class_name': None,
'class_experiments': [],
'role_id': None,
'role': None,
'other_classes': [],
}
Expand All @@ -141,11 +142,11 @@ def _get_auth_from_session() -> AuthDict:
# Uses WHERE active=1 to only allow active roles.
role_rows = db.execute("""
SELECT
roles.id,
roles.id AS role_id,
roles.class_id,
roles.role,
classes.name,
classes.enabled,
roles.role
classes.enabled
FROM roles
JOIN classes ON classes.id=roles.class_id
WHERE roles.user_id=? AND roles.active=1
Expand All @@ -155,25 +156,27 @@ def _get_auth_from_session() -> AuthDict:
found_role = False # track whether the current role from auth is actually found as an active role
if role_rows:
for row in role_rows:
class_dict: ClassDict = {
'class_id': row['class_id'],
'class_name': row['name'],
'role': row['role'],
}
if row['id'] == auth_dict['role_id']:
if row['class_id'] == auth_dict['class_id']:
found_role = True
# merge class info into auth_dict
auth_dict |= class_dict # type: ignore[typeddict-item]
# add class/role info to auth_dict
auth_dict['class_name'] = row['name']
auth_dict['role_id'] = row['role_id']
auth_dict['role'] = row['role']
# check for any registered experiments in the current class
experiment_class_rows = db.execute("SELECT experiments.name FROM experiments JOIN experiment_class ON experiment_class.experiment_id=experiments.id WHERE experiment_class.class_id=?", [class_dict['class_id']]).fetchall()
experiment_class_rows = db.execute("SELECT experiments.name FROM experiments JOIN experiment_class ON experiment_class.experiment_id=experiments.id WHERE experiment_class.class_id=?", [auth_dict['class_id']]).fetchall()
auth_dict['class_experiments'] = [row['name'] for row in experiment_class_rows]
elif row['enabled']:
# store a list of any other classes that are enabled (for switching UI)
class_dict: ClassDict = {
'class_id': row['class_id'],
'class_name': row['name'],
'role': row['role'],
}
auth_dict['other_classes'].append(class_dict)

if not found_role:
# ensure we don't keep a role_id in auth if it's not a valid/active one
auth_dict['role_id'] = None
# ensure we don't keep a class_id in auth if it's not a valid/active one
auth_dict['class_id'] = None

return auth_dict

Expand All @@ -185,29 +188,29 @@ def get_auth() -> AuthDict:
return g.auth # type: ignore[no-any-return]


def get_last_role(user_id: int) -> int | None:
""" Find and return the last role (as a role ID) for the given user,
as long as that role still exists and is currently active.
def get_last_class(user_id: int) -> int | None:
""" Find and return the last class (as a class ID) for the given user,
as long as the user still has an active role in that class.
Returns the role_id or None if nothing is found / matches.
Returns the class_id or None if nothing is found / matches.
"""
db = get_db()

role_row = db.execute("""
SELECT roles.id AS role_id
FROM roles
JOIN users ON roles.user_id=users.id
class_row = db.execute("""
SELECT users.last_class_id AS class_id
FROM users
JOIN roles ON roles.user_id=users.id
WHERE users.id=?
AND users.last_role_id=roles.id
AND roles.class_id=users.last_class_id
AND roles.active=1
""", [user_id]).fetchone()

if not role_row:
if not class_row:
return None

role_id = role_row['role_id']
assert isinstance(role_id, int)
return role_id
class_id = class_row['class_id']
assert isinstance(class_id, int)
return class_id


def ext_login_update_or_create(provider_name: str, user_normed: dict[str, str | None], query_tokens: int=0) -> Row:
Expand Down Expand Up @@ -280,9 +283,9 @@ def login() -> str | Response:
flash("Invalid username or password.", "warning")
else:
# Success!
last_role_id = get_last_role(auth_row['id'])
last_class_id = get_last_class(auth_row['id'])
set_session_auth_user(auth_row['id'])
set_session_auth_role(last_role_id)
set_session_auth_class(last_class_id)
return safe_redirect_next(default_endpoint="helper.help_form")

# we either have a GET request or we fell through the POST login attempt with a failure
Expand Down Expand Up @@ -318,7 +321,7 @@ def instructor_required(f: Callable[P, R]) -> Callable[P, Response | R]:
@wraps(f)
def decorated_function(*args: P.args, **kwargs: P.kwargs) -> Response | R:
auth = get_auth()
if auth['role'] != "instructor":
if auth['role'] != "instructor" and not auth['is_admin']:
flash("Instructor login required.", "warning")
return redirect(url_for('auth.login', next=request.full_path))
return f(*args, **kwargs)
Expand Down
12 changes: 5 additions & 7 deletions src/gened/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from werkzeug.wrappers.response import Response

from .auth import get_auth, login_required, set_session_auth_role
from .auth import get_auth, login_required, set_session_auth_class
from .db import get_db
from .redir import safe_redirect_next
from .tz import date_is_past
Expand Down Expand Up @@ -141,7 +141,6 @@ def switch_class(class_id: int | None) -> bool:
user_id = auth['user_id']

db = get_db()
role_id = None # will be used if class_id is None

if class_id:
# check for a valid role in the new class
Expand All @@ -160,12 +159,11 @@ def switch_class(class_id: int | None) -> bool:
# no valid row found; change nothing and return failure
return False

# otherwise, here's our new role ID
role_id = row['role_id']
# otherwise, we can continue with this class_id

set_session_auth_role(role_id)
# record as user's latest active role
db.execute("UPDATE users SET last_role_id=? WHERE users.id=?", [role_id, user_id])
set_session_auth_class(class_id)
# record as user's latest active class
db.execute("UPDATE users SET last_class_id=? WHERE users.id=?", [class_id, user_id])
db.commit()
return True

Expand Down
16 changes: 6 additions & 10 deletions src/gened/lti.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from .auth import (
ext_login_update_or_create,
set_session_auth_role,
set_session_auth_class,
set_session_auth_user,
)
from .classes import get_or_create_lti_class
Expand Down Expand Up @@ -106,19 +106,15 @@ def lti_login(lti: LTI) -> Response | tuple[str, int]: # noqa: ARG001 (unused a

if not role_row:
# Register this user
cur = db.execute("INSERT INTO roles(user_id, class_id, role) VALUES(?, ?, ?)", [user_id, class_id, role])
db.execute("INSERT INTO roles(user_id, class_id, role) VALUES(?, ?, ?)", [user_id, class_id, role])
db.commit()
role_id = cur.lastrowid
else:
role_id = role_row['id']

if not role_row['active']:
session.clear()
return abort(403)
elif not role_row['active']:
session.clear()
return abort(403)

# Record them as logged in in the session
set_session_auth_user(user_id)
set_session_auth_role(role_id)
set_session_auth_class(class_id)

# Redirect to the app
if role == "instructor":
Expand Down
22 changes: 22 additions & 0 deletions src/gened/migrations/20240901--last_role_to_last_class.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
-- SPDX-FileCopyrightText: 2024 Mark Liffiton <[email protected]>
--
-- SPDX-License-Identifier: AGPL-3.0-only

BEGIN;

CREATE UNIQUE INDEX roles_user_class_unique ON roles(user_id, class_id);

ALTER TABLE users ADD COLUMN
last_class_id INTEGER; -- most recently active class, used to re-activate on login (note: user may no longer have active role in this class)

UPDATE users
SET last_class_id = (
SELECT class_id
FROM roles
WHERE roles.id=users.last_role_id
)
WHERE last_role_id IS NOT NULL;

ALTER TABLE users DROP COLUMN last_role_id;

COMMIT;
15 changes: 9 additions & 6 deletions src/gened/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@
#
# SPDX-License-Identifier: AGPL-3.0-only

from authlib.integrations.flask_client import OAuth, OAuthError # type: ignore [import-untyped]
from authlib.integrations.flask_client import ( # type: ignore [import-untyped]
OAuth,
OAuthError,
)
from flask import Blueprint, abort, current_app, redirect, request, session, url_for
from flask.app import Flask
from werkzeug.wrappers.response import Response

from .auth import (
ext_login_update_or_create,
get_last_role,
set_session_auth_role,
get_last_class,
set_session_auth_class,
set_session_auth_user,
)

Expand Down Expand Up @@ -114,12 +117,12 @@ def auth(provider_name: str) -> Response:
# Given 10 tokens by default if creating an account on first login.
user_row = ext_login_update_or_create(provider_name, user_normed, query_tokens=20)

# Get their last active role, if there is one (and it still exists and is active)
last_role_id = get_last_role(user_row['id'])
# Get their last active class, if there is one (and it still exists and user has active role in it)
last_class_id = get_last_class(user_row['id'])

# Now, either the user existed or has been created. Log them in!
set_session_auth_user(user_row['id'])
set_session_auth_role(last_role_id)
set_session_auth_class(last_class_id)

# Redirect to stored next_url (and reset) if one has been stored, else root path
next_url = session.get(NEXT_URL_SESSION_KEY) or "/"
Expand Down
4 changes: 3 additions & 1 deletion src/gened/schema_common.sql
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ INSERT INTO auth_providers(name) VALUES
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
auth_provider INTEGER NOT NULL,
last_role_id INTEGER, -- most recently activated role (note: may no longer exist if deleted) used to re-activate on login
last_class_id INTEGER, -- most recently active class, used to re-activate on login (note: user may no longer have active role in this class)
full_name TEXT,
email TEXT,
auth_name TEXT,
Expand Down Expand Up @@ -133,6 +133,8 @@ CREATE TABLE roles (
FOREIGN KEY(user_id) REFERENCES users(id),
FOREIGN KEY(class_id) REFERENCES classes(id)
);
DROP INDEX IF EXISTS roles_user_class_unique;
CREATE UNIQUE INDEX roles_user_class_unique ON roles(user_id, class_id);

-- Store/manage demonstration links
CREATE TABLE demo_links (
Expand Down
6 changes: 4 additions & 2 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,19 @@ def check_login(
assert sessauth['display_name'] == username
assert sessauth['is_admin'] == is_admin
assert sessauth['class_id'] is None
assert sessauth['role_id'] is None
assert 'auth_provider' in sessauth
assert sessauth['auth_provider'] == 'local'

else:
# Verify session auth contains correct values for non-logged-in user
sessauth = get_auth()
assert sessauth['user_id'] is None
assert sessauth['role_id'] is None
assert sessauth['is_admin'] is False
assert sessauth['role'] is None
assert 'display_name' not in sessauth
assert 'class_id' not in sessauth
assert 'role_id' not in sessauth
assert 'auth_provider' not in sessauth

assert message in response.text
Expand Down Expand Up @@ -118,10 +120,10 @@ def test_logout(client, auth):

sessauth = get_auth()
assert sessauth['user_id'] is None
assert sessauth['role_id'] is None
assert sessauth['is_admin'] is False
assert 'display_name' not in sessauth
assert 'class_id' not in sessauth
assert 'role_id' not in sessauth
assert 'auth_provider' not in sessauth

# Check if the user can access the login page and see the flashed message after logout
Expand Down

0 comments on commit a0fdddf

Please sign in to comment.