From 5c5ce1d224bbe326974f4e0e3459ff971c168928 Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Wed, 10 Jul 2024 15:37:26 +0200 Subject: [PATCH] Match Array API tests (run array-api tests) --- ndonnx/_core/_impl.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ndonnx/_core/_impl.py b/ndonnx/_core/_impl.py index 1c0dc75..0a32b87 100644 --- a/ndonnx/_core/_impl.py +++ b/ndonnx/_core/_impl.py @@ -112,9 +112,13 @@ def cosh(self, x): return _unary_op(x, opx.cosh, dtypes.float32) def divide(self, x, y): - return _variadic_op( - [x, y], opx.div, via_dtype=dtypes.float64, cast_return=False + x, y = promote(x, y) + via_dtype = ( + dtypes.float64 + if x.dtype in (dtypes.nint64, dtypes.int64) + else dtypes.float32 ) + return _variadic_op([x, y], opx.div, via_dtype=via_dtype, cast_return=False) def equal(self, x, y) -> Array: x, y = promote(x, y)