diff --git a/CHANGELOG.rst b/CHANGELOG.rst index e12e1a2..ec24c23 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,6 +7,13 @@ Changelog ========= +0.9.1 (unreleased) +------------------ + +**Bug fix** + +- Fixed a bug in the construction of nullable arrays using :func:`ndonnx.asarray` where the shape of the null field would not match the values field if the provided `np.ma.MaskedArray`'s mask was scalar. + 0.9.0 (2024-08-30) ------------------ diff --git a/ndonnx/_core/_nullableimpl.py b/ndonnx/_core/_nullableimpl.py index 835f4fd..c03e3e8 100644 --- a/ndonnx/_core/_nullableimpl.py +++ b/ndonnx/_core/_nullableimpl.py @@ -4,10 +4,12 @@ from typing import TYPE_CHECKING, Union +import numpy as np + import ndonnx as ndx from ._shapeimpl import UniformShapeOperations -from ._utils import validate_core +from ._utils import assemble_output_recurse, validate_core if TYPE_CHECKING: from ndonnx._array import Array @@ -35,3 +37,33 @@ def where(self, condition, x, y): x = ndx.astype(x, target_dtype) y = ndx.astype(y, target_dtype) return super().where(condition, x, y) + + def make_array( + self, + shape: tuple[int | None | str, ...], + dtype: Dtype, + eager_value: np.ndarray | None = None, + ) -> Array: + if not isinstance(dtype, ndx.Nullable): + return NotImplemented + + if eager_value is not None: + eager_values = dtype._parse_input(eager_value) + values_np = assemble_output_recurse(dtype.values, eager_values["values"]) + values = ndx.asarray(values_np, dtype=dtype.values) + null = ndx.asarray( + np.broadcast_to( + assemble_output_recurse(dtype.null, eager_values["null"]), + values_np.shape, + ), + dtype=dtype.null, + ) + else: + values = ndx.array(shape=shape, dtype=dtype.values) + null = ndx.array(shape=shape, dtype=dtype.null) + + return ndx.Array._from_fields( + dtype, + values=values, + null=null, + ) diff --git a/ndonnx/_core/_shapeimpl.py b/ndonnx/_core/_shapeimpl.py index 67da73e..4e90232 100644 --- a/ndonnx/_core/_shapeimpl.py +++ b/ndonnx/_core/_shapeimpl.py @@ -14,7 +14,7 @@ import ndonnx.additional as nda from ._interface import OperationsBlock -from ._utils import from_corearray +from ._utils import assemble_output_recurse, from_corearray if TYPE_CHECKING: from ndonnx._array import Array, IndexType @@ -270,7 +270,7 @@ def make_array( if eager_values is None: field_value = None else: - field_value = _assemble_output_recurse(field_dtype, eager_values[name]) + field_value = assemble_output_recurse(field_dtype, eager_values[name]) fields[name] = field_dtype._ops.make_array( shape, field_dtype, @@ -293,14 +293,3 @@ def getitem(self, x: Array, index: IndexType) -> Array: index = index._core() return x._transmute(lambda corearray: corearray[index]) - - -def _assemble_output_recurse(dtype: Dtype, values: dict) -> np.ndarray: - if isinstance(dtype, dtypes.CoreType): - return dtype._assemble_output(values) - else: - fields = { - name: _assemble_output_recurse(field_dtype, values[name]) - for name, field_dtype in dtype._fields().items() - } - return dtype._assemble_output(fields) diff --git a/ndonnx/_core/_utils.py b/ndonnx/_core/_utils.py index 29e1b74..a6e1e85 100644 --- a/ndonnx/_core/_utils.py +++ b/ndonnx/_core/_utils.py @@ -1,16 +1,20 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations import functools import itertools from typing import TYPE_CHECKING +import numpy as np + import ndonnx as ndx import ndonnx._data_types as dtypes from ndonnx._utility import promote if TYPE_CHECKING: from ndonnx import Array + from ndonnx._data_types import Dtype def binary_op( @@ -209,3 +213,14 @@ def wrapper(*args, **kwargs): return func(*args, **kwargs) return wrapper + + +def assemble_output_recurse(dtype: Dtype, values: dict) -> np.ndarray: + if isinstance(dtype, dtypes.CoreType): + return dtype._assemble_output(values) + else: + fields = { + name: assemble_output_recurse(field_dtype, values[name]) + for name, field_dtype in dtype._fields().items() + } + return dtype._assemble_output(fields) diff --git a/tests/test_masked.py b/tests/test_masked.py index 814e794..5eccb0a 100644 --- a/tests/test_masked.py +++ b/tests/test_masked.py @@ -254,3 +254,23 @@ def test_broadcasting(arrays): # NumPy simply drops the masked array. # We do not want to do the same quite intentionally. np.testing.assert_equal(a.to_numpy(), e) + + +@pytest.mark.parametrize( + "np_array", + [ + np.ma.masked_array([1, 2, 3], mask=[0, 0, 1], dtype=np.int64), + np.ma.masked_array([1, 2, 3], mask=[0, 0, 1], dtype=np.float64), + np.ma.masked_array([1, 2]), + np.ma.masked_array(["a", "b"], mask=True), + np.ma.masked_array([1, 2, 3], mask=[[[0]]]), + np.ma.masked_array([[1, 2, 3]], mask=[True, False, True]), + np.ma.masked_array([1.0, 2.0, 3.0], mask=[0, 0, 1]), + ], +) +def test_initialization(np_array): + actual = ndx.asarray(np_array) + values = actual.values.to_numpy() + null = actual.null.to_numpy() + assert_array_equal(actual.to_numpy(), np_array) + assert values.shape == null.shape