diff --git a/ndonnx/_core/_impl.py b/ndonnx/_core/_impl.py index f64fb03..a114f72 100644 --- a/ndonnx/_core/_impl.py +++ b/ndonnx/_core/_impl.py @@ -113,19 +113,15 @@ def cosh(self, x): def divide(self, x, y): x, y = promote(x, y) + bits = ( + ndx.finfo(x.dtype).bits + if x.dtype in (dtypes.NullableFloating, dtypes.Floating) + else ndx.iinfo(x.dtype).bits + ) + via_dtype = ( dtypes.float64 - if x.dtype - in ( - dtypes.nint64, - dtypes.int64, - dtypes.nfloat64, - dtypes.float64, - dtypes.uint64, - dtypes.uint32, - dtypes.nuint64, - dtypes.nuint32, - ) + if bits > 32 or x.dtype in (dtypes.nuint32, dtypes.uint32) else dtypes.float32 ) return _variadic_op([x, y], opx.div, via_dtype=via_dtype, cast_return=False)