Skip to content

Commit

Permalink
Merge pull request #6701 from hotosm/fastapi-refactor
Browse files Browse the repository at this point in the history
pm only dependency function and injection in the required functions
  • Loading branch information
prabinoid authored Jan 21, 2025
2 parents 618eae8 + 1197601 commit 98b2bb3
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 34 deletions.
11 changes: 4 additions & 7 deletions backend/api/issues/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from backend.models.dtos.mapping_issues_dto import MappingIssueCategoryDTO
from backend.models.dtos.user_dto import AuthUserDTO
from backend.services.mapping_issues_service import MappingIssueCategoryService
from backend.services.users.authentication_service import login_required
from backend.services.users.authentication_service import pm_only

router = APIRouter(
prefix="/tasks",
Expand Down Expand Up @@ -50,11 +50,10 @@ async def get(category_id: int, db: Database = Depends(get_db)):


@router.patch("/issues/categories/{category_id}/")
# @tm.pm_only()
async def patch(
request: Request,
category_id: int,
user: AuthUserDTO = Depends(login_required),
user: AuthUserDTO = Depends(pm_only),
db: Database = Depends(get_db),
data: MappingIssueCategoryDTO = Body(...),
):
Expand Down Expand Up @@ -120,11 +119,10 @@ async def patch(


@router.delete("/issues/categories/{category_id}/")
# @tm.pm_only()
async def delete(
request: Request,
category_id: int,
user: AuthUserDTO = Depends(login_required),
user: AuthUserDTO = Depends(pm_only),
db: Database = Depends(get_db),
):
"""
Expand Down Expand Up @@ -195,10 +193,9 @@ async def get(request: Request, db: Database = Depends(get_db)):


@router.post("/issues/categories/", response_model=MappingIssueCategoryDTO)
# @tm.pm_only()
async def post(
request: Request,
user: AuthUserDTO = Depends(login_required),
user: AuthUserDTO = Depends(pm_only),
db: Database = Depends(get_db),
data: dict = Body(...),
):
Expand Down
26 changes: 16 additions & 10 deletions backend/api/licenses/resources.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from backend.models.dtos.user_dto import AuthUserDTO
from backend.services.users.authentication_service import pm_only
from databases import Database
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
Expand All @@ -14,10 +16,11 @@


@router.post("/")
# TODO: refactor decorator functions
# @requires("authenticated")
# @tm.pm_only()
async def post(license_dto: LicenseDTO, db: Database = Depends(get_db)):
async def post(
license_dto: LicenseDTO,
db: Database = Depends(get_db),
user: AuthUserDTO = Depends(pm_only),
):
"""
Creates a new mapping license
---
Expand Down Expand Up @@ -93,10 +96,11 @@ async def get(


@router.patch("/{license_id}/")
# @requires("authenticated")
# @tm.pm_only()
async def patch(
license_dto: LicenseDTO, license_id: int, db: Database = Depends(get_db)
license_dto: LicenseDTO,
license_id: int,
db: Database = Depends(get_db),
user: AuthUserDTO = Depends(pm_only),
):
"""
Update a specified mapping license
Expand Down Expand Up @@ -148,9 +152,11 @@ async def patch(


@router.delete("/{license_id}/")
# @requires("authenticated")
# @tm.pm_only()
async def delete(license_id: int, db: Database = Depends(get_db)):
async def delete(
license_id: int,
db: Database = Depends(get_db),
user: AuthUserDTO = Depends(pm_only),
):
"""
Delete a specified mapping license
---
Expand Down
5 changes: 2 additions & 3 deletions backend/api/projects/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
ProjectAdminServiceError,
)
from backend.services.project_service import ProjectService
from backend.services.users.authentication_service import login_required
from backend.services.users.authentication_service import login_required, pm_only

router = APIRouter(
prefix="/projects",
Expand Down Expand Up @@ -379,10 +379,9 @@ async def post(


@router.post("/actions/intersecting-tiles/")
# @tm.pm_only()
async def post(
request: Request,
user: AuthUserDTO = Depends(login_required),
user: AuthUserDTO = Depends(pm_only),
grid_dto: GridDTO = Body(...),
):
"""
Expand Down
15 changes: 5 additions & 10 deletions backend/api/users/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from backend.models.dtos.user_dto import AuthUserDTO, UserDTO, UserRegisterEmailDTO
from backend.services.interests_service import InterestService
from backend.services.messaging.message_service import MessageService
from backend.services.users.authentication_service import login_required
from backend.services.users.authentication_service import login_required, pm_only
from backend.services.users.user_service import UserService, UserServiceError

router = APIRouter(
Expand Down Expand Up @@ -115,15 +115,12 @@ async def patch(
return verification_sent


# class UsersActionsSetLevelAPI(Resource):
# @token_auth.login_required
@router.patch("/{username}/actions/set-level/{level}/")
# @tm.pm_only()
async def patch(
request: Request,
username,
level,
user: AuthUserDTO = Depends(login_required),
user: AuthUserDTO = Depends(pm_only),
db: Database = Depends(get_db),
):
"""
Expand Down Expand Up @@ -175,12 +172,11 @@ async def patch(


@router.patch("/{username}/actions/set-role/{role}/")
# @tm.pm_only()
async def patch(
request: Request,
username: str,
role: str,
user: AuthUserDTO = Depends(login_required),
user: AuthUserDTO = Depends(pm_only),
db: Database = Depends(get_db),
):
"""
Expand Down Expand Up @@ -232,12 +228,12 @@ async def patch(


@router.patch("/{user_name}/actions/set-expert-mode/{is_expert}/")
# @tm.pm_only()
# @tm.
async def patch(
request: Request,
user_name,
is_expert,
user: AuthUserDTO = Depends(login_required),
user: AuthUserDTO = Depends(pm_only),
db: Database = Depends(get_db),
):
"""
Expand Down Expand Up @@ -279,7 +275,6 @@ async def patch(


@router.patch("/me/actions/verify-email/")
# @tm.pm_only()
async def patch(
request: Request,
user: AuthUserDTO = Depends(login_required),
Expand Down
20 changes: 16 additions & 4 deletions backend/services/users/authentication_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Optional

from databases import Database
from fastapi import HTTPException, Security, status
from fastapi import Depends, HTTPException, Security, status
from fastapi.responses import JSONResponse
from fastapi.security.api_key import APIKeyHeader
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
Expand All @@ -19,7 +19,9 @@

from backend.api.utils import TMAPIDecorators
from backend.config import settings
from backend.db import get_db
from backend.models.dtos.user_dto import AuthUserDTO
from backend.models.postgis.statuses import UserRole
from backend.models.postgis.user import User
from backend.services.messaging.message_service import MessageService
from backend.services.users.user_service import NotFound, UserService
Expand Down Expand Up @@ -283,6 +285,7 @@ async def login_required_optional(

async def pm_only(
Authorization: str = Security(APIKeyHeader(name="Authorization")),
db: Database = Depends(get_db),
):
if not Authorization:
raise HTTPException(status_code=401, detail="Authorization header missing")
Expand All @@ -293,15 +296,24 @@ async def pm_only(
try:
decoded_token = base64.b64decode(credentials).decode("ascii")
except UnicodeDecodeError:
logger.debug("Unable to decode token")
raise HTTPException(status_code=401, detail="Invalid token")
except (ValueError, UnicodeDecodeError, binascii.Error):
raise AuthenticationError("Invalid auth credentials")
raise HTTPException(status_code=401, detail="Invalid auth credentials")

valid_token, user_id = AuthenticationService.is_valid_token(decoded_token, 604800)
if not valid_token:
logger.debug("Token not valid")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={"Error": "Token is expired or invalid", "SubCode": "InvalidToken"},
headers={"WWW-Authenticate": "Bearer"},
)

query = "SELECT id, username, role FROM users WHERE id = :user_id"
user = await db.fetch_one(query=query, values={"user_id": user_id})
if not user:
raise HTTPException(status_code=404, detail="User not found")

if user["role"] != UserRole.ADMIN.value:
raise HTTPException(status_code=403, detail="Admin access required")

return AuthUserDTO(id=user["id"])

0 comments on commit 98b2bb3

Please sign in to comment.