Skip to content

Commit

Permalink
feat: refactor CORS handling and enhance origin validation logic
Browse files Browse the repository at this point in the history
  • Loading branch information
OdyAsh committed Jan 7, 2025
1 parent 792fc63 commit afbbdc6
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 21 deletions.
14 changes: 2 additions & 12 deletions src/ansari/app/main_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ansari.app.main_whatsapp import router as whatsapp_router
from ansari.config import Settings, get_settings
from ansari.presenters.api_presenter import ApiPresenter
from ansari.util.general_helpers import validate_cors
from ansari.util.general_helpers import get_extended_origins, validate_cors

logger = get_logger()

Expand All @@ -46,17 +46,7 @@ async def http_exception_handler(request, exc: HTTPException):


def add_app_middleware():
origins = get_settings().ORIGINS

# This if condition only runs in local development
if get_settings().DEBUG_MODE:
# Change "3000" to the port of your frontend server (3000 is the default there)
local_origin = "http://localhost:3000"
zrok_origin = get_settings().ZROK_SHARE_TOKEN.get_secret_value() + ".share.zrok.io"
# If we don't execute the code below, we'll get a "400 Bad Request" error when
# trying to access the API from the local frontend
origins.append(local_origin)
origins.append(zrok_origin)
origins = get_extended_origins()

app.add_middleware(
CORSMiddleware,
Expand Down
14 changes: 8 additions & 6 deletions src/ansari/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,14 @@ def get_resource_path(filename):
@field_validator("ORIGINS")
def parse_origins(cls, v):
if isinstance(v, str):
return [origin.strip() for origin in v.strip('"').split(",")]
if isinstance(v, list):
return v
raise ValueError(
f"Invalid ORIGINS format: {v}. Expected a comma-separated string or a list.",
)
origins = [origin.strip() for origin in v.strip('"').split(",")]
elif isinstance(v, list):
origins = v
else:
raise ValueError(
f"Invalid ORIGINS format: {v}. Expected a comma-separated string or a list.",
)
return origins


@lru_cache
Expand Down
11 changes: 10 additions & 1 deletion src/ansari/presenters/whatsapp_presenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,16 @@ async def handle_text_message(
else:
passed_time = (datetime.now() - last_message_time).total_seconds()

logger.debug(f"Time passed since user ({user_id_whatsapp})'s last whatsapp message: {passed_time / 60:.1f}mins")
# Log the time passed since the last message
if passed_time < 60:
passed_time_log = f"{passed_time:.1f}sec"
elif passed_time < 3600:
passed_time_log = f"{passed_time / 60:.1f}mins"
elif passed_time < 86400:
passed_time_log = f"{passed_time / 3600:.1f}hours"
else:
passed_time_log = f"{passed_time / 86400:.1f}days"
logger.debug(f"Time passed since user ({user_id_whatsapp})'s last whatsapp message: {passed_time_log}mins")

# Determine the allowed retention time
if get_settings().DEBUG_MODE:
Expand Down
29 changes: 27 additions & 2 deletions src/ansari/util/general_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,35 @@
logger = get_logger()


def get_extended_origins(settings: Settings = Depends(get_settings)):
origins = get_settings().ORIGINS

# This if condition only runs in local development
# Also, if we don't execute the code below, we'll get a "400 Bad Request" error when
# trying to access the API from the local frontend
if get_settings().DEBUG_MODE:
# Change "3000" to the port of your frontend server (3000 is the default there)
local_origin = "http://localhost:3000"
zrok_origin = get_settings().ZROK_SHARE_TOKEN.get_secret_value() + ".share.zrok.io"

if local_origin not in origins:
origins.append(local_origin)
if zrok_origin not in origins:
origins.append(zrok_origin)

# Make sure CI/CD of GitHub Actions is allowed
if "testserver" not in origins:
github_actions_origin = "testserver"
origins.append(github_actions_origin)

return origins


# Defined in a separate file to avoid circular imports between main_*.py files
def validate_cors(request: Request, settings: Settings = Depends(get_settings)) -> bool:
try:
# logger.debug(f"Headers of raw request are: {request.headers}")
origins = get_settings().ORIGINS
origins = get_extended_origins()
incoming_origin = [
request.headers.get("origin", ""), # If coming from ansari's frontend website
request.headers.get("host", ""), # If coming from Meta's WhatsApp API
Expand All @@ -22,7 +46,8 @@ def validate_cors(request: Request, settings: Settings = Depends(get_settings))
if any(i_o in origins for i_o in incoming_origin) or mobile == "ANSARI":
logger.debug("CORS OK")
return True
raise HTTPException(status_code=502, detail=f"Incoming origin/host: {incoming_origin} is not in origin list")
else:
raise HTTPException(status_code=502, detail=f"Incoming origin/host: {incoming_origin} is not in origin list")
except PyJWTError:
raise HTTPException(status_code=403, detail="Could not validate credentials")

Expand Down

0 comments on commit afbbdc6

Please sign in to comment.