Skip to content

Commit

Permalink
Merge branch 'run-array-api-tests-on-pr' into f64-failure-array-api
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Jul 24, 2024
2 parents 937d2b0 + df1a1ca commit c2b07e5
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 17 deletions.
10 changes: 1 addition & 9 deletions .github/workflows/array-api.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
name: Array API coverage tests
on:
push:
branches:
- "*"
schedule:
- cron: "0 8 * * *"
on: [push]

# Automatically stop old builds on the same branch/PR
concurrency:
Expand All @@ -14,9 +9,6 @@ concurrency:
jobs:
array-api-tests:
# Run if the commit message contains 'run array-api tests' or if the job is triggered on schedule
if: >-
contains(github.event.head_commit.message, 'run array-api tests') ||
github.event_name == 'schedule'
name: Array API test
timeout-minutes: 90
runs-on: ubuntu-latest
Expand Down
6 changes: 1 addition & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
name: CI
on:
push:
branches:
- main
pull_request:
on: [push]

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
Expand Down
2 changes: 2 additions & 0 deletions ndonnx/_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def collect_lazy_arguments(obj):
return ndx.asarray(obj_value, obj.dtype)
elif isinstance(obj, (list, tuple)):
return type(obj)(map(collect_lazy_arguments, obj))
elif isinstance(obj, dict):
return {key: collect_lazy_arguments(value) for key, value in obj.items()}
elif isinstance(obj, slice):
return slice(
collect_lazy_arguments(obj.start),
Expand Down
7 changes: 4 additions & 3 deletions skips.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ array_api_tests/test_has_names.py::test_has_names[linear_algebra-vecdot]
array_api_tests/test_has_names.py::test_has_names[linalg-cross]
array_api_tests/test_has_names.py::test_has_names[linalg-cholesky]
array_api_tests/test_linalg.py
array_api_tests/test_operators_and_elementwise_functions.py::test_atan2
array_api_tests/test_operators_and_elementwise_functions.py::test_sign
array_api_tests/test_operators_and_elementwise_functions.py::test_sinh
array_api_tests/test_set_functions.py::test_unique_all
array_api_tests/test_set_functions.py::test_unique_counts
array_api_tests/test_set_functions.py::test_unique_inverse
Expand Down Expand Up @@ -104,3 +101,7 @@ array_api_tests/test_has_names.py::test_has_names[array_attribute-size]
array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack]
array_api_tests/test_has_names.py::test_has_names[creation-meshgrid]
array_api_tests/test_operators_and_elementwise_functions.py::test_pow
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide
array_api_tests/test_operators_and_elementwise_functions.py::test_atan2
array_api_tests/test_operators_and_elementwise_functions.py::test_sign
array_api_tests/test_operators_and_elementwise_functions.py::test_sinh
26 changes: 26 additions & 0 deletions tests/ndonnx/test_constant_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

import numpy as np
import pytest
import spox.opset.ai.onnx.v20 as op

import ndonnx as ndx
from ndonnx._propagation import eager_propagate


def test_add():
Expand Down Expand Up @@ -237,3 +239,27 @@ def test_where_folding(cond, x, y, expected_operators):
model_proto = ndx.build(inputs, {"out": out})
operators_used_const = {node.op_type for node in model_proto.graph.node}
assert operators_used_const == expected_operators


def test_eager_propagation_nested_parameters():
@eager_propagate
def function(
x: ndx.Array, mapping: dict[str, ndx.Array], seq: list[ndx.Array]
) -> tuple[ndx.Array, ndx.Array]:
# do some spox stuff
a = ndx.from_spox_var(op.sigmoid(mapping["a"].astype(ndx.float64).spox_var()))
b = ndx.from_spox_var(
op.regex_full_match(seq[0].spox_var(), pattern="^hello.*")
)
return (a + x) * mapping["b"], b

x, y = function(
ndx.asarray([1, 2, 3, 4]),
{"a": ndx.asarray([1, -10, 120, 40]), "b": 10},
[ndx.asarray(["a", "hello world", "world hello"])],
)
expected_x = np.asarray([17.310586, 20.000454, 40.0, 50.0])
expected_y = np.asarray([False, True, False])

np.testing.assert_allclose(x.to_numpy(), expected_x)
np.testing.assert_array_equal(y.to_numpy(), expected_y)

0 comments on commit c2b07e5

Please sign in to comment.