Skip to content

Commit

Permalink
[Codegen][GPU] Loosen dim mapping restrictions on forall fusion (iree…
Browse files Browse the repository at this point in the history
…-org#17612)

The `FuseForalls` pattern is restricted to forall loops that have
equivalent dim mappings, but it does not have to be. This PR loosens the
restriction to just require equivalent mapping types for the 2 forall
loops, with an equivalent first dim mapping.

---------

Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 authored Jun 10, 2024
1 parent 8ab07d2 commit d7744b7
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,27 +56,27 @@ static FailureOr<int64_t> getTripCount(scf::ForallOp loop) {
static FailureOr<SmallVector<scf::ForallOp>>
getEquivalentMappingConsumerLoopNest(scf::ForallOp producer,
scf::ForallOp consumer) {

auto checkMappingTypes = [&](ArrayAttr array) {
return llvm::all_of(array.getValue(),
llvm::IsaPred<gpu::GPUThreadMappingAttr>) ||
llvm::all_of(array.getValue(),
llvm::IsaPred<gpu::GPUWarpMappingAttr>);
auto checkMappingTypes = [&](ArrayRef<Attribute> array) {
return llvm::all_of(array, llvm::IsaPred<gpu::GPUThreadMappingAttr>) ||
llvm::all_of(array, llvm::IsaPred<gpu::GPUWarpMappingAttr>);
};

ArrayAttr producerMapping = producer.getMappingAttr();
ArrayAttr consumerMapping = consumer.getMappingAttr();
ArrayRef<Attribute> producerMapping = producer.getMappingAttr().getValue();
ArrayRef<Attribute> consumerMapping = consumer.getMappingAttr().getValue();

if (producerMapping.empty() || consumerMapping.empty()) {
return failure();
}

if (producerMapping == consumerMapping &&
if (producerMapping.front() == consumerMapping.front() &&
checkMappingTypes(producerMapping) &&
checkMappingTypes(consumerMapping)) {
return SmallVector<scf::ForallOp>({consumer});
}

if (!llvm::all_of(producerMapping.getValue(),
if (!llvm::all_of(producerMapping,
llvm::IsaPred<gpu::GPUThreadMappingAttr>) ||
!llvm::all_of(consumerMapping.getValue(),
llvm::IsaPred<IREE::GPU::LaneIdAttr>)) {
!llvm::all_of(consumerMapping, llvm::IsaPred<IREE::GPU::LaneIdAttr>)) {
return failure();
}
auto outerWarpLoop = consumer->getParentOfType<scf::ForallOp>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,64 @@ module {

// -----

#map = affine_map<(d0) -> (d0 * 2)>
#map1 = affine_map<(d0) -> (d0 * 4)>
#map2 = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
#map3 = affine_map<(d0) -> (d0 * 16)>
module {
func.func @forall_fuse_then_hoist_mixed_mappings() {
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0.0> : tensor<4x128xf16>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x128xf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readwrite:tensor<128x128xf32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x128xf16>> -> tensor<128x128xf16>
%5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<128x128xf32>> -> tensor<128x128xf32>
%6 = tensor.empty() : tensor<128x4xf16>
%7 = tensor.empty() : tensor<4x128xf16>
%8 = scf.for %arg0 = %c0 to %c128 step %c4 iter_args(%arg1 = %5) -> (tensor<128x128xf32>) {
%9 = scf.forall (%arg2, %arg3, %arg4) in (1, 64, 1) shared_outs(%arg5 = %6) -> (tensor<128x4xf16>) {
%12 = affine.apply #map(%arg3)
%13 = affine.apply #map1(%arg4)
%14 = affine.apply #map(%arg3)
%15 = affine.apply #map2(%arg4)[%arg0]
%extracted_slice = tensor.extract_slice %3[%14, %15] [2, 4] [1, 1] : tensor<128x128xf16> to tensor<2x4xf16>
%extracted_slice_0 = tensor.extract_slice %arg5[%12, %13] [2, 4] [1, 1] : tensor<128x4xf16> to tensor<2x4xf16>
%16 = linalg.copy ins(%extracted_slice : tensor<2x4xf16>) outs(%extracted_slice_0 : tensor<2x4xf16>) -> tensor<2x4xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %16 into %arg5[%12, %13] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<128x4xf16>
}
} {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_2>]}
%11 = scf.forall (%arg2, %arg3) in (8, 8) shared_outs(%arg4 = %arg1) -> (tensor<128x128xf32>) {
%12 = affine.apply #map3(%arg2)
%13 = affine.apply #map3(%arg3)
%extracted_slice = tensor.extract_slice %9[%12, 0] [16, 4] [1, 1] : tensor<128x4xf16> to tensor<16x4xf16>
%extracted_slice_0 = tensor.extract_slice %cst[0, %13] [4, 16] [1, 1] : tensor<4x128xf16> to tensor<4x16xf16>
%extracted_slice_1 = tensor.extract_slice %arg4[%12, %13] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
%14 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<16x4xf16>, tensor<4x16xf16>) outs(%extracted_slice_1 : tensor<16x16xf32>) -> tensor<16x16xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %14 into %arg4[%12, %13] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
}
} {mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]}
scf.yield %11 : tensor<128x128xf32>
}
flow.dispatch.tensor.store %8, %2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : tensor<128x128xf32> -> !flow.dispatch.tensor<readwrite:tensor<128x128xf32>>
return
}
}

// CHECK-LABEL: func @forall_fuse_then_hoist_mixed_mappings
// CHECK: %[[OUTER_PARALLEL:.+]] = scf.forall
// CHECK: %[[LOOP:.+]] = scf.for
// CHECK: scf.yield {{.*}} : tensor<16x16xf32>
// CHECK: scf.forall.in_parallel
// CHECK-NEXT: tensor.parallel_insert_slice %[[LOOP]]
// CHECK-NOT: scf.forall
// CHECK: flow.dispatch.tensor.store %[[OUTER_PARALLEL]]

// -----

#map = affine_map<(d0) -> (d0 * 2)>
#map1 = affine_map<(d0) -> (d0 * 4)>
#map2 = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
Expand Down

0 comments on commit d7744b7

Please sign in to comment.