Skip to content

Commit

Permalink
Fix cumulative_sum
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Aug 27, 2024
1 parent 84b8f4d commit 1ab34a7
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Changelog
**Bug fixes**

- Various operations that depend on the array's shape have been updated to work correctly with lazy arrays.
- Fixes :func:`~ndonnx.cumulative_sum` to correctly apply the ``include_initial`` parameter and workaround missing ORT kernels for unsigned integral types.

**Breaking change**

Expand Down
38 changes: 34 additions & 4 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def searchsorted(
how_many[
ndx.where(indices_x1 + 1 <= combined_shape[0], indices_x1 + 1, indices_x1)
] = counts
how_many = ndx.cumulative_sum(how_many, include_initial=True)
how_many = ndx.cumulative_sum(how_many, include_initial=False, axis=None)

ret = ndx.zeros(nda.shape(x2), dtype=dtypes.int64)

Expand Down Expand Up @@ -566,13 +566,43 @@ def cumulative_sum(
axis = 0
else:
raise ValueError("axis must be specified for multi-dimensional arrays")

if dtype is None:
if isinstance(x.dtype, (dtypes.Unsigned, dtypes.NullableUnsigned)):
if ndx.iinfo(x.dtype).bits < 64:
out = x.astype(dtypes.int64)
else:
raise ValueError(f"Cannot perform `cumulative_sum` using {x.dtype}")
else:
out = x.astype(_determine_reduce_op_dtype(x, dtype, dtypes.int64))
else:
out = out.astype(dtype)

out = from_corearray(
opx.cumsum(
x._core(), axis=opx.const(axis), exclusive=int(not include_initial)
out._core(),
axis=opx.const(axis),
exclusive=0,
)
)
if dtype is not None:
out = out.astype(dtype)

if isinstance(x.dtype, dtypes.Unsigned):
out = out.astype(ndx.uint64)
elif isinstance(x.dtype, dtypes.NullableUnsigned):
out = out.astype(ndx.nuint64)

# Exclude axis and create zeros of that shape
if include_initial:
out_shape = nda.shape(out)
out_shape[axis] = 1
out = ndx.concat(
[
ndx.zeros(out_shape, dtype=out.dtype),
out,
],
axis=axis,
)

return out

@validate_core
Expand Down
2 changes: 1 addition & 1 deletion ndonnx/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def broadcast_to(x, shape):
# TODO: onnxruntime doesn't work for 2 empty arrays of integer type
# TODO: what is the appropriate strategy to dispatch? (iterate over the inputs and keep trying is reasonable but it can
# change the outcome based on order if poorly implemented)
def concat(arrays, axis=None):
def concat(arrays, axis=0):
if axis is None:
arrays = [reshape(x, [-1]) for x in arrays]
axis = 0
Expand Down
38 changes: 38 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,3 +926,41 @@ def test_lazy_array_shape(x, expected_shape):
def test_dynamic_reshape_has_no_static_shape(x, shape):
with pytest.raises(ValueError, match="Could not determine static shape"):
ndx.reshape(x, shape).shape


@pytest.mark.skipif(
not np.__version__.startswith("2"), reason="NumPy >= 2 used for test assertions"
)
@pytest.mark.parametrize("include_initial", [True, False])
@pytest.mark.parametrize(
"dtype",
[ndx.int32, ndx.int64, ndx.float32, ndx.float64, ndx.uint8, ndx.uint16, ndx.uint32],
)
@pytest.mark.parametrize(
"array, axis",
[
([1, 2, 3], None),
([1, 2, 3], 0),
([[1, 2], [3, 4]], 0),
([[1, 2], [3, 4]], 1),
([[1, 2, 50], [3, 4, 5]], 1),
([[[[1]]], [[[3]]]], 0),
([[[[1]]], [[[3]]]], 1),
],
)
def test_cumulative_sum(array, axis, include_initial, dtype):
a = ndx.asarray(array, dtype=dtype)
assert_array_equal(
ndx.cumulative_sum(a, include_initial=include_initial, axis=axis).to_numpy(),
np.cumulative_sum(
np.asarray(array, a.dtype.to_numpy_dtype()),
include_initial=include_initial,
axis=axis,
),
)


def test_no_unsafe_cumulative_sum_cast():
with pytest.raises(ValueError, match="Cannot perform `cumulative_sum`"):
a = ndx.asarray([1, 2, 3], ndx.uint64)
ndx.cumulative_sum(a)

0 comments on commit 1ab34a7

Please sign in to comment.