From 70941887ca44f94162da54fd905bfa2d25d859f6 Mon Sep 17 00:00:00 2001 From: Vianpyro Date: Wed, 20 Nov 2024 12:05:22 -0500 Subject: [PATCH] Add custom TokenError exception and refactor token handling for improved clarity --- jwt_helper.py | 63 ++++++++++++++++++++++++++-------------- routes/authentication.py | 27 +++++++---------- 2 files changed, 53 insertions(+), 37 deletions(-) diff --git a/jwt_helper.py b/jwt_helper.py index 6ec619d..8a8cf01 100644 --- a/jwt_helper.py +++ b/jwt_helper.py @@ -10,8 +10,21 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=30) +class TokenError(Exception): + """ + Custom exception for token-related errors. + """ + + def __init__(self, message, status_code): + super().__init__(message) + self.status_code = status_code + self.message = message + + def generate_access_token(player_id: int) -> str: - """Generate a JWT token for a user.""" + """ + Generate a short-lived JWT access token for a user. + """ payload = { "player_id": player_id, "exp": datetime.now(timezone.utc) + ACCESS_TOKEN_EXPIRY, # Expiration @@ -21,7 +34,9 @@ def generate_access_token(player_id: int) -> str: def generate_refresh_token(player_id: int) -> str: - """Generate a long-lived refresh token.""" + """ + Generate a long-lived refresh token for a user. + """ payload = { "player_id": player_id, "exp": datetime.now(timezone.utc) + REFRESH_TOKEN_EXPIRY, @@ -30,35 +45,41 @@ def generate_refresh_token(player_id: int) -> str: return jwt.encode(payload, SECRET_KEY, algorithm="HS256") -def verify_token(token: str) -> dict | None: - """Verify a JWT token and return the payload.""" - token = request.headers.get("Authorization") - - if not token or not token.startswith("Bearer "): - return jsonify(message="Token is missing or improperly formatted"), 401 +def extract_token_from_header() -> str: + """ + Extract the Bearer token from the Authorization header. + """ + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.startswith("Bearer "): + raise TokenError("Token is missing or improperly formatted", 401) + return auth_header.split("Bearer ")[1] - token = token.split("Bearer ")[1] +def verify_token(token: str) -> dict: + """ + Verify and decode a JWT token. + """ try: return jwt.decode(token, SECRET_KEY, algorithms=["HS256"]) except jwt.ExpiredSignatureError: - return None # Token expired + raise TokenError("Token has expired", 401) except jwt.InvalidTokenError: - return None # Invalid token + raise TokenError("Invalid token", 401) def token_required(f): + """ + Decorator to protect routes by requiring a valid token. + """ + @wraps(f) def decorated(*args, **kwargs): - token = request.headers.get("Authorization") - if not token: - return jsonify(message="Token is missing"), 401 - - decoded = verify_token(token) - if not decoded: - return jsonify(message="Token is invalid or expired"), 401 - - request.player_id = decoded["player_id"] # Attach user ID to the request - return f(*args, **kwargs) + try: + token = extract_token_from_header() + decoded = verify_token(token) + request.player_id = decoded["player_id"] + return f(*args, **kwargs) + except TokenError as e: + return jsonify(message=e.message), e.status_code return decorated diff --git a/routes/authentication.py b/routes/authentication.py index 4593b09..4904e36 100644 --- a/routes/authentication.py +++ b/routes/authentication.py @@ -8,7 +8,13 @@ from pymysql import MySQLError from db import get_db_connection -from jwt_helper import generate_access_token, generate_refresh_token, verify_token +from jwt_helper import ( + generate_access_token, + generate_refresh_token, + verify_token, + extract_token_from_header, + TokenError, +) load_dotenv() @@ -110,23 +116,12 @@ def login(): @authentication_blueprint.route("/refresh", methods=["POST"]) def refresh_token(): - auth_header = request.headers.get("Authorization") - - if not auth_header or not auth_header.startswith("Bearer "): - return ( - jsonify(message="Refresh token is required in the Authorization header"), - 400, - ) - - refresh_token = auth_header.split("Bearer ")[1] - try: - decoded = verify_token(refresh_token) + token = extract_token_from_header() + decoded = verify_token(token) player_id = decoded["player_id"] new_access_token = generate_access_token(player_id) return jsonify(access_token=new_access_token), 200 - except jwt.ExpiredSignatureError: - return jsonify(message="Refresh token has expired, please log in again"), 401 - except jwt.InvalidTokenError: - return jsonify(message="Invalid refresh token"), 401 + except TokenError as e: + return jsonify(message=e.message), e.status_code