From 67fef0a45a13ef61169016c86880afea78a96ac8 Mon Sep 17 00:00:00 2001 From: Blealtan Date: Sat, 30 Dec 2023 19:44:53 +0800 Subject: [PATCH] Implement real presburger-based CompUniqueBounds. (#283) * Abstraction over CompUniqueBounds. * Prepare for isolating implementation. * Rename and isolate implementation. * Cleanup includes. * Fix isl parser. * Initial implementation of full PB unique bound. * Fix for a type inconsistency in simplify. It fixes 20.pass/test_simplify.py::test_multiple_min_max[pb_simplify]. But why? * Fix pb_simplify on unreachable code. * Additional coalescing. * Fix compile error. * Fix priority in pb_parser. * Pass basic tests * Reset default parenDespitePriority flag to preserve existing tests * A new way to reconstruct min / max * Fix inter-PBCtx use of PBSet + Parse more isl's AST nodes * Multiple fixes - Fix return value from CompUniqueBounds when there is no non-trivial bounds. - Fix pass/shrink_var for whether using old shapes for bounds. - Fix incorrect testing program in 20.pass/test_prop_one_time_use.py::test_thread_local_no_prop * Fix for 20.pass/test_simplify but revert some new tests * Fix for test/21.autograd/test_output_intermediates.py::test_dynamic_loop_range * Fix for more tests * Fix for more tests * Fix pass/gpu/normalize_threads * Add 20.pass/test_shrink_for.py::test_presburger_bounds * Pass some previously disabled tests --------- Co-authored-by: Shizhi Tang --- grammar/pb_parser.g | 20 +- include/analyze/comp_access_bound.h | 11 +- include/analyze/comp_transient_bounds.h | 11 + include/analyze/comp_unique_bounds.h | 144 +++++++-- include/analyze/structural_feature.h | 1 + include/math/parse_pb_expr.h | 11 + include/math/presburger.h | 210 ++++++++++++- include/pass/gpu/normalize_threads.h | 18 ++ include/pass/gpu/normalize_var_in_kernel.h | 3 +- include/pass/make_parallel_reduction.h | 6 +- include/pass/pb_simplify.h | 66 ++++- include/pass/shrink_for.h | 13 +- include/pass/shrink_var.h | 34 ++- include/pass/simplify.h | 2 +- include/pass/use_builtin_div.h | 1 + include/serialize/print_ast.h | 19 +- src/analyze/comp_access_bound.cc | 66 +---- src/analyze/comp_unique_bounds.cc | 257 +++++++++++----- src/analyze/structural_feature.cc | 48 +-- src/autograd/analyze_version.cc | 11 +- src/math/parse_pb_expr.cc | 224 ++++++++++++++ src/pass/gpu/lower_parallel_reduction.cc | 40 ++- src/pass/gpu/make_sync.cc | 2 +- src/pass/gpu/normalize_thread_dims.cc | 33 +-- src/pass/gpu/normalize_threads.cc | 50 +++- src/pass/gpu/normalize_var_in_kernel.cc | 18 +- src/pass/make_parallel_reduction.cc | 88 +++--- src/pass/pb_simplify.cc | 278 +++++++++++++++--- src/pass/shrink_for.cc | 35 +-- src/pass/shrink_var.cc | 55 ++-- src/pass/simplify.cc | 147 +++++---- src/pass/use_builtin_div.cc | 2 +- src/pass/z3_simplify.cc | 8 +- src/serialize/print_ast.cc | 9 +- test/20.pass/test_prop_one_time_use.py | 6 +- test/20.pass/test_shrink_for.py | 24 ++ test/20.pass/test_shrink_var.py | 129 ++++---- test/20.pass/test_simplify.py | 88 +++--- test/21.autograd/test_output_intermediates.py | 2 +- test/30.schedule/test_cache.py | 24 +- 40 files changed, 1590 insertions(+), 624 deletions(-) diff --git a/grammar/pb_parser.g b/grammar/pb_parser.g index 72dbf94a6..6c473a094 100644 --- a/grammar/pb_parser.g +++ b/grammar/pb_parser.g @@ -87,6 +87,14 @@ expr returns [Expr node] { $node = makeCeilDiv($expr0.node, $expr1.node); } + | MIN '(' expr0=expr ',' expr1=expr ')' + { + $node = makeMin($expr0.node, $expr1.node); + } + | MAX '(' expr0=expr ',' expr1=expr ')' + { + $node = makeMax($expr0.node, $expr1.node); + } | intConst expr { $node = makeMul($intConst.node, $expr.node); @@ -121,20 +129,12 @@ expr returns [Expr node] case 2: $node = makeSub($expr0.node, $expr1.node); break; } } - | MIN '(' expr0=expr ',' expr1=expr ')' - { - $node = makeMin($expr0.node, $expr1.node); - } - | MAX '(' expr0=expr ',' expr1=expr ')' - { - $node = makeMax($expr0.node, $expr1.node); - } ; boolExpr returns [Expr node] - : '(' boolExpr ')' + : '(' boolExpr0=boolExpr ')' { - $node = $boolExpr.node; + $node = $boolExpr0.node; } | expr0=expr { diff --git a/include/analyze/comp_access_bound.h b/include/analyze/comp_access_bound.h index a044aa90e..322cb47fe 100644 --- a/include/analyze/comp_access_bound.h +++ b/include/analyze/comp_access_bound.h @@ -1,6 +1,7 @@ #ifndef FREE_TENSOR_COMP_ACCESS_BOUND_H #define FREE_TENSOR_COMP_ACCESS_BOUND_H +#include #include #include @@ -47,16 +48,15 @@ class CompAccessBound : public CompTransientBounds> { public: struct Access { std::vector indices_, conds_; - std::vector> lower_; - std::vector> upper_; + std::vector> bounds_; Access(CompUniqueBounds &unique, const std::vector &indices, const std::vector &conds, const std::unordered_set &names) : indices_(indices), conds_(conds) { for (auto &&idx : indices) { - lower_.emplace_back(unique.getDefinedLower(idx, names)); - upper_.emplace_back(unique.getDefinedUpper(idx, names)); + bounds_.emplace_back( + unique.getBound(idx)->restrictScope(names)); } } @@ -64,8 +64,7 @@ class CompAccessBound : public CompTransientBounds> { const std::vector &conds) : indices_(indices), conds_(conds) { for (auto &&idx : indices) { - lower_.emplace_back(unique.getLower(idx)); - upper_.emplace_back(unique.getUpper(idx)); + bounds_.emplace_back(unique.getBound(idx)); } } }; diff --git a/include/analyze/comp_transient_bounds.h b/include/analyze/comp_transient_bounds.h index cf4929e91..dac0b1020 100644 --- a/include/analyze/comp_transient_bounds.h +++ b/include/analyze/comp_transient_bounds.h @@ -23,6 +23,7 @@ class CompTransientBoundsInterface { public: virtual TransientBound transient(const Expr &op) const = 0; virtual const std::vector &conds() const = 0; + virtual const Stmt ¤tStmt() const = 0; }; /** @@ -53,6 +54,9 @@ class CompTransientBounds : public BaseClass, // Original bounds std::vector conds_; + // Currently visited statement + Stmt currentStmt_; + public: TransientBound transient(const Expr &op) const override { if (transients_.count(op)) { @@ -63,6 +67,8 @@ class CompTransientBounds : public BaseClass, const std::vector &conds() const override { return conds_; } + const Stmt ¤tStmt() const override { return currentStmt_; }; + private: void applyCond(const Expr &_cond, const std::unordered_set &bodyAllWrites) { @@ -240,6 +246,11 @@ class CompTransientBounds : public BaseClass, op->id(), op->debugBlame()); } } + + typename BaseClass::StmtRetType visitStmt(const Stmt &op) override { + currentStmt_ = op; + return BaseClass::visitStmt(op); + } }; } // namespace freetensor diff --git a/include/analyze/comp_unique_bounds.h b/include/analyze/comp_unique_bounds.h index 0d3392343..53ac63d59 100644 --- a/include/analyze/comp_unique_bounds.h +++ b/include/analyze/comp_unique_bounds.h @@ -2,14 +2,81 @@ #define FREE_TENSOR_COMP_UNIQUE_BOUNDS_H #include +#include +#include #include #include #include +#include #include namespace freetensor { +class CompUniqueBounds { + public: + enum class BoundType { Combination, Presburger }; + + class Bound { + public: + virtual ~Bound() {} + + virtual BoundType type() const = 0; + + /** + * Get an integer bound. In case of no solution, return LLONG_MAX or + * LLONG_MIN + * + * @{ + */ + virtual int64_t lowerInt() const = 0; + virtual int64_t upperInt() const = 0; + /** @} */ + + /** + * If the bounded value is a constant integer, return it + */ + virtual std::optional getInt() const = 0; + + /** + * Return an Expr for the bound. In case of no solution, return nullptr + * + * @{ + */ + virtual Expr lowerExpr() const = 0; + virtual Expr upperExpr() const = 0; + /** @} */ + + virtual Ref + restrictScope(const std::unordered_set &scope) const = 0; + + virtual Expr simplestExpr( + const std::unordered_map &orderedScope) const = 0; + }; + + protected: + const CompTransientBoundsInterface &transients_; + + public: + CompUniqueBounds(const CompTransientBoundsInterface &transients) + : transients_(transients) {} + virtual ~CompUniqueBounds() {} + + virtual Ref getBound(const Expr &op) = 0; + + int64_t getIntLower(const Expr &op) { return getBound(op)->lowerInt(); } + int64_t getIntUpper(const Expr &op) { return getBound(op)->upperInt(); } + std::optional getInt(const Expr &op) { + return getBound(op)->getInt(); + } + + virtual bool alwaysLT(const Expr &lhs, const Expr &rhs) = 0; + virtual bool alwaysLE(const Expr &lhs, const Expr &rhs) = 0; + + virtual std::pair + unionBounds(const std::vector> &bounds) = 0; +}; + /** * Compute bounds of each UNIQUE INTEGER (sub)expression * @@ -31,49 +98,52 @@ namespace freetensor { * This pass is not accurate. Simplifying passes using this analysis may need * to run for multiple rounds */ -class CompUniqueBounds : public Visitor { +class CompUniqueBoundsCombination : public CompUniqueBounds, public Visitor { typedef Visitor BaseClass; - public: typedef std::vector LowerBoundsList; typedef std::vector UpperBoundsList; typedef ASTHashMap LowerBoundsMap; typedef ASTHashMap UpperBoundsMap; - private: - const CompTransientBoundsInterface &transients_; - LowerBoundsMap lower_; UpperBoundsMap upper_; public: - CompUniqueBounds(const CompTransientBoundsInterface &transients) - : transients_(transients) {} + class Bound : public CompUniqueBounds::Bound { + // retrieving expr from Lower/UpperBound requires to be mutable. fake it + // here. + mutable std::vector lowerBounds_; + mutable std::vector upperBounds_; - LowerBoundsList getLower(const Expr &op) { - (*this)(op); - return lower_.at(op); - } - UpperBoundsList getUpper(const Expr &op) { - (*this)(op); - return upper_.at(op); - } + friend class CompUniqueBoundsCombination; - int64_t getIntLower(const Expr &op); - int64_t getIntUpper(const Expr &op); - std::optional getInt(const Expr &op); + public: + Bound(std::vector lowerBounds, + std::vector upperBounds) + : lowerBounds_(std::move(lowerBounds)), + upperBounds_(std::move(upperBounds)) {} - /** - * Get all bounds defined by only variables or iterators in `names` - * @{ - */ - LowerBoundsList - getDefinedLower(const Expr &op, - const std::unordered_set &names); - UpperBoundsList - getDefinedUpper(const Expr &op, - const std::unordered_set &names); - /** @} */ + BoundType type() const override { return BoundType::Combination; } + + int64_t lowerInt() const override; + int64_t upperInt() const override; + std::optional getInt() const override; + + Expr lowerExpr() const override; + Expr upperExpr() const override; + + Ref restrictScope( + const std::unordered_set &scope) const override; + + Expr simplestExpr(const std::unordered_map + &orderedScope) const override; + }; + + CompUniqueBoundsCombination(const CompTransientBoundsInterface &transients) + : CompUniqueBounds(transients) {} + + Ref getBound(const Expr &op) override; /** * Check wheter `lhs` is always less than `rhs` @@ -83,10 +153,22 @@ class CompUniqueBounds : public Visitor { * For precise comparison, please use `getLower` or `getUpper` on * `makeSub(lhs, rhs)` */ - bool alwaysLT(const Expr &lhs, const Expr &rhs); - bool alwaysLE(const Expr &lhs, const Expr &rhs); + bool alwaysLT(const Expr &lhs, const Expr &rhs) override; + bool alwaysLE(const Expr &lhs, const Expr &rhs) override; + + std::pair unionBounds( + const std::vector> &bounds) override; protected: + LowerBoundsList getLower(const Expr &op) { + (*this)(op); + return lower_.at(op); + } + UpperBoundsList getUpper(const Expr &op) { + (*this)(op); + return upper_.at(op); + } + template void setLower(const Expr &op, T &&list) { lower_[op] = std::forward(list); } diff --git a/include/analyze/structural_feature.h b/include/analyze/structural_feature.h index 9b447497d..b3d969e92 100644 --- a/include/analyze/structural_feature.h +++ b/include/analyze/structural_feature.h @@ -1,6 +1,7 @@ #ifndef FREE_TENSOR_STRUCTURAL_FEATURE_H #define FREE_TENSOR_STRUCTURAL_FEATURE_H +#include #include #include diff --git a/include/math/parse_pb_expr.h b/include/math/parse_pb_expr.h index 515d1eb95..6a88cf8e6 100644 --- a/include/math/parse_pb_expr.h +++ b/include/math/parse_pb_expr.h @@ -1,7 +1,10 @@ #ifndef FREE_TENSOR_PARSE_PB_EXPR_H #define FREE_TENSOR_PARSE_PB_EXPR_H +#include + #include +#include namespace freetensor { @@ -14,6 +17,8 @@ struct SimplePBFuncAST { Expr cond_; // Maybe null }; +std::ostream &operator<<(std::ostream &os, const SimplePBFuncAST &ast); + /** * A PBFunc parsed as ASTs */ @@ -29,6 +34,12 @@ PBFuncAST parsePBFunc(const std::string &str); */ SimplePBFuncAST parseSimplePBFunc(const std::string &str); +/** + * Construct AST from PBSet while preserving min and max with a special hack to + * ISL + */ +PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBSet &set); + } // namespace freetensor #endif // FREE_TENSOR_PARSE_PB_EXPR_H diff --git a/include/math/presburger.h b/include/math/presburger.h index 8adf34f21..b34256792 100644 --- a/include/math/presburger.h +++ b/include/math/presburger.h @@ -1,18 +1,22 @@ #ifndef FREE_TENSOR_PRESBURGER_H #define FREE_TENSOR_PRESBURGER_H +#include #include +#include #include #include #include #include #include +#include #include #include #include #include #include +#include #include #include @@ -116,11 +120,21 @@ class PBMap { bool isSingleValued() const { return isl_map_is_single_valued(get()); } bool isBijective() const { return isl_map_is_bijective(get()); } - isl_size nBasic() const { return isl_map_n_basic_map(map_); } + isl_size nBasic() const { return isl_map_n_basic_map(get()); } - isl_size nInDims() const { return isl_map_dim(map_, isl_dim_in); } - isl_size nOutDims() const { return isl_map_dim(map_, isl_dim_out); } - isl_size nParamDims() const { return isl_map_dim(map_, isl_dim_param); } + isl_size nInDims() const { return isl_map_dim(get(), isl_dim_in); } + isl_size nOutDims() const { return isl_map_dim(get(), isl_dim_out); } + isl_size nParamDims() const { return isl_map_dim(get(), isl_dim_param); } + + const char *nameInDim(unsigned i) const { + return isl_map_get_dim_name(get(), isl_dim_in, i); + } + const char *nameOutDim(unsigned i) const { + return isl_map_get_dim_name(get(), isl_dim_out, i); + } + const char *nameParamDim(unsigned i) const { + return isl_map_get_dim_name(get(), isl_dim_param, i); + } friend std::ostream &operator<<(std::ostream &os, const PBMap &map) { return os << isl_map_to_str(map.map_); @@ -163,7 +177,12 @@ class PBVal { isl_val *copy() const { return COPY_ISL_PTR(val_, val); } isl_val *move() { return MOVE_ISL_PTR(val_); } + bool isNaN() const { return isl_val_is_nan(get()); } bool isRat() const { return isl_val_is_rat(get()); } + bool isInt() const { return isl_val_is_int(get()); } + bool isInf() const { return isl_val_is_infty(get()); } + bool isNegInf() const { return isl_val_is_neginfty(get()); } + int numSi() const { return isl_val_get_num_si(get()); } int denSi() const { return isl_val_get_den_si(get()); } @@ -219,9 +238,26 @@ class PBSet { return isl_set_is_empty(get()); } + bool isSingleValued() const { return isl_set_is_singleton(get()); } + isl_size nBasic() const { return isl_set_n_basic_set(set_); } - isl_size nDims() const { return isl_set_dim(set_, isl_dim_set); } + isl_size nDims() const { return isl_set_dim(get(), isl_dim_set); } + isl_size nParamDims() const { return isl_set_dim(get(), isl_dim_param); } + + const char *nameDim(unsigned i) const { + return isl_set_get_dim_name(get(), isl_dim_set, i); + } + const char *nameParamDim(unsigned i) const { + return isl_set_get_dim_name(get(), isl_dim_param, i); + } + + bool hasLowerBound(unsigned i) const { + return isl_set_dim_has_lower_bound(get(), isl_dim_set, i); + } + bool hasUpperBound(unsigned i) const { + return isl_set_dim_has_upper_bound(get(), isl_dim_set, i); + } friend std::ostream &operator<<(std::ostream &os, const PBSet &set) { return os << isl_set_to_str(set.set_); @@ -277,14 +313,82 @@ class PBSpace { } }; +class PBSingleFunc { + isl_pw_aff *func_ = nullptr; + + public: + PBSingleFunc() {} + PBSingleFunc(isl_pw_aff *func) : func_(func) {} + ~PBSingleFunc() { + if (func_ != nullptr) { + isl_pw_aff_free(func_); + } + } + + PBSingleFunc(const PBSingleFunc &other) : func_(other.copy()) {} + PBSingleFunc &operator=(const PBSingleFunc &other) { + if (func_ != nullptr) { + isl_pw_aff_free(func_); + } + func_ = other.copy(); + return *this; + } + + PBSingleFunc(PBSingleFunc &&other) : func_(other.move()) {} + PBSingleFunc &operator=(PBSingleFunc &&other) { + if (func_ != nullptr) { + isl_pw_aff_free(func_); + } + func_ = other.move(); + return *this; + } + + bool isValid() const { return func_ != nullptr; } + + isl_pw_aff *get() const { return GET_ISL_PTR(func_); } + isl_pw_aff *copy() const { return COPY_ISL_PTR(func_, pw_aff); } + isl_pw_aff *move() { return MOVE_ISL_PTR(func_); } + + isl_size nInDims() const { return isl_pw_aff_dim(get(), isl_dim_in); } + + std::vector> pieces() const { + std::vector> result; + isl_pw_aff_foreach_piece( + get(), + [](isl_set *set, isl_aff *piece, void *user) { + ((std::vector> *)user) + ->emplace_back(PBSet(set), + PBSingleFunc(isl_pw_aff_from_aff(piece))); + return isl_stat_ok; + }, + &result); + return result; + } + + friend std::ostream &operator<<(std::ostream &os, + const PBSingleFunc &func) { + return os << isl_pw_aff_to_str(func.func_); + } +}; + class PBFunc { isl_pw_multi_aff *func_ = nullptr; public: PBFunc() {} PBFunc(isl_pw_multi_aff *func) : func_(func) {} + + PBFunc(const PBSingleFunc &singleFunc) + : func_(isl_pw_multi_aff_from_pw_aff(singleFunc.copy())) {} + PBFunc(PBSingleFunc &&singleFunc) + : func_(isl_pw_multi_aff_from_pw_aff(singleFunc.move())) {} + PBFunc(const PBMap &map) : func_(isl_pw_multi_aff_from_map(map.copy())) {} PBFunc(PBMap &&map) : func_(isl_pw_multi_aff_from_map(map.move())) {} + + PBFunc(const PBSet &set) : func_(isl_pw_multi_aff_from_set(set.copy())) {} + PBFunc(PBSet &&set) : func_(isl_pw_multi_aff_from_set(set.move())) {} + ~PBFunc() { if (func_ != nullptr) { isl_pw_multi_aff_free(func_); @@ -315,6 +419,30 @@ class PBFunc { isl_pw_multi_aff *copy() const { return COPY_ISL_PTR(func_, pw_multi_aff); } isl_pw_multi_aff *move() { return MOVE_ISL_PTR(func_); } + isl_size nInDims() const { return isl_pw_multi_aff_dim(get(), isl_dim_in); } + isl_size nOutDims() const { + return isl_pw_multi_aff_dim(get(), isl_dim_out); + } + + PBSingleFunc operator[](isl_size i) const { + return isl_pw_multi_aff_get_pw_aff(get(), 0); + } + + std::vector> pieces() const { + std::vector> result; + isl_pw_multi_aff_foreach_piece( + get(), + [](isl_set *set, isl_multi_aff *piece, void *user) { + ((std::vector> *)user) + ->emplace_back( + PBSet(set), + PBFunc(isl_pw_multi_aff_from_multi_aff(piece))); + return isl_stat_ok; + }, + &result); + return result; + } + friend std::ostream &operator<<(std::ostream &os, const PBFunc &func) { return os << isl_pw_multi_aff_to_str(func.func_); } @@ -381,6 +509,8 @@ template concept PBSpaceRef = std::same_as>; template concept PBFuncRef = std::same_as>; +template +concept PBSingleFuncRef = std::same_as>; template auto PBRefTake(std::remove_reference_t &t) { return t.copy(); @@ -397,10 +527,31 @@ template PBMap projectOutAllParams(T &&map) { return isl_map_project_out_all_params(PBRefTake(map)); } +template +PBSet projectOutParamById(T &&set, const std::string &name) { + isl_ctx *ctx = isl_set_get_ctx(set.get()); + return isl_set_project_out_param_id( + PBRefTake(set), isl_id_alloc(ctx, name.c_str(), nullptr)); +} +template +PBSet projectOutParamDims(T &&set, unsigned first, unsigned n) { + return isl_set_project_out(PBRefTake(set), isl_dim_param, first, n); +} template PBSet projectOutDims(T &&set, unsigned first, unsigned n) { return isl_set_project_out(PBRefTake(set), isl_dim_set, first, n); } + +template +PBMap projectOutParamById(T &&map, const std::string &name) { + isl_ctx *ctx = isl_map_get_ctx(map.get()); + return isl_map_project_out_param_id( + PBRefTake(map), isl_id_alloc(ctx, name.c_str(), nullptr)); +} +template +PBMap projectOutParamDims(T &&map, unsigned first, unsigned n) { + return isl_map_project_out(PBRefTake(map), isl_dim_param, first, n); +} template PBMap projectOutInputDims(T &&map, unsigned first, unsigned n) { return isl_map_project_out(PBRefTake(map), isl_dim_in, first, n); @@ -489,6 +640,17 @@ PBMap moveDimsParamToOutput(T &&map, unsigned first, unsigned n, isl_dim_param, first, n); } +template +PBSet moveDimsSetToParam(T &&set, unsigned first, unsigned n, unsigned target) { + return isl_set_move_dims(PBRefTake(set), isl_dim_param, target, + isl_dim_set, first, n); +} +template +PBSet moveDimsParamToSet(T &&set, unsigned first, unsigned n, unsigned target) { + return isl_set_move_dims(PBRefTake(set), isl_dim_set, target, + isl_dim_param, first, n); +} + template std::pair padToSameDims(T &&lhs, U &&rhs) { auto n = std::max(lhs.nDims(), rhs.nDims()); @@ -548,6 +710,15 @@ template PBMap intersectRange(T &&lhs, U &&rhs) { return isl_map_intersect_range(PBRefTake(lhs), PBRefTake(rhs)); } +template +PBSingleFunc intersectDomain(T &&lhs, U &&rhs) { + return isl_pw_aff_intersect_domain(PBRefTake(lhs), PBRefTake(rhs)); +} +template PBFunc intersectDomain(T &&lhs, U &&rhs) { + return isl_multi_pw_aff_intersect_domain(PBRefTake(lhs), + PBRefTake(rhs)); +} + template PBSet intersectParams(T &&lhs, U &&rhs) { return isl_set_intersect_params(PBRefTake(lhs), PBRefTake(rhs)); } @@ -678,6 +849,13 @@ template PBSet range(T &&map) { return isl_map_range(PBRefTake(map)); } +template PBSet domain(T &&func) { + return isl_pw_aff_domain(PBRefTake(func)); +} +template PBSet domain(T &&func) { + return isl_multi_pw_aff_domain(PBRefTake(func)); +} + template PBSet params(T &&set) { return isl_set_params(PBRefTake(set)); } @@ -703,6 +881,10 @@ template PBVal dimMinVal(T &&set, int pos) { return isl_set_dim_min_val(PBRefTake(set), pos); } +inline PBVal dimFixVal(const PBSet &set, int pos) { + return isl_set_plain_get_val_if_fixed(set.get(), isl_dim_set, pos); +} + template PBSpace spaceMapFromSet(T &&space) { return isl_space_map_from_set(PBRefTake(space)); } @@ -715,6 +897,16 @@ template PBPoint sample(T &&set) { return isl_set_sample_point(PBRefTake(set)); } +template +PBSingleFunc min(T &&lhs, U &&rhs) { + return isl_pw_aff_min(PBRefTake(lhs), PBRefTake(rhs)); +} + +template +PBSingleFunc max(T &&lhs, U &&rhs) { + return isl_pw_aff_max(PBRefTake(lhs), PBRefTake(rhs)); +} + /** * @brief Compute the set of coefficients corresponded to the given set * @@ -763,6 +955,14 @@ inline bool operator==(const PBMap &lhs, const PBMap &rhs) { return isl_map_is_equal(lhs.get(), rhs.get()); } +inline bool operator==(const PBSingleFunc &lhs, const PBSingleFunc &rhs) { + return isl_pw_aff_is_equal(lhs.get(), rhs.get()); +} + +inline bool operator==(const PBFunc &lhs, const PBFunc &rhs) { + return isl_pw_multi_aff_is_equal(lhs.get(), rhs.get()); +} + class PBBuildExpr { std::string expr_; explicit PBBuildExpr(const std::string &expr) : expr_(expr) {} diff --git a/include/pass/gpu/normalize_threads.h b/include/pass/gpu/normalize_threads.h index b6824abfe..c8a4a6252 100644 --- a/include/pass/gpu/normalize_threads.h +++ b/include/pass/gpu/normalize_threads.h @@ -7,6 +7,7 @@ #include #include +#include namespace freetensor { @@ -39,6 +40,23 @@ class NormalizeThreads : public Mutator { Stmt visit(const Eval &op) override; }; +class ShrinkNormalizedThreads : public ShrinkFor { + typedef ShrinkFor BaseClass; + + std::unordered_set openLoopsInKernel_; + bool inKernel_ = false; + + protected: + bool filterLoop(const For &op) override; + + std::unordered_set + filterNames(const std::unordered_set &names) override; + + protected: + using BaseClass::visit; + Stmt visit(const For &op) override; +}; + /** * Make GPU parallel scopes to be GPU kernel-like scopes * diff --git a/include/pass/gpu/normalize_var_in_kernel.h b/include/pass/gpu/normalize_var_in_kernel.h index 539702ef9..eb38c8b03 100644 --- a/include/pass/gpu/normalize_var_in_kernel.h +++ b/include/pass/gpu/normalize_var_in_kernel.h @@ -7,6 +7,7 @@ #include #include #include +#include #include namespace freetensor { @@ -16,7 +17,7 @@ namespace gpu { class NormalizeVarInKernel : public CompTransientBounds> { typedef CompTransientBounds> BaseClass; - std::vector legalNames_; + std::unordered_set legalNames_; std::vector varsToHoist_; std::unordered_set usedNamesInKernel_; diff --git a/include/pass/make_parallel_reduction.h b/include/pass/make_parallel_reduction.h index eb839f3fb..88ad2ff71 100644 --- a/include/pass/make_parallel_reduction.h +++ b/include/pass/make_parallel_reduction.h @@ -1,10 +1,12 @@ #ifndef FREE_TENSOR_MAKE_PARLLEL_REDUCTION_H #define FREE_TENSOR_MAKE_PARLLEL_REDUCTION_H +#include #include #include #include +#include #include #include #include @@ -62,8 +64,8 @@ class MakeLoopCarriedReduction struct ReductionItemFactors { ReduceOp op_; std::string var_; - std::vector>> lower_, - upper_; // [dim][access][bound] + std::vector>> + bound_; // [dim][access] bool syncFlush_; }; diff --git a/include/pass/pb_simplify.h b/include/pass/pb_simplify.h index d7c2d5b63..ffd1ec2af 100644 --- a/include/pass/pb_simplify.h +++ b/include/pass/pb_simplify.h @@ -1,11 +1,15 @@ #ifndef FREE_TENSOR_PB_SIMPLIFY_H #define FREE_TENSOR_PB_SIMPLIFY_H +#include #include #include +#include #include +#include #include +#include #include namespace freetensor { @@ -16,27 +20,69 @@ namespace freetensor { * For each statements in the AST, a corresponding instance of this class should * be created to deal with all (sub)expressions in the statement */ -class PBCompBounds : public CompUniqueBounds { +class CompUniqueBoundsPB : public CompUniqueBounds { + public: + class Bound : public CompUniqueBounds::Bound { + Ref ctx_; + // isl var -> ft expr, the demangling map yielded from GenPBExpr + // shared from CompUniqueBoundsPB::cachedFreeVars_ + Ref> demangleMap_; + // isl bounding set, multiple params being all outer variables and + // single output being the bounded expression + PBSet bound_; + + friend class CompUniqueBoundsPB; + + public: + Bound(Ref ctx, + Ref> demangleMap, + PBSet bound) + : ctx_(std::move(ctx)), demangleMap_(std::move(demangleMap)), + bound_(std::move(bound)) {} + + BoundType type() const override { return BoundType::Presburger; } + + int64_t lowerInt() const override; + int64_t upperInt() const override; + std::optional getInt() const override; + + Expr lowerExpr() const override; + Expr upperExpr() const override; + + Ref restrictScope( + const std::unordered_set &scope) const override; + + Expr simplestExpr(const std::unordered_map + &orderedScope) const override; + }; + + private: const CompTransientBoundsInterface &transients_; GenPBExpr genPBExpr_; - PBCtx isl_; - std::unordered_set visited_; + Ref ctx_; - public: - PBCompBounds(const CompTransientBoundsInterface &transients) - : CompUniqueBounds(transients), transients_(transients) {} + Stmt cachedPlace_; + PBSet cachedConds_; + Ref> cachedFreeVars_; + std::unordered_map> cachedValues_; - protected: - using CompUniqueBounds::visit; + public: + CompUniqueBoundsPB(const CompTransientBoundsInterface &transients) + : CompUniqueBounds(transients), transients_(transients), + ctx_(Ref::make()) {} - void visitExpr(const Expr &op) override; + Ref getBound(const Expr &op) override; + bool alwaysLE(const Expr &lhs, const Expr &rhs) override; + bool alwaysLT(const Expr &lhs, const Expr &rhs) override; + std::pair unionBounds( + const std::vector> &bounds) override; }; class PBSimplify : public SimplifyPass { public: PBSimplify() : SimplifyPass([](const CompTransientBoundsInterface &tr) { - return Ref::make(tr); + return Ref::make(tr); }) {} }; diff --git a/include/pass/shrink_for.h b/include/pass/shrink_for.h index 94b9859f7..ac33a1db1 100644 --- a/include/pass/shrink_for.h +++ b/include/pass/shrink_for.h @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -26,9 +27,7 @@ class CheckSideEffect : public Visitor { class ShrinkFor : public CompTransientBounds> { typedef CompTransientBounds> BaseClass; - ASTHashMap>, - std::vector>>> - newRange_; + ASTHashMap>> newRange_; std::vector iterStack_; std::vector> namesStack_; @@ -39,6 +38,14 @@ class ShrinkFor : public CompTransientBounds> { public: void setSubAST(const Stmt &subAST); + protected: + virtual bool filterLoop(const For &op) { return true; } + + virtual std::unordered_set + filterNames(const std::unordered_set &names) { + return names; + } + protected: using BaseClass::visit; diff --git a/include/pass/shrink_var.h b/include/pass/shrink_var.h index 6e297d6ee..be86216a7 100644 --- a/include/pass/shrink_var.h +++ b/include/pass/shrink_var.h @@ -11,20 +11,34 @@ namespace freetensor { class ShrinkVar : public Mutator { - std::unordered_map> lower_, upper_; - const std::unordered_map &newRange_; + // Bound considering the old shape. Used for preventing make the shape even + // larger after shrinking + const std::unordered_map &newRangeWithShape_; + + // Bound without considering the old shape. Used for preventing redundant + // guards for maybe-unsafe user code + const std::unordered_map &newRangeWithoutShape_; + bool guardReads_; + + std::unordered_map> lowerWithShape_, + upperWithShape_; + std::unordered_map> lowerWithoutShape_, + upperWithoutShape_; std::unordered_map guards_; public: - ShrinkVar(const std::unordered_map &newRange, + ShrinkVar(const std::unordered_map &newRangeWithShape, + const std::unordered_map &newRangeWithoutShape, bool guardReads = false) - : newRange_(newRange), guardReads_(guardReads) {} + : newRangeWithShape_(newRangeWithShape), + newRangeWithoutShape_(newRangeWithoutShape), guardReads_(guardReads) { + } private: template T modifyAccess(const T &op) { - if (lower_.count(op->var_)) { - auto &&offset = lower_.at(op->var_); + if (lowerWithoutShape_.count(op->var_)) { + auto &&offset = lowerWithoutShape_.at(op->var_); ASSERT(offset.size() == op->indices_.size()); for (auto &&[idx, off] : views::zip(op->indices_, offset)) { if (off.isValid()) { @@ -39,8 +53,8 @@ class ShrinkVar : public Mutator { // We add check w.r.t oldOp because it is simplier, which brings less // redundancy to pass/simplify Expr guard; - if (upper_.count(op->var_)) { - auto &&upper = upper_.at(op->var_); + if (upperWithoutShape_.count(op->var_)) { + auto &&upper = upperWithoutShape_.at(op->var_); ASSERT(upper.size() == op->indices_.size()); for (auto &&[idx, u] : views::zip(oldOp->indices_, upper)) { if (u.isValid()) { @@ -49,8 +63,8 @@ class ShrinkVar : public Mutator { } } } - if (lower_.count(op->var_)) { - auto &&lower = lower_.at(op->var_); + if (lowerWithoutShape_.count(op->var_)) { + auto &&lower = lowerWithoutShape_.at(op->var_); ASSERT(lower.size() == op->indices_.size()); for (auto &&[idx, l] : views::zip(oldOp->indices_, lower)) { if (l.isValid()) { diff --git a/include/pass/simplify.h b/include/pass/simplify.h index 1d8f6b320..29b89adce 100644 --- a/include/pass/simplify.h +++ b/include/pass/simplify.h @@ -107,7 +107,7 @@ class BuiltinSimplify : public SimplifyPass { public: BuiltinSimplify() : SimplifyPass([](const CompTransientBoundsInterface &tr) { - return Ref::make(tr); + return Ref::make(tr); }) {} }; diff --git a/include/pass/use_builtin_div.h b/include/pass/use_builtin_div.h index 7bbcf3481..c66252c83 100644 --- a/include/pass/use_builtin_div.h +++ b/include/pass/use_builtin_div.h @@ -5,6 +5,7 @@ #include #include #include +#include #include namespace freetensor { diff --git a/include/serialize/print_ast.h b/include/serialize/print_ast.h index 7c20d3f45..f96f7f173 100644 --- a/include/serialize/print_ast.h +++ b/include/serialize/print_ast.h @@ -12,7 +12,8 @@ namespace freetensor { class PrintVisitor : public CodeGen { bool printAllId_ = false, pretty_ = false, dtypeInLoad_ = false, - hexFloat_ = false, printSourceLocation_ = false; + hexFloat_ = false, parenDespitePriority_ = false, + printSourceLocation_ = false; const std::unordered_set keywords = { "if", "else", "for", "in", "assert", "assume", "func", "true", "false", }; @@ -59,10 +60,12 @@ class PrintVisitor : public CodeGen { bool parentheses = true) { auto old_priority = precedence_; precedence_ = new_priority; - if (parentheses && old_priority > precedence_) + if (parentheses && + (parenDespitePriority_ || old_priority > precedence_)) os() << "("; inner(); - if (parentheses && old_priority > precedence_) + if (parentheses && + (parenDespitePriority_ || old_priority > precedence_)) os() << ")"; precedence_ = old_priority; } @@ -80,9 +83,11 @@ class PrintVisitor : public CodeGen { public: PrintVisitor(bool printAllId = false, bool pretty = false, bool dtypeInLoad = false, bool hexFloat = false, - bool compact = false, bool printSourceLocation = false) + bool compact = false, bool parenDespitePriority = false, + bool printSourceLocation = false) : CodeGen(compact), printAllId_(printAllId), pretty_(pretty), dtypeInLoad_(dtypeInLoad), hexFloat_(hexFloat), + parenDespitePriority_(parenDespitePriority), printSourceLocation_(printSourceLocation) { os() << manipNoIdSign(true) << (printSourceLocation ? manipMetadataWithLocation @@ -176,10 +181,10 @@ std::string toString(const AST &op, bool pretty); std::string toString(const AST &op, bool pretty, bool printAllId); std::string toString(const AST &op, bool pretty, bool printAllId, bool dtypeInLoad, bool hexFloat = false, - bool compact = false); + bool compact = false, bool parenDespitePriority = false); std::string toString(const AST &op, bool pretty, bool printAllId, bool dtypeInLoad, bool hexFloat, bool compact, - bool printSourceLocation); + bool parenDespitePriority, bool printSourceLocation); /** @} */ /** @@ -187,7 +192,7 @@ std::string toString(const AST &op, bool pretty, bool printAllId, */ inline std::string dumpAST(const AST &op, bool dtypeInLoad = false, bool hexFloat = true) { - return toString(op, false, true, dtypeInLoad, hexFloat, true, false); + return toString(op, false, true, dtypeInLoad, hexFloat, true, false, false); } /** diff --git a/src/analyze/comp_access_bound.cc b/src/analyze/comp_access_bound.cc index e46d84935..4f8c36846 100644 --- a/src/analyze/comp_access_bound.cc +++ b/src/analyze/comp_access_bound.cc @@ -3,6 +3,7 @@ #include #include #include +#include namespace freetensor { @@ -47,7 +48,7 @@ void FindMemType::visit(const VarDef &op) { void CompAccessBound::visitStmt(const Stmt &stmt) { // CompUniqueBounds requires one instance per Stmt auto uniqueOfOuterStmt = unique_; - unique_ = Ref::make(*this); + unique_ = Ref::make(*this); if (stmt->id() == filterSubTree_) { filtered_ = true; @@ -90,56 +91,21 @@ void CompAccessBound::visit(const VarDef &op) { } for (size_t i = 0; i < n; i++) { - std::vector> lower, upper; - for (size_t j = 0, jEnd = access_.size(); j < jEnd; j++) { - ASSERT(access_[j].indices_.size() == n); - auto &&index = access_[j].indices_[i]; - std::vector lowerItem; - if (checkAllDefined(defs_, index)) { - lowerItem.emplace_back(index); - } - bool insertedNonTrivialBounds = false; - for (auto &&b : access_[j].lower_[i]) { - if (!HashComparator{}(index, b.expr())) { - lowerItem.emplace_back(b.expr()); - insertedNonTrivialBounds = true; - } - } - if (includeTrivialBound_ || insertedNonTrivialBounds) { - // If `insertedNonTrivialBounds`, we still include the trivial - // bound, to avoid make a variable even larger after - // pass/shrink_var - lowerItem.emplace_back(makeIntConst(0)); - } - lower.emplace_back(std::move(lowerItem)); + // union the bounds of all accesses and get the lower and upper + // expression + auto [l, u] = unique_->unionBounds( + // get bounds of the i-th dimension + access_ | views::transform([&](auto &&a) { return a.bounds_[i]; }) | + // ... and pack into vector + ranges::to()); + // include the original trivial bounds, if specified + if (includeTrivialBound_) { + auto &&tl = makeIntConst(0); + auto &&tu = + makeSub(op->buffer_->tensor()->shape()[i], makeIntConst(1)); + l = l.isValid() ? makeMax(l, tl) : tl; + u = u.isValid() ? makeMin(u, tu) : tu; } - - for (size_t j = 0, jEnd = access_.size(); j < jEnd; j++) { - ASSERT(access_[j].indices_.size() == n); - auto &&index = access_[j].indices_[i]; - std::vector upperItem; - if (checkAllDefined(defs_, index)) { - upperItem.emplace_back(index); - } - bool insertedNonTrivialBounds = false; - for (auto &&b : access_[j].upper_[i]) { - if (!HashComparator{}(index, b.expr())) { - upperItem.emplace_back(b.expr()); - insertedNonTrivialBounds = true; - } - } - if (includeTrivialBound_ || insertedNonTrivialBounds) { - // If `insertedNonTrivialBounds`, we still include the trivial - // bound, to avoid make a variable even larger after - // pass/shrink_var - upperItem.emplace_back(makeSub( - op->buffer_->tensor()->shape()[i], makeIntConst(1))); - } - upper.emplace_back(std::move(upperItem)); - } - - auto l = makeMinMax(lower); - auto u = makeMaxMin(upper); result_.lower_.emplace_back(l); result_.upper_.emplace_back(u); if (l.isValid() && u.isValid()) { diff --git a/src/analyze/comp_unique_bounds.cc b/src/analyze/comp_unique_bounds.cc index 4cdd91935..5fba460c1 100644 --- a/src/analyze/comp_unique_bounds.cc +++ b/src/analyze/comp_unique_bounds.cc @@ -6,11 +6,132 @@ #include #include #include +#include +#include namespace freetensor { -void CompUniqueBounds::updLower(LowerBoundsList &list, - const LowerBound &bound) const { +namespace { + +class CountHeavyOps : public Visitor { + int cnt_ = 0; + + public: + int cnt() const { return cnt_; } + + protected: + void visitExpr(const Expr &op) { + Visitor::visitExpr(op); + if (!op->isConst() && op->nodeType() != ASTNodeType::Add && + op->nodeType() != ASTNodeType::Sub && + op->nodeType() != ASTNodeType::Mul) { + cnt_++; + } + } +}; + +static int countHeavyOps(const Expr &op) { + CountHeavyOps visitor; + visitor(op); + return visitor.cnt(); +} + +} // namespace + +int64_t CompUniqueBoundsCombination::Bound::lowerInt() const { + int64_t ret = LLONG_MIN; + for (auto &&b : lowerBounds_) { + if (b.lin().isConst()) { + auto bias = b.lin().bias_; + ret = std::max(ret, ceilDiv(bias.p_, bias.q_)); + } + } + return ret; +} +int64_t CompUniqueBoundsCombination::Bound::upperInt() const { + int64_t ret = LLONG_MAX; + for (auto &&b : upperBounds_) { + if (b.lin().isConst()) { + auto bias = b.lin().bias_; + ret = std::min(ret, floorDiv(bias.p_, bias.q_)); + } + } + return ret; +} +std::optional CompUniqueBoundsCombination::Bound::getInt() const { + auto lower = lowerInt(); + auto upper = upperInt(); + return lower == upper ? std::make_optional(lower) : std::nullopt; +} + +Expr CompUniqueBoundsCombination::Bound::lowerExpr() const { + Expr result; + for (LowerBound &b : lowerBounds_) { + if (result.isValid()) + result = makeMax(result, b.expr()); + else + result = b.expr(); + } + return result; +} +Expr CompUniqueBoundsCombination::Bound::upperExpr() const { + Expr result; + for (UpperBound &b : upperBounds_) { + if (result.isValid()) + result = makeMin(result, b.expr()); + else + result = b.expr(); + } + return result; +} + +Ref CompUniqueBoundsCombination::Bound::restrictScope( + const std::unordered_set &scope) const { + auto filter = views::filter([&](auto &b) { + return checkAllDefined(scope, b.allNames()); + }) | + ranges::to(); + return Ref::make(filter(lowerBounds_), filter(upperBounds_)); +} + +Expr CompUniqueBoundsCombination::Bound::simplestExpr( + const std::unordered_map &orderedScope) const { + Expr best = nullptr; + auto bestScope = -1, bestHeavyOps = -1; + for (auto &&lower : lowerBounds_) { + for (auto &&upper : upperBounds_) { + // Check upper <= lower ==> equal + // Here we use the less precise alwaysLE instead of analyzing bounds + // of `upper - lower`, in order to avoid infinite recursion + if (freetensor::alwaysLE(upper, lower)) { + // We need to choose the simplest one. Otherwise we are always + // picking the original expression + Expr expr; + if (upper.lin().coeff_.size() + (upper.lin().bias_ != 0) > + lower.lin().coeff_.size() + (lower.lin().bias_ != 0)) { + expr = lower.expr(); + } else { + expr = upper.expr(); + } + // firstly choose outermost innermost scope + int scope = 0; + for (auto &&use : allUses(expr)) + scope = std::max(scope, orderedScope.at(use)); + // secondly choose the one with least heavy operations + auto heavyOps = countHeavyOps(expr); + if (!best.isValid() || scope < bestScope || + (scope == bestScope && heavyOps < bestHeavyOps)) { + best = expr, bestScope = scope, bestHeavyOps = heavyOps; + } + break; + } + } + } + return best; +} + +void CompUniqueBoundsCombination::updLower(LowerBoundsList &list, + const LowerBound &bound) const { for (LowerBound &old : list) { // The same .expr() does not mean the same bounds // E.g. 1 * floor(a / 4) vs. (1/4) * a @@ -30,8 +151,8 @@ void CompUniqueBounds::updLower(LowerBoundsList &list, list.emplace_back(bound); } -void CompUniqueBounds::updUpper(UpperBoundsList &list, - const UpperBound &bound) const { +void CompUniqueBoundsCombination::updUpper(UpperBoundsList &list, + const UpperBound &bound) const { for (UpperBound &old : list) { // The same .expr() does not mean the same bounds // E.g. 1 * floor(a / 4) vs. (1/4) * a @@ -51,57 +172,32 @@ void CompUniqueBounds::updUpper(UpperBoundsList &list, list.emplace_back(bound); } -int64_t CompUniqueBounds::getIntLower(const Expr &op) { - int64_t ret = LLONG_MIN; - for (auto &&b : getLower(op)) { - if (b.lin().isConst()) { - auto bias = b.lin().bias_; - ret = std::max(ret, ceilDiv(bias.p_, bias.q_)); - } - } - return ret; -} - -int64_t CompUniqueBounds::getIntUpper(const Expr &op) { - int64_t ret = LLONG_MAX; - for (auto &&b : getUpper(op)) { - if (b.lin().isConst()) { - auto bias = b.lin().bias_; - ret = std::min(ret, floorDiv(bias.p_, bias.q_)); - } - } - return ret; -} - -std::optional CompUniqueBounds::getInt(const Expr &op) { - auto lower = getIntLower(op); - auto upper = getIntUpper(op); - return lower == upper ? std::make_optional(lower) : std::nullopt; -} - -CompUniqueBounds::LowerBoundsList CompUniqueBounds::getDefinedLower( - const Expr &op, const std::unordered_set &names) { - LowerBoundsList ret; - for (auto &&b : getLower(op)) { - if (checkAllDefined(names, b.allNames())) { - ret.emplace_back(b); - } - } - return ret; -} - -CompUniqueBounds::UpperBoundsList CompUniqueBounds::getDefinedUpper( - const Expr &op, const std::unordered_set &names) { - UpperBoundsList ret; - for (auto &&b : getUpper(op)) { - if (checkAllDefined(names, b.allNames())) { - ret.emplace_back(b); - } - } - return ret; +Ref +CompUniqueBoundsCombination::getBound(const Expr &op) { + auto lower = getLower(op); + bool selfInLower = false; + for (auto &&l : lower) + if (op == l.expr()) { + selfInLower = true; + break; + } + if (!selfInLower) + lower.emplace_back(op); + + auto upper = getUpper(op); + bool selfInUpper = false; + for (auto &&u : upper) + if (op == u.expr()) { + selfInUpper = true; + break; + } + if (!selfInUpper) + upper.emplace_back(op); + + return Ref::make(std::move(lower), std::move(upper)); } -bool CompUniqueBounds::alwaysLT(const Expr &lhs, const Expr &rhs) { +bool CompUniqueBoundsCombination::alwaysLT(const Expr &lhs, const Expr &rhs) { for (auto &&b1 : getUpper(lhs)) { for (auto &&b2 : getLower(rhs)) { if (freetensor::alwaysLT(b1, b2)) { @@ -112,7 +208,7 @@ bool CompUniqueBounds::alwaysLT(const Expr &lhs, const Expr &rhs) { return false; } -bool CompUniqueBounds::alwaysLE(const Expr &lhs, const Expr &rhs) { +bool CompUniqueBoundsCombination::alwaysLE(const Expr &lhs, const Expr &rhs) { for (auto &&b1 : getUpper(lhs)) { for (auto &&b2 : getLower(rhs)) { if (freetensor::alwaysLE(b1, b2)) { @@ -123,7 +219,24 @@ bool CompUniqueBounds::alwaysLE(const Expr &lhs, const Expr &rhs) { return false; } -void CompUniqueBounds::insertSignDataTypeInfo(const Expr &op) { +std::pair CompUniqueBoundsCombination::unionBounds( + const std::vector> &bounds) { + std::vector> lowers, uppers; + for (auto &&rb : bounds) { + ASSERT(rb->type() == BoundType::Combination); + Bound &b = *rb.as().get(); + std::vector lowerTerm, upperTerm; + for (auto &&l : b.lowerBounds_) + lowerTerm.emplace_back(l.expr()); + for (auto &&u : b.upperBounds_) + upperTerm.emplace_back(u.expr()); + lowers.emplace_back(std::move(lowerTerm)); + uppers.emplace_back(std::move(upperTerm)); + } + return {makeMinMax(lowers), makeMaxMin(uppers)}; +} + +void CompUniqueBoundsCombination::insertSignDataTypeInfo(const Expr &op) { switch (op->dtype().sign()) { case SignDataType::GT0: updLower(lower_[op], LowerBound{LinearExpr>{{}, 1}}); @@ -141,7 +254,7 @@ void CompUniqueBounds::insertSignDataTypeInfo(const Expr &op) { } } -void CompUniqueBounds::visitExpr(const Expr &op) { +void CompUniqueBoundsCombination::visitExpr(const Expr &op) { if (lower_.count(op) || upper_.count(op)) { return; } @@ -176,32 +289,32 @@ void CompUniqueBounds::visitExpr(const Expr &op) { } } -void CompUniqueBounds::visit(const Var &op) { +void CompUniqueBoundsCombination::visit(const Var &op) { BaseClass::visit(op); updLower(lower_[op], LowerBound{op}); updUpper(upper_[op], UpperBound{op}); } -void CompUniqueBounds::visit(const Load &op) { +void CompUniqueBoundsCombination::visit(const Load &op) { BaseClass::visit(op); updLower(lower_[op], LowerBound{op}); updUpper(upper_[op], UpperBound{op}); insertSignDataTypeInfo(op); } -void CompUniqueBounds::visit(const Cast &op) { +void CompUniqueBoundsCombination::visit(const Cast &op) { BaseClass::visit(op); // TODO: Use the expression itself as a bound just like Load? insertSignDataTypeInfo(op); } -void CompUniqueBounds::visit(const Intrinsic &op) { +void CompUniqueBoundsCombination::visit(const Intrinsic &op) { BaseClass::visit(op); // TODO: Use the expression itself as a bound just like Load? insertSignDataTypeInfo(op); } -void CompUniqueBounds::visit(const IntConst &op) { +void CompUniqueBoundsCombination::visit(const IntConst &op) { BaseClass::visit(op); updLower(lower_[op], LowerBound{LinearExpr>{{}, op->val_}}); @@ -209,7 +322,7 @@ void CompUniqueBounds::visit(const IntConst &op) { UpperBound{LinearExpr>{{}, op->val_}}); } -void CompUniqueBounds::visitLinear(const Expr &op) { +void CompUniqueBoundsCombination::visitLinear(const Expr &op) { auto &lower = lower_[op]; auto &upper = upper_[op]; @@ -291,22 +404,22 @@ void CompUniqueBounds::visitLinear(const Expr &op) { upper = std::move(retUpper); } -void CompUniqueBounds::visit(const Add &op) { +void CompUniqueBoundsCombination::visit(const Add &op) { // no need to recurse. getLower or getUpper recurses visitLinear(op); } -void CompUniqueBounds::visit(const Sub &op) { +void CompUniqueBoundsCombination::visit(const Sub &op) { // no need to recurse. getLower or getUpper recurses visitLinear(op); } -void CompUniqueBounds::visit(const Mul &op) { +void CompUniqueBoundsCombination::visit(const Mul &op) { // no need to recurse. getLower or getUpper recurses visitLinear(op); } -void CompUniqueBounds::visit(const Square &op) { +void CompUniqueBoundsCombination::visit(const Square &op) { // no need to recurse. getLower or getUpper recurses auto &lower = lower_[op]; @@ -317,7 +430,7 @@ void CompUniqueBounds::visit(const Square &op) { } } -void CompUniqueBounds::visit(const FloorDiv &op) { +void CompUniqueBoundsCombination::visit(const FloorDiv &op) { // no need to recurse. getLower or getUpper recurses auto &lower = lower_[op]; @@ -341,7 +454,7 @@ void CompUniqueBounds::visit(const FloorDiv &op) { } } -void CompUniqueBounds::visit(const CeilDiv &op) { +void CompUniqueBoundsCombination::visit(const CeilDiv &op) { // no need to recurse. getLower or getUpper recurses auto &lower = lower_[op]; @@ -365,7 +478,7 @@ void CompUniqueBounds::visit(const CeilDiv &op) { } } -void CompUniqueBounds::visit(const Mod &op) { +void CompUniqueBoundsCombination::visit(const Mod &op) { // no need to recurse. getLower or getUpper recurses if (auto &&l = getInt(op->lhs_); l.has_value()) { if (auto &&r = getInt(op->rhs_); r.has_value()) { @@ -393,7 +506,7 @@ void CompUniqueBounds::visit(const Mod &op) { } } -void CompUniqueBounds::visit(const Min &op) { +void CompUniqueBoundsCombination::visit(const Min &op) { // no need to recurse. getLower or getUpper recurses auto &lower = lower_[op]; auto &upper = upper_[op]; @@ -468,7 +581,7 @@ void CompUniqueBounds::visit(const Min &op) { } } -void CompUniqueBounds::visit(const Max &op) { +void CompUniqueBoundsCombination::visit(const Max &op) { // no need to recurse. getLower or getUpper recurses auto &lower = lower_[op]; auto &upper = upper_[op]; @@ -543,7 +656,7 @@ void CompUniqueBounds::visit(const Max &op) { } } -void CompUniqueBounds::visit(const IfExpr &op) { +void CompUniqueBoundsCombination::visit(const IfExpr &op) { // no need to recurse. getLower or getUpper recurses auto &lower = lower_[op]; auto &upper = upper_[op]; diff --git a/src/analyze/structural_feature.cc b/src/analyze/structural_feature.cc index 34dde7851..711400bde 100644 --- a/src/analyze/structural_feature.cc +++ b/src/analyze/structural_feature.cc @@ -117,40 +117,18 @@ int64_t StructuralFeature::calcArea( int64_t area = 1; size_t n = accesses.front().indices_.size(); for (size_t i = 0; i < n; i++) { - std::vector> lower, upper; - for (size_t j = 0, jEnd = accesses.size(); j < jEnd; j++) { - ASSERT(accesses[j].indices_.size() == n); - auto &&index = accesses[j].indices_[i]; - std::vector lowerItem({makeIntConst(0)}); - if (checkAllDefined(names(), index)) { - lowerItem.emplace_back(index); - } - for (auto b : accesses[j].lower_[i]) { - if (checkAllDefined(names(), b.allNames())) { - lowerItem.emplace_back(b.expr()); - } - } - lower.emplace_back(std::move(lowerItem)); - } - - for (size_t j = 0, jEnd = accesses.size(); j < jEnd; j++) { - ASSERT(accesses[j].indices_.size() == n); - auto &&index = accesses[j].indices_[i]; - std::vector upperItem( - {makeSub(buffer(var)->tensor()->shape()[i], makeIntConst(1))}); - if (checkAllDefined(names(), index)) { - upperItem.emplace_back(index); - } - for (auto b : accesses[j].upper_[i]) { - if (checkAllDefined(names(), b.allNames())) { - upperItem.emplace_back(b.expr()); - } - } - upper.emplace_back(std::move(upperItem)); - } - - auto l = makeMinMax(lower); - auto u = makeMaxMin(upper); + // union the bounds of all accesses and get the lower and upper + // expression + auto [l, u] = bound_->unionBounds( + // get bounds of the i-th dimension + accesses | views::transform([&](auto &&a) { + return a.bounds_[i]->restrictScope(names()); + }) | + // ... and pack into vector + ranges::to()); + l = makeMax(l, makeIntConst(0)); + u = makeMin( + u, makeSub(buffer(var)->tensor()->shape()[i], makeIntConst(1))); auto len = makeAdd(makeSub(u, l), makeIntConst(1)); if (auto constLen = bound_->getInt(len); constLen.has_value()) { area *= *constLen; @@ -207,7 +185,7 @@ void StructuralFeature::visitUnaryOp(const UnaryExpr &op) { void StructuralFeature::visitStmt(const Stmt &op) { auto boundOfOuterStmt = bound_; - bound_ = Ref::make(*this); + bound_ = Ref::make(*this); BaseClass::visitStmt(op); calcFeatures(op); bound_ = boundOfOuterStmt; diff --git a/src/autograd/analyze_version.cc b/src/autograd/analyze_version.cc index 476ef8ee6..6258aad1a 100644 --- a/src/autograd/analyze_version.cc +++ b/src/autograd/analyze_version.cc @@ -85,16 +85,13 @@ void CountScopeLen::visit(const For &op) { Expr len = op->len_; if (!allIters(len).empty()) { // Need to relax - CompUniqueBounds bound(*this); + CompUniqueBoundsCombination bound(*this); auto allowedNames = ranges::to( defs() | views::keys); // Names from all `VarDef`s but not `For`s - Expr relaxedInnerLen; - for (auto &&b : bound.getDefinedUpper(len, allowedNames)) { - relaxedInnerLen = relaxedInnerLen.isValid() - ? makeMin(relaxedInnerLen, b.expr()) - : b.expr(); - } + Expr relaxedInnerLen = bound.getBound(len) + ->restrictScope(allowedNames) + ->upperExpr(); if (!relaxedInnerLen.isValid()) { // Fallback to dynamic tape ASSERT(false); // Unimplemented diff --git a/src/math/parse_pb_expr.cc b/src/math/parse_pb_expr.cc index cbfe994c8..873877b96 100644 --- a/src/math/parse_pb_expr.cc +++ b/src/math/parse_pb_expr.cc @@ -1,13 +1,27 @@ #include +#include +#include +#include +#include +#include #include #include #include #include #include +#include namespace freetensor { +std::ostream &operator<<(std::ostream &os, const SimplePBFuncAST &ast) { + os << "[" << ast.args_ << "] -> [" << ast.values_ << "]"; + if (ast.cond_.isValid()) { + os << " : " << ast.cond_; + } + return os; +} + namespace { /** @@ -119,4 +133,214 @@ SimplePBFuncAST parseSimplePBFunc(const std::string &str) { return ret.front(); } +namespace { + +Expr isl2Expr(__isl_take isl_ast_expr *e) { + Expr res; + try { + switch (isl_ast_expr_get_type(e)) { + case isl_ast_expr_id: { + auto id = isl_ast_expr_get_id(e); + std::string name = isl_id_get_name(id); + res = makeVar(name); + isl_id_free(id); + break; + } + case isl_ast_expr_int: { + auto val = isl_ast_expr_get_val(e); + ASSERT(isl_val_get_den_si(val) == 1); + res = makeIntConst(isl_val_get_num_si(val)); + isl_val_free(val); + break; + } + case isl_ast_expr_op: { + auto args = views::ints(0, isl_ast_expr_op_get_n_arg(e)) | + views::transform([&](int i) { + auto result = + isl2Expr(isl_ast_expr_op_get_arg(e, i)); + return result; + }) | + ranges::to_vector; + switch (isl_ast_expr_op_get_type(e)) { + case isl_ast_expr_op_and: + ASSERT(args.size() == 2); + res = makeLAnd(args[0], args[1]); + break; + case isl_ast_expr_op_or: + ASSERT(args.size() == 2); + res = makeLOr(args[0], args[1]); + break; + case isl_ast_expr_op_max: { + ASSERT(!args.empty()); + Expr result = args[0]; + for (size_t i = 1; i < args.size(); ++i) + result = makeMax(result, args[i]); + res = result; + } break; + case isl_ast_expr_op_min: { + ASSERT(!args.empty()); + Expr result = args[0]; + for (size_t i = 1; i < args.size(); ++i) + result = makeMin(result, args[i]); + res = result; + } break; + case isl_ast_expr_op_add: + ASSERT(args.size() == 2); + res = makeAdd(args[0], args[1]); + break; + case isl_ast_expr_op_sub: + ASSERT(args.size() == 2); + res = makeSub(args[0], args[1]); + break; + case isl_ast_expr_op_minus: + ASSERT(args.size() == 1); + res = makeMul(makeIntConst(-1), args[0]); + break; + case isl_ast_expr_op_mul: + ASSERT(args.size() == 2); + res = makeMul(args[0], args[1]); + break; + case isl_ast_expr_op_div: // Exact division. Any rounding is OK. By + // defaults we use FloorDiv + case isl_ast_expr_op_fdiv_q: // Floor division + case isl_ast_expr_op_pdiv_q: // Floor division on non-negative + // divisor + ASSERT(args.size() == 2); + res = makeFloorDiv(args[0], args[1]); + break; + case isl_ast_expr_op_pdiv_r: // Remainder on non-negative divisor. + // Equivalent to Mod. We prefer Mod + // over Remainder + case isl_ast_expr_op_zdiv_r: // Divisible ? 0 : any non-zero value + ASSERT(args.size() == 2); + res = makeMod(args[0], args[1]); + break; + case isl_ast_expr_op_select: + ASSERT(args.size() == 3); + res = makeIfExpr(args[0], args[1], args[2]); + break; + case isl_ast_expr_op_eq: + ASSERT(args.size() == 2); + res = makeEQ(args[0], args[1]); + break; + case isl_ast_expr_op_le: + ASSERT(args.size() == 2); + res = makeLE(args[0], args[1]); + break; + case isl_ast_expr_op_lt: + ASSERT(args.size() == 2); + res = makeLT(args[0], args[1]); + break; + case isl_ast_expr_op_ge: + ASSERT(args.size() == 2); + res = makeGE(args[0], args[1]); + break; + case isl_ast_expr_op_gt: + ASSERT(args.size() == 2); + res = makeGT(args[0], args[1]); + break; + default: + ASSERT(false); + } + } break; + default: + ASSERT(false); + } + } catch (...) { + isl_ast_expr_free(e); + throw; + } + isl_ast_expr_free(e); + return res; +} + +PBFuncAST isl2Func(__isl_take isl_ast_node *node) { + PBFuncAST ret; + try { + if (isl_ast_node_get_type(node) == isl_ast_node_if) { + auto cond = isl2Expr(isl_ast_node_if_get_cond(node)); + for (auto &&[thenNames, thenFT, thenCond] : + isl2Func(isl_ast_node_if_get_then(node))) { + ret.push_back(SimplePBFuncAST{ + thenNames, thenFT, + thenCond.isValid() ? makeLAnd(cond, thenCond) : cond}); + } + if (isl_ast_node_if_has_else(node)) { + for (auto &&[elseNames, elseFT, elseCond] : + isl2Func(isl_ast_node_if_get_else(node))) { + ret.push_back(SimplePBFuncAST{ + elseNames, elseFT, + elseCond.isValid() ? makeLAnd(makeLNot(cond), elseCond) + : makeLNot(cond)}); + } + } + + } else { + // otherwise, node is a user node + ASSERT(isl_ast_node_get_type(node) == isl_ast_node_user); + auto expr = isl_ast_node_user_get_expr(node); + try { + ASSERT(isl_ast_expr_get_type(expr) == isl_ast_expr_op); + ASSERT(isl_ast_expr_op_get_type(expr) == isl_ast_expr_op_call); + auto nVals = + isl_ast_expr_op_get_n_arg(expr) - + 1; // Arguments of the user node is values we need. The + // first arumgnet of the user node is its name + auto vals = + views::ints(1, nVals + 1) | views::transform([&](int i) { + return isl2Expr(isl_ast_expr_op_get_arg(expr, i)); + }) | + ranges::to_vector; + + std::unordered_set names; + for (auto &&item : vals) { + for (auto &&name : allNames(item)) { + names.insert(name); + } + } + ret = {SimplePBFuncAST{ranges::to(names), vals, + nullptr}}; + } catch (...) { + isl_ast_expr_free(expr); + throw; + } + isl_ast_expr_free(expr); + } + } catch (...) { + isl_ast_node_free(node); + throw; + } + isl_ast_node_free(node); + return ret; +} + +} // Anonymous namespace + +PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBSet &set) { + // This is a hack to isl's schedule. Treat the set as an iteration domain. + // For a single-valued set, the domain will be zero or one statement, + // implemented by a statement in multiple branches. We can recover Expr from + // the statement and the branches' conditions. + + ASSERT(set.isSingleValued()); + + isl_options_set_ast_build_detect_min_max(ctx.get(), 1); + + PBFuncAST ret; + isl_ast_build *build = isl_ast_build_alloc(ctx.get()); + try { + isl_schedule *s = + isl_schedule_from_domain(isl_union_set_from_set(set.copy())); + isl_ast_node *ast = + isl_ast_build_node_from_schedule(build /* keep */, s /* take */); + ret = isl2Func(ast /* take */); + } catch (...) { + isl_ast_build_free(build); + throw; + } + isl_ast_build_free(build); + + return ret; +} + } // namespace freetensor diff --git a/src/pass/gpu/lower_parallel_reduction.cc b/src/pass/gpu/lower_parallel_reduction.cc index ea2b01c32..9825ed860 100644 --- a/src/pass/gpu/lower_parallel_reduction.cc +++ b/src/pass/gpu/lower_parallel_reduction.cc @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -250,7 +251,7 @@ Stmt CorrectInterThreadDependence::visit(const For &op) { for (auto &&ws : it->second) { VarDef vardef = ws; - CompUniqueBounds unique(*this); + CompUniqueBoundsCombination unique(*this); auto &&red = ws2red_.at(ws->id()).second; auto &shape = vardef->buffer_->tensor()->shape(); for (auto &&[dim, oldBegin, oldEnd] : @@ -258,14 +259,11 @@ Stmt CorrectInterThreadDependence::visit(const For &op) { red->begins_, red->ends_)) { for (auto &&name : allNames(dim)) { if (!names().count(name)) { - Expr newDim; - for (auto &&b : unique.getDefinedUpper( - makeMin(dim, makeSub(oldEnd, oldBegin)), - names())) { - newDim = newDim.isValid() - ? makeMin(std::move(newDim), b.expr()) - : b.expr(); - } + Expr newDim = unique + .getBound(makeMin( + dim, makeSub(oldEnd, oldBegin))) + ->restrictScope(names()) + ->upperExpr(); ASSERT(newDim.isValid()); dim = std::move(newDim); break; @@ -353,15 +351,24 @@ Stmt lowerParallelReduction(const Stmt &_op) { // 4. Try to make the workspace smaller by `pass/shrink_var`. Here we // use custom bounds only considering the real use of the workspaces - std::unordered_map bounds; + std::unordered_map boundsWithShape, boundsWithoutShape; for (auto &&[wsId, scopeId] : insertBinaryReduction.ws2scope()) { - bounds[wsId] = compAccessBound(op, wsId, COMP_ACCESS_BOUND_READ, - false, scopeId); + boundsWithShape[wsId] = compAccessBound( + op, wsId, COMP_ACCESS_BOUND_READ, true, scopeId); + boundsWithoutShape[wsId] = compAccessBound( + op, wsId, COMP_ACCESS_BOUND_READ, false, scopeId); + + // Don't touch the thread dimension + boundsWithShape[wsId].lower_[0] = nullptr; + boundsWithShape[wsId].upper_[0] = nullptr; + boundsWithShape[wsId].len_[0] = nullptr; + boundsWithoutShape[wsId].lower_[0] = nullptr; + boundsWithoutShape[wsId].upper_[0] = nullptr; + boundsWithoutShape[wsId].len_[0] = nullptr; } - op = ShrinkVar(bounds, true)(op); + op = ShrinkVar(boundsWithShape, boundsWithoutShape, true)(op); - // 5. Simplify, to flatten singleton loops, and to simplify the - // expressions from `pass/shrink_var` + // 5. Simplify expressions from `pass/shrink_var` op = simplify(op); // 6. As per our definition of inter-thread dependence, a VarDef defined @@ -374,6 +381,9 @@ Stmt lowerParallelReduction(const Stmt &_op) { op = CorrectInterThreadDependence(insertWorkspaces.ws2red())(op); } + // Find and flatten singleton loops + op = shrinkFor(op); + op = simplify(op); return op; } diff --git a/src/pass/gpu/make_sync.cc b/src/pass/gpu/make_sync.cc index b7d838179..979ef2a3e 100644 --- a/src/pass/gpu/make_sync.cc +++ b/src/pass/gpu/make_sync.cc @@ -48,7 +48,7 @@ class RelaxOneLoop : public CompTransientBounds> { auto op = __op.as(); if (op->id() == loopId_) { - CompUniqueBounds bound(*this); + CompUniqueBoundsCombination bound(*this); // Already normalized ASSERT(op->begin_->nodeType() == ASTNodeType::IntConst && op->begin_.as()->val_ == 0); diff --git a/src/pass/gpu/normalize_thread_dims.cc b/src/pass/gpu/normalize_thread_dims.cc index 34ff69ba7..1feab9c56 100644 --- a/src/pass/gpu/normalize_thread_dims.cc +++ b/src/pass/gpu/normalize_thread_dims.cc @@ -39,21 +39,20 @@ Stmt NormalizeThreadDims::visit(const For &_op) { inKernel_ = oldInKernel; // CompUniqueBounds requires one instance per Stmt - CompUniqueBounds bound(*this); + CompUniqueBoundsCombination bound(*this); + + std::unordered_set allLegalNames; + for (auto &&name : names()) { + if (isLegalLen({name})) + allLegalNames.emplace(name); + } if (!isLegalLen(op->begin_)) { op->body_ = makeIf( makeUnbound(makeGE(makeVar(op->iter_), op->begin_)), op->body_); - Expr begin; - for (auto &&b : bound.getLower(op->begin_)) { - if (isLegalLen(b.allNames())) { - if (b.lin().isConst() && b.lin().bias_ <= INT_MIN + 1) { - continue; - } - begin = - begin.isValid() ? makeMax(begin, b.expr()) : b.expr(); - } - } + Expr begin = bound.getBound(op->begin_) + ->restrictScope(allLegalNames) + ->lowerExpr(); if (!begin.isValid()) { throw InvalidProgram( "Length of " + toString(op->property_->parallel_) + @@ -68,15 +67,9 @@ Stmt NormalizeThreadDims::visit(const For &_op) { if (!isLegalLen(op->end_)) { op->body_ = makeIf( makeUnbound(makeLT(makeVar(op->iter_), op->end_)), op->body_); - Expr end; - for (auto &&b : bound.getUpper(op->end_)) { - if (isLegalLen(b.allNames())) { - if (b.lin().isConst() && b.lin().bias_ >= INT_MAX - 1) { - continue; - } - end = end.isValid() ? makeMin(end, b.expr()) : b.expr(); - } - } + Expr end = bound.getBound(op->end_) + ->restrictScope(allLegalNames) + ->upperExpr(); if (!end.isValid()) { throw InvalidProgram( "Length of " + toString(op->property_->parallel_) + diff --git a/src/pass/gpu/normalize_threads.cc b/src/pass/gpu/normalize_threads.cc index 643b00d14..ee383c7e1 100644 --- a/src/pass/gpu/normalize_threads.cc +++ b/src/pass/gpu/normalize_threads.cc @@ -3,10 +3,10 @@ #include #include -#include #include #include -#include +#include +#include namespace freetensor { @@ -145,13 +145,55 @@ Stmt NormalizeThreads::visit(const Eval &op) { return doVisitStmt(Mutator::visit(op)); } +bool ShrinkNormalizedThreads::filterLoop(const For &op) { + return std::holds_alternative(op->property_->parallel_); +} + +std::unordered_set ShrinkNormalizedThreads::filterNames( + const std::unordered_set &names) { + std::unordered_set ret; + for (auto &&name : names) { + if (hasLoop(name)) { + // Only iterators from outside of the kernel is OK + if (openLoopsInKernel_.count(loop(name))) { + continue; + } + } else if (!hasDef(name) || buffer(name)->mtype() != MemType::ByValue) { + continue; + } + ret.insert(name); + } + return ret; +} + +Stmt ShrinkNormalizedThreads::visit(const For &op) { + Stmt ret; + if (std::holds_alternative(op->property_->parallel_)) { + openLoopsInKernel_.insert(op); + auto oldInKernel = inKernel_; + inKernel_ = true; + ret = BaseClass::visit(op); + inKernel_ = oldInKernel; + openLoopsInKernel_.erase(op); + } else { + if (inKernel_) { + openLoopsInKernel_.insert(op); + ret = BaseClass::visit(op); + openLoopsInKernel_.erase(op); + } else { + ret = BaseClass::visit(op); + } + } + return ret; +} + Stmt normalizeThreads(const Stmt &_op) { auto op = normalizeLoops(_op, [](const For &l) { return std::holds_alternative(l->property_->parallel_); }); op = NormalizeThreads(op)(op); - op = shrinkFor(op); - op = normalizeThreadDims(op); + op = ShrinkNormalizedThreads{}(op); + op = simplify(z3Simplify(op)); // NOTE: Although we have inserted a lot of identical `if`s, we must delay // `pass/merge_and_hoist_if` until we have done `pass/gpu/make_sync`. // Otherwise, we are introducing dependences between an `if`'s "then" case diff --git a/src/pass/gpu/normalize_var_in_kernel.cc b/src/pass/gpu/normalize_var_in_kernel.cc index 7b896f172..5b377e5fb 100644 --- a/src/pass/gpu/normalize_var_in_kernel.cc +++ b/src/pass/gpu/normalize_var_in_kernel.cc @@ -53,15 +53,11 @@ Stmt NormalizeVarInKernel::visit(const VarDef &_op) { auto op = __op.as(); // CompUniqueBounds requires one instance per Stmt - CompUniqueBounds unique(*this); + CompUniqueBoundsCombination unique(*this); for (auto &dim : op->buffer_->tensor()->shape()) { - Expr newDim; - for (auto &&b : unique.getDefinedUpper( - dim, ranges::to(legalNames_))) { - newDim = newDim.isValid() ? makeMin(std::move(newDim), b.expr()) - : b.expr(); - } + Expr newDim = + unique.getBound(dim)->restrictScope(legalNames_)->upperExpr(); if (!newDim.isValid()) { throw InvalidProgram( "The shape of " + toString(op->id()) + " " + op->name_ + @@ -91,9 +87,9 @@ Stmt NormalizeVarInKernel::visit(const VarDef &_op) { return op; } } else { - legalNames_.emplace_back(_op->name_); + legalNames_.insert(_op->name_); auto ret = BaseClass::visit(_op); - legalNames_.pop_back(); + legalNames_.erase(_op->name_); return ret; } } @@ -120,9 +116,9 @@ Stmt NormalizeVarInKernel::visit(const For &op) { nameCntInKernel_.clear(); return ret; } else if (!inKernel_) { // out of kernel - legalNames_.emplace_back(op->iter_); + legalNames_.insert(op->iter_); auto ret = BaseClass::visit(op); - legalNames_.pop_back(); + legalNames_.erase(op->iter_); return ret; } else { // in kernel return BaseClass::visit(op); diff --git a/src/pass/make_parallel_reduction.cc b/src/pass/make_parallel_reduction.cc index 0bc03cd58..c26a7360f 100644 --- a/src/pass/make_parallel_reduction.cc +++ b/src/pass/make_parallel_reduction.cc @@ -96,61 +96,40 @@ Stmt MakeLoopCarriedReduction::visit(const ReduceTo &_op) { return op; } - CompUniqueBounds unique(*this); + CompUniqueBoundsCombination unique(*this); for (auto &&loopId : paraLoopStack_ | views::slice(needSyncUpTo + 1, (int)paraLoopStack_.size())) { if (toAlter_.at(op->id()).count(loopId)) { - std::vector> lowers, uppers; // [dim][bound] + std::vector> bounds; // [dim] for (auto &&[i, idx, dim] : views::zip( views::ints(0, ranges::unreachable), _op->indices_, buffer(_op->var_)->tensor()->shape())) { - std::vector dimLowers{makeIntConst(0)}, - dimUppers{dim}; - for (auto &&item : unique.getDefinedLower( - idx, scopeDefined_.at(loopId))) { - dimLowers.emplace_back(item.expr()); - } - for (auto &&item : unique.getDefinedUpper( - idx, scopeDefined_.at(loopId))) { - dimUppers.emplace_back(item.expr()); - } - lowers.emplace_back(std::move(dimLowers)); - uppers.emplace_back(std::move(dimUppers)); + bounds.emplace_back(unique.getBound(idx)->restrictScope( + scopeDefined_.at(loopId))); } - for (auto &[redOp, var, allLowers, allUppers, syncFlush] : - forReductions_[loopId]) { - // allLowers, allUppers : [dim][access][bound] + for (auto &[redOp, var, allBounds, syncFlush] : + forReductions_[loopId]) { // allBounds : [dim][access] if (redOp == op->op_ && var == op->var_) { - ASSERT(allLowers.size() == lowers.size()); - ASSERT(allUppers.size() == uppers.size()); - for (auto &&[allLowersItem, lowersItem] : - views::zip(allLowers, lowers)) { - allLowersItem.emplace_back(lowersItem); - } - for (auto &&[allUppersItem, uppersItem] : - views::zip(allUppers, uppers)) { - allUppersItem.emplace_back(uppersItem); + ASSERT(allBounds.size() == bounds.size()); + for (auto &&[allBoundsItem, boundsItem] : + views::zip(allBounds, bounds)) { + allBoundsItem.emplace_back(boundsItem); } syncFlush |= needSyncUpTo >= 0; goto done; } } { - std::vector>> allLowers( - lowers.size()), - allUppers(uppers.size()); - for (auto &&[allLowersItem, lowersItem] : - views::zip(allLowers, lowers)) { - allLowersItem.emplace_back(lowersItem); - } - for (auto &&[allUppersItem, uppersItem] : - views::zip(allUppers, uppers)) { - allUppersItem.emplace_back(uppersItem); + std::vector>> + allBounds(bounds.size()); + for (auto &&[allBoundsItem, boundsItem] : + views::zip(allBounds, bounds)) { + allBoundsItem.emplace_back(boundsItem); } forReductions_[loopId].emplace_back(ReductionItemFactors{ - op->op_, op->var_, std::move(allLowers), - std::move(allUppers), needSyncUpTo >= 0}); + op->op_, op->var_, std::move(allBounds), + needSyncUpTo >= 0}); } done:; } @@ -172,16 +151,17 @@ Stmt MakeLoopCarriedReduction::visit(const For &_op) { scopeDefined_.erase(_op->id()); paraScopes_.erase(_op->id()); + CompUniqueBoundsCombination unique(*this); if (forReductions_.count(op->id())) { - for (auto &&[redOp, var, allLowers, allUppers, syncFlush] : + for (auto &&[redOp, var, allBounds, syncFlush] : forReductions_.at(op->id())) { std::vector begins, ends; - for (auto &&dimLowers : allLowers) { - begins.emplace_back(makeMinMax(dimLowers)); - } - for (auto &&dimUppers : allUppers) { + for (auto &&[dimBounds, dimVarSize] : + views::zip(allBounds, buffer(var)->tensor()->shape())) { + auto [l, u] = unique.unionBounds(dimBounds); + begins.emplace_back(makeMax(makeIntConst(0), l)); ends.emplace_back( - makeAdd(makeMaxMin(dimUppers), makeIntConst(1))); + makeMin(dimVarSize, makeAdd(u, makeIntConst(1)))); } op->property_->reductions_.emplace_back(makeReductionItem( redOp, var, std::move(begins), std::move(ends), syncFlush)); @@ -254,12 +234,12 @@ Stmt MakeSyncReduction::visit(const ReduceTo &_op) { // There will be no cross-thread dependences except the reduction we // are working on (guranteed by schedule/parallelize). Therefore, We // can cache the variable being reduced, so it can be first reduced - // serially inside a thread, before reduced to the finally target in a - // synchronized operation. We will cache over some serial inner loops, - // if reduction is invariant to this loop, or if the loop densly - // iterates over the reduction - ID loopToCache; // Scope to flush locally accumulated result to target - // tensor + // serially inside a thread, before reduced to the finally target in + // a synchronized operation. We will cache over some serial inner + // loops, if reduction is invariant to this loop, or if the loop + // densly iterates over the reduction + ID loopToCache; // Scope to flush locally accumulated result to + // target tensor std::vector preserveDim(op->indices_.size(), false); if (serialOverRed_.count(op->id())) { // Cache at out of the outer-most serial fully reduction loop @@ -310,8 +290,8 @@ Stmt MakeSyncReduction::visit(const ReduceTo &_op) { newTargetIndices.emplace_back(idx); } } - // Try to reuse existing cache array with the same size and the same - // target indices + // Try to reuse existing cache array with the same size and the + // same target indices for (auto &existing : cacheSync_[loopToCache]) { if (existing.oldNode_->var_ == _op->var_ && existing.preserveDim_ == preserveDim && @@ -418,8 +398,8 @@ Stmt makeParallelReduction(const Stmt &_op, const Ref &target) { std::holds_alternative(parallel) && std::get(parallel).level_ == CUDAScope::Thread && d.later() != d.earlier()) { - // Use `__syncthreads` inserted by `pass/gpu/make_sync`, instead of - // synchronizing individual `ReduceTo`s + // Use `__syncthreads` inserted by `pass/gpu/make_sync`, instead + // of synchronizing individual `ReduceTo`s return; } toAlter[d.later().as()->id()].insert(loopId); diff --git a/src/pass/pb_simplify.cc b/src/pass/pb_simplify.cc index f283d6625..cff455d0f 100644 --- a/src/pass/pb_simplify.cc +++ b/src/pass/pb_simplify.cc @@ -1,69 +1,255 @@ +#include +#include +#include + +#include #include +#include +#include +#include #include #include +#include #include +#include namespace freetensor { -template -static void appendTo(std::vector &target, const std::vector &other) { - target.insert(target.end(), other.begin(), other.end()); +int64_t CompUniqueBoundsPB::Bound::lowerInt() const { + auto lower = dimMinVal(bound_, 0); + if (lower.isNegInf()) + return LLONG_MIN; + return floorDiv(lower.numSi(), lower.denSi()); +} +int64_t CompUniqueBoundsPB::Bound::upperInt() const { + auto upper = dimMaxVal(bound_, 0); + if (upper.isInf()) + return LLONG_MAX; + return ceilDiv(upper.numSi(), upper.denSi()); +} +std::optional CompUniqueBoundsPB::Bound::getInt() const { + auto upper = dimFixVal(bound_, 0); + if (!upper.isInt()) + return std::nullopt; + ASSERT(upper.denSi() == 1); + return upper.numSi(); } -void PBCompBounds::visitExpr(const Expr &op) { - CompUniqueBounds::visitExpr(op); +namespace { - if (visited_.count(op)) { - return; +// Translate a computed bound function in ISL back to our Expr. +// We prefer min/max expressions as final results, so we first test if simply +// min/max of the pieces gives the correct result; if not, fallback to IfExpr. +Expr translateBoundFunc( + PBCtx &ctx, const PBSet &boundSet, + const std::unordered_map &demangleMap) { + + if (boundSet.empty()) { + return nullptr; } - visited_.insert(op); - if (!isInt(op->dtype())) { - return; + // TODO: clear out those not related params + PBSet compactedBoundSet = coalesce(boundSet); + auto parsed = parsePBFuncReconstructMinMax(ctx, compactedBoundSet); + + Expr result; + ReplaceIter demangler(demangleMap); + for (auto piece : views::reverse(parsed)) { + for (auto &&arg : piece.args_) + ASSERT(demangleMap.contains(arg)); + ASSERT(piece.values_.size() == 1); + auto pieceExpr = demangler(piece.values_[0]); + if (piece.cond_.isValid()) { + auto condExpr = demangler(piece.cond_); + result = result.isValid() ? makeIfExpr(condExpr, pieceExpr, result) + : pieceExpr; + } else { + result = pieceExpr; + } } - auto &&[expr, vars] = genPBExpr_.gen(op); - // We use the original conditions instead of relying on transient bounds - // here. E.g., for x + y <= 2, and we are computing the maximum value of x + - // y, we shall not rely on x < 2 - y and y < 2 - x. Instead, we use x + y < - // 2 directly - std::vector condExprs; - for (auto &&cond : transients_.conds()) { - auto &&[condExpr, condVars] = genPBExpr_.gen(cond); - for (auto &&var : condVars) { - vars.insert(var); + return result; +} + +} // namespace + +Expr CompUniqueBoundsPB::Bound::lowerExpr() const { + return bound_.hasLowerBound(0) + ? translateBoundFunc(*ctx_, lexmin(bound_), *demangleMap_) + : nullptr; +} +Expr CompUniqueBoundsPB::Bound::upperExpr() const { + return bound_.hasUpperBound(0) + ? translateBoundFunc(*ctx_, lexmax(bound_), *demangleMap_) + : nullptr; +} + +Ref CompUniqueBoundsPB::Bound::restrictScope( + const std::unordered_set &scope) const { + std::vector axesToProject; + for (int i = 0; i < bound_.nParamDims(); ++i) { + for (auto &&used : allNames(demangleMap_->at(bound_.nameParamDim(i)))) { + if (!scope.contains(used)) { + axesToProject.emplace_back(i); + break; + } } - condExprs.emplace_back(condExpr); } + auto newBound = bound_; + for (auto axes : views::reverse(axesToProject)) + newBound = projectOutParamDims(newBound, axes, 1); + return Ref::make(ctx_, demangleMap_, newBound); +} + +Expr CompUniqueBoundsPB::Bound::simplestExpr( + const std::unordered_map &orderedScope) const { + + // first test the original map to be single valued + if (!bound_.isSingleValued()) + return nullptr; - std::string str = "{["; - for (auto &&[i, var] : views::enumerate(vars)) { - str += (i == 0 ? "" : ", ") + var.second; + std::vector> axesScopeLevel; + for (int i = 0; i < bound_.nParamDims(); ++i) { + auto name = bound_.nameParamDim(i); + int scopeLevel = 0; + for (auto &&used : allUses(demangleMap_->at(name))) + scopeLevel = std::max(scopeLevel, orderedScope.at(used)); + axesScopeLevel.emplace_back(name, scopeLevel); } - str += "] -> [" + expr + "]"; - for (auto &&[i, cond] : views::enumerate(condExprs)) { - str += (i == 0 ? ": " : " and ") + cond; + // sort to innermost first, we will try remove them one by one + std::sort(axesScopeLevel.begin(), axesScopeLevel.end(), + [](auto &&a, auto &&b) { return a.second > b.second; }); + + // remove one axis at a time, try until it's not single valued + auto restrictedBound = bound_; + for (auto &&[axis, _] : axesScopeLevel) { + auto newRestrictedBound = + projectOutParamById(std::move(restrictedBound), axis); + if (!newRestrictedBound.isSingleValued()) + break; + restrictedBound = std::move(newRestrictedBound); } - str += "}"; - PBMap map(isl_, str); - PBSet image = range(std::move(map)); - PBVal maxVal = dimMaxVal(image, 0); - if (maxVal.isRat()) { - auto &&list = getUpper(op); - auto maxP = maxVal.numSi(); - auto maxQ = maxVal.denSi(); - updUpper(list, UpperBound{LinearExpr>{ - {}, Rational{maxP, maxQ}}}); - setUpper(op, std::move(list)); + return translateBoundFunc(*ctx_, restrictedBound, *demangleMap_); +} + +Ref CompUniqueBoundsPB::getBound(const Expr &op) { + if (!isInt(op->dtype())) + return nullptr; + + // check if the cache is valid + if (auto place = transients_.currentStmt(); place != cachedPlace_) { + // invalid, refresh it with the new transients condition + cachedPlace_ = place; + + // construct full condition + Expr fullCond = makeBoolConst(true); + for (auto &&cond : transients_.conds()) + fullCond = makeLAnd(fullCond, cond); + + // generate PB condition + auto [str, varMap] = genPBExpr_.gen(fullCond); + cachedConds_ = + PBSet(*ctx_, "[" + (varMap | views::values | join(", ")) + + "] -> { [unique_bounded_var]: " + str + " }"); + + // initialize known demangle map + cachedFreeVars_ = decltype(cachedFreeVars_)::make(); + for (auto &&[expr, pbVar] : varMap) { + ASSERT(!cachedFreeVars_->contains(pbVar)); + (*cachedFreeVars_)[pbVar] = expr; + } + + // clear cached query results + cachedValues_.clear(); } - PBVal minVal = dimMinVal(image, 0); - if (minVal.isRat()) { - auto &&list = getLower(op); - auto minP = minVal.numSi(); - auto minQ = minVal.denSi(); - updLower(list, LowerBound{LinearExpr>{ - {}, Rational{minP, minQ}}}); - setLower(op, std::move(list)); + + // find in cached results + if (auto it = cachedValues_.find(op); it != cachedValues_.end()) + return it->second; + + // not previously queried, construct the bound + auto [str, varMap] = genPBExpr_.gen(op); + auto bound = + (intersect(PBSet(*ctx_, "[" + (varMap | views::values | join(", ")) + + "] -> { [" + str + "] }"), + cachedConds_)); + // update free variables + for (auto &&[expr, pbVar] : varMap) { + if (auto it = cachedFreeVars_->find(pbVar); + it != cachedFreeVars_->end()) + ASSERT(HashComparator()(it->second, expr)); + else + (*cachedFreeVars_)[pbVar] = expr; } + return cachedValues_[op] = Ref::make(ctx_, cachedFreeVars_, bound); +} + +bool CompUniqueBoundsPB::alwaysLE(const Expr &lhs, const Expr &rhs) { + auto l = insertDims(getBound(lhs).as()->bound_, 1, 1), + r = insertDims(getBound(rhs).as()->bound_, 0, 1); + // we check for the emptiness of l > r; if empty, it means we never have l > + // r, or equivalently always have l <= r + auto combined = intersect(intersect(l, r), PBSet(*ctx_, "{[l, r]: l > r}")); + return combined.empty(); +} + +bool CompUniqueBoundsPB::alwaysLT(const Expr &lhs, const Expr &rhs) { + auto l = insertDims(getBound(lhs).as()->bound_, 1, 1), + r = insertDims(getBound(rhs).as()->bound_, 0, 1); + // similar to alwaysLE, but !LT = GE + auto combined = + intersect(intersect(l, r), PBSet(*ctx_, "{[l, r]: l >= r}")); + return combined.empty(); +} + +std::pair CompUniqueBoundsPB::unionBounds( + const std::vector> &_bounds) { + // if no bound presented, return an empty range + if (_bounds.size() == 0) + return {makeIntConst(0), makeIntConst(-1)}; + + // PBSet in _bounds may be from foreign ctx. Reconstruct them in our ctx + auto bounds = ranges::to( + _bounds | views::transform([&](auto &&_bound) { + ASSERT(_bound->type() == BoundType::Presburger); + auto &&bound = _bound.template as(); + return Ref::make(ctx_, bound->demangleMap_, + PBSet(*ctx_, toString(bound->bound_))); + })); + + // union the bounds + PBSet bound = bounds[0]->bound_; + for (size_t i = 1; i < bounds.size(); ++i) { + bound = uni(std::move(bound), bounds[i]->bound_); + } + bound = coalesce(std::move(bound)); + + // construct the demangle map + std::unordered_map demangleMap; + for (isl_size dim = 0; dim < bound.nParamDims(); ++dim) { + auto dimName = bound.nameParamDim(dim); + Expr demangled; + for (const auto &srcBound : bounds) { + auto &&srcDemangleMap = *srcBound.as()->demangleMap_; + auto it = srcDemangleMap.find(dimName); + if (it != srcDemangleMap.end()) { + if (demangled.isValid()) { + ASSERT(HashComparator{}(demangled, it->second)); + } else { + demangled = it->second; + } + } + } + demangleMap[dimName] = demangled; + } + + // translate the lower and upper bounds back to expression + auto l = bound.hasLowerBound(0) + ? translateBoundFunc(*ctx_, lexmin(bound), demangleMap) + : nullptr; + auto u = bound.hasUpperBound(0) + ? translateBoundFunc(*ctx_, lexmax(bound), demangleMap) + : nullptr; + return {l, u}; } Stmt pbSimplify(const Stmt &op) { diff --git a/src/pass/shrink_for.cc b/src/pass/shrink_for.cc index 1e7db0652..38fceaa28 100644 --- a/src/pass/shrink_for.cc +++ b/src/pass/shrink_for.cc @@ -56,7 +56,7 @@ Stmt ShrinkFor::visitStmt(const Stmt &stmt) { } if (checker.hasSideEffect()) { for (auto &&[_var, _names] : views::zip(iterStack_, namesStack_)) { - auto &&names = _names; + auto &&names = filterNames(_names); // We need linear programming from PBCompBounds, because the // minimum/maximum value of a linear function does not always appear @@ -64,19 +64,12 @@ Stmt ShrinkFor::visitStmt(const Stmt &stmt) { // See 2.pass/test_shrink_for.py::test_linear_bounds // // PBCompBounds requires one instance per Stmt - PBCompBounds bound(*this); + CompUniqueBoundsPB bound(*this); + // Trigger recomputing in analyze/comp_unique_bounds auto var = deepCopy(_var).as(); - - std::vector lower, upper; - for (auto &&b : bound.getDefinedLower(var, names)) { - lower.emplace_back(b.expr()); - } - for (auto &&b : bound.getDefinedUpper(var, names)) { - upper.emplace_back(b.expr()); - } - newRange_[var].first.emplace_back(std::move(lower)); - newRange_[var].second.emplace_back(std::move(upper)); + newRange_[var].emplace_back( + bound.getBound(var)->restrictScope(names)); } } @@ -95,18 +88,22 @@ Stmt ShrinkFor::visit(const For &_op) { namesStack_.pop_back(); iterStack_.pop_back(); + if (!filterLoop(op)) { + return op; + } + if (!newRange_.count(var)) { return makeStmtSeq({}); } - auto lower = makeMinMax(newRange_.at(var).first); - auto upper = makeMaxMin(newRange_.at(var).second); + + // PBCompBounds requires one instance per Stmt + CompUniqueBoundsPB bound(*this); + + auto [lower, upper] = bound.unionBounds(newRange_[var]); if (op->property_->unroll_) { // Backends do not support these loops to be of variable lengths - // PBCompBounds requires one instance per Stmt - PBCompBounds bound(*this); - lower = makeIntConst(bound.getIntLower(lower)); upper = makeIntConst(bound.getIntUpper(upper)); } @@ -152,8 +149,8 @@ Stmt shrinkFor(const Stmt &_op, const Stmt &subAST, bool doSimplify) { shrinker.setSubAST(subAST); op = shrinker(op); - if (doSimplify) - op = simplify(op); + if (doSimplify) // Make new ranges simple + remove redundant branches + op = simplify(z3Simplify(op)); return op; } diff --git a/src/pass/shrink_var.cc b/src/pass/shrink_var.cc index 7fb122b85..7eb649b52 100644 --- a/src/pass/shrink_var.cc +++ b/src/pass/shrink_var.cc @@ -24,33 +24,48 @@ Stmt ShrinkVar::visit(const VarDef &_op) { if (isInputting(_op->buffer_->atype()) || isOutputting(_op->buffer_->atype()) || _op->viewOf_.has_value() || !findAllStmt(_op, isViewOfThis).empty() || _op->pinned_ || - !newRange_.count(_op->id())) { + !newRangeWithShape_.count(_op->id()) || + !newRangeWithoutShape_.count(_op->id())) { return Mutator::visit(_op); } - auto &&range = newRange_.at(_op->id()); + auto &&rangeWithShape = newRangeWithShape_.at(_op->id()); + auto &&rangeWithoutShape = newRangeWithoutShape_.at(_op->id()); size_t n = _op->buffer_->tensor()->shape().size(); - ASSERT(range.lower_.size() == n); - ASSERT(range.upper_.size() == n); - ASSERT(range.len_.size() == n); - ASSERT(!lower_.count(_op->name_)); - ASSERT(!upper_.count(_op->name_)); - lower_[_op->name_] = range.lower_; - upper_[_op->name_] = range.upper_; + + ASSERT(rangeWithShape.lower_.size() == n); + ASSERT(rangeWithShape.upper_.size() == n); + ASSERT(rangeWithShape.len_.size() == n); + ASSERT(!lowerWithShape_.count(_op->name_)); + ASSERT(!upperWithShape_.count(_op->name_)); + lowerWithShape_[_op->name_] = rangeWithShape.lower_; + upperWithShape_[_op->name_] = rangeWithShape.upper_; + + ASSERT(rangeWithoutShape.lower_.size() == n); + ASSERT(rangeWithoutShape.upper_.size() == n); + ASSERT(rangeWithoutShape.len_.size() == n); + ASSERT(!lowerWithoutShape_.count(_op->name_)); + ASSERT(!upperWithoutShape_.count(_op->name_)); + lowerWithoutShape_[_op->name_] = rangeWithoutShape.lower_; + upperWithoutShape_[_op->name_] = rangeWithoutShape.upper_; auto __op = Mutator::visit(_op); ASSERT(__op->nodeType() == ASTNodeType::VarDef); auto op = __op.as(); for (auto &&[len, newLen] : - views::zip(op->buffer_->tensor()->shape(), range.len_)) { + views::zip(op->buffer_->tensor()->shape(), rangeWithShape.len_)) { if (newLen.isValid()) { len = newLen; } } - lower_.erase(_op->name_); - upper_.erase(_op->name_); + + lowerWithShape_.erase(_op->name_); + upperWithShape_.erase(_op->name_); + lowerWithoutShape_.erase(_op->name_); + upperWithoutShape_.erase(_op->name_); + return op; } @@ -92,14 +107,16 @@ Stmt shrinkVar(const Stmt &_op) { // (3) Simplify the new indicies // (1) - std::unordered_map bounds; + std::unordered_map boundsWithShape, boundsWithoutShape; for (auto &&[varDefId, name] : allDefs(op, {AccessType::Cache})) { - bounds[varDefId] = + boundsWithShape[varDefId] = + compAccessBound(op, varDefId, COMP_ACCESS_BOUND_READ, true); + boundsWithoutShape[varDefId] = compAccessBound(op, varDefId, COMP_ACCESS_BOUND_READ, false); } // (2) - op = ShrinkVar(bounds)(op); + op = ShrinkVar(boundsWithShape, boundsWithoutShape)(op); // (3) return simplify(z3Simplify(op)); @@ -109,12 +126,14 @@ Stmt shrinkSingleVar(const Stmt &_op, const ID &varDefId) { auto op = removeDeadVar(_op); // (1) - std::unordered_map bounds; - bounds[varDefId] = + std::unordered_map boundsWithShape, boundsWithoutShape; + boundsWithShape[varDefId] = + compAccessBound(op, varDefId, COMP_ACCESS_BOUND_READ, true); + boundsWithoutShape[varDefId] = compAccessBound(op, varDefId, COMP_ACCESS_BOUND_READ, false); // (2) - op = ShrinkVar(bounds)(op); + op = ShrinkVar(boundsWithShape, boundsWithoutShape)(op); // (3) return simplify(z3Simplify(op)); diff --git a/src/pass/simplify.cc b/src/pass/simplify.cc index 74e386192..84ca10e43 100644 --- a/src/pass/simplify.cc +++ b/src/pass/simplify.cc @@ -9,29 +9,6 @@ namespace freetensor { -class CountHeavyOps : public Visitor { - int cnt_ = 0; - - public: - int cnt() const { return cnt_; } - - protected: - void visitExpr(const Expr &op) { - Visitor::visitExpr(op); - if (!op->isConst() && op->nodeType() != ASTNodeType::Add && - op->nodeType() != ASTNodeType::Sub && - op->nodeType() != ASTNodeType::Mul) { - cnt_++; - } - } -}; - -static int countHeavyOps(const Expr &op) { - CountHeavyOps visitor; - visitor(op); - return visitor.cnt(); -} - static std::vector factorize(const Expr &expr) { std::vector factors; std::function recur = [&](const Expr &expr) { @@ -76,6 +53,49 @@ static Expr reduceMul(const std::vector &factors) { return ret.isValid() ? ret : makeIntConst(1); } +static Expr recursiveNegateMul(const Expr &e) { + if (e->nodeType() == ASTNodeType::IntConst) { + return makeIntConst(-e.as()->val_); + } else if (e->nodeType() == ASTNodeType::FloatConst) { + return makeFloatConst(-e.as()->val_); + } else if (e->nodeType() == ASTNodeType::Mul) { + auto &&mul = e.as(); + if (auto &&nl = recursiveNegateMul(mul->lhs_); nl.isValid()) { + return makeMul(nl, mul->rhs_); + } else if (auto &&nr = recursiveNegateMul(mul->rhs_); nr.isValid()) { + return makeMul(mul->lhs_, nr); + } else { + return nullptr; + } + } else { + return nullptr; + } +} + +static std::pair recursiveGetConstOffset(const Expr &e) { + if (e->nodeType() == ASTNodeType::IntConst) { + return {nullptr, e.as()->val_}; + } else if (e->nodeType() == ASTNodeType::Add) { + auto &&add = e.as(); + auto &&[le, lc] = recursiveGetConstOffset(add->lhs_); + auto &&[re, rc] = recursiveGetConstOffset(add->rhs_); + return {le.isValid() && re.isValid() ? makeAdd(le, re) + : le.isValid() ? le + : re, + lc + rc}; + } else if (e->nodeType() == ASTNodeType::Sub) { + auto &&sub = e.as(); + auto &&[le, lc] = recursiveGetConstOffset(sub->lhs_); + auto &&[re, rc] = recursiveGetConstOffset(sub->rhs_); + return {le.isValid() && re.isValid() ? makeSub(le, re) + : le.isValid() ? le + : makeSub(makeIntConst(0), re), + lc - rc}; + } else { + return {e, 0}; + } +} + void FindInnerMostScope::visit(const Var &op) { Visitor::visit(op); if (!varScope_.count(op->name_)) { @@ -118,36 +138,12 @@ Expr SimplifyPass::visitExpr(const Expr &_op) { return op; } - Expr best = nullptr; - auto bestScope = -1, bestHeavyOps = -1; - for (auto &&lower : unique_->getLower(op)) { - for (auto &&upper : unique_->getUpper(op)) { - // Check upper <= lower ==> equal - // Here we use the less precise alwaysLE instead of analyzing bounds - // of `upper - lower`, in order to avoid infinite recursion - if (freetensor::alwaysLE(upper, lower)) { - // We need to choose the simplest one. Otherwise we are always - // picking the original expression - Expr expr; - if (upper.lin().coeff_.size() + (upper.lin().bias_ != 0) > - lower.lin().coeff_.size() + (lower.lin().bias_ != 0)) { - expr = lower.expr(); - } else { - expr = upper.expr(); - } - auto scope = findInnerMostScope(varScope_, expr); - auto heavyOps = countHeavyOps(expr); - if (!best.isValid() || scope < bestScope || - (scope == bestScope && heavyOps < bestHeavyOps)) { - best = expr, bestScope = scope, bestHeavyOps = heavyOps; - } - break; - } + if (auto bound = unique_->getBound(op); bound.isValid()) { + Expr best = bound->simplestExpr(varScope_); + if (best.isValid() && !HashComparator()(best, op)) { + return best; } } - if (best.isValid() && !HashComparator()(best, op)) { - return best; - } return op; } @@ -166,21 +162,34 @@ Expr SimplifyPass::visit(const Add &_op) { return op->lhs_; } + if (op->lhs_->nodeType() == ASTNodeType::IntConst) { + if (auto &&[re, rc] = recursiveGetConstOffset(op->rhs_); rc != 0) { + return makeAdd(makeIntConst(op->lhs_.as()->val_ + rc), + re); + } + } + if (op->rhs_->nodeType() == ASTNodeType::IntConst) { + if (auto &&[le, lc] = recursiveGetConstOffset(op->lhs_); lc != 0) { + return makeAdd( + le, makeIntConst(op->rhs_.as()->val_ + lc)); + } + } + if (op->lhs_->isConst() && op->rhs_->nodeType() == ASTNodeType::Min) { return makeMin(makeAdd(op->lhs_, op->rhs_.as()->lhs_), makeAdd(op->lhs_, op->rhs_.as()->rhs_)); } if (op->lhs_->isConst() && op->rhs_->nodeType() == ASTNodeType::Max) { - return makeMax(makeAdd(op->lhs_, op->rhs_.as()->lhs_), - makeAdd(op->lhs_, op->rhs_.as()->rhs_)); + return makeMax(makeAdd(op->lhs_, op->rhs_.as()->lhs_), + makeAdd(op->lhs_, op->rhs_.as()->rhs_)); } if (op->lhs_->nodeType() == ASTNodeType::Min && op->rhs_->isConst()) { return makeMin(makeAdd(op->lhs_.as()->lhs_, op->rhs_), makeAdd(op->lhs_.as()->rhs_, op->rhs_)); } if (op->lhs_->nodeType() == ASTNodeType::Max && op->rhs_->isConst()) { - return makeMax(makeAdd(op->lhs_.as()->lhs_, op->rhs_), - makeAdd(op->lhs_.as()->rhs_, op->rhs_)); + return makeMax(makeAdd(op->lhs_.as()->lhs_, op->rhs_), + makeAdd(op->lhs_.as()->rhs_, op->rhs_)); } return op; @@ -194,10 +203,28 @@ Expr SimplifyPass::visit(const Sub &_op) { ASSERT(__op->nodeType() == ASTNodeType::Sub); auto op = __op.as(); + if (equals(op->lhs_, 0)) { + if (auto &&nr = recursiveNegateMul(op->rhs_); nr.isValid()) { + return nr; + } + } if (equals(op->rhs_, 0)) { return op->lhs_; } + if (op->lhs_->nodeType() == ASTNodeType::IntConst) { + if (auto &&[re, rc] = recursiveGetConstOffset(op->rhs_); rc != 0) { + return makeSub(makeIntConst(op->lhs_.as()->val_ - rc), + re); + } + } + if (op->rhs_->nodeType() == ASTNodeType::IntConst) { + if (auto &&[le, lc] = recursiveGetConstOffset(op->lhs_); lc != 0) { + return makeAdd( + le, makeIntConst(lc - op->rhs_.as()->val_)); + } + } + if (op->lhs_->isConst() && op->rhs_->nodeType() == ASTNodeType::Min) { return makeMax(makeSub(op->lhs_, op->rhs_.as()->lhs_), makeSub(op->lhs_, op->rhs_.as()->rhs_)); @@ -238,6 +265,16 @@ Expr SimplifyPass::visit(const Mul &_op) { if (equals(op->rhs_, 0)) { return makeIntConst(0); } + if (equals(op->lhs_, -1)) { + if (auto &&nr = recursiveNegateMul(op->rhs_); nr.isValid()) { + return nr; + } + } + if (equals(op->rhs_, -1)) { + if (auto &&nl = recursiveNegateMul(op->lhs_); nl.isValid()) { + return nl; + } + } if (op->lhs_->isConst() && op->rhs_->nodeType() == ASTNodeType::Min) { if (isGT0(op->lhs_->dtype())) { diff --git a/src/pass/use_builtin_div.cc b/src/pass/use_builtin_div.cc index f3e7b0ab4..2911725fa 100644 --- a/src/pass/use_builtin_div.cc +++ b/src/pass/use_builtin_div.cc @@ -6,7 +6,7 @@ static Expr makeNeg(const Expr &expr) { return makeSub(makeIntConst(0), expr); } Stmt UseBuiltinDiv::visitStmt(const Stmt &op) { auto boundOfOuterStmt = bound_; - bound_ = Ref::make(*this); + bound_ = Ref::make(*this); auto ret = BaseClass::visitStmt(op); bound_ = boundOfOuterStmt; return ret; diff --git a/src/pass/z3_simplify.cc b/src/pass/z3_simplify.cc index 6f09e73d4..ebb778e68 100644 --- a/src/pass/z3_simplify.cc +++ b/src/pass/z3_simplify.cc @@ -433,8 +433,12 @@ Expr Z3Simplify::visit(const IfExpr &op) { } auto thenCase = (*this)(op->thenCase_); auto elseCase = (*this)(op->elseCase_); - return makeIfExpr(std::move(cond), std::move(thenCase), std::move(elseCase), - op->debugBlame()); + auto ret = makeIfExpr(cond, thenCase, elseCase, op->debugBlame()); + if (exists(cond) && exists(thenCase) && exists(elseCase)) { + put(ret, z3::ite(get(cond), get(thenCase), get(elseCase)), + cat(conds(cond), cat(conds(thenCase), conds(elseCase)))); + } + return ret; } Stmt Z3Simplify::visit(const If &op) { diff --git a/src/serialize/print_ast.cc b/src/serialize/print_ast.cc index 40dc3a778..d29783a69 100644 --- a/src/serialize/print_ast.cc +++ b/src/serialize/print_ast.cc @@ -794,16 +794,17 @@ std::string toString(const AST &op, bool pretty, bool printAllId) { } std::string toString(const AST &op, bool pretty, bool printAllId, - bool dtypeInLoad, bool hexFloat, bool compact) { + bool dtypeInLoad, bool hexFloat, bool compact, + bool parenDespitePriority) { return toString(op, pretty, printAllId, dtypeInLoad, hexFloat, compact, - Config::printSourceLocation()); + parenDespitePriority, Config::printSourceLocation()); } std::string toString(const AST &op, bool pretty, bool printAllId, bool dtypeInLoad, bool hexFloat, bool compact, - bool printSourceLocation) { + bool parenDespitePriority, bool printSourceLocation) { PrintVisitor visitor(printAllId, pretty, dtypeInLoad, hexFloat, compact, - printSourceLocation); + parenDespitePriority, printSourceLocation); visitor(op); return visitor.toString( [](const CodeGenStream &stream) { return stream.os_.str(); }); diff --git a/test/20.pass/test_prop_one_time_use.py b/test/20.pass/test_prop_one_time_use.py index 99cb43b20..8941664d9 100644 --- a/test/20.pass/test_prop_one_time_use.py +++ b/test/20.pass/test_prop_one_time_use.py @@ -237,7 +237,7 @@ def test_thread_local_no_prop(): ("u", (10,), "float32", "cache", "gpu/shared"), ]) as (t, u): with ft.For("j", 0, 10, label="Lj1") as j: - t[j] = x[i, j] + (t[j - 1] if j > 0 else 0) + t[j] = x[i, j] + t[(j + 1) % 10] u[j] = ft.sin(t[j]) * ft.cos(t[j]) with ft.For("j", 0, 10, label="Lj2") as j: # Used `u` for only once, but we can't propagate the `t`-expression @@ -248,7 +248,7 @@ def test_thread_local_no_prop(): s.parallelize("Li", "blockIdx.x") s.parallelize("Lj2", "threadIdx.x") ast = s.ast() - ast = ft.lower(ast, verbose=1) + ast = ft.lower(ast, verbose=1, skip_passes=['use_builtin_div']) with ft.VarDef([("x", (5, 10), "float32", "input", "gpu/global"), ("y", (5, 10), "float32", "output", "gpu/global")]) as (x, @@ -258,7 +258,7 @@ def test_thread_local_no_prop(): with ft.VarDef("t", (10,), "float32", "cache", "gpu/local") as t: with ft.For("j", 0, 10, label="Lj1") as j: - t[j] = x[i, j] + (t[j - 1] if j > 0 else 0) + t[j] = x[i, j] + t[(j + 1) % 10] u[j] = ft.sin(t[j]) * ft.cos(t[j]) with ft.For("j", 0, 10, label="Lj2") as j: y[i, j] = u[j] # Unchanged diff --git a/test/20.pass/test_shrink_for.py b/test/20.pass/test_shrink_for.py index bf8a2708b..e49b358ac 100644 --- a/test/20.pass/test_shrink_for.py +++ b/test/20.pass/test_shrink_for.py @@ -77,6 +77,30 @@ def test_linear_bounds(): assert std.match(ast) +def test_presburger_bounds(): + with ft.VarDef([("x", (128, 128), "int32", "input", "cpu"), + ("y", (128, 128), "int32", "output", "cpu")]) as (x, y): + with ft.For("i0", 0, 8) as i0: + with ft.For("j0", 0, 8) as j0: + with ft.For("i", 0, 128) as i: + with ft.For("j", 0, 128) as j: + with ft.If(ft.l_and(i // 16 == i0, j // 16 == j0)): + y[i, j] = x[i, j] * 2 + ast = ft.pop_ast(verbose=True) + ast = ft.lower(ast, verbose=1) + + with ft.VarDef([("x", (128, 128), "int32", "input", "cpu"), + ("y", (128, 128), "int32", "output", "cpu")]) as (x, y): + with ft.For("i0", 0, 8) as i0: + with ft.For("j0", 0, 8) as j0: + with ft.For("i", 16 * i0, 16 * i0 + 16) as i: + with ft.For("j", 16 * j0, 16 * j0 + 16) as j: + y[i, j] = x[i, j] * 2 + std = ft.pop_ast() + + assert std.match(ast) + + def test_multiple_branches(): with ft.VarDef([("x", (32, 4), "int32", "inout", "cpu"), ("y", (32, 4), "int32", "output", "cpu")]) as (x, y): diff --git a/test/20.pass/test_shrink_var.py b/test/20.pass/test_shrink_var.py index 3811547c2..632936b0d 100644 --- a/test/20.pass/test_shrink_var.py +++ b/test/20.pass/test_shrink_var.py @@ -188,68 +188,67 @@ def test_no_changing_unbounded_var(): assert ft.pop_ast().match(ast) -# FIXME: Fix this test -#def test_const_in_branch_1(): -# with ft.VarDef([("x", (5,), "int32", "input", "cpu"), -# ("y1", (4,), "int32", "output", "cpu"), -# ("y2", (4,), "int32", "output", "cpu")]) as (x, y1, y2): -# with ft.For("i", 0, 4) as i: -# with ft.VarDef("b", (4,), "int32", "cache", "cpu") as b: -# with ft.If(i == 2): -# b[2] = x[2] -# with ft.Else(): -# b[i] = x[i] + x[i + 1] -# y1[i] = b[i] * i -# y2[i] = b[i] + i -# ast = ft.pop_ast() -# print(ast) -# ast = ft.lower(ast) -# print(ast) -# -# with ft.VarDef([("x", (5,), "int32", "input", "cpu"), -# ("y1", (4,), "int32", "output", "cpu"), -# ("y2", (4,), "int32", "output", "cpu")]) as (x, y1, y2): -# with ft.For("i", 0, 4) as i: -# with ft.VarDef("b", (1,), "int32", "cache", "cpu") as b: -# with ft.If(i == 2): -# b[0] = x[2] -# with ft.Else(): -# b[0] = x[i] + x[i + 1] -# y1[i] = b[0] * i -# y2[i] = b[0] + i -# std = ft.pop_ast() -# -# assert std.match(ast) - -# FIXME: Fix this test -#def test_const_in_branch_2(): -# with ft.VarDef([("x", (5,), "int32", "input", "cpu"), -# ("y1", (4,), "int32", "output", "cpu"), -# ("y2", (4,), "int32", "output", "cpu")]) as (x, y1, y2): -# with ft.For("i", 0, 4) as i: -# with ft.VarDef("b", (4,), "int32", "cache", "cpu") as b: -# with ft.If(i < 3): -# b[i] = x[i] + x[i + 1] -# with ft.Else(): -# b[3] = x[3] -# y1[i] = b[i] * i -# y2[i] = b[i] + i -# ast = ft.pop_ast() -# print(ast) -# ast = ft.lower(ast) -# print(ast) -# -# with ft.VarDef([("x", (5,), "int32", "input", "cpu"), -# ("y1", (4,), "int32", "output", "cpu"), -# ("y2", (4,), "int32", "output", "cpu")]) as (x, y1, y2): -# with ft.For("i", 0, 4) as i: -# with ft.VarDef("b", (1,), "int32", "cache", "cpu") as b: -# with ft.If(i < 3): -# b[0] = x[i] + x[i + 1] -# with ft.Else(): -# b[0] = x[3] -# y1[i] = b[0] * i -# y2[i] = b[0] + i -# std = ft.pop_ast() -# -# assert std.match(ast) +def test_const_in_branch_1(): + with ft.VarDef([("x", (5,), "int32", "input", "cpu"), + ("y1", (4,), "int32", "output", "cpu"), + ("y2", (4,), "int32", "output", "cpu")]) as (x, y1, y2): + with ft.For("i", 0, 4) as i: + with ft.VarDef("b", (4,), "int32", "cache", "cpu") as b: + with ft.If(i == 2): + b[2] = x[2] + with ft.Else(): + b[i] = x[i] + x[i + 1] + y1[i] = b[i] * i + y2[i] = b[i] + i + ast = ft.pop_ast() + print(ast) + ast = ft.lower(ast) + print(ast) + + with ft.VarDef([("x", (5,), "int32", "input", "cpu"), + ("y1", (4,), "int32", "output", "cpu"), + ("y2", (4,), "int32", "output", "cpu")]) as (x, y1, y2): + with ft.For("i", 0, 4) as i: + with ft.VarDef("b", (1,), "int32", "cache", "cpu") as b: + with ft.If(i == 2): + b[0] = x[2] + with ft.Else(): + b[0] = x[i] + x[i + 1] + y1[i] = b[0] * i + y2[i] = b[0] + i + std = ft.pop_ast() + + assert std.match(ast) + + +def test_const_in_branch_2(): + with ft.VarDef([("x", (5,), "int32", "input", "cpu"), + ("y1", (4,), "int32", "output", "cpu"), + ("y2", (4,), "int32", "output", "cpu")]) as (x, y1, y2): + with ft.For("i", 0, 4) as i: + with ft.VarDef("b", (4,), "int32", "cache", "cpu") as b: + with ft.If(i < 3): + b[i] = x[i] + x[i + 1] + with ft.Else(): + b[3] = x[3] + y1[i] = b[i] * i + y2[i] = b[i] + i + ast = ft.pop_ast() + print(ast) + ast = ft.lower(ast) + print(ast) + + with ft.VarDef([("x", (5,), "int32", "input", "cpu"), + ("y1", (4,), "int32", "output", "cpu"), + ("y2", (4,), "int32", "output", "cpu")]) as (x, y1, y2): + with ft.For("i", 0, 4) as i: + with ft.VarDef("b", (1,), "int32", "cache", "cpu") as b: + with ft.If(i < 3): + b[0] = x[i] + x[i + 1] + with ft.Else(): + b[0] = x[3] + y1[i] = b[0] * i + y2[i] = b[0] + i + std = ft.pop_ast() + + assert std.match(ast) diff --git a/test/20.pass/test_simplify.py b/test/20.pass/test_simplify.py index 141d6749c..7052f0139 100644 --- a/test/20.pass/test_simplify.py +++ b/test/20.pass/test_simplify.py @@ -7,7 +7,7 @@ # test these two passes -@pytest.mark.parametrize('p', [ft.simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify]) def test_const_fold(p): with ft.VarDef("y", (4,), "int32", "output", "cpu") as y: with ft.For("i", 0, 4) as i: @@ -24,7 +24,7 @@ def test_const_fold(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify]) def test_partial_fold(p): # This is the case that we need a symbolic bound, instead # of using integers only @@ -45,7 +45,7 @@ def test_partial_fold(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify, ft.z3_simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify]) def test_redundant_if(p): with ft.VarDef("y", (4,), "int32", "output", "cpu") as y: with ft.For("i", 0, 4) as i: @@ -63,7 +63,7 @@ def test_redundant_if(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify, ft.z3_simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify]) def test_redundant_if_2(p): with ft.VarDef("y", (4,), "int32", "output", "cpu") as y: with ft.For("i", 0, 4) as i: @@ -81,7 +81,7 @@ def test_redundant_if_2(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify, ft.z3_simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify]) def test_redundant_if_3(p): with ft.VarDef([("n", (), "int32", "input", "cpu"), ("y", (4,), "int32", "output", "cpu")]) as (n, y): @@ -101,30 +101,30 @@ def test_redundant_if_3(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify, ft.z3_simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify]) def test_int_max(p): - with ft.VarDef([("a", (5, 32), "int32", "input", "cpu"), - ("b", (5, 32), "int32", "output", "cpu")]) as (a, b): - with ft.For("i", 0, 5) as i: + with ft.VarDef([("a", (20, 64), "int32", "input", "cpu"), + ("b", (20, 64), "int32", "output", "cpu")]) as (a, b): + with ft.For("i", 0, 20) as i: with ft.For("j", 0, 2147483647) as j: - with ft.If(j < ft.min(-32 * (i % 4) + 100, 32)): + with ft.If(j < ft.min(-32 * (i % 4) + 100, 64)): b[i, j] = a[i, j] + 1 ast = ft.pop_ast(verbose=True) ast = p(ast) print(ast) - with ft.VarDef([("a", (5, 32), "int32", "input", "cpu"), - ("b", (5, 32), "int32", "output", "cpu")]) as (a, b): - with ft.For("i", 0, 5) as i: + with ft.VarDef([("a", (20, 64), "int32", "input", "cpu"), + ("b", (20, 64), "int32", "output", "cpu")]) as (a, b): + with ft.For("i", 0, 20) as i: with ft.For("j", 0, 2147483647) as j: - with ft.If(j < ft.min(-32 * (i % 4) + 100, 32)): + with ft.If(j < ft.min(-32 * (i % 4) + 100, 64)): b[i, j] = a[i, j] + 1 std = ft.pop_ast() assert std.match(ast) # Unchanged -@pytest.mark.parametrize('p', [ft.simplify, ft.z3_simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify]) def test_redundant_min(p): with ft.VarDef("y", (4,), "int32", "output", "cpu") as y: with ft.For("i", 0, 4) as i: @@ -142,7 +142,7 @@ def test_redundant_min(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify, ft.z3_simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify]) def test_redundant_max(p): with ft.VarDef("y", (4,), "int32", "output", "cpu") as y: with ft.For("i", 0, 4) as i: @@ -160,7 +160,7 @@ def test_redundant_max(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify, ft.z3_simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify]) def test_multiple_mins_1(p): with ft.VarDef([("x", (4,), "int32", "input", "cpu"), ("y", (4,), "int32", "output", "cpu")]) as (x, y): @@ -200,7 +200,7 @@ def test_multiple_mins_2(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify, ft.z3_simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify]) def test_multiple_maxes_1(p): with ft.VarDef([("x", (4,), "int32", "input", "cpu"), ("y", (4,), "int32", "output", "cpu")]) as (x, y): @@ -283,7 +283,7 @@ def test_multiple_mins_separted_by_scalar_op(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify, ft.z3_simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify]) def test_precondition_from_if(p): with ft.VarDef([ ("x1", (4,), "int32", "input", "cpu"), @@ -314,7 +314,7 @@ def test_precondition_from_if(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify, ft.z3_simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify]) def test_multiple_preconditions_from_if(p): with ft.VarDef([ ("x1", (4,), "int32", "input", "cpu"), @@ -345,7 +345,7 @@ def test_multiple_preconditions_from_if(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify, ft.z3_simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify]) def test_precondition_from_assert(p): with ft.VarDef([ ("x1", (4,), "int32", "input", "cpu"), @@ -372,7 +372,7 @@ def test_precondition_from_assert(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify, ft.z3_simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify]) def test_assert_false(p): with ft.VarDef([("x", (), "int32", "input", "cpu"), ("y", (), "int32", "output", "cpu")]) as (x, y): @@ -384,7 +384,7 @@ def test_assert_false(p): ast = p(ast) -@pytest.mark.parametrize('p', [ft.simplify, ft.z3_simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify]) def test_unreachable_assert_false(p): with ft.VarDef([("x", (), "int32", "input", "cpu"), ("y", (), "int32", "output", "cpu")]) as (x, y): @@ -436,7 +436,7 @@ def test_precondition_from_sign_type(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify, ft.z3_simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify]) def test_different_scope(p): with ft.VarDef([ ("x", (4, 10), "int32", "input", "cpu"), @@ -478,7 +478,7 @@ def test_different_scope(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify, ft.z3_simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify]) def test_dynamic(p): with ft.VarDef([("n", (), "int32", "input", "cpu"), ("y", (4,), "int32", "output", "cpu")]) as (n, y): @@ -498,7 +498,7 @@ def test_dynamic(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify, ft.z3_simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify]) def test_floor_div_1(p): with ft.VarDef([("n", (), "int32", "input", "cpu"), ("y", (4,), "int32", "output", "cpu")]) as (n, y): @@ -518,7 +518,7 @@ def test_floor_div_1(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify, ft.z3_simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify]) def test_floor_div_2(p): with ft.VarDef([("n", (), "int32", "input", "cpu"), ("y", (4,), "int32", "output", "cpu")]) as (n, y): @@ -538,7 +538,7 @@ def test_floor_div_2(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify, ft.z3_simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify]) def test_floor_div_3(p): with ft.VarDef([("x", (), "int32", "input", "cpu"), ("y", (), "int32", "output", "cpu")]) as (x, y): @@ -555,7 +555,7 @@ def test_floor_div_3(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify]) def test_floor_div_4(p): with ft.VarDef([("x", (), "int32", "input", "cpu"), ("y", (), "int32", "output", "cpu")]) as (x, y): @@ -572,7 +572,7 @@ def test_floor_div_4(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify]) def test_floor_div_5(p): with ft.VarDef([("x", (), "int32", "input", "cpu"), ("y", (), "int32", "output", "cpu")]) as (x, y): @@ -589,7 +589,7 @@ def test_floor_div_5(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify]) def test_floor_div_6(p): with ft.VarDef([("x", (), "int32", "input", "cpu"), ("y", (), "int32", "output", "cpu")]) as (x, y): @@ -606,7 +606,7 @@ def test_floor_div_6(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify]) def test_mod_1(p): with ft.VarDef([("x", (), "int32", "input", "cpu"), ("y", (), "int32", "output", "cpu")]) as (x, y): @@ -623,7 +623,7 @@ def test_mod_1(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify]) def test_mod_2(p): with ft.VarDef([("x", (), "int32", "input", "cpu"), ("y", (), "int32", "output", "cpu")]) as (x, y): @@ -646,7 +646,7 @@ def test_mod_2(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify]) def test_divisible_div(p): with ft.VarDef([("a", (), "int32", "input", "cpu"), ("b", (), "int32", "input", "cpu"), @@ -665,7 +665,7 @@ def test_divisible_div(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify]) def test_divisible_mod(p): with ft.VarDef([("a", (), "int32", "input", "cpu"), ("b", (), "int32", "input", "cpu"), @@ -684,7 +684,7 @@ def test_divisible_mod(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify]) def test_reduce_fraction_for_div(p): with ft.VarDef([("a", (), "int32", "input", "cpu"), ("b", (), "int32", "input", "cpu"), @@ -705,7 +705,7 @@ def test_reduce_fraction_for_div(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify]) def test_not_reduce_fraction_for_mod(p): with ft.VarDef([("a", (), "int32", "input", "cpu"), ("b", (), "int32", "input", "cpu"), @@ -726,7 +726,7 @@ def test_not_reduce_fraction_for_mod(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify]) def test_simplify_not_cmp(p): with ft.VarDef([ ("x", (4,), "int32", "input", "cpu"), @@ -769,7 +769,7 @@ def test_simplify_not_cmp(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify]) def test_simplify_not_logic_op(p): with ft.VarDef([("x", (4,), "int32", "input", "cpu"), ("y", (4,), "int32", "output", "cpu")]) as (x, y): @@ -794,7 +794,7 @@ def test_simplify_not_logic_op(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify]) def test_min_minus_min(p): with ft.VarDef([ ("x", (), "int32", "input", "cpu"), @@ -817,7 +817,7 @@ def test_min_minus_min(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify, ft.z3_simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify]) def test_min_max_as_bound(p): with ft.VarDef([("l", (), "int32", "input", "cpu"), ("r", (), "int32", "input", "cpu")]) as (l, r): @@ -841,7 +841,7 @@ def test_min_max_as_bound(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify, ft.z3_simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify]) def test_accessible_after_writing_if(p): with ft.VarDef([("x", (4,), "int32", "inout", "cpu"), ("y", (4,), "int32", "output", "cpu")]) as (x, y): @@ -867,7 +867,7 @@ def test_accessible_after_writing_if(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify, ft.z3_simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify]) def test_accessible_after_writing_for(p): with ft.VarDef([("x", (4,), "int32", "inout", "cpu"), ("y", (4,), "int32", "output", "cpu")]) as (x, y): @@ -896,7 +896,7 @@ def test_accessible_after_writing_for(p): assert std.match(ast) -@pytest.mark.parametrize('p', [ft.simplify]) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify]) def test_loop_length_0_or_1(p): with ft.VarDef("n", (), "int32", "input", "cpu") as n: with ft.Assert(n[()] <= 1): diff --git a/test/21.autograd/test_output_intermediates.py b/test/21.autograd/test_output_intermediates.py index e5fd04e17..5a0b2f923 100644 --- a/test/21.autograd/test_output_intermediates.py +++ b/test/21.autograd/test_output_intermediates.py @@ -296,7 +296,7 @@ def expected(n: ft.Var[(), "int32"], x, y, z_tape): #! label: V_z z = ft.empty((), "float32") z[()] = x[bn, pn] + 1 - z_tape[bn * (n * n) + pn + 1 - 1] = z[()] + z_tape[bn * (n * n) + pn] = z[()] y[pn % n] += z[()] * z[()] assert expected.body.match(ast) diff --git a/test/30.schedule/test_cache.py b/test/30.schedule/test_cache.py index c12e50bf6..bfeabfaca 100644 --- a/test/30.schedule/test_cache.py +++ b/test/30.schedule/test_cache.py @@ -312,13 +312,14 @@ def test_cache_with_multiple_conditions(): def test_fill_is_necessary_when_possibly_not_written(): - with ft.VarDef([("x", (2, 4), "int32", "input", "cpu"), - ("y", (2, 4), "int32", "output", "cpu")]) as (x, y): + with ft.VarDef([("x", (3, 4), "int32", "input", "cpu"), + ("y", (3, 4), "int32", "output", "cpu")]) as (x, y): with ft.For("i", 0, 2, label="L1") as i: with ft.For("j", 0, 4, label="L2") as j: with ft.If(i == 0): - y[0, j] = x[0, j] - y[1, j] = x[1, j] + y[1, j] = x[0, j] + y[0, j] = x[1, j] + y[2, j] = x[1, j] ast = ft.pop_ast(verbose=True) s = ft.Schedule(ast) s.cache(s.find("L2").body, "y", "cpu") @@ -326,17 +327,18 @@ def test_fill_is_necessary_when_possibly_not_written(): print(ast) ast = ft.lower(ast, verbose=1) - with ft.VarDef([("x", (2, 4), "int32", "input", "cpu"), - ("y", (2, 4), "int32", "output", "cpu")]) as (x, y): + with ft.VarDef([("x", (3, 4), "int32", "input", "cpu"), + ("y", (3, 4), "int32", "output", "cpu")]) as (x, y): with ft.For("i", 0, 2, label="L1") as i: with ft.For("j", 0, 4, label="L2") as j: - with ft.VarDef("b", (2, 1), "int32", "cache", "cpu") as b: - with ft.For("k", 0, 2) as k: + with ft.VarDef("b", (3, 1), "int32", "cache", "cpu") as b: + with ft.For("k", 0, 3) as k: b[k, 0] = y[k, j] # This statement is necessary with ft.If(i == 0): - b[0, 0] = x[0, j] - b[1, 0] = x[1, j] - with ft.For("k", 0, 2) as k: + b[1, 0] = x[0, j] + b[0, 0] = x[1, j] + b[2, 0] = x[1, j] + with ft.For("k", 0, 3) as k: y[k, j] = b[k, 0] std = ft.pop_ast()