Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into construct-member
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Aug 24, 2024
2 parents 1534a39 + 84b8f4d commit 1bee70d
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 20 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ Changelog
- User defined data types can now define how arrays with that dtype are constructed by implementing the :func:`make_array` function.
- User defined data types can now define how they are indexed (via `__getitem__`) by implementing the :func:`getitem` function.

**Bug fixes**

- Various operations that depend on the array's shape have been updated to work correctly with lazy arrays.

**Breaking change**

- Iterating over dynamic dimensions of :class:`~ndonnx.Array` is no longer allowed since it commonly lead to infinite loops when used without an explicit break condition.


0.8.0 (2024-08-22)
------------------
Expand Down
19 changes: 17 additions & 2 deletions ndonnx/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ def array(
out : Array
The new array. This represents an ONNX model input.
"""
return dtype._ops.make_array(shape, dtype)
if (out := dtype._ops.make_array(shape, dtype)) is not NotImplemented:
return out
raise ndx.UnsupportedOperationError(
f"No implementation of `make_array` for {dtype}"
)


def from_spox_var(
Expand Down Expand Up @@ -110,6 +114,17 @@ def __getattr__(self, name: str) -> Array:
else:
raise AttributeError(f"Field {name} not found")

def __iter__(self):
try:
n, *_ = self.shape
except IndexError:
raise ValueError("iteration over 0-d array")
if isinstance(n, int):
return (self[i, ...] for i in range(n))
raise ValueError(
"iteration requires dimension of static length, but dimension 0 is dynamic."
)

def _set(self, other: Array) -> Array:
self.dtype = other.dtype
self._fields = other._fields
Expand Down Expand Up @@ -497,7 +512,7 @@ def size(self) -> ndx.Array:
out: Array
Scalar ``Array`` instance whose value is the number of elements in the original array.
"""
return ndx.prod(self.shape)
return ndx.prod(shape(self))

@property
def T(self) -> ndx.Array: # noqa: N802
Expand Down
25 changes: 12 additions & 13 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def isfinite(self, x):
def isinf(self, x):
if isinstance(x.dtype, (dtypes.Floating, dtypes.NullableFloating)):
return unary_op(x, opx.isinf)
return ndx.full(x.shape, fill_value=False)
return ndx.full(nda.shape(x), fill_value=False)

@validate_core
def isnan(self, x):
Expand Down Expand Up @@ -403,9 +403,7 @@ def nonzero(self, x) -> tuple[Array, ...]:
return (ndx.arange(0, x != 0, dtype=dtypes.int64),)

ret_full_flattened = ndx.reshape(
from_corearray(
opx.ndindex(ndx.asarray(x.shape, dtype=dtypes.int64)._core())
)[x != 0],
from_corearray(opx.ndindex(nda.shape(x)._core()))[x != 0],
[-1],
)

Expand All @@ -416,7 +414,7 @@ def nonzero(self, x) -> tuple[Array, ...]:
ret_full_flattened._core(),
ndx.arange(
i,
ret_full_flattened.shape[0],
nda.shape(ret_full_flattened)[0],
x.ndim,
dtype=dtypes.int64,
)._core(),
Expand Down Expand Up @@ -457,20 +455,21 @@ def searchsorted(
from_corearray, opx.get_indices(x1._core(), x2._core(), positions._core())
)

how_many = ndx.zeros(ndx.asarray(combined.shape) + 1, dtype=dtypes.int64)
combined_shape = nda.shape(combined)
how_many = ndx.zeros(combined_shape + 1, dtype=dtypes.int64)
how_many[
ndx.where(indices_x1 + 1 <= combined.shape[0], indices_x1 + 1, indices_x1)
ndx.where(indices_x1 + 1 <= combined_shape[0], indices_x1 + 1, indices_x1)
] = counts
how_many = ndx.cumulative_sum(how_many, include_initial=True)

ret = ndx.zeros(x2.shape, dtype=dtypes.int64)
ret = ndx.zeros(nda.shape(x2), dtype=dtypes.int64)

if side == "left":
ret = how_many[indices_x2]
ret[nan_mask] = ndx.asarray(x1.shape, dtype=dtypes.int64) - 1
ret[nan_mask] = nda.shape(x1) - 1
else:
ret = how_many[indices_x2 + 1]
ret[nan_mask] = ndx.asarray(x1.shape, dtype=dtypes.int64)
ret[nan_mask] = nda.shape(x1)

return ret

Expand All @@ -497,7 +496,7 @@ def unique_all(self, x):
# FIXME: I think we can simply use arange/ones+cumsum or something for the indices
# maybe: indices = opx.cumsum(ones_like(flattened, dtype=dtypes.i64), axis=ndx.asarray(0))
indices = opx.squeeze(
opx.ndindex(ndx.asarray(flattened.shape, dtype=dtypes.int64)._core()),
opx.ndindex(nda.shape(flattened)._core()),
opx.const([1], dtype=dtypes.int64),
)

Expand Down Expand Up @@ -538,7 +537,7 @@ def argsort(self, x, *, axis=-1, descending=False, stable=True):
if axis < 0:
axis += x.ndim

_len = ndx.asarray(x.shape[axis : axis + 1], dtype=dtypes.int64)._core()
_len = ndx.asarray(nda.shape(x)[axis : axis + 1], dtype=dtypes.int64)._core()
return _via_i64_f64(
lambda x: opx.top_k(x, _len, largest=descending, axis=axis)[1], [x]
)
Expand All @@ -547,7 +546,7 @@ def argsort(self, x, *, axis=-1, descending=False, stable=True):
def sort(self, x, *, axis=-1, descending=False, stable=True):
if axis < 0:
axis += x.ndim
_len = ndx.asarray(x.shape[axis : axis + 1], dtype=dtypes.int64)._core()
_len = ndx.asarray(nda.shape(x)[axis : axis + 1], dtype=dtypes.int64)._core()
return _via_i64_f64(
lambda x: opx.top_k(x, _len, largest=descending, axis=axis)[0], [x]
)
Expand Down
6 changes: 2 additions & 4 deletions ndonnx/_core/_shapeimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,7 @@ def roll(self, x, shift, axis):
raise ValueError("shift and axis must have the same length")

for sh, ax in zip(shift, axis):
len_single = opx.gather(
ndx.asarray(x.shape, dtype=dtypes.int64)._core(), opx.const(ax)
)
len_single = opx.gather(nda.shape(x)._core(), opx.const(ax))
shift_single = opx.add(opx.const(-sh, dtype=dtypes.int64), len_single)
# Find the needed element index and then gather from it
range = opx.cast(
Expand Down Expand Up @@ -246,7 +244,7 @@ def where_dtype_agnostic(a: ndx.Array, b: ndx.Array):
return output

def zeros_like(self, x, dtype=None, device=None):
return ndx.zeros(x.shape, dtype=dtype or x.dtype, device=device)
return ndx.zeros(nda.shape(x), dtype=dtype or x.dtype, device=device)

def ones_like(self, x, dtype=None, device=None):
return ndx.ones(x.shape, dtype=dtype or x.dtype, device=device)
Expand Down
7 changes: 6 additions & 1 deletion ndonnx/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import ndonnx._data_types as dtypes
from ndonnx._data_types import CastError, CastMixin, CoreType, _NullableCore
from ndonnx._data_types.structtype import StructType
from ndonnx.additional import shape

from . import _opset_extensions as opx
from ._array import Array, _from_corearray
Expand Down Expand Up @@ -76,6 +77,10 @@ def asarray(
dtype=dtype,
eager_value=eager_value,
)
if ret is NotImplemented:
raise UnsupportedOperationError(
f"Unsupported operand type for asarray: '{dtype}'"
)
else:
ret = x.copy() if copy is True else x

Expand Down Expand Up @@ -576,7 +581,7 @@ def numeric_like(x):
ret = numeric_like(next(it))
while (x := next(it, None)) is not None:
ret = ret + numeric_like(x)
target_shape = ret.shape
target_shape = shape(ret)
return [broadcast_to(x, target_shape) for x in arrays]


Expand Down
44 changes: 44 additions & 0 deletions tests/test_iter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause

import pytest

import ndonnx as ndx


def test_iter_for_loop():
n = 5
a = ndx.array(shape=(n,), dtype=ndx.int64)

for i, el in enumerate(a): # type: ignore
assert isinstance(el, ndx.Array)
if i > n:
assert False, "Iterated past the number of elements"


@pytest.mark.parametrize(
"arr",
[
ndx.asarray([1]),
ndx.asarray([[1], [2]]),
ndx.array(shape=(2,), dtype=ndx.int64),
ndx.array(shape=(2, 3), dtype=ndx.int64),
ndx.array(shape=(2, "N"), dtype=ndx.int64),
],
)
def test_create_iterators(arr):
it = iter(arr)
el = next(it)
assert el.ndim == arr.ndim - 1
assert el.shape == arr.shape[1:]


def test_0d_not_iterable():
scalar = ndx.array(shape=(), dtype=ndx.int64)
with pytest.raises(ValueError):
next(iter(scalar))


def test_raises_dynamic_dim():
with pytest.raises(ValueError):
iter(ndx.array(shape=("N",), dtype=ndx.int64))

0 comments on commit 1bee70d

Please sign in to comment.