From 5316ecc76b839c080766757b215dae79e6158308 Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Wed, 28 Aug 2024 10:21:26 +0200 Subject: [PATCH] Make static_map more general --- ndonnx/_opset_extensions.py | 27 +++++++----- ndonnx/additional/_additional.py | 4 +- tests/test_additional.py | 76 ++++++++++++++++++++++++++++---- 3 files changed, 87 insertions(+), 20 deletions(-) diff --git a/ndonnx/_opset_extensions.py b/ndonnx/_opset_extensions.py index 2cb8660..1ac10df 100644 --- a/ndonnx/_opset_extensions.py +++ b/ndonnx/_opset_extensions.py @@ -707,24 +707,31 @@ def reshape_like(x: _CoreArray, y: _CoreArray) -> _CoreArray: def static_map( input: _CoreArray, mapping: Mapping[KeyType, ValueType], default: ValueType | None ) -> _CoreArray: - keys = np.array(tuple(mapping.keys())) - if keys.dtype == np.int32: - keys = keys.astype(np.int64) - values = np.array(tuple(mapping.values())) + keys = np.asarray(tuple(mapping.keys())) + values = np.asarray(tuple(mapping.values())) + if isinstance(input.dtype, dtypes.Integral): + input = input.astype(dtypes.int64) + # Should only be a relevant path in Windows NumPy 1.x + if keys.dtype != np.dtype("int64") and keys.dtype.kind == "i": + keys = keys.astype(input.dtype.to_numpy_dtype()) + elif isinstance(input.dtype, dtypes.Floating): + input = input.astype(dtypes.float64) + value_dtype = values.dtype if default is None: if value_dtype.kind == "U" or ( value_dtype.kind == "O" and all(isinstance(x, str) for x in values.flat) ): - default_tensor = np.array(["MISSING"]) + default_tensor = np.asarray(["MISSING"]) else: - default_tensor = np.array([0], dtype=value_dtype) + default_tensor = np.asarray([0], dtype=value_dtype) + elif value_dtype.kind == "U": + default_tensor = np.asarray([default], dtype=np.str_) else: - default_tensor = np.array([default], dtype=value_dtype) - - if keys.dtype == np.float64 and isinstance(input.dtype, dtypes.Integral): + default_tensor = np.asarray([default], dtype=value_dtype) + if keys.dtype.kind == "f" and isinstance(input.dtype, dtypes.Integral): input = cast(input, dtypes.from_numpy_dtype(keys.dtype)) - elif keys.dtype == np.int64 and isinstance(input.dtype, dtypes.Floating): + elif keys.dtype.kind == "i" and isinstance(input.dtype, dtypes.Floating): keys = keys.astype(input.dtype.to_numpy_dtype()) return _CoreArray( ml.label_encoder( diff --git a/ndonnx/additional/_additional.py b/ndonnx/additional/_additional.py index 5ec6442..9c55764 100644 --- a/ndonnx/additional/_additional.py +++ b/ndonnx/additional/_additional.py @@ -88,7 +88,9 @@ def static_map( A new Array with the values mapped according to the mapping. """ if not isinstance(x.dtype, ndx.CoreType): - raise TypeError("static_map accepts only non-nullable arrays") + raise ndx.UnsupportedOperationError( + "'static_map' accepts only non-nullable arrays" + ) data = opx.static_map(x._core(), mapping, default) return ndx.Array._from_fields(data.dtype, data=data) diff --git a/tests/test_additional.py b/tests/test_additional.py index ee23b1c..0a5c62b 100644 --- a/tests/test_additional.py +++ b/tests/test_additional.py @@ -66,14 +66,9 @@ def test_searchsorted_raises(): ndx.searchsorted(a, b, side="middle") # type: ignore[arg-type] -@pytest.mark.skipif( - sys.platform.startswith("win"), - reason="ORT 1.18 not registering LabelEncoder(4) only on Windows.", -) -def test_static_map(): +def test_static_map_lazy(): a = ndx.array(shape=(3,), dtype=ndx.int64) b = nda.static_map(a, {1: 2, 2: 3}) - model = ndx.build({"a": a}, {"b": b}) assert_array_equal([0, 2, 3], run(model, {"a": np.array([0, 1, 2])})["b"]) @@ -87,9 +82,72 @@ def test_static_map(): run(model, {"a": np.array([0.0, 2.0, 3.0, np.nan])})["b"], ) - a = ndx.asarray(["hello", "world", "!"]) - b = nda.static_map(a, {"hello": "hi", "world": "earth"}) - np.testing.assert_equal(["hi", "earth", "MISSING"], b.to_numpy()) + +@pytest.mark.skipif( + sys.platform.startswith("win") and np.__version__ < "2", + reason="ORT 1.18 not registering LabelEncoder(4) only on Windows.", +) +@pytest.mark.parametrize( + "x, mapping, default, expected", + [ + ( + ndx.asarray(["hello", "world", "!"]), + {"hello": "hi", "world": "earth"}, + None, + ["hi", "earth", "MISSING"], + ), + ( + ndx.asarray(["hello", "world", "!"]), + {"hello": "hi", "world": "earth"}, + "DIFFERENT", + ["hi", "earth", "DIFFERENT"], + ), + (ndx.asarray([0, 1, 2], dtype=ndx.int64), {0: -1, 1: -2}, None, [-1, -2, 0]), + (ndx.asarray([0, 1, 2], dtype=ndx.int64), {0: -1, 1: -2}, 42, [-1, -2, 42]), + ( + ndx.asarray([[0], [1], [2]], dtype=ndx.int64), + {0: -1, 1: -2}, + 42, + [[-1], [-2], [42]], + ), + ( + ndx.asarray([[0], [1], [2]], dtype=ndx.int32), + {0: -1, 1: -2}, + 42, + [[-1], [-2], [42]], + ), + ( + ndx.asarray([[0], [1], [2]], dtype=ndx.int8), + {0: -1, 1: -2}, + 42, + [[-1], [-2], [42]], + ), + ( + ndx.asarray([[0], [1], [2]], dtype=ndx.uint8), + {0: -1, 1: -2}, + 42, + [[-1], [-2], [42]], + ), + ( + ndx.asarray([[0], [1], [np.nan]], dtype=ndx.float32), + {0: -1, 1: -2, np.nan: 3.142}, + 42, + [[-1], [-2], [3.142]], + ), + ], +) +def test_static_map(x, mapping, default, expected): + actual = nda.static_map(x, mapping, default=default) + assert_array_equal(actual.to_numpy(), expected) + + +def test_static_map_unimplemented_for_nullable(): + a = ndx.asarray([1, 2, 3], dtype=ndx.int64) + m = ndx.asarray([True, False, True]) + a = nda.make_nullable(a, m) + + with pytest.raises(ndx.UnsupportedOperationError): + nda.static_map(a, {1: 2, 2: 3}) @pytest.mark.skipif(