Skip to content

Commit

Permalink
Merge pull request #327 from rapidsai/branch-0.41
Browse files Browse the repository at this point in the history
Forward-merge branch-0.41 into branch-0.42
  • Loading branch information
GPUtester authored Nov 15, 2024
2 parents f1fd6ee + c4e8e04 commit 6b78d1f
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 9 deletions.
13 changes: 12 additions & 1 deletion cpp/include/ucxx/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ class Worker : public Component {
[[nodiscard]] bool isFutureEnabled() const;

/**
* @brief Populate the future pool.
* @brief Populate the futures pool.
*
* To avoid taking blocking resources (such as the Python GIL) for every new future
* required by each `ucxx::Request`, the `ucxx::Worker` maintains a pool of futures
Expand All @@ -512,6 +512,17 @@ class Worker : public Component {
*/
virtual void populateFuturesPool();

/**
* @brief Clear the futures pool.
*
* Clear the futures pool, ensuring all references are removed and thus avoiding
* reference cycles that prevent the `ucxx::Worker` and other resources from cleaning
* up on time.
*
* @throws std::runtime_error if future support is not implemented.
*/
virtual void clearFuturesPool();

/**
* @brief Get a future from the pool.
*
Expand Down
15 changes: 14 additions & 1 deletion cpp/python/include/ucxx/python/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,28 @@ class Worker : public ::ucxx::Worker {
std::shared_ptr<Context> context, const bool enableDelayedSubmission, const bool enableFuture);

/**
* @brief Populate the Python future pool.
* @brief Populate the Python futures pool.
*
* To avoid taking the Python GIL for every new future required by each `ucxx::Request`,
* the `ucxx::python::Worker` maintains a pool of futures that can be acquired when a new
* `ucxx::Request` is created. Currently the pool has a maximum size of 100 objects, and
* will refill once it goes under 50, otherwise calling this functions results in a no-op.
*
* @throws std::runtime_error if object was created with `enableFuture=false`.
*/
void populateFuturesPool() override;

/**
* @brief Clear the futures pool.
*
* Clear the futures pool, ensuring all references are removed and thus avoiding
* reference cycles that prevent the `ucxx::Worker` and other resources from cleaning
* up on time.
*
* This method is safe to be called even if object was created with `enableFuture=false`.
*/
void clearFuturesPool() override;

/**
* @brief Get a Python future from the pool.
*
Expand Down
15 changes: 15 additions & 0 deletions cpp/python/src/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,21 @@ void Worker::populateFuturesPool()
}
}

void Worker::clearFuturesPool()
{
if (_enableFuture) {
ucxx_trace_req("ucxx::python::Worker::%s, Worker: %p, populateFuturesPool: %p",
__func__,
this,
shared_from_this().get());
std::lock_guard<std::mutex> lock(_futuresPoolMutex);
PyGILState_STATE state = PyGILState_Ensure();
decltype(_futuresPool) newFuturesPool;
std::swap(_futuresPool, newFuturesPool);
PyGILState_Release(state);
}
}

std::shared_ptr<::ucxx::Future> Worker::getFuture()
{
if (_enableFuture) {
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ bool Worker::registerGenericPost(DelayedSubmissionCallbackType callback, uint64_

void Worker::populateFuturesPool() { THROW_FUTURE_NOT_IMPLEMENTED(); }

void Worker::clearFuturesPool() { THROW_FUTURE_NOT_IMPLEMENTED(); }

std::shared_ptr<Future> Worker::getFuture() { THROW_FUTURE_NOT_IMPLEMENTED(); }

RequestNotifierWaitState Worker::waitRequestNotifier(uint64_t periodNs)
Expand Down
4 changes: 2 additions & 2 deletions python/distributed-ucxx/distributed_ucxx/ucxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ def _deregister_dask_resource(resource_id):
# Stop notifier thread and progress tasks if no Dask resources using
# UCXX communicators are running anymore.
if len(ctx._dask_resources) == 0:
ctx.stop_notifier_thread()
ctx.progress_tasks.clear()
ucxx.stop_notifier_thread()
ctx.clear_progress_tasks()


def _allocate_dask_resources_tracker() -> None:
Expand Down
4 changes: 4 additions & 0 deletions python/ucxx/ucxx/_lib/libucxx.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,10 @@ cdef class UCXWorker():
with nogil:
self._worker.get().populateFuturesPool()

def clear_python_futures_pool(self) -> None:
with nogil:
self._worker.get().clearFuturesPool()

def is_delayed_submission_enabled(self) -> bool:
warnings.warn(
"UCXWorker.is_delayed_submission_enabled() is deprecated and will soon "
Expand Down
1 change: 1 addition & 0 deletions python/ucxx/ucxx/_lib/ucxx_api.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ cdef extern from "<ucxx/api.h>" namespace "ucxx" nogil:
) except +raise_py_error
void runRequestNotifier() except +raise_py_error
void populateFuturesPool() except +raise_py_error
void clearFuturesPool()
shared_ptr[Request] tagRecv(
void* buffer,
size_t length,
Expand Down
20 changes: 16 additions & 4 deletions python/ucxx/ucxx/_lib_async/application_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@
logger = logging.getLogger("ucx")


ProgressTasks = dict()


def clear_progress_tasks():
global ProgressTasks
ProgressTasks.clear()


class ApplicationContext:
"""
The context of the Asyncio interface of UCX.
Expand All @@ -40,7 +48,6 @@ def __init__(
enable_python_future=None,
exchange_peer_info_timeout=10.0,
):
self.progress_tasks = dict()
self.notifier_thread_q = None
self.notifier_thread = None
self._listener_active_clients = ActiveClients()
Expand All @@ -62,7 +69,7 @@ def __init__(

self.start_notifier_thread()

weakref.finalize(self, self.progress_tasks.clear)
weakref.finalize(self, clear_progress_tasks)

# Ensure progress even before Endpoints get created, for example to
# receive messages directly on a worker after a remote endpoint
Expand Down Expand Up @@ -194,6 +201,10 @@ def ucp_worker_info(self):
def worker_address(self):
return self.worker.address

def clear_progress_tasks(self) -> None:
global ProgressTasks
ProgressTasks.clear()

def start_notifier_thread(self):
if self.worker.enable_python_future and self.notifier_thread is None:
logger.debug("UCXX_ENABLE_PYTHON available, enabling notifier thread")
Expand Down Expand Up @@ -456,7 +467,8 @@ def continuous_ucx_progress(self, event_loop=None):
Python 3.10+) is used.
"""
loop = event_loop if event_loop is not None else get_event_loop()
if loop in self.progress_tasks:
global ProgressTasks
if loop in ProgressTasks:
return # Progress has already been guaranteed for the current event loop

if self.progress_mode == "thread":
Expand All @@ -468,7 +480,7 @@ def continuous_ucx_progress(self, event_loop=None):
elif self.progress_mode == "blocking":
task = BlockingMode(self.worker, loop)

self.progress_tasks[loop] = task
ProgressTasks[loop] = task

def get_ucp_worker(self):
"""Returns the underlying UCP worker handle (ucp_worker_h)
Expand Down
6 changes: 5 additions & 1 deletion python/ucxx/ucxx/_lib_async/notifier_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _notifierThread(event_loop, worker, q):
)

if state == ucx_api.PythonRequestNotifierWaitState.Shutdown or shutdown is True:
return
break
elif state == ucx_api.PythonRequestNotifierWaitState.Timeout:
continue

Expand All @@ -62,3 +62,7 @@ def _notifierThread(event_loop, worker, q):
logger.debug("Notifier Thread Result Timeout")
except Exception as e:
logger.debug(f"Notifier Thread Result Exception: {e}")

# Clear all Python futures to ensure no references are held to the
# `ucxx::Worker` that will prevent destructors from running.
worker.clear_python_futures_pool()

0 comments on commit 6b78d1f

Please sign in to comment.