diff --git a/cpp/include/ucxx/worker.h b/cpp/include/ucxx/worker.h index c9808d0b..30fd1220 100644 --- a/cpp/include/ucxx/worker.h +++ b/cpp/include/ucxx/worker.h @@ -500,7 +500,7 @@ class Worker : public Component { [[nodiscard]] bool isFutureEnabled() const; /** - * @brief Populate the future pool. + * @brief Populate the futures pool. * * To avoid taking blocking resources (such as the Python GIL) for every new future * required by each `ucxx::Request`, the `ucxx::Worker` maintains a pool of futures @@ -512,6 +512,17 @@ class Worker : public Component { */ virtual void populateFuturesPool(); + /** + * @brief Clear the futures pool. + * + * Clear the futures pool, ensuring all references are removed and thus avoiding + * reference cycles that prevent the `ucxx::Worker` and other resources from cleaning + * up on time. + * + * @throws std::runtime_error if future support is not implemented. + */ + virtual void clearFuturesPool(); + /** * @brief Get a future from the pool. * diff --git a/cpp/python/include/ucxx/python/worker.h b/cpp/python/include/ucxx/python/worker.h index be4b77fd..3140305a 100644 --- a/cpp/python/include/ucxx/python/worker.h +++ b/cpp/python/include/ucxx/python/worker.h @@ -86,15 +86,28 @@ class Worker : public ::ucxx::Worker { std::shared_ptr context, const bool enableDelayedSubmission, const bool enableFuture); /** - * @brief Populate the Python future pool. + * @brief Populate the Python futures pool. * * To avoid taking the Python GIL for every new future required by each `ucxx::Request`, * the `ucxx::python::Worker` maintains a pool of futures that can be acquired when a new * `ucxx::Request` is created. Currently the pool has a maximum size of 100 objects, and * will refill once it goes under 50, otherwise calling this functions results in a no-op. + * + * @throws std::runtime_error if object was created with `enableFuture=false`. */ void populateFuturesPool() override; + /** + * @brief Clear the futures pool. + * + * Clear the futures pool, ensuring all references are removed and thus avoiding + * reference cycles that prevent the `ucxx::Worker` and other resources from cleaning + * up on time. + * + * This method is safe to be called even if object was created with `enableFuture=false`. + */ + void clearFuturesPool() override; + /** * @brief Get a Python future from the pool. * diff --git a/cpp/python/src/worker.cpp b/cpp/python/src/worker.cpp index 4534e952..b248c73d 100644 --- a/cpp/python/src/worker.cpp +++ b/cpp/python/src/worker.cpp @@ -71,6 +71,21 @@ void Worker::populateFuturesPool() } } +void Worker::clearFuturesPool() +{ + if (_enableFuture) { + ucxx_trace_req("ucxx::python::Worker::%s, Worker: %p, populateFuturesPool: %p", + __func__, + this, + shared_from_this().get()); + std::lock_guard lock(_futuresPoolMutex); + PyGILState_STATE state = PyGILState_Ensure(); + decltype(_futuresPool) newFuturesPool; + std::swap(_futuresPool, newFuturesPool); + PyGILState_Release(state); + } +} + std::shared_ptr<::ucxx::Future> Worker::getFuture() { if (_enableFuture) { diff --git a/cpp/src/worker.cpp b/cpp/src/worker.cpp index 2adf8eca..a266493b 100644 --- a/cpp/src/worker.cpp +++ b/cpp/src/worker.cpp @@ -422,6 +422,8 @@ bool Worker::registerGenericPost(DelayedSubmissionCallbackType callback, uint64_ void Worker::populateFuturesPool() { THROW_FUTURE_NOT_IMPLEMENTED(); } +void Worker::clearFuturesPool() { THROW_FUTURE_NOT_IMPLEMENTED(); } + std::shared_ptr Worker::getFuture() { THROW_FUTURE_NOT_IMPLEMENTED(); } RequestNotifierWaitState Worker::waitRequestNotifier(uint64_t periodNs) diff --git a/python/distributed-ucxx/distributed_ucxx/ucxx.py b/python/distributed-ucxx/distributed_ucxx/ucxx.py index 53785b78..df820a83 100644 --- a/python/distributed-ucxx/distributed_ucxx/ucxx.py +++ b/python/distributed-ucxx/distributed_ucxx/ucxx.py @@ -157,8 +157,8 @@ def _deregister_dask_resource(resource_id): # Stop notifier thread and progress tasks if no Dask resources using # UCXX communicators are running anymore. if len(ctx._dask_resources) == 0: - ctx.stop_notifier_thread() - ctx.progress_tasks.clear() + ucxx.stop_notifier_thread() + ctx.clear_progress_tasks() def _allocate_dask_resources_tracker() -> None: diff --git a/python/ucxx/ucxx/_lib/libucxx.pyx b/python/ucxx/ucxx/_lib/libucxx.pyx index 09bf68d4..434bef16 100644 --- a/python/ucxx/ucxx/_lib/libucxx.pyx +++ b/python/ucxx/ucxx/_lib/libucxx.pyx @@ -741,6 +741,10 @@ cdef class UCXWorker(): with nogil: self._worker.get().populateFuturesPool() + def clear_python_futures_pool(self) -> None: + with nogil: + self._worker.get().clearFuturesPool() + def is_delayed_submission_enabled(self) -> bool: warnings.warn( "UCXWorker.is_delayed_submission_enabled() is deprecated and will soon " diff --git a/python/ucxx/ucxx/_lib/ucxx_api.pxd b/python/ucxx/ucxx/_lib/ucxx_api.pxd index ec8cef1a..88512a4f 100644 --- a/python/ucxx/ucxx/_lib/ucxx_api.pxd +++ b/python/ucxx/ucxx/_lib/ucxx_api.pxd @@ -251,6 +251,7 @@ cdef extern from "" namespace "ucxx" nogil: ) except +raise_py_error void runRequestNotifier() except +raise_py_error void populateFuturesPool() except +raise_py_error + void clearFuturesPool() shared_ptr[Request] tagRecv( void* buffer, size_t length, diff --git a/python/ucxx/ucxx/_lib_async/application_context.py b/python/ucxx/ucxx/_lib_async/application_context.py index 4a488309..6aac4346 100644 --- a/python/ucxx/ucxx/_lib_async/application_context.py +++ b/python/ucxx/ucxx/_lib_async/application_context.py @@ -23,6 +23,14 @@ logger = logging.getLogger("ucx") +ProgressTasks = dict() + + +def clear_progress_tasks(): + global ProgressTasks + ProgressTasks.clear() + + class ApplicationContext: """ The context of the Asyncio interface of UCX. @@ -40,7 +48,6 @@ def __init__( enable_python_future=None, exchange_peer_info_timeout=10.0, ): - self.progress_tasks = dict() self.notifier_thread_q = None self.notifier_thread = None self._listener_active_clients = ActiveClients() @@ -62,7 +69,7 @@ def __init__( self.start_notifier_thread() - weakref.finalize(self, self.progress_tasks.clear) + weakref.finalize(self, clear_progress_tasks) # Ensure progress even before Endpoints get created, for example to # receive messages directly on a worker after a remote endpoint @@ -194,6 +201,10 @@ def ucp_worker_info(self): def worker_address(self): return self.worker.address + def clear_progress_tasks(self) -> None: + global ProgressTasks + ProgressTasks.clear() + def start_notifier_thread(self): if self.worker.enable_python_future and self.notifier_thread is None: logger.debug("UCXX_ENABLE_PYTHON available, enabling notifier thread") @@ -456,7 +467,8 @@ def continuous_ucx_progress(self, event_loop=None): Python 3.10+) is used. """ loop = event_loop if event_loop is not None else get_event_loop() - if loop in self.progress_tasks: + global ProgressTasks + if loop in ProgressTasks: return # Progress has already been guaranteed for the current event loop if self.progress_mode == "thread": @@ -468,7 +480,7 @@ def continuous_ucx_progress(self, event_loop=None): elif self.progress_mode == "blocking": task = BlockingMode(self.worker, loop) - self.progress_tasks[loop] = task + ProgressTasks[loop] = task def get_ucp_worker(self): """Returns the underlying UCP worker handle (ucp_worker_h) diff --git a/python/ucxx/ucxx/_lib_async/notifier_thread.py b/python/ucxx/ucxx/_lib_async/notifier_thread.py index 23254a36..d906860a 100644 --- a/python/ucxx/ucxx/_lib_async/notifier_thread.py +++ b/python/ucxx/ucxx/_lib_async/notifier_thread.py @@ -47,7 +47,7 @@ def _notifierThread(event_loop, worker, q): ) if state == ucx_api.PythonRequestNotifierWaitState.Shutdown or shutdown is True: - return + break elif state == ucx_api.PythonRequestNotifierWaitState.Timeout: continue @@ -62,3 +62,7 @@ def _notifierThread(event_loop, worker, q): logger.debug("Notifier Thread Result Timeout") except Exception as e: logger.debug(f"Notifier Thread Result Exception: {e}") + + # Clear all Python futures to ensure no references are held to the + # `ucxx::Worker` that will prevent destructors from running. + worker.clear_python_futures_pool()