Skip to content

Commit

Permalink
[GPUHeuristic] Modify schedule generator to consider distribution of …
Browse files Browse the repository at this point in the history
…tranfer_read layout anchor (iree-org#17636)

Modify heuristic to take into account layout of transfer reads, S.T we will not generate invalid schedules who's transfer read cannot be distributed because the sizes do not match up.

For example in one matmul with N-dim with these sizes
[wgTileSize, elemPerThread, threadSize] = [192, 8, 128]. 
There is no good layout for this because, the numbers of threads
needed would be 192/8 == 24, and Since the threadSize pre-determined by 
schedule is 128, we will have 128 % 24 != 0. Hence we cannot distribute it.

This patch introduce constraints in our heuristic to solve these cases.

---------

Signed-off-by: stanley-nod <[email protected]>
  • Loading branch information
raikonenfnu authored Jun 12, 2024
1 parent c1e542d commit 52b21f8
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 38 deletions.
45 changes: 36 additions & 9 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ static int64_t calculateSharedMemoryUsedInBytes(const GPUMMASchedule &schedule,
}

bool isValidSchedule(const GPUMatmulShapeType &problem,
const GPUMMASchedule &schedule, const bool mustBeAligned) {
const GPUMMASchedule &schedule, const bool mustBeAligned,
const int64_t subgroupSize, const bool transposedLhs,
const bool transposedRhs) {
auto alignedMSize =
mustBeAligned
? problem.mSize
Expand All @@ -48,20 +50,43 @@ bool isValidSchedule(const GPUMatmulShapeType &problem,
bool isValidN = (alignedNSize % (schedule.nSize * schedule.nTileCount *
schedule.nWarpCount)) == 0;
bool isValidK = (alignedKSize % (schedule.kSize * schedule.kTileCount)) == 0;
return isValidN && isValidM && isValidK;

// Constraint to ensure wgTileSize is distributable by wgSize.
// such that we can distribute to it's corresponding vector.transfer_read.
const int64_t kMaxVectorLoadBitWidth = 128;
int64_t elemsPerThread =
kMaxVectorLoadBitWidth / problem.bType.getIntOrFloatBitWidth();
int64_t wgThreads = schedule.mWarpCount * schedule.nWarpCount * subgroupSize;

int64_t mWgSize = schedule.mSize * schedule.mTileCount * schedule.mWarpCount;
int64_t nWgSize = schedule.nSize * schedule.nTileCount * schedule.nWarpCount;
int64_t kWgSize = schedule.kSize * schedule.kTileCount;
int64_t innerLhsDimSize = transposedLhs ? mWgSize : kWgSize;
int64_t innerRhsDimSize = transposedRhs ? kWgSize : nWgSize;

bool isDistributableLhs =
(innerLhsDimSize / elemsPerThread) % wgThreads == 0 ||
wgThreads % (innerLhsDimSize / elemsPerThread) == 0;
bool isDistributableRhs =
(innerRhsDimSize / elemsPerThread) % wgThreads == 0 ||
wgThreads % (innerRhsDimSize / elemsPerThread) == 0;

return isValidN && isValidM && isValidK && isDistributableLhs &&
isDistributableRhs;
}

FailureOr<GPUMMASchedule>
fitScheduleInSharedMemory(const GPUMatmulShapeType &problem,
ArrayRef<GPUMatmulShapeType> intrinsics,
GPUMMASchedule schedule,
int64_t sharedMemLimitInBytes, bool mustBeAligned) {
FailureOr<GPUMMASchedule> fitScheduleInSharedMemory(
const GPUMatmulShapeType &problem, ArrayRef<GPUMatmulShapeType> intrinsics,
GPUMMASchedule schedule, int64_t sharedMemLimitInBytes,
int64_t subgroupSize, bool transposedLhs, bool transposedRhs,
bool mustBeAligned) {
int64_t lhsBitwidth =
intrinsics[schedule.index].aType.getIntOrFloatBitWidth();
int64_t rhsBitwidth =
intrinsics[schedule.index].bType.getIntOrFloatBitWidth();

while (!isValidSchedule(problem, schedule, mustBeAligned) ||
while (!isValidSchedule(problem, schedule, mustBeAligned, subgroupSize,
transposedLhs, transposedRhs) ||
calculateSharedMemoryUsedInBytes(schedule, lhsBitwidth, rhsBitwidth) >
sharedMemLimitInBytes) {
LLVM_DEBUG({
Expand Down Expand Up @@ -113,6 +138,7 @@ fitScheduleInSharedMemory(const GPUMatmulShapeType &problem,
FailureOr<GPUMMASchedule> deduceMMASchedule(
const GPUMatmulShapeType &problem, ArrayRef<GPUMatmulShapeType> intrinsics,
const GPUMMAHeuristicSeeds &seeds, int64_t sharedMemLimitInBytes,
int64_t subgroupSize, bool transposedLhs, bool transposedRhs,
bool canUpcastAcc, bool mustBeAligned) {
for (auto [index, intrinsic] : llvm::enumerate(intrinsics)) {
if (problem.aType != intrinsic.aType || problem.bType != intrinsic.bType) {
Expand Down Expand Up @@ -219,7 +245,8 @@ FailureOr<GPUMMASchedule> deduceMMASchedule(
GPUMMASchedule{index, intrinsic.mSize, intrinsic.nSize, intrinsic.kSize,
mWarpCount, nWarpCount, mTileCount, nTileCount,
kTileCount},
sharedMemLimitInBytes, mustBeAligned);
sharedMemLimitInBytes, subgroupSize, transposedLhs, transposedRhs,
mustBeAligned);
}
return failure();
}
Expand Down
11 changes: 7 additions & 4 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,12 @@ struct GPUMMASchedule {

/// Returns a schedule for using one of the given MMA |intrinsics| to target the
/// input |problem|. Returns std::nullopt if we cannot find such a schedule.
FailureOr<GPUMMASchedule> deduceMMASchedule(
const GPUMatmulShapeType &problem, ArrayRef<GPUMatmulShapeType> intrinsics,
const GPUMMAHeuristicSeeds &seeds, int64_t sharedMemLimitInBytes,
bool canUpcastAcc = false, bool mustBeAligned = true);
FailureOr<GPUMMASchedule>
deduceMMASchedule(const GPUMatmulShapeType &problem,
ArrayRef<GPUMatmulShapeType> intrinsics,
const GPUMMAHeuristicSeeds &seeds,
int64_t sharedMemLimitInBytes, int64_t subgroupSize,
bool transposedLhs = false, bool transposedRhs = false,
bool canUpcastAcc = false, bool mustBeAligned = true);

} // namespace mlir::iree_compiler
30 changes: 22 additions & 8 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,13 +299,14 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target,
int64_t maxSharedMemoryBytes = target.getWgp().getMaxWorkgroupMemoryBytes();

// First try to find a schedule with an exactly matching intrinsic.
FailureOr<GPUMMASchedule> schedule =
deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes);
FailureOr<GPUMMASchedule> schedule = deduceMMASchedule(
problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize);
if (failed(schedule)) {
// Then try again by allowing upcasting accumulator.
schedule =
deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes,
/*canUpcastAcc=*/true);
schedule = deduceMMASchedule(
problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize,
/*transposedLhs*/ false, /*transposedRhs*/ false,
/*canUpcastAcc=*/true);
}
if (failed(schedule)) {
return failure();
Expand Down Expand Up @@ -465,14 +466,25 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target,

LDBG("Matmul Vector Distribution Config");

// First try to find a schedule with an exactly matching intrinsic.
auto pipeline = CodeGenPipeline::LLVMGPUVectorDistribute;
std::optional<GPUMMASchedule> schedule =
deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes);

// Infer if lhs or rhs is transposed to help generate better schedule.
SmallVector<AffineMap> maps = op.getIndexingMapsArray();
bool transposedLhs =
kDim !=
llvm::cast<AffineDimExpr>(maps[0].getResults().back()).getPosition();
bool transposedRhs =
nDim !=
llvm::cast<AffineDimExpr>(maps[1].getResults().back()).getPosition();

// First try to find a schedule with an exactly matching intrinsic.
std::optional<GPUMMASchedule> schedule = deduceMMASchedule(
problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize);
if (!schedule) {
// Then try again by allowing upcasting accumulator.
schedule =
deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes,
targetSubgroupSize, transposedLhs, transposedRhs,
/*canUpcastAcc=*/true);
}

Expand All @@ -485,11 +497,13 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target,
bool mustBeAligned = false;
schedule =
deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes,
targetSubgroupSize, transposedLhs, transposedRhs,
/*canUpcastAcc=*/false, mustBeAligned);
if (!schedule) {
// Then try again by allowing upcasting accumulator.
schedule =
deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes,
targetSubgroupSize, transposedLhs, transposedRhs,
/*canUpcastAcc=*/true, mustBeAligned);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx940 --iree-codegen-llvmgpu-use-vector-distribution --iree-llvmgpu-enable-prefetch=true \
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-llvmgpu-select-lowering-strategy, func.func(iree-llvmgpu-lower-executable-target)))))" %s | FileCheck %s
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx1100 --iree-codegen-llvmgpu-use-vector-distribution --iree-llvmgpu-enable-prefetch=true \
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-llvmgpu-select-lowering-strategy, func.func(iree-llvmgpu-lower-executable-target)))))" %s | FileCheck %s --check-prefix=CDNA3
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-llvmgpu-select-lowering-strategy, func.func(iree-llvmgpu-lower-executable-target)))))" %s | FileCheck %s --check-prefix=RDNA3

// TODO: This test is still using the legacy LLVMGPU kernel config. This needs
// to be migrated to the rocdl heuristics, but for now is just physically
Expand Down Expand Up @@ -318,22 +318,22 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
}
}

// CDNA3: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [64, 2, 1] subgroup_size = 32
// CDNA3-SAME: mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>,
// CDNA3-SAME: subgroup_m_count = 2, subgroup_n_count = 2>
// RDNA3: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [64, 2, 1] subgroup_size = 32
// RDNA3-SAME: mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>,
// RDNA3-SAME: subgroup_m_count = 2, subgroup_n_count = 2>


// CDNA3-LABEL: func.func @matmul_256x256x256_f16_f32
// CDNA3-SAME: translation_info = #[[$TRANSLATION]]
// CDNA3: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args({{.*}}) -> (vector<2x2x8x1x1x1xf32>)
// RDNA3-LABEL: func.func @matmul_256x256x256_f16_f32
// RDNA3-SAME: translation_info = #[[$TRANSLATION]]
// RDNA3: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args({{.*}}) -> (vector<2x2x8x1x1x1xf32>)
// Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times
// along the K dimension. So in total 32 wmma ops.
// CDNA3-COUNT-32: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<8xf32>
// CDNA3: scf.yield %{{.+}} : vector<2x2x8x1x1x1xf32>
// RDNA3-COUNT-32: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<8xf32>
// RDNA3: scf.yield %{{.+}} : vector<2x2x8x1x1x1xf32>
// Since each subgroup handles 2 * 2 tiles, and for each tile, each lane holds 4 values.
// we will have 32 writes. We cannot do contiguous writes since the outputs columns has interleaved
// thread ids.
// CDNA3-COUNT-32: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<1x1xf32>, memref<256x256xf32, #hal.descriptor_type<storage_buffer>>
// RDNA3-COUNT-32: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<1x1xf32>, memref<256x256xf32, #hal.descriptor_type<storage_buffer>>

// -----

Expand Down Expand Up @@ -408,3 +408,68 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
// CHECK: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK: %[[OUT_GLOBAL_SUB:.+]] = memref.subview %[[OUT_GLOBAL]]
// CHECK: vector.transfer_write %{{.+}}, %[[OUT_GLOBAL_SUB]]

// -----

// This test ensures that we are generating contraction schedules does not only work on contraction,
// but also will be compatible with transfer_read layouts anchors.
// Currently the transfer_read layout anchors expects WorkgroupSize % (WgTileSize / numelPerThread) == 0.
// this test ensure that this constraint is satisfied.

// NOTE: This test is not exhaustive of all possible ways the above condition is breaking,
// but rather is an example of a matmul shape from a model that broke our compilation heuristic.

#pipeline_layout = #hal.pipeline.layout<
push_constants = 3,
sets = [
<0, bindings = [
<0, storage_buffer, ReadOnly>,
<1, storage_buffer>
]>
]>
hal.executable public @contract_schedule_considering_read_layout {
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
hal.executable.export public @contract_schedule_considering_read_layout ordinal(0) layout(#pipeline_layout) {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @contract_schedule_considering_read_layout() {
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = arith.index_castui %0 : i32 to index
%4 = arith.index_castui %1 : i32 to index
%5 = arith.index_castui %2 : i32 to index
%6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%3) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x160x1536xf16>>
%7 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%4) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x1536x1536xf16>>
%8 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%5) : !flow.dispatch.tensor<writeonly:tensor<2x160x1536xf16>>
%9 = flow.dispatch.tensor.load %6, offsets = [0, 0, 0], sizes = [2, 160, 1536], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x160x1536xf16>> -> tensor<2x160x1536xf16>
%10 = flow.dispatch.tensor.load %7, offsets = [0, 0, 0], sizes = [2, 1536, 1536], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x1536x1536xf16>> -> tensor<2x1536x1536xf16>
%11 = tensor.empty() : tensor<2x160x1536xf16>
%12 = linalg.fill ins(%cst : f16) outs(%11 : tensor<2x160x1536xf16>) -> tensor<2x160x1536xf16>
%13 = linalg.batch_matmul ins(%9, %10 : tensor<2x160x1536xf16>, tensor<2x1536x1536xf16>) outs(%12 : tensor<2x160x1536xf16>) -> tensor<2x160x1536xf16>
flow.dispatch.tensor.store %13, %8, offsets = [0, 0, 0], sizes = [2, 160, 1536], strides = [1, 1, 1] : tensor<2x160x1536xf16> -> !flow.dispatch.tensor<writeonly:tensor<2x160x1536xf16>>
return
}
}
}
}
// Basic pipeline test to make sure it generates the instructions we expect.

// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64
// CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>,
// CHECK-SAME: subgroup_m_count = 1, subgroup_n_count = 4>

// CHECK-LABEL: func.func @contract_schedule_considering_read_layout()
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK-DAG: %[[RHS_SHARED:.+]] = memref.alloc() : memref<128x132xf16, #gpu.address_space<workgroup>>
// CHECK-DAG: %[[RHS_SHARED_SUB:.+]] = memref.subview %[[RHS_SHARED]][0, 0] [128, 128] [1, 1]
// CHECK-DAG: %[[LHS_SHARED:.+]] = memref.alloc() : memref<16x132xf16, #gpu.address_space<workgroup>>
// CHECK-DAG: %[[LHS_SHARED_SUB:.+]] = memref.subview %[[LHS_SHARED]][0, 0] [16, 128] [1, 1]
// CHECK: scf.for {{.*}} = %c0 to %c11 step %c1 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<1x2x1x1x4x1xf16>)
// CHECK-COUNT-16: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK: scf.yield
// CHECK-COUNT-16: amdgpu.mfma
24 changes: 17 additions & 7 deletions compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -913,13 +913,6 @@ LogicalResult setCooperativeMatrixConfig(
int64_t sharedMemoryLimitInBytes =
targetEnv.getResourceLimits().getMaxComputeSharedMemorySize();

FailureOr<GPUMMASchedule> schedule =
deduceMMASchedule(problem, intrinsics, seeds, sharedMemoryLimitInBytes);
if (failed(schedule))
return failure();

auto pipeline = CodeGenPipeline::SPIRVCooperativeMatrixVectorize;

std::optional<int64_t> subgroupSize = limits.getSubgroupSize();
// AMD RDNA architectures supports both wave32 and wave64 modes. Prefer to use
// wave32 mode for better performance.
Expand All @@ -928,6 +921,23 @@ LogicalResult setCooperativeMatrixConfig(
subgroupSize = *minSize;
}

// Infer if lhs or rhs is transposed to help generate better schedule.
SmallVector<AffineMap> maps = op.getIndexingMapsArray();
bool transposedLhs =
kIndex !=
llvm::cast<AffineDimExpr>(maps[0].getResults().back()).getPosition();
bool transposedRhs =
nIndex !=
llvm::cast<AffineDimExpr>(maps[1].getResults().back()).getPosition();

FailureOr<GPUMMASchedule> schedule =
deduceMMASchedule(problem, intrinsics, seeds, sharedMemoryLimitInBytes,
*subgroupSize, transposedLhs, transposedRhs);
if (failed(schedule))
return failure();

auto pipeline = CodeGenPipeline::SPIRVCooperativeMatrixVectorize;

std::array<int64_t, 3> workgroupSize{schedule->nWarpCount * *subgroupSize,
schedule->mWarpCount, 1};

Expand Down

0 comments on commit 52b21f8

Please sign in to comment.