diff --git a/frontend/src/context/socket.tsx b/frontend/src/context/socket.tsx index 259acd2e3e2..de36bcce38c 100644 --- a/frontend/src/context/socket.tsx +++ b/frontend/src/context/socket.tsx @@ -50,7 +50,8 @@ function SocketProvider({ children }: SocketProviderProps) { const baseUrl = import.meta.env.VITE_BACKEND_BASE_URL || fallback; const protocol = window.location.protocol === "https:" ? "wss:" : "ws:"; const ws = new WebSocket( - `${protocol}//${baseUrl}/ws${options?.token ? `?token=${options.token}` : ""}`, + `${protocol}//${baseUrl}/ws`, + ["openhands", options?.token || ""] // First protocol is our real protocol, second is the auth token ); ws.addEventListener("open", (event) => { diff --git a/openhands/server/listen.py b/openhands/server/listen.py index f7a9a9aa120..909460296cc 100644 --- a/openhands/server/listen.py +++ b/openhands/server/listen.py @@ -302,34 +302,36 @@ async def websocket_endpoint(websocket: WebSocket): {"action": "finish", "args": {}} ``` """ - cookies = dict( - cookie.split('=') - for cookie in websocket.headers.get('cookie', '').split('; ') - if cookie - ) - auth_cookie = cookies.get('openhands_auth') - - if not await authenticate_github_user(auth_cookie): + # Get protocols from Sec-WebSocket-Protocol header + protocols = websocket.headers.get('sec-websocket-protocol', '').split(', ') + + # The first protocol should be our real protocol (e.g. 'openhands') + # The second protocol should contain our auth token + if len(protocols) < 2: await websocket.close(code=status.WS_1008_POLICY_VIOLATION) return - await asyncio.wait_for(websocket.accept(), 10) + real_protocol = protocols[0] + auth_token = protocols[1] + + # Verify GitHub authentication + if not await authenticate_github_user(auth_token): + await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + return - if websocket.query_params.get('token'): - token = websocket.query_params.get('token') - sid = get_sid_from_token(token, config.jwt_secret) + # Accept the connection with the real protocol + await asyncio.wait_for(websocket.accept(subprotocol=real_protocol), 10) - if sid == '': - await websocket.send_json({'error': 'Invalid token', 'error_code': 401}) - await websocket.close() - return - else: - sid = str(uuid.uuid4()) - token = sign_token({'sid': sid}, config.jwt_secret) + # Extract session ID from token + sid = get_sid_from_token(auth_token, config.jwt_secret) + if sid == '': + await websocket.send_json({'error': 'Invalid token', 'error_code': 401}) + await websocket.close() + return logger.info(f'New session: {sid}') session = session_manager.add_or_restart_session(sid, websocket) - await websocket.send_json({'token': token, 'status': 'ok'}) + await websocket.send_json({'token': auth_token, 'status': 'ok'}) latest_event_id = -1 if websocket.query_params.get('latest_event_id'):