Skip to content

Commit

Permalink
Make static_map more general
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Aug 28, 2024
1 parent 6da53f3 commit 0b69318
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 17 deletions.
27 changes: 17 additions & 10 deletions ndonnx/_opset_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,24 +707,31 @@ def reshape_like(x: _CoreArray, y: _CoreArray) -> _CoreArray:
def static_map(
input: _CoreArray, mapping: Mapping[KeyType, ValueType], default: ValueType | None
) -> _CoreArray:
keys = np.array(tuple(mapping.keys()))
if keys.dtype == np.int32:
keys = keys.astype(np.int64)
values = np.array(tuple(mapping.values()))
keys = np.asarray(tuple(mapping.keys()))
values = np.asarray(tuple(mapping.values()))
if isinstance(input.dtype, dtypes.Integral):
input = input.astype(dtypes.int64)
# Should only be a relevant path in Windows NumPy 1.x
if keys.dtype != np.dtype("int64") and keys.dtype.kind == "i":
keys = keys.astype(input.dtype.to_numpy_dtype())
elif isinstance(input.dtype, dtypes.Floating):
input = input.astype(dtypes.float64)

value_dtype = values.dtype
if default is None:
if value_dtype.kind == "U" or (
value_dtype.kind == "O" and all(isinstance(x, str) for x in values.flat)
):
default_tensor = np.array(["MISSING"])
default_tensor = np.asarray(["MISSING"])
else:
default_tensor = np.array([0], dtype=value_dtype)
default_tensor = np.asarray([0], dtype=value_dtype)
elif value_dtype.kind == "U":
default_tensor = np.asarray([default], dtype=np.str_)
else:
default_tensor = np.array([default], dtype=value_dtype)

if keys.dtype == np.float64 and isinstance(input.dtype, dtypes.Integral):
default_tensor = np.asarray([default], dtype=value_dtype)
if keys.dtype.kind == "f" and isinstance(input.dtype, dtypes.Integral):
input = cast(input, dtypes.from_numpy_dtype(keys.dtype))
elif keys.dtype == np.int64 and isinstance(input.dtype, dtypes.Floating):
elif keys.dtype.kind == "i" and isinstance(input.dtype, dtypes.Floating):
keys = keys.astype(input.dtype.to_numpy_dtype())
return _CoreArray(
ml.label_encoder(
Expand Down
4 changes: 3 additions & 1 deletion ndonnx/additional/_additional.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def static_map(
A new Array with the values mapped according to the mapping.
"""
if not isinstance(x.dtype, ndx.CoreType):
raise TypeError("static_map accepts only non-nullable arrays")
raise ndx.UnsupportedOperationError(
"'static_map' accepts only non-nullable arrays"
)
data = opx.static_map(x._core(), mapping, default)
return ndx.Array._from_fields(data.dtype, data=data)

Expand Down
74 changes: 68 additions & 6 deletions tests/test_additional.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,12 @@ def test_searchsorted_raises():


@pytest.mark.skipif(
sys.platform.startswith("win"),
sys.platform.startswith("win") and np.__version__ < "2",
reason="ORT 1.18 not registering LabelEncoder(4) only on Windows.",
)
def test_static_map():
def test_static_map_lazy():
a = ndx.array(shape=(3,), dtype=ndx.int64)
b = nda.static_map(a, {1: 2, 2: 3})

model = ndx.build({"a": a}, {"b": b})
assert_array_equal([0, 2, 3], run(model, {"a": np.array([0, 1, 2])})["b"])

Expand All @@ -87,9 +86,72 @@ def test_static_map():
run(model, {"a": np.array([0.0, 2.0, 3.0, np.nan])})["b"],
)

a = ndx.asarray(["hello", "world", "!"])
b = nda.static_map(a, {"hello": "hi", "world": "earth"})
np.testing.assert_equal(["hi", "earth", "MISSING"], b.to_numpy())

@pytest.mark.skipif(
sys.platform.startswith("win") and np.__version__ < "2",
reason="ORT 1.18 not registering LabelEncoder(4) only on Windows.",
)
@pytest.mark.parametrize(
"x, mapping, default, expected",
[
(
ndx.asarray(["hello", "world", "!"]),
{"hello": "hi", "world": "earth"},
None,
["hi", "earth", "MISSING"],
),
(
ndx.asarray(["hello", "world", "!"]),
{"hello": "hi", "world": "earth"},
"DIFFERENT",
["hi", "earth", "DIFFERENT"],
),
(ndx.asarray([0, 1, 2], dtype=ndx.int64), {0: -1, 1: -2}, None, [-1, -2, 0]),
(ndx.asarray([0, 1, 2], dtype=ndx.int64), {0: -1, 1: -2}, 42, [-1, -2, 42]),
(
ndx.asarray([[0], [1], [2]], dtype=ndx.int64),
{0: -1, 1: -2},
42,
[[-1], [-2], [42]],
),
(
ndx.asarray([[0], [1], [2]], dtype=ndx.int32),
{0: -1, 1: -2},
42,
[[-1], [-2], [42]],
),
(
ndx.asarray([[0], [1], [2]], dtype=ndx.int8),
{0: -1, 1: -2},
42,
[[-1], [-2], [42]],
),
(
ndx.asarray([[0], [1], [2]], dtype=ndx.uint8),
{0: -1, 1: -2},
42,
[[-1], [-2], [42]],
),
(
ndx.asarray([[0], [1], [np.nan]], dtype=ndx.float32),
{0: -1, 1: -2, np.nan: 3.142},
42,
[[-1], [-2], [3.142]],
),
],
)
def test_static_map(x, mapping, default, expected):
actual = nda.static_map(x, mapping, default=default)
assert_array_equal(actual.to_numpy(), expected)


def test_static_map_unimplemented_for_nullable():
a = ndx.asarray([1, 2, 3], dtype=ndx.int64)
m = ndx.asarray([True, False, True])
a = nda.make_nullable(a, m)

with pytest.raises(ndx.UnsupportedOperationError):
nda.static_map(a, {1: 2, 2: 3})


@pytest.mark.skipif(
Expand Down

0 comments on commit 0b69318

Please sign in to comment.