Skip to content

Commit

Permalink
Merge branch 'main' into 20-ndxprod-fails-for-float64
Browse files Browse the repository at this point in the history
  • Loading branch information
MatejUrbanQC authored Jul 24, 2024
2 parents 937d2b0 + 4583e13 commit 7d414f2
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 1 deletion.
2 changes: 1 addition & 1 deletion ndonnx/_core/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def where_dtype_agnostic(a: ndx.Array, b: ndx.Array) -> ndx.Array:
a.astype(dtypes.int32)._core(),
b.astype(dtypes.int32)._core(),
)
)
).astype(a.dtype)
elif a.dtype in (dtypes.uint16, dtypes.uint32, dtypes.uint64):
return _from_corearray(
opx.where(
Expand Down
2 changes: 2 additions & 0 deletions ndonnx/_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def collect_lazy_arguments(obj):
return ndx.asarray(obj_value, obj.dtype)
elif isinstance(obj, (list, tuple)):
return type(obj)(map(collect_lazy_arguments, obj))
elif isinstance(obj, dict):
return {key: collect_lazy_arguments(value) for key, value in obj.items()}
elif isinstance(obj, slice):
return slice(
collect_lazy_arguments(obj.start),
Expand Down
26 changes: 26 additions & 0 deletions tests/ndonnx/test_constant_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

import numpy as np
import pytest
import spox.opset.ai.onnx.v20 as op

import ndonnx as ndx
from ndonnx._propagation import eager_propagate


def test_add():
Expand Down Expand Up @@ -237,3 +239,27 @@ def test_where_folding(cond, x, y, expected_operators):
model_proto = ndx.build(inputs, {"out": out})
operators_used_const = {node.op_type for node in model_proto.graph.node}
assert operators_used_const == expected_operators


def test_eager_propagation_nested_parameters():
@eager_propagate
def function(
x: ndx.Array, mapping: dict[str, ndx.Array], seq: list[ndx.Array]
) -> tuple[ndx.Array, ndx.Array]:
# do some spox stuff
a = ndx.from_spox_var(op.sigmoid(mapping["a"].astype(ndx.float64).spox_var()))
b = ndx.from_spox_var(
op.regex_full_match(seq[0].spox_var(), pattern="^hello.*")
)
return (a + x) * mapping["b"], b

x, y = function(
ndx.asarray([1, 2, 3, 4]),
{"a": ndx.asarray([1, -10, 120, 40]), "b": 10},
[ndx.asarray(["a", "hello world", "world hello"])],
)
expected_x = np.asarray([17.310586, 20.000454, 40.0, 50.0])
expected_y = np.asarray([False, True, False])

np.testing.assert_allclose(x.to_numpy(), expected_x)
np.testing.assert_array_equal(y.to_numpy(), expected_y)

0 comments on commit 7d414f2

Please sign in to comment.