From 39eebde6adcc32969f8f45982271d09cb5e49bed Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Wed, 28 Aug 2024 17:55:26 +0200 Subject: [PATCH] Fix missing UnsupportedOperationError --- ndonnx/_funcs.py | 38 +++++++++++++++++++++++++++++++------- tests/test_dtypes.py | 7 ++++++- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/ndonnx/_funcs.py b/ndonnx/_funcs.py index 92f5bb1..a79fa31 100644 --- a/ndonnx/_funcs.py +++ b/ndonnx/_funcs.py @@ -580,7 +580,11 @@ def numeric_like(x): def broadcast_to(x, shape): - return x.dtype._ops.broadcast_to(x, shape) + if (out := x.dtype._ops.broadcast_to(x, shape)) is not NotImplemented: + return out + raise UnsupportedOperationError( + f"Unsupported operand type for broadcast_to: '{x.dtype}'" + ) # TODO: onnxruntime doesn't work for 2 empty arrays of integer type @@ -599,27 +603,47 @@ def concat(arrays, /, *, axis: int | None = 0): def expand_dims(x, axis=0): - return x.dtype._ops.expand_dims(x, axis) + if (out := x.dtype._ops.expand_dims(x, axis)) is not NotImplemented: + return out + raise UnsupportedOperationError( + f"Unsupported operand type for expand_dims: '{x.dtype}'" + ) def flip(x, axis=None): - return x.dtype._ops.flip(x, axis=axis) + if (out := x.dtype._ops.flip(x, axis=axis)) is not NotImplemented: + return out + raise UnsupportedOperationError(f"Unsupported operand type for flip: '{x.dtype}'") def permute_dims(x, axes): - return x.dtype._ops.permute_dims(x, axes) + if (out := x.dtype._ops.permute_dims(x, axes)) is not NotImplemented: + return out + raise UnsupportedOperationError( + f"Unsupported operand type for permute_dims: '{x.dtype}'" + ) def reshape(x, shape, *, copy=None): - return x.dtype._ops.reshape(x, shape, copy=copy) + if (out := x.dtype._ops.reshape(x, shape, copy=copy)) is not NotImplemented: + return out + raise UnsupportedOperationError( + f"Unsupported operand type for reshape: '{x.dtype}'" + ) def roll(x, shift, axis=None): - return x.dtype._ops.roll(x, shift, axis) + if (out := x.dtype._ops.roll(x, shift, axis)) is not NotImplemented: + return out + raise UnsupportedOperationError(f"Unsupported operand type for roll: '{x.dtype}'") def squeeze(x, axis): - return x.dtype._ops.squeeze(x, axis) + if (out := x.dtype._ops.squeeze(x, axis)) is not NotImplemented: + return out + raise UnsupportedOperationError( + f"Unsupported operand type for squeeze: '{x.dtype}'" + ) def stack(arrays, axis=0): diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 048978d..81a4fb6 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -15,7 +15,12 @@ CastError, CoreType, ) -from ndonnx._experimental import CastMixin, Schema, StructType, UniformShapeOperations +from ndonnx._experimental import ( + CastMixin, + Schema, + StructType, + UniformShapeOperations, +) from .utils import assert_array_equal