From 7317a8fbc4233b8a879eddbf4ad7b8f09fe3404f Mon Sep 17 00:00:00 2001 From: Matej Urban Date: Tue, 27 Aug 2024 11:56:15 +0100 Subject: [PATCH] Fix make_nullable --- ndonnx/_core/_boolimpl.py | 2 +- ndonnx/_core/_numericimpl.py | 2 +- ndonnx/_core/_stringimpl.py | 2 +- tests/test_additional.py | 26 ++++++++++++++++++++++++++ 4 files changed, 29 insertions(+), 3 deletions(-) diff --git a/ndonnx/_core/_boolimpl.py b/ndonnx/_core/_boolimpl.py index 806c2d2..bf78657 100644 --- a/ndonnx/_core/_boolimpl.py +++ b/ndonnx/_core/_boolimpl.py @@ -170,7 +170,7 @@ def make_nullable(self, x, null): return ndx.Array._from_fields( dtypes.into_nullable(x.dtype), values=x.copy(), - null=ndx.reshape(null, nda.shape(x)), + null=ndx.broadcast_to(null, nda.shape(x)), ) diff --git a/ndonnx/_core/_numericimpl.py b/ndonnx/_core/_numericimpl.py index f7dbf25..856cbf2 100644 --- a/ndonnx/_core/_numericimpl.py +++ b/ndonnx/_core/_numericimpl.py @@ -813,7 +813,7 @@ def make_nullable(self, x, null): return ndx.Array._from_fields( dtypes.into_nullable(x.dtype), values=x.copy(), - null=ndx.reshape(null, nda.shape(x)), + null=ndx.broadcast_to(null, nda.shape(x)), ) @validate_core diff --git a/ndonnx/_core/_stringimpl.py b/ndonnx/_core/_stringimpl.py index ddb4f6e..e71cd83 100644 --- a/ndonnx/_core/_stringimpl.py +++ b/ndonnx/_core/_stringimpl.py @@ -77,7 +77,7 @@ def make_nullable(self, x, null): return ndx.Array._from_fields( dtypes.into_nullable(x.dtype), values=x.copy(), - null=ndx.reshape(null, nda.shape(x)), + null=ndx.broadcast_to(null, nda.shape(x)), ) diff --git a/tests/test_additional.py b/tests/test_additional.py index d314dbc..ee23b1c 100644 --- a/tests/test_additional.py +++ b/tests/test_additional.py @@ -117,3 +117,29 @@ def test_isin(): a = ndx.asarray(["hello", "world"]) assert_array_equal([True, False], nda.isin(a, ["hello"]).to_numpy()) + + +@pytest.mark.parametrize( + "dtype", + [ + ndx.int64, + ndx.utf8, + ndx.bool, + ], +) +@pytest.mark.parametrize( + "mask", + [ + [True, False, True], + True, + False, + [True], + ], +) +def test_make_nullable(dtype, mask): + a = ndx.asarray([1, 2, 3], dtype=dtype) + m = ndx.asarray(mask) + + result = nda.make_nullable(a, m) + expected = np.ma.masked_array([1, 2, 3], mask, dtype.to_numpy_dtype()) + assert_array_equal(result.to_numpy(), expected)