From 9837b835cca65b87b2052fdbac1c31c0200c3893 Mon Sep 17 00:00:00 2001 From: Aditya Goel <48102515+adityagoel4512@users.noreply.github.com> Date: Thu, 22 Aug 2024 15:11:27 +0100 Subject: [PATCH] Make prod work with u32 default type (#58) --- ndonnx/_core/_numericimpl.py | 18 +++++++++++++----- tests/test_core.py | 28 ++++++++++++++++++++++++++-- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/ndonnx/_core/_numericimpl.py b/ndonnx/_core/_numericimpl.py index 0b31361..0b8641f 100644 --- a/ndonnx/_core/_numericimpl.py +++ b/ndonnx/_core/_numericimpl.py @@ -663,7 +663,7 @@ def prod( else: axes = axis # type: ignore - x = x.astype(_determine_reduce_op_dtype(x, dtype)) + x = x.astype(_determine_reduce_op_dtype(x, dtype, dtypes.uint32)) if isinstance(x.dtype, dtypes.NullableNumerical): fill_value = ndx.asarray(1, dtype=x.dtype.values) @@ -760,7 +760,7 @@ def sum( else: axes = axis # type: ignore - x = x.astype(_determine_reduce_op_dtype(x, dtype)) + x = x.astype(_determine_reduce_op_dtype(x, dtype, dtypes.uint64)) # Fill any nulls with 0 if isinstance(x.dtype, dtypes.NullableNumerical): @@ -972,14 +972,22 @@ def _via_i64_f64( def _determine_reduce_op_dtype( - x: Array, dtype: dtypes.CoreType | dtypes.StructType | None + x: Array, + dtype: dtypes.CoreType | dtypes.StructType | None, + maximum_unsigned_dtype: dtypes.CoreType, ) -> dtypes.CoreType | dtypes.StructType: if dtype is not None: return dtype elif isinstance(x.dtype, dtypes.Unsigned): - return dtypes.uint64 + if ndx.iinfo(x.dtype).bits <= ndx.iinfo(maximum_unsigned_dtype).bits: + return maximum_unsigned_dtype + else: + raise TypeError(f"Cannot reduce {x.dtype} to a smaller unsigned dtype") elif isinstance(x.dtype, dtypes.NullableUnsigned): - return dtypes.nuint64 + if ndx.iinfo(x.dtype).bits <= ndx.iinfo(maximum_unsigned_dtype).bits: + return dtypes.promote_nullable(maximum_unsigned_dtype) + else: + raise TypeError(f"Cannot reduce {x.dtype} to a smaller unsigned dtype") elif isinstance(x.dtype, dtypes.Integral): return dtypes.int64 elif isinstance(x.dtype, dtypes.NullableIntegral): diff --git a/tests/test_core.py b/tests/test_core.py index fa90dbe..7ec974f 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -633,14 +633,38 @@ def test_prod(dtype): assert_array_equal(y.to_numpy(), actual) +@pytest.mark.parametrize( + "dtype", + [ + ndx.uint8, + ndx.uint16, + ndx.uint32, + ndx.nuint8, + ndx.nuint16, + ndx.nuint32, + ], +) +def test_prod_unsigned(dtype): + # We intentionally deviate from the Array API standard and reduce for <= uint32 + # using uint32 as our default unsigned type due to lack of kernel support for uint64 + x = ndx.asarray([2, 2]).astype(dtype) + y = ndx.prod(x) + if isinstance(dtype, ndx.Nullable): + input = np.asarray([2, 2], dtype=dtype.values.to_numpy_dtype()) + input = np.ma.masked_array(input, mask=False) + else: + input = np.asarray([2, 2], dtype=dtype.to_numpy_dtype()) + actual = np.prod(input).astype(np.uint32) + + assert_array_equal(y.to_numpy(), actual) + + @pytest.mark.parametrize( "dtype", [ ndx.float64, ndx.nfloat64, ndx.uint64, - ndx.uint32, - ndx.uint8, ndx.nuint64, ], )