Skip to content

Commit

Permalink
Fix make_nullable
Browse files Browse the repository at this point in the history
  • Loading branch information
MatejUrbanQC committed Aug 27, 2024
1 parent 84b8f4d commit 7317a8f
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 3 deletions.
2 changes: 1 addition & 1 deletion ndonnx/_core/_boolimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def make_nullable(self, x, null):
return ndx.Array._from_fields(
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.reshape(null, nda.shape(x)),
null=ndx.broadcast_to(null, nda.shape(x)),
)


Expand Down
2 changes: 1 addition & 1 deletion ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ def make_nullable(self, x, null):
return ndx.Array._from_fields(
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.reshape(null, nda.shape(x)),
null=ndx.broadcast_to(null, nda.shape(x)),
)

@validate_core
Expand Down
2 changes: 1 addition & 1 deletion ndonnx/_core/_stringimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def make_nullable(self, x, null):
return ndx.Array._from_fields(
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.reshape(null, nda.shape(x)),
null=ndx.broadcast_to(null, nda.shape(x)),
)


Expand Down
26 changes: 26 additions & 0 deletions tests/test_additional.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,29 @@ def test_isin():

a = ndx.asarray(["hello", "world"])
assert_array_equal([True, False], nda.isin(a, ["hello"]).to_numpy())


@pytest.mark.parametrize(
"dtype",
[
ndx.int64,
ndx.utf8,
ndx.bool,
],
)
@pytest.mark.parametrize(
"mask",
[
[True, False, True],
True,
False,
[True],
],
)
def test_make_nullable(dtype, mask):
a = ndx.asarray([1, 2, 3], dtype=dtype)
m = ndx.asarray(mask)

result = nda.make_nullable(a, m)
expected = np.ma.masked_array([1, 2, 3], mask, dtype.to_numpy_dtype())
assert_array_equal(result.to_numpy(), expected)

0 comments on commit 7317a8f

Please sign in to comment.