Skip to content

Commit

Permalink
Make NullableCore public (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 authored Aug 29, 2024
1 parent d64dd52 commit c132f93
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 29 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ Changelog
0.9.0 (unreleased)
------------------

**New feature**

- :class:`ndonnx.NullableCore` is now public, encapsulating nullable variants of `CoreType`s exported by ndonnx.

**Bug fixes**

- Various operations that depend on the array's shape have been updated to work correctly with lazy arrays.
Expand Down
2 changes: 2 additions & 0 deletions ndonnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Floating,
Integral,
Nullable,
NullableCore,
NullableFloating,
NullableIntegral,
NullableNumerical,
Expand Down Expand Up @@ -323,6 +324,7 @@
"Floating",
"NullableIntegral",
"Nullable",
"NullableCore",
"Integral",
"CoreType",
"CastError",
Expand Down
4 changes: 2 additions & 2 deletions ndonnx/_core/_boolimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def can_cast(self, from_, to) -> bool:

@validate_core
def all(self, x, *, axis=None, keepdims: bool = False):
if isinstance(x.dtype, dtypes._NullableCore):
if isinstance(x.dtype, dtypes.NullableCore):
x = ndx.where(x.null, True, x.values)
if functools.reduce(operator.mul, x._static_shape, 1) == 0:
return ndx.asarray(True, dtype=ndx.bool)
Expand All @@ -110,7 +110,7 @@ def all(self, x, *, axis=None, keepdims: bool = False):

@validate_core
def any(self, x, *, axis=None, keepdims: bool = False):
if isinstance(x.dtype, dtypes._NullableCore):
if isinstance(x.dtype, dtypes.NullableCore):
x = ndx.where(x.null, False, x.values)
if functools.reduce(operator.mul, x._static_shape, 1) == 0:
return ndx.asarray(False, dtype=ndx.bool)
Expand Down
10 changes: 5 additions & 5 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ def clip(
and isinstance(x.dtype, dtypes.Numerical)
):
x, min, max = promote(x, min, max)
if isinstance(x.dtype, dtypes._NullableCore):
if isinstance(x.dtype, dtypes.NullableCore):
out_null = x.null
x_values = x.values._core()
clipped = from_corearray(opx.clip(x_values, min._core(), max._core()))
Expand Down Expand Up @@ -856,7 +856,7 @@ def can_cast(self, from_, to) -> bool:

@validate_core
def all(self, x, *, axis=None, keepdims: bool = False):
if isinstance(x.dtype, dtypes._NullableCore):
if isinstance(x.dtype, dtypes.NullableCore):
x = ndx.where(x.null, True, x.values)
if functools.reduce(operator.mul, x._static_shape, 1) == 0:
return ndx.asarray(True, dtype=ndx.bool)
Expand All @@ -866,7 +866,7 @@ def all(self, x, *, axis=None, keepdims: bool = False):

@validate_core
def any(self, x, *, axis=None, keepdims: bool = False):
if isinstance(x.dtype, dtypes._NullableCore):
if isinstance(x.dtype, dtypes.NullableCore):
x = ndx.where(x.null, False, x.values)
if functools.reduce(operator.mul, x._static_shape, 1) == 0:
return ndx.asarray(False, dtype=ndx.bool)
Expand Down Expand Up @@ -898,7 +898,7 @@ def arange(self, start, stop=None, step=None, dtype=None, device=None) -> ndx.Ar

@validate_core
def tril(self, x, k=0) -> ndx.Array:
if isinstance(x.dtype, dtypes._NullableCore):
if isinstance(x.dtype, dtypes.NullableCore):
# NumPy appears to just ignore the mask so we do the same
x = x.values
return x._transmute(
Expand All @@ -909,7 +909,7 @@ def tril(self, x, k=0) -> ndx.Array:

@validate_core
def triu(self, x, k=0) -> ndx.Array:
if isinstance(x.dtype, dtypes._NullableCore):
if isinstance(x.dtype, dtypes.NullableCore):
# NumPy appears to just ignore the mask so we do the same
x = x.values
return x._transmute(
Expand Down
2 changes: 1 addition & 1 deletion ndonnx/_core/_stringimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def zeros_like(
self, x, dtype: dtypes.CoreType | dtypes.StructType | None = None, device=None
):
if dtype is not None and not isinstance(
dtype, (dtypes.CoreType, dtypes._NullableCore)
dtype, (dtypes.CoreType, dtypes.NullableCore)
):
raise TypeError("'dtype' must be a CoreType or NullableCoreType")
if dtype in (None, dtypes.utf8, dtypes.nutf8):
Expand Down
6 changes: 3 additions & 3 deletions ndonnx/_core/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def variadic_op(
):
args = promote(*args)
out_dtype = args[0].dtype
if not isinstance(out_dtype, (dtypes.CoreType, dtypes._NullableCore)):
if not isinstance(out_dtype, (dtypes.CoreType, dtypes.NullableCore)):
raise TypeError(
f"Expected ndx.Array with CoreType or NullableCoreType, got {args[0].dtype}"
)
Expand Down Expand Up @@ -100,7 +100,7 @@ def _via_dtype(
promoted = promote(*arrays)
out_dtype = promoted[0].dtype

if isinstance(out_dtype, dtypes._NullableCore) and out_dtype.values == dtype:
if isinstance(out_dtype, dtypes.NullableCore) and out_dtype.values == dtype:
dtype = out_dtype

values, nulls = split_nulls_and_values(
Expand Down Expand Up @@ -203,7 +203,7 @@ def validate_core(func):
def wrapper(*args, **kwargs):
for arg in itertools.chain(args, kwargs.values()):
if isinstance(arg, ndx.Array) and not isinstance(
arg.dtype, (dtypes.CoreType, dtypes._NullableCore)
arg.dtype, (dtypes.CoreType, dtypes.NullableCore)
):
return NotImplemented
return func(*args, **kwargs)
Expand Down
12 changes: 6 additions & 6 deletions ndonnx/_data_types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
NullableUnsigned,
Numerical,
Unsigned,
_NullableCore,
NullableCore,
from_numpy_dtype,
get_finfo,
get_iinfo,
Expand All @@ -51,7 +51,7 @@
from .structtype import StructType


def into_nullable(dtype: StructType | CoreType) -> _NullableCore:
def into_nullable(dtype: StructType | CoreType) -> NullableCore:
"""Return nullable counterpart, if present.
Parameters
Expand All @@ -61,7 +61,7 @@ def into_nullable(dtype: StructType | CoreType) -> _NullableCore:
Returns
-------
out : _NullableCore
out : NullableCore
The nullable counterpart of the input type.
Raises
Expand Down Expand Up @@ -93,7 +93,7 @@ def into_nullable(dtype: StructType | CoreType) -> _NullableCore:
return nuint64
elif dtype == utf8:
return nutf8
elif isinstance(dtype, _NullableCore):
elif isinstance(dtype, NullableCore):
return dtype
else:
raise ValueError(f"Cannot promote {dtype} to nullable")
Expand All @@ -103,14 +103,14 @@ def into_nullable(dtype: StructType | CoreType) -> _NullableCore:
"Function 'ndonnx.promote_nullable' will be deprecated in ndonnx 0.7. "
"To create nullable array, use 'ndonnx.additional.make_nullable' instead."
)
def promote_nullable(dtype: StructType | CoreType) -> _NullableCore:
def promote_nullable(dtype: StructType | CoreType) -> NullableCore:
return into_nullable(dtype)


__all__ = [
"CoreType",
"StructType",
"_NullableCore",
"NullableCore",
"NullableFloating",
"NullableIntegral",
"NullableUnsigned",
Expand Down
20 changes: 10 additions & 10 deletions ndonnx/_data_types/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def _fields(self) -> dict[str, StructType | CoreType]:
}


class _NullableCore(Nullable[CoreType], CastMixin):
class NullableCore(Nullable[CoreType], CastMixin):
def copy(self) -> Self:
return self

Expand All @@ -213,7 +213,7 @@ def _schema(self) -> Schema:
return Schema(type_name=type(self).__name__, author="ndonnx")

def _cast_to(self, array: Array, dtype: CoreType | StructType) -> Array:
if isinstance(dtype, _NullableCore):
if isinstance(dtype, NullableCore):
return ndx.Array._from_fields(
dtype,
values=self.values._cast_to(array.values, dtype.values),
Expand All @@ -230,7 +230,7 @@ def _cast_from(self, array: Array) -> Array:
values=self.values._cast_from(array),
null=ndx.zeros_like(array, dtype=Boolean()),
)
elif isinstance(array.dtype, _NullableCore):
elif isinstance(array.dtype, NullableCore):
return ndx.Array._from_fields(
self,
values=self.values._cast_from(array.values),
Expand All @@ -240,7 +240,7 @@ def _cast_from(self, array: Array) -> Array:
raise CastError(f"Cannot cast from {array.dtype} to {self}")


class NullableNumerical(_NullableCore):
class NullableNumerical(NullableCore):
"""Base class for nullable numerical data types."""

_ops: OperationsBlock = NullableNumericOperationsImpl()
Expand Down Expand Up @@ -312,14 +312,14 @@ class NFloat64(NullableFloating):
null = Boolean()


class NBoolean(_NullableCore):
class NBoolean(NullableCore):
values = Boolean()
null = Boolean()

_ops: OperationsBlock = NullableBooleanOperationsImpl()


class NUtf8(_NullableCore):
class NUtf8(NullableCore):
values = Utf8()
null = Boolean()

Expand Down Expand Up @@ -405,18 +405,18 @@ def _from_dtype(cls, dtype: CoreType) -> Finfo:
)


def get_finfo(dtype: _NullableCore | CoreType) -> Finfo:
def get_finfo(dtype: NullableCore | CoreType) -> Finfo:
try:
if isinstance(dtype, _NullableCore):
if isinstance(dtype, NullableCore):
dtype = dtype.values
return Finfo._from_dtype(dtype)
except KeyError:
raise TypeError(f"'{dtype}' is not a floating point data type.")


def get_iinfo(dtype: _NullableCore | CoreType) -> Iinfo:
def get_iinfo(dtype: NullableCore | CoreType) -> Iinfo:
try:
if isinstance(dtype, _NullableCore):
if isinstance(dtype, NullableCore):
dtype = dtype.values
return Iinfo._from_dtype(dtype)
except KeyError:
Expand Down
4 changes: 2 additions & 2 deletions ndonnx/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy.typing as npt

import ndonnx._data_types as dtypes
from ndonnx._data_types import CastError, CastMixin, CoreType, _NullableCore
from ndonnx._data_types import CastError, CastMixin, CoreType, NullableCore
from ndonnx._data_types.structtype import StructType
from ndonnx.additional import shape

Expand Down Expand Up @@ -291,7 +291,7 @@ def result_type(
np_dtypes = []
for dtype in observed_dtypes:
if isinstance(dtype, dtypes.StructType):
if isinstance(dtype, _NullableCore):
if isinstance(dtype, NullableCore):
nullable = True
np_dtypes.append(dtype.values.to_numpy_dtype())
else:
Expand Down

0 comments on commit c132f93

Please sign in to comment.