Skip to content

Commit

Permalink
[BACKEND] Extract MinMaxFOp pattern to TritonGPUToLLVM lib. Enable sp…
Browse files Browse the repository at this point in the history
…lit/join/transpose tests on AMD backend. (#3203)
  • Loading branch information
zahimoud authored Feb 27, 2024
1 parent b750a5c commit 6891634
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,6 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {

protected:
ModuleAxisInfoAnalysis &axisAnalysisPass;

private:
int computeCapability;
};

} // namespace gpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ void populateMakeRangeOpToLLVMPattern(LLVMTypeConverter &typeConverter,
void populateViewOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
PatternBenefit benefit);

void populateMinMaxFOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
ModuleAxisInfoAnalysis &axisInfoAnalysis,
bool hwNanPropagationSupported,
PatternBenefit benefit);
} // namespace triton
} // namespace mlir

Expand Down
65 changes: 65 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -593,8 +593,73 @@ struct SelectOpConversion
adaptor.getAttributes().getValue())};
}
};
template <typename OpTy>
struct MinMaxFOpConversion
: ElementwiseOpConversionBase<OpTy, MinMaxFOpConversion<OpTy>> {
using Base = ElementwiseOpConversionBase<OpTy, MinMaxFOpConversion<OpTy>>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;

static_assert(std::is_same<OpTy, arith::MinimumFOp>::value ||
std::is_same<OpTy, arith::MaximumFOp>::value,
"OpTy must be arith::MinimumFOp or arith::MaximumFOp");

// Choose the destination op based on the OpTy.
using DestOpNanProp =
typename std::conditional<std::is_same<OpTy, arith::MinimumFOp>::value,
LLVM::MinimumOp, LLVM::MaximumOp>::type;
using DestOpNoNanProp =
typename std::conditional<std::is_same<OpTy, arith::MinimumFOp>::value,
LLVM::MinNumOp, LLVM::MaxNumOp>::type;

explicit MinMaxFOpConversion(LLVMTypeConverter &typeConverter,
ModuleAxisInfoAnalysis &axisAnalysisPass,
bool hwNanPropagationSupported,
PatternBenefit benefit = 1)
: Base::ElementwiseOpConversionBase(typeConverter, axisAnalysisPass,
benefit),
hwNanPropagationSupported(hwNanPropagationSupported) {}

SmallVector<Value> createDestOps(OpTy op, Adaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
if (hwNanPropagationSupported) {
return {rewriter.create<DestOpNanProp>(loc, elemTy, operands[0][0],
operands[0][1])};
}
// Handle workaround for NaN propagation, i.e. software emulation of NaN
// propagation. If any of the operands is NaN, return NaN.
auto lhs = operands[0][0];
auto rhs = operands[0][1];
auto lhsIsNan =
rewriter.create<LLVM::FCmpOp>(loc, LLVM::FCmpPredicate::une, lhs, lhs);
auto rhsIsNan =
rewriter.create<LLVM::FCmpOp>(loc, LLVM::FCmpPredicate::une, rhs, rhs);
auto isNan = rewriter.create<LLVM::OrOp>(loc, lhsIsNan, rhsIsNan);
auto nonNanRes = rewriter.create<DestOpNoNanProp>(loc, elemTy, lhs, rhs);

auto nan = LLVM::createNaNConstant(loc, rewriter, elemTy);

// Select the result based on the isNan flag.
return {rewriter.create<LLVM::SelectOp>(loc, isNan, nan, nonNanRes)};
}

private:
bool hwNanPropagationSupported;
};
} // namespace

void mlir::triton::populateMinMaxFOpToLLVMPattern(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
ModuleAxisInfoAnalysis &axisInfoAnalysis, bool hwNanPropagationSupported,
PatternBenefit benefit) {
patterns.add<MinMaxFOpConversion<arith::MinimumFOp>>(
typeConverter, axisInfoAnalysis, hwNanPropagationSupported, benefit);
patterns.add<MinMaxFOpConversion<arith::MaximumFOp>>(
typeConverter, axisInfoAnalysis, hwNanPropagationSupported, benefit);
}

void mlir::triton::populateElementwiseOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) {
Expand Down
24 changes: 2 additions & 22 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,8 +917,6 @@ def kernel():

@pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16])
def test_transpose(dtype_x, device):
if is_hip():
pytest.skip('test_transpose not supported on HIP.')
SIZE = 128

@triton.jit
Expand Down Expand Up @@ -1503,8 +1501,6 @@ def kernel(in_out_ptr):


def test_join(device):
if is_hip():
pytest.skip("test_join not supported on HIP")

@triton.jit
def kernel(X, Y, Z, N: tl.constexpr):
Expand All @@ -1524,8 +1520,6 @@ def kernel(X, Y, Z, N: tl.constexpr):


def test_join_scalars(device):
if is_hip():
pytest.skip("test_join not supported on HIP")

@triton.jit
def kernel(X, Y, Z):
Expand All @@ -1544,8 +1538,6 @@ def kernel(X, Y, Z):


def test_join_with_mma(device):
if is_hip():
pytest.skip("test_join_with_mma not supported on HIP")

@triton.jit
def kernel(X, Z):
Expand Down Expand Up @@ -1599,8 +1591,6 @@ def kernel(X, Y, Z):


def test_split(device):
if is_hip():
pytest.skip("test_split not supported on HIP")

@triton.jit
def kernel(X, Z1, Z2, N: tl.constexpr):
Expand All @@ -1622,8 +1612,6 @@ def kernel(X, Z1, Z2, N: tl.constexpr):


def test_split_to_scalar(device):
if is_hip():
pytest.skip("test_split not supported on HIP")

@triton.jit
def kernel(X, Z1, Z2):
Expand Down Expand Up @@ -2652,8 +2640,6 @@ def kernel(X, stride_xm, stride_xn, Z, stride_zm, stride_zn, BLOCK_M: tl.constex
@pytest.mark.parametrize("shape", [(2, 4), (16, 16)])
@pytest.mark.parametrize("perm", list(itertools.permutations([0, 1])))
def test_trans_2d(dtype_str, shape, perm, device):
if is_hip():
pytest.skip('test_trans_2d for HIP currently broken')

@triton.jit
def kernel(In, Out, in_shape1: tl.constexpr, in_shape2: tl.constexpr, ou_shape1: tl.constexpr,
Expand All @@ -2676,8 +2662,6 @@ def kernel(In, Out, in_shape1: tl.constexpr, in_shape2: tl.constexpr, ou_shape1:
@pytest.mark.parametrize("shape", [(2, 2, 8, 64), (4, 4, 4, 4)])
@pytest.mark.parametrize("perm", list(itertools.permutations([0, 1, 2, 3])))
def test_trans_4d(dtype_str, shape, perm, device):
if is_hip():
pytest.skip('test_trans_4d for HIP currently broken')

@triton.jit
def kernel(In, Out, #
Expand Down Expand Up @@ -3650,8 +3634,6 @@ def kernel():


def test_trans_reshape(device):
if is_hip():
pytest.skip('test_trans_reshape not supported on HIP.')

@triton.jit
def kernel(in_base_ptr, out_base_ptr, IN_SHAPE0: tl.constexpr, IN_SHAPE1: tl.constexpr):
Expand Down Expand Up @@ -4803,10 +4785,8 @@ def mul_add(data):
@pytest.mark.parametrize("propagate_nan", ['NONE', 'ALL'])
@pytest.mark.parametrize("func", ['minimum', 'maximum', 'clamp'])
def test_propagate_nan(dtype, propagate_nan, func, device):
if is_hip():
pytest.skip(
'test_propagate_nan for HIP currently broken in https://github.com/openai/triton. Use https://github.com/ROCmSoftwarePlatform/triton'
)
if is_hip() and func == 'clamp':
pytest.skip('test_propagate_nan is not supported for clamp in HIP')

@triton.jit
def kernel(A, B, C, propagate_nan: tl.constexpr, func: tl.constexpr):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1574,5 +1574,8 @@ void populateElementwiseOpToLLVMPatterns(
patterns.add<ExpOpConversionApprox>(typeConverter, axisInfoAnalysis, benefit);
mlir::triton::populateElementwiseOpToLLVMPatterns(typeConverter, patterns,
axisInfoAnalysis, benefit);
mlir::triton::populateMinMaxFOpToLLVMPattern(
typeConverter, patterns, axisInfoAnalysis,
true /*hwNanPropagationSupported*/, benefit);
}
} // namespace AMD
Original file line number Diff line number Diff line change
Expand Up @@ -887,62 +887,6 @@ struct ExpOpConversionApprox
}
};

template <typename OpTy>
struct MinMaxFOpConversion
: ElementwiseOpConversionBase<OpTy, MinMaxFOpConversion<OpTy>> {
using Base = ElementwiseOpConversionBase<OpTy, MinMaxFOpConversion<OpTy>>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;

static_assert(std::is_same<OpTy, arith::MinimumFOp>::value ||
std::is_same<OpTy, arith::MaximumFOp>::value,
"OpTy must be arith::MinimumFOp or arith::MaximumFOp");

// Choose the destination op based on the OpTy.
using DestOpNanProp =
typename std::conditional<std::is_same<OpTy, arith::MinimumFOp>::value,
LLVM::MinimumOp, LLVM::MaximumOp>::type;
using DestOpNoNanProp =
typename std::conditional<std::is_same<OpTy, arith::MinimumFOp>::value,
LLVM::MinNumOp, LLVM::MaxNumOp>::type;

explicit MinMaxFOpConversion(LLVMTypeConverter &typeConverter,
ModuleAxisInfoAnalysis &axisAnalysisPass,
int computeCapability,
PatternBenefit benefit = 1)
: Base::ElementwiseOpConversionBase(typeConverter, axisAnalysisPass,
benefit),
computeCapability(computeCapability) {}

SmallVector<Value> createDestOps(OpTy op, Adaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
if (computeCapability >= 80) {
return {rewriter.create<DestOpNanProp>(loc, elemTy, operands[0][0],
operands[0][1])};
}
// Handle pre-80 compute capability.
// If any of the operands is NaN, return NaN.
auto lhs = operands[0][0];
auto rhs = operands[0][1];
auto lhsIsNan =
rewriter.create<LLVM::FCmpOp>(loc, LLVM::FCmpPredicate::une, lhs, lhs);
auto rhsIsNan =
rewriter.create<LLVM::FCmpOp>(loc, LLVM::FCmpPredicate::une, rhs, rhs);
auto isNan = rewriter.create<LLVM::OrOp>(loc, lhsIsNan, rhsIsNan);
auto nonNanRes = rewriter.create<DestOpNoNanProp>(loc, elemTy, lhs, rhs);

auto nan = LLVM::createNaNConstant(loc, rewriter, elemTy);

// Select the result based on the isNan flag.
return {rewriter.create<LLVM::SelectOp>(loc, isNan, nan, nonNanRes)};
}

private:
int computeCapability;
};

struct ClampFOpConversion
: ElementwiseOpConversionBase<ClampFOp, ClampFOpConversion> {
using Base = ElementwiseOpConversionBase<ClampFOp, ClampFOpConversion>;
Expand Down Expand Up @@ -1146,8 +1090,7 @@ void mlir::triton::NVIDIA::populateElementwiseOpToLLVMPatterns(
patterns.add<ExpOpConversionApprox>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<ClampFOpConversion>(typeConverter, axisInfoAnalysis,
computeCapability, benefit);
patterns.add<MinMaxFOpConversion<arith::MinimumFOp>>(
typeConverter, axisInfoAnalysis, computeCapability, benefit);
patterns.add<MinMaxFOpConversion<arith::MaximumFOp>>(
typeConverter, axisInfoAnalysis, computeCapability, benefit);
mlir::triton::populateMinMaxFOpToLLVMPattern(
typeConverter, patterns, axisInfoAnalysis,
computeCapability >= 80 /*hwNanPropagationSupported*/, benefit);
}

0 comments on commit 6891634

Please sign in to comment.