diff --git a/.github/workflows/array-api.yml b/.github/workflows/array-api.yml index 1dbce37..852c84b 100644 --- a/.github/workflows/array-api.yml +++ b/.github/workflows/array-api.yml @@ -1,10 +1,5 @@ name: Array API coverage tests -on: - push: - branches: - - "*" - schedule: - - cron: "0 8 * * *" +on: [push] # Automatically stop old builds on the same branch/PR concurrency: @@ -14,9 +9,6 @@ concurrency: jobs: array-api-tests: # Run if the commit message contains 'run array-api tests' or if the job is triggered on schedule - if: >- - contains(github.event.head_commit.message, 'run array-api tests') || - github.event_name == 'schedule' name: Array API test timeout-minutes: 90 runs-on: ubuntu-latest diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 971c461..883b492 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,9 +1,5 @@ name: CI -on: - push: - branches: - - main - pull_request: +on: [push] concurrency: group: ${{ github.workflow }}-${{ github.ref }} diff --git a/ndonnx/_propagation.py b/ndonnx/_propagation.py index b1fdce1..9efeb12 100644 --- a/ndonnx/_propagation.py +++ b/ndonnx/_propagation.py @@ -133,6 +133,8 @@ def collect_lazy_arguments(obj): return ndx.asarray(obj_value, obj.dtype) elif isinstance(obj, (list, tuple)): return type(obj)(map(collect_lazy_arguments, obj)) + elif isinstance(obj, dict): + return {key: collect_lazy_arguments(value) for key, value in obj.items()} elif isinstance(obj, slice): return slice( collect_lazy_arguments(obj.start), diff --git a/skips.txt b/skips.txt index 03f7287..81c8dbf 100644 --- a/skips.txt +++ b/skips.txt @@ -31,9 +31,6 @@ array_api_tests/test_has_names.py::test_has_names[linear_algebra-vecdot] array_api_tests/test_has_names.py::test_has_names[linalg-cross] array_api_tests/test_has_names.py::test_has_names[linalg-cholesky] array_api_tests/test_linalg.py -array_api_tests/test_operators_and_elementwise_functions.py::test_atan2 -array_api_tests/test_operators_and_elementwise_functions.py::test_sign -array_api_tests/test_operators_and_elementwise_functions.py::test_sinh array_api_tests/test_set_functions.py::test_unique_all array_api_tests/test_set_functions.py::test_unique_counts array_api_tests/test_set_functions.py::test_unique_inverse @@ -104,3 +101,7 @@ array_api_tests/test_has_names.py::test_has_names[array_attribute-size] array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] array_api_tests/test_has_names.py::test_has_names[creation-meshgrid] array_api_tests/test_operators_and_elementwise_functions.py::test_pow +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide +array_api_tests/test_operators_and_elementwise_functions.py::test_atan2 +array_api_tests/test_operators_and_elementwise_functions.py::test_sign +array_api_tests/test_operators_and_elementwise_functions.py::test_sinh diff --git a/tests/ndonnx/test_constant_propagation.py b/tests/ndonnx/test_constant_propagation.py index 8c2ad69..0a92591 100644 --- a/tests/ndonnx/test_constant_propagation.py +++ b/tests/ndonnx/test_constant_propagation.py @@ -5,8 +5,10 @@ import numpy as np import pytest +import spox.opset.ai.onnx.v20 as op import ndonnx as ndx +from ndonnx._propagation import eager_propagate def test_add(): @@ -237,3 +239,27 @@ def test_where_folding(cond, x, y, expected_operators): model_proto = ndx.build(inputs, {"out": out}) operators_used_const = {node.op_type for node in model_proto.graph.node} assert operators_used_const == expected_operators + + +def test_eager_propagation_nested_parameters(): + @eager_propagate + def function( + x: ndx.Array, mapping: dict[str, ndx.Array], seq: list[ndx.Array] + ) -> tuple[ndx.Array, ndx.Array]: + # do some spox stuff + a = ndx.from_spox_var(op.sigmoid(mapping["a"].astype(ndx.float64).spox_var())) + b = ndx.from_spox_var( + op.regex_full_match(seq[0].spox_var(), pattern="^hello.*") + ) + return (a + x) * mapping["b"], b + + x, y = function( + ndx.asarray([1, 2, 3, 4]), + {"a": ndx.asarray([1, -10, 120, 40]), "b": 10}, + [ndx.asarray(["a", "hello world", "world hello"])], + ) + expected_x = np.asarray([17.310586, 20.000454, 40.0, 50.0]) + expected_y = np.asarray([False, True, False]) + + np.testing.assert_allclose(x.to_numpy(), expected_x) + np.testing.assert_array_equal(y.to_numpy(), expected_y)