Skip to content

Commit

Permalink
Fix argmax
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Oct 24, 2024
1 parent ed532b5 commit 4e8ac5f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
16 changes: 16 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,3 +992,19 @@ def test_no_unsafe_cumulative_sum_cast():
):
a = ndx.asarray([1, 2, 3], ndx.int32)
ndx.cumulative_sum(a, dtype=ndx.uint64)


@pytest.mark.parametrize(
"func, x",
[
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=ndx.int32)),
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=ndx.float32)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=ndx.float32)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=ndx.float32)),
],
)
def test_argmaxmin(func, x):
np_result = func(x)
ndx_result = getattr(ndx, func.__name__)(ndx.asarray(x)).to_numpy()
breakpoint()
np.testing.assert_equal(np_result, ndx_result)
2 changes: 0 additions & 2 deletions xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_signbit
array_api_tests/test_operators_and_elementwise_functions.py::test_sinh
array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt
array_api_tests/test_operators_and_elementwise_functions.py::test_tan
array_api_tests/test_searching_functions.py::test_argmax
array_api_tests/test_searching_functions.py::test_argmin
array_api_tests/test_searching_functions.py::test_nonzero_zerodim_error
array_api_tests/test_searching_functions.py::test_searchsorted
array_api_tests/test_searching_functions.py::test_where
Expand Down

0 comments on commit 4e8ac5f

Please sign in to comment.