From a4577e844bb25a0f99de3841379cb54335ef840f Mon Sep 17 00:00:00 2001 From: Myles Scolnick Date: Thu, 30 Jan 2025 12:00:32 -0500 Subject: [PATCH 1/3] fix: handle periodic cleanup for sessions --- marimo/_server/session/session_view.py | 38 +++++-- marimo/_server/sessions.py | 53 ++++++++-- marimo/_utils/debounce.py | 6 +- marimo/_utils/distributor.py | 8 +- tests/_server/test_session_manager.py | 137 ++++++++++++++++++++++++- tests/_utils/test_decorator.py | 6 +- 6 files changed, 223 insertions(+), 25 deletions(-) diff --git a/marimo/_server/session/session_view.py b/marimo/_server/session/session_view.py index c586854f0d3..40909c164b4 100644 --- a/marimo/_server/session/session_view.py +++ b/marimo/_server/session/session_view.py @@ -57,9 +57,10 @@ def __init__(self) -> None: self.stale_code: Optional[UpdateCellCodes] = None # Auto-saving - self.has_auto_exported_html = False - self.has_auto_exported_md = False - self.has_auto_exported_ipynb = False + self.last_auto_exported_html: Optional[float] = None + self.last_auto_exported_md: Optional[float] = None + self.last_auto_exported_ipynb: Optional[float] = None + self.last_active_time = time.time() def _add_ui_value(self, name: str, value: Any) -> None: self.ui_values[name] = value @@ -239,18 +240,37 @@ def operations(self) -> list[MessageOperation]: return all_ops def mark_auto_export_html(self) -> None: - self.has_auto_exported_html = True + self.last_auto_exported_html = time.time() def mark_auto_export_md(self) -> None: - self.has_auto_exported_md = True + self.last_auto_exported_md = time.time() def mark_auto_export_ipynb(self) -> None: - self.has_auto_exported_ipynb = True + self.last_auto_exported_ipynb = time.time() + + @property + def has_auto_exported_html(self) -> bool: + return ( + self.last_auto_exported_html is not None + and self.last_active_time <= self.last_auto_exported_html + ) + + @property + def has_auto_exported_md(self) -> bool: + return ( + self.last_auto_exported_md is not None + and self.last_active_time <= self.last_auto_exported_md + ) + + @property + def has_auto_exported_ipynb(self) -> bool: + return ( + self.last_auto_exported_ipynb is not None + and self.last_active_time <= self.last_auto_exported_ipynb + ) def _touch(self) -> None: - self.has_auto_exported_html = False - self.has_auto_exported_md = False - self.has_auto_exported_ipynb = False + self.last_active_time = time.time() def merge_cell_operation( diff --git a/marimo/_server/sessions.py b/marimo/_server/sessions.py index 65dd8eee42c..c6e7ccb74ac 100644 --- a/marimo/_server/sessions.py +++ b/marimo/_server/sessions.py @@ -68,6 +68,7 @@ from marimo._server.types import QueueType from marimo._server.utils import print_, print_tabbed from marimo._tracer import server_tracer +from marimo._utils.debounce import debounce from marimo._utils.disposable import Disposable from marimo._utils.distributor import ( ConnectionDistributor, @@ -906,6 +907,11 @@ def handle_file_rename_for_watch( return False, str(e) def get_session(self, session_id: SessionId) -> Optional[Session]: + # Clean up orphaned sessions periodically when looking up sessions + self.cleanup_orphaned_sessions() + return self._get_session_by_id(session_id) + + def _get_session_by_id(self, session_id: SessionId) -> Optional[Session]: session = self.sessions.get(session_id) if session: return session @@ -939,7 +945,7 @@ def maybe_resume_session( # If in run mode, only resume the session if it is orphaned and has # the same session id, otherwise we want to create a new session if self.mode == SessionMode.RUN: - maybe_session = self.get_session(new_session_id) + maybe_session = self._get_session_by_id(new_session_id) if ( maybe_session and maybe_session.connection_state() @@ -952,13 +958,6 @@ def maybe_resume_session( return maybe_session return None - # Cleanup sessions with dead kernels; materializing as a list because - # close_sessions mutates self.sessions - for session_id, session in list(self.sessions.items()): - task = session.kernel_manager.kernel_task - if task is not None and not task.is_alive(): - self.close_session(session_id) - # Should only return an orphaned session sessions_with_the_same_file: dict[SessionId, Session] = { session_id: session @@ -1023,7 +1022,7 @@ async def start_lsp_server(self) -> None: def close_session(self, session_id: SessionId) -> bool: """Close a session and remove its file watcher if it has one.""" LOGGER.debug("Closing session %s", session_id) - session = self.get_session(session_id) + session = self._get_session_by_id(session_id) if session is None: return False @@ -1065,6 +1064,42 @@ def get_active_connection_count(self) -> int: ] ) + # Check every 1 minute. + # We check the TTL later, so the 1 minute check is just to avoid + # unnecessary work. + @debounce(60) + def cleanup_orphaned_sessions(self) -> None: + """Clean up any orphaned or dead sessions""" + if len(self.sessions) == 0: + return + + cleaned_up_count = 0 + + LOGGER.debug("Cleaning up orphaned sessions") + + # Materialize list since we'll modify self.sessions + for session_id, session in list(self.sessions.items()): + # Check if kernel is dead + task = session.kernel_manager.kernel_task + if task is not None and not task.is_alive(): + cleaned_up_count += 1 + self.close_session(session_id) + continue + + # Check if session is orphaned and past TTL + if session.connection_state() == ConnectionState.ORPHANED: + # Get time since last activity + last_activity = session.session_view.last_active_time + stale_time = time.time() - last_activity + if stale_time > session.ttl_seconds: + cleaned_up_count += 1 + self.close_session(session_id) + + if cleaned_up_count > 0: + LOGGER.debug("Cleaned up %d orphaned sessions", cleaned_up_count) + else: + LOGGER.debug("No orphaned sessions to clean up") + class LspServer: def __init__(self, port: int) -> None: diff --git a/marimo/_utils/debounce.py b/marimo/_utils/debounce.py index 9168a136816..5a2f1efaf6d 100644 --- a/marimo/_utils/debounce.py +++ b/marimo/_utils/debounce.py @@ -1,4 +1,6 @@ # Copyright 2024 Marimo. All rights reserved. +from __future__ import annotations + import time from functools import wraps from typing import Any, Callable, TypeVar, cast @@ -6,7 +8,7 @@ F = TypeVar("F", bound=Callable[..., None]) -def debounce(wait_time: float) -> Callable[[F], F]: +def debounce(wait_time_seconds: float) -> Callable[[F], F]: """ Decorator to prevent a function from being called more than once every wait_time seconds. @@ -19,7 +21,7 @@ def decorator(func: F) -> F: def wrapped(*args: Any, **kwargs: Any) -> None: nonlocal last_called current_time = time.time() - if current_time - last_called >= wait_time: + if current_time - last_called >= wait_time_seconds: last_called = current_time func(*args, **kwargs) diff --git a/marimo/_utils/distributor.py b/marimo/_utils/distributor.py index a1471c33302..d45e419352c 100644 --- a/marimo/_utils/distributor.py +++ b/marimo/_utils/distributor.py @@ -80,7 +80,13 @@ def start(self) -> Disposable: def stop(self) -> None: """Stop distributing the response.""" - asyncio.get_event_loop().remove_reader(self.input_connection.fileno()) + try: + asyncio.get_event_loop().remove_reader( + self.input_connection.fileno() + ) + except OSError: + # Handle may already be closed + pass if not self.input_connection.closed: self.input_connection.close() self.consumers.clear() diff --git a/tests/_server/test_session_manager.py b/tests/_server/test_session_manager.py index 3caa2ebb04b..682d438eb8e 100644 --- a/tests/_server/test_session_manager.py +++ b/tests/_server/test_session_manager.py @@ -1,6 +1,8 @@ from __future__ import annotations -from unittest.mock import MagicMock, Mock +import asyncio +import time +from unittest.mock import MagicMock, Mock, patch import pytest @@ -34,7 +36,7 @@ def mock_session(): @pytest.fixture def session_manager(): - return SessionManager( + sm = SessionManager( file_router=AppFileRouter.new_file(), mode=SessionMode.EDIT, development_mode=False, @@ -48,6 +50,13 @@ def session_manager(): ttl_seconds=None, ) + # Unwrap debounce from cleanup_orphaned_sessions + if hasattr(sm.cleanup_orphaned_sessions, "__wrapped__"): + unwrapped = sm.cleanup_orphaned_sessions.__wrapped__ + sm.cleanup_orphaned_sessions = lambda: unwrapped(sm) + yield sm + sm.shutdown() + async def test_start_lsp_server(session_manager: SessionManager) -> None: await session_manager.start_lsp_server() @@ -209,3 +218,127 @@ def test_shutdown( session_manager.lsp_server.stop.assert_called_once() assert len(session_manager.sessions) == 0 assert mock_session.close.call_count == 2 + + +def test_cleanup_orphaned_sessions_no_sessions( + session_manager: SessionManager, +) -> None: + # Test with no sessions + session_manager.cleanup_orphaned_sessions() + assert len(session_manager.sessions) == 0 + + +async def test_cleanup_orphaned_sessions_dead_kernel( + session_manager: SessionManager, + mock_session_consumer: SessionConsumer, +) -> None: + # Create a session with a dead kernel + session_id = "test_session_id" + session = session_manager.create_session( + session_id, + mock_session_consumer, + query_params={}, + file_key=AppFileRouter.NEW_FILE, + ) + + assert session.kernel_manager.kernel_task is not None + session.kernel_manager.close_kernel() + await asyncio.sleep(0.05) # Flush the close + + # Run cleanup + session_manager.cleanup_orphaned_sessions() + + # Session should be cleaned up + assert session_id not in session_manager.sessions + session.close() + + +def test_cleanup_orphaned_sessions_stale_session( + session_manager: SessionManager, + mock_session_consumer: SessionConsumer, +) -> None: + # Create a session that will become stale + session_id = "test_session_id" + ttl = 1 # 1 second TTL for testing + session_manager.ttl_seconds = ttl + + session = session_manager.create_session( + session_id, + mock_session_consumer, + query_params={}, + file_key=AppFileRouter.NEW_FILE, + ) + + # Mock the session to be orphaned + session.connection_state = lambda: ConnectionState.ORPHANED + + # Mock last_active_time to be in the past + with patch.object( + session.session_view, "last_active_time", time.time() - ttl - 1 + ): + # Run cleanup + session_manager.cleanup_orphaned_sessions() + + # Session should be cleaned up + assert session_id not in session_manager.sessions + session.close() + + +def test_cleanup_orphaned_sessions_active_session( + session_manager: SessionManager, + mock_session_consumer: SessionConsumer, +) -> None: + # Create an active session + session_id = "test_session_id" + session = session_manager.create_session( + session_id, + mock_session_consumer, + query_params={}, + file_key=AppFileRouter.NEW_FILE, + ) + + # Mock the session to be active + session.connection_state = lambda: ConnectionState.OPEN + + # Run cleanup + session_manager.cleanup_orphaned_sessions() + + # Session should still be there + assert session_id in session_manager.sessions + session.close() + + +def test_cleanup_orphaned_sessions_not_stale_yet( + session_manager: SessionManager, + mock_session_consumer: SessionConsumer, +) -> None: + # Create a session that is orphaned but not stale yet + session_id = "test_session_id" + ttl = 60 # 60 second TTL + session_manager.ttl_seconds = ttl + + session = session_manager.create_session( + session_id, + mock_session_consumer, + query_params={}, + file_key=AppFileRouter.NEW_FILE, + ) + + # Mock the session to be orphaned + session.connection_state = lambda: ConnectionState.ORPHANED + + # Update last_active_time to be recent + session.session_view._touch() + + # Run cleanup + session_manager.cleanup_orphaned_sessions() + + # Session should still be there since it hasn't exceeded TTL + assert session_id in session_manager.sessions + + # Set last_active_time to be stale + session.session_view.last_active_time = time.time() - ttl - 1 + + # Run cleanup + session_manager.cleanup_orphaned_sessions() + assert session_id not in session_manager.sessions diff --git a/tests/_utils/test_decorator.py b/tests/_utils/test_decorator.py index 797c490a897..8ea42b96381 100644 --- a/tests/_utils/test_decorator.py +++ b/tests/_utils/test_decorator.py @@ -1,4 +1,6 @@ # Assuming the debounce decorator is defined in a file named debounce.py +from __future__ import annotations + import time from marimo._utils.debounce import debounce @@ -8,7 +10,7 @@ def test_debounce_within_period() -> None: - @debounce(wait_time=1.5) + @debounce(wait_time_seconds=1.5) def my_function() -> None: global call_count call_count += 1 @@ -25,7 +27,7 @@ def my_function() -> None: def test_debounce_after_period() -> None: - @debounce(wait_time=0.1) + @debounce(wait_time_seconds=0.1) def my_function() -> None: global call_count call_count += 1 From 1c7857db87bceb970a5e143311bbb35345e5d613 Mon Sep 17 00:00:00 2001 From: Myles Scolnick Date: Thu, 30 Jan 2025 14:38:25 -0500 Subject: [PATCH 2/3] update time change --- marimo/_server/session/session_view.py | 36 +++++++------------------- 1 file changed, 9 insertions(+), 27 deletions(-) diff --git a/marimo/_server/session/session_view.py b/marimo/_server/session/session_view.py index 40909c164b4..a591f23e687 100644 --- a/marimo/_server/session/session_view.py +++ b/marimo/_server/session/session_view.py @@ -57,9 +57,9 @@ def __init__(self) -> None: self.stale_code: Optional[UpdateCellCodes] = None # Auto-saving - self.last_auto_exported_html: Optional[float] = None - self.last_auto_exported_md: Optional[float] = None - self.last_auto_exported_ipynb: Optional[float] = None + self.has_auto_exported_html = False + self.has_auto_exported_md = False + self.has_auto_exported_ipynb = False self.last_active_time = time.time() def _add_ui_value(self, name: str, value: Any) -> None: @@ -240,37 +240,19 @@ def operations(self) -> list[MessageOperation]: return all_ops def mark_auto_export_html(self) -> None: - self.last_auto_exported_html = time.time() + self.has_auto_exported_html = True def mark_auto_export_md(self) -> None: - self.last_auto_exported_md = time.time() + self.has_auto_exported_md = True def mark_auto_export_ipynb(self) -> None: - self.last_auto_exported_ipynb = time.time() - - @property - def has_auto_exported_html(self) -> bool: - return ( - self.last_auto_exported_html is not None - and self.last_active_time <= self.last_auto_exported_html - ) - - @property - def has_auto_exported_md(self) -> bool: - return ( - self.last_auto_exported_md is not None - and self.last_active_time <= self.last_auto_exported_md - ) - - @property - def has_auto_exported_ipynb(self) -> bool: - return ( - self.last_auto_exported_ipynb is not None - and self.last_active_time <= self.last_auto_exported_ipynb - ) + self.has_auto_exported_ipynb = True def _touch(self) -> None: self.last_active_time = time.time() + self.has_auto_exported_html = False + self.has_auto_exported_md = False + self.has_auto_exported_ipynb = False def merge_cell_operation( From 457cb5c033fca9928951e92e3d6fda70078c12d0 Mon Sep 17 00:00:00 2001 From: Myles Scolnick Date: Thu, 30 Jan 2025 14:48:31 -0500 Subject: [PATCH 3/3] maybe fixes --- frontend/src/components/pages/run-page.tsx | 7 ++++++- tests/_server/test_session_manager.py | 2 ++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/frontend/src/components/pages/run-page.tsx b/frontend/src/components/pages/run-page.tsx index c2f3b837219..011c1f9ac0d 100644 --- a/frontend/src/components/pages/run-page.tsx +++ b/frontend/src/components/pages/run-page.tsx @@ -32,7 +32,12 @@ const Watermark = () => { rel="noreferrer" > made with marimo - marimo + marimo ); diff --git a/tests/_server/test_session_manager.py b/tests/_server/test_session_manager.py index 682d438eb8e..6dc7747dbe1 100644 --- a/tests/_server/test_session_manager.py +++ b/tests/_server/test_session_manager.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import sys import time from unittest.mock import MagicMock, Mock, patch @@ -228,6 +229,7 @@ def test_cleanup_orphaned_sessions_no_sessions( assert len(session_manager.sessions) == 0 +@pytest.mark.xfail(sys.platform == "win32", reason="Flaky on Windows") async def test_cleanup_orphaned_sessions_dead_kernel( session_manager: SessionManager, mock_session_consumer: SessionConsumer,