diff --git a/python/cuda_parallel/cuda/parallel/experimental/_cccl.py b/python/cuda_parallel/cuda/parallel/experimental/_cccl.py index 955274d66e4..e231f721238 100644 --- a/python/cuda_parallel/cuda/parallel/experimental/_cccl.py +++ b/python/cuda_parallel/cuda/parallel/experimental/_cccl.py @@ -11,7 +11,7 @@ import numpy as np from numba import cuda, types -from ._utils.cai import get_dtype, is_contiguous +from ._utils.protocols import get_dtype, is_contiguous from .iterators._iterators import IteratorBase from .typing import DeviceArrayLike, GpuStruct diff --git a/python/cuda_parallel/cuda/parallel/experimental/_utils/cai.py b/python/cuda_parallel/cuda/parallel/experimental/_utils/protocols.py similarity index 66% rename from python/cuda_parallel/cuda/parallel/experimental/_utils/cai.py rename to python/cuda_parallel/cuda/parallel/experimental/_utils/protocols.py index 3a3391f93f2..d62717115cb 100644 --- a/python/cuda_parallel/cuda/parallel/experimental/_utils/cai.py +++ b/python/cuda_parallel/cuda/parallel/experimental/_utils/protocols.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception """ -Utilities for extracting information from `__cuda_array_interface__`. +Utilities for extracting information from protocols such as `__cuda_array_interface__` and `__cuda_stream__`. """ from typing import Optional, Tuple @@ -68,3 +68,30 @@ def is_contiguous(arr: DeviceArrayLike) -> bool: else: # not contiguous return False + + +def validate_and_get_stream(stream) -> Optional[int]: + # null stream is allowed + if stream is None: + return None + + try: + stream_property = stream.__cuda_stream__() + except AttributeError as e: + raise TypeError( + f"stream argument {stream} does not implement the '__cuda_stream__' protocol" + ) from e + + try: + version, handle, *_ = stream_property + except (TypeError, ValueError) as e: + raise TypeError( + f"could not obtain __cuda_stream__ protocol version and handle from {stream_property}" + ) from e + + if version == 0: + if not isinstance(handle, int): + raise TypeError(f"invalid stream handle {handle}") + return handle + + raise TypeError(f"unsupported __cuda_stream__ version {version}") diff --git a/python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py b/python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py index 5e731bc4c50..f0b73f2b51d 100644 --- a/python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py +++ b/python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py @@ -16,7 +16,7 @@ from .. import _cccl as cccl from .._bindings import get_bindings, get_paths from .._caching import CachableFunction, cache_with_key -from .._utils import cai +from .._utils import protocols from ..iterators._iterators import IteratorBase from ..typing import DeviceArrayLike, GpuStruct @@ -63,7 +63,7 @@ def __init__( self._ctor_d_in_cccl_type_enum_name = cccl.type_enum_as_name( d_in_cccl.value_type.type.value ) - self._ctor_d_out_dtype = cai.get_dtype(d_out) + self._ctor_d_out_dtype = protocols.get_dtype(d_out) self._ctor_init_dtype = h_init.dtype cc_major, cc_minor = cuda.get_current_device().compute_capability cub_path, thrust_path, libcudacxx_path, cuda_include_path = get_paths() @@ -89,7 +89,13 @@ def __init__( raise ValueError("Error building reduce") def __call__( - self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray | GpuStruct + self, + temp_storage, + d_in, + d_out, + num_items: int, + h_init: np.ndarray | GpuStruct, + stream=None, ): d_in_cccl = cccl.to_cccl_iter(d_in) if d_in_cccl.type.value == cccl.IteratorKind.ITERATOR: @@ -104,8 +110,9 @@ def __call__( self._ctor_d_in_cccl_type_enum_name, cccl.type_enum_as_name(d_in_cccl.value_type.type.value), ) - _dtype_validation(self._ctor_d_out_dtype, cai.get_dtype(d_out)) + _dtype_validation(self._ctor_d_out_dtype, protocols.get_dtype(d_out)) _dtype_validation(self._ctor_init_dtype, h_init.dtype) + stream_handle = protocols.validate_and_get_stream(stream) bindings = get_bindings() if temp_storage is None: temp_storage_bytes = ctypes.c_size_t() @@ -125,7 +132,7 @@ def __call__( ctypes.c_ulonglong(num_items), self.op_wrapper.handle(), cccl.to_cccl_value(h_init), - None, + stream_handle, ) if error != enums.CUDA_SUCCESS: raise ValueError("Error reducing") @@ -145,8 +152,10 @@ def make_cache_key( op: Callable, h_init: np.ndarray, ): - d_in_key = d_in.kind if isinstance(d_in, IteratorBase) else cai.get_dtype(d_in) - d_out_key = cai.get_dtype(d_out) + d_in_key = ( + d_in.kind if isinstance(d_in, IteratorBase) else protocols.get_dtype(d_in) + ) + d_out_key = protocols.get_dtype(d_out) op_key = CachableFunction(op) h_init_key = h_init.dtype return (d_in_key, d_out_key, op_key, h_init_key) diff --git a/python/cuda_parallel/tests/test_reduce.py b/python/cuda_parallel/tests/test_reduce.py index 9549ef7bee3..65710954b0b 100644 --- a/python/cuda_parallel/tests/test_reduce.py +++ b/python/cuda_parallel/tests/test_reduce.py @@ -550,3 +550,108 @@ def binary_op(x, y): d_in = cp.zeros(size)[::2] with pytest.raises(ValueError, match="Non-contiguous arrays are not supported."): _ = algorithms.reduce_into(d_in, d_out, binary_op, h_init) + + +def test_reduce_with_stream(): + # Simple cupy stream wrapper that implements the __cuda_stream__ protocol for the purposes of this test + class Stream: + def __init__(self, cp_stream): + self.cp_stream = cp_stream + + def __cuda_stream__(self): + return (0, self.cp_stream.ptr) + + def add_op(x, y): + return x + y + + h_init = np.asarray([0], dtype=np.int32) + h_in = random_int(5, np.int32) + + stream = cp.cuda.Stream() + with stream: + d_in = cp.asarray(h_in) + d_out = cp.empty(1, dtype=np.int32) + + stream_wrapper = Stream(stream) + reduce_into = algorithms.reduce_into( + d_in=d_in, d_out=d_out, op=add_op, h_init=h_init + ) + temp_storage_size = reduce_into( + None, + d_in=d_in, + d_out=d_out, + num_items=d_in.size, + h_init=h_init, + stream=stream_wrapper, + ) + with stream: + d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8) + + reduce_into(d_temp_storage, d_in, d_out, d_in.size, h_init, stream=stream_wrapper) + with stream: + cp.testing.assert_allclose(d_in.sum().get(), d_out.get()) + + +def test_reduce_invalid_stream(): + # Invalid stream that doesn't implement __cuda_stream__ + class Stream1: + def __init__(self): + pass + + # Invalid stream that implements __cuda_stream__ but returns the wrong type + class Stream2: + def __init__(self): + pass + + def __cuda_stream__(self): + return None + + # Invalid stream that returns an invalid handle + class Stream3: + def __init__(self): + pass + + def __cuda_stream__(self): + return (0, None) + + def add_op(x, y): + return x + y + + d_out = cp.empty(1) + h_init = np.empty(1) + d_in = cp.empty(1) + reduce_into = algorithms.reduce_into(d_in, d_out, add_op, h_init) + + with pytest.raises( + TypeError, match="does not implement the '__cuda_stream__' protocol" + ): + _ = reduce_into( + None, + d_in=d_in, + d_out=d_out, + num_items=d_in.size, + h_init=h_init, + stream=Stream1(), + ) + + with pytest.raises( + TypeError, match="could not obtain __cuda_stream__ protocol version and handle" + ): + _ = reduce_into( + None, + d_in=d_in, + d_out=d_out, + num_items=d_in.size, + h_init=h_init, + stream=Stream2(), + ) + + with pytest.raises(TypeError, match="invalid stream handle"): + _ = reduce_into( + None, + d_in=d_in, + d_out=d_out, + num_items=d_in.size, + h_init=h_init, + stream=Stream3(), + )