Skip to content

Commit

Permalink
Refactor CORS middleware and enhance localhost handling (#4624)
Browse files Browse the repository at this point in the history
Co-authored-by: openhands <[email protected]>
  • Loading branch information
tofarr and openhands-agent authored Oct 30, 2024
1 parent e21abce commit 05645d1
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 20 deletions.
22 changes: 2 additions & 20 deletions openhands/server/listen.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,10 @@
WebSocket,
status,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security import HTTPBearer
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from starlette.middleware.base import BaseHTTPMiddleware

import openhands.agenthub # noqa F401 (we import this to get the agents registered)
from openhands.controller.agent import Agent
Expand All @@ -57,6 +55,7 @@
from openhands.llm import bedrock
from openhands.runtime.base import Runtime
from openhands.server.auth import get_sid_from_token, sign_token
from openhands.server.middleware import LocalhostCORSMiddleware, NoCacheMiddleware
from openhands.server.session import SessionManager

load_dotenv()
Expand Down Expand Up @@ -93,30 +92,13 @@ async def lifespan(app: FastAPI):

app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=['http://localhost:3001', 'http://127.0.0.1:3001'],
LocalhostCORSMiddleware,
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)


class NoCacheMiddleware(BaseHTTPMiddleware):
"""
Middleware to disable caching for all routes by adding appropriate headers
"""

async def dispatch(self, request, call_next):
response = await call_next(request)
if not request.url.path.startswith('/assets'):
response.headers['Cache-Control'] = (
'no-cache, no-store, must-revalidate, max-age=0'
)
response.headers['Pragma'] = 'no-cache'
response.headers['Expires'] = '0'
return response


app.add_middleware(NoCacheMiddleware)

security_scheme = HTTPBearer()
Expand Down
43 changes: 43 additions & 0 deletions openhands/server/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from urllib.parse import urlparse

from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp


class LocalhostCORSMiddleware(CORSMiddleware):
"""
Custom CORS middleware that allows any request from localhost/127.0.0.1 domains,
while using standard CORS rules for other origins.
"""

def __init__(self, app: ASGIApp, **kwargs) -> None:
super().__init__(app, **kwargs)

async def is_allowed_origin(self, origin: str) -> bool:
if origin:
parsed = urlparse(origin)
hostname = parsed.hostname or ''

# Allow any localhost/127.0.0.1 origin regardless of port
if hostname in ['localhost', '127.0.0.1']:
return True

# For missing origin or other origins, use the parent class's logic
return await super().is_allowed_origin(origin)


class NoCacheMiddleware(BaseHTTPMiddleware):
"""
Middleware to disable caching for all routes by adding appropriate headers
"""

async def dispatch(self, request, call_next):
response = await call_next(request)
if not request.url.path.startswith('/assets'):
response.headers['Cache-Control'] = (
'no-cache, no-store, must-revalidate, max-age=0'
)
response.headers['Pragma'] = 'no-cache'
response.headers['Expires'] = '0'
return response

0 comments on commit 05645d1

Please sign in to comment.