From a1aa58b441748e3beff21471a1e68f12a36bde68 Mon Sep 17 00:00:00 2001 From: David Berard Date: Wed, 23 Oct 2024 10:00:00 -0700 Subject: [PATCH] [BACKEND] Use vectorized atomics on Hopper (#4971) Hopper supports vectorized atomics for add, max, and min. This PR adds support for generating these instructions. Note: atomic add/min/max also have packed instructions for f16x2 and bf16x2. Packed instructions were used prior to this PR, but vectorized instructions weren't. When vectorized instructions are available, this PR switches to using vectorized instructions (like .v2.f16 instead of .f16x2, or .v8.f16 instead of .v4.f16x2). When vectorized instructions aren't available, packed instructions will be used instead. This PR also adds a check for mask alignment, which wasn't previously checked. --- test/Conversion/tritongpu_to_llvm.mlir | 34 ++++- test/Conversion/tritongpu_to_llvm_hopper.mlir | 38 +++++ .../LoadStoreOpToLLVM.cpp | 144 ++++++++++++++---- 3 files changed, 184 insertions(+), 32 deletions(-) diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index e2f43f4ba629..e1a2ec68bd5a 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1034,7 +1034,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { // CHECK-LABEL: atomic_add_f32 tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm @@ -1048,7 +1048,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { // CHECK-LABEL: atomic_add_f32_scalar tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr, %arg1 : i1, %arg2 : f32) { // CHECK: llvm.icmp "eq" @@ -1062,7 +1062,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { // CHECK-LABEL: atomic_add_f32 tt.func @atomic_add_f32_sys_scope(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm @@ -1076,6 +1076,34 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f16_nomask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f16_nomask + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data : (tensor<256x!tt.ptr, #blocked>, tensor<256xf16, #blocked>) -> tensor<256xf16, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f16_withmask + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr, #blocked>, tensor<256xf16, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf16, #blocked> + tt.return + } +} + +// ----- + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: store_f32 diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index d44529966274..83653d57b65e 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -241,3 +241,41 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : tt.return } } + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f32_nomask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f32_nomask + // CHECK: atom.global.gpu.acq_rel.add.v4.f32 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data : (tensor<256x!tt.ptr, #blocked>, tensor<256xf32, #blocked>) -> tensor<256xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f32_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 2 : i32}) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f32_withmask + // CHECK: atom.global.gpu.acq_rel.add.v2.f32 + // CHECK: atom.global.gpu.acq_rel.add.v2.f32 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr, #blocked>, tensor<256xf32, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 4 : i32}) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f16_withmask + // CHECK: atom.global.gpu.acq_rel.add.noftz.v4.f16 + // CHECK: atom.global.gpu.acq_rel.add.noftz.v4.f16 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr, #blocked>, tensor<256xf16, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf16, #blocked> + tt.return + } +} diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index b19f3ac88ed1..760ba75d9816 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -98,6 +98,23 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, return mask; } +std::string getRegisterSizeCode(int size, bool is_float) { + switch (size) { + case 1: + return "b"; + case 16: + return "h"; + case 32: + return is_float ? "f" : "r"; + case 64: + return is_float ? "d" : "l"; + case 128: + return "q"; + default: + llvm_unreachable("Unsupported register size"); + } +} + // Contains some helper functions for both Load and Store conversions. struct LoadStoreConversionBase { explicit LoadStoreConversionBase(const NVIDIA::TargetInfo &targetInfo, @@ -632,6 +649,20 @@ struct AtomicRMWOpConversion : ConvertOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + bool supportsVectorized(Operation *moduleOp, RMWOp opType, + Type elementType) const { + // vectorized atomics are only supported on hopper, + // and only for specific atomic ops (add, min, max). + // Note that "packed types" like f16x2 are supported sm60+. + auto computeCapability = getNVIDIAComputeCapability(moduleOp); + if (computeCapability < 90) { + return false; + } + + return opType == RMWOp::FADD && + (elementType.isF16() || elementType.isBF16() || elementType.isF32()); + } + LogicalResult matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -664,45 +695,82 @@ struct AtomicRMWOpConversion : valueTy; const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); auto elemsPerThread = getTotalElemsPerThread(val.getType()); - // vec = 1, numElements = 1 for scalar - auto vec = getVectorSize(ptr); - auto vecOrig = vec; - int numElems = 1; - // tensor + // packed: e.g. packed=2 for f16x2 + // vec: e.g. .v2, .v4, .v8 version of atom instruction. + unsigned vec, vecOrig; + int numElems, packed; if (tensorTy) { + vec = getVectorSize(ptr); + if (llMask) { + vec = std::min(vec, getMaskAlignment(op.getMask())); + } + vecOrig = vec; + packed = 1; auto valTy = cast(val.getType()); - vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); - // mask + if (!supportsVectorized(moduleOp, atomicRmwAttr, + valTy.getElementType())) { + packed = + std::min(vecOrig, valTy.getElementType().isF16() ? 2 : 1); + vec = 1; + } numElems = tensorTy.getNumElements(); + } else { + // scalar + vec = 1; + vecOrig = 1; + numElems = 1; + packed = 1; } + assert((packed == 1 || vec == 1) && "packed or vec must be 1"); - if (vec == 1 && numElems > 1) + if (vec * packed == 1 && numElems > 1) op->emitRemark() << "Warning: vectorization fails vec = " << vec - << " origin vec = " << vecOrig + << " packed = " << packed << " origin vec = " << vecOrig << " numElems = " << numElems; Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); - auto vecTy = vec_ty(valueElemTy, vec); + auto packedTy = vec_ty(valueElemTy, packed); SmallVector resultVals(elemsPerThread); - for (size_t i = 0; i < elemsPerThread; i += vec) { - Value rmwVal = undef(vecTy); - for (int ii = 0; ii < vec; ++ii) { - Value iiVal = createIndexAttrConstant( - rewriter, loc, getTypeConverter()->getIndexType(), ii); - rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal); - } - + for (size_t i = 0; i < elemsPerThread; i += vec * packed) { Value rmwPtr = ptrElements[i]; Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask; std::string sTy; PTXBuilder ptxBuilderAtomicRMW; - std::string tyId = valueElemNBits * vec == 64 - ? "l" - : (valueElemNBits * vec == 32 ? "r" : "h"); - auto *dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true); + // 16-bit -> "h", 32-bit -> "r", 64-bit -> "l" + std::string tyId = + getRegisterSizeCode(valueElemNBits * packed, /*is_float=*/false); + + PTXBuilder::Operand *dstOpr; + if (vec > 1) { + dstOpr = ptxBuilderAtomicRMW.newListOperand(); + for (unsigned ii = 0; ii < vec; ++ii) { + dstOpr->listAppend( + ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true)); + } + } else { + dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true); + } + auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand(rmwPtr, "l"); - auto *valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId); + + PTXBuilder::Operand *valOpr; + if (vec > 1) { + valOpr = ptxBuilderAtomicRMW.newListOperand(); + for (unsigned ii = 0; ii < vec; ++ii) { + valOpr->listAppend( + ptxBuilderAtomicRMW.newOperand(valElements[i + ii], tyId)); + } + } else if (packed > 1) { + Value rmwVal = undef(packedTy); + for (int ii = 0; ii < packed; ++ii) { + rmwVal = insert_element(packedTy, rmwVal, valElements[i + ii], + i32_val(ii)); + } + valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId); + } else { + valOpr = ptxBuilderAtomicRMW.newOperand(valElements[i], tyId); + } auto scope = stringifyMemSyncScope(op.getScope()).str(); auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o(scope); @@ -725,7 +793,7 @@ struct AtomicRMWOpConversion rmwOp = "add"; rmwOp += (valueElemNBits == 16 ? ".noftz" : ""); sTy = "f" + sBits; - sTy += (vec == 2 && valueElemNBits == 16) ? "x2" : ""; + sTy += (packed == 2 && valueElemNBits == 16) ? "x2" : ""; break; case RMWOp::MAX: sTy = "s" + sBits; @@ -750,15 +818,33 @@ struct AtomicRMWOpConversion std::string semStr; llvm::raw_string_ostream os(semStr); os << op.getSem(); - atom.o(semStr).o(rmwOp).o(sTy); + atom.o(semStr).o(rmwOp).v(vec).o(sTy); if (tensorTy) { atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); - auto retType = vec == 1 ? valueElemTy : vecTy; + Type retType; + if (vec > 1) { + SmallVector retTys(vec, valueElemTy); + retType = struct_ty(retTys); + } else if (packed > 1) { + retType = packedTy; + } else { + retType = valueElemTy; + } + auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, retType); - for (int ii = 0; ii < vec; ++ii) { - resultVals[i + ii] = - vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii)); + + if (vec > 1) { + for (unsigned ii = 0; ii < vec; ++ii) { + resultVals[i + ii] = extract_val(valueElemTy, ret, ii); + } + } else if (packed > 1) { + for (unsigned ii = 0; ii < packed; ++ii) { + resultVals[i + ii] = extract_element(valueElemTy, ret, i32_val(ii)); + } + } else { + resultVals[i] = ret; } + } else { auto ASMReturnTy = void_ty(ctx); atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);