From a770d6f0ba914adc152489e7e2ea2d818a060ce2 Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Thu, 22 Aug 2024 01:19:35 +0100 Subject: [PATCH] Shape --- CHANGELOG.rst | 6 +++++- ndonnx/_array.py | 33 +++++++++++++++----------------- ndonnx/_core/_interface.py | 3 +++ ndonnx/_core/_shapeimpl.py | 12 +++++++++++- ndonnx/additional/__init__.py | 3 ++- ndonnx/additional/_additional.py | 21 ++++++++++++++++++++ 6 files changed, 57 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a4d1ee3..4713642 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,13 +7,17 @@ Changelog ========= -0.7.1 (unreleased) +0.8.0 (unreleased) ------------------ **Bug fixes** - Fixes parsing numpy arrays of type ``object`` (consisting of strings) as ``utf8``. Previously this worked correctly only for 1d arrays. +**Breaking change** + +- :meth:`ndonnx.Array.shape` now strictly returns a ``tuple[int | None, ...]``, with unknown dimensions denoted by ``None``. This relies on ONNX shape inference for lazy arrays. + 0.7.0 (2024-08-12) ------------------ diff --git a/ndonnx/_array.py b/ndonnx/_array.py index 2585a60..e130790 100644 --- a/ndonnx/_array.py +++ b/ndonnx/_array.py @@ -14,7 +14,7 @@ import ndonnx as ndx import ndonnx._data_types as dtypes -from ndonnx.additional import shape +from ndonnx.additional import shape, static_shape from ._corearray import _CoreArray from ._index import ScalarIndexType @@ -227,28 +227,25 @@ def ndim(self) -> int: return len(self._static_shape) @property - def shape(self) -> Array | tuple[int, ...]: + def shape(self) -> tuple[int | None, ...]: """The shape of the array. Returns: - Array | tuple[int, ...]: The shape of the array. + tuple[int | None, ...]: The shape of the array. Note that the shape of the array is a tuple of integers when the "eager value" - of the array can be determined. This is in strict compliance with the Array API standard which - presupposes tuples in compliance tests. + of the array can be determined. - When lazy arrays are involved, the shape is data-dependent (or runtime inputs dependent) and so we - cannot reliably determine the shape. In such cases, an integer ``ndx.Array`` is returned instead. + When the array is lazy, getting the concrete shape may not be possible. + We fall back to the `static_shape` function, which may be implemented + using ONNX shape inference. This may have dimensions set as `None` where + they are unknown or symbolic with respect to input shapes. """ - # The Array API standard expects that the shape has type tuple[int | None, ...], as do compliance tests. - # However, when the array is completely lazy, a concrete shape may not be determinable. - # We therefore provide an int Array which is provisioned for in the standard. - # See https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.shape.html#shape for more context. - eager_value = self.to_numpy() - if eager_value is not None: - return eager_value.shape + shape_array = shape(self).to_numpy() + if shape_array is not None: + return tuple(map(int, shape_array)) else: - return shape(self) + return static_shape(self) @property def values(self) -> Array: @@ -371,8 +368,8 @@ def __bool__(self) -> bool: raise ValueError(f"Cannot convert Array of shape {self.shape} to a bool") def __len__(self) -> int: - if isinstance(eager_value := self.shape, tuple): - return eager_value[0] + if isinstance(self.shape[0], int): + return self.shape[0] else: raise ValueError(f"Cannot convert Array of shape {self.shape} to a length") @@ -540,7 +537,7 @@ def len(self): """ if self.ndim != 1: raise TypeError("Cannot call len on Array of rank != 1") - return ndx.asarray(self.shape[0]) + return shape(self)[0] def sum(self, axis: int | None = 0, keepdims: bool | None = False) -> ndx.Array: """See :py:func:`ndonnx.sum` for documentation.""" diff --git a/ndonnx/_core/_interface.py b/ndonnx/_core/_interface.py index 8d3a765..37fca9a 100644 --- a/ndonnx/_core/_interface.py +++ b/ndonnx/_core/_interface.py @@ -410,3 +410,6 @@ def make_nullable(self, x, null) -> ndx.Array: def can_cast(self, from_, to) -> bool: return NotImplemented + + def static_shape(self, x) -> tuple[int | None, ...]: + return NotImplemented diff --git a/ndonnx/_core/_shapeimpl.py b/ndonnx/_core/_shapeimpl.py index f3919c6..fd90bac 100644 --- a/ndonnx/_core/_shapeimpl.py +++ b/ndonnx/_core/_shapeimpl.py @@ -8,6 +8,7 @@ import ndonnx as ndx import ndonnx._data_types as dtypes import ndonnx._opset_extensions as opx +import ndonnx.additional as nda from ._interface import OperationsBlock from ._utils import from_corearray @@ -24,6 +25,15 @@ def shape(self, x): current = next(iter(current._fields.values())) return from_corearray(opx.shape(current)) + def static_shape(self, x) -> tuple[int | None, ...]: + current = x + while isinstance(current, ndx.Array): + current = next(iter(current._fields.values())) + return tuple( + None if not isinstance(dim, int) else dim + for dim in current.var.unwrap_tensor().shape + ) + def take(self, x, indices, axis=None): if axis is None: axis = 0 @@ -138,7 +148,7 @@ def full(self, shape, fill_value, dtype=None, device=None): def full_like(self, x, fill_value, dtype=None, device=None): fill_value = ndx.asarray(fill_value).astype(dtype or x.dtype) - return ndx.broadcast_to(fill_value, ndx.asarray(x).shape) + return ndx.broadcast_to(fill_value, nda.shape(x)) def where(self, condition, x, y): if x.dtype != y.dtype: diff --git a/ndonnx/additional/__init__.py b/ndonnx/additional/__init__.py index 7554609..62bed8e 100644 --- a/ndonnx/additional/__init__.py +++ b/ndonnx/additional/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause -from ._additional import fill_null, make_nullable, isin, shape, static_map +from ._additional import fill_null, make_nullable, isin, shape, static_map, static_shape __all__ = [ "fill_null", @@ -9,4 +9,5 @@ "isin", "shape", "static_map", + "static_shape", ] diff --git a/ndonnx/additional/_additional.py b/ndonnx/additional/_additional.py index 711cd69..7600c08 100644 --- a/ndonnx/additional/_additional.py +++ b/ndonnx/additional/_additional.py @@ -39,6 +39,27 @@ def shape(x: Array) -> Array: return out +def static_shape(x: Array) -> tuple[int | None, ...]: + """Return shape of the array as a tuple. Typical implementations will make use of + ONNX shape inference, with `None` entries denoting unknown or symbolic dimension. + + Parameters + ---------- + x: Array + Array to get static shape of + + Returns + ------- + out: tuple[int | None, ...] + """ + out = x.dtype._ops.static_shape(x) + if out is NotImplemented: + raise ndx.UnsupportedOperationError( + f"`static_shape` not implemented for `{x.dtype}`" + ) + return out + + def isin(x: Array, items: Sequence[Scalar]) -> Array: """Return true where the input ``Array`` contains an element in ``items``.