Skip to content

Commit

Permalink
Refactor where into CoreOperationsImpl and NullableOperationsImpl
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Aug 29, 2024
1 parent 3fcec54 commit 7b7fa90
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 24 deletions.
8 changes: 0 additions & 8 deletions ndonnx/_core/_boolimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,6 @@ def empty_like(self, x, dtype=None, device=None) -> ndx.Array:
def nonzero(self, x) -> tuple[Array, ...]:
return ndx.nonzero(x.astype(ndx.int8))

@validate_core
def where(self, condition, x, y):
if x.dtype != y.dtype:
target_dtype = ndx.result_type(x, y)
x = ndx.astype(x, target_dtype)
y = ndx.astype(y, target_dtype)
return super().where(condition, x, y)


class BooleanOperationsImpl(
CoreOperationsImpl, _BooleanOperationsImpl, UniformShapeOperations
Expand Down
8 changes: 8 additions & 0 deletions ndonnx/_core/_coreimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,11 @@ def make_nullable(self, x: Array, null: Array) -> Array:
values=x.copy(),
null=ndx.broadcast_to(null, nda.shape(x)),
)

@validate_core
def where(self, condition, x, y):
if x.dtype != y.dtype:
target_dtype = ndx.result_type(x, y)
x = ndx.astype(x, target_dtype)
y = ndx.astype(y, target_dtype)
return super().where(condition, x, y)
8 changes: 8 additions & 0 deletions ndonnx/_core/_nullableimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,11 @@ def fill_null(self, x: Array, value) -> Array:
@validate_core
def make_nullable(self, x: Array, null: Array) -> Array:
return NotImplemented

@validate_core
def where(self, condition, x, y):
if x.dtype != y.dtype:
target_dtype = ndx.result_type(x, y)
x = ndx.astype(x, target_dtype)
y = ndx.astype(y, target_dtype)
return super().where(condition, x, y)
8 changes: 0 additions & 8 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,14 +970,6 @@ def empty(self, shape, dtype=None, device=None) -> ndx.Array:
def empty_like(self, x, dtype=None, device=None) -> ndx.Array:
return ndx.full_like(x, 0, dtype=dtype)

@validate_core
def where(self, condition, x, y):
if x.dtype != y.dtype:
target_dtype = ndx.result_type(x, y)
x = ndx.astype(x, target_dtype)
y = ndx.astype(y, target_dtype)
return super().where(condition, x, y)


class NumericOperationsImpl(
CoreOperationsImpl, _NumericOperationsImpl, UniformShapeOperations
Expand Down
8 changes: 0 additions & 8 deletions ndonnx/_core/_stringimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,6 @@ def empty(self, shape, dtype=None, device=None) -> ndx.Array:
def empty_like(self, x, dtype=None, device=None) -> ndx.Array:
return ndx.zeros_like(x, dtype=dtype, device=device)

@validate_core
def where(self, condition, x, y):
if x.dtype != y.dtype:
target_dtype = ndx.result_type(x, y)
x = ndx.astype(x, target_dtype)
y = ndx.astype(y, target_dtype)
return super().where(condition, x, y)


class StringOperationsImpl(
CoreOperationsImpl, _StringOperationsImpl, UniformShapeOperations
Expand Down

0 comments on commit 7b7fa90

Please sign in to comment.