Skip to content

Commit

Permalink
Add custom TokenError exception and refactor token handling for impro…
Browse files Browse the repository at this point in the history
…ved clarity
  • Loading branch information
Vianpyro committed Nov 20, 2024
1 parent a7eae1a commit 7094188
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 37 deletions.
63 changes: 42 additions & 21 deletions jwt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,21 @@
REFRESH_TOKEN_EXPIRY = timedelta(days=30)


class TokenError(Exception):
"""
Custom exception for token-related errors.
"""

def __init__(self, message, status_code):
super().__init__(message)
self.status_code = status_code
self.message = message


def generate_access_token(player_id: int) -> str:
"""Generate a JWT token for a user."""
"""
Generate a short-lived JWT access token for a user.
"""
payload = {
"player_id": player_id,
"exp": datetime.now(timezone.utc) + ACCESS_TOKEN_EXPIRY, # Expiration
Expand All @@ -21,7 +34,9 @@ def generate_access_token(player_id: int) -> str:


def generate_refresh_token(player_id: int) -> str:
"""Generate a long-lived refresh token."""
"""
Generate a long-lived refresh token for a user.
"""
payload = {
"player_id": player_id,
"exp": datetime.now(timezone.utc) + REFRESH_TOKEN_EXPIRY,
Expand All @@ -30,35 +45,41 @@ def generate_refresh_token(player_id: int) -> str:
return jwt.encode(payload, SECRET_KEY, algorithm="HS256")


def verify_token(token: str) -> dict | None:
"""Verify a JWT token and return the payload."""
token = request.headers.get("Authorization")

if not token or not token.startswith("Bearer "):
return jsonify(message="Token is missing or improperly formatted"), 401
def extract_token_from_header() -> str:
"""
Extract the Bearer token from the Authorization header.
"""
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
raise TokenError("Token is missing or improperly formatted", 401)
return auth_header.split("Bearer ")[1]

token = token.split("Bearer ")[1]

def verify_token(token: str) -> dict:
"""
Verify and decode a JWT token.
"""
try:
return jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
except jwt.ExpiredSignatureError:
return None # Token expired
raise TokenError("Token has expired", 401)
except jwt.InvalidTokenError:
return None # Invalid token
raise TokenError("Invalid token", 401)


def token_required(f):
"""
Decorator to protect routes by requiring a valid token.
"""

@wraps(f)
def decorated(*args, **kwargs):
token = request.headers.get("Authorization")
if not token:
return jsonify(message="Token is missing"), 401

decoded = verify_token(token)
if not decoded:
return jsonify(message="Token is invalid or expired"), 401

request.player_id = decoded["player_id"] # Attach user ID to the request
return f(*args, **kwargs)
try:
token = extract_token_from_header()
decoded = verify_token(token)
request.player_id = decoded["player_id"]
return f(*args, **kwargs)
except TokenError as e:
return jsonify(message=e.message), e.status_code

return decorated
27 changes: 11 additions & 16 deletions routes/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
from pymysql import MySQLError

from db import get_db_connection
from jwt_helper import generate_access_token, generate_refresh_token, verify_token
from jwt_helper import (
generate_access_token,
generate_refresh_token,
verify_token,
extract_token_from_header,
TokenError,
)

load_dotenv()

Expand Down Expand Up @@ -110,23 +116,12 @@ def login():

@authentication_blueprint.route("/refresh", methods=["POST"])
def refresh_token():
auth_header = request.headers.get("Authorization")

if not auth_header or not auth_header.startswith("Bearer "):
return (
jsonify(message="Refresh token is required in the Authorization header"),
400,
)

refresh_token = auth_header.split("Bearer ")[1]

try:
decoded = verify_token(refresh_token)
token = extract_token_from_header()
decoded = verify_token(token)
player_id = decoded["player_id"]

new_access_token = generate_access_token(player_id)
return jsonify(access_token=new_access_token), 200
except jwt.ExpiredSignatureError:
return jsonify(message="Refresh token has expired, please log in again"), 401
except jwt.InvalidTokenError:
return jsonify(message="Invalid refresh token"), 401
except TokenError as e:
return jsonify(message=e.message), e.status_code

0 comments on commit 7094188

Please sign in to comment.