From 05645d1bbdb3f7ba3a5c00a96a7ec7e4b4adefa8 Mon Sep 17 00:00:00 2001 From: tofarr Date: Wed, 30 Oct 2024 08:46:22 -0600 Subject: [PATCH] Refactor CORS middleware and enhance localhost handling (#4624) Co-authored-by: openhands --- openhands/server/listen.py | 22 ++--------------- openhands/server/middleware.py | 43 ++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 20 deletions(-) create mode 100644 openhands/server/middleware.py diff --git a/openhands/server/listen.py b/openhands/server/listen.py index c3a63853453..fc740e80293 100644 --- a/openhands/server/listen.py +++ b/openhands/server/listen.py @@ -29,12 +29,10 @@ WebSocket, status, ) -from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from fastapi.security import HTTPBearer from fastapi.staticfiles import StaticFiles from pydantic import BaseModel -from starlette.middleware.base import BaseHTTPMiddleware import openhands.agenthub # noqa F401 (we import this to get the agents registered) from openhands.controller.agent import Agent @@ -57,6 +55,7 @@ from openhands.llm import bedrock from openhands.runtime.base import Runtime from openhands.server.auth import get_sid_from_token, sign_token +from openhands.server.middleware import LocalhostCORSMiddleware, NoCacheMiddleware from openhands.server.session import SessionManager load_dotenv() @@ -93,30 +92,13 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) app.add_middleware( - CORSMiddleware, - allow_origins=['http://localhost:3001', 'http://127.0.0.1:3001'], + LocalhostCORSMiddleware, allow_credentials=True, allow_methods=['*'], allow_headers=['*'], ) -class NoCacheMiddleware(BaseHTTPMiddleware): - """ - Middleware to disable caching for all routes by adding appropriate headers - """ - - async def dispatch(self, request, call_next): - response = await call_next(request) - if not request.url.path.startswith('/assets'): - response.headers['Cache-Control'] = ( - 'no-cache, no-store, must-revalidate, max-age=0' - ) - response.headers['Pragma'] = 'no-cache' - response.headers['Expires'] = '0' - return response - - app.add_middleware(NoCacheMiddleware) security_scheme = HTTPBearer() diff --git a/openhands/server/middleware.py b/openhands/server/middleware.py new file mode 100644 index 00000000000..f09ac0788ae --- /dev/null +++ b/openhands/server/middleware.py @@ -0,0 +1,43 @@ +from urllib.parse import urlparse + +from fastapi.middleware.cors import CORSMiddleware +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.types import ASGIApp + + +class LocalhostCORSMiddleware(CORSMiddleware): + """ + Custom CORS middleware that allows any request from localhost/127.0.0.1 domains, + while using standard CORS rules for other origins. + """ + + def __init__(self, app: ASGIApp, **kwargs) -> None: + super().__init__(app, **kwargs) + + async def is_allowed_origin(self, origin: str) -> bool: + if origin: + parsed = urlparse(origin) + hostname = parsed.hostname or '' + + # Allow any localhost/127.0.0.1 origin regardless of port + if hostname in ['localhost', '127.0.0.1']: + return True + + # For missing origin or other origins, use the parent class's logic + return await super().is_allowed_origin(origin) + + +class NoCacheMiddleware(BaseHTTPMiddleware): + """ + Middleware to disable caching for all routes by adding appropriate headers + """ + + async def dispatch(self, request, call_next): + response = await call_next(request) + if not request.url.path.startswith('/assets'): + response.headers['Cache-Control'] = ( + 'no-cache, no-store, must-revalidate, max-age=0' + ) + response.headers['Pragma'] = 'no-cache' + response.headers['Expires'] = '0' + return response