Skip to content

Commit

Permalink
Fix initialization with mask (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 authored Sep 24, 2024
1 parent 34d48f9 commit def1e52
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 14 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
Changelog
=========

0.9.1 (unreleased)
------------------

**Bug fix**

- Fixed a bug in the construction of nullable arrays using :func:`ndonnx.asarray` where the shape of the null field would not match the values field if the provided `np.ma.MaskedArray`'s mask was scalar.


0.9.0 (2024-08-30)
------------------
Expand Down
34 changes: 33 additions & 1 deletion ndonnx/_core/_nullableimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

from typing import TYPE_CHECKING, Union

import numpy as np

import ndonnx as ndx

from ._shapeimpl import UniformShapeOperations
from ._utils import validate_core
from ._utils import assemble_output_recurse, validate_core

if TYPE_CHECKING:
from ndonnx._array import Array
Expand Down Expand Up @@ -35,3 +37,33 @@ def where(self, condition, x, y):
x = ndx.astype(x, target_dtype)
y = ndx.astype(y, target_dtype)
return super().where(condition, x, y)

def make_array(
self,
shape: tuple[int | None | str, ...],
dtype: Dtype,
eager_value: np.ndarray | None = None,
) -> Array:
if not isinstance(dtype, ndx.Nullable):
return NotImplemented

if eager_value is not None:
eager_values = dtype._parse_input(eager_value)
values_np = assemble_output_recurse(dtype.values, eager_values["values"])
values = ndx.asarray(values_np, dtype=dtype.values)
null = ndx.asarray(
np.broadcast_to(
assemble_output_recurse(dtype.null, eager_values["null"]),
values_np.shape,
),
dtype=dtype.null,
)
else:
values = ndx.array(shape=shape, dtype=dtype.values)
null = ndx.array(shape=shape, dtype=dtype.null)

return ndx.Array._from_fields(
dtype,
values=values,
null=null,
)
15 changes: 2 additions & 13 deletions ndonnx/_core/_shapeimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import ndonnx.additional as nda

from ._interface import OperationsBlock
from ._utils import from_corearray
from ._utils import assemble_output_recurse, from_corearray

if TYPE_CHECKING:
from ndonnx._array import Array, IndexType
Expand Down Expand Up @@ -270,7 +270,7 @@ def make_array(
if eager_values is None:
field_value = None
else:
field_value = _assemble_output_recurse(field_dtype, eager_values[name])
field_value = assemble_output_recurse(field_dtype, eager_values[name])
fields[name] = field_dtype._ops.make_array(
shape,
field_dtype,
Expand All @@ -293,14 +293,3 @@ def getitem(self, x: Array, index: IndexType) -> Array:
index = index._core()

return x._transmute(lambda corearray: corearray[index])


def _assemble_output_recurse(dtype: Dtype, values: dict) -> np.ndarray:
if isinstance(dtype, dtypes.CoreType):
return dtype._assemble_output(values)
else:
fields = {
name: _assemble_output_recurse(field_dtype, values[name])
for name, field_dtype in dtype._fields().items()
}
return dtype._assemble_output(fields)
15 changes: 15 additions & 0 deletions ndonnx/_core/_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations

import functools
import itertools
from typing import TYPE_CHECKING

import numpy as np

import ndonnx as ndx
import ndonnx._data_types as dtypes
from ndonnx._utility import promote

if TYPE_CHECKING:
from ndonnx import Array
from ndonnx._data_types import Dtype


def binary_op(
Expand Down Expand Up @@ -209,3 +213,14 @@ def wrapper(*args, **kwargs):
return func(*args, **kwargs)

return wrapper


def assemble_output_recurse(dtype: Dtype, values: dict) -> np.ndarray:
if isinstance(dtype, dtypes.CoreType):
return dtype._assemble_output(values)
else:
fields = {
name: assemble_output_recurse(field_dtype, values[name])
for name, field_dtype in dtype._fields().items()
}
return dtype._assemble_output(fields)
20 changes: 20 additions & 0 deletions tests/test_masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,23 @@ def test_broadcasting(arrays):
# NumPy simply drops the masked array.
# We do not want to do the same quite intentionally.
np.testing.assert_equal(a.to_numpy(), e)


@pytest.mark.parametrize(
"np_array",
[
np.ma.masked_array([1, 2, 3], mask=[0, 0, 1], dtype=np.int64),
np.ma.masked_array([1, 2, 3], mask=[0, 0, 1], dtype=np.float64),
np.ma.masked_array([1, 2]),
np.ma.masked_array(["a", "b"], mask=True),
np.ma.masked_array([1, 2, 3], mask=[[[0]]]),
np.ma.masked_array([[1, 2, 3]], mask=[True, False, True]),
np.ma.masked_array([1.0, 2.0, 3.0], mask=[0, 0, 1]),
],
)
def test_initialization(np_array):
actual = ndx.asarray(np_array)
values = actual.values.to_numpy()
null = actual.null.to_numpy()
assert_array_equal(actual.to_numpy(), np_array)
assert values.shape == null.shape

0 comments on commit def1e52

Please sign in to comment.