From 8b25df25148d8f335b3f52857cffecd0470d7ce3 Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Mon, 22 Jul 2024 13:03:50 +0100 Subject: [PATCH] Cast back to input dtype in where --- ndonnx/_core/_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndonnx/_core/_impl.py b/ndonnx/_core/_impl.py index 9b28b93..9c374f2 100644 --- a/ndonnx/_core/_impl.py +++ b/ndonnx/_core/_impl.py @@ -520,7 +520,7 @@ def where_dtype_agnostic(a: ndx.Array, b: ndx.Array) -> ndx.Array: a.astype(dtypes.int32)._core(), b.astype(dtypes.int32)._core(), ) - ) + ).astype(a.dtype) elif isinstance(a.dtype, dtypes.CoreType): return _from_corearray( opx.where(