Skip to content

Commit

Permalink
fix: kernel globals memory leak in run mode (#3634)
Browse files Browse the repository at this point in the history
This change fixes a memory leak in run mode in which a kernel's globals
memory appeared to not be freed on session (thread) exit. However,
another (smaller) leak appears to still exist.

This change also makes sure to stop the stream's buffered writer thread
when kernel exits.

Also, add a TODO for performance of VariableValue construction --
calling str() on some values, such as a bytearray of size 1GB, can make
memory usage spike to 25GB.

Hopefully improves #3623

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
akshayka and pre-commit-ci[bot] authored Jan 31, 2025
1 parent bea2632 commit 805ba31
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 9 deletions.
11 changes: 7 additions & 4 deletions marimo/_messaging/console_output_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _add_output_to_buffer(


def buffered_writer(
msg_queue: deque[ConsoleMsg],
msg_queue: deque[ConsoleMsg | None],
stream: Stream,
cv: Condition,
) -> None:
Expand All @@ -83,6 +83,8 @@ def buffered_writer(
variable is used to synchronize access to `msg_queue`, and to obtain
notifications when messages have been added. (A deque + condition variable
was noticeably faster than the builtin queue.Queue in testing.)
A `None` passed to `msg_queue` signals the writer should terminate.
"""

# only have a non-None timer when there's at least one output buffered
Expand All @@ -102,9 +104,10 @@ def buffered_writer(
if timer is not None or not msg_queue:
cv.wait(timeout=timer)
while msg_queue:
_add_output_to_buffer(
msg_queue.popleft(), outputs_buffered_per_cell
)
msg = msg_queue.popleft()
if msg is None:
return
_add_output_to_buffer(msg, outputs_buffered_per_cell)
if outputs_buffered_per_cell and timer is None:
# start the timeout timer
timer = TIMEOUT_S
Expand Down
4 changes: 2 additions & 2 deletions marimo/_messaging/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,9 +539,9 @@ def _stringify(self, value: object) -> str:
if table_manager is not None:
return str(table_manager)
else:
# TODO(akshayka): str(value) can be extremely expensive
# for some objects; find a better solution.
return str(value)[:50]

return str(value)[:50]
except BaseException:
# Catch-all: some libraries like Polars have bugs and raise
# BaseExceptions, which shouldn't crash the kernel
Expand Down
11 changes: 10 additions & 1 deletion marimo/_messaging/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(

# Console outputs are buffered
self.console_msg_cv = threading.Condition(threading.Lock())
self.console_msg_queue: deque[ConsoleMsg] = deque()
self.console_msg_queue: deque[ConsoleMsg | None] = deque()
self.buffered_console_thread = threading.Thread(
target=buffered_writer,
args=(self.console_msg_queue, self, self.console_msg_cv),
Expand All @@ -110,6 +110,15 @@ def write(self, op: str, data: dict[Any, Any]) -> None:
# server process shutting down
LOGGER.debug("Error when writing (op: %s) to pipe: %s", op, e)

def stop(self) -> None:
"""Teardown resources created by the stream."""
# Sending `None` through the queue signals the console thread to exit.
# We don't join the thread in case its processing outputs still; don't
# want to block the entire program.
self.console_msg_queue.append(None)
with self.console_msg_cv:
self.console_msg_cv.notify()


def _forward_os_stream(standard_stream: Stdout | Stderr, fd: int) -> None:
"""Watch a file descriptor and forward it to a stream object."""
Expand Down
13 changes: 11 additions & 2 deletions marimo/_runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
KernelRuntimeContext,
initialize_kernel_context,
)
from marimo._runtime.context.types import teardown_context
from marimo._runtime.control_flow import MarimoInterrupt
from marimo._runtime.input_override import input_override
from marimo._runtime.packages.module_registry import ModuleRegistry
Expand Down Expand Up @@ -2251,7 +2252,7 @@ def _enqueue_control_request(req: ControlRequest) -> None:

ui_element_request_mgr = SetUIElementRequestManager(set_ui_element_queue)

async def control_loop() -> None:
async def control_loop(kernel: Kernel) -> None:
while True:
try:
request: ControlRequest | None = control_queue.get()
Expand All @@ -2272,14 +2273,22 @@ async def control_loop() -> None:
# top-level await; nothing else is awaited. Don't introduce async
# primitives anywhere else in the runtime unless there is a *very* good
# reason; prefer using threads (for performance and clarity).
asyncio.run(control_loop())
asyncio.run(control_loop(kernel))

if stdout is not None:
stdout._watcher.stop()
if stderr is not None:
stderr._watcher.stop()
get_context().virtual_file_registry.shutdown()
stream.stop()

if profiler is not None and profile_path is not None:
profiler.disable()
profiler.dump_stats(profile_path)

# TODO(akshayka): There's a memory leak in run mode, with memory
# usage increasing with each session creation. Somehow the kernel
# appears to leak, even though the thread exits. As a hack we manually
# clear various data structures.
teardown_context()
kernel._module.__dict__.clear()
2 changes: 2 additions & 0 deletions marimo/_runtime/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
get_context,
initialize_context,
runtime_context_installed,
teardown_context,
)

# Set of thread ids for running mo.Threads
Expand Down Expand Up @@ -94,3 +95,4 @@ def run(self) -> None:
THREADS.add(thread_id)
super().run()
THREADS.remove(thread_id)
teardown_context()

0 comments on commit 805ba31

Please sign in to comment.