Skip to content

Commit

Permalink
Feature: Fork kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
zmbc committed Aug 4, 2024
1 parent d4a8703 commit 35c0a20
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 17 deletions.
31 changes: 17 additions & 14 deletions ipykernel/ipkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,23 @@ def __init__(self, **kwargs):
self.debug_just_my_code,
)

self.init_shell()

if _use_appnope() and self._darwin_app_nap:
# Disable app-nap as the kernel is not a gui but can have guis
import appnope # type:ignore[import-untyped]

appnope.nope()

self._new_threads_parent_header = {}
self._initialize_thread_hooks()

if hasattr(gc, "callbacks"):
# while `gc.callbacks` exists since Python 3.3, pypy does not
# implement it even as of 3.9.
gc.callbacks.append(self._clean_thread_parent_frames)

def init_shell(self):
# Initialize the InteractiveShell subclass
self.shell = self.shell_class.instance(
parent=self,
Expand Down Expand Up @@ -145,20 +162,6 @@ def __init__(self, **kwargs):
for msg_type in comm_msg_types:
self.shell_handlers[msg_type] = getattr(self.comm_manager, msg_type)

if _use_appnope() and self._darwin_app_nap:
# Disable app-nap as the kernel is not a gui but can have guis
import appnope # type:ignore[import-untyped]

appnope.nope()

self._new_threads_parent_header = {}
self._initialize_thread_hooks()

if hasattr(gc, "callbacks"):
# while `gc.callbacks` exists since Python 3.3, pypy does not
# implement it even as of 3.9.
gc.callbacks.append(self._clean_thread_parent_frames)

help_links = List(
[
{
Expand Down
95 changes: 93 additions & 2 deletions ipykernel/kernelapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,8 +726,99 @@ def start(self) -> None:
if self.poller is not None:
self.poller.start()
backend = "trio" if self.trio_loop else "asyncio"
run(self.main, backend=backend)
return

while True:
run(self.main, backend=backend)
if not getattr(self.kernel, "_fork_requested", False):
break
self.fork()

def fork(self):
# HACK: Why is this necessary?
# Without it, the *parent* kernel doesn't work.
# Also, it doesn't work if I try to start it again with
# self.init_iopub()...
self.iopub_thread.stop()

# Create a temporary connection file that will be inherited by the child process.
connection_file, conn = write_connection_file()

parent_pid = os.getpid()
pid = os.fork()
self.kernel._fork_requested = False # reset for parent AND child
if pid == 0:
self.log.debug("Child kernel with pid %s", os.getpid())

# close all sockets and ioloops
self.close()

# Reset all ports so they will be reinitialized with the ports from the connection file
for name in [
"%s_port" % channel for channel in ("shell", "stdin", "iopub", "hb", "control")
]:
setattr(self, name, 0)
self.connection_file = connection_file

# Reset the ZMQ context for it to be recreated
self.context = None

# Make ParentPoller work correctly (the new process is a child of the previous kernel)
self.parent_handle = parent_pid

# Session have a protection to send messages from forked processes through the `check_pid` flag.
self.session.pid = os.getpid()
self.session.key = conn["key"].encode()

self.init_connection_file()
self.init_poller()
self.init_sockets()
self.init_heartbeat()
self.init_io()

kernel = self.kernel
params = dict(
parent=self,
session=self.session,
control_socket=self.control_socket,
control_thread=self.control_thread,
debugpy_socket=self.debugpy_socket,
debug_shell_socket=self.debug_shell_socket,
shell_socket=self.shell_socket,
iopub_thread=self.iopub_thread,
iopub_socket=self.iopub_socket,
stdin_socket=self.stdin_socket,
log=self.log,
profile_dir=self.profile_dir,
)
for k, v in params.items():
setattr(kernel, k, v)

kernel.user_ns = kernel.shell.user_ns
kernel.init_shell()

kernel.record_ports({name + "_port": port for name, port in self._ports.items()})
self.kernel = kernel

# Allow the displayhook to get the execution count
self.displayhook.get_execution_count = lambda: kernel.execution_count

# shell init steps
self.init_shell()
if self.shell:
self.init_gui_pylab()
self.init_extensions()
self.init_code()
# flush stdout/stderr, so that anything written to these streams during
# initialization do not get associated with the first execution request
sys.stdout.flush()
sys.stderr.flush()
self.start()
else:
self.log.debug("Parent kernel will resume")
# keep a reference, since the will set this to None
post_fork_callback = self.kernel._post_fork_callback
post_fork_callback(pid, conn)
self.kernel._post_fork_callback = None

async def main(self):
async with create_task_group() as tg:
Expand Down
23 changes: 22 additions & 1 deletion ipykernel/kernelbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def _parent_header(self):
"shutdown_request",
"is_complete_request",
"interrupt_request",
"fork",
# deprecated:
"apply_request",
]
Expand All @@ -229,6 +230,25 @@ def _parent_header(self):
"usage_request",
]

def fork(self, stream, ident, parent):
# Forking in the (async)io loop is not supported.
# instead, we stop it, and use the io loop to pass
# information up the callstack
# loop = ioloop.IOLoop.current()
self._fork_requested = True

def post_fork_callback(pid, conn):
reply_content = json_clean({"status": "ok", "pid": pid, "conn": conn})
metadata = {}
metadata = self.finish_metadata(parent, metadata, reply_content)

self.session.send(
stream, "fork_reply", reply_content, parent, metadata=metadata, ident=ident
)

self._post_fork_callback = post_fork_callback
self.stop()

def __init__(self, **kwargs):
"""Initialize the kernel."""
super().__init__(**kwargs)
Expand Down Expand Up @@ -469,7 +489,8 @@ async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None:
if not self._is_test and self.control_socket is not None:
if self.control_thread:
self.control_thread.set_task(self.control_main)
self.control_thread.start()
if not self.control_thread.is_alive():
self.control_thread.start()
else:
tg.start_soon(self.control_main)

Expand Down
45 changes: 45 additions & 0 deletions tests/test_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .utils import (
TIMEOUT,
assemble_output,
connect_to_kernel,
execute,
flush_channels,
get_reply,
Expand Down Expand Up @@ -491,6 +492,50 @@ def test_shutdown():
assert not km.is_alive()


def test_fork_metadata():
with new_kernel() as kc:
from .test_message_spec import validate_message

km = kc.parent
fork_msg_id = kc.fork()
fork_reply = kc.get_shell_msg(timeout=TIMEOUT)
validate_message(fork_reply, "fork_reply", fork_msg_id)
assert fork_msg_id == fork_reply["parent_header"]["msg_id"] == fork_msg_id
assert fork_reply["content"]["conn"]["key"] != kc.session.key.decode()
fork_pid = fork_reply["content"]["pid"]
_check_status(fork_reply["content"])
wait_for_idle(kc)

assert fork_pid != km.provisioner.pid
# TODO: Inspect if `fork_pid` is running? Might need to use `psutil` for this in order to be cross platform

with connect_to_kernel(fork_reply["content"]["conn"], TIMEOUT) as kc_fork:
assert fork_reply["content"]["conn"]["key"] == kc_fork.session.key.decode()
kc_fork.shutdown()


def test_fork():
def execute_with_user_expression(kc, code, user_expression):
_, reply = execute(code, kc=kc, user_expressions={"my-user-expression": user_expression})
content = reply["user_expressions"]["my-user-expression"]["data"]["text/plain"]
wait_for_idle(kc)
return content

"""Kernel forks after fork_request"""
with kernel() as kc:
assert execute_with_user_expression(kc, "a = 1", "a") == "1"
assert execute_with_user_expression(kc, "b = 2", "b") == "2"
kc.fork()
fork_reply = kc.get_shell_msg(timeout=TIMEOUT)
wait_for_idle(kc)

with connect_to_kernel(fork_reply["content"]["conn"], TIMEOUT) as kc_fork:
assert execute_with_user_expression(kc_fork, "a = 11", "a, b") == str((11, 2))
assert execute_with_user_expression(kc_fork, "b = 12", "a, b") == str((11, 12))
assert execute_with_user_expression(kc, "z = 20", "a, b") == str((1, 2))
kc_fork.shutdown()


def test_interrupt_during_input():
"""
The kernel exits after being interrupted while waiting in input().
Expand Down
6 changes: 6 additions & 0 deletions tests/test_message_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,11 @@ class IsCompleteReplyIncomplete(Reference):
indent = Unicode()


class ForkReply(Reply):
pid = Integer()
conn = Dict()


# IOPub messages


Expand Down Expand Up @@ -255,6 +260,7 @@ class HistoryReply(Reply):
"stream": Stream(),
"display_data": DisplayData(),
"header": RHeader(),
"fork_reply": ForkReply(),
}

# -----------------------------------------------------------------------------
Expand Down
13 changes: 13 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,16 @@ def __enter__(self):
def __exit__(self, exc, value, tb):
os.chdir(self.old_wd)
return super().__exit__(exc, value, tb)


@contextmanager
def connect_to_kernel(connection_info, timeout):
from jupyter_client import BlockingKernelClient

kc = BlockingKernelClient()
kc.log.setLevel("DEBUG")
kc.load_connection_info(connection_info)
kc.start_channels()
kc.wait_for_ready(timeout)
yield kc
kc.stop_channels()

0 comments on commit 35c0a20

Please sign in to comment.