Skip to content

Commit

Permalink
Restructure source files for CompUniqueBounds' implementations (#572)
Browse files Browse the repository at this point in the history
  • Loading branch information
roastduck authored Dec 31, 2023
1 parent 67fef0a commit 57e4671
Show file tree
Hide file tree
Showing 19 changed files with 193 additions and 176 deletions.
1 change: 0 additions & 1 deletion ffi/pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include <pass/make_reduction.h>
#include <pass/merge_and_hoist_if.h>
#include <pass/move_out_first_or_last_iter.h>
#include <pass/pb_simplify.h>
#include <pass/prop_one_time_use.h>
#include <pass/remove_cyclic_assign.h>
#include <pass/remove_dead_var.h>
Expand Down
143 changes: 0 additions & 143 deletions include/analyze/comp_unique_bounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
#include <unordered_set>

#include <analyze/comp_transient_bounds.h>
#include <hash.h>
#include <math/bounds.h>
#include <visitor.h>

namespace freetensor {

Expand Down Expand Up @@ -77,146 +74,6 @@ class CompUniqueBounds {
unionBounds(const std::vector<Ref<Bound>> &bounds) = 0;
};

/**
* Compute bounds of each UNIQUE INTEGER (sub)expression
*
* E.g.
*
* ```
* if (x < 2) {
* ... = x;
* }
* ... = x;
* ```
*
* Two UNIQUE expressions `x` have different upper bounds
*
* For each statements in the AST, a corresponding instance of this class should
* be created to deal with all (sub)expressions in the statement, so as to
* distinguish different `x` sites in the example above
*
* This pass is not accurate. Simplifying passes using this analysis may need
* to run for multiple rounds
*/
class CompUniqueBoundsCombination : public CompUniqueBounds, public Visitor {
typedef Visitor BaseClass;

typedef std::vector<LowerBound> LowerBoundsList;
typedef std::vector<UpperBound> UpperBoundsList;
typedef ASTHashMap<Expr, LowerBoundsList> LowerBoundsMap;
typedef ASTHashMap<Expr, UpperBoundsList> UpperBoundsMap;

LowerBoundsMap lower_;
UpperBoundsMap upper_;

public:
class Bound : public CompUniqueBounds::Bound {
// retrieving expr from Lower/UpperBound requires to be mutable. fake it
// here.
mutable std::vector<LowerBound> lowerBounds_;
mutable std::vector<UpperBound> upperBounds_;

friend class CompUniqueBoundsCombination;

public:
Bound(std::vector<LowerBound> lowerBounds,
std::vector<UpperBound> upperBounds)
: lowerBounds_(std::move(lowerBounds)),
upperBounds_(std::move(upperBounds)) {}

BoundType type() const override { return BoundType::Combination; }

int64_t lowerInt() const override;
int64_t upperInt() const override;
std::optional<int64_t> getInt() const override;

Expr lowerExpr() const override;
Expr upperExpr() const override;

Ref<CompUniqueBounds::Bound> restrictScope(
const std::unordered_set<std::string> &scope) const override;

Expr simplestExpr(const std::unordered_map<std::string, int>
&orderedScope) const override;
};

CompUniqueBoundsCombination(const CompTransientBoundsInterface &transients)
: CompUniqueBounds(transients) {}

Ref<CompUniqueBounds::Bound> getBound(const Expr &op) override;

/**
* Check wheter `lhs` is always less than `rhs`
*
* This is a fast non-recursing function, which check the less-than relation
* literally, without invoking CompUniqueBounds again, but maybe imprecise.
* For precise comparison, please use `getLower` or `getUpper` on
* `makeSub(lhs, rhs)`
*/
bool alwaysLT(const Expr &lhs, const Expr &rhs) override;
bool alwaysLE(const Expr &lhs, const Expr &rhs) override;

std::pair<Expr, Expr> unionBounds(
const std::vector<Ref<CompUniqueBounds::Bound>> &bounds) override;

protected:
LowerBoundsList getLower(const Expr &op) {
(*this)(op);
return lower_.at(op);
}
UpperBoundsList getUpper(const Expr &op) {
(*this)(op);
return upper_.at(op);
}

template <class T> void setLower(const Expr &op, T &&list) {
lower_[op] = std::forward<T>(list);
}
template <class T> void setUpper(const Expr &op, T &&list) {
upper_[op] = std::forward<T>(list);
}

/**
* Insert a new bound to a list of bounds. But if the new bound is a trivial
* deduction of existing bounds in the list, it will not be inserted
*
* @{
*/
void updLower(LowerBoundsList &list, const LowerBound &bound) const;
void updUpper(UpperBoundsList &list, const UpperBound &bound) const;
/** @} */

private:
/**
* When analyzing Add, Sub and Mul, we first convert it to an linear
* expression before analyzing bounds, so `a - a: l <= a <= r` results in `0
* <= a - a <= 0`, instead of `l - r, l - a, a - r, 0 <= a - a <= r - l, r -
* a, a - l, 0`
*/
void visitLinear(const Expr &op);

void insertSignDataTypeInfo(const Expr &op);

protected:
void visitExpr(const Expr &op) override;

void visit(const Var &op) override;
void visit(const Load &op) override;
void visit(const Cast &op) override;
void visit(const Intrinsic &op) override;
void visit(const IntConst &op) override;
void visit(const Add &op) override;
void visit(const Sub &op) override;
void visit(const Mul &op) override;
void visit(const Square &op) override;
void visit(const FloorDiv &op) override;
void visit(const CeilDiv &op) override;
void visit(const Mod &op) override;
void visit(const Min &op) override;
void visit(const Max &op) override;
void visit(const IfExpr &op) override;
};

} // namespace freetensor

#endif // FREE_TENSOR_COMP_UNIQUE_BOUNDS_H
159 changes: 159 additions & 0 deletions include/analyze/comp_unique_bounds_combination.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#ifndef FREE_TENSOR_COMP_UNIQUE_BOUNDS_COMBINATION_H
#define FREE_TENSOR_COMP_UNIQUE_BOUNDS_COMBINATION_H

#include <optional>
#include <string>
#include <unordered_map>
#include <unordered_set>

#include <analyze/comp_transient_bounds.h>
#include <analyze/comp_unique_bounds.h>
#include <hash.h>
#include <math/bounds.h>
#include <visitor.h>

namespace freetensor {

/**
* Compute bounds of each UNIQUE INTEGER (sub)expression
*
* E.g.
*
* ```
* if (x < 2) {
* ... = x;
* }
* ... = x;
* ```
*
* Two UNIQUE expressions `x` have different upper bounds
*
* For each statements in the AST, a corresponding instance of this class should
* be created to deal with all (sub)expressions in the statement, so as to
* distinguish different `x` sites in the example above
*
* This pass is not accurate. Simplifying passes using this analysis may need
* to run for multiple rounds
*/
class CompUniqueBoundsCombination : public CompUniqueBounds, public Visitor {
typedef Visitor BaseClass;

typedef std::vector<LowerBound> LowerBoundsList;
typedef std::vector<UpperBound> UpperBoundsList;
typedef ASTHashMap<Expr, LowerBoundsList> LowerBoundsMap;
typedef ASTHashMap<Expr, UpperBoundsList> UpperBoundsMap;

LowerBoundsMap lower_;
UpperBoundsMap upper_;

public:
class Bound : public CompUniqueBounds::Bound {
// retrieving expr from Lower/UpperBound requires to be mutable. fake it
// here.
mutable std::vector<LowerBound> lowerBounds_;
mutable std::vector<UpperBound> upperBounds_;

friend class CompUniqueBoundsCombination;

public:
Bound(std::vector<LowerBound> lowerBounds,
std::vector<UpperBound> upperBounds)
: lowerBounds_(std::move(lowerBounds)),
upperBounds_(std::move(upperBounds)) {}

BoundType type() const override { return BoundType::Combination; }

int64_t lowerInt() const override;
int64_t upperInt() const override;
std::optional<int64_t> getInt() const override;

Expr lowerExpr() const override;
Expr upperExpr() const override;

Ref<CompUniqueBounds::Bound> restrictScope(
const std::unordered_set<std::string> &scope) const override;

Expr simplestExpr(const std::unordered_map<std::string, int>
&orderedScope) const override;
};

CompUniqueBoundsCombination(const CompTransientBoundsInterface &transients)
: CompUniqueBounds(transients) {}

Ref<CompUniqueBounds::Bound> getBound(const Expr &op) override;

/**
* Check wheter `lhs` is always less than `rhs`
*
* This is a fast non-recursing function, which check the less-than relation
* literally, without invoking CompUniqueBounds again, but maybe imprecise.
* For precise comparison, please use `getLower` or `getUpper` on
* `makeSub(lhs, rhs)`
*/
bool alwaysLT(const Expr &lhs, const Expr &rhs) override;
bool alwaysLE(const Expr &lhs, const Expr &rhs) override;

std::pair<Expr, Expr> unionBounds(
const std::vector<Ref<CompUniqueBounds::Bound>> &bounds) override;

protected:
LowerBoundsList getLower(const Expr &op) {
(*this)(op);
return lower_.at(op);
}
UpperBoundsList getUpper(const Expr &op) {
(*this)(op);
return upper_.at(op);
}

template <class T> void setLower(const Expr &op, T &&list) {
lower_[op] = std::forward<T>(list);
}
template <class T> void setUpper(const Expr &op, T &&list) {
upper_[op] = std::forward<T>(list);
}

/**
* Insert a new bound to a list of bounds. But if the new bound is a trivial
* deduction of existing bounds in the list, it will not be inserted
*
* @{
*/
void updLower(LowerBoundsList &list, const LowerBound &bound) const;
void updUpper(UpperBoundsList &list, const UpperBound &bound) const;
/** @} */

private:
/**
* When analyzing Add, Sub and Mul, we first convert it to an linear
* expression before analyzing bounds, so `a - a: l <= a <= r` results in `0
* <= a - a <= 0`, instead of `l - r, l - a, a - r, 0 <= a - a <= r - l, r -
* a, a - l, 0`
*/
void visitLinear(const Expr &op);

void insertSignDataTypeInfo(const Expr &op);

protected:
void visitExpr(const Expr &op) override;

void visit(const Var &op) override;
void visit(const Load &op) override;
void visit(const Cast &op) override;
void visit(const Intrinsic &op) override;
void visit(const IntConst &op) override;
void visit(const Add &op) override;
void visit(const Sub &op) override;
void visit(const Mul &op) override;
void visit(const Square &op) override;
void visit(const FloorDiv &op) override;
void visit(const CeilDiv &op) override;
void visit(const Mod &op) override;
void visit(const Min &op) override;
void visit(const Max &op) override;
void visit(const IfExpr &op) override;
};

} // namespace freetensor

#endif // FREE_TENSOR_COMP_UNIQUE_BOUNDS_COMBINATION_H
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
#ifndef FREE_TENSOR_PB_SIMPLIFY_H
#define FREE_TENSOR_PB_SIMPLIFY_H
#ifndef FREE_TENSOR_COMP_UNIQUE_BOUNDS_PB_H
#define FREE_TENSOR_COMP_UNIQUE_BOUNDS_PB_H

#include <optional>
#include <unordered_map>
#include <unordered_set>

#include <analyze/comp_unique_bounds.h>
#include <math/gen_pb_expr.h>
#include <math/parse_pb_expr.h>
#include <math/presburger.h>
#include <math/utils.h>
#include <pass/simplify.h>

namespace freetensor {

Expand Down Expand Up @@ -78,18 +75,6 @@ class CompUniqueBoundsPB : public CompUniqueBounds {
const std::vector<Ref<CompUniqueBounds::Bound>> &bounds) override;
};

class PBSimplify : public SimplifyPass {
public:
PBSimplify()
: SimplifyPass([](const CompTransientBoundsInterface &tr) {
return Ref<CompUniqueBoundsPB>::make(tr);
}) {}
};

Stmt pbSimplify(const Stmt &op);

DEFINE_PASS_FOR_FUNC(pbSimplify)

} // namespace freetensor

#endif // FREE_TENSOR_PB_SIMPLIFY_H
#endif // FREE_TENSOR_COMP_UNIQUE_BOUNDS_PB_H
Loading

0 comments on commit 57e4671

Please sign in to comment.