Skip to content

Commit

Permalink
[BACKEND] Optimize code style in rewrite-tensor-pointer and add more …
Browse files Browse the repository at this point in the history
…tests (#4724)

The core Triton is a small number of people, and we receive many PRs
(thank
you!).  To help us review your code more quickly, **if you are a new
contributor (less than 3 PRs merged) we ask that you complete the
following
tasks and include the filled-out checklist in your PR description.**

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.

- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [x] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
  - [ ] This PR does not need a test because `FILL THIS IN`.

- Select one of the following.
  - [ ] I have not added any `lit` tests.
- [x] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)

---

Hello, maintainers and reviewers!

While reading the
[RewriteTensorPointer.cpp](https://github.com/triton-lang/triton/blob/main/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp)
pass, I noticed that the current implementation to be somewhat redundant
and the test is hard to understand, so I submitted this PR.

PR description:

- Use `llvm::make_early_inc_range` to ensure no issues arise during
visiting ops, instead of making a copy

```mlir
    for (auto &region : op->getRegions()) {
      for (auto &block : region) {
        SmallVector<Operation *> blockCopy;
        for (auto &nestedOp : block)
          blockCopy.push_back(&nestedOp);
        for (auto &nestedOp : blockCopy) {
          if (auto newOp = rewriteOp(nestedOp, eraser))

-> 

    for (Region &region : op->getRegions()) {
      for (Block &block : region) {
        for (Operation &nestedOp : llvm::make_early_inc_range(block)) {
          if (auto newOp = rewriteOp(&nestedOp, eraser)) {
            visitOperation(newOp, eraser);
          }
```

- Return directly from the parameter instead of constructing a new
SmallVector.

```mlir
  static SmallVector<Value>
  generateNewOperands(const SmallVector<Value> &oldOperands, unsigned index,
                      const SmallVector<Value> &newValues) {

->

  static void generateNewOperands(SmallVector<Value> &oldOperands,
                                  unsigned index, ArrayRef<Value> newValues) {
```


- delete some dead code

-  add detailed tests.

see test/Triton/rewrite-tensor-pointer.mlir

Co-authored-by: Keren Zhou <[email protected]>
  • Loading branch information
tfruan2000 and Jokeren authored Sep 13, 2024
1 parent 5000e32 commit 1f5dc71
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 117 deletions.
64 changes: 27 additions & 37 deletions lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include <memory>
#include <stack>

#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "triton/Analysis/Utility.h"
Expand Down Expand Up @@ -171,10 +173,7 @@ struct RewritedInfo {
auto otherTensorType = RankedTensorType::get(tensorShape, elementType);

// Set zero padding value
TypedAttr attr =
elementType.isIntOrIndex()
? cast<TypedAttr>(builder.getIntegerAttr(elementType, 0))
: cast<TypedAttr>(builder.getFloatAttr(elementType, 0));
TypedAttr attr = builder.getZeroAttr(elementType);

// Float NaN padding case
if (padding.value() == triton::PaddingOption::PAD_NAN) {
Expand Down Expand Up @@ -209,18 +208,20 @@ class RewriteTensorPointerPass
});
}

static SmallVector<Value>
generateNewOperands(const SmallVector<Value> &oldOperands, unsigned index,
const SmallVector<Value> &newValues) {
assert(index < oldOperands.size());
SmallVector<Value> newOperands;
for (int i = 0; i < index; ++i)
newOperands.push_back(oldOperands[i]);
for (auto value : newValues)
newOperands.push_back(value);
for (auto i = index + 1; i < oldOperands.size(); ++i)
newOperands.push_back(oldOperands[i]);
return newOperands;
static void generateNewOperands(SmallVector<Value> &oldOperands,
unsigned index, ArrayRef<Value> newValues) {
size_t size = oldOperands.size();
assert(index < size);
SmallVector<Value> operands = oldOperands;
oldOperands.reserve(size - 1 + newValues.size());
oldOperands.clear();
if (index != 0) {
oldOperands.append(operands.begin(), operands.begin() + index);
}
oldOperands.append(newValues.begin(), newValues.end());
if (index != size - 1) {
oldOperands.append(operands.begin() + index + 1, operands.end());
}
}

Operation *rewriteMakeTensorPtrOp(OpBuilder &builder,
Expand Down Expand Up @@ -358,7 +359,7 @@ class RewriteTensorPointerPass
}
auto rematerialize = [&](Block *block) {
for (Operation &opInIf : block->getOperations()) {
auto newOp = builder.clone(opInIf, mapping);
builder.clone(opInIf, mapping);
}
};
builder.setInsertionPointToStart(newOp.thenBlock());
Expand Down Expand Up @@ -403,8 +404,7 @@ class RewriteTensorPointerPass
// Expand the tensor pointer into offsets
assert(rewritedInfo.count(newIterOperands[i]));
auto info = rewritedInfo[newIterOperands[i]];
newIterOperands =
generateNewOperands(newIterOperands, i, info.getOffsets());
generateNewOperands(newIterOperands, i, info.getOffsets());
i += info.length() - 1;
size += info.length() - 1;
}
Expand Down Expand Up @@ -439,9 +439,7 @@ class RewriteTensorPointerPass
// Clone body
builder.setInsertionPointToStart(newForOp.getBody());
for (auto &opInFor : *op.getBody()) {
auto *newOp = builder.clone(opInFor, mapping);
for (unsigned i = 0; i < opInFor.getNumResults(); ++i)
mapping.map(opInFor.getResult(i), newOp->getResult(i));
builder.clone(opInFor, mapping);
}

// Replace later usages
Expand Down Expand Up @@ -476,7 +474,7 @@ class RewriteTensorPointerPass

assert(rewritedInfo.count(newOperands[i]));
auto info = rewritedInfo[newOperands[i]];
newOperands = generateNewOperands(newOperands, i, info.getOffsets());
generateNewOperands(newOperands, i, info.getOffsets());
i += info.length() - 1;
size += info.length() - 1;
}
Expand All @@ -492,15 +490,13 @@ class RewriteTensorPointerPass
// Rewrite `make_tensor_ptr` and `advance` and make a tensor of pointers
// Rewriting functions return the next operation to visit, if there is no
// next one, simply return `nullptr`
std::pair<Value, RewritedInfo> rewrited;
if (auto makeTensorPtrOp = dyn_cast<triton::MakeTensorPtrOp>(op)) {
return rewriteMakeTensorPtrOp(builder, makeTensorPtrOp, eraser);
} else if (auto advanceOp = dyn_cast<triton::AdvanceOp>(op)) {
return rewriteAdvanceOp(builder, advanceOp, eraser);
} else if (isa<triton::LoadOp>(op) || isa<triton::StoreOp>(op)) {
return rewriteLoadStoreOp(builder, op, eraser);
} else if (op->getDialect()->getNamespace() == "scf" ||
op->getDialect()->getNamespace() == "cf") {
} else if (isa<scf::SCFDialect, cf::ControlFlowDialect>(op->getDialect())) {
if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
return rewriteIfOp(builder, ifOp, eraser);
}
Expand All @@ -524,18 +520,12 @@ class RewriteTensorPointerPass
}

void visitOperation(Operation *op, std::stack<Operation *> &eraser) {
for (auto &region : op->getRegions()) {
for (auto &block : region) {
// We need an extra copy because erasing operations may break the
// iterator behavior
SmallVector<Operation *> blockCopy;
for (auto &nestedOp : block)
blockCopy.push_back(&nestedOp);

// Rewrite and recursively visit
for (auto &nestedOp : blockCopy) {
if (auto newOp = rewriteOp(nestedOp, eraser))
for (Region &region : op->getRegions()) {
for (Block &block : region) {
for (Operation &nestedOp : llvm::make_early_inc_range(block)) {
if (auto newOp = rewriteOp(&nestedOp, eraser)) {
visitOperation(newOp, eraser);
}
}
}
}
Expand Down
Loading

0 comments on commit 1f5dc71

Please sign in to comment.