diff --git a/ndonnx/_array.py b/ndonnx/_array.py index cf9124f..caddcd2 100644 --- a/ndonnx/_array.py +++ b/ndonnx/_array.py @@ -506,7 +506,7 @@ def size(self) -> ndx.Array: out: Array Scalar ``Array`` instance whose value is the number of elements in the original array. """ - return ndx.prod(self.shape) + return ndx.prod(shape(self)) @property def T(self) -> ndx.Array: # noqa: N802 diff --git a/ndonnx/_core/_boolimpl.py b/ndonnx/_core/_boolimpl.py index 9d7b946..806c2d2 100644 --- a/ndonnx/_core/_boolimpl.py +++ b/ndonnx/_core/_boolimpl.py @@ -12,6 +12,7 @@ import ndonnx as ndx import ndonnx._data_types as dtypes import ndonnx._opset_extensions as opx +import ndonnx.additional as nda from ._nullableimpl import NullableOperationsImpl from ._shapeimpl import UniformShapeOperations @@ -169,7 +170,7 @@ def make_nullable(self, x, null): return ndx.Array._from_fields( dtypes.into_nullable(x.dtype), values=x.copy(), - null=ndx.reshape(null, x.shape), + null=ndx.reshape(null, nda.shape(x)), ) diff --git a/ndonnx/_core/_numericimpl.py b/ndonnx/_core/_numericimpl.py index 0b8641f..f7dbf25 100644 --- a/ndonnx/_core/_numericimpl.py +++ b/ndonnx/_core/_numericimpl.py @@ -16,6 +16,7 @@ import ndonnx as ndx import ndonnx._data_types as dtypes import ndonnx._opset_extensions as opx +import ndonnx.additional as nda from ndonnx._utility import promote from ._nullableimpl import NullableOperationsImpl @@ -198,7 +199,7 @@ def isfinite(self, x): def isinf(self, x): if isinstance(x.dtype, (dtypes.Floating, dtypes.NullableFloating)): return unary_op(x, opx.isinf) - return ndx.full(x.shape, fill_value=False) + return ndx.full(nda.shape(x), fill_value=False) @validate_core def isnan(self, x): @@ -400,9 +401,7 @@ def nonzero(self, x) -> tuple[Array, ...]: return (ndx.arange(0, x != 0, dtype=dtypes.int64),) ret_full_flattened = ndx.reshape( - from_corearray( - opx.ndindex(ndx.asarray(x.shape, dtype=dtypes.int64)._core()) - )[x != 0], + from_corearray(opx.ndindex(nda.shape(x)._core()))[x != 0], [-1], ) @@ -413,7 +412,7 @@ def nonzero(self, x) -> tuple[Array, ...]: ret_full_flattened._core(), ndx.arange( i, - ret_full_flattened.shape[0], + nda.shape(ret_full_flattened)[0], x.ndim, dtype=dtypes.int64, )._core(), @@ -454,20 +453,21 @@ def searchsorted( from_corearray, opx.get_indices(x1._core(), x2._core(), positions._core()) ) - how_many = ndx.zeros(ndx.asarray(combined.shape) + 1, dtype=dtypes.int64) + combined_shape = nda.shape(combined) + how_many = ndx.zeros(combined_shape + 1, dtype=dtypes.int64) how_many[ - ndx.where(indices_x1 + 1 <= combined.shape[0], indices_x1 + 1, indices_x1) + ndx.where(indices_x1 + 1 <= combined_shape[0], indices_x1 + 1, indices_x1) ] = counts how_many = ndx.cumulative_sum(how_many, include_initial=True) - ret = ndx.zeros(x2.shape, dtype=dtypes.int64) + ret = ndx.zeros(nda.shape(x2), dtype=dtypes.int64) if side == "left": ret = how_many[indices_x2] - ret[nan_mask] = ndx.asarray(x1.shape, dtype=dtypes.int64) - 1 + ret[nan_mask] = nda.shape(x1) - 1 else: ret = how_many[indices_x2 + 1] - ret[nan_mask] = ndx.asarray(x1.shape, dtype=dtypes.int64) + ret[nan_mask] = nda.shape(x1) return ret @@ -494,7 +494,7 @@ def unique_all(self, x): # FIXME: I think we can simply use arange/ones+cumsum or something for the indices # maybe: indices = opx.cumsum(ones_like(flattened, dtype=dtypes.i64), axis=ndx.asarray(0)) indices = opx.squeeze( - opx.ndindex(ndx.asarray(flattened.shape, dtype=dtypes.int64)._core()), + opx.ndindex(nda.shape(flattened)._core()), opx.const([1], dtype=dtypes.int64), ) @@ -502,7 +502,7 @@ def unique_all(self, x): values = from_corearray(ret_opd[0]) indices = from_corearray(indices[ret_opd[1]]) - inverse_indices = ndx.reshape(from_corearray(ret_opd[2]), x.shape) + inverse_indices = ndx.reshape(from_corearray(ret_opd[2]), nda.shape(x)) counts = from_corearray(ret_opd[3]) return ret( @@ -535,7 +535,7 @@ def argsort(self, x, *, axis=-1, descending=False, stable=True): if axis < 0: axis += x.ndim - _len = ndx.asarray(x.shape[axis : axis + 1], dtype=dtypes.int64)._core() + _len = ndx.asarray(nda.shape(x)[axis : axis + 1], dtype=dtypes.int64)._core() return _via_i64_f64( lambda x: opx.top_k(x, _len, largest=descending, axis=axis)[1], [x] ) @@ -544,7 +544,7 @@ def argsort(self, x, *, axis=-1, descending=False, stable=True): def sort(self, x, *, axis=-1, descending=False, stable=True): if axis < 0: axis += x.ndim - _len = ndx.asarray(x.shape[axis : axis + 1], dtype=dtypes.int64)._core() + _len = ndx.asarray(nda.shape(x)[axis : axis + 1], dtype=dtypes.int64)._core() return _via_i64_f64( lambda x: opx.top_k(x, _len, largest=descending, axis=axis)[0], [x] ) @@ -813,7 +813,7 @@ def make_nullable(self, x, null): return ndx.Array._from_fields( dtypes.into_nullable(x.dtype), values=x.copy(), - null=ndx.reshape(null, x.shape), + null=ndx.reshape(null, nda.shape(x)), ) @validate_core diff --git a/ndonnx/_core/_shapeimpl.py b/ndonnx/_core/_shapeimpl.py index 8674d6d..b68f795 100644 --- a/ndonnx/_core/_shapeimpl.py +++ b/ndonnx/_core/_shapeimpl.py @@ -99,7 +99,7 @@ def roll(self, x, shift, axis): if not isinstance(shift, Sequence): shift = [shift] - old_shape = x.shape + old_shape = nda.shape(x) if axis is None: x = ndx.reshape(x, [-1]) @@ -112,9 +112,7 @@ def roll(self, x, shift, axis): raise ValueError("shift and axis must have the same length") for sh, ax in zip(shift, axis): - len_single = opx.gather( - ndx.asarray(x.shape, dtype=dtypes.int64)._core(), opx.const(ax) - ) + len_single = opx.gather(nda.shape(x)._core(), opx.const(ax)) shift_single = opx.add(opx.const(-sh, dtype=dtypes.int64), len_single) # Find the needed element index and then gather from it range = opx.cast( @@ -246,7 +244,7 @@ def where_dtype_agnostic(a: ndx.Array, b: ndx.Array): return output def zeros_like(self, x, dtype=None, device=None): - return ndx.zeros(x.shape, dtype=dtype or x.dtype, device=device) + return ndx.zeros(nda.shape(x), dtype=dtype or x.dtype, device=device) def ones_like(self, x, dtype=None, device=None): - return ndx.ones(x.shape, dtype=dtype or x.dtype, device=device) + return ndx.ones(nda.shape(x), dtype=dtype or x.dtype, device=device) diff --git a/ndonnx/_core/_stringimpl.py b/ndonnx/_core/_stringimpl.py index 0218a4a..ddb4f6e 100644 --- a/ndonnx/_core/_stringimpl.py +++ b/ndonnx/_core/_stringimpl.py @@ -10,6 +10,7 @@ import ndonnx as ndx import ndonnx._data_types as dtypes import ndonnx._opset_extensions as opx +import ndonnx.additional as nda from ._nullableimpl import NullableOperationsImpl from ._shapeimpl import UniformShapeOperations @@ -76,7 +77,7 @@ def make_nullable(self, x, null): return ndx.Array._from_fields( dtypes.into_nullable(x.dtype), values=x.copy(), - null=ndx.reshape(null, x.shape), + null=ndx.reshape(null, nda.shape(x)), ) diff --git a/ndonnx/_funcs.py b/ndonnx/_funcs.py index 1088a28..8a2b448 100644 --- a/ndonnx/_funcs.py +++ b/ndonnx/_funcs.py @@ -13,6 +13,7 @@ import ndonnx._data_types as dtypes from ndonnx._data_types import CastError, CastMixin, CoreType, _NullableCore from ndonnx._data_types.structtype import StructType +from ndonnx.additional import shape from . import _opset_extensions as opx from ._array import Array, _from_corearray @@ -574,7 +575,7 @@ def numeric_like(x): ret = numeric_like(next(it)) while (x := next(it, None)) is not None: ret = ret + numeric_like(x) - target_shape = ret.shape + target_shape = shape(ret) return [broadcast_to(x, target_shape) for x in arrays]