Skip to content

Commit

Permalink
Fix full_like dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Aug 1, 2024
1 parent abc9baf commit 120cb27
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 29 deletions.
39 changes: 39 additions & 0 deletions ndonnx/_core/_boolimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def bitwise_xor(self, x, y) -> Array:
def logical_and(self, x, y):
# If one of the operands is True and the broadcasted shape can be guaranteed to be the other array's shape,
# we can return this other array directly.
x, y = map(ndx.asarray, (x, y))
if (
x.to_numpy() is not None
and x.to_numpy().size == 1
Expand Down Expand Up @@ -108,3 +109,41 @@ def any(self, x, *, axis=None, keepdims: bool = False):
return ndx.max(x.astype(ndx.int8), axis=axis, keepdims=keepdims).astype(
ndx.bool
)

def ones(
self,
shape,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
device=None,
):
dtype = dtypes.float64 if dtype is None else dtype
return ndx.full(shape, True, dtype=dtype)

def ones_like(
self, x, dtype: dtypes.StructType | dtypes.CoreType | None = None, device=None
):
return ndx.full_like(x, True, dtype=dtype)

def zeros(
self,
shape,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
device=None,
):
dtype = dtypes.float64 if dtype is None else dtype
return ndx.full(shape, False, dtype=dtype)

def zeros_like(
self, x, dtype: dtypes.CoreType | dtypes.StructType | None = None, device=None
):
return ndx.full_like(x, False, dtype=dtype)

def empty(self, shape, dtype=None, device=None) -> ndx.Array:
shape = ndx.asarray(shape, dtype=dtypes.int64)
return ndx.full(shape, False, dtype=dtype)

def empty_like(self, x, dtype=None, device=None) -> ndx.Array:
return ndx.full_like(x, False, dtype=dtype)

def nonzero(self, x) -> tuple[Array, ...]:
return ndx.nonzero(x.astype(ndx.int8))
15 changes: 6 additions & 9 deletions ndonnx/_core/_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,34 +537,31 @@ def ones(
dtype: dtypes.CoreType | dtypes.StructType | None = None,
device=None,
):
dtype = dtypes.float64 if dtype is None else dtype
return ndx.full(shape, 1, dtype=dtype)
return NotImplemented

def ones_like(
self, x, dtype: dtypes.StructType | dtypes.CoreType | None = None, device=None
):
return ndx.full_like(x, 1, dtype=dtype)
return NotImplemented

def zeros(
self,
shape,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
device=None,
):
dtype = dtypes.float64 if dtype is None else dtype
return ndx.full(shape, 0, dtype=dtype)
return NotImplemented

def zeros_like(
self, x, dtype: dtypes.CoreType | dtypes.StructType | None = None, device=None
):
return ndx.full_like(x, 0, dtype=dtype)
return NotImplemented

def empty(self, shape, dtype=None, device=None) -> ndx.Array:
shape = ndx.asarray(shape, dtype=dtypes.int64)
return ndx.full(shape, 0, dtype=dtype)
return NotImplemented

def empty_like(self, x, dtype=None, device=None) -> ndx.Array:
return ndx.full_like(x, 0, dtype=dtype)
return NotImplemented

def arange(self, start, stop=None, step=None, dtype=None, device=None) -> ndx.Array:
return NotImplemented
Expand Down
48 changes: 43 additions & 5 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,16 +325,19 @@ def argmin(self, x, axis=None, keepdims=False):
)
return _via_i64_f64(lambda x: opx.arg_min(x, axis=axis, keepdims=keepdims), [x])

def nonzero(self, x):
def nonzero(self, x) -> tuple[Array, ...]:
if not isinstance(x.dtype, dtypes.CoreType):
return NotImplemented

if x.ndim == 0:
return [ndx.arange(0, x != 0, dtype=dtypes.int64)]
return (ndx.arange(0, x != 0, dtype=dtypes.int64),)

ret_full_flattened = ndx.reshape(
from_corearray(opx.ndindex(opx.shape(x)))[x != 0],
from_corearray(opx.ndindex(opx.shape(x._core())))[x != 0],
[-1],
)

return [
return tuple(
ndx.reshape(
from_corearray(
opx.gather_elements(
Expand All @@ -350,7 +353,7 @@ def nonzero(self, x):
[-1],
)
for i in range(x.ndim)
]
)

def searchsorted(
self,
Expand Down Expand Up @@ -825,6 +828,41 @@ def linspace(
np.linspace(start, stop, num=num, endpoint=endpoint), dtype=dtype
)

def ones(
self,
shape,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
device=None,
):
dtype = dtypes.float64 if dtype is None else dtype
return ndx.full(shape, 1, dtype=dtype)

def ones_like(
self, x, dtype: dtypes.StructType | dtypes.CoreType | None = None, device=None
):
return ndx.full_like(x, 1, dtype=dtype)

def zeros(
self,
shape,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
device=None,
):
dtype = dtypes.float64 if dtype is None else dtype
return ndx.full(shape, 0, dtype=dtype)

def zeros_like(
self, x, dtype: dtypes.CoreType | dtypes.StructType | None = None, device=None
):
return ndx.full_like(x, 0, dtype=dtype)

def empty(self, shape, dtype=None, device=None) -> ndx.Array:
shape = ndx.asarray(shape, dtype=dtypes.int64)
return ndx.full(shape, 0, dtype=dtype)

def empty_like(self, x, dtype=None, device=None) -> ndx.Array:
return ndx.full_like(x, 0, dtype=dtype)


def _via_i64_f64(
fn: Callable[..., _CoreArray], arrays: list[Array], *, cast_return=True
Expand Down
19 changes: 16 additions & 3 deletions ndonnx/_core/_stringimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import numpy as np

import ndonnx as ndx
import ndonnx._data_types as dtypes
import ndonnx._opset_extensions as opx
from ndonnx._data_types.structtype import StructType

from ._interface import OperationsBlock
from ._utils import binary_op
Expand All @@ -30,8 +30,21 @@ def can_cast(self, from_, to) -> bool:
return np.can_cast(from_.to_numpy_dtype(), to.to_numpy_dtype())
return NotImplemented

def zeros(self, shape, dtype: ndx.CoreType | StructType | None = None, device=None):
return ndx.full(shape, "", dtype=dtype, device=device)
def zeros(
self,
shape,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
device=None,
):
return ndx.full(shape, "", dtype=dtype)

def zeros_like(
self, x, dtype: dtypes.CoreType | dtypes.StructType | None = None, device=None
):
return ndx.full_like(x, "", dtype=dtype)

def empty(self, shape, dtype=None, device=None) -> ndx.Array:
return ndx.zeros(shape, dtype=dtype, device=device)

def empty_like(self, x, dtype=None, device=None) -> ndx.Array:
return ndx.zeros_like(x, dtype=dtype, device=device)
52 changes: 40 additions & 12 deletions ndonnx/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,20 @@ def asarray(

def empty(shape, dtype: dtypes.CoreType | dtypes.StructType | None = None, device=None):
dtype = dtype or dtypes.float64
return dtype._ops.empty(shape, dtype=dtype)
if (out := dtype._ops.empty(shape, dtype=dtype)) is not NotImplemented:
return out
raise UnsupportedOperationError(f"Unsupported operand type for empty: '{dtype}'")


def empty_like(
x, dtype: dtypes.CoreType | dtypes.StructType | None = None, device=None
):
dtype = dtype or dtypes.float64
return dtype._ops.empty_like(x, dtype=dtype)
dtype = dtype or x.dtype
if (out := dtype._ops.empty_like(x, dtype=dtype)) is not NotImplemented:
return out
raise UnsupportedOperationError(
f"Unsupported operand type for empty_like: '{dtype}'"
)


def eye(
Expand All @@ -113,7 +119,11 @@ def full(
device=None,
):
dtype = asarray(fill_value).dtype if dtype is None else dtype
return dtype._ops.full(shape, fill_value, dtype=dtype, device=device)
if (
out := dtype._ops.full(shape, fill_value, dtype=dtype, device=device)
) is not NotImplemented:
return out
raise UnsupportedOperationError(f"Unsupported operand type for full: '{dtype}'")


def full_like(
Expand All @@ -122,8 +132,14 @@ def full_like(
dtype: dtypes.CoreType | dtypes.StructType | None = None,
device=None,
):
dtype = asarray(fill_value).dtype if dtype is None else dtype
return dtype._ops.full_like(x, fill_value, dtype=dtype, device=device)
dtype = x.dtype if dtype is None else dtype
if (
out := dtype._ops.full_like(x, fill_value, dtype=dtype, device=device)
) is not NotImplemented:
return out
raise UnsupportedOperationError(
f"Unsupported operand type for full_like: '{dtype}'"
)


def linspace(
Expand All @@ -146,12 +162,18 @@ def linspace(

def ones(shape, dtype: dtypes.CoreType | dtypes.StructType | None = None, device=None):
dtype = dtypes.float64 if dtype is None else dtype
return dtype._ops.ones(shape, dtype=dtype)
if (out := dtype._ops.ones(shape, dtype=dtype)) is not NotImplemented:
return out
raise UnsupportedOperationError(f"Unsupported operand type for ones: '{dtype}'")


def ones_like(x, dtype: dtypes.StructType | dtypes.CoreType | None = None, device=None):
dtype = dtypes.float64 if dtype is None else dtype
return dtype._ops.ones_like(x, dtype=dtype)
dtype = x.dtype if dtype is None else dtype
if (out := dtype._ops.ones_like(x, dtype=dtype)) is not NotImplemented:
return out
raise UnsupportedOperationError(
f"Unsupported operand type for ones_like: '{dtype}'"
)


def tril(x: Array, *, k: int = 0) -> Array:
Expand All @@ -170,14 +192,20 @@ def triu(x: Array, *, k: int = 0) -> Array:

def zeros(shape, dtype: dtypes.CoreType | dtypes.StructType | None = None, device=None):
dtype = dtypes.float64 if dtype is None else dtype
return dtype._ops.zeros(shape, dtype=dtype)
if (out := dtype._ops.zeros(shape, dtype=dtype)) is not NotImplemented:
return out
raise UnsupportedOperationError(f"Unsupported operand type for zeros: '{dtype}'")


def zeros_like(
x, dtype: dtypes.CoreType | dtypes.StructType | None = None, device=None
):
dtype = dtypes.float64 if dtype is None else dtype
return dtype._ops.zeros_like(x, dtype=dtype)
dtype = x.dtype if dtype is None else dtype
if (out := dtype._ops.zeros_like(x, dtype=dtype)) is not NotImplemented:
return out
raise UnsupportedOperationError(
f"Unsupported operand type for zeros_like: '{dtype}'"
)


# data_type.py
Expand Down
20 changes: 20 additions & 0 deletions tests/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,26 @@ def linspace(
lower=ndx.linspace(start, stop, num, dtype=ndx.uint64, endpoint=endpoint),
)

def zeros(self, shape, dtype: CoreType | StructType | None = None, device=None):
return Array._from_fields(
Unsigned96(),
upper=ndx.zeros(shape, dtype=ndx.uint32, device=device),
lower=ndx.zeros(shape, dtype=ndx.uint64, device=device),
)

def ones(self, shape, dtype: CoreType | StructType | None = None, device=None):
return Array._from_fields(
Unsigned96(),
upper=ndx.zeros(shape, dtype=ndx.uint32, device=device),
lower=ndx.ones(shape, dtype=ndx.uint64, device=device),
)

def empty(self, shape, dtype=None, device=None) -> Array:
return ndx.zeros(shape, dtype=Unsigned96(), device=device)

def empty_like(self, x, dtype=None, device=None) -> Array:
return ndx.zeros_like(x, dtype=Unsigned96(), device=device)


class Unsigned96(StructType, CastMixin):
def _fields(self) -> dict[str, StructType | CoreType]:
Expand Down

0 comments on commit 120cb27

Please sign in to comment.