diff --git a/ci/build_wheel_distributed_ucxx.sh b/ci/build_wheel_distributed_ucxx.sh index f6ee95a7..512e36bd 100755 --- a/ci/build_wheel_distributed_ucxx.sh +++ b/ci/build_wheel_distributed_ucxx.sh @@ -8,5 +8,6 @@ package_dir="python/distributed-ucxx" RAPIDS_PY_CUDA_SUFFIX="$(rapids-wheel-ctk-name-gen ${RAPIDS_CUDA_VERSION})" ./ci/build_wheel.sh distributed-ucxx "${package_dir}" +./ci/validate_wheel.sh "${package_dir}" dist RAPIDS_PY_WHEEL_NAME="distributed_ucxx_${RAPIDS_PY_CUDA_SUFFIX}" rapids-upload-wheels-to-s3 python "${package_dir}/dist" diff --git a/ci/build_wheel_libucxx.sh b/ci/build_wheel_libucxx.sh index e9262077..4ae818b5 100755 --- a/ci/build_wheel_libucxx.sh +++ b/ci/build_wheel_libucxx.sh @@ -37,4 +37,6 @@ python -m auditwheel repair \ -w "${package_dir}/final_dist" \ ${package_dir}/dist/* +./ci/validate_wheel.sh "${package_dir}" final_dist + RAPIDS_PY_WHEEL_NAME="${package_name}_${RAPIDS_PY_CUDA_SUFFIX}" rapids-upload-wheels-to-s3 cpp "${package_dir}/final_dist" diff --git a/ci/build_wheel_ucxx.sh b/ci/build_wheel_ucxx.sh index 78602085..b8e60c4c 100755 --- a/ci/build_wheel_ucxx.sh +++ b/ci/build_wheel_ucxx.sh @@ -28,4 +28,6 @@ python -m auditwheel repair \ -w "${package_dir}/final_dist" \ ${package_dir}/dist/* +./ci/validate_wheel.sh "${package_dir}" final_dist + RAPIDS_PY_WHEEL_NAME="${package_name}_${RAPIDS_PY_CUDA_SUFFIX}" rapids-upload-wheels-to-s3 python "${package_dir}/final_dist" diff --git a/ci/test_common.sh b/ci/test_common.sh index 48592e36..d6f2754a 100755 --- a/ci/test_common.sh +++ b/ci/test_common.sh @@ -128,13 +128,13 @@ run_py_tests_async() { ENABLE_PYTHON_FUTURE=$3 SKIP=$4 - CMD_LINE="UCXPY_PROGRESS_MODE=${PROGRESS_MODE} UCXPY_ENABLE_DELAYED_SUBMISSION=${ENABLE_DELAYED_SUBMISSION} UCXPY_ENABLE_PYTHON_FUTURE=${ENABLE_PYTHON_FUTURE} timeout 20m python -m pytest -vs python/ucxx/ucxx/_lib_async/tests/ --durations=50" + CMD_LINE="UCXPY_PROGRESS_MODE=${PROGRESS_MODE} UCXPY_ENABLE_DELAYED_SUBMISSION=${ENABLE_DELAYED_SUBMISSION} UCXPY_ENABLE_PYTHON_FUTURE=${ENABLE_PYTHON_FUTURE} timeout 30m python -m pytest -vs python/ucxx/ucxx/_lib_async/tests/ --runslow" if [ $SKIP -ne 0 ]; then echo -e "\e[1;33mSkipping unstable test: ${CMD_LINE}\e[0m" else log_command "${CMD_LINE}" - UCXPY_PROGRESS_MODE=${PROGRESS_MODE} UCXPY_ENABLE_DELAYED_SUBMISSION=${ENABLE_DELAYED_SUBMISSION} UCXPY_ENABLE_PYTHON_FUTURE=${ENABLE_PYTHON_FUTURE} timeout 20m python -m pytest -vs python/ucxx/ucxx/_lib_async/tests/ --durations=50 + UCXPY_PROGRESS_MODE=${PROGRESS_MODE} UCXPY_ENABLE_DELAYED_SUBMISSION=${ENABLE_DELAYED_SUBMISSION} UCXPY_ENABLE_PYTHON_FUTURE=${ENABLE_PYTHON_FUTURE} timeout 30m python -m pytest -vs python/ucxx/ucxx/_lib_async/tests/ --runslow fi } diff --git a/ci/validate_wheel.sh b/ci/validate_wheel.sh new file mode 100755 index 00000000..5910a5c5 --- /dev/null +++ b/ci/validate_wheel.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# Copyright (c) 2024, NVIDIA CORPORATION. + +set -euo pipefail + +package_dir=$1 +wheel_dir_relative_path=$2 + +cd "${package_dir}" + +rapids-logger "validate packages with 'pydistcheck'" + +pydistcheck \ + --inspect \ + "$(echo ${wheel_dir_relative_path}/*.whl)" + +rapids-logger "validate packages with 'twine'" + +twine check \ + --strict \ + "$(echo ${wheel_dir_relative_path}/*.whl)" diff --git a/cpp/include/ucxx/request.h b/cpp/include/ucxx/request.h index 2584dda5..0a7b1505 100644 --- a/cpp/include/ucxx/request.h +++ b/cpp/include/ucxx/request.h @@ -120,7 +120,7 @@ class Request : public Component { /** * @brief Cancel the request. * - * Cancel the request. Often called by the an error handler or parent's object + * Cancel the request. Often called by the error handler or parent's object * destructor but may be called by the user to cancel the request as well. */ virtual void cancel(); diff --git a/cpp/include/ucxx/request_am.h b/cpp/include/ucxx/request_am.h index fb5cfd96..b7937c0c 100644 --- a/cpp/include/ucxx/request_am.h +++ b/cpp/include/ucxx/request_am.h @@ -96,6 +96,14 @@ class RequestAm : public Request { RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData); + /** + * @brief Cancel the request. + * + * Cancel the request. Often called by the error handler or parent's object + * destructor but may be called by the user to cancel the request as well. + */ + void cancel() override; + void populateDelayedSubmission() override; /** diff --git a/cpp/include/ucxx/worker.h b/cpp/include/ucxx/worker.h index c9808d0b..30fd1220 100644 --- a/cpp/include/ucxx/worker.h +++ b/cpp/include/ucxx/worker.h @@ -500,7 +500,7 @@ class Worker : public Component { [[nodiscard]] bool isFutureEnabled() const; /** - * @brief Populate the future pool. + * @brief Populate the futures pool. * * To avoid taking blocking resources (such as the Python GIL) for every new future * required by each `ucxx::Request`, the `ucxx::Worker` maintains a pool of futures @@ -512,6 +512,17 @@ class Worker : public Component { */ virtual void populateFuturesPool(); + /** + * @brief Clear the futures pool. + * + * Clear the futures pool, ensuring all references are removed and thus avoiding + * reference cycles that prevent the `ucxx::Worker` and other resources from cleaning + * up on time. + * + * @throws std::runtime_error if future support is not implemented. + */ + virtual void clearFuturesPool(); + /** * @brief Get a future from the pool. * diff --git a/cpp/python/include/ucxx/python/worker.h b/cpp/python/include/ucxx/python/worker.h index be4b77fd..3140305a 100644 --- a/cpp/python/include/ucxx/python/worker.h +++ b/cpp/python/include/ucxx/python/worker.h @@ -86,15 +86,28 @@ class Worker : public ::ucxx::Worker { std::shared_ptr context, const bool enableDelayedSubmission, const bool enableFuture); /** - * @brief Populate the Python future pool. + * @brief Populate the Python futures pool. * * To avoid taking the Python GIL for every new future required by each `ucxx::Request`, * the `ucxx::python::Worker` maintains a pool of futures that can be acquired when a new * `ucxx::Request` is created. Currently the pool has a maximum size of 100 objects, and * will refill once it goes under 50, otherwise calling this functions results in a no-op. + * + * @throws std::runtime_error if object was created with `enableFuture=false`. */ void populateFuturesPool() override; + /** + * @brief Clear the futures pool. + * + * Clear the futures pool, ensuring all references are removed and thus avoiding + * reference cycles that prevent the `ucxx::Worker` and other resources from cleaning + * up on time. + * + * This method is safe to be called even if object was created with `enableFuture=false`. + */ + void clearFuturesPool() override; + /** * @brief Get a Python future from the pool. * diff --git a/cpp/python/src/worker.cpp b/cpp/python/src/worker.cpp index 4534e952..b248c73d 100644 --- a/cpp/python/src/worker.cpp +++ b/cpp/python/src/worker.cpp @@ -71,6 +71,21 @@ void Worker::populateFuturesPool() } } +void Worker::clearFuturesPool() +{ + if (_enableFuture) { + ucxx_trace_req("ucxx::python::Worker::%s, Worker: %p, populateFuturesPool: %p", + __func__, + this, + shared_from_this().get()); + std::lock_guard lock(_futuresPoolMutex); + PyGILState_STATE state = PyGILState_Ensure(); + decltype(_futuresPool) newFuturesPool; + std::swap(_futuresPool, newFuturesPool); + PyGILState_Release(state); + } +} + std::shared_ptr<::ucxx::Future> Worker::getFuture() { if (_enableFuture) { diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 4731b78d..5f5870c3 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -262,8 +262,8 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts) if (_endpointErrorHandling) param = {.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS, .flags = UCP_EP_CLOSE_FLAG_FORCE}; - auto worker = ::ucxx::getWorker(_parent); - ucs_status_ptr_t status; + auto worker = ::ucxx::getWorker(_parent); + ucs_status_ptr_t status = nullptr; if (worker->isProgressThreadRunning()) { bool closeSuccess = false; diff --git a/cpp/src/request_am.cpp b/cpp/src/request_am.cpp index 713f6bdc..de29de0d 100644 --- a/cpp/src/request_am.cpp +++ b/cpp/src/request_am.cpp @@ -154,6 +154,26 @@ RequestAm::RequestAm(std::shared_ptr endpointOrWorker, requestData); } +void RequestAm::cancel() +{ + std::lock_guard lock(_mutex); + if (_status == UCS_INPROGRESS) { + /** + * This is needed to ensure AM requests are cancelable, since they do not + * use the `_request`, thus `ucp_request_cancel()` cannot cancel them. + */ + setStatus(UCS_ERR_CANCELED); + } else { + ucxx_trace_req_f(_ownerString.c_str(), + this, + _request, + _operationName.c_str(), + "already completed with status: %d (%s)", + _status, + ucs_status_string(_status)); + } +} + static void _amSendCallback(void* request, ucs_status_t status, void* user_data) { Request* req = reinterpret_cast(user_data); @@ -248,19 +268,29 @@ ucs_status_t RequestAm::recvCallback(void* arg, amHeader.memoryType = UCS_MEMORY_TYPE_HOST; } - std::shared_ptr buf = amData->_allocators.at(amHeader.memoryType)(length); + try { + buf = amData->_allocators.at(amHeader.memoryType)(length); + } catch (const std::exception& e) { + ucxx_debug("Exception calling allocator: %s", e.what()); + } auto recvAmMessage = std::make_shared(amData, ep, req, buf, receiverCallback); - ucp_request_param_t request_param = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | - UCP_OP_ATTR_FIELD_USER_DATA | - UCP_OP_ATTR_FLAG_NO_IMM_CMPL, - .cb = {.recv_am = _recvCompletedCallback}, - .user_data = recvAmMessage.get()}; + ucp_request_param_t requestParam = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_USER_DATA | + UCP_OP_ATTR_FLAG_NO_IMM_CMPL, + .cb = {.recv_am = _recvCompletedCallback}, + .user_data = recvAmMessage.get()}; + + if (buf == nullptr) { + ucxx_debug("Failed to allocate %lu bytes of memory", length); + recvAmMessage->_request->setStatus(UCS_ERR_NO_MEMORY); + return UCS_ERR_NO_MEMORY; + } ucs_status_ptr_t status = - ucp_am_recv_data_nbx(worker->getHandle(), data, buf->data(), length, &request_param); + ucp_am_recv_data_nbx(worker->getHandle(), data, buf->data(), length, &requestParam); if (req->_enablePythonFuture) ucxx_trace_req_f(ownerString.c_str(), @@ -302,7 +332,15 @@ ucs_status_t RequestAm::recvCallback(void* arg, return UCS_INPROGRESS; } } else { - std::shared_ptr buf = amData->_allocators.at(UCS_MEMORY_TYPE_HOST)(length); + buf = amData->_allocators.at(UCS_MEMORY_TYPE_HOST)(length); + + internal::RecvAmMessage recvAmMessage(amData, ep, req, buf, receiverCallback); + if (buf == nullptr) { + ucxx_debug("Failed to allocate %lu bytes of memory", length); + recvAmMessage._request->setStatus(UCS_ERR_NO_MEMORY); + return UCS_ERR_NO_MEMORY; + } + if (length > 0) memcpy(buf->data(), data, length); if (req->_enablePythonFuture) @@ -326,7 +364,6 @@ ucs_status_t RequestAm::recvCallback(void* arg, buf->data(), length); - internal::RecvAmMessage recvAmMessage(amData, ep, req, buf, receiverCallback); recvAmMessage.callback(nullptr, UCS_OK); return UCS_OK; } diff --git a/cpp/src/worker.cpp b/cpp/src/worker.cpp index 2adf8eca..a266493b 100644 --- a/cpp/src/worker.cpp +++ b/cpp/src/worker.cpp @@ -422,6 +422,8 @@ bool Worker::registerGenericPost(DelayedSubmissionCallbackType callback, uint64_ void Worker::populateFuturesPool() { THROW_FUTURE_NOT_IMPLEMENTED(); } +void Worker::clearFuturesPool() { THROW_FUTURE_NOT_IMPLEMENTED(); } + std::shared_ptr Worker::getFuture() { THROW_FUTURE_NOT_IMPLEMENTED(); } RequestNotifierWaitState Worker::waitRequestNotifier(uint64_t periodNs) diff --git a/python/distributed-ucxx/distributed_ucxx/ucxx.py b/python/distributed-ucxx/distributed_ucxx/ucxx.py index 1f5fc1df..df820a83 100644 --- a/python/distributed-ucxx/distributed_ucxx/ucxx.py +++ b/python/distributed-ucxx/distributed_ucxx/ucxx.py @@ -157,8 +157,8 @@ def _deregister_dask_resource(resource_id): # Stop notifier thread and progress tasks if no Dask resources using # UCXX communicators are running anymore. if len(ctx._dask_resources) == 0: - ctx.stop_notifier_thread() - ctx.progress_tasks.clear() + ucxx.stop_notifier_thread() + ctx.clear_progress_tasks() def _allocate_dask_resources_tracker() -> None: @@ -307,6 +307,25 @@ def _close_comm(ref): comm._closed = True +def _finalizer(endpoint: ucxx.Endpoint, resource_id: int) -> None: + """UCXX comms object finalizer. + + Attempt to close the UCXX endpoint if it's still alive, and deregister Dask + resource. + + Parameters + ---------- + endpoint: ucx_api.UCXEndpoint + The endpoint to close. + resource_id: int + The unique ID of the resource returned by `_register_dask_resource` upon + registration. + """ + if endpoint is not None: + endpoint.abort() + _deregister_dask_resource(resource_id) + + class UCXX(Comm): """Comm object using UCXX. @@ -375,14 +394,19 @@ def __init__( # type: ignore[no-untyped-def] else: self._has_close_callback = False - self._resource_id = _register_dask_resource() + resource_id = _register_dask_resource() + self._resource_id = resource_id logger.debug("UCX.__init__ %s", self) - weakref.finalize(self, _deregister_dask_resource, self._resource_id) + weakref.finalize(self, _finalizer, ep, resource_id) - def __del__(self) -> None: - self.abort() + def abort(self): + self._closed = True + if self._ep is not None: + self._ep.abort() + self._ep = None + _deregister_dask_resource(self._resource_id) @property def local_address(self) -> str: diff --git a/python/distributed-ucxx/pyproject.toml b/python/distributed-ucxx/pyproject.toml index 1005df8d..da7291cf 100644 --- a/python/distributed-ucxx/pyproject.toml +++ b/python/distributed-ucxx/pyproject.toml @@ -125,6 +125,14 @@ matrix-entry = "cuda_suffixed=true" [tool.setuptools.dynamic] version = {file = "distributed_ucxx/VERSION"} +[tool.pydistcheck] +select = [ + "distro-too-large-compressed", +] + +# PyPI limit is 100 MiB, fail CI before we get too close to that +max_allowed_size_compressed = '75M' + [tool.pytest.ini_options] markers = [ "ignore_alive_references", diff --git a/python/libucxx/libucxx/load.py b/python/libucxx/libucxx/load.py index 4d49d2cd..be514139 100644 --- a/python/libucxx/libucxx/load.py +++ b/python/libucxx/libucxx/load.py @@ -15,8 +15,35 @@ import ctypes import os +# Loading with RTLD_LOCAL adds the library itself to the loader's +# loaded library cache without loading any symbols into the global +# namespace. This allows libraries that express a dependency on +# this library to be loaded later and successfully satisfy this dependency +# without polluting the global symbol table with symbols from +# libucxx that could conflict with symbols from other DSOs. +PREFERRED_LOAD_FLAG = ctypes.RTLD_LOCAL + + +def _load_system_installation(soname: str): + """Try to dlopen() the library indicated by ``soname`` + + Raises ``OSError`` if library cannot be loaded. + """ + return ctypes.CDLL(soname, PREFERRED_LOAD_FLAG) + + +def _load_wheel_installation(soname: str): + """Try to dlopen() the library indicated by ``soname`` + + Returns ``None`` if the library cannot be loaded. + """ + if os.path.isfile(lib := os.path.join(os.path.dirname(__file__), "lib64", soname)): + return ctypes.CDLL(lib, PREFERRED_LOAD_FLAG) + return None + def load_library(): + """Dynamically load libucxx.so and its dependencies""" # If libucx was installed as a wheel, we must request it to load the library # symbols. Otherwise, we assume that the library was installed in a system path # that ld can find. @@ -28,25 +55,33 @@ def load_library(): libucx.load_library() del libucx - # Dynamically load libucxx.so. Prefer a system library if one is present to - # avoid clobbering symbols that other packages might expect, but if no - # other library is present use the one in the wheel. + prefer_system_installation = ( + os.getenv("RAPIDS_LIBUCXX_PREFER_SYSTEM_LIBRARY", "false").lower() != "false" + ) + + soname = "libucxx.so" libucxx_lib = None - try: - libucxx_lib = ctypes.CDLL("libucxx.so", ctypes.RTLD_GLOBAL) - except OSError: - # If neither of these directories contain the library, we assume we are in an - # environment where the C++ library is already installed somewhere else and the - # CMake build of the libucxx Python package was a no-op. Note that this approach - # won't work for real editable installs of the libucxx package, but that's not a - # use case I think we need to support. scikit-build-core has limited support for - # importlib.resources so there isn't a clean way to support that case yet. - for lib_dir in ("lib", "lib64"): - if os.path.isfile( - lib := os.path.join(os.path.dirname(__file__), lib_dir, "libucxx.so") - ): - libucxx_lib = ctypes.CDLL(lib, ctypes.RTLD_GLOBAL) - break + if prefer_system_installation: + # Prefer a system library if one is present to + # avoid clobbering symbols that other packages might expect, but if no + # other library is present use the one in the wheel. + try: + libucxx_lib = _load_system_installation(soname) + except OSError: + libucxx_lib = _load_wheel_installation(soname) + else: + # Prefer the libraries bundled in this package. If they aren't found + # (which might be the case in builds where the library was prebuilt + # before packaging the wheel), look for a system installation. + try: + libucxx_lib = _load_wheel_installation(soname) + if libucxx_lib is None: + libucxx_lib = _load_system_installation(soname) + except OSError: + # If none of the searches above succeed, just silently return None + # and rely on other mechanisms (like RPATHs on other DSOs) to + # help the loader find the library. + pass # The caller almost never needs to do anything with this library, but no # harm in offering the option since this object at least provides a handle diff --git a/python/libucxx/pyproject.toml b/python/libucxx/pyproject.toml index e035effb..8309badf 100644 --- a/python/libucxx/pyproject.toml +++ b/python/libucxx/pyproject.toml @@ -65,3 +65,11 @@ requires = [ "libucx==1.15.0", "ninja", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. + +[tool.pydistcheck] +select = [ + "distro-too-large-compressed", +] + +# PyPI limit is 100 MiB, fail CI before we get too close to that +max_allowed_size_compressed = '75M' diff --git a/python/ucxx/pyproject.toml b/python/ucxx/pyproject.toml index c548e85e..730fc0cb 100644 --- a/python/ucxx/pyproject.toml +++ b/python/ucxx/pyproject.toml @@ -79,3 +79,11 @@ wheel.exclude = ["*.pyx", "CMakeLists.txt"] provider = "scikit_build_core.metadata.regex" input = "ucxx/VERSION" regex = "(?P.*)" + +[tool.pydistcheck] +select = [ + "distro-too-large-compressed", +] + +# PyPI limit is 100 MiB, fail CI before we get too close to that +max_allowed_size_compressed = '75M' diff --git a/python/ucxx/ucxx/_lib/libucxx.pyx b/python/ucxx/ucxx/_lib/libucxx.pyx index 2455a781..434bef16 100644 --- a/python/ucxx/ucxx/_lib/libucxx.pyx +++ b/python/ucxx/ucxx/_lib/libucxx.pyx @@ -95,7 +95,7 @@ def _get_host_buffer(uintptr_t recv_buffer_ptr): return np.asarray(HostBufferAdapter._from_host_buffer(host_buffer)) -cdef shared_ptr[Buffer] _rmm_am_allocator(size_t length): +cdef shared_ptr[Buffer] _rmm_am_allocator(size_t length) noexcept nogil: cdef shared_ptr[RMMBuffer] rmm_buffer = make_shared[RMMBuffer](length) return dynamic_pointer_cast[Buffer, RMMBuffer](rmm_buffer) @@ -741,6 +741,10 @@ cdef class UCXWorker(): with nogil: self._worker.get().populateFuturesPool() + def clear_python_futures_pool(self) -> None: + with nogil: + self._worker.get().clearFuturesPool() + def is_delayed_submission_enabled(self) -> bool: warnings.warn( "UCXWorker.is_delayed_submission_enabled() is deprecated and will soon " diff --git a/python/ucxx/ucxx/_lib/ucxx_api.pxd b/python/ucxx/ucxx/_lib/ucxx_api.pxd index 9c30a4c3..88512a4f 100644 --- a/python/ucxx/ucxx/_lib/ucxx_api.pxd +++ b/python/ucxx/ucxx/_lib/ucxx_api.pxd @@ -155,7 +155,7 @@ cdef extern from "" namespace "ucxx" nogil: void* data() except +raise_py_error cdef cppclass RMMBuffer: - RMMBuffer(const size_t size_t) + RMMBuffer(const size_t size_t) except +raise_py_error BufferType getType() size_t getSize() unique_ptr[device_buffer] release() except +raise_py_error @@ -251,6 +251,7 @@ cdef extern from "" namespace "ucxx" nogil: ) except +raise_py_error void runRequestNotifier() except +raise_py_error void populateFuturesPool() except +raise_py_error + void clearFuturesPool() shared_ptr[Request] tagRecv( void* buffer, size_t length, diff --git a/python/ucxx/ucxx/_lib_async/application_context.py b/python/ucxx/ucxx/_lib_async/application_context.py index a146b309..29d14f93 100644 --- a/python/ucxx/ucxx/_lib_async/application_context.py +++ b/python/ucxx/ucxx/_lib_async/application_context.py @@ -23,6 +23,14 @@ logger = logging.getLogger("ucx") +ProgressTasks = dict() + + +def clear_progress_tasks(): + global ProgressTasks + ProgressTasks.clear() + + class ApplicationContext: """ The context of the Asyncio interface of UCX. @@ -40,7 +48,6 @@ def __init__( enable_python_future=None, exchange_peer_info_timeout=10.0, ): - self.progress_tasks = dict() self.notifier_thread_q = None self.notifier_thread = None self._listener_active_clients = ActiveClients() @@ -62,7 +69,7 @@ def __init__( self.start_notifier_thread() - weakref.finalize(self, self.progress_tasks.clear) + weakref.finalize(self, clear_progress_tasks) # Ensure progress even before Endpoints get created, for example to # receive messages directly on a worker after a remote endpoint @@ -194,6 +201,10 @@ def ucp_worker_info(self): def worker_address(self): return self.worker.address + def clear_progress_tasks(self) -> None: + global ProgressTasks + ProgressTasks.clear() + def start_notifier_thread(self): if self.worker.enable_python_future and self.notifier_thread is None: logger.debug("UCXX_ENABLE_PYTHON available, enabling notifier thread") @@ -456,7 +467,8 @@ def continuous_ucx_progress(self, event_loop=None): Python 3.10+) is used. """ loop = event_loop if event_loop is not None else get_event_loop() - if loop in self.progress_tasks: + global ProgressTasks + if loop in ProgressTasks: return # Progress has already been guaranteed for the current event loop logger.info(f"Starting progress in '{self.progress_mode}' mode") @@ -470,7 +482,7 @@ def continuous_ucx_progress(self, event_loop=None): elif self.progress_mode == "blocking": task = BlockingMode(self.worker, loop) - self.progress_tasks[loop] = task + ProgressTasks[loop] = task def get_ucp_worker(self): """Returns the underlying UCP worker handle (ucp_worker_h) diff --git a/python/ucxx/ucxx/_lib_async/continuous_ucx_progress.py b/python/ucxx/ucxx/_lib_async/continuous_ucx_progress.py index 02cde6fd..4e32caed 100644 --- a/python/ucxx/ucxx/_lib_async/continuous_ucx_progress.py +++ b/python/ucxx/ucxx/_lib_async/continuous_ucx_progress.py @@ -73,9 +73,7 @@ def __init__(self, worker, event_loop, polling_mode=False): super().__init__(worker, event_loop) worker.set_progress_thread_start_callback(_create_context) worker.start_progress_thread(polling_mode=polling_mode, epoll_timeout=1) - - def __del__(self): - self.worker.stop_progress_thread() + weakref.finalize(self, worker.stop_progress_thread) class PollingMode(ProgressTask): diff --git a/python/ucxx/ucxx/_lib_async/endpoint.py b/python/ucxx/ucxx/_lib_async/endpoint.py index 5fee43c9..48d7877a 100644 --- a/python/ucxx/ucxx/_lib_async/endpoint.py +++ b/python/ucxx/ucxx/_lib_async/endpoint.py @@ -6,6 +6,7 @@ import asyncio import logging import warnings +import weakref import ucxx._lib.libucxx as ucx_api from ucxx._lib.arr import Array @@ -17,6 +18,23 @@ logger = logging.getLogger("ucx") +def _finalizer(endpoint: ucx_api.UCXEndpoint) -> None: + """Endpoint finalizer. + + Attempt to close the endpoint if it's still alive. + + Parameters + ---------- + endpoint: ucx_api.UCXEndpoint + The endpoint to close. + """ + if endpoint is not None: + logger.debug(f"Endpoint _finalize(): {endpoint.handle:#x}") + # Wait for a maximum of `period` ns + endpoint.close_blocking(period=10**10, max_attempts=1) + endpoint.remove_close_callback() + + class Endpoint: """An endpoint represents a connection to a peer @@ -41,8 +59,7 @@ def __init__(self, endpoint, ctx, tags=None): self._close_after_n_recv = None self._tags = tags - def __del__(self): - self.abort() + weakref.finalize(self, _finalizer, endpoint) @property def alive(self): @@ -107,7 +124,7 @@ def abort(self, period=10**10, max_attempts=1): if worker is running a progress thread and `period > 0`. """ if self._ep is not None: - logger.debug("Endpoint.abort(): 0x%x" % self.uid) + logger.debug(f"Endpoint.abort(): {self.uid:#x}") # Wait for a maximum of `period` ns self._ep.close_blocking(period=period, max_attempts=max_attempts) self._ep.remove_close_callback() diff --git a/python/ucxx/ucxx/_lib_async/listener.py b/python/ucxx/ucxx/_lib_async/listener.py index 9531d86c..c68ccf37 100644 --- a/python/ucxx/ucxx/_lib_async/listener.py +++ b/python/ucxx/ucxx/_lib_async/listener.py @@ -5,6 +5,7 @@ import logging import os import threading +import weakref import ucxx._lib.libucxx as ucx_api from ucxx.exceptions import UCXMessageTruncatedError @@ -30,37 +31,62 @@ def __init__(self): self._locks = dict() self._active_clients = dict() - def add_listener(self, id: int) -> None: - if id in self._active_clients: - raise ValueError("Listener {id} is already registered in ActiveClients.") + def add_listener(self, ident: int) -> None: + if ident in self._active_clients: + raise ValueError("Listener {ident} is already registered in ActiveClients.") - self._locks[id] = threading.Lock() - self._active_clients[id] = 0 + self._locks[ident] = threading.Lock() + self._active_clients[ident] = 0 - def remove_listener(self, id: int) -> None: - with self._locks[id]: - active_clients = self.get_active(id) + def remove_listener(self, ident: int) -> None: + with self._locks[ident]: + active_clients = self.get_active(ident) if active_clients > 0: raise RuntimeError( - "Listener {id} is being removed from ActiveClients, but " + "Listener {ident} is being removed from ActiveClients, but " f"{active_clients} active client(s) is(are) still accounted for." ) - del self._locks[id] - del self._active_clients[id] + del self._locks[ident] + del self._active_clients[ident] - def inc(self, id: int) -> None: - with self._locks[id]: - self._active_clients[id] += 1 + def inc(self, ident: int) -> None: + with self._locks[ident]: + self._active_clients[ident] += 1 - def dec(self, id: int) -> None: - with self._locks[id]: - if self._active_clients[id] == 0: - raise ValueError(f"There are no active clients for listener {id}") - self._active_clients[id] -= 1 + def dec(self, ident: int) -> None: + with self._locks[ident]: + if self._active_clients[ident] == 0: + raise ValueError(f"There are no active clients for listener {ident}") + self._active_clients[ident] -= 1 - def get_active(self, id: int) -> int: - return self._active_clients[id] + def get_active(self, ident: int) -> int: + return self._active_clients[ident] + + +def _finalizer(ident: int, active_clients: ActiveClients) -> None: + """Listener finalizer. + + Finalize the listener and remove it from the `ActiveClients`. If there are + active clients, a warning is logged. + + Parameters + ---------- + ident: int + The unique identifier of the `Listener`. + active_clients: ActiveClients + Instance of `ActiveClients` owned by the parent `ApplicationContext` + from which to remove the `Listener`. + """ + try: + active_clients.remove_listener(ident) + except RuntimeError: + active_clients = active_clients.get_active(ident) + logger.warning( + f"Listener object is being destroyed, but {active_clients} client " + "handler(s) is(are) still alive. This usually indicates the Listener " + "was prematurely destroyed." + ) class Listener: @@ -70,26 +96,17 @@ class Listener: Please use `create_listener()` to create an Listener. """ - def __init__(self, listener, id, active_clients): + def __init__(self, listener, ident, active_clients): if not isinstance(listener, ucx_api.UCXListener): raise ValueError("listener must be an instance of UCXListener") self._listener = listener - active_clients.add_listener(id) - self._id = id + active_clients.add_listener(ident) + self._ident = ident self._active_clients = active_clients - def __del__(self): - try: - self._active_clients.remove_listener(self._id) - except RuntimeError: - active_clients = self._active_clients.get_active(self._id) - logger.warning( - f"Listener object is being destroyed, but {active_clients} client " - "handler(s) is(are) still alive. This usually indicates the Listener " - "was prematurely destroyed." - ) + weakref.finalize(self, _finalizer, ident, active_clients) @property def closed(self): @@ -108,7 +125,7 @@ def port(self): @property def active_clients(self): - return self._active_clients.get_active(self._id) + return self._active_clients.get_active(self._ident) def close(self): """Closing the listener""" @@ -121,11 +138,11 @@ async def _listener_handler_coroutine( func, endpoint_error_handling, exchange_peer_info_timeout, - id, + ident, active_clients, ): # def _listener_handler_coroutine( - # conn_request, ctx, func, endpoint_error_handling, id, active_clients + # conn_request, ctx, func, endpoint_error_handling, ident, active_clients # ): # We create the Endpoint in five steps: # 1) Create endpoint from conn_request @@ -133,7 +150,7 @@ async def _listener_handler_coroutine( # 3) Exchange endpoint info such as tags # 4) Setup control receive callback # 5) Execute the listener's callback function - active_clients.inc(id) + active_clients.inc(ident) endpoint = conn_request seed = os.urandom(16) @@ -186,9 +203,9 @@ async def _listener_handler_coroutine( else: func(ep) - active_clients.dec(id) + active_clients.dec(ident) - # Ensure `ep` is destroyed and `__del__` is called + # Ensure no references to `ep` remain to permit garbage collection. del ep @@ -199,7 +216,7 @@ def _listener_handler( ctx, endpoint_error_handling, exchange_peer_info_timeout, - id, + ident, active_clients, ): asyncio.run_coroutine_threadsafe( @@ -209,7 +226,7 @@ def _listener_handler( callback_func, endpoint_error_handling, exchange_peer_info_timeout, - id, + ident, active_clients, ), event_loop, diff --git a/python/ucxx/ucxx/_lib_async/notifier_thread.py b/python/ucxx/ucxx/_lib_async/notifier_thread.py index 23254a36..d906860a 100644 --- a/python/ucxx/ucxx/_lib_async/notifier_thread.py +++ b/python/ucxx/ucxx/_lib_async/notifier_thread.py @@ -47,7 +47,7 @@ def _notifierThread(event_loop, worker, q): ) if state == ucx_api.PythonRequestNotifierWaitState.Shutdown or shutdown is True: - return + break elif state == ucx_api.PythonRequestNotifierWaitState.Timeout: continue @@ -62,3 +62,7 @@ def _notifierThread(event_loop, worker, q): logger.debug("Notifier Thread Result Timeout") except Exception as e: logger.debug(f"Notifier Thread Result Exception: {e}") + + # Clear all Python futures to ensure no references are held to the + # `ucxx::Worker` that will prevent destructors from running. + worker.clear_python_futures_pool() diff --git a/python/ucxx/ucxx/_lib_async/tests/test_endpoint.py b/python/ucxx/ucxx/_lib_async/tests/test_endpoint.py index abd82669..5b127fa9 100644 --- a/python/ucxx/ucxx/_lib_async/tests/test_endpoint.py +++ b/python/ucxx/ucxx/_lib_async/tests/test_endpoint.py @@ -58,9 +58,6 @@ async def client_node(port): @pytest.mark.asyncio @pytest.mark.parametrize("transfer_api", ["am", "tag", "tag_multi"]) async def test_cancel(transfer_api): - if transfer_api == "am": - pytest.skip("AM not implemented yet") - q = Queue() async def server_node(ep): diff --git a/python/ucxx/ucxx/_lib_async/tests/test_from_worker_address_error.py b/python/ucxx/ucxx/_lib_async/tests/test_from_worker_address_error.py index 5832fa46..c53c5a1e 100644 --- a/python/ucxx/ucxx/_lib_async/tests/test_from_worker_address_error.py +++ b/python/ucxx/ucxx/_lib_async/tests/test_from_worker_address_error.py @@ -50,6 +50,9 @@ def _test_from_worker_address_error_client(q1, q2, error_type): async def run(): # Receive worker address from server via multiprocessing.Queue remote_address = ucxx.get_ucx_address_from_buffer(q1.get()) + if error_type == "unreachable": + server_closed = q1.get() + assert server_closed == "Server closed" if error_type == "unreachable": with pytest.raises( @@ -162,9 +165,6 @@ async def run(): }, ) def test_from_worker_address_error(error_type): - if error_type in ["timeout_am_send", "timeout_am_recv"]: - pytest.skip("AM not implemented yet") - q1 = mp.Queue() q2 = mp.Queue() @@ -180,6 +180,10 @@ def test_from_worker_address_error(error_type): ) client.start() + if error_type == "unreachable": + server.join() + q1.put("Server closed") + join_processes([client, server], timeout=30) terminate_process(server) try: diff --git a/python/ucxx/ucxx/_lib_async/tests/test_send_recv_two_workers.py b/python/ucxx/ucxx/_lib_async/tests/test_send_recv_two_workers.py index 6bd0a6c2..e49783c6 100644 --- a/python/ucxx/ucxx/_lib_async/tests/test_send_recv_two_workers.py +++ b/python/ucxx/ucxx/_lib_async/tests/test_send_recv_two_workers.py @@ -9,7 +9,7 @@ import numpy as np import pytest -import ucxx as ucxx +import ucxx from ucxx._lib_async.utils import get_event_loop from ucxx._lib_async.utils_test import ( am_recv, @@ -27,49 +27,54 @@ distributed = pytest.importorskip("distributed") cloudpickle = pytest.importorskip("cloudpickle") +# Enable for additional debug output +VERBOSE = False + ITERATIONS = 30 +def print_with_pid(msg): + if VERBOSE: + print(f"[{os.getpid()}] {msg}") + + async def get_ep(name, port): addr = ucxx.get_address() ep = await ucxx.create_endpoint(addr, port) return ep -def register_am_allocators(): - ucxx.register_am_allocator(lambda n: np.empty(n, dtype=np.uint8), "host") - ucxx.register_am_allocator(lambda n: rmm.DeviceBuffer(size=n), "cuda") - - def client(port, func, comm_api): - # wait for server to come up - # receive cudf object - # deserialize - # assert deserialized msg is cdf - # send receipt + # 1. Wait for server to come up + # 2. Loop receiving object multiple times from server + # 3. Send close message + # 4. Assert last received message has correct content from distributed.utils import nbytes - ucxx.init() - - if comm_api == "am": - register_am_allocators() - # must create context before importing # cudf/cupy/etc + ucxx.init() + async def read(): await asyncio.sleep(1) ep = await get_ep("client", port) - msg = None - import cupy - cupy.cuda.set_allocator(None) for i in range(ITERATIONS): - print(f"Client iteration {i}") + print_with_pid(f"Client iteration {i}") if comm_api == "tag": frames, msg = await recv(ep) else: - frames, msg = await am_recv(ep) + while True: + try: + frames, msg = await am_recv(ep) + except ucxx.exceptions.UCXNoMemoryError as e: + # Client didn't receive/consume messages quickly enough, + # new AM failed to allocate memory and raised this + # exception, we need to keep trying. + print_with_pid(f"Client exception: {type(e)} {e}") + else: + break close_msg = b"shutdown listener" @@ -81,13 +86,13 @@ async def read(): else: await ep.am_send(close_msg) - print("Shutting Down Client...") + print_with_pid("Shutting Down Client...") return msg["data"] rx_cuda_obj = get_event_loop().run_until_complete(read()) rx_cuda_obj + rx_cuda_obj num_bytes = nbytes(rx_cuda_obj) - print(f"TOTAL DATA RECEIVED: {num_bytes}") + print_with_pid(f"TOTAL DATA RECEIVED: {num_bytes}") cuda_obj_generator = cloudpickle.loads(func) pure_cuda_obj = cuda_obj_generator() @@ -101,39 +106,39 @@ async def read(): def server(port, func, comm_api): - # create listener receiver - # write cudf object - # confirm message is sent correctly + # 1. Create listener receiver + # 2. Loop sending object multiple times to connected client + # 3. Receive close message and close listener from distributed.comm.utils import to_frames from distributed.protocol import to_serialize ucxx.init() - if comm_api == "am": - register_am_allocators() - async def f(listener_port): - # coroutine shows up when the client asks - # to connect + # Coroutine shows up when the client asks to connect async def write(ep): - import cupy - - cupy.cuda.set_allocator(None) - - print("CREATING CUDA OBJECT IN SERVER...") + print_with_pid("CREATING CUDA OBJECT IN SERVER...") cuda_obj_generator = cloudpickle.loads(func) cuda_obj = cuda_obj_generator() msg = {"data": to_serialize(cuda_obj)} frames = await to_frames(msg, serializers=("cuda", "dask", "pickle")) for i in range(ITERATIONS): - print(f"Server iteration {i}") + print_with_pid(f"Server iteration {i}") # Send meta data if comm_api == "tag": await send(ep, frames) else: - await am_send(ep, frames) - - print("CONFIRM RECEIPT") + while True: + try: + await am_send(ep, frames) + except ucxx.exceptions.UCXNoMemoryError as e: + # Memory pressure due to client taking too long to + # receive will raise an exception. + print_with_pid(f"Listener exception: {type(e)} {e}") + else: + break + + print_with_pid("CONFIRM RECEIPT") close_msg = b"shutdown listener" if comm_api == "tag": @@ -147,7 +152,7 @@ async def write(ep): recv_msg = msg.tobytes() assert recv_msg == close_msg - print("Shutting Down Server...") + print_with_pid("Shutting Down Server...") await ep.close() lf.close() @@ -156,10 +161,8 @@ async def write(ep): try: while not lf.closed: await asyncio.sleep(0.1) - # except ucxx.UCXCloseError: - # pass - except Exception as e: - print(f"Exception: {e=}") + except ucxx.UCXCloseError: + pass loop = get_event_loop() loop.run_until_complete(f(port)) @@ -199,33 +202,28 @@ def cupy_obj(): @pytest.mark.slow -@pytest.mark.skipif( - get_num_gpus() <= 2, reason="Machine does not have more than two GPUs" -) +@pytest.mark.skipif(get_num_gpus() <= 2, reason="Machine needs at least two GPUs") @pytest.mark.parametrize( "cuda_obj_generator", [dataframe, empty_dataframe, series, cupy_obj] ) @pytest.mark.parametrize("comm_api", ["tag", "am"]) def test_send_recv_cu(cuda_obj_generator, comm_api): - if comm_api == "am": - pytest.skip("AM not implemented yet") - base_env = os.environ env_client = base_env.copy() - # grab first two devices + # Grab first two devices cvd = get_cuda_devices()[:2] cvd = ",".join(map(str, cvd)) - # reverse CVD for other worker + # Reverse CVD for client env_client["CUDA_VISIBLE_DEVICES"] = cvd[::-1] port = random.randint(13000, 15500) - # serialize function and send to the client and server - # server will use the return value of the contents, - # serialize the values, then send serialized values to client. - # client will compare return values of the deserialized - # data sent from the server + # Serialize function and send to the client and server. The server will use + # the return value of the contents, serialize the values, then send + # serialized values to client. The client will compare return values of the + # deserialized data sent from the server. func = cloudpickle.dumps(cuda_obj_generator) + ctx = multiprocessing.get_context("spawn") server_process = ctx.Process( name="server", target=server, args=[port, func, comm_api] @@ -235,12 +233,12 @@ def test_send_recv_cu(cuda_obj_generator, comm_api): ) server_process.start() - # cudf will ping the driver for validity of device - # this will influence device on which a cuda context is created. - # work around is to update env with new CVD before spawning + # cuDF will ping the driver for validity of device, this will influence + # device on which a cuda context is created. Workaround is to update + # env with new CVD before spawning os.environ.update(env_client) client_process.start() - join_processes([client, server], timeout=30) - terminate_process(client) - terminate_process(server) + join_processes([client_process, server_process], timeout=3000) + terminate_process(client_process) + terminate_process(server_process) diff --git a/python/ucxx/ucxx/_lib_async/tests/test_shutdown.py b/python/ucxx/ucxx/_lib_async/tests/test_shutdown.py index 163c4fc3..bc39d05a 100644 --- a/python/ucxx/ucxx/_lib_async/tests/test_shutdown.py +++ b/python/ucxx/ucxx/_lib_async/tests/test_shutdown.py @@ -35,8 +35,6 @@ async def _shutdown_recv(ep, message_type): @pytest.mark.parametrize("message_type", ["tag", "am"]) async def test_server_shutdown(message_type): """The server calls shutdown""" - if message_type == "am": - pytest.skip("AM not implemented yet") async def server_node(ep): with pytest.raises(ucxx.exceptions.UCXCanceledError): @@ -67,8 +65,6 @@ async def client_node(port): @pytest.mark.parametrize("message_type", ["tag", "am"]) async def test_client_shutdown(message_type): """The client calls shutdown""" - if message_type == "am": - pytest.skip("AM not implemented yet") async def client_node(port): ep = await ucxx.create_endpoint( @@ -96,8 +92,6 @@ async def server_node(ep): @pytest.mark.parametrize("message_type", ["tag", "am"]) async def test_listener_close(message_type): """The server close the listener""" - if message_type == "am": - pytest.skip("AM not implemented yet") async def client_node(listener): ep = await ucxx.create_endpoint( @@ -125,8 +119,6 @@ async def server_node(ep): @pytest.mark.parametrize("message_type", ["tag", "am"]) async def test_listener_del(message_type): """The client delete the listener""" - if message_type == "am": - pytest.skip("AM not implemented yet") async def server_node(ep): await _shutdown_send(ep, message_type) @@ -156,8 +148,6 @@ async def server_node(ep): @pytest.mark.parametrize("message_type", ["tag", "am"]) async def test_close_after_n_recv(message_type): """The Endpoint.close_after_n_recv()""" - if message_type == "am": - pytest.skip("AM not implemented yet") async def server_node(ep): for _ in range(10):