Skip to content

Commit

Permalink
Fix scope counting in pass/simplify (#594)
Browse files Browse the repository at this point in the history
  • Loading branch information
roastduck authored Jan 21, 2024
1 parent 8af79b8 commit 41346f5
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 6 deletions.
5 changes: 3 additions & 2 deletions src/analyze/comp_unique_bounds.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ int CompUniqueBounds::Bound::countHeavyOps(const Expr &op) {
int CompUniqueBounds::Bound::countScope(
const Expr &expr,
const std::unordered_map<std::string, int> &orderedScope) {
int scope = 0;
for (auto &&use : allUses(expr))
int scope = -1; // 0 = first level var, -1 = no var
for (auto &&use : allNames(expr)) {
scope = std::max(scope, orderedScope.at(use));
}
return scope;
}

Expand Down
3 changes: 0 additions & 3 deletions src/analyze/comp_unique_bounds_pb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,6 @@ Expr CompUniqueBoundsPB::Bound::simplestExpr(
int minScopeLevel = INT_MAX,
oldScopeLevel = countScope(reference, orderedScope);
for (auto &&[axis, scopeLevel] : axesScopeLevel) {
if (scopeLevel > oldScopeLevel) {
continue;
}
auto newRestrictedBound =
projectOutParamById(std::move(restrictedBound), axis);
if (!newRestrictedBound.isSingleValued())
Expand Down
21 changes: 20 additions & 1 deletion test/20.pass/test_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_const_fold(p):


@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify])
def test_partial_fold(p):
def test_partial_fold_1(p):
# This is the case that we need a symbolic bound, instead
# of using integers only
with ft.VarDef("y", (4, 4), "int32", "output", "cpu") as y:
Expand All @@ -45,6 +45,25 @@ def test_partial_fold(p):
assert std.match(ast)


@pytest.mark.parametrize('p', [ft.pb_simplify])
def test_partial_fold_2(p):
with ft.VarDef("y", (32, 2), "int32", "output", "cpu") as y:
with ft.For("i", 0, 32) as i:
with ft.For("j", 0, 2) as j:
y[i, j] = i - 4 * (1 + ((2 * i + j) // 8)) + 4
ast = ft.pop_ast(verbose=True)
ast = p(ast)
print(ast)

with ft.VarDef("y", (32, 2), "int32", "output", "cpu") as y:
with ft.For("i", 0, 32) as i:
with ft.For("j", 0, 2) as j:
y[i, j] = i % 4
std = ft.pop_ast()

assert std.match(ast)


@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify])
def test_redundant_if(p):
with ft.VarDef("y", (4,), "int32", "output", "cpu") as y:
Expand Down

0 comments on commit 41346f5

Please sign in to comment.