Skip to content

Commit

Permalink
[BACKEND] Fix crash in coalesce pass with blocked ptr (#3866)
Browse files Browse the repository at this point in the history
The `setCoalescedEncoding` function can handle operations that have a
'mem access ptr' with type `RankedTensorType`:
```
  void
  setCoalescedEncoding(ModuleAxisInfoAnalysis &axisInfoAnalysis, Operation *op,
                       int numWarps, int threadsPerWarp,
                       llvm::MapVector<Operation *, Attribute> &layoutMap) {
    Value ptr = getMemAccessPtr(op);
    auto refTensorType = cast<RankedTensorType>(ptr.getType());
```
Therefore the caller in `runOnOperation` should avoid calling it when
the 'mem access ptr' does not have `RankedTensorType` (otherwise the
cast in the callee will fail).

---------

Signed-off-by: Tiotto, Ettore <[email protected]>
  • Loading branch information
etiotto authored May 15, 2024
1 parent 4faa131 commit 25b4212
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
8 changes: 3 additions & 5 deletions lib/Dialect/TritonGPU/Transforms/Coalesce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,11 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
Value ptr = getMemAccessPtr(curr);
if (!ptr)
return;
// We only convert `tensor<tt.ptr<>>` or `tt.ptr<tensor<>>` load/store
bool isPtrTensor = false, isTensorPointer = false;
// We only convert `tensor<tt.ptr<>>` load/store
bool isPtrTensor = false;
if (auto tensorType = dyn_cast<RankedTensorType>(ptr.getType()))
isPtrTensor = isa<PointerType>(tensorType.getElementType());
if (auto ptrType = dyn_cast<PointerType>(ptr.getType()))
isTensorPointer = isa<RankedTensorType>(ptrType.getPointeeType());
if (!isPtrTensor && !isTensorPointer)
if (!isPtrTensor)
return;
auto mod = curr->getParentOfType<ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
Expand Down
13 changes: 13 additions & 0 deletions test/TritonGPU/coalesce.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,16 @@ tt.func public @load_tensors_two_types(%arg0: !tt.ptr<f32> {tt.divisibility = 16
}

}

// -----

// COM: Reproducer for issue #3866
// CHECK-LABEL: @test_3866
// CHECK: tt.load {{.*}} : !tt.ptr<tensor<64x16xf16>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
tt.func public @test_3866(%arg0: !tt.ptr<f16>, %arg1: i32, %arg2: i64) {
%0 = tt.make_tensor_ptr %arg0, [%arg2, %arg2], [%arg2, %arg2], [%arg1, %arg1] {order = array<i32: 1, 0>} : <tensor<64x16xf16>>
%1 = tt.load %0 : !tt.ptr<tensor<64x16xf16>>
tt.return
}
}

0 comments on commit 25b4212

Please sign in to comment.