Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Improve the code generated by the RewriteTensorPointer pass. #1766

Open
mfrancepillois opened this issue Aug 2, 2024 · 1 comment · Fixed by #2181 · May be fixed by #2359
Open

[Performance] Improve the code generated by the RewriteTensorPointer pass. #1766

mfrancepillois opened this issue Aug 2, 2024 · 1 comment · Fixed by #2181 · May be fixed by #2359
Assignees
Labels
enhancement New feature or request performance

Comments

@mfrancepillois
Copy link
Contributor

When the Triton::MakeTensorPtrOp has to be rewritten by the RewriteTensorPointer pass to use "regular" memory operations, the generated code seems less performant than a code directly written using regular operations.
Indeed, the Trtion::AdvanceOp are used as anchors to generate the new memory accesses, which cause the entire code that calculates the pointers to be inside the loop, while a significant part of these instructions could be hoisted outside the loop.

For example:
The relevant section of the TritonGPU MLIR code of the 03 tutorial, looks like;

%15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> 
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> 
%17 = tt.splat %14 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%18 = tt.splat %14 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%19 = arith.addi %17, %15 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> 
%20 = arith.addi %18, %16 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> 
%21 = tt.splat %arg3 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%22 = arith.remsi %19, %21 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%23 = arith.muli %13, %c256_i32 : i32
%24 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%25 = tt.splat %23 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%26 = arith.addi %25, %24 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%27 = tt.splat %arg4 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%28 = arith.remsi %26, %27 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%29 = tt.expand_dims %22 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1>
%30 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1>
%31 = arith.muli %29, %30 : tensor<128x1xi32, #blocked1>
%32 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%33 = tt.expand_dims %32 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1>
%34 = tt.broadcast %31 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1>
%35 = tt.broadcast %33 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1>
%36 = arith.addi %34, %35 : tensor<128x64xi32, #blocked1>
%37 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #blocked1>
%38 = tt.addptr %37, %36 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%39 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%40 = tt.expand_dims %39 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
%41 = tt.splat %arg7 : i32 -> tensor<64x1xi32, #blocked>
%42 = arith.muli %40, %41 : tensor<64x1xi32, #blocked>
%43 = tt.expand_dims %28 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked>
%44 = tt.broadcast %42 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked>
%45 = tt.broadcast %43 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked>
%46 = arith.addi %44, %45 : tensor<64x256xi32, #blocked>
%47 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<64x256x!tt.ptr<f16>, #blocked>
%48 = tt.addptr %47, %46 : tensor<64x256x!tt.ptr<f16>, #blocked>, tensor<64x256xi32, #blocked>
%49 = arith.addi %arg5, %c63_i32 : i32
%50 = arith.divsi %49, %c64_i32 : i32
%51 = arith.muli %arg7, %c64_i32 : i32
%52 = tt.splat %51 : i32 -> tensor<64x256xi32, #blocked>
%53:3 = scf.for %arg9 = %c0_i32 to %50 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %38, %arg12 = %48) -> (tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<64x256x!tt.ptr<f16>, #blocked>)  : i32 {
      %72 = arith.muli %arg9, %c64_i32 : i32 
      %73 = arith.subi %arg5, %72 : i32 
      %74 = tt.splat %73 : i32 -> tensor<1x64xi32, #blocked1> 
      %75 = arith.cmpi slt, %33, %74 : tensor<1x64xi32, #blocked1>
      %76 = tt.broadcast %75 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> 
      %77 = tt.load %arg11, %76, %cst_1 : tensor<128x64x!tt.ptr<f16>, #blocked1> 
      %78 = triton_gpu.local_alloc %77 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> 
      %79 = tt.splat %73 : i32 -> tensor<64x1xi32, #blocked> 
      %80 = arith.cmpi slt, %40, %79 : tensor<64x1xi32, #blocked> 
      %81 = tt.broadcast %80 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> 
      %82 = tt.load %arg12, %81, %cst_0 : tensor<64x256x!tt.ptr<f16>, #blocked>
      %83 = triton_gpu.local_alloc %82 : (tensor<64x256xf16, #blocked>) -> !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory>
      %84 = triton_gpu.local_load %78 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
      %85 = triton_gpu.local_load %83 : !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
      %86 = tt.dot %84, %85, %arg10, inputPrecision = tf32 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x256xf32, #mma> 
      %87 = tt.addptr %arg11, %cst_2 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> 
      %88 = tt.addptr %arg12, %52 : tensor<64x256x!tt.ptr<f16>, #blocked>, tensor<64x256xi32, #blocked> 
      scf.yield %86, %87, %88 : tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<64x256x!tt.ptr<f16>, #blocked>
    }

While the code for an equivalent of the 03 tutorial using block pointers (after forcing the block pointers to be rewritten) looks like:

%18 = arith.extsi %arg7 : i32 to i64
%19 = arith.extsi %17 : i32 to i64
%20 = arith.addi %arg5, %c63_i32 : i32
%21 = arith.divsi %20, %c64_i32 : i32 
%22:3 = scf.for %arg9 = %c0_i32 to %21 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %c0_i64, %arg12 = %c0_i64) -> (tensor<128x256xf32, #mma>, i64, i64)  : i32 {
      %43 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #blocked> 
      %44 = tt.splat %16 : i64 -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> 
      %45 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> 
      %46 = arith.extsi %45 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> 
      %47 = arith.addi %44, %46 : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> 
      %48 = tt.expand_dims %47 {axis = 1 : i32} : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi64, #blocked> 
      %49 = tt.splat %15 : i64 -> tensor<128x1xi64, #blocked> 
      %50 = arith.muli %48, %49 : tensor<128x1xi64, #blocked> 
      %51 = tt.broadcast %50 : tensor<128x1xi64, #blocked> -> tensor<128x64xi64, #blocked> 
      %52 = tt.addptr %43, %51 : tensor<128x64x!tt.ptr<f16>, #blocked>, tensor<128x64xi64, #blocked> 
      %53 = tt.splat %arg11 : i64 -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> 
      %54 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> 
      %55 = arith.extsi %54 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> to tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> 
      %56 = arith.addi %53, %55 : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> 
      %57 = tt.expand_dims %56 {axis = 0 : i32} : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi64, #blocked> 
      %58 = tt.broadcast %57 : tensor<1x64xi64, #blocked> -> tensor<128x64xi64, #blocked> 
      %59 = tt.addptr %52, %58 : tensor<128x64x!tt.ptr<f16>, #blocked>, tensor<128x64xi64, #blocked> 
      %60 = tt.load %59 : tensor<128x64x!tt.ptr<f16>, #blocked> 
      %61 = triton_gpu.local_alloc %60 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> 
      %62 = triton_gpu.local_load %61 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> 
      %63 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<64x256x!tt.ptr<f16>, #blocked1> 
      %64 = tt.splat %arg12 : i64 -> tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> 
      %65 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> 
      %66 = arith.extsi %65 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> to tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> 
      %67 = arith.addi %64, %66 : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> 
      %68 = tt.expand_dims %67 {axis = 1 : i32} : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi64, #blocked1> 
      %69 = tt.splat %18 : i64 -> tensor<64x1xi64, #blocked1> 
      %70 = arith.muli %68, %69 : tensor<64x1xi64, #blocked1> 
      %71 = tt.broadcast %70 : tensor<64x1xi64, #blocked1> -> tensor<64x256xi64, #blocked1> 
      %72 = tt.addptr %63, %71 : tensor<64x256x!tt.ptr<f16>, #blocked1>, tensor<64x256xi64, #blocked1> 
      %73 = tt.splat %19 : i64 -> tensor<256xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> 
      %74 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> 
      %75 = arith.extsi %74 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> to tensor<256xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> 
      %76 = arith.addi %73, %75 : tensor<256xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> 
      %77 = tt.expand_dims %76 {axis = 0 : i32} : tensor<256xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi64, #blocked1> 
      %78 = tt.broadcast %77 : tensor<1x256xi64, #blocked1> -> tensor<64x256xi64, #blocked1> 
      %79 = tt.addptr %72, %78 : tensor<64x256x!tt.ptr<f16>, #blocked1>, tensor<64x256xi64, #blocked1> 
      %80 = tt.load %79 : tensor<64x256x!tt.ptr<f16>, #blocked1> 
      %81 = triton_gpu.local_alloc %80 : (tensor<64x256xf16, #blocked1>) -> !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> 
      %82 = triton_gpu.local_load %81 : !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> 
      %83 = tt.dot %62, %82, %arg10, inputPrecision = tf32 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x256xf32, #mma>
      %84 = arith.addi %arg11, %c64_i64 : i64
      %85 = arith.addi %arg12, %c64_i64 : i64
      scf.yield %83, %84, %85 : tensor<128x256xf32, #mma>, i64, i64
    }

The RewriteTensorPointer pass should therefore be optimized to hoist these extra instructions out of the loop.

@mfrancepillois mfrancepillois added enhancement New feature or request performance labels Aug 2, 2024
@chengjunlu chengjunlu self-assigned this Aug 30, 2024
@chengjunlu
Copy link
Contributor

We can fallback to gather/scatter memory accessing when lowering tt.load and tt.store from TTGIR to SIMT LLVM.
The offsets can be re-calculated with the information of the block pointer when lowering.

So that we can remove the RewriterTensorPointer pass which maybe not efficient.

@chengjunlu chengjunlu changed the title Improve the code generated by the RewriteTensorPointer pass. [Performance] Improve the code generated by the RewriteTensorPointer pass. Aug 30, 2024
@whitneywhtsang whitneywhtsang linked a pull request Sep 26, 2024 that will close this issue
@whitneywhtsang whitneywhtsang linked a pull request Sep 26, 2024 that will close this issue
@vlad-penkin vlad-penkin assigned etiotto and unassigned etiotto and chengjunlu Oct 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment