Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

User driven getitem and construction #60

Merged
merged 8 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ Changelog
0.9.0 (unreleased)
------------------

**New features**

- User defined data types can now define how arrays with that dtype are constructed by implementing the ``make_array`` function.
- User defined data types can now define how they are indexed (via ``__getitem__``) by implementing the ``getitem`` function.

**Bug fixes**

- Various operations that depend on the array's shape have been updated to work correctly with lazy arrays.
Expand Down
19 changes: 7 additions & 12 deletions ndonnx/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ndonnx as ndx
import ndonnx._data_types as dtypes
from ndonnx.additional import shape
from ndonnx.additional._additional import _getitem as getitem
from ndonnx.additional._additional import _static_shape as static_shape

from ._corearray import _CoreArray
Expand Down Expand Up @@ -47,7 +48,11 @@ def array(
out : Array
The new array. This represents an ONNX model input.
"""
return Array._construct(shape=shape, dtype=dtype)
if (out := dtype._ops.make_array(shape, dtype)) is not NotImplemented:
return out
raise ndx.UnsupportedOperationError(
f"No implementation of `make_array` for {dtype}"
)


def from_spox_var(
Expand Down Expand Up @@ -154,17 +159,7 @@ def astype(self, to: CoreType | StructType) -> Array:
return ndx.astype(self, to)

def __getitem__(self, index: IndexType) -> Array:
if isinstance(index, Array) and not (
isinstance(index.dtype, dtypes.Integral) or index.dtype == dtypes.bool
):
raise TypeError(
f"Index must be an integral or boolean 'Array', not `{index.dtype}`"
)

if isinstance(index, Array):
index = index._core()

return self._transmute(lambda corearray: corearray[index])
return getitem(self, index)

def __setitem__(
self, index: IndexType | Self, updates: int | bool | float | Array
Expand Down
24 changes: 10 additions & 14 deletions ndonnx/_core/_boolimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import ndonnx as ndx
import ndonnx._data_types as dtypes
import ndonnx._opset_extensions as opx
import ndonnx.additional as nda

from ._coreimpl import CoreOperationsImpl
from ._interface import OperationsBlock
from ._nullableimpl import NullableOperationsImpl
from ._shapeimpl import UniformShapeOperations
from ._utils import binary_op, unary_op, validate_core
Expand All @@ -22,7 +23,7 @@
from ndonnx import Array


class BooleanOperationsImpl(UniformShapeOperations):
class _BooleanOperationsImpl(OperationsBlock):
@validate_core
def equal(self, x, y) -> Array:
return binary_op(x, y, opx.equal)
Expand Down Expand Up @@ -163,17 +164,12 @@ def empty_like(self, x, dtype=None, device=None) -> ndx.Array:
def nonzero(self, x) -> tuple[Array, ...]:
return ndx.nonzero(x.astype(ndx.int8))

@validate_core
def make_nullable(self, x, null):
if null.dtype != dtypes.bool:
raise TypeError("'null' must be a boolean array")
return ndx.Array._from_fields(
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.broadcast_to(null, nda.shape(x)),
)

class BooleanOperationsImpl(
CoreOperationsImpl, _BooleanOperationsImpl, UniformShapeOperations
): ...

class NullableBooleanOperationsImpl(BooleanOperationsImpl, NullableOperationsImpl):
def make_nullable(self, x, null):
return NotImplemented

class NullableBooleanOperationsImpl(
NullableOperationsImpl, _BooleanOperationsImpl, UniformShapeOperations
): ...
49 changes: 49 additions & 0 deletions ndonnx/_core/_coreimpl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause

from typing import TYPE_CHECKING

import numpy as np
from spox import Tensor, argument

import ndonnx as ndx
import ndonnx._data_types as dtypes
import ndonnx.additional as nda
from ndonnx._corearray import _CoreArray

from ._interface import OperationsBlock
from ._utils import validate_core

if TYPE_CHECKING:
from ndonnx._array import Array
from ndonnx._data_types import Dtype


class CoreOperationsImpl(OperationsBlock):
def make_array(
self,
shape: tuple[int | None | str, ...],
dtype: "Dtype",
eager_value: np.ndarray | None = None,
) -> "Array":
if not isinstance(dtype, dtypes.CoreType):
return NotImplemented
return ndx.Array._from_fields(
dtype,
data=_CoreArray(
dtype._parse_input(eager_value)["data"]
if eager_value is not None
else argument(Tensor(dtype.to_numpy_dtype(), shape))
),
)

@validate_core
def make_nullable(self, x: "Array", null: "Array") -> "Array":
if null.dtype != ndx.bool:
raise TypeError("'null' must be a boolean array")

return ndx.Array._from_fields(
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.broadcast_to(null, nda.shape(x)),
)
22 changes: 21 additions & 1 deletion ndonnx/_core/_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@

from __future__ import annotations

from typing import Literal
from typing import TYPE_CHECKING, Literal

import numpy as np

import ndonnx as ndx
import ndonnx._data_types as dtypes

if TYPE_CHECKING:
from ndonnx._array import IndexType


class OperationsBlock:
"""Interface for data types to implement top-level functions exported by ndonnx."""
Expand Down Expand Up @@ -413,3 +418,18 @@ def can_cast(self, from_, to) -> bool:

def static_shape(self, x) -> tuple[int | None, ...]:
return NotImplemented

def make_array(
self,
shape: tuple[int | None | str, ...],
dtype: dtypes.CoreType | dtypes.StructType,
eager_value: np.ndarray | None = None,
) -> ndx.Array:
return NotImplemented

def getitem(
self,
x: ndx.Array,
index: IndexType,
) -> ndx.Array:
return NotImplemented
13 changes: 12 additions & 1 deletion ndonnx/_core/_nullableimpl.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause
from typing import TYPE_CHECKING, Union

import ndonnx as ndx

from ._interface import OperationsBlock
from ._utils import validate_core

if TYPE_CHECKING:
from ndonnx._array import Array
from ndonnx._data_types import CoreType, StructType

Dtype = Union[CoreType, StructType]


class NullableOperationsImpl(OperationsBlock):
@validate_core
def fill_null(self, x, value):
def fill_null(self, x: "Array", value) -> "Array":
value = ndx.asarray(value)
if value.dtype != x.values.dtype:
value = value.astype(x.values.dtype)
return ndx.where(x.null, value, x.values)

@validate_core
def make_nullable(self, x: "Array", null: "Array") -> "Array":
return NotImplemented
26 changes: 11 additions & 15 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import ndonnx.additional as nda
from ndonnx._utility import promote

from ._coreimpl import CoreOperationsImpl
from ._interface import OperationsBlock
from ._nullableimpl import NullableOperationsImpl
from ._shapeimpl import UniformShapeOperations
from ._utils import (
Expand All @@ -36,7 +38,7 @@
from ndonnx._corearray import _CoreArray


class NumericOperationsImpl(UniformShapeOperations):
class _NumericOperationsImpl(OperationsBlock):
# elementwise.py

@validate_core
Expand Down Expand Up @@ -837,17 +839,6 @@ def var(
- correction
)

@validate_core
def make_nullable(self, x, null):
if null.dtype != dtypes.bool:
raise TypeError("'null' must be a boolean array")

return ndx.Array._from_fields(
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.broadcast_to(null, nda.shape(x)),
)

@validate_core
def can_cast(self, from_, to) -> bool:
if isinstance(from_, dtypes.CoreType) and isinstance(to, ndx.CoreType):
Expand Down Expand Up @@ -980,9 +971,14 @@ def empty_like(self, x, dtype=None, device=None) -> ndx.Array:
return ndx.full_like(x, 0, dtype=dtype)


class NullableNumericOperationsImpl(NumericOperationsImpl, NullableOperationsImpl):
def make_nullable(self, x, null):
return NotImplemented
class NumericOperationsImpl(
CoreOperationsImpl, _NumericOperationsImpl, UniformShapeOperations
): ...


class NullableNumericOperationsImpl(
NullableOperationsImpl, _NumericOperationsImpl, UniformShapeOperations
): ...


def _via_i64_f64(
Expand Down
58 changes: 57 additions & 1 deletion ndonnx/_core/_shapeimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING

import numpy as np

Expand All @@ -13,6 +14,10 @@
from ._interface import OperationsBlock
from ._utils import from_corearray

if TYPE_CHECKING:
from ndonnx._array import Array, IndexType
from ndonnx._data_types import Dtype


class UniformShapeOperations(OperationsBlock):
"""Provides implementation for shape/indexing operations that are generic across all
Expand Down Expand Up @@ -247,4 +252,55 @@ def zeros_like(self, x, dtype=None, device=None):
return ndx.zeros(nda.shape(x), dtype=dtype or x.dtype, device=device)

def ones_like(self, x, dtype=None, device=None):
return ndx.ones(nda.shape(x), dtype=dtype or x.dtype, device=device)
return ndx.ones(x.shape, dtype=dtype or x.dtype, device=device)

def make_array(
self,
shape: tuple[int | None | str, ...],
dtype: "Dtype",
eager_value: np.ndarray | None = None,
) -> "Array":
if isinstance(dtype, dtypes.CoreType):
return NotImplemented

fields: dict[str, ndx.Array] = {}

eager_values = None if eager_value is None else dtype._parse_input(eager_value)
for name, field_dtype in dtype._fields().items():
if eager_values is None:
field_value = None
else:
field_value = _assemble_output_recurse(field_dtype, eager_values[name])
fields[name] = field_dtype._ops.make_array(
shape,
field_dtype,
field_value,
)
return ndx.Array._from_fields(
dtype,
**fields,
)

def getitem(self, x: "Array", index: "IndexType") -> "Array":
if isinstance(index, ndx.Array) and not (
isinstance(index.dtype, dtypes.Integral) or index.dtype == dtypes.bool
):
raise TypeError(
f"Index must be an integral or boolean 'Array', not `{index.dtype}`"
)

if isinstance(index, ndx.Array):
index = index._core()

return x._transmute(lambda corearray: corearray[index])


def _assemble_output_recurse(dtype: "Dtype", values: dict) -> np.ndarray:
if isinstance(dtype, dtypes.CoreType):
return dtype._assemble_output(values)
else:
fields = {
name: _assemble_output_recurse(field_dtype, values[name])
for name, field_dtype in dtype._fields().items()
}
return dtype._assemble_output(fields)
23 changes: 9 additions & 14 deletions ndonnx/_core/_stringimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
import ndonnx as ndx
import ndonnx._data_types as dtypes
import ndonnx._opset_extensions as opx
import ndonnx.additional as nda

from ._coreimpl import CoreOperationsImpl
from ._interface import OperationsBlock
from ._nullableimpl import NullableOperationsImpl
from ._shapeimpl import UniformShapeOperations
from ._utils import binary_op, validate_core
Expand All @@ -20,7 +21,7 @@
from ndonnx import Array


class StringOperationsImpl(UniformShapeOperations):
class _StringOperationsImpl(OperationsBlock):
@validate_core
def add(self, x, y) -> Array:
return binary_op(x, y, opx.string_concat)
Expand Down Expand Up @@ -69,18 +70,12 @@ def empty(self, shape, dtype=None, device=None) -> ndx.Array:
def empty_like(self, x, dtype=None, device=None) -> ndx.Array:
return ndx.zeros_like(x, dtype=dtype, device=device)

@validate_core
def make_nullable(self, x, null):
if null.dtype != dtypes.bool:
raise TypeError("'null' must be a boolean array")

return ndx.Array._from_fields(
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.broadcast_to(null, nda.shape(x)),
)
class StringOperationsImpl(
CoreOperationsImpl, _StringOperationsImpl, UniformShapeOperations
): ...


class NullableStringOperationsImpl(StringOperationsImpl, NullableOperationsImpl):
def make_nullable(self, x, null):
return NotImplemented
class NullableStringOperationsImpl(
NullableOperationsImpl, _StringOperationsImpl, UniformShapeOperations
): ...
Loading