diff --git a/app/main.py b/app/main.py index 2affc0e..bb71d34 100644 --- a/app/main.py +++ b/app/main.py @@ -25,11 +25,60 @@ import string import time from logger_config import setup_logger +from slowapi import Limiter +from slowapi.util import get_remote_address +from slowapi.errors import RateLimitExceeded +from slowapi.middleware import SlowAPIMiddleware +from fastapi import Depends logger = setup_logger() +RATE_LIMITING_ENABLED = True + app = FastAPI() +limiter = Limiter( + key_func=get_remote_address, + strategy="fixed-window", + storage_uri="memory://", + enabled=RATE_LIMITING_ENABLED, +) + +app.state.limiter = limiter + +# Add rate limiting middleware +app.add_middleware(SlowAPIMiddleware) + + +@app.exception_handler(RateLimitExceeded) +async def rate_limit_handler(request: Request, exc: RateLimitExceeded): + logger.warning( + f"Rate limit exceeded - IP: {get_remote_address(request)}, " + f"Path: {request.url.path}, " + f"Method: {request.method}" + ) + return JSONResponse( + status_code=429, + content={"detail": "Rate limit exceeded", "retry_after": str(exc.retry_after)}, + ) + + +# Define rate limit decorators for different endpoints +def auth_rate_limit(): + return limiter.limit( + "10 per minute", error_message="Authentication rate limit exceeded" + ) + + +def user_rate_limit(): + return limiter.limit( + "10 per minute", error_message="User endpoint rate limit exceeded" + ) + + +auth.router.dependencies.append(Depends(auth_rate_limit())) +user.router.dependencies.append(Depends(user_rate_limit())) + @app.middleware("http") async def log_requests(request: Request, call_next): @@ -78,8 +127,11 @@ def authjwt_exception_handler(request: Request, exc: AuthJWTException): return JSONResponse(status_code=exc.status_code, content={"detail": exc.message}) +app.include_router(auth.router, prefix="/auth", tags=["authentication"]) +app.include_router(user.router, prefix="/users", tags=["users"]) + + app.include_router(auth_group.router) -app.include_router(auth.router) app.include_router(batch.router) app.include_router(enrollment_record.router) app.include_router(form.router) @@ -92,11 +144,11 @@ def authjwt_exception_handler(request: Request, exc: AuthJWTException): app.include_router(student.router) app.include_router(teacher.router) app.include_router(user_session.router) -app.include_router(user.router) @app.get("/") -def index(): +@limiter.limit("10 per minute") +async def index(request: Request): # Added request parameter and made it async return "Welcome to Portal!" diff --git a/app/requirements.txt b/app/requirements.txt index 06f82e2..99abd7a 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -22,3 +22,4 @@ uvicorn==0.21.1 uvloop==0.17.0 watchfiles==0.18.1 websockets==10.4 +slowapi==0.1.4 diff --git a/app/router/auth.py b/app/router/auth.py index 91ffb27..81caf72 100644 --- a/app/router/auth.py +++ b/app/router/auth.py @@ -14,16 +14,22 @@ def index(): # if user is valid, generates both access token and refresh token. Otherwise, only an access token. @router.post("/create-access-token") def create_access_token(auth_user: AuthUser, Authorize: AuthJWT = Depends()): + # Define access_token and refresh_token as empty strings + access_token = "" refresh_token = "" - data = auth_user.data + data = auth_user.data if auth_user.data is not None else {} - if auth_user.data is None: - data = {} + # Validate auth_user.type + if auth_user.type not in ["user", "organization"]: + raise HTTPException( + status_code=400, + detail="Invalid user type! Must be either 'user' or 'organization'.", + ) if auth_user.type == "organization": if not auth_user.name: - return HTTPException( - status_code=400, detail="Data Parameter {} is missing!".format("name") + raise HTTPException( + status_code=400, detail="Data Parameter 'name' is missing!" ) expires = datetime.timedelta(weeks=260) access_token = Authorize.create_access_token( @@ -34,9 +40,8 @@ def create_access_token(auth_user: AuthUser, Authorize: AuthJWT = Depends()): elif auth_user.type == "user": if "is_user_valid" not in auth_user.dict().keys(): - return HTTPException( - status_code=400, - detail="Data Parameter {} is missing!".format("is_user_valid"), + raise HTTPException( + status_code=400, detail="Data Parameter 'is_user_valid' is missing!" ) if auth_user.is_user_valid: refresh_token = Authorize.create_refresh_token( @@ -55,10 +60,7 @@ def refresh_token(Authorize: AuthJWT = Depends()): Authorize.jwt_refresh_token_required() current_user = Authorize.get_jwt_subject() old_data = Authorize.get_raw_jwt() - if "group" in old_data: - custom_claims = {"group": old_data["group"]} - else: - custom_claims = {} + custom_claims = {"group": old_data.get("group", {})} new_access_token = Authorize.create_access_token( subject=current_user, user_claims=custom_claims )