diff --git a/ndonnx/_core/_impl.py b/ndonnx/_core/_impl.py index 9b28b93..9c374f2 100644 --- a/ndonnx/_core/_impl.py +++ b/ndonnx/_core/_impl.py @@ -520,7 +520,7 @@ def where_dtype_agnostic(a: ndx.Array, b: ndx.Array) -> ndx.Array: a.astype(dtypes.int32)._core(), b.astype(dtypes.int32)._core(), ) - ) + ).astype(a.dtype) elif isinstance(a.dtype, dtypes.CoreType): return _from_corearray( opx.where(