diff --git a/ndonnx/_core/_impl.py b/ndonnx/_core/_impl.py index 8681ea4..60c19fd 100644 --- a/ndonnx/_core/_impl.py +++ b/ndonnx/_core/_impl.py @@ -289,7 +289,12 @@ def positive(self, x): def pow(self, x, y): x, y = ndx.asarray(x), ndx.asarray(y) dtype = ndx.result_type(x, y) - if isinstance(dtype, (dtypes.Unsigned, dtypes.NullableUnsigned)): + if isinstance(dtype, (dtypes.Unsigned, dtypes.NullableUnsigned)) or dtype in ( + dtypes.int8, + dtypes.nint8, + dtypes.int16, + dtypes.nint16, + ): return _binary_op(x, y, opx.pow, dtypes.int64) else: return _binary_op(x, y, opx.pow)