Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin' into construct-member
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Aug 28, 2024
2 parents f89fe11 + d64dd52 commit b86e765
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 23 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ Changelog
**Bug fixes**

- Various operations that depend on the array's shape have been updated to work correctly with lazy arrays.
- :func:`ndonnx.cumulative_sum` now correctly applies the ``include_initial`` parameter and works around missing onnxruntime kernels for unsigned integral types.
- :func:`ndonnx.additional.make_nullable` applies broadcasting to the provided null array (instead of reshape like it did previously). This allows writing ``make_nullable(x, False)`` to turn an array into nullable.

**Breaking change**

Expand Down
2 changes: 1 addition & 1 deletion ndonnx/_core/_coreimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ def make_nullable(self, x, null):
return ndx.Array._from_fields(
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.reshape(null, nda.shape(x)),
null=ndx.broadcast_to(null, nda.shape(x)),
)
40 changes: 36 additions & 4 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def searchsorted(
how_many[
ndx.where(indices_x1 + 1 <= combined_shape[0], indices_x1 + 1, indices_x1)
] = counts
how_many = ndx.cumulative_sum(how_many, include_initial=True)
how_many = ndx.cumulative_sum(how_many, include_initial=False, axis=None)

ret = ndx.zeros(nda.shape(x2), dtype=dtypes.int64)

Expand Down Expand Up @@ -568,13 +568,45 @@ def cumulative_sum(
axis = 0
else:
raise ValueError("axis must be specified for multi-dimensional arrays")

if dtype is None:
if isinstance(x.dtype, (dtypes.Unsigned, dtypes.NullableUnsigned)):
if ndx.iinfo(x.dtype).bits < 64:
out = x.astype(dtypes.int64)
else:
raise ndx.UnsupportedOperationError(
f"Cannot perform `cumulative_sum` using {x.dtype}"
)
else:
out = x.astype(_determine_reduce_op_dtype(x, dtype, dtypes.int64))
else:
out = out.astype(dtype)

out = from_corearray(
opx.cumsum(
x._core(), axis=opx.const(axis), exclusive=int(not include_initial)
out._core(),
axis=opx.const(axis),
exclusive=0,
)
)
if dtype is not None:
out = out.astype(dtype)

if isinstance(x.dtype, dtypes.Unsigned):
out = out.astype(ndx.uint64)
elif isinstance(x.dtype, dtypes.NullableUnsigned):
out = out.astype(ndx.nuint64)

# Exclude axis and create zeros of that shape
if include_initial:
out_shape = nda.shape(out)
out_shape[axis] = 1
out = ndx.concat(
[
ndx.zeros(out_shape, dtype=out.dtype),
out,
],
axis=axis,
)

return out

@validate_core
Expand Down
2 changes: 1 addition & 1 deletion ndonnx/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def broadcast_to(x, shape):
# TODO: onnxruntime doesn't work for 2 empty arrays of integer type
# TODO: what is the appropriate strategy to dispatch? (iterate over the inputs and keep trying is reasonable but it can
# change the outcome based on order if poorly implemented)
def concat(arrays, axis=None):
def concat(arrays, /, *, axis: int | None = 0):
if axis is None:
arrays = [reshape(x, [-1]) for x in arrays]
axis = 0
Expand Down
27 changes: 17 additions & 10 deletions ndonnx/_opset_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,24 +707,31 @@ def reshape_like(x: _CoreArray, y: _CoreArray) -> _CoreArray:
def static_map(
input: _CoreArray, mapping: Mapping[KeyType, ValueType], default: ValueType | None
) -> _CoreArray:
keys = np.array(tuple(mapping.keys()))
if keys.dtype == np.int32:
keys = keys.astype(np.int64)
values = np.array(tuple(mapping.values()))
keys = np.asarray(tuple(mapping.keys()))
values = np.asarray(tuple(mapping.values()))
if isinstance(input.dtype, dtypes.Integral):
input = input.astype(dtypes.int64)
# Should only be a relevant path in Windows NumPy 1.x
if keys.dtype != np.dtype("int64") and keys.dtype.kind == "i":
keys = keys.astype(input.dtype.to_numpy_dtype())
elif isinstance(input.dtype, dtypes.Floating):
input = input.astype(dtypes.float64)

value_dtype = values.dtype
if default is None:
if value_dtype.kind == "U" or (
value_dtype.kind == "O" and all(isinstance(x, str) for x in values.flat)
):
default_tensor = np.array(["MISSING"])
default_tensor = np.asarray(["MISSING"])
else:
default_tensor = np.array([0], dtype=value_dtype)
default_tensor = np.asarray([0], dtype=value_dtype)
elif value_dtype.kind == "U":
default_tensor = np.asarray([default], dtype=np.str_)
else:
default_tensor = np.array([default], dtype=value_dtype)

if keys.dtype == np.float64 and isinstance(input.dtype, dtypes.Integral):
default_tensor = np.asarray([default], dtype=value_dtype)
if keys.dtype.kind == "f" and isinstance(input.dtype, dtypes.Integral):
input = cast(input, dtypes.from_numpy_dtype(keys.dtype))
elif keys.dtype == np.int64 and isinstance(input.dtype, dtypes.Floating):
elif keys.dtype.kind == "i" and isinstance(input.dtype, dtypes.Floating):
keys = keys.astype(input.dtype.to_numpy_dtype())
return _CoreArray(
ml.label_encoder(
Expand Down
4 changes: 3 additions & 1 deletion ndonnx/additional/_additional.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def static_map(
A new Array with the values mapped according to the mapping.
"""
if not isinstance(x.dtype, ndx.CoreType):
raise TypeError("static_map accepts only non-nullable arrays")
raise ndx.UnsupportedOperationError(
"'static_map' accepts only non-nullable arrays"
)
data = opx.static_map(x._core(), mapping, default)
return ndx.Array._from_fields(data.dtype, data=data)

Expand Down
100 changes: 94 additions & 6 deletions tests/test_additional.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,12 @@ def test_searchsorted_raises():


@pytest.mark.skipif(
sys.platform.startswith("win"),
sys.platform.startswith("win") and np.__version__ < "2",
reason="ORT 1.18 not registering LabelEncoder(4) only on Windows.",
)
def test_static_map():
def test_static_map_lazy():
a = ndx.array(shape=(3,), dtype=ndx.int64)
b = nda.static_map(a, {1: 2, 2: 3})

model = ndx.build({"a": a}, {"b": b})
assert_array_equal([0, 2, 3], run(model, {"a": np.array([0, 1, 2])})["b"])

Expand All @@ -87,9 +86,72 @@ def test_static_map():
run(model, {"a": np.array([0.0, 2.0, 3.0, np.nan])})["b"],
)

a = ndx.asarray(["hello", "world", "!"])
b = nda.static_map(a, {"hello": "hi", "world": "earth"})
np.testing.assert_equal(["hi", "earth", "MISSING"], b.to_numpy())

@pytest.mark.skipif(
sys.platform.startswith("win") and np.__version__ < "2",
reason="ORT 1.18 not registering LabelEncoder(4) only on Windows.",
)
@pytest.mark.parametrize(
"x, mapping, default, expected",
[
(
ndx.asarray(["hello", "world", "!"]),
{"hello": "hi", "world": "earth"},
None,
["hi", "earth", "MISSING"],
),
(
ndx.asarray(["hello", "world", "!"]),
{"hello": "hi", "world": "earth"},
"DIFFERENT",
["hi", "earth", "DIFFERENT"],
),
(ndx.asarray([0, 1, 2], dtype=ndx.int64), {0: -1, 1: -2}, None, [-1, -2, 0]),
(ndx.asarray([0, 1, 2], dtype=ndx.int64), {0: -1, 1: -2}, 42, [-1, -2, 42]),
(
ndx.asarray([[0], [1], [2]], dtype=ndx.int64),
{0: -1, 1: -2},
42,
[[-1], [-2], [42]],
),
(
ndx.asarray([[0], [1], [2]], dtype=ndx.int32),
{0: -1, 1: -2},
42,
[[-1], [-2], [42]],
),
(
ndx.asarray([[0], [1], [2]], dtype=ndx.int8),
{0: -1, 1: -2},
42,
[[-1], [-2], [42]],
),
(
ndx.asarray([[0], [1], [2]], dtype=ndx.uint8),
{0: -1, 1: -2},
42,
[[-1], [-2], [42]],
),
(
ndx.asarray([[0], [1], [np.nan]], dtype=ndx.float32),
{0: -1, 1: -2, np.nan: 3.142},
42,
[[-1], [-2], [3.142]],
),
],
)
def test_static_map(x, mapping, default, expected):
actual = nda.static_map(x, mapping, default=default)
assert_array_equal(actual.to_numpy(), expected)


def test_static_map_unimplemented_for_nullable():
a = ndx.asarray([1, 2, 3], dtype=ndx.int64)
m = ndx.asarray([True, False, True])
a = nda.make_nullable(a, m)

with pytest.raises(ndx.UnsupportedOperationError):
nda.static_map(a, {1: 2, 2: 3})


@pytest.mark.skipif(
Expand Down Expand Up @@ -117,3 +179,29 @@ def test_isin():

a = ndx.asarray(["hello", "world"])
assert_array_equal([True, False], nda.isin(a, ["hello"]).to_numpy())


@pytest.mark.parametrize(
"dtype",
[
ndx.int64,
ndx.utf8,
ndx.bool,
],
)
@pytest.mark.parametrize(
"mask",
[
[True, False, True],
True,
False,
[True],
],
)
def test_make_nullable(dtype, mask):
a = ndx.asarray([1, 2, 3], dtype=dtype)
m = ndx.asarray(mask)

result = nda.make_nullable(a, m)
expected = np.ma.masked_array([1, 2, 3], mask, dtype.to_numpy_dtype())
assert_array_equal(result.to_numpy(), expected)
40 changes: 40 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,3 +926,43 @@ def test_lazy_array_shape(x, expected_shape):
def test_dynamic_reshape_has_no_static_shape(x, shape):
with pytest.raises(ValueError, match="Could not determine static shape"):
ndx.reshape(x, shape).shape


@pytest.mark.skipif(
not np.__version__.startswith("2"), reason="NumPy >= 2 used for test assertions"
)
@pytest.mark.parametrize("include_initial", [True, False])
@pytest.mark.parametrize(
"dtype",
[ndx.int32, ndx.int64, ndx.float32, ndx.float64, ndx.uint8, ndx.uint16, ndx.uint32],
)
@pytest.mark.parametrize(
"array, axis",
[
([1, 2, 3], None),
([1, 2, 3], 0),
([[1, 2], [3, 4]], 0),
([[1, 2], [3, 4]], 1),
([[1, 2, 50], [3, 4, 5]], 1),
([[[[1]]], [[[3]]]], 0),
([[[[1]]], [[[3]]]], 1),
],
)
def test_cumulative_sum(array, axis, include_initial, dtype):
a = ndx.asarray(array, dtype=dtype)
assert_array_equal(
ndx.cumulative_sum(a, include_initial=include_initial, axis=axis).to_numpy(),
np.cumulative_sum(
np.asarray(array, a.dtype.to_numpy_dtype()),
include_initial=include_initial,
axis=axis,
),
)


def test_no_unsafe_cumulative_sum_cast():
with pytest.raises(
ndx.UnsupportedOperationError, match="Cannot perform `cumulative_sum`"
):
a = ndx.asarray([1, 2, 3], ndx.uint64)
ndx.cumulative_sum(a)

0 comments on commit b86e765

Please sign in to comment.