-
Notifications
You must be signed in to change notification settings - Fork 178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cuda.parallel: Add optional stream argument to reduce_into() #3348
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -6,7 +6,7 @@ | |||||||||||||||||||||||||||||||
from __future__ import annotations # TODO: required for Python 3.7 docs env | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
import ctypes | ||||||||||||||||||||||||||||||||
from typing import Callable | ||||||||||||||||||||||||||||||||
from typing import Callable, Optional | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
import numba | ||||||||||||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||||||||||||
|
@@ -46,6 +46,30 @@ def _dtype_validation(dt1, dt2): | |||||||||||||||||||||||||||||||
raise TypeError(f"dtype mismatch: __init__={dt1}, __call__={dt2}") | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def _validate_and_get_stream(stream) -> Optional[int]: | ||||||||||||||||||||||||||||||||
# null stream is allowed | ||||||||||||||||||||||||||||||||
if stream is None: | ||||||||||||||||||||||||||||||||
return None | ||||||||||||||||||||||||||||||||
Comment on lines
+50
to
+52
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I think this would be a common source of bugs. The first naive question is would reduce_into(..., stream=cp.cuda.get_current_stream()) in order to preserve the respective library's stream ordering. If it's the case, we probably should explicitly forbid There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI, in nvmath-python There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Could you expand on this a bit? How would There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess I see what you're asking. For interpreting a provided stream, the protocol will help. For understanding the stream semantics of each array library (ex: does it have the notion of a current stream?), no the protocol would not help. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes - this would match the API of the corresponding C++ function where the
Please correct me if I'm wrong, but my understanding is that your concern is about the use of APIs like CuPy's with torch.cuda.stream() as s:
torch.some_function(...) # no need to pass a stream explicitly, uses `s` implicitly The above works great as long as I'm only using PyTorch, but if I want to combine PyTorch with e.g., CuPy: with torch.cuda.stream() as s:
torch.some_function(...) # no need to pass a stream explicitly, uses `s` implicitly
# need to pass stream explicitly, as cupy doesn't know about
# PyTorch's "current stream":
cupy.some_other_function(..., s) I certainly agree that what you're describing is a concern, but I don't feel that the API decisions of downstream libraries like CuPy or PyTorch should influence the APIs of upstream libraries like The default stream is a reasonable default and 'just works' across the ecosystem for the majority of users who don't necessarily want to use CUDA streams. I would prefer we keep things easy for them. If someone opts in to using streams, then I think it's fine to require that they take additional care to pass streams appropriately to the various functions they use across libraries (which is something they need to do already). |
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if not hasattr(stream, "__cuda_stream__"): | ||||||||||||||||||||||||||||||||
raise TypeError( | ||||||||||||||||||||||||||||||||
f"stream argument {stream} does not implement the '__cuda_stream__' protocol" | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
stream_property = stream.__cuda_stream__ | ||||||||||||||||||||||||||||||||
Comment on lines
+54
to
+59
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general, EAFP is more idiomatic, and also as being discussed in NVIDIA/cuda-python#348, much faster: So instead of: if not hasattr(obj, "attr"):
raise TypeError(...)
attr = obj.attr Prefer: try:
attr = obj.attr
except AttributeError:
raise TypeError(...) As a rule of thumb, we should avoid (side note: IIRC there's at least one other place in the codebase that we're using |
||||||||||||||||||||||||||||||||
if ( | ||||||||||||||||||||||||||||||||
isinstance(stream_property, tuple) | ||||||||||||||||||||||||||||||||
and len(stream_property) == 2 | ||||||||||||||||||||||||||||||||
and all(isinstance(i, int) for i in stream_property) | ||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||
version, handle = stream_property | ||||||||||||||||||||||||||||||||
return handle | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
raise TypeError( | ||||||||||||||||||||||||||||||||
f"__cuda_stream__ property of '{stream}' must return a 'Tuple[int, int]'; got {stream_property} instead" | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
Comment on lines
+60
to
+70
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Personally, I'd be a bit more lax here and really only ensure that
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should also check if version is 0, because we could change it to say 3-tuple in a future version of the protocol. |
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
class _Reduce: | ||||||||||||||||||||||||||||||||
# TODO: constructor shouldn't require concrete `d_in`, `d_out`: | ||||||||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||||||||
|
@@ -85,7 +109,9 @@ def __init__( | |||||||||||||||||||||||||||||||
if error != enums.CUDA_SUCCESS: | ||||||||||||||||||||||||||||||||
raise ValueError("Error building reduce") | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def __call__(self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray): | ||||||||||||||||||||||||||||||||
def __call__( | ||||||||||||||||||||||||||||||||
self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray, stream=None | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't have strong opinion in this, just FYI: in some places in def func(..., *, stream): so that we enforce |
||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||
d_in_cccl = cccl.to_cccl_iter(d_in) | ||||||||||||||||||||||||||||||||
if d_in_cccl.type.value == cccl.IteratorKind.ITERATOR: | ||||||||||||||||||||||||||||||||
assert num_items is not None | ||||||||||||||||||||||||||||||||
|
@@ -101,6 +127,7 @@ def __call__(self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray | |||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
_dtype_validation(self._ctor_d_out_dtype, d_out.dtype) | ||||||||||||||||||||||||||||||||
_dtype_validation(self._ctor_init_dtype, h_init.dtype) | ||||||||||||||||||||||||||||||||
stream_handle = _validate_and_get_stream(stream) | ||||||||||||||||||||||||||||||||
bindings = get_bindings() | ||||||||||||||||||||||||||||||||
if temp_storage is None: | ||||||||||||||||||||||||||||||||
temp_storage_bytes = ctypes.c_size_t() | ||||||||||||||||||||||||||||||||
|
@@ -120,7 +147,7 @@ def __call__(self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray | |||||||||||||||||||||||||||||||
ctypes.c_ulonglong(num_items), | ||||||||||||||||||||||||||||||||
self.op_wrapper.handle(), | ||||||||||||||||||||||||||||||||
cccl.host_array_to_value(h_init), | ||||||||||||||||||||||||||||||||
None, | ||||||||||||||||||||||||||||||||
stream_handle, | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
if error != enums.CUDA_SUCCESS: | ||||||||||||||||||||||||||||||||
raise ValueError("Error reducing") | ||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -550,3 +550,88 @@ def binary_op(x, y): | |
d_in = cp.zeros(size)[::2] | ||
with pytest.raises(ValueError, match="Non-contiguous arrays are not supported."): | ||
_ = algorithms.reduce_into(d_in, d_out, binary_op, h_init) | ||
|
||
|
||
def test_reduce_with_stream(): | ||
# Simple cupy stream wrapper that implements the __cuda_stream__ protocol for the purposes of this test | ||
class Stream: | ||
def __init__(self, cp_stream): | ||
self.cp_stream = cp_stream | ||
|
||
@property | ||
def __cuda_stream__(self): | ||
return (0, self.cp_stream.ptr) | ||
|
||
def add_op(x, y): | ||
return x + y | ||
|
||
h_init = np.asarray([0], dtype=np.int32) | ||
h_in = random_int(5, np.int32) | ||
|
||
stream = cp.cuda.Stream() | ||
with stream: | ||
d_in = cp.asarray(h_in) | ||
d_out = cp.empty(1, dtype=np.int32) | ||
|
||
stream_wrapper = Stream(stream) | ||
reduce_into = algorithms.reduce_into( | ||
d_in=d_in, d_out=d_out, op=add_op, h_init=h_init | ||
) | ||
temp_storage_size = reduce_into( | ||
None, | ||
d_in=d_in, | ||
d_out=d_out, | ||
num_items=d_in.size, | ||
h_init=h_init, | ||
stream=stream_wrapper, | ||
) | ||
d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8) | ||
|
||
reduce_into(d_temp_storage, d_in, d_out, d_in.size, h_init, stream=stream_wrapper) | ||
np.testing.assert_allclose(d_in.sum().get(), d_out.get()) | ||
Comment on lines
+590
to
+591
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should call Perhaps our wrapper There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wearing my CuPy hat: Just call cupy.asnumpy(..., stream=stream, blocking=True) or with stream:
cupy.asnumpy(..., blocking=True) to perform a stream-ordered, blocking copy to host. No need to add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Even better is call with stream:
cp.testing.assert_allclose(...) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In general, is it recommended that we rely on this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it's public API. |
||
|
||
|
||
def test_reduce_invalid_stream(): | ||
# Invalid stream that doesn't implement __cuda_stream__ | ||
class Stream1: | ||
def __init__(self): | ||
pass | ||
|
||
# Invalid stream that implements __cuda_stream__ but returns the wrong type | ||
class Stream2: | ||
def __init__(self): | ||
pass | ||
|
||
@property | ||
def __cuda_stream__(self): | ||
return None | ||
|
||
def add_op(x, y): | ||
return x + y | ||
|
||
d_out = cp.empty(1) | ||
h_init = np.empty(1) | ||
d_in = cp.empty(1) | ||
reduce_into = algorithms.reduce_into(d_in, d_out, add_op, h_init) | ||
|
||
with pytest.raises( | ||
TypeError, match="does not implement the '__cuda_stream__' protocol" | ||
): | ||
_ = reduce_into( | ||
None, | ||
d_in=d_in, | ||
d_out=d_out, | ||
num_items=d_in.size, | ||
h_init=h_init, | ||
stream=Stream1(), | ||
) | ||
|
||
with pytest.raises(TypeError, match="must return a 'Tuple\\[int, int\\]';"): | ||
_ = reduce_into( | ||
None, | ||
d_in=d_in, | ||
d_out=d_out, | ||
num_items=d_in.size, | ||
h_init=h_init, | ||
stream=Stream2(), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In terms of where this function should live, here are a couple of suggestions:
_utils/stream.py
_utils/cai.py
to_utils/protocols.py
and move it there (this module could be general utilities for working with protocol objects like__cuda_array_interface__
and__cuda_stream__
)