Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle periodic cleanup for sessions #3627

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions marimo/_server/session/session_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@ def __init__(self) -> None:
self.stale_code: Optional[UpdateCellCodes] = None

# Auto-saving
self.has_auto_exported_html = False
self.has_auto_exported_md = False
self.has_auto_exported_ipynb = False
self.last_auto_exported_html: Optional[float] = None
self.last_auto_exported_md: Optional[float] = None
self.last_auto_exported_ipynb: Optional[float] = None
self.last_active_time = time.time()

def _add_ui_value(self, name: str, value: Any) -> None:
self.ui_values[name] = value
Expand Down Expand Up @@ -239,18 +240,37 @@ def operations(self) -> list[MessageOperation]:
return all_ops

def mark_auto_export_html(self) -> None:
self.has_auto_exported_html = True
self.last_auto_exported_html = time.time()

def mark_auto_export_md(self) -> None:
self.has_auto_exported_md = True
self.last_auto_exported_md = time.time()

def mark_auto_export_ipynb(self) -> None:
self.has_auto_exported_ipynb = True
self.last_auto_exported_ipynb = time.time()

@property
def has_auto_exported_html(self) -> bool:
return (
self.last_auto_exported_html is not None
and self.last_active_time <= self.last_auto_exported_html
)

@property
def has_auto_exported_md(self) -> bool:
return (
self.last_auto_exported_md is not None
and self.last_active_time <= self.last_auto_exported_md
)

@property
def has_auto_exported_ipynb(self) -> bool:
return (
self.last_auto_exported_ipynb is not None
and self.last_active_time <= self.last_auto_exported_ipynb
)

def _touch(self) -> None:
self.has_auto_exported_html = False
self.has_auto_exported_md = False
self.has_auto_exported_ipynb = False
self.last_active_time = time.time()


def merge_cell_operation(
Expand Down
53 changes: 44 additions & 9 deletions marimo/_server/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from marimo._server.types import QueueType
from marimo._server.utils import print_, print_tabbed
from marimo._tracer import server_tracer
from marimo._utils.debounce import debounce
from marimo._utils.disposable import Disposable
from marimo._utils.distributor import (
ConnectionDistributor,
Expand Down Expand Up @@ -906,6 +907,11 @@ def handle_file_rename_for_watch(
return False, str(e)

def get_session(self, session_id: SessionId) -> Optional[Session]:
# Clean up orphaned sessions periodically when looking up sessions
self.cleanup_orphaned_sessions()
return self._get_session_by_id(session_id)

def _get_session_by_id(self, session_id: SessionId) -> Optional[Session]:
session = self.sessions.get(session_id)
if session:
return session
Expand Down Expand Up @@ -939,7 +945,7 @@ def maybe_resume_session(
# If in run mode, only resume the session if it is orphaned and has
# the same session id, otherwise we want to create a new session
if self.mode == SessionMode.RUN:
maybe_session = self.get_session(new_session_id)
maybe_session = self._get_session_by_id(new_session_id)
if (
maybe_session
and maybe_session.connection_state()
Expand All @@ -952,13 +958,6 @@ def maybe_resume_session(
return maybe_session
return None

# Cleanup sessions with dead kernels; materializing as a list because
# close_sessions mutates self.sessions
for session_id, session in list(self.sessions.items()):
task = session.kernel_manager.kernel_task
if task is not None and not task.is_alive():
self.close_session(session_id)

# Should only return an orphaned session
sessions_with_the_same_file: dict[SessionId, Session] = {
session_id: session
Expand Down Expand Up @@ -1023,7 +1022,7 @@ async def start_lsp_server(self) -> None:
def close_session(self, session_id: SessionId) -> bool:
"""Close a session and remove its file watcher if it has one."""
LOGGER.debug("Closing session %s", session_id)
session = self.get_session(session_id)
session = self._get_session_by_id(session_id)
if session is None:
return False

Expand Down Expand Up @@ -1065,6 +1064,42 @@ def get_active_connection_count(self) -> int:
]
)

# Check every 1 minute.
# We check the TTL later, so the 1 minute check is just to avoid
# unnecessary work.
@debounce(60)
def cleanup_orphaned_sessions(self) -> None:
Comment on lines +1070 to +1071
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perplexing bug; I would have thought that

if self.manager.mode == SessionMode.RUN:
# When the websocket is closed, we wait session.ttl_seconds before
# closing the session. This is to prevent the session from being
# closed if the during an intermittent network issue.
def _close() -> None:
if self.status != ConnectionState.OPEN:
LOGGER.debug(
"Closing session %s (TTL EXPIRED)",
self.session_id,
)
# wait until TTL is expired before calling the cleanup
# function
cleanup_fn()
self.manager.close_session(self.session_id)
session = self.manager.get_session(self.session_id)
if session is not None:
cancellation_handle = asyncio.get_event_loop().call_later(
session.ttl_seconds, _close
)
self.cancel_close_handle = cancellation_handle

handled this case. Unless somehow the websocket's on_disconnect is not called?

"""Clean up any orphaned or dead sessions"""
if len(self.sessions) == 0:
return

cleaned_up_count = 0

LOGGER.debug("Cleaning up orphaned sessions")

# Materialize list since we'll modify self.sessions
for session_id, session in list(self.sessions.items()):
# Check if kernel is dead
task = session.kernel_manager.kernel_task
if task is not None and not task.is_alive():
cleaned_up_count += 1
self.close_session(session_id)
continue

# Check if session is orphaned and past TTL
if session.connection_state() == ConnectionState.ORPHANED:
# Get time since last activity
last_activity = session.session_view.last_active_time
stale_time = time.time() - last_activity
if stale_time > session.ttl_seconds:
cleaned_up_count += 1
self.close_session(session_id)

if cleaned_up_count > 0:
LOGGER.debug("Cleaned up %d orphaned sessions", cleaned_up_count)
else:
LOGGER.debug("No orphaned sessions to clean up")


class LspServer:
def __init__(self, port: int) -> None:
Expand Down
6 changes: 4 additions & 2 deletions marimo/_utils/debounce.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright 2024 Marimo. All rights reserved.
from __future__ import annotations

import time
from functools import wraps
from typing import Any, Callable, TypeVar, cast

F = TypeVar("F", bound=Callable[..., None])


def debounce(wait_time: float) -> Callable[[F], F]:
def debounce(wait_time_seconds: float) -> Callable[[F], F]:
"""
Decorator to prevent a function from being called more than once every
wait_time seconds.
Expand All @@ -19,7 +21,7 @@ def decorator(func: F) -> F:
def wrapped(*args: Any, **kwargs: Any) -> None:
nonlocal last_called
current_time = time.time()
if current_time - last_called >= wait_time:
if current_time - last_called >= wait_time_seconds:
last_called = current_time
func(*args, **kwargs)

Expand Down
8 changes: 7 additions & 1 deletion marimo/_utils/distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,13 @@ def start(self) -> Disposable:

def stop(self) -> None:
"""Stop distributing the response."""
asyncio.get_event_loop().remove_reader(self.input_connection.fileno())
try:
asyncio.get_event_loop().remove_reader(
self.input_connection.fileno()
)
except OSError:
# Handle may already be closed
pass
if not self.input_connection.closed:
self.input_connection.close()
self.consumers.clear()
Expand Down
137 changes: 135 additions & 2 deletions tests/_server/test_session_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from unittest.mock import MagicMock, Mock
import asyncio
import time
from unittest.mock import MagicMock, Mock, patch

import pytest

Expand Down Expand Up @@ -34,7 +36,7 @@ def mock_session():

@pytest.fixture
def session_manager():
return SessionManager(
sm = SessionManager(
file_router=AppFileRouter.new_file(),
mode=SessionMode.EDIT,
development_mode=False,
Expand All @@ -48,6 +50,13 @@ def session_manager():
ttl_seconds=None,
)

# Unwrap debounce from cleanup_orphaned_sessions
if hasattr(sm.cleanup_orphaned_sessions, "__wrapped__"):
unwrapped = sm.cleanup_orphaned_sessions.__wrapped__
sm.cleanup_orphaned_sessions = lambda: unwrapped(sm)
yield sm
sm.shutdown()


async def test_start_lsp_server(session_manager: SessionManager) -> None:
await session_manager.start_lsp_server()
Expand Down Expand Up @@ -209,3 +218,127 @@ def test_shutdown(
session_manager.lsp_server.stop.assert_called_once()
assert len(session_manager.sessions) == 0
assert mock_session.close.call_count == 2


def test_cleanup_orphaned_sessions_no_sessions(
session_manager: SessionManager,
) -> None:
# Test with no sessions
session_manager.cleanup_orphaned_sessions()
assert len(session_manager.sessions) == 0


async def test_cleanup_orphaned_sessions_dead_kernel(
session_manager: SessionManager,
mock_session_consumer: SessionConsumer,
) -> None:
# Create a session with a dead kernel
session_id = "test_session_id"
session = session_manager.create_session(
session_id,
mock_session_consumer,
query_params={},
file_key=AppFileRouter.NEW_FILE,
)

assert session.kernel_manager.kernel_task is not None
session.kernel_manager.close_kernel()
await asyncio.sleep(0.05) # Flush the close

# Run cleanup
session_manager.cleanup_orphaned_sessions()

# Session should be cleaned up
assert session_id not in session_manager.sessions
session.close()


def test_cleanup_orphaned_sessions_stale_session(
session_manager: SessionManager,
mock_session_consumer: SessionConsumer,
) -> None:
# Create a session that will become stale
session_id = "test_session_id"
ttl = 1 # 1 second TTL for testing
session_manager.ttl_seconds = ttl

session = session_manager.create_session(
session_id,
mock_session_consumer,
query_params={},
file_key=AppFileRouter.NEW_FILE,
)

# Mock the session to be orphaned
session.connection_state = lambda: ConnectionState.ORPHANED

# Mock last_active_time to be in the past
with patch.object(
session.session_view, "last_active_time", time.time() - ttl - 1
):
# Run cleanup
session_manager.cleanup_orphaned_sessions()

# Session should be cleaned up
assert session_id not in session_manager.sessions
session.close()


def test_cleanup_orphaned_sessions_active_session(
session_manager: SessionManager,
mock_session_consumer: SessionConsumer,
) -> None:
# Create an active session
session_id = "test_session_id"
session = session_manager.create_session(
session_id,
mock_session_consumer,
query_params={},
file_key=AppFileRouter.NEW_FILE,
)

# Mock the session to be active
session.connection_state = lambda: ConnectionState.OPEN

# Run cleanup
session_manager.cleanup_orphaned_sessions()

# Session should still be there
assert session_id in session_manager.sessions
session.close()


def test_cleanup_orphaned_sessions_not_stale_yet(
session_manager: SessionManager,
mock_session_consumer: SessionConsumer,
) -> None:
# Create a session that is orphaned but not stale yet
session_id = "test_session_id"
ttl = 60 # 60 second TTL
session_manager.ttl_seconds = ttl

session = session_manager.create_session(
session_id,
mock_session_consumer,
query_params={},
file_key=AppFileRouter.NEW_FILE,
)

# Mock the session to be orphaned
session.connection_state = lambda: ConnectionState.ORPHANED

# Update last_active_time to be recent
session.session_view._touch()

# Run cleanup
session_manager.cleanup_orphaned_sessions()

# Session should still be there since it hasn't exceeded TTL
assert session_id in session_manager.sessions

# Set last_active_time to be stale
session.session_view.last_active_time = time.time() - ttl - 1

# Run cleanup
session_manager.cleanup_orphaned_sessions()
assert session_id not in session_manager.sessions
6 changes: 4 additions & 2 deletions tests/_utils/test_decorator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Assuming the debounce decorator is defined in a file named debounce.py
from __future__ import annotations

import time

from marimo._utils.debounce import debounce
Expand All @@ -8,7 +10,7 @@


def test_debounce_within_period() -> None:
@debounce(wait_time=1.5)
@debounce(wait_time_seconds=1.5)
def my_function() -> None:
global call_count
call_count += 1
Expand All @@ -25,7 +27,7 @@ def my_function() -> None:


def test_debounce_after_period() -> None:
@debounce(wait_time=0.1)
@debounce(wait_time_seconds=0.1)
def my_function() -> None:
global call_count
call_count += 1
Expand Down
Loading