Skip to content

Commit

Permalink
Fix promotion precision loss
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Aug 1, 2024
1 parent 2220369 commit b25cf3e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
11 changes: 10 additions & 1 deletion ndonnx/_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,16 @@ def _promote_with_none(*args: Array | npt.ArrayLike) -> list[Array | None]:
signed_integer = True
elif isinstance(arg, np.ndarray):
arr_or_none.append(asarray(arg))
elif isinstance(arg, (int, float, str, np.generic)):
elif isinstance(arg, float):
np_dtype = (
np.dtype("float32")
if np.float32(arg) == np.float64(arg)
else np.dtype("float64")
)
arr = asarray(arg, dtypes.from_numpy_dtype(np_dtype)) # type: ignore
arr_or_none.append(arr)
scalars.append(arr)
elif isinstance(arg, (int, str, np.generic)):
np_dtype = np.min_scalar_type(arg)
if np_dtype == np.dtype("float16"):
np_dtype = np.dtype("float32")
Expand Down
8 changes: 8 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,3 +634,11 @@ def test_array_creation_with_invalid_fields():
def test_promote_nullable():
with pytest.warns(DeprecationWarning):
assert ndx.promote_nullable(np.int64) == ndx.nint64


@pytest.mark.parametrize("val", [1, 1.0, 1.5, 0.123456789, "a"])
def test_scalar_promote(val):
x = ndx.asarray([val] * 4)
actual = (x + val).to_numpy()
expected = x.to_numpy() + val
np.testing.assert_equal(actual, expected, strict=True)

0 comments on commit b25cf3e

Please sign in to comment.