From b25cf3e6b891b64d017b9594a8d3713c26338d20 Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Thu, 1 Aug 2024 15:16:49 +0100 Subject: [PATCH] Fix promotion precision loss --- ndonnx/_utility.py | 11 ++++++++++- tests/test_core.py | 8 ++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/ndonnx/_utility.py b/ndonnx/_utility.py index c4eb758..8ddefae 100644 --- a/ndonnx/_utility.py +++ b/ndonnx/_utility.py @@ -45,7 +45,16 @@ def _promote_with_none(*args: Array | npt.ArrayLike) -> list[Array | None]: signed_integer = True elif isinstance(arg, np.ndarray): arr_or_none.append(asarray(arg)) - elif isinstance(arg, (int, float, str, np.generic)): + elif isinstance(arg, float): + np_dtype = ( + np.dtype("float32") + if np.float32(arg) == np.float64(arg) + else np.dtype("float64") + ) + arr = asarray(arg, dtypes.from_numpy_dtype(np_dtype)) # type: ignore + arr_or_none.append(arr) + scalars.append(arr) + elif isinstance(arg, (int, str, np.generic)): np_dtype = np.min_scalar_type(arg) if np_dtype == np.dtype("float16"): np_dtype = np.dtype("float32") diff --git a/tests/test_core.py b/tests/test_core.py index 8ff85e4..6eedbee 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -634,3 +634,11 @@ def test_array_creation_with_invalid_fields(): def test_promote_nullable(): with pytest.warns(DeprecationWarning): assert ndx.promote_nullable(np.int64) == ndx.nint64 + + +@pytest.mark.parametrize("val", [1, 1.0, 1.5, 0.123456789, "a"]) +def test_scalar_promote(val): + x = ndx.asarray([val] * 4) + actual = (x + val).to_numpy() + expected = x.to_numpy() + val + np.testing.assert_equal(actual, expected, strict=True)