Skip to content

Commit

Permalink
Avoid exception based control flow
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Jul 16, 2024
1 parent 7c7771d commit 1783d84
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 96 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@
Changelog
=========

0.7.0 (unreleased)
------------------

**Other changes**

- ``ndonnx.result_type`` may now return ``None`` if the provided dtypes are user-defined.


0.6.1 (2024-07-12)
------------------

Expand Down
156 changes: 94 additions & 62 deletions ndonnx/_core/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ def acosh(self, x):
return _unary_op(x, opx.acosh, dtypes.float32)

def add(self, x, y) -> ndx.Array:
x, y = promote(x, y)
if x.dtype in (dtypes.utf8, dtypes.nutf8):
return _binary_op(x, y, opx.string_concat)
return _via_i64_f64(opx.add, [x, y])
operands = promote(x, y)
if operands is None:
return NotImplemented
if operands[0].dtype in (dtypes.utf8, dtypes.nutf8):
return _variadic_op(operands, opx.string_concat)
return _via_i64_f64(opx.add, operands)

def asin(self, x):
return _unary_op(x, opx.asin, dtypes.float32)
Expand All @@ -60,16 +62,19 @@ def atanh(self, x):
def bitwise_and(self, x, y):
# FIXME: Bitwise operations like this one should raise a type
# error when encountering booleans!
x, y = promote(x, y)
if x.dtype in (dtypes.bool, dtypes.nbool):
return self.logical_and(x, y)
operands = promote(x, y)
if operands is None:
return NotImplemented

if operands[0].dtype in (dtypes.bool, dtypes.nbool):
return self.logical_and(*operands)

return _binary_op(x, y, opx.bitwise_and)
return _variadic_op(operands, opx.bitwise_and)

# TODO: ONNX standard -> not cyclic
def bitwise_left_shift(self, x, y):
return _binary_op(
x, y, lambda a, b: opx.bit_shift(a, b, direction="LEFT"), dtypes.uint64
return _variadic_op(
[x, y], lambda a, b: opx.bit_shift(a, b, direction="LEFT"), dtypes.uint64
)

def bitwise_invert(self, x):
Expand All @@ -79,25 +84,28 @@ def bitwise_invert(self, x):
return _unary_op(x, opx.bitwise_not)

def bitwise_or(self, x, y):
x, y = promote(x, y)
if x.dtype in (dtypes.bool, dtypes.nbool):
return self.logical_or(x, y)
return _binary_op(x, y, opx.bitwise_or)
operands = promote(x, y)
if operands is None:
return None
if operands[0].dtype in (dtypes.bool, dtypes.nbool):
return self.logical_or(*operands)
return _variadic_op(operands, opx.bitwise_or)

# TODO: ONNX standard -> not cyclic
def bitwise_right_shift(self, x, y):
x, y = promote(x, y)
return _binary_op(
x,
y,
operands = promote(x, y)
if operands is None:
return None
return _variadic_op(
operands,
lambda x, y: opx.bit_shift(x, y, direction="RIGHT"),
dtypes.uint64,
)

def bitwise_xor(self, x, y):
if x.dtype in (dtypes.bool, dtypes.nbool):
return self.logical_xor(x, y)
return _binary_op(x, y, opx.bitwise_xor)
return _variadic_op([x, y], opx.bitwise_xor)

def ceil(self, x):
if isinstance(x.dtype, (dtypes.Floating, dtypes.NullableFloating)):
Expand All @@ -111,27 +119,32 @@ def cosh(self, x):
return _unary_op(x, opx.cosh, dtypes.float32)

def divide(self, x, y):
x, y = promote(x, y)
if not isinstance(x.dtype, (dtypes.Numerical, dtypes.NullableNumerical)):
operands = promote(x, y)
if operands is None:
return NotImplemented
dtype = operands[0].dtype
if not isinstance(dtype, (dtypes.Numerical, dtypes.NullableNumerical)):
raise TypeError(f"Unsupported dtype for divide: {x.dtype}")
bits = (
ndx.iinfo(x.dtype).bits
if isinstance(x.dtype, (dtypes.Integral, dtypes.NullableIntegral))
else ndx.finfo(x.dtype).bits
ndx.iinfo(dtype).bits
if isinstance(dtype, (dtypes.Integral, dtypes.NullableIntegral))
else ndx.finfo(dtype).bits
)
via_dtype = (
dtypes.float64
if bits > 32 or x.dtype in (dtypes.nuint32, dtypes.uint32)
if bits > 32 or dtype in (dtypes.nuint32, dtypes.uint32)
else dtypes.float32
)
return _variadic_op([x, y], opx.div, via_dtype=via_dtype, cast_return=False)
return _variadic_op(operands, opx.div, via_dtype=via_dtype, cast_return=False)

def equal(self, x, y) -> Array:
x, y = promote(x, y)
if isinstance(x.dtype, (dtypes.Integral, dtypes.NullableIntegral)):
return _variadic_op([x, y], opx.equal, dtypes.int64, cast_return=False)
operands = promote(x, y)
if operands is None:
return NotImplemented
if isinstance(operands[0].dtype, (dtypes.Integral, dtypes.NullableIntegral)):
return _variadic_op(operands, opx.equal, dtypes.int64, cast_return=False)
else:
return _binary_op(x, y, opx.equal)
return _variadic_op(operands, opx.equal)

def exp(self, x):
return _unary_op(x, opx.exp, dtypes.float32)
Expand All @@ -146,9 +159,11 @@ def floor(self, x):
return x

def floor_divide(self, x, y):
x, y = promote(x, y)
dtype = x.dtype
out = self.floor(self.divide(x, y))
operands = promote(x, y)
if operands is None:
return NotImplemented
dtype = operands[0].dtype
out = self.floor(self.divide(*operands))
if isinstance(
dtype,
(
Expand Down Expand Up @@ -208,7 +223,10 @@ def logaddexp(self, x, y):
return self.log(self.exp(x) + self.exp(y))

def logical_and(self, x, y):
x, y = promote(x, y)
operands = promote(x, y)
if operands is None:
return NotImplemented
x, y = operands
if x.dtype not in (dtypes.bool, dtypes.nbool):
raise TypeError(f"Unsupported dtype for logical_and: {x.dtype}")

Expand All @@ -228,7 +246,7 @@ def logical_and(self, x, y):
and y.to_numpy().item()
):
return x.copy()
return _binary_op(x, y, opx.and_)
return _variadic_op(operands, opx.and_)

def logical_not(self, x):
return _unary_op(x, opx.not_)
Expand All @@ -254,14 +272,16 @@ def logical_or(self, x, y):
and not y.to_numpy().item()
):
return x.copy()
return _binary_op(x, y, opx.or_)
return _variadic_op([x, y], opx.or_)

def logical_xor(self, x, y):
return _binary_op(x, y, opx.xor)
return _variadic_op([x, y], opx.xor)

def multiply(self, x, y):
x, y = promote(x, y)
dtype = ndx.result_type(x, y)
operands = promote(x, y)
if operands is None:
return NotImplemented
dtype = ndx.result_type(*operands)
via_dtype: dtypes.CoreType
if isinstance(dtype, (dtypes.Integral, dtypes.NullableIntegral)) or dtype in (
dtypes.nbool,
Expand All @@ -272,7 +292,7 @@ def multiply(self, x, y):
via_dtype = dtypes.float64
else:
raise TypeError(f"Unsupported dtype for multiply: {dtype}")
return _binary_op(x, y, opx.mul, via_dtype)
return _variadic_op(operands, opx.mul, via_dtype)

def negative(self, x):
if isinstance(
Expand All @@ -290,12 +310,12 @@ def pow(self, x, y):
x, y = ndx.asarray(x), ndx.asarray(y)
dtype = ndx.result_type(x, y)
if isinstance(dtype, (dtypes.Unsigned, dtypes.NullableUnsigned)):
return _binary_op(x, y, opx.pow, dtypes.int64)
return _variadic_op([x, y], opx.pow, dtypes.int64)
else:
return _binary_op(x, y, opx.pow)
return _variadic_op([x, y], opx.pow)

def remainder(self, x, y):
return _binary_op(x, y, lambda x, y: opx.mod(x, y, fmod=1))
return _variadic_op([x, y], lambda x, y: opx.mod(x, y, fmod=1))

def round(self, x):
x = ndx.asarray(x)
Expand Down Expand Up @@ -325,14 +345,20 @@ def sqrt(self, x):
return _unary_op(x, opx.sqrt, dtypes.float32)

def subtract(self, x, y):
x, y = promote(x, y)
if isinstance(
x.dtype, (dtypes.Unsigned, dtypes.NullableUnsigned)
) or x.dtype in (dtypes.int16, dtypes.int8, dtypes.nint16, dtypes.nint8):
operands = promote(x, y)
if operands is None:
return NotImplemented
dtype = operands[0].dtype
if isinstance(dtype, (dtypes.Unsigned, dtypes.NullableUnsigned)) or dtype in (
dtypes.int16,
dtypes.int8,
dtypes.nint16,
dtypes.nint8,
):
via_dtype = dtypes.int64
else:
via_dtype = None
return _binary_op(x, y, opx.sub, via_dtype=via_dtype)
return _variadic_op([x, y], opx.sub, via_dtype=via_dtype)

def tan(self, x):
return _unary_op(x, opx.tan, dtypes.float32)
Expand Down Expand Up @@ -467,7 +493,11 @@ def searchsorted(
def where(self, condition, x, y):
condition, x, y = ndx.asarray(condition), ndx.asarray(x), ndx.asarray(y)
if x.dtype != y.dtype:
x, y = promote(x, y)
operands = promote(x, y)
if operands is None:
return NotImplemented
else:
x, y = operands
if isinstance(condition.dtype, dtypes.Nullable) and not isinstance(
x.dtype, (dtypes.Nullable, dtypes.CoreType)
):
Expand Down Expand Up @@ -760,7 +790,10 @@ def clip(
and max.ndim == 0
and isinstance(x.dtype, dtypes.Numerical)
):
x, min, max = promote(x, min, max)
operands = promote(x, min, max)
if operands is None:
return NotImplemented
x, min, max = operands
if isinstance(x.dtype, dtypes._NullableCore):
out_null = x.null
x_values = x.values._core()
Expand Down Expand Up @@ -892,15 +925,6 @@ def shape(self, x):
return _from_corearray(opx.shape(current))


def _binary_op(
x: ndx.Array,
y: ndx.Array,
op: Callable[[_CoreArray, _CoreArray], _CoreArray],
via_dtype: dtypes.CoreType | None = None,
):
return _variadic_op([x, y], op, via_dtype)


def _unary_op(
x: ndx.Array,
op: Callable[[_CoreArray], _CoreArray],
Expand All @@ -915,11 +939,13 @@ def _variadic_op(
via_dtype: dtypes.CoreType | None = None,
cast_return: bool = True,
) -> ndx.Array:
args = promote(*args)
out_dtype = args[0].dtype
promoted_args = promote(*args)
if promoted_args is None:
return NotImplemented
out_dtype = promoted_args[0].dtype
if not isinstance(out_dtype, (dtypes.CoreType, dtypes._NullableCore)):
return NotImplemented
data, nulls = _split_nulls_and_values(*args)
data, nulls = _split_nulls_and_values(*promoted_args)
if via_dtype is None:
values = _from_corearray(op(*(x._core() for x in data)))
else:
Expand Down Expand Up @@ -975,6 +1001,9 @@ def _via_dtype(
ndx.Arrays are promoted to a common type prior to their first use.
"""
promoted = promote(*arrays)
if promoted is None:
return NotImplemented

out_dtype = promoted[0].dtype

if isinstance(out_dtype, dtypes._NullableCore) and out_dtype.values == dtype:
Expand Down Expand Up @@ -1010,6 +1039,9 @@ def _via_i64_f64(
"""
promoted_values = promote(*arrays)

if promoted_values is None:
return NotImplemented

dtype = promoted_values[0].dtype

via_dtype: dtypes.CoreType
Expand Down
Loading

0 comments on commit 1783d84

Please sign in to comment.