Skip to content

Commit

Permalink
Merge branch 'branch-0.41' into python-blocking-progress-improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev authored Nov 18, 2024
2 parents c317f82 + 042261d commit 546905d
Show file tree
Hide file tree
Showing 29 changed files with 412 additions and 175 deletions.
1 change: 1 addition & 0 deletions ci/build_wheel_distributed_ucxx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ package_dir="python/distributed-ucxx"
RAPIDS_PY_CUDA_SUFFIX="$(rapids-wheel-ctk-name-gen ${RAPIDS_CUDA_VERSION})"

./ci/build_wheel.sh distributed-ucxx "${package_dir}"
./ci/validate_wheel.sh "${package_dir}" dist

RAPIDS_PY_WHEEL_NAME="distributed_ucxx_${RAPIDS_PY_CUDA_SUFFIX}" rapids-upload-wheels-to-s3 python "${package_dir}/dist"
2 changes: 2 additions & 0 deletions ci/build_wheel_libucxx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,6 @@ python -m auditwheel repair \
-w "${package_dir}/final_dist" \
${package_dir}/dist/*

./ci/validate_wheel.sh "${package_dir}" final_dist

RAPIDS_PY_WHEEL_NAME="${package_name}_${RAPIDS_PY_CUDA_SUFFIX}" rapids-upload-wheels-to-s3 cpp "${package_dir}/final_dist"
2 changes: 2 additions & 0 deletions ci/build_wheel_ucxx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,6 @@ python -m auditwheel repair \
-w "${package_dir}/final_dist" \
${package_dir}/dist/*

./ci/validate_wheel.sh "${package_dir}" final_dist

RAPIDS_PY_WHEEL_NAME="${package_name}_${RAPIDS_PY_CUDA_SUFFIX}" rapids-upload-wheels-to-s3 python "${package_dir}/final_dist"
4 changes: 2 additions & 2 deletions ci/test_common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@ run_py_tests_async() {
ENABLE_PYTHON_FUTURE=$3
SKIP=$4

CMD_LINE="UCXPY_PROGRESS_MODE=${PROGRESS_MODE} UCXPY_ENABLE_DELAYED_SUBMISSION=${ENABLE_DELAYED_SUBMISSION} UCXPY_ENABLE_PYTHON_FUTURE=${ENABLE_PYTHON_FUTURE} timeout 20m python -m pytest -vs python/ucxx/ucxx/_lib_async/tests/ --durations=50"
CMD_LINE="UCXPY_PROGRESS_MODE=${PROGRESS_MODE} UCXPY_ENABLE_DELAYED_SUBMISSION=${ENABLE_DELAYED_SUBMISSION} UCXPY_ENABLE_PYTHON_FUTURE=${ENABLE_PYTHON_FUTURE} timeout 30m python -m pytest -vs python/ucxx/ucxx/_lib_async/tests/ --runslow"

if [ $SKIP -ne 0 ]; then
echo -e "\e[1;33mSkipping unstable test: ${CMD_LINE}\e[0m"
else
log_command "${CMD_LINE}"
UCXPY_PROGRESS_MODE=${PROGRESS_MODE} UCXPY_ENABLE_DELAYED_SUBMISSION=${ENABLE_DELAYED_SUBMISSION} UCXPY_ENABLE_PYTHON_FUTURE=${ENABLE_PYTHON_FUTURE} timeout 20m python -m pytest -vs python/ucxx/ucxx/_lib_async/tests/ --durations=50
UCXPY_PROGRESS_MODE=${PROGRESS_MODE} UCXPY_ENABLE_DELAYED_SUBMISSION=${ENABLE_DELAYED_SUBMISSION} UCXPY_ENABLE_PYTHON_FUTURE=${ENABLE_PYTHON_FUTURE} timeout 30m python -m pytest -vs python/ucxx/ucxx/_lib_async/tests/ --runslow
fi
}

Expand Down
21 changes: 21 additions & 0 deletions ci/validate_wheel.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/bin/bash
# Copyright (c) 2024, NVIDIA CORPORATION.

set -euo pipefail

package_dir=$1
wheel_dir_relative_path=$2

cd "${package_dir}"

rapids-logger "validate packages with 'pydistcheck'"

pydistcheck \
--inspect \
"$(echo ${wheel_dir_relative_path}/*.whl)"

rapids-logger "validate packages with 'twine'"

twine check \
--strict \
"$(echo ${wheel_dir_relative_path}/*.whl)"
2 changes: 1 addition & 1 deletion cpp/include/ucxx/request.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class Request : public Component {
/**
* @brief Cancel the request.
*
* Cancel the request. Often called by the an error handler or parent's object
* Cancel the request. Often called by the error handler or parent's object
* destructor but may be called by the user to cancel the request as well.
*/
virtual void cancel();
Expand Down
8 changes: 8 additions & 0 deletions cpp/include/ucxx/request_am.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ class RequestAm : public Request {
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData);

/**
* @brief Cancel the request.
*
* Cancel the request. Often called by the error handler or parent's object
* destructor but may be called by the user to cancel the request as well.
*/
void cancel() override;

void populateDelayedSubmission() override;

/**
Expand Down
13 changes: 12 additions & 1 deletion cpp/include/ucxx/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ class Worker : public Component {
[[nodiscard]] bool isFutureEnabled() const;

/**
* @brief Populate the future pool.
* @brief Populate the futures pool.
*
* To avoid taking blocking resources (such as the Python GIL) for every new future
* required by each `ucxx::Request`, the `ucxx::Worker` maintains a pool of futures
Expand All @@ -512,6 +512,17 @@ class Worker : public Component {
*/
virtual void populateFuturesPool();

/**
* @brief Clear the futures pool.
*
* Clear the futures pool, ensuring all references are removed and thus avoiding
* reference cycles that prevent the `ucxx::Worker` and other resources from cleaning
* up on time.
*
* @throws std::runtime_error if future support is not implemented.
*/
virtual void clearFuturesPool();

/**
* @brief Get a future from the pool.
*
Expand Down
15 changes: 14 additions & 1 deletion cpp/python/include/ucxx/python/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,28 @@ class Worker : public ::ucxx::Worker {
std::shared_ptr<Context> context, const bool enableDelayedSubmission, const bool enableFuture);

/**
* @brief Populate the Python future pool.
* @brief Populate the Python futures pool.
*
* To avoid taking the Python GIL for every new future required by each `ucxx::Request`,
* the `ucxx::python::Worker` maintains a pool of futures that can be acquired when a new
* `ucxx::Request` is created. Currently the pool has a maximum size of 100 objects, and
* will refill once it goes under 50, otherwise calling this functions results in a no-op.
*
* @throws std::runtime_error if object was created with `enableFuture=false`.
*/
void populateFuturesPool() override;

/**
* @brief Clear the futures pool.
*
* Clear the futures pool, ensuring all references are removed and thus avoiding
* reference cycles that prevent the `ucxx::Worker` and other resources from cleaning
* up on time.
*
* This method is safe to be called even if object was created with `enableFuture=false`.
*/
void clearFuturesPool() override;

/**
* @brief Get a Python future from the pool.
*
Expand Down
15 changes: 15 additions & 0 deletions cpp/python/src/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,21 @@ void Worker::populateFuturesPool()
}
}

void Worker::clearFuturesPool()
{
if (_enableFuture) {
ucxx_trace_req("ucxx::python::Worker::%s, Worker: %p, populateFuturesPool: %p",
__func__,
this,
shared_from_this().get());
std::lock_guard<std::mutex> lock(_futuresPoolMutex);
PyGILState_STATE state = PyGILState_Ensure();
decltype(_futuresPool) newFuturesPool;
std::swap(_futuresPool, newFuturesPool);
PyGILState_Release(state);
}
}

std::shared_ptr<::ucxx::Future> Worker::getFuture()
{
if (_enableFuture) {
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/endpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts)
if (_endpointErrorHandling)
param = {.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS, .flags = UCP_EP_CLOSE_FLAG_FORCE};

auto worker = ::ucxx::getWorker(_parent);
ucs_status_ptr_t status;
auto worker = ::ucxx::getWorker(_parent);
ucs_status_ptr_t status = nullptr;

if (worker->isProgressThreadRunning()) {
bool closeSuccess = false;
Expand Down
55 changes: 46 additions & 9 deletions cpp/src/request_am.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,26 @@ RequestAm::RequestAm(std::shared_ptr<Component> endpointOrWorker,
requestData);
}

void RequestAm::cancel()
{
std::lock_guard<std::recursive_mutex> lock(_mutex);
if (_status == UCS_INPROGRESS) {
/**
* This is needed to ensure AM requests are cancelable, since they do not
* use the `_request`, thus `ucp_request_cancel()` cannot cancel them.
*/
setStatus(UCS_ERR_CANCELED);
} else {
ucxx_trace_req_f(_ownerString.c_str(),
this,
_request,
_operationName.c_str(),
"already completed with status: %d (%s)",
_status,
ucs_status_string(_status));
}
}

static void _amSendCallback(void* request, ucs_status_t status, void* user_data)
{
Request* req = reinterpret_cast<Request*>(user_data);
Expand Down Expand Up @@ -248,19 +268,29 @@ ucs_status_t RequestAm::recvCallback(void* arg,
amHeader.memoryType = UCS_MEMORY_TYPE_HOST;
}

std::shared_ptr<Buffer> buf = amData->_allocators.at(amHeader.memoryType)(length);
try {
buf = amData->_allocators.at(amHeader.memoryType)(length);
} catch (const std::exception& e) {
ucxx_debug("Exception calling allocator: %s", e.what());
}

auto recvAmMessage =
std::make_shared<internal::RecvAmMessage>(amData, ep, req, buf, receiverCallback);

ucp_request_param_t request_param = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
UCP_OP_ATTR_FIELD_USER_DATA |
UCP_OP_ATTR_FLAG_NO_IMM_CMPL,
.cb = {.recv_am = _recvCompletedCallback},
.user_data = recvAmMessage.get()};
ucp_request_param_t requestParam = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
UCP_OP_ATTR_FIELD_USER_DATA |
UCP_OP_ATTR_FLAG_NO_IMM_CMPL,
.cb = {.recv_am = _recvCompletedCallback},
.user_data = recvAmMessage.get()};

if (buf == nullptr) {
ucxx_debug("Failed to allocate %lu bytes of memory", length);
recvAmMessage->_request->setStatus(UCS_ERR_NO_MEMORY);
return UCS_ERR_NO_MEMORY;
}

ucs_status_ptr_t status =
ucp_am_recv_data_nbx(worker->getHandle(), data, buf->data(), length, &request_param);
ucp_am_recv_data_nbx(worker->getHandle(), data, buf->data(), length, &requestParam);

if (req->_enablePythonFuture)
ucxx_trace_req_f(ownerString.c_str(),
Expand Down Expand Up @@ -302,7 +332,15 @@ ucs_status_t RequestAm::recvCallback(void* arg,
return UCS_INPROGRESS;
}
} else {
std::shared_ptr<Buffer> buf = amData->_allocators.at(UCS_MEMORY_TYPE_HOST)(length);
buf = amData->_allocators.at(UCS_MEMORY_TYPE_HOST)(length);

internal::RecvAmMessage recvAmMessage(amData, ep, req, buf, receiverCallback);
if (buf == nullptr) {
ucxx_debug("Failed to allocate %lu bytes of memory", length);
recvAmMessage._request->setStatus(UCS_ERR_NO_MEMORY);
return UCS_ERR_NO_MEMORY;
}

if (length > 0) memcpy(buf->data(), data, length);

if (req->_enablePythonFuture)
Expand All @@ -326,7 +364,6 @@ ucs_status_t RequestAm::recvCallback(void* arg,
buf->data(),
length);

internal::RecvAmMessage recvAmMessage(amData, ep, req, buf, receiverCallback);
recvAmMessage.callback(nullptr, UCS_OK);
return UCS_OK;
}
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ bool Worker::registerGenericPost(DelayedSubmissionCallbackType callback, uint64_

void Worker::populateFuturesPool() { THROW_FUTURE_NOT_IMPLEMENTED(); }

void Worker::clearFuturesPool() { THROW_FUTURE_NOT_IMPLEMENTED(); }

std::shared_ptr<Future> Worker::getFuture() { THROW_FUTURE_NOT_IMPLEMENTED(); }

RequestNotifierWaitState Worker::waitRequestNotifier(uint64_t periodNs)
Expand Down
36 changes: 30 additions & 6 deletions python/distributed-ucxx/distributed_ucxx/ucxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ def _deregister_dask_resource(resource_id):
# Stop notifier thread and progress tasks if no Dask resources using
# UCXX communicators are running anymore.
if len(ctx._dask_resources) == 0:
ctx.stop_notifier_thread()
ctx.progress_tasks.clear()
ucxx.stop_notifier_thread()
ctx.clear_progress_tasks()


def _allocate_dask_resources_tracker() -> None:
Expand Down Expand Up @@ -307,6 +307,25 @@ def _close_comm(ref):
comm._closed = True


def _finalizer(endpoint: ucxx.Endpoint, resource_id: int) -> None:
"""UCXX comms object finalizer.
Attempt to close the UCXX endpoint if it's still alive, and deregister Dask
resource.
Parameters
----------
endpoint: ucx_api.UCXEndpoint
The endpoint to close.
resource_id: int
The unique ID of the resource returned by `_register_dask_resource` upon
registration.
"""
if endpoint is not None:
endpoint.abort()
_deregister_dask_resource(resource_id)


class UCXX(Comm):
"""Comm object using UCXX.
Expand Down Expand Up @@ -375,14 +394,19 @@ def __init__( # type: ignore[no-untyped-def]
else:
self._has_close_callback = False

self._resource_id = _register_dask_resource()
resource_id = _register_dask_resource()
self._resource_id = resource_id

logger.debug("UCX.__init__ %s", self)

weakref.finalize(self, _deregister_dask_resource, self._resource_id)
weakref.finalize(self, _finalizer, ep, resource_id)

def __del__(self) -> None:
self.abort()
def abort(self):
self._closed = True
if self._ep is not None:
self._ep.abort()
self._ep = None
_deregister_dask_resource(self._resource_id)

@property
def local_address(self) -> str:
Expand Down
8 changes: 8 additions & 0 deletions python/distributed-ucxx/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@ matrix-entry = "cuda_suffixed=true"
[tool.setuptools.dynamic]
version = {file = "distributed_ucxx/VERSION"}

[tool.pydistcheck]
select = [
"distro-too-large-compressed",
]

# PyPI limit is 100 MiB, fail CI before we get too close to that
max_allowed_size_compressed = '75M'

[tool.pytest.ini_options]
markers = [
"ignore_alive_references",
Expand Down
Loading

0 comments on commit 546905d

Please sign in to comment.