Skip to content

Commit

Permalink
Complete .shape -> nda.shape
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Aug 23, 2024
1 parent b642d5a commit 13851f0
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 25 deletions.
2 changes: 1 addition & 1 deletion ndonnx/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def size(self) -> ndx.Array:
out: Array
Scalar ``Array`` instance whose value is the number of elements in the original array.
"""
return ndx.prod(self.shape)
return ndx.prod(shape(self))

@property
def T(self) -> ndx.Array: # noqa: N802
Expand Down
3 changes: 2 additions & 1 deletion ndonnx/_core/_boolimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import ndonnx as ndx
import ndonnx._data_types as dtypes
import ndonnx._opset_extensions as opx
import ndonnx.additional as nda

from ._nullableimpl import NullableOperationsImpl
from ._shapeimpl import UniformShapeOperations
Expand Down Expand Up @@ -169,7 +170,7 @@ def make_nullable(self, x, null):
return ndx.Array._from_fields(
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.reshape(null, x.shape),
null=ndx.reshape(null, nda.shape(x)),
)


Expand Down
30 changes: 15 additions & 15 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ndonnx as ndx
import ndonnx._data_types as dtypes
import ndonnx._opset_extensions as opx
import ndonnx.additional as nda
from ndonnx._utility import promote

from ._nullableimpl import NullableOperationsImpl
Expand Down Expand Up @@ -198,7 +199,7 @@ def isfinite(self, x):
def isinf(self, x):
if isinstance(x.dtype, (dtypes.Floating, dtypes.NullableFloating)):
return unary_op(x, opx.isinf)
return ndx.full(x.shape, fill_value=False)
return ndx.full(nda.shape(x), fill_value=False)

@validate_core
def isnan(self, x):
Expand Down Expand Up @@ -400,9 +401,7 @@ def nonzero(self, x) -> tuple[Array, ...]:
return (ndx.arange(0, x != 0, dtype=dtypes.int64),)

ret_full_flattened = ndx.reshape(
from_corearray(
opx.ndindex(ndx.asarray(x.shape, dtype=dtypes.int64)._core())
)[x != 0],
from_corearray(opx.ndindex(nda.shape(x)._core()))[x != 0],
[-1],
)

Expand All @@ -413,7 +412,7 @@ def nonzero(self, x) -> tuple[Array, ...]:
ret_full_flattened._core(),
ndx.arange(
i,
ret_full_flattened.shape[0],
nda.shape(ret_full_flattened)[0],
x.ndim,
dtype=dtypes.int64,
)._core(),
Expand Down Expand Up @@ -454,20 +453,21 @@ def searchsorted(
from_corearray, opx.get_indices(x1._core(), x2._core(), positions._core())
)

how_many = ndx.zeros(ndx.asarray(combined.shape) + 1, dtype=dtypes.int64)
combined_shape = nda.shape(combined)
how_many = ndx.zeros(combined_shape + 1, dtype=dtypes.int64)
how_many[
ndx.where(indices_x1 + 1 <= combined.shape[0], indices_x1 + 1, indices_x1)
ndx.where(indices_x1 + 1 <= combined_shape[0], indices_x1 + 1, indices_x1)
] = counts
how_many = ndx.cumulative_sum(how_many, include_initial=True)

ret = ndx.zeros(x2.shape, dtype=dtypes.int64)
ret = ndx.zeros(nda.shape(x2), dtype=dtypes.int64)

if side == "left":
ret = how_many[indices_x2]
ret[nan_mask] = ndx.asarray(x1.shape, dtype=dtypes.int64) - 1
ret[nan_mask] = nda.shape(x1) - 1
else:
ret = how_many[indices_x2 + 1]
ret[nan_mask] = ndx.asarray(x1.shape, dtype=dtypes.int64)
ret[nan_mask] = nda.shape(x1)

return ret

Expand All @@ -494,15 +494,15 @@ def unique_all(self, x):
# FIXME: I think we can simply use arange/ones+cumsum or something for the indices
# maybe: indices = opx.cumsum(ones_like(flattened, dtype=dtypes.i64), axis=ndx.asarray(0))
indices = opx.squeeze(
opx.ndindex(ndx.asarray(flattened.shape, dtype=dtypes.int64)._core()),
opx.ndindex(nda.shape(flattened)._core()),
opx.const([1], dtype=dtypes.int64),
)

ret = namedtuple("ret", ["values", "indices", "inverse_indices", "counts"])

values = from_corearray(ret_opd[0])
indices = from_corearray(indices[ret_opd[1]])
inverse_indices = ndx.reshape(from_corearray(ret_opd[2]), x.shape)
inverse_indices = ndx.reshape(from_corearray(ret_opd[2]), nda.shape(x))
counts = from_corearray(ret_opd[3])

return ret(
Expand Down Expand Up @@ -535,7 +535,7 @@ def argsort(self, x, *, axis=-1, descending=False, stable=True):
if axis < 0:
axis += x.ndim

_len = ndx.asarray(x.shape[axis : axis + 1], dtype=dtypes.int64)._core()
_len = ndx.asarray(nda.shape(x)[axis : axis + 1], dtype=dtypes.int64)._core()
return _via_i64_f64(
lambda x: opx.top_k(x, _len, largest=descending, axis=axis)[1], [x]
)
Expand All @@ -544,7 +544,7 @@ def argsort(self, x, *, axis=-1, descending=False, stable=True):
def sort(self, x, *, axis=-1, descending=False, stable=True):
if axis < 0:
axis += x.ndim
_len = ndx.asarray(x.shape[axis : axis + 1], dtype=dtypes.int64)._core()
_len = ndx.asarray(nda.shape(x)[axis : axis + 1], dtype=dtypes.int64)._core()
return _via_i64_f64(
lambda x: opx.top_k(x, _len, largest=descending, axis=axis)[0], [x]
)
Expand Down Expand Up @@ -813,7 +813,7 @@ def make_nullable(self, x, null):
return ndx.Array._from_fields(
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.reshape(null, x.shape),
null=ndx.reshape(null, nda.shape(x)),
)

@validate_core
Expand Down
10 changes: 4 additions & 6 deletions ndonnx/_core/_shapeimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def roll(self, x, shift, axis):
if not isinstance(shift, Sequence):
shift = [shift]

old_shape = x.shape
old_shape = nda.shape(x)

if axis is None:
x = ndx.reshape(x, [-1])
Expand All @@ -112,9 +112,7 @@ def roll(self, x, shift, axis):
raise ValueError("shift and axis must have the same length")

for sh, ax in zip(shift, axis):
len_single = opx.gather(
ndx.asarray(x.shape, dtype=dtypes.int64)._core(), opx.const(ax)
)
len_single = opx.gather(nda.shape(x)._core(), opx.const(ax))
shift_single = opx.add(opx.const(-sh, dtype=dtypes.int64), len_single)
# Find the needed element index and then gather from it
range = opx.cast(
Expand Down Expand Up @@ -246,7 +244,7 @@ def where_dtype_agnostic(a: ndx.Array, b: ndx.Array):
return output

def zeros_like(self, x, dtype=None, device=None):
return ndx.zeros(x.shape, dtype=dtype or x.dtype, device=device)
return ndx.zeros(nda.shape(x), dtype=dtype or x.dtype, device=device)

def ones_like(self, x, dtype=None, device=None):
return ndx.ones(x.shape, dtype=dtype or x.dtype, device=device)
return ndx.ones(nda.shape(x), dtype=dtype or x.dtype, device=device)
3 changes: 2 additions & 1 deletion ndonnx/_core/_stringimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import ndonnx as ndx
import ndonnx._data_types as dtypes
import ndonnx._opset_extensions as opx
import ndonnx.additional as nda

from ._nullableimpl import NullableOperationsImpl
from ._shapeimpl import UniformShapeOperations
Expand Down Expand Up @@ -76,7 +77,7 @@ def make_nullable(self, x, null):
return ndx.Array._from_fields(
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.reshape(null, x.shape),
null=ndx.reshape(null, nda.shape(x)),
)


Expand Down
3 changes: 2 additions & 1 deletion ndonnx/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import ndonnx._data_types as dtypes
from ndonnx._data_types import CastError, CastMixin, CoreType, _NullableCore
from ndonnx._data_types.structtype import StructType
from ndonnx.additional import shape

from . import _opset_extensions as opx
from ._array import Array, _from_corearray
Expand Down Expand Up @@ -574,7 +575,7 @@ def numeric_like(x):
ret = numeric_like(next(it))
while (x := next(it, None)) is not None:
ret = ret + numeric_like(x)
target_shape = ret.shape
target_shape = shape(ret)
return [broadcast_to(x, target_shape) for x in arrays]


Expand Down

0 comments on commit 13851f0

Please sign in to comment.