Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] Add the TileTransposeTactic #70942

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions paddle/cinn/backends/codegen_gpu_dev.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,23 @@ std::vector<ir::stmt::StmtRef> CodeGenGpuDev::GenerateBufferAliasStmts(
}

for (auto &t : unique_tensors) {
auto data_type = t->type();
auto data_ptr_type = data_type;
data_ptr_type.set_cpp_handle();
auto tensor_type = t->type();
auto tensor_ptr_type = tensor_type;
tensor_ptr_type.set_cpp_handle();

auto buffer_type = t->buffer->dtype;
auto buffer_ptr_type = buffer_type;
buffer_ptr_type.set_cpp_handle();

Expr t_var = Var(t->name, tensor_ptr_type);
Expr buf_var = Var(t->buffer->name, buffer_ptr_type);

// A tensor and its buffer may have different types when multiple tensors
// share the same buffer. In this case, add a Cast before aliasing.
if (tensor_type != buffer_type) {
buf_var = common::cast(buf_var, tensor_ptr_type);
}

Var t_var(t->name, data_ptr_type);
Var buf_var(t->buffer->name, data_ptr_type);
buffer_alias.push_back(ir::stmt::Let(t_var, buf_var));
}

Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/tile_broadcast_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/tile_transpose_tactic.h"
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/common/enforce.h"
Expand All @@ -36,6 +37,7 @@ void DynamicShapeGroupScheduler::Init() {
InitBuckets();
tactics_.emplace_back(CreateAlignIterSpaceTactic());
tactics_.emplace_back(CreateTileBroadcastTactic());
tactics_.emplace_back(CreateTileTransposeTactic());
tactics_.emplace_back(CreateTileFirstGeneralTactic());
tactics_.emplace_back(CreateComputeInlineTactic());
tactics_.emplace_back(CreateComputeAtReductionTactic());
Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ gather_srcs(cinnapi_src SRCS compute_at_reduction_tactic.cc)
gather_srcs(cinnapi_src SRCS bind_cuda_tactic.cc)
gather_srcs(cinnapi_src SRCS arrange_storage_tactic.cc)
gather_srcs(cinnapi_src SRCS tile_broadcast_tactic.cc)
gather_srcs(cinnapi_src SRCS tile_transpose_tactic.cc)
gather_srcs(cinnapi_src SRCS tile_first_general_tactic.cc)
348 changes: 348 additions & 0 deletions paddle/cinn/ir/group_schedule/tactic/tile_transpose_tactic.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,348 @@
// Copyright (c) 2025 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/ir/group_schedule/tactic/tile_transpose_tactic.h"
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/runtime/intrinsic.h"

PD_DECLARE_bool(cinn_enable_tile_transpose);

namespace cinn {
namespace ir {
namespace {

class TileTransposeTactic final : public ScheduleTactic {
public:
void Init(ScheduleContext* context, ir::IRSchedule* sch) override;

void Apply(ir::IRSchedule* sch, const std::string& block_id) override;

std::string TacticName() const override { return "TileTransposeTactic"; }

private:
struct Candidate {
// The transposed load to do CacheRead.
ir::Expr load;

// The block where this load first appears. We will do CacheRead on this
// block, and later blocks will simply reuse the first load's result.
std::string first_appear_block_id;

// The buffer index of this load in the first block it appears.
int buffer_index;
};

void InitCandidates(ir::IRSchedule* sch);

void InitAxisInfo();

bool CandidateExists(const ir::Expr& load);

/**
* Create cache blocks for the `buffer_index`-th load in the block, and do
* tiling for the created cache blocks. This doesn't tile the block itself.
*/
void TileCacheBlock(ir::IRSchedule* sch,
const std::string& block_id,
int buffer_index);

/**
* Do tiling for the block itself.
*/
void TileBlock(ir::IRSchedule* sch, const std::string& block_id);

void CanonicalizeLayout(ir::IRSchedule* sch, const std::string& block_id);

void FuseAndBind(ir::IRSchedule* sch,
const std::string& block_id,
bool need_sync = false);

private:
ScheduleContext* context_;
bool can_apply_;

std::vector<int> common_perm_;

std::vector<int> high_axis_;
std::vector<int> src_low_axis_;
std::vector<int> dst_low_axis_;

// Map from the candidate loads' tensor names to the corresponding Candidate
// structs.
// Note: the same tensor name doesn't necessarily refers to the same load,
// because the load indices may differ. Therefore, we map tensor names to a
// list of loads of the same tensor but with different indices.
std::unordered_map<std::string, std::vector<Candidate>> tensor2candidates_;

// Map from each block's id to the candidates in the block.
std::unordered_map<std::string, std::vector<Candidate>> block2candidates_;
};

std::vector<int> GetTransposePerm(const std::vector<ir::Expr>& store_indices,
const std::vector<ir::Expr>& load_indices) {
int data_rank = store_indices.size();
if (load_indices.size() != data_rank) return {};
std::vector<int> perm(data_rank);

for (int i = 0; i < data_rank; ++i) {
ir::Expr index = load_indices[i];
if (!index.is_var()) return {};
auto it = std::find(store_indices.begin(), store_indices.end(), index);
if (it == store_indices.end()) return {};
perm[it - store_indices.begin()] = i;
}
return perm;
}

std::vector<int> OffsetVec(const std::vector<int>& vec, int offset) {
std::vector<int> new_vec = vec;
for (auto& e : new_vec) e += offset;
return new_vec;
}

std::vector<int> ArangeVec(int count, int begin = 0) {
std::vector<int> vec(count);
std::iota(vec.begin(), vec.end(), begin);
return vec;
}

void TileTransposeTactic::Init(ScheduleContext* context, ir::IRSchedule* sch) {
context_ = context;
can_apply_ = false;
if (!FLAGS_cinn_enable_tile_transpose) return;

ir::Expr module_root = sch->GetModule().GetExprs().front();
ir::Expr root_block = ir::analyzer::GetRootSBlock(module_root);
auto* root_node = root_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>();

if (root_node->attrs.count(kTileMethod) > 0) return;
if (!context->config.base_info->reduce_axis.empty()) return;

InitCandidates(sch);

VLOG(4) << "common_perm: " << utils::Join(common_perm_, ", ");
if (common_perm_.empty()) return;

can_apply_ = true;
root_node->attrs[kTileMethod] = TacticName();

InitAxisInfo();
}

bool TileTransposeTactic::CandidateExists(const ir::Expr& load) {
auto& tensor_name = load.As<ir::Load>()->tensor.as_tensor()->name;
auto it = tensor2candidates_.find(tensor_name);
if (it == tensor2candidates_.end()) return false;
for (auto& candidate : it->second) {
if (candidate.load == load) return true;
}
return false;
}

void TileTransposeTactic::InitCandidates(ir::IRSchedule* sch) {
common_perm_.clear();
tensor2candidates_.clear();
block2candidates_.clear();

for (auto& block : sch->GetAllBlocks()) {
std::vector<ir::Expr> loops = sch->GetLoops(block);
std::string block_id = ir::analyzer::GetBlockName(block);
ir::Expr store = ir::analyzer::GetStoreOfSBlock(block);
store = ir::analyzer::ExpandIterVar(store, block);
store = ir::analyzer::CanonicalizeLoopVar(store, loops);

std::vector<ir::Expr> loads = ir::ir_utils::CollectIRNodesInOrder(
store.As<ir::Store>()->value,
[](const ir::Expr* x) { return x->As<ir::Load>(); });

for (int i = 0; i < loads.size(); ++i) {
ir::Expr load = loads[i];
auto* tensor = load.As<ir::Load>()->tensor.as_tensor();
if (sch->HasBlock(tensor->name)) continue;

if (CandidateExists(load)) continue;

std::vector<int> perm = GetTransposePerm(store.As<ir::Store>()->indices,
load.As<ir::Load>()->indices);
VLOG(4) << "GetTransposePerm on load [" << i << "]: " << load
<< " perm: " << utils::Join(perm, ", ");

// Not a full load, or is an unsupported access type (reshape, slice, ...)
if (perm.size() != loops.size()) continue;
if (perm.back() == perm.size() - 1) continue;

if (common_perm_.empty()) {
common_perm_ = perm;
} else if (common_perm_ != perm) {
common_perm_.clear();
return;
}

Candidate candidate{load, block_id, i};
tensor2candidates_[tensor->name].push_back(candidate);
block2candidates_[block_id].push_back(candidate);
}
}
}

void TileTransposeTactic::InitAxisInfo() {
std::set<int> src_low_axis;
std::set<int> dst_low_axis;
std::set<int> high_axis;

dst_low_axis.insert(common_perm_.size() - 1);
for (int i = common_perm_.size() - 2; i >= 0; --i) {
if (common_perm_[i] + 1 != common_perm_[i + 1]) break;
dst_low_axis.insert(i);
}

for (int i = 0; i < common_perm_.size(); ++i) {
if (common_perm_[i] == common_perm_.size() - 1) {
src_low_axis.insert(i);
for (int j = i - 1; j >= 0; j--) {
if (common_perm_[j] + 1 != common_perm_[j + 1]) break;
src_low_axis.insert(j);
}
}
}

for (int i = 0; i < common_perm_.size(); ++i) high_axis.insert(i);
for (auto i : src_low_axis) high_axis.erase(i);
for (auto i : dst_low_axis) high_axis.erase(i);

high_axis_.assign(high_axis.begin(), high_axis.end());
src_low_axis_.assign(src_low_axis.begin(), src_low_axis.end());
dst_low_axis_.assign(dst_low_axis.begin(), dst_low_axis.end());
}

void TileTransposeTactic::Apply(ir::IRSchedule* sch,
const std::string& block_id) {
if (!can_apply_) return;

for (auto& candidate : block2candidates_[block_id]) {
if (candidate.first_appear_block_id == block_id) {
TileCacheBlock(sch, block_id, candidate.buffer_index);
}
}

TileBlock(sch, block_id);
VLOG(4) << "After TileTransposeTactic on [" << block_id
<< "]: " << sch->GetModule().GetExprs().front();
}

void TileTransposeTactic::TileCacheBlock(ir::IRSchedule* sch,
const std::string& block_id,
int buffer_index) {
// Step 1. Create buffer
ir::Expr shared_cache_block =
sch->CacheRead(sch->GetBlock(block_id), buffer_index, "shared");
std::string shared_cache_block_id =
ir::analyzer::GetBlockName(shared_cache_block);

ir::Expr local_cache_block =
sch->CacheRead(sch->GetBlock(block_id), buffer_index, "local");
std::string local_cache_block_id =
ir::analyzer::GetBlockName(local_cache_block);

sch->Reorder(shared_cache_block_id, common_perm_);
sch->Reorder(local_cache_block_id, common_perm_);

// Step 2. Convert the layout to [high_axis, src_low_axis, dst_low_axis]
context_->output_names.insert(shared_cache_block_id);
context_->output_names.insert(local_cache_block_id);

sch->Annotate(sch->GetBlock(shared_cache_block_id), "transpose_stage", 0);
sch->Annotate(sch->GetBlock(local_cache_block_id), "transpose_stage", 1);

// Step 3. Convert the layout to [high_axis, src_low_axis, dst_low_axis]
CanonicalizeLayout(sch, shared_cache_block_id);
CanonicalizeLayout(sch, local_cache_block_id);

// Step 4. The core
int offset = high_axis_.size();
sch->Split(shared_cache_block_id, offset + 1, {-1, 4, 8});
sch->Split(shared_cache_block_id, offset, {-1, 32});

sch->Split(local_cache_block_id, offset + 1, {-1, 32});
sch->Split(local_cache_block_id, offset, {-1, 4, 8});

sch->Reorder(shared_cache_block_id, OffsetVec({0, 2, 3, 4, 1}, offset));
sch->Reorder(local_cache_block_id, OffsetVec({0, 3, 1, 2, 4}, offset));

// Step 5. Final
FuseAndBind(sch, shared_cache_block_id, /* need_sync = */ true);
FuseAndBind(sch, local_cache_block_id, /* need_sync = */ true);
}

void TileTransposeTactic::TileBlock(ir::IRSchedule* sch,
const std::string& block_id) {
CanonicalizeLayout(sch, block_id);

int offset = high_axis_.size();
sch->Split(block_id, offset + 1, {-1, 32});
sch->Split(block_id, offset, {-1, 4, 8});

sch->Reorder(block_id, OffsetVec({0, 3, 1, 2, 4}, offset));

FuseAndBind(sch, block_id);

if (context_->output_names.count(block_id) == 0) {
ir::Expr block = sch->GetBlock(block_id);
sch->SetBuffer(block, "local");
}
}

void TileTransposeTactic::CanonicalizeLayout(ir::IRSchedule* sch,
const std::string& block_id) {
std::vector<int> order = high_axis_;
order.insert(order.end(), src_low_axis_.begin(), src_low_axis_.end());
order.insert(order.end(), dst_low_axis_.begin(), dst_low_axis_.end());

sch->Reorder(block_id, order);

std::vector<int> src_low_axis =
ArangeVec(src_low_axis_.size(), high_axis_.size());
std::vector<int> dst_low_axis =
ArangeVec(dst_low_axis_.size(), high_axis_.size() + src_low_axis_.size());

sch->Fuse(block_id, dst_low_axis);
sch->Fuse(block_id, src_low_axis);
}

void TileTransposeTactic::FuseAndBind(ir::IRSchedule* sch,
const std::string& block_id,
bool need_sync) {
int offset = high_axis_.size();
sch->Fuse(block_id, ArangeVec(offset + 2));

std::vector<ir::Expr> loops = sch->GetLoops(block_id);
sch->Bind(loops[0], "blockIdx.x");
sch->Bind(loops[2], "threadIdx.y");
sch->Bind(loops[3], "threadIdx.x");

if (need_sync) {
sch->SyncThreads(sch->GetLoops(block_id)[0], /* after_node = */ false);
}
}

} // namespace

std::unique_ptr<ScheduleTactic> CreateTileTransposeTactic() {
return std::make_unique<TileTransposeTactic>();
}

} // namespace ir
} // namespace cinn
Loading
Loading