diff --git a/src/analyze/comp_unique_bounds.cc b/src/analyze/comp_unique_bounds.cc index 3d2c92e96..2a67f299e 100644 --- a/src/analyze/comp_unique_bounds.cc +++ b/src/analyze/comp_unique_bounds.cc @@ -33,9 +33,10 @@ int CompUniqueBounds::Bound::countHeavyOps(const Expr &op) { int CompUniqueBounds::Bound::countScope( const Expr &expr, const std::unordered_map &orderedScope) { - int scope = 0; - for (auto &&use : allUses(expr)) + int scope = -1; // 0 = first level var, -1 = no var + for (auto &&use : allNames(expr)) { scope = std::max(scope, orderedScope.at(use)); + } return scope; } diff --git a/src/analyze/comp_unique_bounds_pb.cc b/src/analyze/comp_unique_bounds_pb.cc index 00c79523f..19f34cb38 100644 --- a/src/analyze/comp_unique_bounds_pb.cc +++ b/src/analyze/comp_unique_bounds_pb.cc @@ -135,9 +135,6 @@ Expr CompUniqueBoundsPB::Bound::simplestExpr( int minScopeLevel = INT_MAX, oldScopeLevel = countScope(reference, orderedScope); for (auto &&[axis, scopeLevel] : axesScopeLevel) { - if (scopeLevel > oldScopeLevel) { - continue; - } auto newRestrictedBound = projectOutParamById(std::move(restrictedBound), axis); if (!newRestrictedBound.isSingleValued()) diff --git a/test/20.pass/test_simplify.py b/test/20.pass/test_simplify.py index 9cb59ba33..355ad1c19 100644 --- a/test/20.pass/test_simplify.py +++ b/test/20.pass/test_simplify.py @@ -25,7 +25,7 @@ def test_const_fold(p): @pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify]) -def test_partial_fold(p): +def test_partial_fold_1(p): # This is the case that we need a symbolic bound, instead # of using integers only with ft.VarDef("y", (4, 4), "int32", "output", "cpu") as y: @@ -45,6 +45,25 @@ def test_partial_fold(p): assert std.match(ast) +@pytest.mark.parametrize('p', [ft.pb_simplify]) +def test_partial_fold_2(p): + with ft.VarDef("y", (32, 2), "int32", "output", "cpu") as y: + with ft.For("i", 0, 32) as i: + with ft.For("j", 0, 2) as j: + y[i, j] = i - 4 * (1 + ((2 * i + j) // 8)) + 4 + ast = ft.pop_ast(verbose=True) + ast = p(ast) + print(ast) + + with ft.VarDef("y", (32, 2), "int32", "output", "cpu") as y: + with ft.For("i", 0, 32) as i: + with ft.For("j", 0, 2) as j: + y[i, j] = i % 4 + std = ft.pop_ast() + + assert std.match(ast) + + @pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify]) def test_redundant_if(p): with ft.VarDef("y", (4,), "int32", "output", "cpu") as y: