Skip to content

Commit

Permalink
More rules to simplify IfExpr (#595)
Browse files Browse the repository at this point in the history
* More rules to simplify IfExpr

* Fix broken tests
  • Loading branch information
roastduck authored Jan 21, 2024
1 parent 41346f5 commit 2d404a5
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 1 deletion.
43 changes: 43 additions & 0 deletions src/pass/simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <analyze/analyze_linear.h>
#include <analyze/as_dnf.h>
#include <except.h>
#include <math/bounds.h>
#include <math/min_max.h>
#include <math/utils.h>
#include <pass/annotate_conds.h>
Expand Down Expand Up @@ -836,6 +837,48 @@ Expr SimplifyPass::visit(const IfExpr &_op) {
makeIfExpr(op->cond_, lin2expr(thenLin), lin2expr(elseLin)));
}

if (op->thenCase_->nodeType() == ASTNodeType::IntConst &&
op->elseCase_->nodeType() == ASTNodeType::IntConst) {
if (auto lc = linearComp(op->cond_);
lc.has_value() && lc->first.coeff_.size() == 1) {
auto &&x = lc->first.coeff_[0].a_;
auto xl = unique_->getIntLower(x);
auto xu = unique_->getIntUpper(x);
// TODO: Use saturation arithmetic when C++26 is available to be
// safer
if (xl > LLONG_MIN && xu < LLONG_MAX) {
auto &&[cl, cu] = lin2bounds(lc->first, lc->second, x);
if (cu.has_value() && !cl.has_value()) {
ASSERT(cu->lin().coeff_.empty());
if (xu - xl < 2 * (cu->lin().bias_ - xl + 1)) {
// x <= cu ? then : else === then + floor((x - xl) / (cu
// - xl + 1)) * (else - then)
return makeAdd(
op->thenCase_,
makeMul(
makeFloorDiv(
makeSub(x, makeIntConst(xl)),
makeAdd(cu->expr(), makeIntConst(-xl + 1))),
makeSub(op->elseCase_, op->thenCase_)));
}
}
if (cl.has_value() && !cu.has_value()) {
ASSERT(cl->lin().coeff_.empty());
if (xu - xl < 2 * cl->lin().bias_) {
// x >= cl ? then : else === else + floor((x - xl) / (cl
// - xl)) * (then - else)
return makeAdd(
op->elseCase_,
makeMul(makeFloorDiv(
makeSub(x, makeIntConst(xl)),
makeSub(cl->expr(), makeIntConst(xl))),
makeSub(op->thenCase_, op->elseCase_)));
}
}
}
}
}

return op;
}

Expand Down
34 changes: 34 additions & 0 deletions test/20.pass/test_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,40 @@ def test_sink_if_expr_into_linear_expression(p):
assert std.match(ast)


@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify])
def test_convert_if_expr_to_floor_1(p):
with ft.VarDef("y", (4,), "int32", "output", "cpu") as y:
with ft.For("i", 0, 4) as i:
y[i] = ft.if_then_else(i < 2, 1, 5)
ast = ft.pop_ast(verbose=True)
ast = p(ast)
print(ast)

with ft.VarDef("y", (4,), "int32", "output", "cpu") as y:
with ft.For("i", 0, 4) as i:
y[i] = 1 + (i // 2) * 4
std = ft.pop_ast()

assert std.match(ast)


@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify])
def test_convert_if_expr_to_floor_2(p):
with ft.VarDef("y", (4,), "int32", "output", "cpu") as y:
with ft.For("i", 0, 4) as i:
y[i] = ft.if_then_else(i >= 3, 5, 1)
ast = ft.pop_ast(verbose=True)
ast = p(ast)
print(ast)

with ft.VarDef("y", (4,), "int32", "output", "cpu") as y:
with ft.For("i", 0, 4) as i:
y[i] = 1 + (i // 3) * 4
std = ft.pop_ast()

assert std.match(ast)


@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify])
def test_accessible_after_writing_if(p):
with ft.VarDef([("x", (4,), "int32", "inout", "cpu"),
Expand Down
1 change: 0 additions & 1 deletion test/30.schedule/test_reorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,6 @@ def test_if_expr():
y[i * 8 + j + ft.if_then_else(i <= 1, 16, -16)] = i + j
ast = ft.pop_ast(verbose=True)
ast = ft.schedule(ast, lambda s: s.reorder(["L2", "L1"]), verbose=1)
ast = ft.lower(ast, verbose=1)

with ft.VarDef("y", (32,), "int32", "output", "cpu") as y:
with ft.For("j", 0, 8, label="L2") as j:
Expand Down

0 comments on commit 2d404a5

Please sign in to comment.