Skip to content

Commit

Permalink
Shape
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Aug 22, 2024
1 parent e9148c3 commit c44467a
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 20 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
Changelog
=========

0.7.1 (unreleased)
0.8.0 (unreleased)
------------------

**Bug fixes**

- Fixes parsing numpy arrays of type ``object`` (consisting of strings) as ``utf8``. Previously this worked correctly only for 1d arrays.

**Breaking change**

- :meth:`ndonnx.Array.shape` now strictly returns a ``tuple[int | None, ...]``, with unknown dimensions denoted by ``None``. This relies on ONNX shape inference for lazy arrays.


0.7.0 (2024-08-12)
------------------
Expand Down
31 changes: 14 additions & 17 deletions ndonnx/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import ndonnx as ndx
import ndonnx._data_types as dtypes
from ndonnx.additional import shape
from ndonnx.additional import shape, static_shape

from ._corearray import _CoreArray
from ._index import ScalarIndexType
Expand Down Expand Up @@ -227,28 +227,25 @@ def ndim(self) -> int:
return len(self._static_shape)

@property
def shape(self) -> Array | tuple[int, ...]:
def shape(self) -> tuple[int | None, ...]:
"""The shape of the array.
Returns:
Array | tuple[int, ...]: The shape of the array.
Note that the shape of the array is a tuple of integers when the "eager value"
of the array can be determined. This is in strict compliance with the Array API standard which
presupposes tuples in compliance tests.
of the array can be determined.
When lazy arrays are involved, the shape is data-dependent (or runtime inputs dependent) and so we
cannot reliably determine the shape. In such cases, an integer ``ndx.Array`` is returned instead.
When the array is lazy, the shape is data-dependent and so we
cannot reliably determine the shape. In such cases, we fall back to the
`static_shape` function which may be implemented by data types using ONNX
shape inference.
"""
# The Array API standard expects that the shape has type tuple[int | None, ...], as do compliance tests.
# However, when the array is completely lazy, a concrete shape may not be determinable.
# We therefore provide an int Array which is provisioned for in the standard.
# See https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.shape.html#shape for more context.
eager_value = self.to_numpy()
if eager_value is not None:
return eager_value.shape
shape_array = shape(self).to_numpy()
if shape_array is not None:
return tuple(shape_array)
else:
return shape(self)
return static_shape(self)

@property
def values(self) -> Array:
Expand Down Expand Up @@ -371,8 +368,8 @@ def __bool__(self) -> bool:
raise ValueError(f"Cannot convert Array of shape {self.shape} to a bool")

def __len__(self) -> int:
if isinstance(eager_value := self.shape, tuple):
return eager_value[0]
if isinstance(self.shape[0], int):
return self.shape[0]
else:
raise ValueError(f"Cannot convert Array of shape {self.shape} to a length")

Expand Down Expand Up @@ -540,7 +537,7 @@ def len(self):
"""
if self.ndim != 1:
raise TypeError("Cannot call len on Array of rank != 1")
return ndx.asarray(self.shape[0])
return shape(self)[0]

def sum(self, axis: int | None = 0, keepdims: bool | None = False) -> ndx.Array:
"""See :py:func:`ndonnx.sum` for documentation."""
Expand Down
3 changes: 3 additions & 0 deletions ndonnx/_core/_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,3 +410,6 @@ def make_nullable(self, x, null) -> ndx.Array:

def can_cast(self, from_, to) -> bool:
return NotImplemented

def static_shape(self, x) -> tuple[int | None, ...]:
return NotImplemented
12 changes: 11 additions & 1 deletion ndonnx/_core/_shapeimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ndonnx as ndx
import ndonnx._data_types as dtypes
import ndonnx._opset_extensions as opx
import ndonnx.additional as nda

from ._interface import OperationsBlock
from ._utils import from_corearray
Expand All @@ -24,6 +25,15 @@ def shape(self, x):
current = next(iter(current._fields.values()))
return from_corearray(opx.shape(current))

def static_shape(self, x) -> tuple[int | None, ...]:
current = x
while isinstance(current, ndx.Array):
current = next(iter(current._fields.values()))
return tuple(
None if not isinstance(dim, int) else dim
for dim in current.var.unwrap_tensor().shape
)

def take(self, x, indices, axis=None):
if axis is None:
axis = 0
Expand Down Expand Up @@ -138,7 +148,7 @@ def full(self, shape, fill_value, dtype=None, device=None):

def full_like(self, x, fill_value, dtype=None, device=None):
fill_value = ndx.asarray(fill_value).astype(dtype or x.dtype)
return ndx.broadcast_to(fill_value, ndx.asarray(x).shape)
return ndx.broadcast_to(fill_value, nda.shape(x))

def where(self, condition, x, y):
if x.dtype != y.dtype:
Expand Down
3 changes: 2 additions & 1 deletion ndonnx/additional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause

from ._additional import fill_null, make_nullable, isin, shape, static_map
from ._additional import fill_null, make_nullable, isin, shape, static_map, static_shape

__all__ = [
"fill_null",
"make_nullable",
"isin",
"shape",
"static_map",
"static_shape",
]
21 changes: 21 additions & 0 deletions ndonnx/additional/_additional.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,27 @@ def shape(x: Array) -> Array:
return out


def static_shape(x: Array) -> tuple[int | None, ...]:
"""Return shape of the array as a tuple. Typical implementations will make use of
ONNX shape inference, with `None` entries denoting unknown or symbolic dimension.
Parameters
----------
x: Array
Array to get static shape of
Returns
-------
out: tuple[int | None, ...]
"""
out = x.dtype._ops.static_shape(x)
if out is NotImplemented:
raise ndx.UnsupportedOperationError(
f"`static_shape` not implemented for `{x.dtype}`"
)
return out


def isin(x: Array, items: Sequence[Scalar]) -> Array:
"""Return true where the input ``Array`` contains an element in ``items``.
Expand Down

0 comments on commit c44467a

Please sign in to comment.