Skip to content

Commit

Permalink
Fix ones_like (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 authored Oct 1, 2024
1 parent fd81960 commit d185f2c
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 2 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
Changelog
=========

0.9.1 (unreleased)
0.9.1 (2024-10-01)
------------------

**Bug fix**

- Fixed a bug in the construction of nullable arrays using :func:`ndonnx.asarray` where the shape of the null field would not match the values field if the provided `np.ma.MaskedArray`'s mask was scalar.
- Fixed a bug in the implementation of :func:`ndonnx.ones_like` where the static shape was being used to construct the array of ones.


0.9.0 (2024-08-30)
Expand Down
2 changes: 1 addition & 1 deletion ndonnx/_core/_shapeimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def zeros_like(self, x, dtype=None, device=None):
return ndx.zeros(nda.shape(x), dtype=dtype or x.dtype, device=device)

def ones_like(self, x, dtype=None, device=None):
return ndx.ones(x.shape, dtype=dtype or x.dtype, device=device)
return ndx.ones(nda.shape(x), dtype=dtype or x.dtype, device=device)

def make_array(
self,
Expand Down
10 changes: 10 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,16 @@ def test_creation_full():
)


def test_creation_ones_like():
a = ndx.array(shape=("N",), dtype=ndx.int64)
b = ndx.ones_like(a)
model = ndx.build({"a": a}, {"b": b})
assert_array_equal(
run(model, {"a": np.array([1, 2, 3], dtype=np.int64)})["b"],
np.ones(3, dtype=np.int64),
)


@pytest.mark.parametrize(
"args, expected",
[
Expand Down
1 change: 1 addition & 0 deletions xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,4 @@ array_api_tests/test_special_cases.py::test_unary[acos(x_i < -1) -> NaN]
array_api_tests/test_special_cases.py::test_unary[sqrt(x_i < 0) -> NaN]
array_api_tests/test_special_cases.py::test_unary[asin(x_i < -1) -> NaN]
array_api_tests/test_special_cases.py::test_unary[atanh(x_i > 1) -> NaN]
array_api_tests/test_special_cases.py::test_unary[atanh(x_i < -1) -> NaN]

0 comments on commit d185f2c

Please sign in to comment.