Skip to content

Commit

Permalink
Implement real presburger-based CompUniqueBounds. (#283)
Browse files Browse the repository at this point in the history
* Abstraction over CompUniqueBounds.

* Prepare for isolating implementation.

* Rename and isolate implementation.

* Cleanup includes.

* Fix isl parser.

* Initial implementation of full PB unique bound.

* Fix for a type inconsistency in simplify.

It fixes 20.pass/test_simplify.py::test_multiple_min_max[pb_simplify].
But why?

* Fix pb_simplify on unreachable code.

* Additional coalescing.

* Fix compile error.

* Fix priority in pb_parser.

* Pass basic tests

* Reset default parenDespitePriority flag to preserve existing tests

* A new way to reconstruct min / max

* Fix inter-PBCtx use of PBSet + Parse more isl's AST nodes

* Multiple fixes

- Fix return value from CompUniqueBounds when there is no non-trivial
  bounds.
- Fix pass/shrink_var for whether using old shapes for bounds.
- Fix incorrect testing program in 20.pass/test_prop_one_time_use.py::test_thread_local_no_prop

* Fix for 20.pass/test_simplify but revert some new tests

* Fix for test/21.autograd/test_output_intermediates.py::test_dynamic_loop_range

* Fix for more tests

* Fix for more tests

* Fix pass/gpu/normalize_threads

* Add 20.pass/test_shrink_for.py::test_presburger_bounds

* Pass some previously disabled tests

---------

Co-authored-by: Shizhi Tang <[email protected]>
  • Loading branch information
Blealtan and roastduck authored Dec 30, 2023
1 parent 5ccfff4 commit 67fef0a
Show file tree
Hide file tree
Showing 40 changed files with 1,590 additions and 624 deletions.
20 changes: 10 additions & 10 deletions grammar/pb_parser.g
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ expr returns [Expr node]
{
$node = makeCeilDiv($expr0.node, $expr1.node);
}
| MIN '(' expr0=expr ',' expr1=expr ')'
{
$node = makeMin($expr0.node, $expr1.node);
}
| MAX '(' expr0=expr ',' expr1=expr ')'
{
$node = makeMax($expr0.node, $expr1.node);
}
| intConst expr
{
$node = makeMul($intConst.node, $expr.node);
Expand Down Expand Up @@ -121,20 +129,12 @@ expr returns [Expr node]
case 2: $node = makeSub($expr0.node, $expr1.node); break;
}
}
| MIN '(' expr0=expr ',' expr1=expr ')'
{
$node = makeMin($expr0.node, $expr1.node);
}
| MAX '(' expr0=expr ',' expr1=expr ')'
{
$node = makeMax($expr0.node, $expr1.node);
}
;

boolExpr returns [Expr node]
: '(' boolExpr ')'
: '(' boolExpr0=boolExpr ')'
{
$node = $boolExpr.node;
$node = $boolExpr0.node;
}
| expr0=expr
{
Expand Down
11 changes: 5 additions & 6 deletions include/analyze/comp_access_bound.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef FREE_TENSOR_COMP_ACCESS_BOUND_H
#define FREE_TENSOR_COMP_ACCESS_BOUND_H

#include <memory>
#include <unordered_map>
#include <unordered_set>

Expand Down Expand Up @@ -47,25 +48,23 @@ class CompAccessBound : public CompTransientBounds<SymbolTable<Visitor>> {
public:
struct Access {
std::vector<Expr> indices_, conds_;
std::vector<std::vector<LowerBound>> lower_;
std::vector<std::vector<UpperBound>> upper_;
std::vector<Ref<CompUniqueBounds::Bound>> bounds_;

Access(CompUniqueBounds &unique, const std::vector<Expr> &indices,
const std::vector<Expr> &conds,
const std::unordered_set<std::string> &names)
: indices_(indices), conds_(conds) {
for (auto &&idx : indices) {
lower_.emplace_back(unique.getDefinedLower(idx, names));
upper_.emplace_back(unique.getDefinedUpper(idx, names));
bounds_.emplace_back(
unique.getBound(idx)->restrictScope(names));
}
}

Access(CompUniqueBounds &unique, const std::vector<Expr> &indices,
const std::vector<Expr> &conds)
: indices_(indices), conds_(conds) {
for (auto &&idx : indices) {
lower_.emplace_back(unique.getLower(idx));
upper_.emplace_back(unique.getUpper(idx));
bounds_.emplace_back(unique.getBound(idx));
}
}
};
Expand Down
11 changes: 11 additions & 0 deletions include/analyze/comp_transient_bounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class CompTransientBoundsInterface {
public:
virtual TransientBound transient(const Expr &op) const = 0;
virtual const std::vector<Expr> &conds() const = 0;
virtual const Stmt &currentStmt() const = 0;
};

/**
Expand Down Expand Up @@ -53,6 +54,9 @@ class CompTransientBounds : public BaseClass,
// Original bounds
std::vector<Expr> conds_;

// Currently visited statement
Stmt currentStmt_;

public:
TransientBound transient(const Expr &op) const override {
if (transients_.count(op)) {
Expand All @@ -63,6 +67,8 @@ class CompTransientBounds : public BaseClass,

const std::vector<Expr> &conds() const override { return conds_; }

const Stmt &currentStmt() const override { return currentStmt_; };

private:
void applyCond(const Expr &_cond,
const std::unordered_set<std::string> &bodyAllWrites) {
Expand Down Expand Up @@ -240,6 +246,11 @@ class CompTransientBounds : public BaseClass,
op->id(), op->debugBlame());
}
}

typename BaseClass::StmtRetType visitStmt(const Stmt &op) override {
currentStmt_ = op;
return BaseClass::visitStmt(op);
}
};

} // namespace freetensor
Expand Down
144 changes: 113 additions & 31 deletions include/analyze/comp_unique_bounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,81 @@
#define FREE_TENSOR_COMP_UNIQUE_BOUNDS_H

#include <optional>
#include <string>
#include <unordered_map>
#include <unordered_set>

#include <analyze/comp_transient_bounds.h>
#include <hash.h>
#include <math/bounds.h>
#include <visitor.h>

namespace freetensor {

class CompUniqueBounds {
public:
enum class BoundType { Combination, Presburger };

class Bound {
public:
virtual ~Bound() {}

virtual BoundType type() const = 0;

/**
* Get an integer bound. In case of no solution, return LLONG_MAX or
* LLONG_MIN
*
* @{
*/
virtual int64_t lowerInt() const = 0;
virtual int64_t upperInt() const = 0;
/** @} */

/**
* If the bounded value is a constant integer, return it
*/
virtual std::optional<int64_t> getInt() const = 0;

/**
* Return an Expr for the bound. In case of no solution, return nullptr
*
* @{
*/
virtual Expr lowerExpr() const = 0;
virtual Expr upperExpr() const = 0;
/** @} */

virtual Ref<Bound>
restrictScope(const std::unordered_set<std::string> &scope) const = 0;

virtual Expr simplestExpr(
const std::unordered_map<std::string, int> &orderedScope) const = 0;
};

protected:
const CompTransientBoundsInterface &transients_;

public:
CompUniqueBounds(const CompTransientBoundsInterface &transients)
: transients_(transients) {}
virtual ~CompUniqueBounds() {}

virtual Ref<Bound> getBound(const Expr &op) = 0;

int64_t getIntLower(const Expr &op) { return getBound(op)->lowerInt(); }
int64_t getIntUpper(const Expr &op) { return getBound(op)->upperInt(); }
std::optional<int64_t> getInt(const Expr &op) {
return getBound(op)->getInt();
}

virtual bool alwaysLT(const Expr &lhs, const Expr &rhs) = 0;
virtual bool alwaysLE(const Expr &lhs, const Expr &rhs) = 0;

virtual std::pair<Expr, Expr>
unionBounds(const std::vector<Ref<Bound>> &bounds) = 0;
};

/**
* Compute bounds of each UNIQUE INTEGER (sub)expression
*
Expand All @@ -31,49 +98,52 @@ namespace freetensor {
* This pass is not accurate. Simplifying passes using this analysis may need
* to run for multiple rounds
*/
class CompUniqueBounds : public Visitor {
class CompUniqueBoundsCombination : public CompUniqueBounds, public Visitor {
typedef Visitor BaseClass;

public:
typedef std::vector<LowerBound> LowerBoundsList;
typedef std::vector<UpperBound> UpperBoundsList;
typedef ASTHashMap<Expr, LowerBoundsList> LowerBoundsMap;
typedef ASTHashMap<Expr, UpperBoundsList> UpperBoundsMap;

private:
const CompTransientBoundsInterface &transients_;

LowerBoundsMap lower_;
UpperBoundsMap upper_;

public:
CompUniqueBounds(const CompTransientBoundsInterface &transients)
: transients_(transients) {}
class Bound : public CompUniqueBounds::Bound {
// retrieving expr from Lower/UpperBound requires to be mutable. fake it
// here.
mutable std::vector<LowerBound> lowerBounds_;
mutable std::vector<UpperBound> upperBounds_;

LowerBoundsList getLower(const Expr &op) {
(*this)(op);
return lower_.at(op);
}
UpperBoundsList getUpper(const Expr &op) {
(*this)(op);
return upper_.at(op);
}
friend class CompUniqueBoundsCombination;

int64_t getIntLower(const Expr &op);
int64_t getIntUpper(const Expr &op);
std::optional<int64_t> getInt(const Expr &op);
public:
Bound(std::vector<LowerBound> lowerBounds,
std::vector<UpperBound> upperBounds)
: lowerBounds_(std::move(lowerBounds)),
upperBounds_(std::move(upperBounds)) {}

/**
* Get all bounds defined by only variables or iterators in `names`
* @{
*/
LowerBoundsList
getDefinedLower(const Expr &op,
const std::unordered_set<std::string> &names);
UpperBoundsList
getDefinedUpper(const Expr &op,
const std::unordered_set<std::string> &names);
/** @} */
BoundType type() const override { return BoundType::Combination; }

int64_t lowerInt() const override;
int64_t upperInt() const override;
std::optional<int64_t> getInt() const override;

Expr lowerExpr() const override;
Expr upperExpr() const override;

Ref<CompUniqueBounds::Bound> restrictScope(
const std::unordered_set<std::string> &scope) const override;

Expr simplestExpr(const std::unordered_map<std::string, int>
&orderedScope) const override;
};

CompUniqueBoundsCombination(const CompTransientBoundsInterface &transients)
: CompUniqueBounds(transients) {}

Ref<CompUniqueBounds::Bound> getBound(const Expr &op) override;

/**
* Check wheter `lhs` is always less than `rhs`
Expand All @@ -83,10 +153,22 @@ class CompUniqueBounds : public Visitor {
* For precise comparison, please use `getLower` or `getUpper` on
* `makeSub(lhs, rhs)`
*/
bool alwaysLT(const Expr &lhs, const Expr &rhs);
bool alwaysLE(const Expr &lhs, const Expr &rhs);
bool alwaysLT(const Expr &lhs, const Expr &rhs) override;
bool alwaysLE(const Expr &lhs, const Expr &rhs) override;

std::pair<Expr, Expr> unionBounds(
const std::vector<Ref<CompUniqueBounds::Bound>> &bounds) override;

protected:
LowerBoundsList getLower(const Expr &op) {
(*this)(op);
return lower_.at(op);
}
UpperBoundsList getUpper(const Expr &op) {
(*this)(op);
return upper_.at(op);
}

template <class T> void setLower(const Expr &op, T &&list) {
lower_[op] = std::forward<T>(list);
}
Expand Down
1 change: 1 addition & 0 deletions include/analyze/structural_feature.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef FREE_TENSOR_STRUCTURAL_FEATURE_H
#define FREE_TENSOR_STRUCTURAL_FEATURE_H

#include <memory>
#include <unordered_map>
#include <unordered_set>

Expand Down
11 changes: 11 additions & 0 deletions include/math/parse_pb_expr.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#ifndef FREE_TENSOR_PARSE_PB_EXPR_H
#define FREE_TENSOR_PARSE_PB_EXPR_H

#include <iostream>

#include <expr.h>
#include <math/presburger.h>

namespace freetensor {

Expand All @@ -14,6 +17,8 @@ struct SimplePBFuncAST {
Expr cond_; // Maybe null
};

std::ostream &operator<<(std::ostream &os, const SimplePBFuncAST &ast);

/**
* A PBFunc parsed as ASTs
*/
Expand All @@ -29,6 +34,12 @@ PBFuncAST parsePBFunc(const std::string &str);
*/
SimplePBFuncAST parseSimplePBFunc(const std::string &str);

/**
* Construct AST from PBSet while preserving min and max with a special hack to
* ISL
*/
PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBSet &set);

} // namespace freetensor

#endif // FREE_TENSOR_PARSE_PB_EXPR_H
Loading

0 comments on commit 67fef0a

Please sign in to comment.