Skip to content

Commit

Permalink
Drop pyproject.toml entries, start windows fix
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Jun 10, 2024
1 parent c4fa1fd commit a307913
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 33 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/array-api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ concurrency:

jobs:
array-api-tests:
# Run if the commit message contains 'run array-api tests' or if the job is triggered on schedule
if: >-
contains(github.event.head_commit.message, 'run array-api tests') ||
github.event_name == 'schedule'
name: Array API test
timeout-minutes: 90
runs-on: ubuntu-latest-8core
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ jobs:
fail-fast: false
matrix:
os:
- ubuntu-latest
- macos-latest
# - ubuntu-latest
# - macos-latest
- windows-latest
environment:
- py310
# - py310
- py311
- py312
# - py312
steps:
- name: Checkout branch
uses: actions/checkout@v4
Expand Down
5 changes: 4 additions & 1 deletion ndonnx/_core/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,10 @@ def unique_all(self, x):

# FIXME: I think we can simply use arange/ones+cumsum or something for the indices
# maybe: indices = opx.cumsum(ones_like(flattened, dtype=dtypes.i64), axis=ndx.asarray(0))
indices = opx.squeeze(opx.ndindex(opx.shape(flattened._core())), opx.const([1]))
indices = opx.squeeze(
opx.ndindex(opx.shape(flattened._core())),
opx.const([1], dtype=dtypes.int64),
)

ret = namedtuple("ret", ["values", "indices", "inverse_indices", "counts"])

Expand Down
11 changes: 9 additions & 2 deletions ndonnx/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,9 @@ def expand_dims(x, axis=0):
if (out := x.dtype._ops.expand_dims(x, axis)) is not NotImplemented:
return out
return x._transmute(
lambda corearray: opx.unsqueeze(corearray, axes=opx.const([axis]))
lambda corearray: opx.unsqueeze(
corearray, axes=opx.const([axis], dtype=dtypes.int64)
)
)


Expand Down Expand Up @@ -694,7 +696,12 @@ def roll(x, shift, axis=None):
shift_single = opx.add(opx.const(-sh), len_single)
# Find the needed element index and then gather from it
range = opx.cast(
opx.range(opx.const(0), len_single, opx.const(1)), to=dtypes.int64
opx.range(
opx.const(0, dtype=len_single.dtype),
len_single,
opx.const(1, dtype=len_single.dtype),
),
to=dtypes.int64,
)
new_indices = opx.mod(opx.add(range, shift_single), len_single)
x = take(x, _from_corearray(new_indices), axis=ax)
Expand Down
31 changes: 22 additions & 9 deletions ndonnx/_opset_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ def const(value, dtype: CoreType | None = None) -> _CoreArray:

@eager_propagate
def squeeze(data: _CoreArray, axes: _CoreArray | None = None) -> _CoreArray:
if axes is not None and axes.dtype != dtypes.int64:
raise ValueError(f"Expected axes to be of type int64, got {axes.dtype}")
return _CoreArray(op.squeeze(data.var, axes=axes.var if axes is not None else None))


Expand Down Expand Up @@ -364,6 +366,8 @@ def concat(inputs: list[_CoreArray], axis: int) -> _CoreArray:

@eager_propagate
def unsqueeze(data: _CoreArray, axes: _CoreArray) -> _CoreArray:
if axes.dtype != dtypes.int64:
raise TypeError(f"axes must be int64, got {axes.dtype}")
return _CoreArray(op.unsqueeze(data.var, axes.var))


Expand Down Expand Up @@ -503,7 +507,7 @@ def getitem_null(corearray: _CoreArray, index: _CoreArray) -> _CoreArray:
if get_rank(index) == 0:

def extend_shape(x: Var) -> Var:
return op.concat([op.const([1]), op.shape(x)], axis=0)
return op.concat([op.const([1], dtype=np.int64), op.shape(x)], axis=0)

var = op.reshape(var, extend_shape(var), allowzero=True)
index_var = op.reshape(index.var, extend_shape(index.var), allowzero=True)
Expand All @@ -514,7 +518,8 @@ def extend_shape(x: Var) -> Var:
reshaped_var, reshaped_index = var, index.var
else:
ret_shape = op.concat(
[op.const([-1]), op.shape(var, start=get_rank(index))], axis=0
[op.const([-1], dtype=np.int64), op.shape(var, start=get_rank(index))],
axis=0,
)
reshaped_var = op.reshape(var, ret_shape, allowzero=True)
reshaped_index = op.reshape(index.var, op.const([-1]), allowzero=True)
Expand Down Expand Up @@ -583,7 +588,7 @@ def getitem(
index_filtered = [x for x in index if isinstance(x, (type(None), slice))]
axis_new_axes = [ind for ind, x in enumerate(index_filtered) if x is None]
if len(axis_new_axes) != 0:
var = op.unsqueeze(var, axes=op.const(axis_new_axes))
var = op.unsqueeze(var, axes=op.const(axis_new_axes, dtype=np.int64))

return _CoreArray(var)

Expand Down Expand Up @@ -640,14 +645,19 @@ def ndindex(shape: _CoreArray, to_reverse=None, axes_permutation=None) -> _CoreA
axes_indices = [axes_permutation.index(i) for i in builtins.range(rank)]

shape_var = shape.var
dtype = shape_var.unwrap_tensor().dtype
ranges = [
(
op.range(op.const(0), op.gather(shape_var, op.const(i)), op.const(1))
op.range(
op.const(0, dtype=dtype),
op.gather(shape_var, op.const(i)),
op.const(1, dtype=dtype),
)
if i not in to_reverse
else op.range(
op.sub(op.gather(shape_var, op.const(i)), op.const(1)),
op.const(-1),
op.const(-1),
op.sub(op.gather(shape_var, op.const(i)), op.const(1, dtype=dtype)),
op.const(-1, dtype=dtype),
op.const(-1, dtype=dtype),
)
)
for i in builtins.range(rank)
Expand All @@ -657,7 +667,8 @@ def ndindex(shape: _CoreArray, to_reverse=None, axes_permutation=None) -> _CoreA
op.unsqueeze(
r,
op.const(
[j for j in builtins.range(rank) if axes_indices[i] != j], dtype=np.int_
[j for j in builtins.range(rank) if axes_indices[i] != j],
dtype=np.int64,
),
)
for i, r in enumerate(ranges)
Expand All @@ -669,7 +680,7 @@ def ndindex(shape: _CoreArray, to_reverse=None, axes_permutation=None) -> _CoreA
expanded_ranges = [op.expand(r, shape_var) for r in fit_ranges]

ret = op.concat(
[op.unsqueeze(r, op.const([-1])) for r in expanded_ranges],
[op.unsqueeze(r, op.const([-1], dtype=np.int64)) for r in expanded_ranges],
axis=-1,
)

Expand All @@ -692,6 +703,8 @@ def static_map(
input: _CoreArray, mapping: Mapping[KeyType, ValueType], default: ValueType | None
) -> _CoreArray:
keys = np.array(tuple(mapping.keys()))
if keys.dtype == np.int32:
keys = keys.astype(np.int64)
values = np.array(tuple(mapping.values()))
value_dtype = values.dtype
if default is None:
Expand Down
9 changes: 0 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,3 @@ exclude = ["docs/"]

[tool.typos.default]
extend-ignore-identifiers-re = ["scatter_nd", "arange"]

[tool.pixi.project]
channels = ["conda-forge"]
platforms = ["osx-arm64"]

[tool.pixi.pypi-dependencies]
ndonnx = { path = ".", editable = true }

[tool.pixi.tasks]
4 changes: 2 additions & 2 deletions tests/ndonnx/test_constant_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ def dynamic_masking_model(mode: Literal["lazy", "constant"]):

def constant_indexing_model(mode: Literal["lazy", "constant"]):
if mode == "constant":
a = ndx.asarray([0, 1, 2, 3])
a = ndx.asarray([0, 1, 2, 3], dtype=ndx.int64)
else:
a = ndx.array(
shape=("N",),
dtype=ndx.int64,
)
b = ndx.asarray([5, 7, 8, 8, 9, 9, 234])
b = ndx.asarray([5, 7, 8, 8, 9, 9, 234], dtype=ndx.int64)
idx = ndx.asarray([1, 3, 5, 0])
result = a * b[idx]
return ndx.build({"a": a} if mode == "lazy" else {}, {"y": result})
Expand Down
14 changes: 8 additions & 6 deletions tests/ndonnx/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,11 @@ def test_null_promotion():


def test_asarray():
a = ndx.asarray([1, 2, 3])
a = ndx.asarray([1, 2, 3], dtype=ndx.int64)
assert a.dtype == ndx.int64
np.testing.assert_array_equal(np.array([1, 2, 3]), a.to_numpy(), strict=True)
np.testing.assert_array_equal(
np.array([1, 2, 3], np.int64), a.to_numpy(), strict=True
)


def test_asarray_masked():
Expand Down Expand Up @@ -366,7 +368,7 @@ def test_matrix_transpose():
model = ndx.build({"a": a}, {"b": b})
np.testing.assert_equal(
npx.matrix_transpose(npx.reshape(npx.arange(3 * 2 * 3), (3, 2, 3))),
run(model, {"a": np.arange(3 * 2 * 3).reshape(3, 2, 3)})["b"],
run(model, {"a": np.arange(3 * 2 * 3, dtype=np.int64).reshape(3, 2, 3)})["b"],
)


Expand All @@ -377,7 +379,7 @@ def test_matrix_transpose_attribute():
model = ndx.build({"a": a}, {"b": b})
np.testing.assert_equal(
npx.reshape(npx.arange(3 * 2 * 3), (3, 2, 3)).mT,
run(model, {"a": np.arange(3 * 2 * 3).reshape(3, 2, 3)})["b"],
run(model, {"a": np.arange(3 * 2 * 3, dtype=np.int64).reshape(3, 2, 3)})["b"],
)


Expand All @@ -388,7 +390,7 @@ def test_transpose_attribute():
model = ndx.build({"a": a}, {"b": b})
np.testing.assert_equal(
npx.reshape(npx.arange(3 * 2), (3, 2)).T,
run(model, {"a": np.arange(3 * 2).reshape(3, 2)})["b"],
run(model, {"a": np.arange(3 * 2, dtype=np.int64).reshape(3, 2)})["b"],
)


Expand All @@ -399,7 +401,7 @@ def test_array_spox_interoperability():
model = ndx.build({"a": a}, {"b": b})
expected = npx.reshape(npx.arange(3 * 2), (3, 2)) + 5
input = np.ma.masked_array(
np.arange(3 * 2).reshape(3, 2), mask=np.ones((3, 2), dtype=bool)
np.arange(3 * 2, dtype=np.int64).reshape(3, 2), mask=np.ones((3, 2), dtype=bool)
)
actual = run(model, {"a": input})["b"]
np.testing.assert_equal(expected, actual)
Expand Down

0 comments on commit a307913

Please sign in to comment.