Skip to content

Commit

Permalink
Fix missing UnsupportedOperationError
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Aug 28, 2024
1 parent d64dd52 commit 39eebde
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 8 deletions.
38 changes: 31 additions & 7 deletions ndonnx/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,11 @@ def numeric_like(x):


def broadcast_to(x, shape):
return x.dtype._ops.broadcast_to(x, shape)
if (out := x.dtype._ops.broadcast_to(x, shape)) is not NotImplemented:
return out
raise UnsupportedOperationError(
f"Unsupported operand type for broadcast_to: '{x.dtype}'"
)


# TODO: onnxruntime doesn't work for 2 empty arrays of integer type
Expand All @@ -599,27 +603,47 @@ def concat(arrays, /, *, axis: int | None = 0):


def expand_dims(x, axis=0):
return x.dtype._ops.expand_dims(x, axis)
if (out := x.dtype._ops.expand_dims(x, axis)) is not NotImplemented:
return out
raise UnsupportedOperationError(
f"Unsupported operand type for expand_dims: '{x.dtype}'"
)


def flip(x, axis=None):
return x.dtype._ops.flip(x, axis=axis)
if (out := x.dtype._ops.flip(x, axis=axis)) is not NotImplemented:
return out
raise UnsupportedOperationError(f"Unsupported operand type for flip: '{x.dtype}'")


def permute_dims(x, axes):
return x.dtype._ops.permute_dims(x, axes)
if (out := x.dtype._ops.permute_dims(x, axes)) is not NotImplemented:
return out
raise UnsupportedOperationError(
f"Unsupported operand type for permute_dims: '{x.dtype}'"
)


def reshape(x, shape, *, copy=None):
return x.dtype._ops.reshape(x, shape, copy=copy)
if (out := x.dtype._ops.reshape(x, shape, copy=copy)) is not NotImplemented:
return out
raise UnsupportedOperationError(
f"Unsupported operand type for reshape: '{x.dtype}'"
)


def roll(x, shift, axis=None):
return x.dtype._ops.roll(x, shift, axis)
if (out := x.dtype._ops.roll(x, shift, axis)) is not NotImplemented:
return out
raise UnsupportedOperationError(f"Unsupported operand type for roll: '{x.dtype}'")


def squeeze(x, axis):
return x.dtype._ops.squeeze(x, axis)
if (out := x.dtype._ops.squeeze(x, axis)) is not NotImplemented:
return out
raise UnsupportedOperationError(
f"Unsupported operand type for squeeze: '{x.dtype}'"
)


def stack(arrays, axis=0):
Expand Down
7 changes: 6 additions & 1 deletion tests/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
CastError,
CoreType,
)
from ndonnx._experimental import CastMixin, Schema, StructType, UniformShapeOperations
from ndonnx._experimental import (
CastMixin,
Schema,
StructType,
UniformShapeOperations,
)

from .utils import assert_array_equal

Expand Down

0 comments on commit 39eebde

Please sign in to comment.