diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index 7899c2dcfab0..490e3355f059 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -163,9 +163,6 @@ async def close(self) -> None: filter_hidden=True, ) ) - - # unsubscribe from the event stream - self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER, self.id) self._closed = True def log(self, level: str, message: str, extra: dict | None = None) -> None: diff --git a/openhands/core/config/sandbox_config.py b/openhands/core/config/sandbox_config.py index 3a0b705dd02d..0ea40f29faab 100644 --- a/openhands/core/config/sandbox_config.py +++ b/openhands/core/config/sandbox_config.py @@ -41,7 +41,7 @@ class SandboxConfig: remote_runtime_api_url: str = 'http://localhost:8000' local_runtime_url: str = 'http://localhost' - keep_runtime_alive: bool = True + keep_runtime_alive: bool = False rm_all_containers: bool = False api_key: str | None = None base_container_image: str = 'nikolaik/python-nodejs:python3.12-nodejs22' # default to nikolaik/python-nodejs:python3.12-nodejs22 for eventstream runtime @@ -60,7 +60,7 @@ class SandboxConfig: runtime_startup_env_vars: dict[str, str] = field(default_factory=dict) browsergym_eval_env: str | None = None platform: str | None = None - close_delay: int = 900 + close_delay: int = 15 remote_runtime_resource_factor: int = 1 enable_gpu: bool = False docker_runtime_kwargs: str | None = None diff --git a/openhands/runtime/builder/remote.py b/openhands/runtime/builder/remote.py index 2e2c67c5a8f3..3c9c81e77af6 100644 --- a/openhands/runtime/builder/remote.py +++ b/openhands/runtime/builder/remote.py @@ -18,10 +18,12 @@ class RemoteRuntimeBuilder(RuntimeBuilder): """This class interacts with the remote Runtime API for building and managing container images.""" - def __init__(self, api_url: str, api_key: str): + def __init__( + self, api_url: str, api_key: str, session: requests.Session | None = None + ): self.api_url = api_url self.api_key = api_key - self.session = requests.Session() + self.session = session or requests.Session() self.session.headers.update({'X-API-Key': self.api_key}) def build( @@ -53,13 +55,14 @@ def build( # Send the POST request to /build (Begins the build process) try: - response = send_request( + with send_request( self.session, 'POST', f'{self.api_url}/build', files=files, timeout=30, - ) + ) as response: + build_data = response.json() except requests.exceptions.HTTPError as e: if e.response.status_code == 429: logger.warning('Build was rate limited. Retrying in 30 seconds.') @@ -68,7 +71,6 @@ def build( else: raise e - build_data = response.json() build_id = build_data['build_id'] logger.info(f'Build initiated with ID: {build_id}') @@ -80,71 +82,62 @@ def build( logger.error('Build timed out after 30 minutes') raise AgentRuntimeBuildError('Build timed out after 30 minutes') - status_response = send_request( + with send_request( self.session, 'GET', f'{self.api_url}/build_status', params={'build_id': build_id}, - ) - - if status_response.status_code != 200: - logger.error(f'Failed to get build status: {status_response.text}') - raise AgentRuntimeBuildError( - f'Failed to get build status: {status_response.text}' - ) - - status_data = status_response.json() - status = status_data['status'] - logger.info(f'Build status: {status}') - - if status == 'SUCCESS': - logger.debug(f"Successfully built {status_data['image']}") - return status_data['image'] - elif status in [ - 'FAILURE', - 'INTERNAL_ERROR', - 'TIMEOUT', - 'CANCELLED', - 'EXPIRED', - ]: - error_message = status_data.get( - 'error', f'Build failed with status: {status}. Build ID: {build_id}' - ) - logger.error(error_message) - raise AgentRuntimeBuildError(error_message) + ) as status_response: + status_data = status_response.json() + status = status_data['status'] + logger.info(f'Build status: {status}') + + if status == 'SUCCESS': + logger.debug(f"Successfully built {status_data['image']}") + return status_data['image'] + elif status in [ + 'FAILURE', + 'INTERNAL_ERROR', + 'TIMEOUT', + 'CANCELLED', + 'EXPIRED', + ]: + error_message = status_data.get( + 'error', + f'Build failed with status: {status}. Build ID: {build_id}', + ) + logger.error(error_message) + raise AgentRuntimeBuildError(error_message) # Wait before polling again sleep_if_should_continue(30) - raise AgentRuntimeBuildError( - 'Build interrupted (likely received SIGTERM or SIGINT).' - ) + raise AgentRuntimeBuildError('Build interrupted') def image_exists(self, image_name: str, pull_from_repo: bool = True) -> bool: """Checks if an image exists in the remote registry using the /image_exists endpoint.""" params = {'image': image_name} - response = send_request( + with send_request( self.session, 'GET', f'{self.api_url}/image_exists', params=params, - ) - - if response.status_code != 200: - logger.error(f'Failed to check image existence: {response.text}') - raise AgentRuntimeBuildError( - f'Failed to check image existence: {response.text}' - ) - - result = response.json() - - if result['exists']: - logger.debug( - f"Image {image_name} exists. " - f"Uploaded at: {result['image']['upload_time']}, " - f"Size: {result['image']['image_size_bytes'] / 1024 / 1024:.2f} MB" - ) - else: - logger.debug(f'Image {image_name} does not exist.') - - return result['exists'] + ) as response: + if response.status_code != 200: + logger.error(f'Failed to check image existence: {response.text}') + raise AgentRuntimeBuildError( + f'Failed to check image existence: {response.text}' + ) + + result = response.json() + + if result['exists']: + logger.debug( + f"Image {image_name} exists. " + f"Uploaded at: {result['image']['upload_time']}, " + f"Size: {result['image']['image_size_bytes'] / 1024 / 1024:.2f} MB" + ) + else: + logger.debug(f'Image {image_name} does not exist.') + + return result['exists'] diff --git a/openhands/runtime/impl/remote/remote_runtime.py b/openhands/runtime/impl/remote/remote_runtime.py index 0e0b7adc79e6..c0d75e3b700f 100644 --- a/openhands/runtime/impl/remote/remote_runtime.py +++ b/openhands/runtime/impl/remote/remote_runtime.py @@ -66,7 +66,9 @@ def __init__( ) self.runtime_builder = RemoteRuntimeBuilder( - self.config.sandbox.remote_runtime_api_url, self.config.sandbox.api_key + self.config.sandbox.remote_runtime_api_url, + self.config.sandbox.api_key, + self.session, ) self.runtime_id: str | None = None self.runtime_url: str | None = None diff --git a/openhands/runtime/utils/request.py b/openhands/runtime/utils/request.py index a145bd27f4e5..e05a083e7b0d 100644 --- a/openhands/runtime/utils/request.py +++ b/openhands/runtime/utils/request.py @@ -21,7 +21,7 @@ def __str__(self) -> str: return s -def is_rate_limit_error(exception): +def is_retryable_error(exception): return ( isinstance(exception, requests.HTTPError) and exception.response.status_code == 429 @@ -29,7 +29,7 @@ def is_rate_limit_error(exception): @retry( - retry=retry_if_exception(is_rate_limit_error), + retry=retry_if_exception(is_retryable_error), stop=stop_after_attempt(3) | stop_if_should_exit(), wait=wait_exponential(multiplier=1, min=4, max=60), ) @@ -48,6 +48,8 @@ def send_request( _json = response.json() except (requests.exceptions.JSONDecodeError, json.decoder.JSONDecodeError): _json = None + finally: + response.close() raise RequestHTTPError( e, response=e.response, diff --git a/openhands/server/listen_socket.py b/openhands/server/listen_socket.py index 884b7e3b7cef..027c24328852 100644 --- a/openhands/server/listen_socket.py +++ b/openhands/server/listen_socket.py @@ -30,7 +30,7 @@ async def connect(connection_id: str, environ, auth): logger.error('No conversation_id in query params') raise ConnectionRefusedError('No conversation_id in query params') - user_id = -1 + user_id = None if openhands_config.app_mode != AppMode.OSS: cookies_str = environ.get('HTTP_COOKIE', '') cookies = dict(cookie.split('=', 1) for cookie in cookies_str.split('; ')) @@ -63,7 +63,7 @@ async def connect(connection_id: str, environ, auth): try: event_stream = await session_manager.join_conversation( - conversation_id, connection_id, settings + conversation_id, connection_id, settings, user_id ) except ConversationDoesNotExistError: logger.error(f'Conversation {conversation_id} does not exist') diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index 8b788e23c99e..e94a491621cc 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -1,5 +1,5 @@ import uuid -from datetime import datetime +from datetime import datetime, timezone from typing import Callable from fastapi import APIRouter, Body, Request @@ -42,7 +42,8 @@ async def new_conversation(request: Request, data: InitSessionRequest): logger.info('Initializing new conversation') logger.info('Loading settings') - settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request)) + user_id = get_user_id(request) + settings_store = await SettingsStoreImpl.get_instance(config, user_id) settings = await settings_store.load() logger.info('Settings loaded') @@ -55,9 +56,7 @@ async def new_conversation(request: Request, data: InitSessionRequest): session_init_args['selected_repository'] = data.selected_repository conversation_init_data = ConversationInitData(**session_init_args) logger.info('Loading conversation store') - conversation_store = await ConversationStoreImpl.get_instance( - config, get_user_id(request) - ) + conversation_store = await ConversationStoreImpl.get_instance(config, user_id) logger.info('Conversation store loaded') conversation_id = uuid.uuid4().hex @@ -76,19 +75,19 @@ async def new_conversation(request: Request, data: InitSessionRequest): ConversationMetadata( conversation_id=conversation_id, title=conversation_title, - github_user_id=get_user_id(request), + github_user_id=user_id, selected_repository=data.selected_repository, ) ) logger.info(f'Starting agent loop for conversation {conversation_id}') event_stream = await session_manager.maybe_start_agent_loop( - conversation_id, conversation_init_data + conversation_id, conversation_init_data, user_id ) try: event_stream.subscribe( EventStreamSubscriber.SERVER, - _create_conversation_update_callback(get_user_id(request), conversation_id), + _create_conversation_update_callback(user_id, conversation_id), UPDATED_AT_CALLBACK_ID, ) except ValueError: @@ -112,8 +111,8 @@ async def search_conversations( for conversation in conversation_metadata_result_set.results if hasattr(conversation, 'created_at') ) - running_conversations = await session_manager.get_agent_loop_running( - set(conversation_ids) + running_conversations = await session_manager.get_running_agent_loops( + get_user_id(request), set(conversation_ids) ) result = ConversationInfoResultSet( results=await wait_all( @@ -222,5 +221,5 @@ def callback(*args, **kwargs): async def _update_timestamp_for_conversation(user_id: int, conversation_id: str): conversation_store = await ConversationStoreImpl.get_instance(config, user_id) conversation = await conversation_store.get_metadata(conversation_id) - conversation.last_updated_at = datetime.now() + conversation.last_updated_at = datetime.now(timezone.utc) await conversation_store.save_metadata(conversation) diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index 0bb0d2387e5e..3c4125ca0f15 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -1,4 +1,5 @@ import asyncio +import time from typing import Callable, Optional from openhands.controller import AgentController @@ -16,7 +17,7 @@ from openhands.runtime.base import Runtime from openhands.security import SecurityAnalyzer, options from openhands.storage.files import FileStore -from openhands.utils.async_utils import call_async_from_sync, call_sync_from_async +from openhands.utils.async_utils import call_sync_from_async from openhands.utils.shutdown_listener import should_continue WAIT_TIME_BEFORE_CLOSE = 300 @@ -36,7 +37,8 @@ class AgentSession: controller: AgentController | None = None runtime: Runtime | None = None security_analyzer: SecurityAnalyzer | None = None - _initializing: bool = False + _starting: bool = False + _started_at: float = 0 _closed: bool = False loop: asyncio.AbstractEventLoop | None = None @@ -88,7 +90,8 @@ async def start( if self._closed: logger.warning('Session closed before starting') return - self._initializing = True + self._starting = True + self._started_at = time.time() self._create_security_analyzer(config.security.security_analyzer) await self._create_runtime( runtime_name=runtime_name, @@ -109,24 +112,19 @@ async def start( self.event_stream.add_event( ChangeAgentStateAction(AgentState.INIT), EventSource.ENVIRONMENT ) - self._initializing = False + self._starting = False - def close(self): + async def close(self): """Closes the Agent session""" if self._closed: return self._closed = True - call_async_from_sync(self._close) - - async def _close(self): - seconds_waited = 0 - while self._initializing and should_continue(): + while self._starting and should_continue(): logger.debug( f'Waiting for initialization to finish before closing session {self.sid}' ) await asyncio.sleep(WAIT_TIME_BEFORE_CLOSE_INTERVAL) - seconds_waited += WAIT_TIME_BEFORE_CLOSE_INTERVAL - if seconds_waited > WAIT_TIME_BEFORE_CLOSE: + if time.time() <= self._started_at + WAIT_TIME_BEFORE_CLOSE: logger.error( f'Waited too long for initialization to finish before closing session {self.sid}' ) diff --git a/openhands/server/session/manager.py b/openhands/server/session/manager.py index 9e7f7d8b8d7f..68b111387590 100644 --- a/openhands/server/session/manager.py +++ b/openhands/server/session/manager.py @@ -2,6 +2,7 @@ import json import time from dataclasses import dataclass, field +from typing import Generic, Iterable, TypeVar from uuid import uuid4 import socketio @@ -9,12 +10,13 @@ from openhands.core.config import AppConfig from openhands.core.exceptions import AgentRuntimeUnavailableError from openhands.core.logger import openhands_logger as logger +from openhands.core.schema.agent import AgentState from openhands.events.stream import EventStream, session_exists from openhands.server.session.conversation import Conversation from openhands.server.session.session import ROOM_KEY, Session from openhands.server.settings import Settings from openhands.storage.files import FileStore -from openhands.utils.async_utils import call_sync_from_async +from openhands.utils.async_utils import wait_all from openhands.utils.shutdown_listener import should_continue _REDIS_POLL_TIMEOUT = 1.5 @@ -22,6 +24,8 @@ _CLEANUP_INTERVAL = 15 _CLEANUP_EXCEPTION_WAIT_TIME = 15 +MAX_RUNNING_CONVERSATIONS = 3 +T = TypeVar('T') class ConversationDoesNotExistError(Exception): @@ -29,10 +33,10 @@ class ConversationDoesNotExistError(Exception): @dataclass -class _SessionIsRunningCheck: - request_id: str - request_sids: list[str] - running_sids: set[str] = field(default_factory=set) +class _ClusterQuery(Generic[T]): + query_id: str + request_ids: set[str] | None + result: T flag: asyncio.Event = field(default_factory=asyncio.Event) @@ -42,21 +46,18 @@ class SessionManager: config: AppConfig file_store: FileStore _local_agent_loops_by_sid: dict[str, Session] = field(default_factory=dict) - local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict) + _local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict) _last_alive_timestamps: dict[str, float] = field(default_factory=dict) _redis_listen_task: asyncio.Task | None = None - _session_is_running_checks: dict[str, _SessionIsRunningCheck] = field( + _running_sid_queries: dict[str, _ClusterQuery[set[str]]] = field( default_factory=dict ) _active_conversations: dict[str, tuple[Conversation, int]] = field( default_factory=dict ) - _detached_conversations: dict[str, tuple[Conversation, float]] = field( - default_factory=dict - ) _conversations_lock: asyncio.Lock = field(default_factory=asyncio.Lock) _cleanup_task: asyncio.Task | None = None - _has_remote_connections_flags: dict[str, asyncio.Event] = field( + _connection_queries: dict[str, _ClusterQuery[dict[str, str]]] = field( default_factory=dict ) @@ -64,7 +65,7 @@ async def __aenter__(self): redis_client = self._get_redis_client() if redis_client: self._redis_listen_task = asyncio.create_task(self._redis_subscribe()) - self._cleanup_task = asyncio.create_task(self._cleanup_detached_conversations()) + self._cleanup_task = asyncio.create_task(self._cleanup_stale()) return self async def __aexit__(self, exc_type, exc_value, traceback): @@ -86,7 +87,7 @@ async def _redis_subscribe(self): logger.debug('_redis_subscribe') redis_client = self._get_redis_client() pubsub = redis_client.pubsub() - await pubsub.subscribe('oh_event') + await pubsub.subscribe('session_msg') while should_continue(): try: message = await pubsub.get_message( @@ -114,59 +115,71 @@ async def _process_message(self, message: dict): session = self._local_agent_loops_by_sid.get(sid) if session: await session.dispatch(data['data']) - elif message_type == 'is_session_running': + elif message_type == 'running_agent_loops_query': # Another node in the cluster is asking if the current node is running the session given. - request_id = data['request_id'] - sids = [ - sid for sid in data['sids'] if sid in self._local_agent_loops_by_sid - ] + query_id = data['query_id'] + sids = self._get_running_agent_loops_locally( + data.get('user_id'), data.get('filter_to_sids') + ) if sids: await self._get_redis_client().publish( - 'oh_event', + 'session_msg', json.dumps( { - 'request_id': request_id, - 'sids': sids, - 'message_type': 'session_is_running', + 'query_id': query_id, + 'sids': list(sids), + 'message_type': 'running_agent_loops_response', } ), ) - elif message_type == 'session_is_running': - request_id = data['request_id'] + elif message_type == 'running_agent_loops_response': + query_id = data['query_id'] for sid in data['sids']: self._last_alive_timestamps[sid] = time.time() - check = self._session_is_running_checks.get(request_id) - if check: - check.running_sids.update(data['sids']) - if len(check.request_sids) == len(check.running_sids): - check.flag.set() - elif message_type == 'has_remote_connections_query': + running_query = self._running_sid_queries.get(query_id) + if running_query: + running_query.result.update(data['sids']) + if running_query.request_ids is not None and len( + running_query.request_ids + ) == len(running_query.result): + running_query.flag.set() + elif message_type == 'connections_query': # Another node in the cluster is asking if the current node is connected to a session - sid = data['sid'] - required = sid in self.local_connection_id_to_session_id.values() - if required: + query_id = data['query_id'] + connections = self._get_connections_locally( + data.get('user_id'), data.get('filter_to_sids') + ) + if connections: await self._get_redis_client().publish( - 'oh_event', + 'session_msg', json.dumps( - {'sid': sid, 'message_type': 'has_remote_connections_response'} + { + 'query_id': query_id, + 'connections': connections, + 'message_type': 'connections_response', + } ), ) - elif message_type == 'has_remote_connections_response': - sid = data['sid'] - flag = self._has_remote_connections_flags.get(sid) - if flag: - flag.set() + elif message_type == 'connections_response': + query_id = data['query_id'] + connection_query = self._connection_queries.get(query_id) + if connection_query: + connection_query.result.update(**data['connections']) + if connection_query.request_ids is not None and len( + connection_query.request_ids + ) == len(connection_query.result): + connection_query.flag.set() elif message_type == 'close_session': sid = data['sid'] if sid in self._local_agent_loops_by_sid: - await self._on_close_session(sid) + await self._close_session(sid) elif message_type == 'session_closing': # Session closing event - We only get this in the event of graceful shutdown, # which can't be guaranteed - nodes can simply vanish unexpectedly! sid = data['sid'] logger.debug(f'session_closing:{sid}') # Create a list of items to process to avoid modifying dict during iteration - items = list(self.local_connection_id_to_session_id.items()) + items = list(self._local_connection_id_to_session_id.items()) for connection_id, local_sid in items: if sid == local_sid: logger.warning( @@ -187,13 +200,6 @@ async def attach_to_conversation(self, sid: str) -> Conversation | None: logger.info(f'Reusing active conversation {sid}') return conversation - # Check if we have a detached conversation we can reuse - if sid in self._detached_conversations: - conversation, _ = self._detached_conversations.pop(sid) - self._active_conversations[sid] = (conversation, 1) - logger.info(f'Reusing detached conversation {sid}') - return conversation - # Create new conversation if none exists c = Conversation(sid, file_store=self.file_store, config=self.config) try: @@ -209,13 +215,15 @@ async def attach_to_conversation(self, sid: str) -> Conversation | None: self._active_conversations[sid] = (c, 1) return c - async def join_conversation(self, sid: str, connection_id: str, settings: Settings): + async def join_conversation( + self, sid: str, connection_id: str, settings: Settings, user_id: int | None + ): logger.info(f'join_conversation:{sid}:{connection_id}') await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid)) - self.local_connection_id_to_session_id[connection_id] = sid + self._local_connection_id_to_session_id[connection_id] = sid event_stream = await self._get_event_stream(sid) if not event_stream: - return await self.maybe_start_agent_loop(sid, settings) + return await self.maybe_start_agent_loop(sid, settings, user_id) return event_stream async def detach_from_conversation(self, conversation: Conversation): @@ -228,131 +236,202 @@ async def detach_from_conversation(self, conversation: Conversation): return else: self._active_conversations.pop(sid) - self._detached_conversations[sid] = (conversation, time.time()) - async def _cleanup_detached_conversations(self): + async def _cleanup_stale(self): while should_continue(): if self._get_redis_client(): # Debug info for HA envs logger.info( - f'Attached conversations: {len(self._active_conversations)}' - ) - logger.info( - f'Detached conversations: {len(self._detached_conversations)}' - ) - logger.info( - f'Running agent loops: {len(self._local_agent_loops_by_sid)}' - ) - logger.info( - f'Local connections: {len(self.local_connection_id_to_session_id)}' + f'agent_loops: {len(self._local_agent_loops_by_sid)}; local_connections: {len(self._local_connection_id_to_session_id)}' ) try: - async with self._conversations_lock: - # Create a list of items to process to avoid modifying dict during iteration - items = list(self._detached_conversations.items()) - for sid, (conversation, detach_time) in items: - await conversation.disconnect() - self._detached_conversations.pop(sid, None) + close_threshold = time.time() - self.config.sandbox.close_delay + running_loops = list(self._local_agent_loops_by_sid.items()) + running_loops.sort(key=lambda item: item[1].last_active_ts) + sid_to_close: list[str] = [] + for sid, session in running_loops: + controller = session.agent_session.controller + state = controller.state if controller else AgentState.STOPPED + if ( + session.last_active_ts < close_threshold + and state != AgentState.RUNNING + ): + sid_to_close.append(sid) + + connections = self._get_connections_locally( + filter_to_sids=set(sid_to_close) + ) + connected_sids = {sid for _, sid in connections.items()} + sid_to_close = [ + sid for sid in sid_to_close if sid not in connected_sids + ] + + if sid_to_close: + connections = await self._get_connections_remotely( + filter_to_sids=set(sid_to_close) + ) + connected_sids = {sid for _, sid in connections.items()} + sid_to_close = [ + sid for sid in sid_to_close if sid not in connected_sids + ] + await wait_all(self._close_session(sid) for sid in sid_to_close) await asyncio.sleep(_CLEANUP_INTERVAL) except asyncio.CancelledError: - async with self._conversations_lock: - for conversation, _ in self._detached_conversations.values(): - await conversation.disconnect() - self._detached_conversations.clear() + await wait_all( + self._close_session(sid) for sid in self._local_agent_loops_by_sid + ) return except Exception: - logger.warning('error_cleaning_detached_conversations', exc_info=True) + logger.warning('error_cleaning_stale', exc_info=True) await asyncio.sleep(_CLEANUP_EXCEPTION_WAIT_TIME) - async def get_agent_loop_running(self, sids: set[str]) -> set[str]: - running_sids = set(sid for sid in sids if sid in self._local_agent_loops_by_sid) - check_cluster_sids = [sid for sid in sids if sid not in running_sids] - running_cluster_sids = await self.get_agent_loop_running_in_cluster( - check_cluster_sids - ) - running_sids.union(running_cluster_sids) - return running_sids - async def is_agent_loop_running(self, sid: str) -> bool: - if await self.is_agent_loop_running_locally(sid): - return True - if await self.is_agent_loop_running_in_cluster(sid): - return True - return False - - async def is_agent_loop_running_locally(self, sid: str) -> bool: - return sid in self._local_agent_loops_by_sid - - async def is_agent_loop_running_in_cluster(self, sid: str) -> bool: - running_sids = await self.get_agent_loop_running_in_cluster([sid]) - return bool(running_sids) - - async def get_agent_loop_running_in_cluster(self, sids: list[str]) -> set[str]: + sids = await self.get_running_agent_loops(filter_to_sids={sid}) + return bool(sids) + + async def get_running_agent_loops( + self, user_id: int | None = None, filter_to_sids: set[str] | None = None + ) -> set[str]: + """Get the running session ids. If a user is supplied, then the results are limited to session ids for that user. If a set of filter_to_sids is supplied, then results are limited to these ids of interest.""" + sids = self._get_running_agent_loops_locally(user_id, filter_to_sids) + remote_sids = await self._get_running_agent_loops_remotely( + user_id, filter_to_sids + ) + return sids.union(remote_sids) + + def _get_running_agent_loops_locally( + self, user_id: int | None = None, filter_to_sids: set[str] | None = None + ) -> set[str]: + items: Iterable[tuple[str, Session]] = self._local_agent_loops_by_sid.items() + if filter_to_sids is not None: + items = (item for item in items if item[0] in filter_to_sids) + if user_id: + items = (item for item in items if item[1].user_id == user_id) + sids = {sid for sid, _ in items} + return sids + + async def _get_running_agent_loops_remotely( + self, + user_id: int | None = None, + filter_to_sids: set[str] | None = None, + ) -> set[str]: """As the rest of the cluster if a session is running. Wait a for a short timeout for a reply""" redis_client = self._get_redis_client() if not redis_client: return set() flag = asyncio.Event() - request_id = str(uuid4()) - check = _SessionIsRunningCheck(request_id=request_id, request_sids=sids) - self._session_is_running_checks[request_id] = check + query_id = str(uuid4()) + query = _ClusterQuery[set[str]]( + query_id=query_id, request_ids=filter_to_sids, result=set() + ) + self._running_sid_queries[query_id] = query try: - logger.debug(f'publish:is_session_running:{sids}') - await redis_client.publish( - 'oh_event', - json.dumps( - { - 'request_id': request_id, - 'sids': sids, - 'message_type': 'is_session_running', - } - ), + logger.debug( + f'publish:_get_running_agent_loops_remotely_query:{user_id}:{filter_to_sids}' ) + data: dict = { + 'query_id': query_id, + 'message_type': 'running_agent_loops_query', + } + if user_id: + data['user_id'] = user_id + if filter_to_sids: + data['filter_to_sids'] = list(filter_to_sids) + await redis_client.publish('session_msg', json.dumps(data)) async with asyncio.timeout(_REDIS_POLL_TIMEOUT): await flag.wait() - return check.running_sids + return query.result except TimeoutError: # Nobody replied in time - return check.running_sids + return query.result finally: - self._session_is_running_checks.pop(request_id, None) + self._running_sid_queries.pop(query_id, None) + + async def get_connections( + self, user_id: int | None = None, filter_to_sids: set[str] | None = None + ) -> dict[str, str]: + connection_ids = self._get_connections_locally(user_id, filter_to_sids) + remote_connection_ids = await self._get_connections_remotely( + user_id, filter_to_sids + ) + connection_ids.update(**remote_connection_ids) + return connection_ids + + def _get_connections_locally( + self, user_id: int | None = None, filter_to_sids: set[str] | None = None + ) -> dict[str, str]: + connections = dict(**self._local_connection_id_to_session_id) + if filter_to_sids is not None: + connections = { + connection_id: sid + for connection_id, sid in connections.items() + if sid in filter_to_sids + } + if user_id: + for connection_id, sid in list(connections.items()): + session = self._local_agent_loops_by_sid.get(sid) + if not session or session.user_id != user_id: + connections.pop(connection_id) + return connections + + async def _get_connections_remotely( + self, user_id: int | None = None, filter_to_sids: set[str] | None = None + ) -> dict[str, str]: + redis_client = self._get_redis_client() + if not redis_client: + return {} - async def _has_remote_connections(self, sid: str) -> bool: - """As the rest of the cluster if they still want this session running. Wait a for a short timeout for a reply""" - # Create a flag for the callback flag = asyncio.Event() - self._has_remote_connections_flags[sid] = flag + query_id = str(uuid4()) + query = _ClusterQuery[dict[str, str]]( + query_id=query_id, request_ids=filter_to_sids, result={} + ) + self._connection_queries[query_id] = query try: - await self._get_redis_client().publish( - 'oh_event', - json.dumps( - { - 'sid': sid, - 'message_type': 'has_remote_connections_query', - } - ), + logger.debug( + f'publish:get_connections_remotely_query:{user_id}:{filter_to_sids}' ) + data: dict = { + 'query_id': query_id, + 'message_type': 'connections_query', + } + if user_id: + data['user_id'] = user_id + if filter_to_sids: + data['filter_to_sids'] = list(filter_to_sids) + await redis_client.publish('session_msg', json.dumps(data)) async with asyncio.timeout(_REDIS_POLL_TIMEOUT): await flag.wait() - result = flag.is_set() - return result + return query.result except TimeoutError: # Nobody replied in time - return False + return query.result finally: - self._has_remote_connections_flags.pop(sid, None) + self._connection_queries.pop(query_id, None) - async def maybe_start_agent_loop(self, sid: str, settings: Settings) -> EventStream: + async def maybe_start_agent_loop( + self, sid: str, settings: Settings, user_id: int | None = None + ) -> EventStream: logger.info(f'maybe_start_agent_loop:{sid}') session: Session | None = None if not await self.is_agent_loop_running(sid): logger.info(f'start_agent_loop:{sid}') + + response_ids = await self.get_running_agent_loops(user_id) + if len(response_ids) >= MAX_RUNNING_CONVERSATIONS: + logger.info('too_many_sessions_for:{user_id}') + await self.close_session(next(iter(response_ids))) + session = Session( - sid=sid, file_store=self.file_store, config=self.config, sio=self.sio + sid=sid, + file_store=self.file_store, + config=self.config, + sio=self.sio, + user_id=user_id, ) self._local_agent_loops_by_sid[sid] = session asyncio.create_task(session.initialize_agent(settings)) @@ -361,7 +440,6 @@ async def maybe_start_agent_loop(self, sid: str, settings: Settings) -> EventStr if not event_stream: logger.error(f'No event stream after starting agent loop: {sid}') raise RuntimeError(f'no_event_stream:{sid}') - asyncio.create_task(self._cleanup_session_later(sid)) return event_stream async def _get_event_stream(self, sid: str) -> EventStream | None: @@ -371,7 +449,7 @@ async def _get_event_stream(self, sid: str) -> EventStream | None: logger.info(f'found_local_agent_loop:{sid}') return session.agent_session.event_stream - if await self.is_agent_loop_running_in_cluster(sid): + if await self._get_running_agent_loops_remotely(filter_to_sids={sid}): logger.info(f'found_remote_agent_loop:{sid}') return EventStream(sid, self.file_store) @@ -379,7 +457,7 @@ async def _get_event_stream(self, sid: str) -> EventStream | None: async def send_to_event_stream(self, connection_id: str, data: dict): # If there is a local session running, send to that - sid = self.local_connection_id_to_session_id.get(connection_id) + sid = self._local_connection_id_to_session_id.get(connection_id) if not sid: raise RuntimeError(f'no_connected_session:{connection_id}') @@ -395,11 +473,11 @@ async def send_to_event_stream(self, connection_id: str, data: dict): next_alive_check = last_alive_at + _CHECK_ALIVE_INTERVAL if ( next_alive_check > time.time() - or await self.is_agent_loop_running_in_cluster(sid) + or await self._get_running_agent_loops_remotely(filter_to_sids={sid}) ): # Send the event to the other pod await redis_client.publish( - 'oh_event', + 'session_msg', json.dumps( { 'sid': sid, @@ -413,75 +491,37 @@ async def send_to_event_stream(self, connection_id: str, data: dict): raise RuntimeError(f'no_connected_session:{connection_id}:{sid}') async def disconnect_from_session(self, connection_id: str): - sid = self.local_connection_id_to_session_id.pop(connection_id, None) + sid = self._local_connection_id_to_session_id.pop(connection_id, None) logger.info(f'disconnect_from_session:{connection_id}:{sid}') if not sid: # This can occur if the init action was never run. logger.warning(f'disconnect_from_uninitialized_session:{connection_id}') return - if should_continue(): - asyncio.create_task(self._cleanup_session_later(sid)) - else: - await self._on_close_session(sid) - - async def _cleanup_session_later(self, sid: str): - # Once there have been no connections to a session for a reasonable period, we close it - try: - await asyncio.sleep(self.config.sandbox.close_delay) - finally: - # If the sleep was cancelled, we still want to close these - await self._cleanup_session(sid) - - async def _cleanup_session(self, sid: str) -> bool: - # Get local connections - logger.info(f'_cleanup_session:{sid}') - has_local_connections = next( - (True for v in self.local_connection_id_to_session_id.values() if v == sid), - False, - ) - if has_local_connections: - return False - - # If no local connections, get connections through redis - redis_client = self._get_redis_client() - if redis_client and await self._has_remote_connections(sid): - return False - - # We alert the cluster in case they are interested - if redis_client: - await redis_client.publish( - 'oh_event', - json.dumps({'sid': sid, 'message_type': 'session_closing'}), - ) - - await self._on_close_session(sid) - return True - async def close_session(self, sid: str): session = self._local_agent_loops_by_sid.get(sid) if session: - await self._on_close_session(sid) + await self._close_session(sid) redis_client = self._get_redis_client() if redis_client: await redis_client.publish( - 'oh_event', + 'session_msg', json.dumps({'sid': sid, 'message_type': 'close_session'}), ) - async def _on_close_session(self, sid: str): + async def _close_session(self, sid: str): logger.info(f'_close_session:{sid}') # Clear up local variables connection_ids_to_remove = list( connection_id - for connection_id, conn_sid in self.local_connection_id_to_session_id.items() + for connection_id, conn_sid in self._local_connection_id_to_session_id.items() if sid == conn_sid ) logger.info(f'removing connections: {connection_ids_to_remove}') for connnnection_id in connection_ids_to_remove: - self.local_connection_id_to_session_id.pop(connnnection_id, None) + self._local_connection_id_to_session_id.pop(connnnection_id, None) session = self._local_agent_loops_by_sid.pop(sid, None) if not session: @@ -490,12 +530,17 @@ async def _on_close_session(self, sid: str): logger.info(f'closing_session:{session.sid}') # We alert the cluster in case they are interested - redis_client = self._get_redis_client() - if redis_client: - await redis_client.publish( - 'oh_event', - json.dumps({'sid': session.sid, 'message_type': 'session_closing'}), + try: + redis_client = self._get_redis_client() + if redis_client: + await redis_client.publish( + 'session_msg', + json.dumps({'sid': session.sid, 'message_type': 'session_closing'}), + ) + except Exception: + logger.info( + 'error_publishing_close_session_event', exc_info=True, stack_info=True ) - await call_sync_from_async(session.close) + await session.close() logger.info(f'closed_session:{session.sid}') diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index a481fbd27078..de1f7000c0c0 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -37,6 +37,7 @@ class Session: loop: asyncio.AbstractEventLoop config: AppConfig file_store: FileStore + user_id: int | None def __init__( self, @@ -44,6 +45,7 @@ def __init__( config: AppConfig, file_store: FileStore, sio: socketio.AsyncServer | None, + user_id: int | None = None, ): self.sid = sid self.sio = sio @@ -58,10 +60,19 @@ def __init__( # Copying this means that when we update variables they are not applied to the shared global configuration! self.config = deepcopy(config) self.loop = asyncio.get_event_loop() - - def close(self): + self.user_id = user_id + + async def close(self): + if self.sio: + await self.sio.emit( + 'oh_event', + event_to_dict( + AgentStateChangedObservation('', AgentState.STOPPED.value) + ), + to=ROOM_KEY.format(sid=self.sid), + ) self.is_alive = False - self.agent_session.close() + asyncio.create_task(self.agent_session.close()) async def initialize_agent( self, @@ -113,9 +124,9 @@ async def initialize_agent( selected_repository=selected_repository, ) except Exception as e: - logger.exception(f'Error creating controller: {e}') + logger.exception(f'Error creating agent_session: {e}') await self.send_error( - f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..' + f'Error creating agent_session. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..' ) return diff --git a/openhands/storage/conversation/conversation_store.py b/openhands/storage/conversation/conversation_store.py index 1f0b41ea6c87..40b5a23f1251 100644 --- a/openhands/storage/conversation/conversation_store.py +++ b/openhands/storage/conversation/conversation_store.py @@ -40,5 +40,7 @@ async def search( @classmethod @abstractmethod - async def get_instance(cls, config: AppConfig, user_id: int) -> ConversationStore: + async def get_instance( + cls, config: AppConfig, user_id: int | None + ) -> ConversationStore: """Get a store for the user represented by the token given""" diff --git a/openhands/storage/conversation/file_conversation_store.py b/openhands/storage/conversation/file_conversation_store.py index 15679f404903..c51ab3219db3 100644 --- a/openhands/storage/conversation/file_conversation_store.py +++ b/openhands/storage/conversation/file_conversation_store.py @@ -94,7 +94,7 @@ def get_conversation_metadata_filename(self, conversation_id: str) -> str: @classmethod async def get_instance( - cls, config: AppConfig, user_id: int + cls, config: AppConfig, user_id: int | None ) -> FileConversationStore: file_store = get_file_store(config.file_store, config.file_store_path) return FileConversationStore(file_store) diff --git a/openhands/storage/data_models/conversation_info.py b/openhands/storage/data_models/conversation_info.py index 52464ec30f85..9bbeacd73d43 100644 --- a/openhands/storage/data_models/conversation_info.py +++ b/openhands/storage/data_models/conversation_info.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from datetime import datetime +from datetime import datetime, timezone from openhands.storage.data_models.conversation_status import ConversationStatus @@ -13,4 +13,4 @@ class ConversationInfo: last_updated_at: datetime | None = None status: ConversationStatus = ConversationStatus.STOPPED selected_repository: str | None = None - created_at: datetime = field(default_factory=datetime.now) + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) diff --git a/openhands/storage/data_models/conversation_metadata.py b/openhands/storage/data_models/conversation_metadata.py index e75bbf21f8d5..d19183b23376 100644 --- a/openhands/storage/data_models/conversation_metadata.py +++ b/openhands/storage/data_models/conversation_metadata.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from datetime import datetime +from datetime import datetime, timezone @dataclass @@ -9,4 +9,4 @@ class ConversationMetadata: selected_repository: str | None title: str | None = None last_updated_at: datetime | None = None - created_at: datetime = field(default_factory=datetime.now) + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) diff --git a/openhands/storage/s3.py b/openhands/storage/s3.py index b571d4288c89..76db87dc638a 100644 --- a/openhands/storage/s3.py +++ b/openhands/storage/s3.py @@ -41,6 +41,10 @@ def list(self, path: str) -> list[str]: def delete(self, path: str) -> None: try: - self.client.remove_object(self.bucket, path) + client = self.client + bucket = self.bucket + objects_to_delete = client.list_objects(bucket, prefix=path, recursive=True) + for obj in objects_to_delete: + client.remove_object(bucket, obj.object_name) except Exception as e: raise FileNotFoundError(f'Failed to delete S3 object at path {path}: {e}') diff --git a/openhands/storage/settings/file_settings_store.py b/openhands/storage/settings/file_settings_store.py index c8703b304c11..1b4d4b6b0b67 100644 --- a/openhands/storage/settings/file_settings_store.py +++ b/openhands/storage/settings/file_settings_store.py @@ -30,6 +30,8 @@ async def store(self, settings: Settings): await call_sync_from_async(self.file_store.write, self.path, json_str) @classmethod - async def get_instance(cls, config: AppConfig, user_id: int) -> FileSettingsStore: + async def get_instance( + cls, config: AppConfig, user_id: int | None + ) -> FileSettingsStore: file_store = get_file_store(config.file_store, config.file_store_path) return FileSettingsStore(file_store) diff --git a/openhands/storage/settings/settings_store.py b/openhands/storage/settings/settings_store.py index a371720600ac..1ff1aa85c34a 100644 --- a/openhands/storage/settings/settings_store.py +++ b/openhands/storage/settings/settings_store.py @@ -21,5 +21,7 @@ async def store(self, settings: Settings): @classmethod @abstractmethod - async def get_instance(cls, config: AppConfig, user_id: int) -> SettingsStore: + async def get_instance( + cls, config: AppConfig, user_id: int | None + ) -> SettingsStore: """Get a store for the user represented by the token given""" diff --git a/pyproject.toml b/pyproject.toml index 1a15ae754ca3..3753ae45d200 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,6 +101,7 @@ reportlab = "*" [tool.coverage.run] concurrency = ["gevent"] + [tool.poetry.group.runtime.dependencies] jupyterlab = "*" notebook = "*" @@ -129,6 +130,7 @@ ignore = ["D1"] [tool.ruff.lint.pydocstyle] convention = "google" + [tool.poetry.group.evaluation.dependencies] streamlit = "*" whatthepatch = "*" diff --git a/tests/unit/test_manager.py b/tests/unit/test_manager.py index 144f79f9f491..a2c705f270e5 100644 --- a/tests/unit/test_manager.py +++ b/tests/unit/test_manager.py @@ -44,28 +44,28 @@ async def test_session_not_running_in_cluster(): async with SessionManager( sio, AppConfig(), InMemoryFileStore() ) as session_manager: - result = await session_manager.is_agent_loop_running_in_cluster( - 'non-existant-session' + result = await session_manager._get_running_agent_loops_remotely( + filter_to_sids={'non-existant-session'} ) - assert result is False + assert result == set() assert sio.manager.redis.publish.await_count == 1 sio.manager.redis.publish.assert_called_once_with( - 'oh_event', - '{"request_id": "' + 'session_msg', + '{"query_id": "' + str(id) - + '", "sids": ["non-existant-session"], "message_type": "is_session_running"}', + + '", "message_type": "running_agent_loops_query", "filter_to_sids": ["non-existant-session"]}', ) @pytest.mark.asyncio -async def test_session_is_running_in_cluster(): +async def test_get_running_agent_loops_remotely(): id = uuid4() sio = get_mock_sio( GetMessageMock( { - 'request_id': str(id), + 'query_id': str(id), 'sids': ['existing-session'], - 'message_type': 'session_is_running', + 'message_type': 'running_agent_loops_response', } ) ) @@ -76,16 +76,16 @@ async def test_session_is_running_in_cluster(): async with SessionManager( sio, AppConfig(), InMemoryFileStore() ) as session_manager: - result = await session_manager.is_agent_loop_running_in_cluster( - 'existing-session' + result = await session_manager._get_running_agent_loops_remotely( + 1, {'existing-session'} ) - assert result is True + assert result == {'existing-session'} assert sio.manager.redis.publish.await_count == 1 sio.manager.redis.publish.assert_called_once_with( - 'oh_event', - '{"request_id": "' + 'session_msg', + '{"query_id": "' + str(id) - + '", "sids": ["existing-session"], "message_type": "is_session_running"}', + + '", "message_type": "running_agent_loops_query", "user_id": 1, "filter_to_sids": ["existing-session"]}', ) @@ -96,8 +96,8 @@ async def test_init_new_local_session(): mock_session = MagicMock() mock_session.return_value = session_instance sio = get_mock_sio() - is_agent_loop_running_in_cluster_mock = AsyncMock() - is_agent_loop_running_in_cluster_mock.return_value = False + get_running_agent_loops_mock = AsyncMock() + get_running_agent_loops_mock.return_value = set() with ( patch('openhands.server.session.manager.Session', mock_session), patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1), @@ -106,8 +106,8 @@ async def test_init_new_local_session(): AsyncMock(), ), patch( - 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster', - is_agent_loop_running_in_cluster_mock, + 'openhands.server.session.manager.SessionManager.get_running_agent_loops', + get_running_agent_loops_mock, ), ): async with SessionManager( @@ -117,7 +117,7 @@ async def test_init_new_local_session(): 'new-session-id', ConversationInitData() ) await session_manager.join_conversation( - 'new-session-id', 'new-session-id', ConversationInitData() + 'new-session-id', 'new-session-id', ConversationInitData(), 1 ) assert session_instance.initialize_agent.call_count == 1 assert sio.enter_room.await_count == 1 @@ -130,8 +130,8 @@ async def test_join_local_session(): mock_session = MagicMock() mock_session.return_value = session_instance sio = get_mock_sio() - is_agent_loop_running_in_cluster_mock = AsyncMock() - is_agent_loop_running_in_cluster_mock.return_value = False + get_running_agent_loops_mock = AsyncMock() + get_running_agent_loops_mock.return_value = set() with ( patch('openhands.server.session.manager.Session', mock_session), patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01), @@ -140,8 +140,8 @@ async def test_join_local_session(): AsyncMock(), ), patch( - 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster', - is_agent_loop_running_in_cluster_mock, + 'openhands.server.session.manager.SessionManager.get_running_agent_loops', + get_running_agent_loops_mock, ), ): async with SessionManager( @@ -151,10 +151,10 @@ async def test_join_local_session(): 'new-session-id', ConversationInitData() ) await session_manager.join_conversation( - 'new-session-id', 'new-session-id', ConversationInitData() + 'new-session-id', 'new-session-id', ConversationInitData(), None ) await session_manager.join_conversation( - 'new-session-id', 'new-session-id', ConversationInitData() + 'new-session-id', 'new-session-id', ConversationInitData(), 0 ) assert session_instance.initialize_agent.call_count == 1 assert sio.enter_room.await_count == 2 @@ -167,8 +167,8 @@ async def test_join_cluster_session(): mock_session = MagicMock() mock_session.return_value = session_instance sio = get_mock_sio() - is_agent_loop_running_in_cluster_mock = AsyncMock() - is_agent_loop_running_in_cluster_mock.return_value = True + get_running_agent_loops_mock = AsyncMock() + get_running_agent_loops_mock.return_value = {'new-session-id'} with ( patch('openhands.server.session.manager.Session', mock_session), patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01), @@ -177,15 +177,15 @@ async def test_join_cluster_session(): AsyncMock(), ), patch( - 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster', - is_agent_loop_running_in_cluster_mock, + 'openhands.server.session.manager.SessionManager._get_running_agent_loops_remotely', + get_running_agent_loops_mock, ), ): async with SessionManager( sio, AppConfig(), InMemoryFileStore() ) as session_manager: await session_manager.join_conversation( - 'new-session-id', 'new-session-id', ConversationInitData() + 'new-session-id', 'new-session-id', ConversationInitData(), 1 ) assert session_instance.initialize_agent.call_count == 0 assert sio.enter_room.await_count == 1 @@ -198,8 +198,8 @@ async def test_add_to_local_event_stream(): mock_session = MagicMock() mock_session.return_value = session_instance sio = get_mock_sio() - is_agent_loop_running_in_cluster_mock = AsyncMock() - is_agent_loop_running_in_cluster_mock.return_value = False + get_running_agent_loops_mock = AsyncMock() + get_running_agent_loops_mock.return_value = set() with ( patch('openhands.server.session.manager.Session', mock_session), patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01), @@ -208,8 +208,8 @@ async def test_add_to_local_event_stream(): AsyncMock(), ), patch( - 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster', - is_agent_loop_running_in_cluster_mock, + 'openhands.server.session.manager.SessionManager.get_running_agent_loops', + get_running_agent_loops_mock, ), ): async with SessionManager( @@ -219,7 +219,7 @@ async def test_add_to_local_event_stream(): 'new-session-id', ConversationInitData() ) await session_manager.join_conversation( - 'new-session-id', 'connection-id', ConversationInitData() + 'new-session-id', 'connection-id', ConversationInitData(), 1 ) await session_manager.send_to_event_stream( 'connection-id', {'event_type': 'some_event'} @@ -234,8 +234,8 @@ async def test_add_to_cluster_event_stream(): mock_session = MagicMock() mock_session.return_value = session_instance sio = get_mock_sio() - is_agent_loop_running_in_cluster_mock = AsyncMock() - is_agent_loop_running_in_cluster_mock.return_value = True + get_running_agent_loops_mock = AsyncMock() + get_running_agent_loops_mock.return_value = {'new-session-id'} with ( patch('openhands.server.session.manager.Session', mock_session), patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01), @@ -244,22 +244,22 @@ async def test_add_to_cluster_event_stream(): AsyncMock(), ), patch( - 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster', - is_agent_loop_running_in_cluster_mock, + 'openhands.server.session.manager.SessionManager._get_running_agent_loops_remotely', + get_running_agent_loops_mock, ), ): async with SessionManager( sio, AppConfig(), InMemoryFileStore() ) as session_manager: await session_manager.join_conversation( - 'new-session-id', 'connection-id', ConversationInitData() + 'new-session-id', 'connection-id', ConversationInitData(), 1 ) await session_manager.send_to_event_stream( 'connection-id', {'event_type': 'some_event'} ) assert sio.manager.redis.publish.await_count == 1 sio.manager.redis.publish.assert_called_once_with( - 'oh_event', + 'session_msg', '{"sid": "new-session-id", "message_type": "event", "data": {"event_type": "some_event"}}', ) @@ -277,7 +277,7 @@ async def test_cleanup_session_connections(): async with SessionManager( sio, AppConfig(), InMemoryFileStore() ) as session_manager: - session_manager.local_connection_id_to_session_id.update( + session_manager._local_connection_id_to_session_id.update( { 'conn1': 'session1', 'conn2': 'session1', @@ -286,9 +286,9 @@ async def test_cleanup_session_connections(): } ) - await session_manager._on_close_session('session1') + await session_manager._close_session('session1') - remaining_connections = session_manager.local_connection_id_to_session_id + remaining_connections = session_manager._local_connection_id_to_session_id assert 'conn1' not in remaining_connections assert 'conn2' not in remaining_connections assert 'conn3' in remaining_connections