Skip to content

Commit

Permalink
Fix parsing of scalar string arrays (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
MatejUrbanQC authored Aug 20, 2024
1 parent 2cf6994 commit e9148c3
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 7 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
Changelog
=========

0.7.1 (unreleased)
------------------

**Bug fixes**

- Fixes parsing numpy arrays of type ``object`` (consisting of strings) as ``utf8``. Previously this worked correctly only for 1d arrays.


0.7.0 (2024-08-12)
------------------

Expand Down
2 changes: 1 addition & 1 deletion ndonnx/_data_types/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def to_numpy_dtype() -> np.dtype:

def _parse_input(self, data: np.ndarray) -> dict[str, np.ndarray]:
if data.dtype.kind == "U" or (
data.dtype.kind == "O" and all(isinstance(x, str) for x in data)
data.dtype.kind == "O" and all(isinstance(x, str) for x in data.flat)
):
return {"data": data.astype(np.str_)}
else:
Expand Down
40 changes: 34 additions & 6 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,44 @@ def test_null_promotion():
np.testing.assert_equal(np.add(inputs["a"], inputs["b"]), actual["c"])


def test_asarray():
a = ndx.asarray([1, 2, 3], dtype=ndx.int64)
assert a.dtype == ndx.int64
@pytest.mark.parametrize(
"array, dtype, expected_dtype",
[
([1, 2, 3], ndx.int64, ndx.int64),
(np.array([1, 2, 3], np.int64), None, ndx.int64),
(1, ndx.int64, ndx.int64),
(1, ndx.float64, ndx.float64),
(["a", "b"], ndx.utf8, ndx.utf8),
(np.array(["a", "b"]), None, ndx.utf8),
(np.array(["a", "b"], object), None, ndx.utf8),
("a", ndx.utf8, ndx.utf8),
(np.array("a"), None, ndx.utf8),
(np.array("a", object), None, ndx.utf8),
([["a"]], None, ndx.utf8),
(np.array([["a"]]), None, ndx.utf8),
(np.array([["a"]], object), None, ndx.utf8),
],
)
def test_asarray(array, dtype, expected_dtype):
a = ndx.asarray(array, dtype=dtype)
assert a.dtype == expected_dtype
np.testing.assert_array_equal(
np.array([1, 2, 3], np.int64), a.to_numpy(), strict=True
np.array(array, expected_dtype.to_numpy_dtype()), a.to_numpy(), strict=True
)


def test_asarray_masked():
np_arr = np.ma.masked_array([1, 2, 3], mask=[0, 0, 1])
@pytest.mark.parametrize(
"np_arr",
[
np.ma.masked_array([1, 2, 3], mask=[0, 0, 1]),
np.ma.masked_array(1, mask=0),
np.ma.masked_array(["a", "b"], mask=[1, 0]),
np.ma.masked_array("a", mask=0),
np.ma.masked_array(["a", "b"], mask=[1, 0], dtype=object),
np.ma.masked_array("a", mask=0, dtype=object),
],
)
def test_asarray_masked(np_arr):
ndx_arr = ndx.asarray(np_arr)
assert isinstance(ndx_arr, ndx.Array)
assert isinstance(ndx_arr.to_numpy(), np.ma.MaskedArray)
Expand Down

0 comments on commit e9148c3

Please sign in to comment.