Skip to content

Commit

Permalink
User driven getitem and construction
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Aug 23, 2024
1 parent b642d5a commit 1e48baa
Show file tree
Hide file tree
Showing 12 changed files with 265 additions and 60 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@
Changelog
=========

0.9.0 (unreleased)
------------------

**New features**

- User defined data types can now define how arrays with that dtype are constructed by implementing the :func:`make_array` function.
- User defined data types can now define how they are indexed (via `__getitem__`) by implementing the :func:`getitem` function.


0.8.0 (2024-08-22)
------------------

Expand Down
15 changes: 3 additions & 12 deletions ndonnx/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ndonnx as ndx
import ndonnx._data_types as dtypes
from ndonnx.additional import shape
from ndonnx.additional._additional import _getitem as getitem
from ndonnx.additional._additional import _static_shape as static_shape

from ._corearray import _CoreArray
Expand Down Expand Up @@ -47,7 +48,7 @@ def array(
out : Array
The new array. This represents an ONNX model input.
"""
return Array._construct(shape=shape, dtype=dtype)
return dtype._ops.make_array(shape, dtype)


def from_spox_var(
Expand Down Expand Up @@ -143,17 +144,7 @@ def astype(self, to: CoreType | StructType) -> Array:
return ndx.astype(self, to)

def __getitem__(self, index: IndexType) -> Array:
if isinstance(index, Array) and not (
isinstance(index.dtype, dtypes.Integral) or index.dtype == dtypes.bool
):
raise TypeError(
f"Index must be an integral or boolean 'Array', not `{index.dtype}`"
)

if isinstance(index, Array):
index = index._core()

return self._transmute(lambda corearray: corearray[index])
return getitem(self, index)

def __setitem__(
self, index: IndexType | Self, updates: int | bool | float | Array
Expand Down
23 changes: 10 additions & 13 deletions ndonnx/_core/_boolimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import ndonnx._data_types as dtypes
import ndonnx._opset_extensions as opx

from ._coreimpl import CoreOperationsImpl
from ._interface import OperationsBlock
from ._nullableimpl import NullableOperationsImpl
from ._shapeimpl import UniformShapeOperations
from ._utils import binary_op, unary_op, validate_core
Expand All @@ -21,7 +23,7 @@
from ndonnx import Array


class BooleanOperationsImpl(UniformShapeOperations):
class _BooleanOperationsImpl(OperationsBlock):
@validate_core
def equal(self, x, y) -> Array:
return binary_op(x, y, opx.equal)
Expand Down Expand Up @@ -162,17 +164,12 @@ def empty_like(self, x, dtype=None, device=None) -> ndx.Array:
def nonzero(self, x) -> tuple[Array, ...]:
return ndx.nonzero(x.astype(ndx.int8))

@validate_core
def make_nullable(self, x, null):
if null.dtype != dtypes.bool:
raise TypeError("'null' must be a boolean array")
return ndx.Array._from_fields(
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.reshape(null, x.shape),
)

class BooleanOperationsImpl(
CoreOperationsImpl, _BooleanOperationsImpl, UniformShapeOperations
): ...

class NullableBooleanOperationsImpl(BooleanOperationsImpl, NullableOperationsImpl):
def make_nullable(self, x, null):
return NotImplemented

class NullableBooleanOperationsImpl(
NullableOperationsImpl, _BooleanOperationsImpl, UniformShapeOperations
): ...
34 changes: 34 additions & 0 deletions ndonnx/_core/_coreimpl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause

from spox import Tensor, argument

import ndonnx as ndx
import ndonnx._data_types as dtypes
from ndonnx._corearray import _CoreArray

from ._interface import OperationsBlock
from ._utils import validate_core


class CoreOperationsImpl(OperationsBlock):
def make_array(self, shape, dtype, eager_value=None):
return ndx.Array._from_fields(
dtype,
data=_CoreArray(
dtype._parse_input(eager_value)["data"]
if eager_value is not None
else argument(Tensor(dtype.to_numpy_dtype(), shape))
),
)

@validate_core
def make_nullable(self, x, null):
if null.dtype != ndx.bool:
raise TypeError("'null' must be a boolean array")

return ndx.Array._from_fields(
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.reshape(null, x.shape),
)
22 changes: 21 additions & 1 deletion ndonnx/_core/_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@

from __future__ import annotations

from typing import Literal
from typing import TYPE_CHECKING, Literal

import numpy as np

import ndonnx as ndx
import ndonnx._data_types as dtypes

if TYPE_CHECKING:
from ndonnx._array import IndexType


class OperationsBlock:
"""Interface for data types to implement top-level functions exported by ndonnx."""
Expand Down Expand Up @@ -413,3 +418,18 @@ def can_cast(self, from_, to) -> bool:

def static_shape(self, x) -> tuple[int | None, ...]:
return NotImplemented

def make_array(
self,
shape: tuple[int | None | str, ...],
dtype: dtypes.CoreType | dtypes.StructType,
eager_value: np.ndarray | None = None,
) -> ndx.Array:
return NotImplemented

def getitem(
self,
x: ndx.Array,
index: IndexType,
) -> ndx.Array:
return NotImplemented
4 changes: 4 additions & 0 deletions ndonnx/_core/_nullableimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@ def fill_null(self, x, value):
if value.dtype != x.values.dtype:
value = value.astype(x.values.dtype)
return ndx.where(x.null, value, x.values)

@validate_core
def make_nullable(self, x, null):
return NotImplemented
26 changes: 11 additions & 15 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import ndonnx._opset_extensions as opx
from ndonnx._utility import promote

from ._coreimpl import CoreOperationsImpl
from ._interface import OperationsBlock
from ._nullableimpl import NullableOperationsImpl
from ._shapeimpl import UniformShapeOperations
from ._utils import (
Expand All @@ -35,7 +37,7 @@
from ndonnx._corearray import _CoreArray


class NumericOperationsImpl(UniformShapeOperations):
class _NumericOperationsImpl(OperationsBlock):
# elementwise.py

@validate_core
Expand Down Expand Up @@ -805,17 +807,6 @@ def var(
- correction
)

@validate_core
def make_nullable(self, x, null):
if null.dtype != dtypes.bool:
raise TypeError("'null' must be a boolean array")

return ndx.Array._from_fields(
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.reshape(null, x.shape),
)

@validate_core
def can_cast(self, from_, to) -> bool:
if isinstance(from_, dtypes.CoreType) and isinstance(to, ndx.CoreType):
Expand Down Expand Up @@ -948,9 +939,14 @@ def empty_like(self, x, dtype=None, device=None) -> ndx.Array:
return ndx.full_like(x, 0, dtype=dtype)


class NullableNumericOperationsImpl(NumericOperationsImpl, NullableOperationsImpl):
def make_nullable(self, x, null):
return NotImplemented
class NumericOperationsImpl(
CoreOperationsImpl, _NumericOperationsImpl, UniformShapeOperations
): ...


class NullableNumericOperationsImpl(
NullableOperationsImpl, _NumericOperationsImpl, UniformShapeOperations
): ...


def _via_i64_f64(
Expand Down
30 changes: 30 additions & 0 deletions ndonnx/_core/_shapeimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,33 @@ def zeros_like(self, x, dtype=None, device=None):

def ones_like(self, x, dtype=None, device=None):
return ndx.ones(x.shape, dtype=dtype or x.dtype, device=device)

def make_array(self, shape, dtype, eager_value=None):
fields = {}

eager_values = None if eager_value is None else dtype._parse_input(eager_value)
for name, field_dtype in dtype._fields().items():
fields[name] = field_dtype._ops.make_array(
shape,
field_dtype,
field_dtype._assemble_output(eager_values[name])
if eager_values is not None
else None,
)
return ndx.Array._from_fields(
dtype,
**fields,
)

def getitem(self, x, index):
if isinstance(index, ndx.Array) and not (
isinstance(index.dtype, dtypes.Integral) or index.dtype == dtypes.bool
):
raise TypeError(
f"Index must be an integral or boolean 'Array', not `{index.dtype}`"
)

if isinstance(index, ndx.Array):
index = index._core()

return x._transmute(lambda corearray: corearray[index])
22 changes: 9 additions & 13 deletions ndonnx/_core/_stringimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import ndonnx._data_types as dtypes
import ndonnx._opset_extensions as opx

from ._coreimpl import CoreOperationsImpl
from ._interface import OperationsBlock
from ._nullableimpl import NullableOperationsImpl
from ._shapeimpl import UniformShapeOperations
from ._utils import binary_op, validate_core
Expand All @@ -19,7 +21,7 @@
from ndonnx import Array


class StringOperationsImpl(UniformShapeOperations):
class _StringOperationsImpl(OperationsBlock):
@validate_core
def add(self, x, y) -> Array:
return binary_op(x, y, opx.string_concat)
Expand Down Expand Up @@ -68,18 +70,12 @@ def empty(self, shape, dtype=None, device=None) -> ndx.Array:
def empty_like(self, x, dtype=None, device=None) -> ndx.Array:
return ndx.zeros_like(x, dtype=dtype, device=device)

@validate_core
def make_nullable(self, x, null):
if null.dtype != dtypes.bool:
raise TypeError("'null' must be a boolean array")

return ndx.Array._from_fields(
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.reshape(null, x.shape),
)
class StringOperationsImpl(
CoreOperationsImpl, _StringOperationsImpl, UniformShapeOperations
): ...


class NullableStringOperationsImpl(StringOperationsImpl, NullableOperationsImpl):
def make_nullable(self, x, null):
return NotImplemented
class NullableStringOperationsImpl(
NullableOperationsImpl, _StringOperationsImpl, UniformShapeOperations
): ...
12 changes: 7 additions & 5 deletions ndonnx/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,21 @@ def asarray(
device=None,
) -> Array:
if not isinstance(x, Array):
arr = np.asanyarray(
eager_value = np.asanyarray(
x,
dtype=(
dtype.to_numpy_dtype() if isinstance(dtype, dtypes.CoreType) else None
),
)
if dtype is None:
dtype = dtypes.from_numpy_dtype(arr.dtype)
if isinstance(arr, np.ma.masked_array):
dtype = dtypes.from_numpy_dtype(eager_value.dtype)
if isinstance(eager_value, np.ma.masked_array):
dtype = dtypes.into_nullable(dtype)

ret = Array._construct(
shape=arr.shape, dtype=dtype, eager_values=dtype._parse_input(arr)
ret = dtype._ops.make_array(
shape=eager_value.shape,
dtype=dtype,
eager_value=eager_value,
)
else:
ret = x.copy() if copy is True else x
Expand Down
10 changes: 10 additions & 0 deletions ndonnx/additional/_additional.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

if TYPE_CHECKING:
from ndonnx import Array
from ndonnx._array import IndexType

Scalar = TypeVar("Scalar", int, float, str)

Expand Down Expand Up @@ -147,6 +148,15 @@ def make_nullable(x: Array, null: Array) -> Array:
return out


def _getitem(x: Array, index: IndexType) -> ndx.Array:
out = x.dtype._ops.getitem(x, index)
if out is NotImplemented:
raise ndx.UnsupportedOperationError(
f"`getitem` not implemented for `{x.dtype}`"
)
return out


def _static_shape(x: Array) -> tuple[int | None, ...]:
"""Return shape of the array as a tuple. Typical implementations will make use of
ONNX shape inference, with `None` entries denoting unknown or symbolic dimensions.
Expand Down
Loading

0 comments on commit 1e48baa

Please sign in to comment.