Skip to content

Commit

Permalink
Make prod work with u32 default type (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 authored Aug 22, 2024
1 parent af78c47 commit 9837b83
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
18 changes: 13 additions & 5 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ def prod(
else:
axes = axis # type: ignore

x = x.astype(_determine_reduce_op_dtype(x, dtype))
x = x.astype(_determine_reduce_op_dtype(x, dtype, dtypes.uint32))

if isinstance(x.dtype, dtypes.NullableNumerical):
fill_value = ndx.asarray(1, dtype=x.dtype.values)
Expand Down Expand Up @@ -760,7 +760,7 @@ def sum(
else:
axes = axis # type: ignore

x = x.astype(_determine_reduce_op_dtype(x, dtype))
x = x.astype(_determine_reduce_op_dtype(x, dtype, dtypes.uint64))

# Fill any nulls with 0
if isinstance(x.dtype, dtypes.NullableNumerical):
Expand Down Expand Up @@ -972,14 +972,22 @@ def _via_i64_f64(


def _determine_reduce_op_dtype(
x: Array, dtype: dtypes.CoreType | dtypes.StructType | None
x: Array,
dtype: dtypes.CoreType | dtypes.StructType | None,
maximum_unsigned_dtype: dtypes.CoreType,
) -> dtypes.CoreType | dtypes.StructType:
if dtype is not None:
return dtype
elif isinstance(x.dtype, dtypes.Unsigned):
return dtypes.uint64
if ndx.iinfo(x.dtype).bits <= ndx.iinfo(maximum_unsigned_dtype).bits:
return maximum_unsigned_dtype
else:
raise TypeError(f"Cannot reduce {x.dtype} to a smaller unsigned dtype")
elif isinstance(x.dtype, dtypes.NullableUnsigned):
return dtypes.nuint64
if ndx.iinfo(x.dtype).bits <= ndx.iinfo(maximum_unsigned_dtype).bits:
return dtypes.promote_nullable(maximum_unsigned_dtype)
else:
raise TypeError(f"Cannot reduce {x.dtype} to a smaller unsigned dtype")
elif isinstance(x.dtype, dtypes.Integral):
return dtypes.int64
elif isinstance(x.dtype, dtypes.NullableIntegral):
Expand Down
28 changes: 26 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,14 +633,38 @@ def test_prod(dtype):
assert_array_equal(y.to_numpy(), actual)


@pytest.mark.parametrize(
"dtype",
[
ndx.uint8,
ndx.uint16,
ndx.uint32,
ndx.nuint8,
ndx.nuint16,
ndx.nuint32,
],
)
def test_prod_unsigned(dtype):
# We intentionally deviate from the Array API standard and reduce for <= uint32
# using uint32 as our default unsigned type due to lack of kernel support for uint64
x = ndx.asarray([2, 2]).astype(dtype)
y = ndx.prod(x)
if isinstance(dtype, ndx.Nullable):
input = np.asarray([2, 2], dtype=dtype.values.to_numpy_dtype())
input = np.ma.masked_array(input, mask=False)
else:
input = np.asarray([2, 2], dtype=dtype.to_numpy_dtype())
actual = np.prod(input).astype(np.uint32)

assert_array_equal(y.to_numpy(), actual)


@pytest.mark.parametrize(
"dtype",
[
ndx.float64,
ndx.nfloat64,
ndx.uint64,
ndx.uint32,
ndx.uint8,
ndx.nuint64,
],
)
Expand Down

0 comments on commit 9837b83

Please sign in to comment.