diff --git a/libshortfin/bindings/python/_shortfin/asyncio_bridge.py b/libshortfin/bindings/python/_shortfin/asyncio_bridge.py new file mode 100644 index 000000000..2cfc7f600 --- /dev/null +++ b/libshortfin/bindings/python/_shortfin/asyncio_bridge.py @@ -0,0 +1,74 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import asyncio +from collections.abc import Callable +from contextvars import Context +from typing_extensions import Unpack + +from . import lib as sfl + + +class PyWorkerEventLoop(asyncio.AbstractEventLoop): + def __init__(self, worker: sfl.local.Worker): + self._worker = worker + + def get_debug(self): + # Requirement of asyncio. + return False + + def create_task(self, coro): + return asyncio.Task(coro, loop=self) + + def create_future(self): + return asyncio.Future(loop=self) + + def time(self) -> float: + return self._worker._now() / 1e9 + + def call_soon_threadsafe(self, callback, *args, context=None) -> asyncio.Handle: + def on_worker(): + asyncio.set_event_loop(self) + return callback(*args) + + self._worker.call_threadsafe(on_worker) + # TODO: Return future. + + def call_soon(self, callback, *args, context=None) -> asyncio.Handle: + handle = _Handle(callback, args, self, context) + self._worker.call(handle._sf_maybe_run) + return handle + + def call_later( + self, delay: float, callback, *args, context=None + ) -> asyncio.TimerHandle: + w = self._worker + deadline = w._delay_to_deadline_ns(delay) + handle = _TimerHandle(deadline / 1e9, callback, args, self, context) + w.delay_call(deadline, handle._sf_maybe_run) + return handle + + def call_exception_handler(self, context) -> None: + # TODO: Should route this to the central exception handler. + raise RuntimeError(f"Async exception on {self._worker}: {context}") + + def _timer_handle_cancelled(self, handle): + # We don't do anything special: just skip it if it comes up. + pass + + +class _Handle(asyncio.Handle): + def _sf_maybe_run(self): + if self.cancelled(): + return + self._run() + + +class _TimerHandle(asyncio.TimerHandle): + def _sf_maybe_run(self): + if self.cancelled(): + return + self._run() diff --git a/libshortfin/bindings/python/lib_ext.cc b/libshortfin/bindings/python/lib_ext.cc index 7a65a15f8..8cb92f54c 100644 --- a/libshortfin/bindings/python/lib_ext.cc +++ b/libshortfin/bindings/python/lib_ext.cc @@ -7,6 +7,7 @@ #include "./lib_ext.h" #include "./utils.h" +#include "shortfin/local/process.h" #include "shortfin/local/scope.h" #include "shortfin/local/system.h" #include "shortfin/local/systems/amdgpu.h" @@ -16,6 +17,204 @@ namespace shortfin::python { +namespace { + +class Refs { + public: + py::object asyncio_create_task = + py::module_::import_("asyncio").attr("create_task"); + py::object asyncio_set_event_loop = + py::module_::import_("asyncio").attr("set_event_loop"); + py::object asyncio_set_running_loop = + py::module_::import_("asyncio.events").attr("_set_running_loop"); + py::object threading_Thread = + py::module_::import_("threading").attr("Thread"); + py::object threading_current_thread = + py::module_::import_("threading").attr("current_thread"); + py::object threading_main_thread = + py::module_::import_("threading").attr("main_thread"); + + py::handle lazy_PyWorkerEventLoop() { + if (!lazy_PyWorkerEventLoop_.is_valid()) { + lazy_PyWorkerEventLoop_ = py::module_::import_("_shortfin.asyncio_bridge") + .attr("PyWorkerEventLoop"); + } + return lazy_PyWorkerEventLoop_; + } + + private: + py::object lazy_PyWorkerEventLoop_; +}; + +// Custom worker which hosts an asyncio event loop. +class PyWorker : public local::Worker { + public: + PyWorker(PyInterpreterState *interp, std::shared_ptr refs, + Options options) + : Worker(std::move(options)), interp_(interp), refs_(std::move(refs)) {} + + void WaitForShutdown() override { + // Need to release the GIL if blocking. + py::gil_scoped_release g; + Worker::WaitForShutdown(); + } + + void OnThreadStart() override { + // If our own thread, teach Python about it. Not done for donated. + if (options().owned_thread) { + PyThreadState_New(interp_); + } + + py::gil_scoped_acquire g; + // Aside from set_event_loop being old and _set_running_loop being new + // it isn't clear to me that either can be left off. + refs_->asyncio_set_event_loop(loop_); + refs_->asyncio_set_running_loop(loop_); + } + + void OnThreadStop() override { + { + // Do Python level thread cleanup. + py::gil_scoped_acquire g; + loop_.reset(); + + // Scrub thread state if not donated. + if (options().owned_thread) { + PyThreadState_Clear(PyThreadState_Get()); + } else { + // Otherwise, juse reset the event loop. + refs_->asyncio_set_event_loop(py::none()); + refs_->asyncio_set_running_loop(py::none()); + } + } + + // And destroy our thread state (if not donated). + // TODO: PyThreadState_Delete seems like it should be used here, but I + // couldn't find that being done and I couldn't find a way to use it + // with the GIL/thread state correct. + if (options().owned_thread) { + PyThreadState_Swap(nullptr); + } + } + + std::string to_s() { return fmt::format("PyWorker(name='{}')", name()); } + + py::object loop_; + PyInterpreterState *interp_; + std::shared_ptr refs_; +}; + +std::unique_ptr CreatePyWorker(std::shared_ptr refs, + local::Worker::Options options) { + PyInterpreterState *interp = PyInterpreterState_Get(); + auto new_worker = + std::make_unique(interp, std::move(refs), std::move(options)); + py::object worker_obj = py::cast(*new_worker.get(), py::rv_policy::reference); + new_worker->loop_ = new_worker->refs_->lazy_PyWorkerEventLoop()(worker_obj); + return new_worker; +} + +class PyProcess : public local::detail::BaseProcess { + public: + PyProcess(std::shared_ptr scope, std::shared_ptr refs) + : BaseProcess(std::move(scope)), refs_(std::move(refs)) {} + using BaseProcess::Launch; + + void ScheduleOnWorker() override { + // This is tricky: We need to retain the object reference across the + // thread transition, but on the receiving side, the GIL will not be + // held initially, so we must avoid any refcount maintenance until it + // is acquired. Therefore, we manually borrow a reference and steal it in + // the callback. + py::handle self_object = py::cast(this, py::rv_policy::none); + self_object.inc_ref(); + scope()->worker().CallThreadsafe( + std::bind(&PyProcess::RunOnWorker, self_object)); + } + static void RunOnWorker(py::handle self_handle) { + { + py::gil_scoped_acquire g; + // Steal the reference back from ScheduleOnWorker. Important: this is + // very likely the last reference to the process. So self must not be + // touched after self_object goes out of scope. + py::object self_object = py::steal(self_handle); + PyProcess *self = py::cast(self_handle); + // We assume that the run method either returns None (def) or a coroutine + // (async def). + auto coro = self_object.attr("run")(); + if (!coro.is_none()) { + auto task = self->refs_->asyncio_create_task(coro); + // Capture the self object to avoid lifetime hazzard with PyProcess + // going away before done. + task.attr("add_done_callback")( + py::cpp_function([self_object](py::handle future) { + PyProcess *done_self = py::cast(self_object); + done_self->Terminate(); + })); + } else { + // Synchronous termination. + self->Terminate(); + } + } + } + + std::shared_ptr refs_; +}; + +py::object RunInForeground(std::shared_ptr refs, local::System &self, + py::object coro) { + bool is_main_thread = + refs->threading_current_thread().is(refs->threading_main_thread()); + + PyWorker &worker = dynamic_cast(self.init_worker()); + py::object result = py::none(); + auto done_callback = [&](py::handle future) { + worker.Kill(); + result = future.attr("result")(); + }; + worker.CallThreadsafe([&]() { + // Run within the worker we are about to donate to. + py::gil_scoped_acquire g; + auto task = refs->asyncio_create_task(coro); + task.attr("add_done_callback")(py::cpp_function(done_callback)); + }); + + auto run = py::cpp_function([&]() { + // Release GIL and run until the worker exits. + { + py::gil_scoped_release g; + worker.RunOnCurrentThread(); + } + }); + + // If running on the main thread, we spawn a background thread and join + // it because that shields it from receiving spurious KeyboardInterrupt + // exceptions at inopportune points. + if (is_main_thread) { + auto thread = refs->threading_Thread(/*group=*/py::none(), /*target=*/run); + thread.attr("start")(); + try { + thread.attr("join")(); + } catch (...) { + logging::warn("Exception caught in run(). Shutting down."); + // Leak warnings are hopeless in exceptional termination. + py::set_leak_warnings(false); + // Give it a go waiting for the worker thread to exit. + worker.Kill(); + thread.attr("join")(); + self.Shutdown(); + throw; + } + } else { + run(); + } + + self.Shutdown(); + return result; +} + +} // namespace + NB_MODULE(lib, m) { m.def("initialize", shortfin::GlobalInitialize); auto local_m = m.def_submodule("local"); @@ -28,10 +227,35 @@ NB_MODULE(lib, m) { } void BindLocal(py::module_ &m) { + // Keep weak refs to key objects that need explicit atexit shutdown. + auto weakref = py::module_::import_("weakref"); + py::object live_system_refs = weakref.attr("WeakSet")(); + auto atexit = py::module_::import_("atexit"); + // Manually shutdown all System instances atexit if still alive (it is + // not reliable to shutdown during interpreter finalization). + atexit.attr("register")(py::cpp_function([](py::handle live_system_refs) { + for (auto it = live_system_refs.begin(); + it != live_system_refs.end(); ++it) { + (*it).attr("shutdown")(); + } + }), + live_system_refs); + auto refs = std::make_shared(); + auto worker_factory = [refs](local::Worker::Options options) { + return CreatePyWorker(refs, std::move(options)); + }; + py::class_(m, "SystemBuilder") - .def("create_system", - [](local::SystemBuilder &self) { return self.CreateSystem(); }); - py::class_(m, "System") + .def("create_system", [live_system_refs, + worker_factory](local::SystemBuilder &self) { + auto system_ptr = self.CreateSystem(); + system_ptr->set_worker_factory(worker_factory); + auto system_obj = py::cast(system_ptr, py::rv_policy::take_ownership); + live_system_refs.attr("add")(system_obj); + return system_obj; + }); + py::class_(m, "System", py::is_weak_referenceable()) + .def("shutdown", &local::System::Shutdown) // Access devices by list, name, or lookup. .def_prop_ro("device_names", [](local::System &self) { @@ -53,7 +277,30 @@ void BindLocal(py::module_ &m) { return it->second; }, py::rv_policy::reference_internal) - .def("create_scope", &local::System::CreateScope); + .def( + "create_scope", + [](local::System &self, PyWorker &worker) { + return self.CreateScope(worker); + }, + py::rv_policy::reference_internal) + .def( + "create_scope", + [](local::System &self) { return self.CreateScope(); }, + py::rv_policy::reference_internal) + .def( + "create_worker", + [refs](local::System &self, std::string name) -> PyWorker & { + local::Worker::Options options(self.host_allocator(), + std::move(name)); + return dynamic_cast(self.CreateWorker(options)); + }, + py::arg("name"), py::rv_policy::reference_internal) + .def( + "run", + [refs](local::System &self, py::object coro) { + return RunInForeground(refs, self, std::move(coro)); + }, + py::arg("coro")); // Support classes. py::class_(m, "Node") @@ -133,6 +380,75 @@ void BindLocal(py::module_ &m) { return self.scope.device(name); }, py::rv_policy::reference_internal); + + py::class_(m, "_Worker", py::is_weak_referenceable()) + .def("__repr__", &local::Worker::to_s); + py::class_(m, "Worker") + .def_ro("loop", &PyWorker::loop_) + .def("call_threadsafe", &PyWorker::CallThreadsafe) + .def("call", + [](local::Worker &self, py::handle callable) { + callable.inc_ref(); // Stolen within the callback. + auto thunk = +[](void *user_data, iree_loop_t loop, + iree_status_t status) noexcept -> iree_status_t { + py::gil_scoped_acquire g; + py::object user_callable = + py::steal(static_cast(user_data)); + IREE_RETURN_IF_ERROR(status); + try { + user_callable(); + } catch (std::exception &e) { + return iree_make_status( + IREE_STATUS_UNKNOWN, + "Python exception raised from async callback: %s", + e.what()); + } + return iree_ok_status(); + }; + SHORTFIN_THROW_IF_ERROR(self.CallLowLevel(thunk, callable.ptr())); + }) + .def("delay_call", + [](local::Worker &self, iree_time_t deadline_ns, + py::handle callable) { + callable.inc_ref(); // Stolen within the callback. + auto thunk = +[](void *user_data, iree_loop_t loop, + iree_status_t status) noexcept -> iree_status_t { + py::gil_scoped_acquire g; + py::object user_callable = + py::steal(static_cast(user_data)); + IREE_RETURN_IF_ERROR(status); + try { + user_callable(); + } catch (std::exception &e) { + return iree_make_status( + IREE_STATUS_UNKNOWN, + "Python exception raised from async callback: %s", + e.what()); + } + return iree_ok_status(); + }; + SHORTFIN_THROW_IF_ERROR(self.WaitUntilLowLevel( + iree_make_deadline(deadline_ns), thunk, callable.ptr())); + }) + .def("_delay_to_deadline_ns", + [](local::Worker &self, double delay_seconds) { + return self.ConvertRelativeTimeoutToDeadlineNs( + static_cast(delay_seconds * 1e9)); + }) + .def("_now", [](local::Worker &self) { return self.now(); }) + .def("__repr__", &PyWorker::to_s); + + py::class_(m, "Process") + .def("__init__", [](py::args, py::kwargs) {}) + .def_static( + "__new__", + [refs](py::handle py_type, std::shared_ptr scope, + py::args, py::kwargs) { + return custom_new(py_type, std::move(scope), refs); + }) + .def_prop_ro("scope", &PyProcess::scope) + .def("launch", &PyProcess::Launch) + .def("__repr__", &PyProcess::to_s); } void BindHostSystem(py::module_ &global_m) { diff --git a/libshortfin/bindings/python/lib_ext.h b/libshortfin/bindings/python/lib_ext.h index b3f898a48..5fb38b87b 100644 --- a/libshortfin/bindings/python/lib_ext.h +++ b/libshortfin/bindings/python/lib_ext.h @@ -9,6 +9,7 @@ #include #include +#include #include #include #include diff --git a/libshortfin/bindings/python/shortfin/__init__.py b/libshortfin/bindings/python/shortfin/__init__.py new file mode 100644 index 000000000..4075ec768 --- /dev/null +++ b/libshortfin/bindings/python/shortfin/__init__.py @@ -0,0 +1,38 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from _shortfin import lib as _sfl + +# Most classes from the native "local" namespace are aliased to the top +# level of the public API. +Device = _sfl.local.Device +Node = _sfl.local.Node +Process = _sfl.local.Process +Scope = _sfl.local.Scope +ScopedDevice = _sfl.local.ScopedDevice +System = _sfl.local.System +SystemBuilder = _sfl.local.SystemBuilder +Worker = _sfl.local.Worker + +# Array is auto-imported. +from . import array + +# System namespaces. +from . import amdgpu +from . import host + +__all__ = [ + "Device", + "Node", + "Scope", + "ScopedDevice", + "System", + "SystemBuilder", + "Worker", + # System namespaces. + "amdgpu", + "host", +] diff --git a/libshortfin/bindings/python/shortfin/amdgpu.py b/libshortfin/bindings/python/shortfin/amdgpu.py new file mode 100644 index 000000000..d10dfa459 --- /dev/null +++ b/libshortfin/bindings/python/shortfin/amdgpu.py @@ -0,0 +1,24 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from _shortfin import lib as _sfl + +_is_available = False + + +def is_available(): + return _is_available + + +if hasattr(_sfl.local, "amdgpu"): + AMDGPUDevice = _sfl.local.amdgpu.AMDGPUDevice + SystemBuilder = _sfl.local.amdgpu.SystemBuilder + + __all__ = [ + "AMDGPUDevice", + "SystemBuilder", + ] + _is_available = True diff --git a/libshortfin/bindings/python/shortfin/array.py b/libshortfin/bindings/python/shortfin/array.py new file mode 100644 index 000000000..e99595554 --- /dev/null +++ b/libshortfin/bindings/python/shortfin/array.py @@ -0,0 +1,79 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from _shortfin import lib as _sfl + +# All dtype aliases. +opaque8 = _sfl.array.opaque8 +opaque16 = _sfl.array.opaque16 +opaque32 = _sfl.array.opaque32 +opaque64 = _sfl.array.opaque64 +bool8 = _sfl.array.bool8 +int4 = _sfl.array.int4 +sint4 = _sfl.array.sint4 +uint4 = _sfl.array.uint4 +int8 = _sfl.array.int8 +sint8 = _sfl.array.sint8 +uint8 = _sfl.array.uint8 +int16 = _sfl.array.int16 +sint16 = _sfl.array.sint16 +uint16 = _sfl.array.uint16 +int32 = _sfl.array.int32 +sint32 = _sfl.array.sint32 +uint32 = _sfl.array.uint32 +int64 = _sfl.array.int64 +sint64 = _sfl.array.sint64 +uint64 = _sfl.array.uint64 +float16 = _sfl.array.float16 +float32 = _sfl.array.float32 +float64 = _sfl.array.float64 +bfloat16 = _sfl.array.bfloat16 +complex64 = _sfl.array.complex64 +complex128 = _sfl.array.complex128 + + +base_array = _sfl.array.base_array +device_array = _sfl.array.device_array +host_array = _sfl.array.host_array +storage = _sfl.array.storage +DType = _sfl.array.DType + + +__all__ = [ + # DType aliases. + "opaque8", + "opaque16", + "opaque32", + "opaque64", + "bool8", + "int4", + "sint4", + "uint4", + "int8", + "sint8", + "uint8", + "int16", + "sint16", + "uint16", + "int32", + "sint32", + "uint32", + "int64", + "sint64", + "uint64", + "float16", + "float32", + "float64", + "bfloat16", + "complex64", + "complex128", + # Classes. + "base_array", + "device_array", + "host_array", + "storage", + "DType", +] diff --git a/libshortfin/bindings/python/shortfin/host.py b/libshortfin/bindings/python/shortfin/host.py new file mode 100644 index 000000000..66631fcc4 --- /dev/null +++ b/libshortfin/bindings/python/shortfin/host.py @@ -0,0 +1,17 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from _shortfin import lib as _sfl + +CPUSystemBuilder = _sfl.local.host.CPUSystemBuilder +HostCPUDevice = _sfl.local.host.HostCPUDevice +SystemBuilder = _sfl.local.host.SystemBuilder + +__all__ = [ + "CPUSystemBuilder" "HostCPUSystemBuilder", + "HostCPUDevice", + "SystemBuilder", +] diff --git a/libshortfin/examples/python/basic_asyncio.py b/libshortfin/examples/python/basic_asyncio.py new file mode 100644 index 000000000..a757b50a5 --- /dev/null +++ b/libshortfin/examples/python/basic_asyncio.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import asyncio +import threading +import time + +import shortfin as sf + +lsys = sf.host.CPUSystemBuilder().create_system() +worker = lsys.create_worker("main") +print("Worker:", worker) + + +async def do_something(i, delay): + print(f"({i}): FROM ASYNC do_something (tid={threading.get_ident()})", delay) + print(f"({i}): Time:", asyncio.get_running_loop().time(), "Delay:", delay) + await asyncio.sleep(delay) + print(f"({i}): DONE", delay) + return delay + + +import random + +fs = [] +total_delay = 0.0 +max_delay = 0.0 +for i in range(20): + delay = random.random() * 2 + total_delay += delay + max_delay = max(max_delay, delay) + print("SCHEDULE", i) + fs.append(asyncio.run_coroutine_threadsafe(do_something(i, delay), worker.loop)) + +for f in fs: + print(f.result()) + +print("TOTAL DELAY:", total_delay, "MAX:", max_delay) diff --git a/libshortfin/examples/python/process.py b/libshortfin/examples/python/process.py new file mode 100644 index 000000000..f5907d0e0 --- /dev/null +++ b/libshortfin/examples/python/process.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import asyncio + +import shortfin as sf + +lsys = sf.host.CPUSystemBuilder().create_system() + + +class MyProcess(sf.Process): + def __init__(self, scope, arg): + super().__init__(scope) + self.arg = arg + + async def run(self): + print("Hello async:", self.arg, self) + if self.arg < 10: + await asyncio.sleep(0.3) + MyProcess(self.scope, self.arg + 1).launch() + + +async def main(): + worker = lsys.create_worker("main") + scope = lsys.create_scope(worker) + for i in range(10): + MyProcess(scope, i).launch() + await asyncio.sleep(0.1) + MyProcess(scope, i * 100).launch() + await asyncio.sleep(0.1) + MyProcess(scope, i * 1000).launch() + await asyncio.sleep(2.5) + return i + + +print("RESULT:", lsys.run(main())) diff --git a/libshortfin/setup.py b/libshortfin/setup.py index 5387ec10e..f3fc4e9b9 100644 --- a/libshortfin/setup.py +++ b/libshortfin/setup.py @@ -55,8 +55,11 @@ def copy_extensions_to_source(self, *args, **kwargs): ... +python_src_dir = rel_source_dir / "bindings" / "python" +python_bin_dir = rel_binary_dir / "bindings" / "python" + setup( - name="libshortfin", + name="shortfin", version="0.9", description="Shortfin native library implementation", author="SHARK Authors", @@ -64,14 +67,14 @@ def copy_extensions_to_source(self, *args, **kwargs): "_shortfin", "_shortfin_default", # TODO: Conditionally map additional native library variants. + "shortfin", ], zip_safe=False, package_dir={ - "_shortfin": str(rel_source_dir / "bindings" / "python" / "_shortfin"), - "_shortfin_default": str( - rel_binary_dir / "bindings" / "python" / "_shortfin_default" - ), + "_shortfin": str(python_src_dir / "_shortfin"), + "_shortfin_default": str(python_bin_dir / "_shortfin_default"), # TODO: Conditionally map additional native library variants. + "shortfin": str(python_src_dir / "shortfin"), }, ext_modules=[ BuiltExtension("_shortfin_default.lib"), diff --git a/libshortfin/src/CMakeLists.txt b/libshortfin/src/CMakeLists.txt index 916070bd4..1a69094e0 100644 --- a/libshortfin/src/CMakeLists.txt +++ b/libshortfin/src/CMakeLists.txt @@ -19,7 +19,6 @@ shortfin_public_library( COMPONENTS shortfin_array shortfin_local - shortfin_process shortfin_support shortfin_systems_amdgpu shortfin_systems_host diff --git a/libshortfin/src/shortfin/CMakeLists.txt b/libshortfin/src/shortfin/CMakeLists.txt index b22f48791..3c50ed1eb 100644 --- a/libshortfin/src/shortfin/CMakeLists.txt +++ b/libshortfin/src/shortfin/CMakeLists.txt @@ -6,5 +6,4 @@ add_subdirectory(array) add_subdirectory(local) -add_subdirectory(process) add_subdirectory(support) diff --git a/libshortfin/src/shortfin/local/CMakeLists.txt b/libshortfin/src/shortfin/local/CMakeLists.txt index 228b2e914..db410ec0a 100644 --- a/libshortfin/src/shortfin/local/CMakeLists.txt +++ b/libshortfin/src/shortfin/local/CMakeLists.txt @@ -11,19 +11,23 @@ shortfin_cc_component( shortfin_local HDRS device.h + process.h + worker.h scheduler.h scope.h system.h SRCS device.cc + process.cc + worker.cc scheduler.cc scope.cc system.cc COMPONENTS - shortfin_process shortfin_support DEPS iree_base_base + iree_base_loop_sync iree_hal_hal iree_modules_hal_hal iree_vm_vm diff --git a/libshortfin/src/shortfin/local/process.cc b/libshortfin/src/shortfin/local/process.cc new file mode 100644 index 000000000..56fbcf904 --- /dev/null +++ b/libshortfin/src/shortfin/local/process.cc @@ -0,0 +1,76 @@ +// Copyright 2024 Advanced Micro Devices, Inc +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "shortfin/local/process.h" + +#include "fmt/core.h" +#include "shortfin/local/system.h" +#include "shortfin/support/logging.h" + +namespace shortfin::local { + +detail::BaseProcess::BaseProcess(std::shared_ptr scope) + : scope_(std::move(scope)) {} + +detail::BaseProcess::~BaseProcess() {} + +int64_t detail::BaseProcess::pid() const { + iree_slim_mutex_lock_guard g(lock_); + return pid_; +} + +std::string detail::BaseProcess::to_s() const { + int pid; + { + iree_slim_mutex_lock_guard g(lock_); + pid = pid_; + } + + if (pid == 0) { + return fmt::format("Process(NOT_STARTED, worker='{}')", + scope_->worker().name()); + } else if (pid < 0) { + return fmt::format("Process(TERMINATED, worker='{}')", + scope_->worker().name()); + } else { + return fmt::format("Process(pid={}, worker='{}')", pid, + scope_->worker().name()); + } +} + +void detail::BaseProcess::Launch() { + Scope* scope = scope_.get(); + { + iree_slim_mutex_lock_guard g(lock_); + if (pid_ != 0) { + throw std::logic_error("Process can only be launched a single time"); + } + pid_ = scope->system().AllocateProcess(this); + } + + ScheduleOnWorker(); +} + +void detail::BaseProcess::ScheduleOnWorker() { + logging::info("ScheduleOnWorker()"); + Terminate(); +} + +void detail::BaseProcess::Terminate() { + int deallocate_pid; + { + iree_slim_mutex_lock_guard g(lock_); + deallocate_pid = pid_; + pid_ = -1; + } + if (deallocate_pid > 0) { + scope_->system().DeallocateProcess(deallocate_pid); + } else { + logging::warn("Process signalled termination multiple times (ignored)"); + } +} + +} // namespace shortfin::local diff --git a/libshortfin/src/shortfin/local/process.h b/libshortfin/src/shortfin/local/process.h new file mode 100644 index 000000000..72efe9864 --- /dev/null +++ b/libshortfin/src/shortfin/local/process.h @@ -0,0 +1,73 @@ +// Copyright 2024 Advanced Micro Devices, Inc +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef SHORTFIN_LOCAL_PROCESS_H +#define SHORTFIN_LOCAL_PROCESS_H + +#include +#include + +#include "shortfin/local/scope.h" +#include "shortfin/local/worker.h" +#include "shortfin/support/api.h" +#include "shortfin/support/iree_concurrency.h" + +namespace shortfin::local { + +namespace detail { + +// Processes have a unique lifetime and also can be extended from languages +// other than C++. We therefore have a more binding friendly base class that +// can be used when the Process is aggregated in some kind of foreign +// structure and external lifetime management. +class SHORTFIN_API BaseProcess { + public: + BaseProcess(std::shared_ptr scope); + BaseProcess(const BaseProcess &) = delete; + virtual ~BaseProcess(); + + // The unique pid of this process (or zero if not launched). + int64_t pid() const; + std::string to_s() const; + std::shared_ptr &scope() { return scope_; } + + protected: + // Launches the process. + void Launch(); + + // Subclasses will have ScheduleOnWorker() called exactly once during + // Launch(). The subclass must eventually call Terminate(), either + // synchronously within this call frame or asynchronously at a future point. + virtual void ScheduleOnWorker(); + + // Called when this process has asynchronously finished. + void Terminate(); + + private: + std::shared_ptr scope_; + + // Process control state. Since this can be accessed by multiple threads, + // it is protected by a lock. Most process state can only be accessed on + // the worker thread and is unprotected. + mutable iree_slim_mutex lock_; + // Pid is 0 if not yet started, -1 if terminated, and a postive value if + // running. + int64_t pid_ = 0; +}; + +} // namespace detail + +// Processes are the primary unit of scheduling in shortfin. They are light +// weight entities that are created on a Worker and operate in an event +// driven fashion (i.e. cps, async/await, co-routines, etc). +class SHORTFIN_API Process : public detail::BaseProcess { + public: + using BaseProcess::BaseProcess; +}; + +} // namespace shortfin::local + +#endif // SHORTFIN_LOCAL_PROCESS_H diff --git a/libshortfin/src/shortfin/local/scope.cc b/libshortfin/src/shortfin/local/scope.cc index 28edf5535..56e96e73e 100644 --- a/libshortfin/src/shortfin/local/scope.cc +++ b/libshortfin/src/shortfin/local/scope.cc @@ -9,6 +9,7 @@ #include #include +#include "shortfin/local/system.h" #include "shortfin/support/logging.h" namespace shortfin::local { @@ -17,17 +18,24 @@ namespace shortfin::local { // Scope // -------------------------------------------------------------------------- // -Scope::Scope(iree_allocator_t host_allocator, +Scope::Scope(std::shared_ptr system, Worker &worker, std::span> devices) - : host_allocator_(host_allocator), scheduler_(host_allocator) { + : host_allocator_(system->host_allocator()), + scheduler_(system->host_allocator()), + system_(std::move(system)), + worker_(worker) { for (auto &it : devices) { AddDevice(it.first, it.second); } Initialize(); } -Scope::Scope(iree_allocator_t host_allocator, std::span devices) - : host_allocator_(host_allocator), scheduler_(host_allocator) { +Scope::Scope(std::shared_ptr system, Worker &worker, + std::span devices) + : host_allocator_(system->host_allocator()), + scheduler_(system->host_allocator()), + system_(std::move(system)), + worker_(worker) { for (auto *device : devices) { AddDevice(device->address().logical_device_class, device); } diff --git a/libshortfin/src/shortfin/local/scope.h b/libshortfin/src/shortfin/local/scope.h index 5082c069f..78f7c1039 100644 --- a/libshortfin/src/shortfin/local/scope.h +++ b/libshortfin/src/shortfin/local/scope.h @@ -18,6 +18,8 @@ namespace shortfin::local { class SHORTFIN_API Scope; +class SHORTFIN_API System; +class SHORTFIN_API Worker; // Wraps a Scope and a DeviceAffinity together. This is used in all // Scope based APIs as a short-hand for "device" as it contains everything @@ -60,17 +62,27 @@ class SHORTFIN_API ScopedDevice { // situations, this can be customized. By default, devices are added in the // order defined by the system and will have an `` corresponding to // their order. It is up to the constructor to produce a sensible arrangement. -class SHORTFIN_API Scope { +class SHORTFIN_API Scope : public std::enable_shared_from_this { public: // Initialize with devices using logical_device_class as the device class. - Scope(iree_allocator_t host_allocator, std::span devices); + Scope(std::shared_ptr system, Worker &worker, + std::span devices); // Initialize with devices with custom device class names. - Scope(iree_allocator_t host_allocator, + Scope(std::shared_ptr system, Worker &worker, std::span> devices); Scope(const Scope &) = delete; // Ensure polymorphic. virtual ~Scope(); + // All scopes are created as shared pointers. + std::shared_ptr shared_ptr() { return shared_from_this(); } + + // The worker that this scope is bound to. + Worker &worker() { return worker_; } + + // System that this scope is bound to. + System &system() { return *system_; } + // Device access. // Throws std::invalid_argument on lookup failure. Device *raw_device(std::string_view name) const; @@ -116,6 +128,10 @@ class SHORTFIN_API Scope { // Map of `` to Device. std::unordered_map named_devices_; detail::Scheduler scheduler_; + + // Back reference to owning system. + std::shared_ptr system_; + Worker &worker_; }; } // namespace shortfin::local diff --git a/libshortfin/src/shortfin/local/system.cc b/libshortfin/src/shortfin/local/system.cc index 64ab43c02..3c25e375c 100644 --- a/libshortfin/src/shortfin/local/system.cc +++ b/libshortfin/src/shortfin/local/system.cc @@ -13,25 +13,6 @@ namespace shortfin::local { -namespace { - -// A Scope with a back reference to the System from which it -// originated. -class ExtendingScope : public Scope { - public: - using Scope::Scope; - - private: - std::shared_ptr backref_; - friend std::shared_ptr &mutable_local_scope_backref(ExtendingScope &); -}; - -std::shared_ptr &mutable_local_scope_backref(ExtendingScope &scope) { - return scope.backref_; -} - -} // namespace - // -------------------------------------------------------------------------- // // System // -------------------------------------------------------------------------- // @@ -44,13 +25,51 @@ System::System(iree_allocator_t host_allocator) } System::~System() { + bool needs_shutdown = false; + { + iree_slim_mutex_lock_guard guard(lock_); + if (initialized_ && !shutdown_) { + needs_shutdown = true; + } + } + if (needs_shutdown) { + logging::warn( + "Implicit Shutdown from System destructor. Please call Shutdown() " + "explicitly for maximum stability."); + Shutdown(); + } +} + +std::unique_ptr System::DefaultWorkerFactory(Worker::Options options) { + return std::make_unique(std::move(options)); +} + +void System::set_worker_factory(Worker::Factory factory) { + iree_slim_mutex_lock_guard guard(lock_); + worker_factory_ = std::move(factory); +} + +void System::Shutdown() { + // Stop workers. + std::vector> local_workers; + { + iree_slim_mutex_lock_guard guard(lock_); + if (!initialized_ || shutdown_) return; + shutdown_ = true; + workers_by_name_.clear(); + local_workers.swap(workers_); + } + // Worker drain and shutdown. - for (auto &worker : workers_) { + for (auto &worker : local_workers) { worker->Kill(); } - for (auto &worker : workers_) { - worker->WaitForShutdown(); + for (auto &worker : local_workers) { + if (worker->options().owned_thread) { + worker->WaitForShutdown(); + } } + local_workers.clear(); // Orderly destruction of heavy-weight objects. // Shutdown order is important so we don't leave it to field ordering. @@ -65,11 +84,15 @@ System::~System() { hal_drivers_.clear(); } -std::unique_ptr System::CreateScope() { - auto new_scope = - std::make_unique(host_allocator(), devices()); - mutable_local_scope_backref(*new_scope) = shared_from_this(); - return new_scope; +std::shared_ptr System::CreateScope(Worker &worker) { + iree_slim_mutex_lock_guard guard(lock_); + return std::make_shared(shared_ptr(), worker, devices()); +} + +std::shared_ptr System::CreateScope() { + Worker &w = init_worker(); + iree_slim_mutex_lock_guard guard(lock_); + return std::make_shared(shared_ptr(), w, devices()); } void System::InitializeNodes(int node_count) { @@ -83,6 +106,46 @@ void System::InitializeNodes(int node_count) { } } +Worker &System::CreateWorker(Worker::Options options) { + Worker *unowned_worker; + { + iree_slim_mutex_lock_guard guard(lock_); + if (options.name == std::string_view("__init__")) { + throw std::invalid_argument( + "Cannot create worker '__init__' (reserved name)"); + } + if (workers_by_name_.count(options.name) != 0) { + throw std::invalid_argument(fmt::format( + "Cannot create worker with duplicate name '{}'", options.name)); + } + auto worker = worker_factory_(std::move(options)); + workers_.push_back(std::move(worker)); + unowned_worker = workers_.back().get(); + workers_by_name_[unowned_worker->name()] = unowned_worker; + } + if (unowned_worker->options().owned_thread) { + unowned_worker->Start(); + } + return *unowned_worker; +} + +Worker &System::init_worker() { + iree_slim_mutex_lock_guard guard(lock_); + auto found_it = workers_by_name_.find("__init__"); + if (found_it != workers_by_name_.end()) { + return *found_it->second; + } + + // Create. + Worker::Options options(host_allocator(), "__init__"); + options.owned_thread = false; + auto worker = worker_factory_(std::move(options)); + workers_.push_back(std::move(worker)); + Worker *unowned_worker = workers_.back().get(); + workers_by_name_[unowned_worker->name()] = unowned_worker; + return *unowned_worker; +} + void System::InitializeHalDriver(std::string_view moniker, iree_hal_driver_ptr driver) { AssertNotInitialized(); @@ -107,17 +170,21 @@ void System::InitializeHalDevice(std::unique_ptr device) { } void System::FinishInitialization() { + iree_slim_mutex_lock_guard guard(lock_); AssertNotInitialized(); + initialized_ = true; +} - // TODO: Remove this. Just testing. - // workers_.push_back( - // std::make_unique(Worker::Options(host_allocator(), - // "worker:0"))); - // workers_.back()->Start(); - // workers_.back()->EnqueueCallback( - // []() { spdlog::info("Hi from a worker callback"); }); +int64_t System::AllocateProcess(detail::BaseProcess *p) { + iree_slim_mutex_lock_guard guard(lock_); + int pid = next_pid_++; + processes_by_pid_[pid] = p; + return pid; +} - initialized_ = true; +void System::DeallocateProcess(int64_t pid) { + iree_slim_mutex_lock_guard guard(lock_); + processes_by_pid_.erase(pid); } } // namespace shortfin::local diff --git a/libshortfin/src/shortfin/local/system.h b/libshortfin/src/shortfin/local/system.h index 08277aa73..a4a47e3a6 100644 --- a/libshortfin/src/shortfin/local/system.h +++ b/libshortfin/src/shortfin/local/system.h @@ -15,12 +15,18 @@ #include #include "shortfin/local/device.h" -#include "shortfin/process/worker.h" +#include "shortfin/local/worker.h" #include "shortfin/support/api.h" +#include "shortfin/support/iree_concurrency.h" #include "shortfin/support/iree_helpers.h" +#include "shortfin/support/stl_extras.h" namespace shortfin::local { +namespace detail { +class BaseProcess; +} // namespace detail + class Scope; class System; class SystemBuilder; @@ -43,6 +49,14 @@ class SHORTFIN_API System : public std::enable_shared_from_this { System(const System &) = delete; ~System(); + // Sets a worker factory that will be used for all subsequently created + // Worker instances. Certain bindings and integrations may need special + // kinds of Worker classes, and this can customize that. + void set_worker_factory(Worker::Factory factory); + + // Explicit shutdown (vs in destructor) is encouraged. + void Shutdown(); + // Get a shared pointer from the instance. std::shared_ptr shared_ptr() { return shared_from_this(); } @@ -61,7 +75,19 @@ class SHORTFIN_API System : public std::enable_shared_from_this { // Creates a new Scope bound to this System (it will internally // hold a reference to this instance). All devices in system order will be // added to the scope. - std::unique_ptr CreateScope(); + std::shared_ptr CreateScope(Worker &worker); + + // Creates a scope bound to the init worker. + std::shared_ptr CreateScope(); + + // Creates and starts a worker (if it is configured to run in a thread). + Worker &CreateWorker(Worker::Options options); + + // Accesses the initialization worker that is intended to be run on the main + // or adopted thread to perform any async interactions with the system. + // Internally, this worker is called "__init__". It will be created on + // demand if it does not yet exist. + Worker &init_worker(); // Initialization APIs. Calls to these methods is only permitted between // construction and Initialize(). @@ -73,6 +99,7 @@ class SHORTFIN_API System : public std::enable_shared_from_this { void FinishInitialization(); private: + static std::unique_ptr DefaultWorkerFactory(Worker::Options options); void AssertNotInitialized() { if (initialized_) { throw std::logic_error( @@ -81,8 +108,19 @@ class SHORTFIN_API System : public std::enable_shared_from_this { } } + // Allocates a process in the process table and returns its new pid. + // This is done on process construction. Note that it acquires the + // system lock and is non-reentrant. + int64_t AllocateProcess(detail::BaseProcess *); + // Deallocates a process by pid. This is done on process destruction. Note + // that is acquires the system lock and is non-reentrant. + void DeallocateProcess(int64_t pid); + const iree_allocator_t host_allocator_; + string_interner interner_; + iree_slim_mutex lock_; + // NUMA nodes relevant to this system. std::vector nodes_; @@ -100,11 +138,20 @@ class SHORTFIN_API System : public std::enable_shared_from_this { iree_vm_instance_ptr vm_instance_; // Workers. + Worker::Factory worker_factory_ = System::DefaultWorkerFactory; std::vector> workers_; + std::unordered_map workers_by_name_; + + // Process management. + int next_pid_ = 1; + std::unordered_map processes_by_pid_; // Whether initialization is complete. If true, various low level // mutations are disallowed. bool initialized_ = false; + bool shutdown_ = false; + + friend class detail::BaseProcess; }; using SystemPtr = std::shared_ptr; diff --git a/libshortfin/src/shortfin/process/worker.cc b/libshortfin/src/shortfin/local/worker.cc similarity index 60% rename from libshortfin/src/shortfin/process/worker.cc rename to libshortfin/src/shortfin/local/worker.cc index 7276aa01e..ff5be12d6 100644 --- a/libshortfin/src/shortfin/process/worker.cc +++ b/libshortfin/src/shortfin/local/worker.cc @@ -4,11 +4,11 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "shortfin/process/worker.h" +#include "shortfin/local/worker.h" #include "shortfin/support/logging.h" -namespace shortfin { +namespace shortfin::local { Worker::Worker(const Options options) : options_(std::move(options)), @@ -34,14 +34,19 @@ Worker::~Worker() { thread_.reset(); } +std::string Worker::to_s() { + return fmt::format("", options_.name); +} + +void Worker::OnThreadStart() {} +void Worker::OnThreadStop() {} + iree_status_t Worker::TransactLoop(iree_status_t signal_status) { if (!iree_status_is_ok(signal_status)) { // TODO: Handle failure. return signal_status; } - logging::info("Transact loop!"); - { // An outside thread cannot change the state we are managing without // entering this critical section, so it is safe to reset the event @@ -63,31 +68,32 @@ iree_status_t Worker::TransactLoop(iree_status_t signal_status) { next_thunk(); } next_thunks_.clear(); + return ScheduleExternalTransactEvent(); +} - return iree_ok_status(); +iree_status_t Worker::ScheduleExternalTransactEvent() { + return iree_loop_wait_one( + loop_, signal_transact_.await(), iree_infinite_timeout(), + +[](void* self, iree_loop_t loop, iree_status_t status) { + return static_cast(self)->TransactLoop(status); + }, + this); } -int Worker::Run() { +int Worker::RunOnThread() { auto RunLoop = [&]() -> iree_status_t { + IREE_RETURN_IF_ERROR(ScheduleExternalTransactEvent()); for (;;) { { iree_slim_mutex_lock_guard guard(mu_); if (kill_) break; } - - // Need to re-add our transact event on each cycle because it was - // necessarily fulfilled each time. - IREE_RETURN_IF_ERROR(iree_loop_wait_one( - loop_, signal_transact_.await(), iree_infinite_timeout(), - +[](void* self, iree_loop_t loop, iree_status_t status) { - return static_cast(self)->TransactLoop(status); - }, - this)); - IREE_RETURN_IF_ERROR(iree_loop_drain(loop_, iree_infinite_timeout())); + IREE_RETURN_IF_ERROR(iree_loop_drain(loop_, options_.quantum)); } return iree_ok_status(); }; + OnThreadStart(); { auto loop_status = RunLoop(); if (!iree_status_is_ok(loop_status)) { @@ -95,18 +101,22 @@ int Worker::Run() { iree_status_abort(loop_status); } } + OnThreadStop(); signal_ended_.set(); return 0; } void Worker::Start() { + if (!options_.owned_thread) { + throw std::logic_error("Cannot start worker when owned_thread=false"); + } if (thread_) { throw std::logic_error("Cannot start Worker multiple times"); } auto EntryFunction = - +[](void* self) { return static_cast(self)->Run(); }; + +[](void* self) { return static_cast(self)->RunOnThread(); }; iree_thread_create_params_t params = { .name = {options_.name.data(), options_.name.size()}, // Need to make sure that the thread can access thread_ so need to create @@ -119,8 +129,8 @@ void Worker::Start() { } void Worker::Kill() { - if (!thread_) { - throw std::logic_error("Cannot Drain a Worker that was not started"); + if (options_.owned_thread && !thread_) { + throw std::logic_error("Cannot kill a Worker that was not started"); } { iree_slim_mutex_lock_guard guard(mu_); @@ -130,6 +140,9 @@ void Worker::Kill() { } void Worker::WaitForShutdown() { + if (!options_.owned_thread) { + throw std::logic_error("Cannot shutdown worker when owned_thread=false"); + } if (!thread_) { throw std::logic_error("Cannot Shutdown a Worker that was not started"); } @@ -147,7 +160,22 @@ void Worker::WaitForShutdown() { } } -void Worker::EnqueueCallback(std::function callback) { +void Worker::RunOnCurrentThread() { + if (options_.owned_thread) { + throw std::logic_error( + "Cannot RunOnCurrentThread if worker was configured for owned_thread"); + } + { + iree_slim_mutex_lock_guard guard(mu_); + if (has_run_) { + throw std::logic_error("Cannot RunOnCurrentThread if already finished"); + } + has_run_ = true; + } + RunOnThread(); +} + +void Worker::CallThreadsafe(std::function callback) { { iree_slim_mutex_lock_guard guard(mu_); pending_thunks_.push_back(std::move(callback)); @@ -155,4 +183,26 @@ void Worker::EnqueueCallback(std::function callback) { signal_transact_.set(); } -} // namespace shortfin +iree_status_t Worker::CallLowLevel( + iree_status_t (*callback)(void* user_data, iree_loop_t loop, + iree_status_t status) noexcept, + void* user_data, iree_loop_priority_e priority) noexcept { + return iree_loop_call(loop_, priority, callback, user_data); +} + +iree_status_t Worker::WaitUntilLowLevel( + iree_timeout_t timeout, + iree_status_t (*callback)(void* user_data, iree_loop_t loop, + iree_status_t status) noexcept, + void* user_data) { + return iree_loop_wait_until(loop_, timeout, callback, user_data); +} + +// Time management. +iree_time_t Worker::now() { return iree_time_now(); } +iree_time_t Worker::ConvertRelativeTimeoutToDeadlineNs( + iree_duration_t timeout_ns) { + return iree_relative_timeout_to_deadline_ns(timeout_ns); +} + +} // namespace shortfin::local diff --git a/libshortfin/src/shortfin/local/worker.h b/libshortfin/src/shortfin/local/worker.h new file mode 100644 index 000000000..7ec308f2c --- /dev/null +++ b/libshortfin/src/shortfin/local/worker.h @@ -0,0 +1,114 @@ +// Copyright 2024 Advanced Micro Devices, Inc +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef SHORTFIN_WORKER_H +#define SHORTFIN_WORKER_H + +#include +#include +#include + +#include "iree/base/loop_sync.h" +#include "shortfin/support/api.h" +#include "shortfin/support/iree_concurrency.h" + +namespace shortfin::local { + +// Cooperative worker. +class SHORTFIN_API Worker { + public: + struct Options { + iree_allocator_t allocator; + std::string name; + + // Controls the maximum duration that can transpire between the loop + // making an outer trip where it can exit and perform other outside-world + // maintenance. Without this, the loop could run forever if there was + // an infinite/long async wait or something. + iree_timeout_t quantum = iree_make_timeout_ms(500); + + // Whether to create the worker on an owned thread. If false, then the + // worker is set up to be adopted and a thread will not be created. + bool owned_thread = true; + + Options(iree_allocator_t allocator, std::string name) + : allocator(allocator), name(name) {} + }; + using Factory = std::function(Options options)>; + + Worker(Options options); + Worker(const Worker &) = delete; + virtual ~Worker(); + + const Options &options() const { return options_; } + const std::string_view name() const { return options_.name; } + std::string to_s(); + + void Start(); + void Kill(); + virtual void WaitForShutdown(); + + // Runs on the current thread. This is used instead of Start() when + // owned_thread is false. + void RunOnCurrentThread(); + + // Enqueues a callback to the worker from another thread. + void CallThreadsafe(std::function callback); + + // Operations that can be done from on the worker. + // Callback to execute user code on the loop "soon". This variant must not + // raise exceptions and matches the underlying C API. It should not generally + // be used by "regular users" but can be useful for bindings that wish to + // reduce the tolls/hops. + iree_status_t CallLowLevel( + iree_status_t (*callback)(void *user_data, iree_loop_t loop, + iree_status_t status) noexcept, + void *user_data, + iree_loop_priority_e priority = IREE_LOOP_PRIORITY_DEFAULT) noexcept; + + // Calls back after a timeout. + iree_status_t WaitUntilLowLevel( + iree_timeout_t timeout, + iree_status_t (*callback)(void *user_data, iree_loop_t loop, + iree_status_t status) noexcept, + void *user_data); + + // Time management. + // Returns the current absolute time in nanoseconds. + iree_time_t now(); + iree_time_t ConvertRelativeTimeoutToDeadlineNs(iree_duration_t timeout_ns); + + protected: + virtual void OnThreadStart(); + virtual void OnThreadStop(); + + private: + int RunOnThread(); + iree_status_t ScheduleExternalTransactEvent(); + iree_status_t TransactLoop(iree_status_t signal_status); + + const Options options_; + iree_slim_mutex mu_; + iree_thread_ptr thread_; + iree_event signal_transact_; + iree_event signal_ended_; + + // State management. These are all manipulated both on and off the worker + // thread. + std::vector> pending_thunks_; + bool kill_ = false; + bool has_run_ = false; + + // Loop management. This is all purely operated on the worker thread. + iree_loop_sync_scope_t loop_scope_; + iree_loop_sync_t *loop_sync_; + iree_loop_t loop_; + std::vector> next_thunks_; +}; + +} // namespace shortfin::local + +#endif // SHORTFIN_WORKER_H diff --git a/libshortfin/src/shortfin/process/CMakeLists.txt b/libshortfin/src/shortfin/process/CMakeLists.txt deleted file mode 100644 index d4e418168..000000000 --- a/libshortfin/src/shortfin/process/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. See -# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier: -# Apache-2.0 WITH LLVM-exception - -shortfin_cc_component( - NAME - shortfin_process - HDRS - worker.h - SRCS - worker.cc - COMPONENTS - shortfin_support - DEPS - iree_base_base - iree_base_loop_sync -) diff --git a/libshortfin/src/shortfin/process/worker.h b/libshortfin/src/shortfin/process/worker.h deleted file mode 100644 index a7b9e2d2b..000000000 --- a/libshortfin/src/shortfin/process/worker.h +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2024 Advanced Micro Devices, Inc -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef SHORTFIN_WORKER_H -#define SHORTFIN_WORKER_H - -#include -#include -#include - -#include "iree/base/loop_sync.h" -#include "shortfin/support/iree_concurrency.h" - -namespace shortfin { - -// Cooperative worker. -class Worker { - public: - struct Options { - iree_allocator_t allocator; - std::string name; - - Options(iree_allocator_t allocator, std::string name) - : allocator(allocator), name(name) {} - }; - - Worker(Options options); - Worker(const Worker &) = delete; - ~Worker(); - - const std::string_view name() const { return options_.name; } - - void Start(); - void Kill(); - void WaitForShutdown(); - - // Enqueues a callback on the worker. - void EnqueueCallback(std::function callback); - - private: - int Run(); - iree_status_t TransactLoop(iree_status_t signal_status); - - const Options options_; - iree_slim_mutex mu_; - iree_thread_ptr thread_; - iree_event signal_transact_; - iree_event signal_ended_; - - // State management. These are all manipulated both on and off the worker - // thread. - bool kill_ = false; - std::vector> pending_thunks_; - - // Loop management. This is all purely operated on the worker thread. - iree_loop_sync_scope_t loop_scope_; - iree_loop_sync_t *loop_sync_; - iree_loop_t loop_; - std::vector> next_thunks_; -}; - -} // namespace shortfin - -#endif // SHORTFIN_WORKER_H diff --git a/libshortfin/tests/array_test.py b/libshortfin/tests/array_test.py index 2aacf0bed..41cf51aa8 100644 --- a/libshortfin/tests/array_test.py +++ b/libshortfin/tests/array_test.py @@ -19,7 +19,9 @@ def lsys(): @pytest.fixture def scope(lsys): - return lsys.create_scope() + # TODO: Should adopt the main thread. + worker = lsys.create_worker("main") + return lsys.create_scope(worker) def test_storage(scope): diff --git a/libshortfin/tests/local_scope_test.py b/libshortfin/tests/local_scope_test.py index 476ca018a..de3598711 100644 --- a/libshortfin/tests/local_scope_test.py +++ b/libshortfin/tests/local_scope_test.py @@ -18,7 +18,9 @@ def lsys(): @pytest.fixture def scope(lsys): - return lsys.create_scope() + # TODO: Should adopt the main thread. + worker = lsys.create_worker("main") + return lsys.create_scope(worker) def test_raw_device_access(scope):