diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 194dae8..2aa8730 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -11,8 +11,10 @@ Changelog 0.9.0 (unreleased) ------------------ -**New feature** +**New features** +- User defined data types can now define how arrays with that dtype are constructed by implementing the ``make_array`` function. +- User defined data types can now define how they are indexed (via ``__getitem__``) by implementing the ``getitem`` function. - :class:`ndonnx.NullableCore` is now public, encapsulating nullable variants of `CoreType`s exported by ndonnx. **Bug fixes** diff --git a/ndonnx/_array.py b/ndonnx/_array.py index 08195a7..c06bf66 100644 --- a/ndonnx/_array.py +++ b/ndonnx/_array.py @@ -15,6 +15,7 @@ import ndonnx as ndx import ndonnx._data_types as dtypes from ndonnx.additional import shape +from ndonnx.additional._additional import _getitem as getitem from ndonnx.additional._additional import _static_shape as static_shape from ._corearray import _CoreArray @@ -47,7 +48,11 @@ def array( out : Array The new array. This represents an ONNX model input. """ - return Array._construct(shape=shape, dtype=dtype) + if (out := dtype._ops.make_array(shape, dtype)) is not NotImplemented: + return out + raise ndx.UnsupportedOperationError( + f"No implementation of `make_array` for {dtype}" + ) def from_spox_var( @@ -154,17 +159,7 @@ def astype(self, to: CoreType | StructType) -> Array: return ndx.astype(self, to) def __getitem__(self, index: IndexType) -> Array: - if isinstance(index, Array) and not ( - isinstance(index.dtype, dtypes.Integral) or index.dtype == dtypes.bool - ): - raise TypeError( - f"Index must be an integral or boolean 'Array', not `{index.dtype}`" - ) - - if isinstance(index, Array): - index = index._core() - - return self._transmute(lambda corearray: corearray[index]) + return getitem(self, index) def __setitem__( self, index: IndexType | Self, updates: int | bool | float | Array diff --git a/ndonnx/_core/_boolimpl.py b/ndonnx/_core/_boolimpl.py index e778f62..f8c93c8 100644 --- a/ndonnx/_core/_boolimpl.py +++ b/ndonnx/_core/_boolimpl.py @@ -12,8 +12,9 @@ import ndonnx as ndx import ndonnx._data_types as dtypes import ndonnx._opset_extensions as opx -import ndonnx.additional as nda +from ._coreimpl import CoreOperationsImpl +from ._interface import OperationsBlock from ._nullableimpl import NullableOperationsImpl from ._shapeimpl import UniformShapeOperations from ._utils import binary_op, unary_op, validate_core @@ -22,7 +23,7 @@ from ndonnx import Array -class BooleanOperationsImpl(UniformShapeOperations): +class _BooleanOperationsImpl(OperationsBlock): @validate_core def equal(self, x, y) -> Array: return binary_op(x, y, opx.equal) @@ -163,17 +164,12 @@ def empty_like(self, x, dtype=None, device=None) -> ndx.Array: def nonzero(self, x) -> tuple[Array, ...]: return ndx.nonzero(x.astype(ndx.int8)) - @validate_core - def make_nullable(self, x, null): - if null.dtype != dtypes.bool: - raise TypeError("'null' must be a boolean array") - return ndx.Array._from_fields( - dtypes.into_nullable(x.dtype), - values=x.copy(), - null=ndx.broadcast_to(null, nda.shape(x)), - ) +class BooleanOperationsImpl( + CoreOperationsImpl, _BooleanOperationsImpl, UniformShapeOperations +): ... -class NullableBooleanOperationsImpl(BooleanOperationsImpl, NullableOperationsImpl): - def make_nullable(self, x, null): - return NotImplemented + +class NullableBooleanOperationsImpl( + NullableOperationsImpl, _BooleanOperationsImpl, UniformShapeOperations +): ... diff --git a/ndonnx/_core/_coreimpl.py b/ndonnx/_core/_coreimpl.py new file mode 100644 index 0000000..e5b2a58 --- /dev/null +++ b/ndonnx/_core/_coreimpl.py @@ -0,0 +1,50 @@ +# Copyright (c) QuantCo 2023-2024 +# SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +from spox import Tensor, argument + +import ndonnx as ndx +import ndonnx._data_types as dtypes +import ndonnx.additional as nda +from ndonnx._corearray import _CoreArray + +from ._interface import OperationsBlock +from ._utils import validate_core + +if TYPE_CHECKING: + from ndonnx._array import Array + from ndonnx._data_types import Dtype + + +class CoreOperationsImpl(OperationsBlock): + def make_array( + self, + shape: tuple[int | None | str, ...], + dtype: Dtype, + eager_value: np.ndarray | None = None, + ) -> Array: + if not isinstance(dtype, dtypes.CoreType): + return NotImplemented + return ndx.Array._from_fields( + dtype, + data=_CoreArray( + dtype._parse_input(eager_value)["data"] + if eager_value is not None + else argument(Tensor(dtype.to_numpy_dtype(), shape)) + ), + ) + + @validate_core + def make_nullable(self, x: Array, null: Array) -> Array: + if null.dtype != ndx.bool: + raise TypeError("'null' must be a boolean array") + + return ndx.Array._from_fields( + dtypes.into_nullable(x.dtype), + values=x.copy(), + null=ndx.broadcast_to(null, nda.shape(x)), + ) diff --git a/ndonnx/_core/_interface.py b/ndonnx/_core/_interface.py index 37fca9a..5340f4f 100644 --- a/ndonnx/_core/_interface.py +++ b/ndonnx/_core/_interface.py @@ -3,11 +3,17 @@ from __future__ import annotations -from typing import Literal +from typing import TYPE_CHECKING, Literal + +import numpy as np import ndonnx as ndx import ndonnx._data_types as dtypes +if TYPE_CHECKING: + from ndonnx._array import IndexType + from ndonnx._data_types import Dtype + class OperationsBlock: """Interface for data types to implement top-level functions exported by ndonnx.""" @@ -251,7 +257,7 @@ def cumulative_sum( x, *, axis: int | None = None, - dtype: ndx.CoreType | dtypes.StructType | None = None, + dtype: Dtype | None = None, include_initial: bool = False, ): return NotImplemented @@ -270,7 +276,7 @@ def prod( x, *, axis=None, - dtype: dtypes.CoreType | dtypes.StructType | None = None, + dtype: Dtype | None = None, keepdims: bool = False, ) -> ndx.Array: return NotImplemented @@ -293,7 +299,7 @@ def sum( x, *, axis=None, - dtype: dtypes.CoreType | dtypes.StructType | None = None, + dtype: Dtype | None = None, keepdims: bool = False, ) -> ndx.Array: return NotImplemented @@ -305,7 +311,7 @@ def var( axis=None, keepdims: bool = False, correction=0.0, - dtype: dtypes.CoreType | dtypes.StructType | None = None, + dtype: Dtype | None = None, ) -> ndx.Array: return NotImplemented @@ -352,7 +358,7 @@ def full_like(self, x, fill_value, dtype=None, device=None) -> ndx.Array: def ones( self, shape, - dtype: dtypes.CoreType | dtypes.StructType | None = None, + dtype: Dtype | None = None, device=None, ): return NotImplemented @@ -365,14 +371,12 @@ def ones_like( def zeros( self, shape, - dtype: dtypes.CoreType | dtypes.StructType | None = None, + dtype: Dtype | None = None, device=None, ): return NotImplemented - def zeros_like( - self, x, dtype: dtypes.CoreType | dtypes.StructType | None = None, device=None - ): + def zeros_like(self, x, dtype: Dtype | None = None, device=None): return NotImplemented def empty(self, shape, dtype=None, device=None) -> ndx.Array: @@ -413,3 +417,18 @@ def can_cast(self, from_, to) -> bool: def static_shape(self, x) -> tuple[int | None, ...]: return NotImplemented + + def make_array( + self, + shape: tuple[int | None | str, ...], + dtype: Dtype, + eager_value: np.ndarray | None = None, + ) -> ndx.Array: + return NotImplemented + + def getitem( + self, + x: ndx.Array, + index: IndexType, + ) -> ndx.Array: + return NotImplemented diff --git a/ndonnx/_core/_nullableimpl.py b/ndonnx/_core/_nullableimpl.py index 71115fc..ce1013b 100644 --- a/ndonnx/_core/_nullableimpl.py +++ b/ndonnx/_core/_nullableimpl.py @@ -1,16 +1,29 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + +from typing import TYPE_CHECKING, Union import ndonnx as ndx from ._interface import OperationsBlock from ._utils import validate_core +if TYPE_CHECKING: + from ndonnx._array import Array + from ndonnx._data_types import CoreType, StructType + + Dtype = Union[CoreType, StructType] + class NullableOperationsImpl(OperationsBlock): @validate_core - def fill_null(self, x, value): + def fill_null(self, x: Array, value) -> Array: value = ndx.asarray(value) if value.dtype != x.values.dtype: value = value.astype(x.values.dtype) return ndx.where(x.null, value, x.values) + + @validate_core + def make_nullable(self, x: Array, null: Array) -> Array: + return NotImplemented diff --git a/ndonnx/_core/_numericimpl.py b/ndonnx/_core/_numericimpl.py index ecfdcd0..e01edd7 100644 --- a/ndonnx/_core/_numericimpl.py +++ b/ndonnx/_core/_numericimpl.py @@ -19,6 +19,8 @@ import ndonnx.additional as nda from ndonnx._utility import promote +from ._coreimpl import CoreOperationsImpl +from ._interface import OperationsBlock from ._nullableimpl import NullableOperationsImpl from ._shapeimpl import UniformShapeOperations from ._utils import ( @@ -36,7 +38,7 @@ from ndonnx._corearray import _CoreArray -class NumericOperationsImpl(UniformShapeOperations): +class _NumericOperationsImpl(OperationsBlock): # elementwise.py @validate_core @@ -837,17 +839,6 @@ def var( - correction ) - @validate_core - def make_nullable(self, x, null): - if null.dtype != dtypes.bool: - raise TypeError("'null' must be a boolean array") - - return ndx.Array._from_fields( - dtypes.into_nullable(x.dtype), - values=x.copy(), - null=ndx.broadcast_to(null, nda.shape(x)), - ) - @validate_core def can_cast(self, from_, to) -> bool: if isinstance(from_, dtypes.CoreType) and isinstance(to, ndx.CoreType): @@ -980,9 +971,14 @@ def empty_like(self, x, dtype=None, device=None) -> ndx.Array: return ndx.full_like(x, 0, dtype=dtype) -class NullableNumericOperationsImpl(NumericOperationsImpl, NullableOperationsImpl): - def make_nullable(self, x, null): - return NotImplemented +class NumericOperationsImpl( + CoreOperationsImpl, _NumericOperationsImpl, UniformShapeOperations +): ... + + +class NullableNumericOperationsImpl( + NullableOperationsImpl, _NumericOperationsImpl, UniformShapeOperations +): ... def _via_i64_f64( diff --git a/ndonnx/_core/_shapeimpl.py b/ndonnx/_core/_shapeimpl.py index b68f795..30d619c 100644 --- a/ndonnx/_core/_shapeimpl.py +++ b/ndonnx/_core/_shapeimpl.py @@ -1,7 +1,10 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + from collections.abc import Iterable, Sequence +from typing import TYPE_CHECKING import numpy as np @@ -13,6 +16,10 @@ from ._interface import OperationsBlock from ._utils import from_corearray +if TYPE_CHECKING: + from ndonnx._array import Array, IndexType + from ndonnx._data_types import Dtype + class UniformShapeOperations(OperationsBlock): """Provides implementation for shape/indexing operations that are generic across all @@ -247,4 +254,55 @@ def zeros_like(self, x, dtype=None, device=None): return ndx.zeros(nda.shape(x), dtype=dtype or x.dtype, device=device) def ones_like(self, x, dtype=None, device=None): - return ndx.ones(nda.shape(x), dtype=dtype or x.dtype, device=device) + return ndx.ones(x.shape, dtype=dtype or x.dtype, device=device) + + def make_array( + self, + shape: tuple[int | None | str, ...], + dtype: Dtype, + eager_value: np.ndarray | None = None, + ) -> Array: + if isinstance(dtype, dtypes.CoreType): + return NotImplemented + + fields: dict[str, ndx.Array] = {} + + eager_values = None if eager_value is None else dtype._parse_input(eager_value) + for name, field_dtype in dtype._fields().items(): + if eager_values is None: + field_value = None + else: + field_value = _assemble_output_recurse(field_dtype, eager_values[name]) + fields[name] = field_dtype._ops.make_array( + shape, + field_dtype, + field_value, + ) + return ndx.Array._from_fields( + dtype, + **fields, + ) + + def getitem(self, x: Array, index: IndexType) -> Array: + if isinstance(index, ndx.Array) and not ( + isinstance(index.dtype, dtypes.Integral) or index.dtype == dtypes.bool + ): + raise TypeError( + f"Index must be an integral or boolean 'Array', not `{index.dtype}`" + ) + + if isinstance(index, ndx.Array): + index = index._core() + + return x._transmute(lambda corearray: corearray[index]) + + +def _assemble_output_recurse(dtype: Dtype, values: dict) -> np.ndarray: + if isinstance(dtype, dtypes.CoreType): + return dtype._assemble_output(values) + else: + fields = { + name: _assemble_output_recurse(field_dtype, values[name]) + for name, field_dtype in dtype._fields().items() + } + return dtype._assemble_output(fields) diff --git a/ndonnx/_core/_stringimpl.py b/ndonnx/_core/_stringimpl.py index 2a42152..1ba2802 100644 --- a/ndonnx/_core/_stringimpl.py +++ b/ndonnx/_core/_stringimpl.py @@ -10,8 +10,9 @@ import ndonnx as ndx import ndonnx._data_types as dtypes import ndonnx._opset_extensions as opx -import ndonnx.additional as nda +from ._coreimpl import CoreOperationsImpl +from ._interface import OperationsBlock from ._nullableimpl import NullableOperationsImpl from ._shapeimpl import UniformShapeOperations from ._utils import binary_op, validate_core @@ -20,7 +21,7 @@ from ndonnx import Array -class StringOperationsImpl(UniformShapeOperations): +class _StringOperationsImpl(OperationsBlock): @validate_core def add(self, x, y) -> Array: return binary_op(x, y, opx.string_concat) @@ -69,18 +70,12 @@ def empty(self, shape, dtype=None, device=None) -> ndx.Array: def empty_like(self, x, dtype=None, device=None) -> ndx.Array: return ndx.zeros_like(x, dtype=dtype, device=device) - @validate_core - def make_nullable(self, x, null): - if null.dtype != dtypes.bool: - raise TypeError("'null' must be a boolean array") - return ndx.Array._from_fields( - dtypes.into_nullable(x.dtype), - values=x.copy(), - null=ndx.broadcast_to(null, nda.shape(x)), - ) +class StringOperationsImpl( + CoreOperationsImpl, _StringOperationsImpl, UniformShapeOperations +): ... -class NullableStringOperationsImpl(StringOperationsImpl, NullableOperationsImpl): - def make_nullable(self, x, null): - return NotImplemented +class NullableStringOperationsImpl( + NullableOperationsImpl, _StringOperationsImpl, UniformShapeOperations +): ... diff --git a/ndonnx/_data_types/__init__.py b/ndonnx/_data_types/__init__.py index 83fe3f5..392abe0 100644 --- a/ndonnx/_data_types/__init__.py +++ b/ndonnx/_data_types/__init__.py @@ -3,7 +3,7 @@ from __future__ import annotations from ndonnx._utility import deprecated - +from typing import Union from .aliases import ( bool, float32, @@ -99,6 +99,9 @@ def into_nullable(dtype: StructType | CoreType) -> NullableCore: raise ValueError(f"Cannot promote {dtype} to nullable") +Dtype = Union[CoreType, StructType] + + @deprecated( "Function 'ndonnx.promote_nullable' will be deprecated in ndonnx 0.7. " "To create nullable array, use 'ndonnx.additional.make_nullable' instead." @@ -151,4 +154,5 @@ def promote_nullable(dtype: StructType | CoreType) -> NullableCore: "Schema", "CastMixin", "CastError", + "Dtype", ] diff --git a/ndonnx/_funcs.py b/ndonnx/_funcs.py index 4fa8c50..d15dd16 100644 --- a/ndonnx/_funcs.py +++ b/ndonnx/_funcs.py @@ -61,20 +61,26 @@ def asarray( device=None, ) -> Array: if not isinstance(x, Array): - arr = np.asanyarray( + eager_value = np.asanyarray( x, dtype=( dtype.to_numpy_dtype() if isinstance(dtype, dtypes.CoreType) else None ), ) if dtype is None: - dtype = dtypes.from_numpy_dtype(arr.dtype) - if isinstance(arr, np.ma.masked_array): + dtype = dtypes.from_numpy_dtype(eager_value.dtype) + if isinstance(eager_value, np.ma.masked_array): dtype = dtypes.into_nullable(dtype) - ret = Array._construct( - shape=arr.shape, dtype=dtype, eager_values=dtype._parse_input(arr) + ret = dtype._ops.make_array( + shape=eager_value.shape, + dtype=dtype, + eager_value=eager_value, ) + if ret is NotImplemented: + raise UnsupportedOperationError( + f"Unsupported operand type for asarray: '{dtype}'" + ) else: ret = x.copy() if copy is True else x diff --git a/ndonnx/additional/_additional.py b/ndonnx/additional/_additional.py index 9c55764..a9f1fc4 100644 --- a/ndonnx/additional/_additional.py +++ b/ndonnx/additional/_additional.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from ndonnx import Array + from ndonnx._array import IndexType Scalar = TypeVar("Scalar", int, float, str) @@ -149,6 +150,15 @@ def make_nullable(x: Array, null: Array) -> Array: return out +def _getitem(x: Array, index: IndexType) -> ndx.Array: + out = x.dtype._ops.getitem(x, index) + if out is NotImplemented: + raise ndx.UnsupportedOperationError( + f"'getitem' not implemented for `{x.dtype}`" + ) + return out + + def _static_shape(x: Array) -> tuple[int | None, ...]: """Return shape of the array as a tuple. Typical implementations will make use of ONNX shape inference, with `None` entries denoting unknown or symbolic dimensions. diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 81a4fb6..d9919fd 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -3,6 +3,7 @@ from __future__ import annotations +import functools import re import numpy as np @@ -10,6 +11,7 @@ from typing_extensions import Self import ndonnx as ndx +import ndonnx.additional as nda from ndonnx import ( Array, CastError, @@ -17,6 +19,7 @@ ) from ndonnx._experimental import ( CastMixin, + OperationsBlock, Schema, StructType, UniformShapeOperations, @@ -137,6 +140,94 @@ def _cast_from(self, array: Array) -> Array: _ops = Unsigned96Impl() +class ListImpl(OperationsBlock): + def make_array( + self, + shape: tuple[int | str | None, ...], + dtype: CoreType | StructType, + eager_value: np.ndarray | None = None, + ) -> Array: + if eager_value is None: + return Array._from_fields( + dtype, + endpoints=ndx.array(shape=shape + (2,), dtype=ndx.int64), + items=ndx.array(shape=(None,), dtype=ndx.utf8), + ) + else: + fields = dtype._parse_input(eager_value) + return Array._from_fields( + dtype, **{name: ndx.asarray(field) for name, field in fields.items()} + ) + + def getitem( + self, + x: Array, + index, + ) -> Array: + if isinstance(index, int): + index = slice(index, index + 1), ... + + return Array._from_fields( + dtype=x.dtype, + endpoints=x.endpoints[index], + items=x.items.copy(), + ) + + def shape(self, x) -> Array: + return nda.shape(x.endpoints)[:-1] + + def static_shape(self, x) -> tuple[int | None, ...]: + return x.endpoints.shape[:-1] + + +class List(StructType): + # The fields here have different shapes + def _fields(self) -> dict[str, StructType | CoreType]: + return { + "endpoints": ndx.int64, + "items": ndx.utf8, + } + + def _parse_input(self, x: np.ndarray) -> dict: + assert x.dtype == object + assert all(isinstance(x, list) for x in x.flat) + + endpoints = np.empty(x.shape + (2,), dtype=np.int64) + items = np.empty( + functools.reduce(lambda acc, elem: acc + len(elem), x.flat, 0), dtype=object + ) + + cur_items_idx = 0 + for idx in np.ndindex(x.shape): + endpoints[idx, :] = [cur_items_idx, cur_items_idx + len(x[idx])] + for elem in x[idx]: + items[cur_items_idx] = elem + cur_items_idx += 1 + + return { + "endpoints": endpoints, + "items": items.astype(np.str_), + } + + def _assemble_output(self, fields: dict[str, np.ndarray]) -> np.ndarray: + endpoints = fields["endpoints"] + items = fields["items"] + + out = np.empty(endpoints.shape[:-1], dtype=object) + for idx in np.ndindex(endpoints.shape[:-1]): + start, end = endpoints[idx] + out[idx] = items[start:end].tolist() + return out + + def copy(self) -> Self: + return self + + def _schema(self) -> Schema: + return Schema(type_name="List", author="value from data!") + + _ops = ListImpl() + + def custom_equal(x: Array, y: Array) -> Array: if x.dtype != Unsigned96() or y.dtype != Unsigned96(): raise ValueError("Can only compare Unsigned96 arrays") @@ -280,3 +371,46 @@ def test_custom_dtype_capable_creation_functions(): assert_array_equal( ndx.ones_like(x, dtype=ndx.int32).to_numpy(), np.ones_like(arr, dtype=np.int32) ) + + +def test_create_dtype_mismatched_shape_fields_eager(): + array = np.empty(shape=(2,), dtype=object) + array[0] = ["a", "bcd", "e"] + array[1] = ["f", "gh"] + x = ndx.asarray(array, dtype=List()) + assert_array_equal(x.to_numpy(), array) + assert x[0].to_numpy().item() == ["a", "bcd", "e"] + assert_array_equal(nda.shape(x).to_numpy(), np.array([2], dtype=np.int64)) + assert x.shape == (2,) + + +def test_create_dtype_mismatched_shape_fields_lazy(): + x = ndx.array(shape=("N", "M", 2), dtype=List()) + assert x.shape == (None, None, 2) + out = x[1:2, 0, ...] + + ndx.build({"x": x}, {"out": out}) + + +def test_recursive_construction(): + class MyNInt64(StructType): + def _fields(self) -> dict[str, StructType | CoreType]: + return {"x": ndx.nint64} + + def _parse_input(self, x: np.ndarray) -> dict: + return {"x": ndx.nint64._parse_input(x)} + + def _assemble_output(self, fields: dict[str, np.ndarray]) -> np.ndarray: + return fields["x"] + + def copy(self) -> Self: + return self + + def _schema(self) -> Schema: + return Schema(type_name="my_nint64", author="me") + + _ops = UniformShapeOperations() + + my_nint64 = MyNInt64() + a = ndx.asarray(np.ma.masked_array([1, 2, 3], [1, 0, 1], np.int64), my_nint64) + assert_array_equal(a.to_numpy(), np.ma.masked_array([1, 2, 3], [1, 0, 1], np.int64))