Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize token handling #13

Merged
merged 6 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 42 additions & 21 deletions jwt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,21 @@
REFRESH_TOKEN_EXPIRY = timedelta(days=30)


class TokenError(Exception):
"""
Custom exception for token-related errors.
"""

def __init__(self, message, status_code):
super().__init__(message)
self.status_code = status_code
self.message = message


def generate_access_token(player_id: int) -> str:
"""Generate a JWT token for a user."""
"""
Generate a short-lived JWT access token for a user.
"""
payload = {
"player_id": player_id,
"exp": datetime.now(timezone.utc) + ACCESS_TOKEN_EXPIRY, # Expiration
Expand All @@ -21,7 +34,9 @@ def generate_access_token(player_id: int) -> str:


def generate_refresh_token(player_id: int) -> str:
"""Generate a long-lived refresh token."""
"""
Generate a long-lived refresh token for a user.
"""
payload = {
"player_id": player_id,
"exp": datetime.now(timezone.utc) + REFRESH_TOKEN_EXPIRY,
Expand All @@ -30,35 +45,41 @@ def generate_refresh_token(player_id: int) -> str:
return jwt.encode(payload, SECRET_KEY, algorithm="HS256")


def verify_token(token: str) -> dict | None:
"""Verify a JWT token and return the payload."""
token = request.headers.get("Authorization")

if not token or not token.startswith("Bearer "):
return jsonify(message="Token is missing or improperly formatted"), 401
def extract_token_from_header() -> str:
"""
Extract the Bearer token from the Authorization header.
"""
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
raise TokenError("Token is missing or improperly formatted", 401)
return auth_header.split("Bearer ")[1]

token = token.split("Bearer ")[1]

def verify_token(token: str) -> dict:
"""
Verify and decode a JWT token.
"""
try:
return jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
except jwt.ExpiredSignatureError:
return None # Token expired
raise TokenError("Token has expired", 401)
except jwt.InvalidTokenError:
return None # Invalid token
raise TokenError("Invalid token", 401)


def token_required(f):
"""
Decorator to protect routes by requiring a valid token.
"""

@wraps(f)
def decorated(*args, **kwargs):
token = request.headers.get("Authorization")
if not token:
return jsonify(message="Token is missing"), 401

decoded = verify_token(token)
if not decoded:
return jsonify(message="Token is invalid or expired"), 401

request.player_id = decoded["player_id"] # Attach user ID to the request
return f(*args, **kwargs)
try:
token = extract_token_from_header()
decoded = verify_token(token)
request.player_id = decoded["player_id"]
return f(*args, **kwargs)
except TokenError as e:
return jsonify(message=e.message), e.status_code

return decorated
28 changes: 11 additions & 17 deletions routes/authentication.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import os
from re import match

import jwt
from argon2 import PasswordHasher, exceptions
from dotenv import load_dotenv
from flask import Blueprint, jsonify, request
from pymysql import MySQLError

from db import get_db_connection
from jwt_helper import generate_access_token, generate_refresh_token, verify_token
from jwt_helper import (
generate_access_token,
generate_refresh_token,
verify_token,
extract_token_from_header,
TokenError,
)

load_dotenv()

Expand Down Expand Up @@ -110,23 +115,12 @@ def login():

@authentication_blueprint.route("/refresh", methods=["POST"])
def refresh_token():
auth_header = request.headers.get("Authorization")

if not auth_header or not auth_header.startswith("Bearer "):
return (
jsonify(message="Refresh token is required in the Authorization header"),
400,
)

refresh_token = auth_header.split("Bearer ")[1]

try:
decoded = verify_token(refresh_token)
token = extract_token_from_header()
decoded = verify_token(token)
player_id = decoded["player_id"]

new_access_token = generate_access_token(player_id)
return jsonify(access_token=new_access_token), 200
except jwt.ExpiredSignatureError:
return jsonify(message="Refresh token has expired, please log in again"), 401
except jwt.InvalidTokenError:
return jsonify(message="Invalid refresh token"), 401
except TokenError as e:
return jsonify(message=e.message), e.status_code