diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp index 70ff8ed451ac..f41846abadfe 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp @@ -30,7 +30,9 @@ static int64_t calculateSharedMemoryUsedInBytes(const GPUMMASchedule &schedule, } bool isValidSchedule(const GPUMatmulShapeType &problem, - const GPUMMASchedule &schedule, const bool mustBeAligned) { + const GPUMMASchedule &schedule, const bool mustBeAligned, + const int64_t subgroupSize, const bool transposedLhs, + const bool transposedRhs) { auto alignedMSize = mustBeAligned ? problem.mSize @@ -48,20 +50,43 @@ bool isValidSchedule(const GPUMatmulShapeType &problem, bool isValidN = (alignedNSize % (schedule.nSize * schedule.nTileCount * schedule.nWarpCount)) == 0; bool isValidK = (alignedKSize % (schedule.kSize * schedule.kTileCount)) == 0; - return isValidN && isValidM && isValidK; + + // Constraint to ensure wgTileSize is distributable by wgSize. + // such that we can distribute to it's corresponding vector.transfer_read. + const int64_t kMaxVectorLoadBitWidth = 128; + int64_t elemsPerThread = + kMaxVectorLoadBitWidth / problem.bType.getIntOrFloatBitWidth(); + int64_t wgThreads = schedule.mWarpCount * schedule.nWarpCount * subgroupSize; + + int64_t mWgSize = schedule.mSize * schedule.mTileCount * schedule.mWarpCount; + int64_t nWgSize = schedule.nSize * schedule.nTileCount * schedule.nWarpCount; + int64_t kWgSize = schedule.kSize * schedule.kTileCount; + int64_t innerLhsDimSize = transposedLhs ? mWgSize : kWgSize; + int64_t innerRhsDimSize = transposedRhs ? kWgSize : nWgSize; + + bool isDistributableLhs = + (innerLhsDimSize / elemsPerThread) % wgThreads == 0 || + wgThreads % (innerLhsDimSize / elemsPerThread) == 0; + bool isDistributableRhs = + (innerRhsDimSize / elemsPerThread) % wgThreads == 0 || + wgThreads % (innerRhsDimSize / elemsPerThread) == 0; + + return isValidN && isValidM && isValidK && isDistributableLhs && + isDistributableRhs; } -FailureOr -fitScheduleInSharedMemory(const GPUMatmulShapeType &problem, - ArrayRef intrinsics, - GPUMMASchedule schedule, - int64_t sharedMemLimitInBytes, bool mustBeAligned) { +FailureOr fitScheduleInSharedMemory( + const GPUMatmulShapeType &problem, ArrayRef intrinsics, + GPUMMASchedule schedule, int64_t sharedMemLimitInBytes, + int64_t subgroupSize, bool transposedLhs, bool transposedRhs, + bool mustBeAligned) { int64_t lhsBitwidth = intrinsics[schedule.index].aType.getIntOrFloatBitWidth(); int64_t rhsBitwidth = intrinsics[schedule.index].bType.getIntOrFloatBitWidth(); - while (!isValidSchedule(problem, schedule, mustBeAligned) || + while (!isValidSchedule(problem, schedule, mustBeAligned, subgroupSize, + transposedLhs, transposedRhs) || calculateSharedMemoryUsedInBytes(schedule, lhsBitwidth, rhsBitwidth) > sharedMemLimitInBytes) { LLVM_DEBUG({ @@ -113,6 +138,7 @@ fitScheduleInSharedMemory(const GPUMatmulShapeType &problem, FailureOr deduceMMASchedule( const GPUMatmulShapeType &problem, ArrayRef intrinsics, const GPUMMAHeuristicSeeds &seeds, int64_t sharedMemLimitInBytes, + int64_t subgroupSize, bool transposedLhs, bool transposedRhs, bool canUpcastAcc, bool mustBeAligned) { for (auto [index, intrinsic] : llvm::enumerate(intrinsics)) { if (problem.aType != intrinsic.aType || problem.bType != intrinsic.bType) { @@ -219,7 +245,8 @@ FailureOr deduceMMASchedule( GPUMMASchedule{index, intrinsic.mSize, intrinsic.nSize, intrinsic.kSize, mWarpCount, nWarpCount, mTileCount, nTileCount, kTileCount}, - sharedMemLimitInBytes, mustBeAligned); + sharedMemLimitInBytes, subgroupSize, transposedLhs, transposedRhs, + mustBeAligned); } return failure(); } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h index 596fca4ed5e9..bfcb3a018268 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h @@ -46,9 +46,12 @@ struct GPUMMASchedule { /// Returns a schedule for using one of the given MMA |intrinsics| to target the /// input |problem|. Returns std::nullopt if we cannot find such a schedule. -FailureOr deduceMMASchedule( - const GPUMatmulShapeType &problem, ArrayRef intrinsics, - const GPUMMAHeuristicSeeds &seeds, int64_t sharedMemLimitInBytes, - bool canUpcastAcc = false, bool mustBeAligned = true); +FailureOr +deduceMMASchedule(const GPUMatmulShapeType &problem, + ArrayRef intrinsics, + const GPUMMAHeuristicSeeds &seeds, + int64_t sharedMemLimitInBytes, int64_t subgroupSize, + bool transposedLhs = false, bool transposedRhs = false, + bool canUpcastAcc = false, bool mustBeAligned = true); } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index a8f38ba5c60a..d73740f54c79 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -299,13 +299,14 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, int64_t maxSharedMemoryBytes = target.getWgp().getMaxWorkgroupMemoryBytes(); // First try to find a schedule with an exactly matching intrinsic. - FailureOr schedule = - deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes); + FailureOr schedule = deduceMMASchedule( + problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize); if (failed(schedule)) { // Then try again by allowing upcasting accumulator. - schedule = - deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes, - /*canUpcastAcc=*/true); + schedule = deduceMMASchedule( + problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize, + /*transposedLhs*/ false, /*transposedRhs*/ false, + /*canUpcastAcc=*/true); } if (failed(schedule)) { return failure(); @@ -465,14 +466,25 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, LDBG("Matmul Vector Distribution Config"); - // First try to find a schedule with an exactly matching intrinsic. auto pipeline = CodeGenPipeline::LLVMGPUVectorDistribute; - std::optional schedule = - deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes); + + // Infer if lhs or rhs is transposed to help generate better schedule. + SmallVector maps = op.getIndexingMapsArray(); + bool transposedLhs = + kDim != + llvm::cast(maps[0].getResults().back()).getPosition(); + bool transposedRhs = + nDim != + llvm::cast(maps[1].getResults().back()).getPosition(); + + // First try to find a schedule with an exactly matching intrinsic. + std::optional schedule = deduceMMASchedule( + problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize); if (!schedule) { // Then try again by allowing upcasting accumulator. schedule = deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes, + targetSubgroupSize, transposedLhs, transposedRhs, /*canUpcastAcc=*/true); } @@ -485,11 +497,13 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, bool mustBeAligned = false; schedule = deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes, + targetSubgroupSize, transposedLhs, transposedRhs, /*canUpcastAcc=*/false, mustBeAligned); if (!schedule) { // Then try again by allowing upcasting accumulator. schedule = deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes, + targetSubgroupSize, transposedLhs, transposedRhs, /*canUpcastAcc=*/true, mustBeAligned); } } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir index df81dc2afeaf..a69bba567c5a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir @@ -1,7 +1,7 @@ // RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx940 --iree-codegen-llvmgpu-use-vector-distribution --iree-llvmgpu-enable-prefetch=true \ // RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-llvmgpu-select-lowering-strategy, func.func(iree-llvmgpu-lower-executable-target)))))" %s | FileCheck %s // RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx1100 --iree-codegen-llvmgpu-use-vector-distribution --iree-llvmgpu-enable-prefetch=true \ -// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-llvmgpu-select-lowering-strategy, func.func(iree-llvmgpu-lower-executable-target)))))" %s | FileCheck %s --check-prefix=CDNA3 +// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-llvmgpu-select-lowering-strategy, func.func(iree-llvmgpu-lower-executable-target)))))" %s | FileCheck %s --check-prefix=RDNA3 // TODO: This test is still using the legacy LLVMGPU kernel config. This needs // to be migrated to the rocdl heuristics, but for now is just physically @@ -318,22 +318,22 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { } } -// CDNA3: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info, -// CDNA3-SAME: subgroup_m_count = 2, subgroup_n_count = 2> +// RDNA3: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info, +// RDNA3-SAME: subgroup_m_count = 2, subgroup_n_count = 2> -// CDNA3-LABEL: func.func @matmul_256x256x256_f16_f32 -// CDNA3-SAME: translation_info = #[[$TRANSLATION]] -// CDNA3: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args({{.*}}) -> (vector<2x2x8x1x1x1xf32>) +// RDNA3-LABEL: func.func @matmul_256x256x256_f16_f32 +// RDNA3-SAME: translation_info = #[[$TRANSLATION]] +// RDNA3: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args({{.*}}) -> (vector<2x2x8x1x1x1xf32>) // Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times // along the K dimension. So in total 32 wmma ops. -// CDNA3-COUNT-32: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<8xf32> -// CDNA3: scf.yield %{{.+}} : vector<2x2x8x1x1x1xf32> +// RDNA3-COUNT-32: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<8xf32> +// RDNA3: scf.yield %{{.+}} : vector<2x2x8x1x1x1xf32> // Since each subgroup handles 2 * 2 tiles, and for each tile, each lane holds 4 values. // we will have 32 writes. We cannot do contiguous writes since the outputs columns has interleaved // thread ids. -// CDNA3-COUNT-32: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<1x1xf32>, memref<256x256xf32, #hal.descriptor_type> +// RDNA3-COUNT-32: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<1x1xf32>, memref<256x256xf32, #hal.descriptor_type> // ----- @@ -408,3 +408,68 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // CHECK: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> // CHECK: %[[OUT_GLOBAL_SUB:.+]] = memref.subview %[[OUT_GLOBAL]] // CHECK: vector.transfer_write %{{.+}}, %[[OUT_GLOBAL_SUB]] + +// ----- + +// This test ensures that we are generating contraction schedules does not only work on contraction, +// but also will be compatible with transfer_read layouts anchors. +// Currently the transfer_read layout anchors expects WorkgroupSize % (WgTileSize / numelPerThread) == 0. +// this test ensure that this constraint is satisfied. + +// NOTE: This test is not exhaustive of all possible ways the above condition is breaking, +// but rather is an example of a matmul shape from a model that broke our compilation heuristic. + +#pipeline_layout = #hal.pipeline.layout< + push_constants = 3, + sets = [ + <0, bindings = [ + <0, storage_buffer, ReadOnly>, + <1, storage_buffer> + ]> + ]> +hal.executable public @contract_schedule_considering_read_layout { + hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) { + hal.executable.export public @contract_schedule_considering_read_layout ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %x, %y, %z = flow.dispatch.workgroup_count_from_slice + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @contract_schedule_considering_read_layout() { + %cst = arith.constant 0.000000e+00 : f16 + %0 = hal.interface.constant.load[0] : i32 + %1 = hal.interface.constant.load[1] : i32 + %2 = hal.interface.constant.load[2] : i32 + %3 = arith.index_castui %0 : i32 to index + %4 = arith.index_castui %1 : i32 to index + %5 = arith.index_castui %2 : i32 to index + %6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%3) flags(ReadOnly) : !flow.dispatch.tensor> + %7 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%4) flags(ReadOnly) : !flow.dispatch.tensor> + %8 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%5) : !flow.dispatch.tensor> + %9 = flow.dispatch.tensor.load %6, offsets = [0, 0, 0], sizes = [2, 160, 1536], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x160x1536xf16> + %10 = flow.dispatch.tensor.load %7, offsets = [0, 0, 0], sizes = [2, 1536, 1536], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x1536x1536xf16> + %11 = tensor.empty() : tensor<2x160x1536xf16> + %12 = linalg.fill ins(%cst : f16) outs(%11 : tensor<2x160x1536xf16>) -> tensor<2x160x1536xf16> + %13 = linalg.batch_matmul ins(%9, %10 : tensor<2x160x1536xf16>, tensor<2x1536x1536xf16>) outs(%12 : tensor<2x160x1536xf16>) -> tensor<2x160x1536xf16> + flow.dispatch.tensor.store %13, %8, offsets = [0, 0, 0], sizes = [2, 160, 1536], strides = [1, 1, 1] : tensor<2x160x1536xf16> -> !flow.dispatch.tensor> + return + } + } + } +} +// Basic pipeline test to make sure it generates the instructions we expect. + +// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info, +// CHECK-SAME: subgroup_m_count = 1, subgroup_n_count = 4> + +// CHECK-LABEL: func.func @contract_schedule_considering_read_layout() +// CHECK-SAME: translation_info = #[[$TRANSLATION]] +// CHECK-DAG: %[[RHS_SHARED:.+]] = memref.alloc() : memref<128x132xf16, #gpu.address_space> +// CHECK-DAG: %[[RHS_SHARED_SUB:.+]] = memref.subview %[[RHS_SHARED]][0, 0] [128, 128] [1, 1] +// CHECK-DAG: %[[LHS_SHARED:.+]] = memref.alloc() : memref<16x132xf16, #gpu.address_space> +// CHECK-DAG: %[[LHS_SHARED_SUB:.+]] = memref.subview %[[LHS_SHARED]][0, 0] [16, 128] [1, 1] +// CHECK: scf.for {{.*}} = %c0 to %c11 step %c1 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<1x2x1x1x4x1xf16>) +// CHECK-COUNT-16: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> +// CHECK: scf.yield +// CHECK-COUNT-16: amdgpu.mfma diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp index 1db80aeef7d7..e491182c1054 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp @@ -913,13 +913,6 @@ LogicalResult setCooperativeMatrixConfig( int64_t sharedMemoryLimitInBytes = targetEnv.getResourceLimits().getMaxComputeSharedMemorySize(); - FailureOr schedule = - deduceMMASchedule(problem, intrinsics, seeds, sharedMemoryLimitInBytes); - if (failed(schedule)) - return failure(); - - auto pipeline = CodeGenPipeline::SPIRVCooperativeMatrixVectorize; - std::optional subgroupSize = limits.getSubgroupSize(); // AMD RDNA architectures supports both wave32 and wave64 modes. Prefer to use // wave32 mode for better performance. @@ -928,6 +921,23 @@ LogicalResult setCooperativeMatrixConfig( subgroupSize = *minSize; } + // Infer if lhs or rhs is transposed to help generate better schedule. + SmallVector maps = op.getIndexingMapsArray(); + bool transposedLhs = + kIndex != + llvm::cast(maps[0].getResults().back()).getPosition(); + bool transposedRhs = + nIndex != + llvm::cast(maps[1].getResults().back()).getPosition(); + + FailureOr schedule = + deduceMMASchedule(problem, intrinsics, seeds, sharedMemoryLimitInBytes, + *subgroupSize, transposedLhs, transposedRhs); + if (failed(schedule)) + return failure(); + + auto pipeline = CodeGenPipeline::SPIRVCooperativeMatrixVectorize; + std::array workgroupSize{schedule->nWarpCount * *subgroupSize, schedule->mWarpCount, 1};