Skip to content

Commit

Permalink
[BACKEND] Use vectorized atomics on Hopper (#4971)
Browse files Browse the repository at this point in the history
Hopper supports vectorized atomics for add, max, and min. This PR adds
support for generating these instructions.

Note: atomic add/min/max also have packed instructions for f16x2 and
bf16x2. Packed instructions were used prior to this PR, but vectorized
instructions weren't. When vectorized instructions are available, this
PR switches to using vectorized instructions (like .v2.f16 instead of
.f16x2, or .v8.f16 instead of .v4.f16x2). When vectorized instructions
aren't available, packed instructions will be used instead.

This PR also adds a check for mask alignment, which wasn't previously
checked.
  • Loading branch information
davidberard98 authored Oct 23, 2024
1 parent a20ce64 commit a1aa58b
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 32 deletions.
34 changes: 31 additions & 3 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1034,7 +1034,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// -----

#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} {
// CHECK-LABEL: atomic_add_f32
tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.inline_asm
Expand All @@ -1048,7 +1048,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :

// -----

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} {
// CHECK-LABEL: atomic_add_f32_scalar
tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
// CHECK: llvm.icmp "eq"
Expand All @@ -1062,7 +1062,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// -----

#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} {
// CHECK-LABEL: atomic_add_f32
tt.func @atomic_add_f32_sys_scope(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.inline_asm
Expand All @@ -1076,6 +1076,34 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @atomic_add_f16_nomask(%dest_ptrs: tensor<256x!tt.ptr<f16>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>) attributes {noinline = false} {
// CHECK-LABEL: atomic_add_f16_nomask
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2
%0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data : (tensor<256x!tt.ptr<f16>, #blocked>, tensor<256xf16, #blocked>) -> tensor<256xf16, #blocked>
tt.return
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr<f16>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked>) attributes {noinline = false} {
// CHECK-LABEL: atomic_add_f16_withmask
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16
%0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr<f16>, #blocked>, tensor<256xf16, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf16, #blocked>
tt.return
}
}

// -----

#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: store_f32
Expand Down
38 changes: 38 additions & 0 deletions test/Conversion/tritongpu_to_llvm_hopper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,41 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :
tt.return
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @atomic_add_f32_nomask(%dest_ptrs: tensor<256x!tt.ptr<f32>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>) attributes {noinline = false} {
// CHECK-LABEL: atomic_add_f32_nomask
// CHECK: atom.global.gpu.acq_rel.add.v4.f32
%0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data : (tensor<256x!tt.ptr<f32>, #blocked>, tensor<256xf32, #blocked>) -> tensor<256xf32, #blocked>
tt.return
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @atomic_add_f32_withmask(%dest_ptrs: tensor<256x!tt.ptr<f32>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 2 : i32}) attributes {noinline = false} {
// CHECK-LABEL: atomic_add_f32_withmask
// CHECK: atom.global.gpu.acq_rel.add.v2.f32
// CHECK: atom.global.gpu.acq_rel.add.v2.f32
%0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr<f32>, #blocked>, tensor<256xf32, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf32, #blocked>
tt.return
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr<f16>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 4 : i32}) attributes {noinline = false} {
// CHECK-LABEL: atomic_add_f16_withmask
// CHECK: atom.global.gpu.acq_rel.add.noftz.v4.f16
// CHECK: atom.global.gpu.acq_rel.add.noftz.v4.f16
%0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr<f16>, #blocked>, tensor<256xf16, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf16, #blocked>
tt.return
}
}
144 changes: 115 additions & 29 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,23 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
return mask;
}

std::string getRegisterSizeCode(int size, bool is_float) {
switch (size) {
case 1:
return "b";
case 16:
return "h";
case 32:
return is_float ? "f" : "r";
case 64:
return is_float ? "d" : "l";
case 128:
return "q";
default:
llvm_unreachable("Unsupported register size");
}
}

// Contains some helper functions for both Load and Store conversions.
struct LoadStoreConversionBase {
explicit LoadStoreConversionBase(const NVIDIA::TargetInfo &targetInfo,
Expand Down Expand Up @@ -632,6 +649,20 @@ struct AtomicRMWOpConversion
: ConvertOpToLLVMPattern<triton::AtomicRMWOp>(converter, benefit),
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}

bool supportsVectorized(Operation *moduleOp, RMWOp opType,
Type elementType) const {
// vectorized atomics are only supported on hopper,
// and only for specific atomic ops (add, min, max).
// Note that "packed types" like f16x2 are supported sm60+.
auto computeCapability = getNVIDIAComputeCapability(moduleOp);
if (computeCapability < 90) {
return false;
}

return opType == RMWOp::FADD &&
(elementType.isF16() || elementType.isBF16() || elementType.isF32());
}

LogicalResult
matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -664,45 +695,82 @@ struct AtomicRMWOpConversion
: valueTy;
const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
auto elemsPerThread = getTotalElemsPerThread(val.getType());
// vec = 1, numElements = 1 for scalar
auto vec = getVectorSize(ptr);
auto vecOrig = vec;
int numElems = 1;
// tensor
// packed: e.g. packed=2 for f16x2
// vec: e.g. .v2, .v4, .v8 version of atom instruction.
unsigned vec, vecOrig;
int numElems, packed;
if (tensorTy) {
vec = getVectorSize(ptr);
if (llMask) {
vec = std::min<unsigned>(vec, getMaskAlignment(op.getMask()));
}
vecOrig = vec;
packed = 1;
auto valTy = cast<RankedTensorType>(val.getType());
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
// mask
if (!supportsVectorized(moduleOp, atomicRmwAttr,
valTy.getElementType())) {
packed =
std::min<unsigned>(vecOrig, valTy.getElementType().isF16() ? 2 : 1);
vec = 1;
}
numElems = tensorTy.getNumElements();
} else {
// scalar
vec = 1;
vecOrig = 1;
numElems = 1;
packed = 1;
}
assert((packed == 1 || vec == 1) && "packed or vec must be 1");

if (vec == 1 && numElems > 1)
if (vec * packed == 1 && numElems > 1)
op->emitRemark() << "Warning: vectorization fails vec = " << vec
<< " origin vec = " << vecOrig
<< " packed = " << packed << " origin vec = " << vecOrig
<< " numElems = " << numElems;

Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);

auto vecTy = vec_ty(valueElemTy, vec);
auto packedTy = vec_ty(valueElemTy, packed);
SmallVector<Value> resultVals(elemsPerThread);
for (size_t i = 0; i < elemsPerThread; i += vec) {
Value rmwVal = undef(vecTy);
for (int ii = 0; ii < vec; ++ii) {
Value iiVal = createIndexAttrConstant(
rewriter, loc, getTypeConverter()->getIndexType(), ii);
rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal);
}

for (size_t i = 0; i < elemsPerThread; i += vec * packed) {
Value rmwPtr = ptrElements[i];
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;
std::string sTy;
PTXBuilder ptxBuilderAtomicRMW;
std::string tyId = valueElemNBits * vec == 64
? "l"
: (valueElemNBits * vec == 32 ? "r" : "h");
auto *dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true);
// 16-bit -> "h", 32-bit -> "r", 64-bit -> "l"
std::string tyId =
getRegisterSizeCode(valueElemNBits * packed, /*is_float=*/false);

PTXBuilder::Operand *dstOpr;
if (vec > 1) {
dstOpr = ptxBuilderAtomicRMW.newListOperand();
for (unsigned ii = 0; ii < vec; ++ii) {
dstOpr->listAppend(
ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true));
}
} else {
dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true);
}

auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand(rmwPtr, "l");
auto *valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId);

PTXBuilder::Operand *valOpr;
if (vec > 1) {
valOpr = ptxBuilderAtomicRMW.newListOperand();
for (unsigned ii = 0; ii < vec; ++ii) {
valOpr->listAppend(
ptxBuilderAtomicRMW.newOperand(valElements[i + ii], tyId));
}
} else if (packed > 1) {
Value rmwVal = undef(packedTy);
for (int ii = 0; ii < packed; ++ii) {
rmwVal = insert_element(packedTy, rmwVal, valElements[i + ii],
i32_val(ii));
}
valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId);
} else {
valOpr = ptxBuilderAtomicRMW.newOperand(valElements[i], tyId);
}

auto scope = stringifyMemSyncScope(op.getScope()).str();
auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o(scope);
Expand All @@ -725,7 +793,7 @@ struct AtomicRMWOpConversion
rmwOp = "add";
rmwOp += (valueElemNBits == 16 ? ".noftz" : "");
sTy = "f" + sBits;
sTy += (vec == 2 && valueElemNBits == 16) ? "x2" : "";
sTy += (packed == 2 && valueElemNBits == 16) ? "x2" : "";
break;
case RMWOp::MAX:
sTy = "s" + sBits;
Expand All @@ -750,15 +818,33 @@ struct AtomicRMWOpConversion
std::string semStr;
llvm::raw_string_ostream os(semStr);
os << op.getSem();
atom.o(semStr).o(rmwOp).o(sTy);
atom.o(semStr).o(rmwOp).v(vec).o(sTy);
if (tensorTy) {
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto retType = vec == 1 ? valueElemTy : vecTy;
Type retType;
if (vec > 1) {
SmallVector<Type> retTys(vec, valueElemTy);
retType = struct_ty(retTys);
} else if (packed > 1) {
retType = packedTy;
} else {
retType = valueElemTy;
}

auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, retType);
for (int ii = 0; ii < vec; ++ii) {
resultVals[i + ii] =
vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii));

if (vec > 1) {
for (unsigned ii = 0; ii < vec; ++ii) {
resultVals[i + ii] = extract_val(valueElemTy, ret, ii);
}
} else if (packed > 1) {
for (unsigned ii = 0; ii < packed; ++ii) {
resultVals[i + ii] = extract_element(valueElemTy, ret, i32_val(ii));
}
} else {
resultVals[i] = ret;
}

} else {
auto ASMReturnTy = void_ty(ctx);
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
Expand Down

0 comments on commit a1aa58b

Please sign in to comment.