Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clang tidy integration #172

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions .clang-tidy
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
Checks: |
bugprone-*, # Enable checks for common bugs
-bugprone-unchecked-optional-access, # This one crashes on our code
-bugprone-reserved-identifier, # We have a few false-positives here
-bugprone-easily-swappable-parameters, # Maybe in the future
-bugprone-casting-through-void,
-bugprone-narrowing-conversions, # Maybe in the future
cert-*, # Enable CERT coding standard checks
-cert-dcl37-c,
-cert-dcl51-cpp,
clang-analyzer-*,
-clang-analyzer-osx*, # Not needed
concurrency-*,
cppcoreguidelines-*,
-cppcoreguidelines-special-member-functions,
-cppcoreguidelines-non-private-member-variables-in-classes, # Maybe in the future
-cppcoreguidelines-avoid-magic-numbers, # Maybe in the future
-cppcoreguidelines-avoid-c-arrays,
-cppcoreguidelines-avoid-do-while, # LLVM macros have these inside
-cppcoreguidelines-pro-type-const-cast,
-cppcoreguidelines-slicing, # Unsure if this is a false-positive or not
-cppcoreguidelines-narrowing-conversions,
hicpp-*,
-hicpp-braces-around-statements, # Maybe in the future
-hicpp-special-member-functions,
-hicpp-avoid-c-arrays,
-hicpp-uppercase-literal-suffix,
llvm-*,
misc-*,
-misc-include-cleaner,
-misc-non-private-member-variables-in-classes,
-misc-no-recursion,
-misc-use-anonymous-namespace, # Maybe in the future
modernize-*,
-modernize-use-nodiscard,
-modernize-use-trailing-return-type,
-modernize-avoid-c-arrays, # Maybe in the future
-modernize-return-braced-init-list, # False positives
performance-*,
readability-*,
-readability-convert-member-functions-to-static, # Maybe in the future
-readability-uppercase-literal-suffix,
-readability-avoid-const-params-in-decls,
-readability-magic-numbers,
-readability-function-cognitive-complexity,
-readability-braces-around-statements,
-readability-identifier-length, # This one is too strict

WarningsAsErrors: '*'
HeaderFilterRegex: '^.*/triton_shared/include/.*(?<!\.h\.inc)$'
FormatStyle: none
User: ''
CheckOptions:
- key: readability-identifier-naming.ClassCase
value: CamelCase
- key: readability-identifier-naming.FunctionCase
value: camelBack
10 changes: 8 additions & 2 deletions .github/workflows/test-plugin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,15 @@ jobs:
python3 -m pip install --upgrade pip
python3 -m pip install cmake==3.24 ninja pytest-xdist
sudo apt-get update -y
sudo apt-get install -y ccache clang lld
sudo apt-get install -y ccache clang lld bear clang-tidy
export TRITON_PLUGIN_DIRS="${GITHUB_WORKSPACE}/triton_shared"
TRITON_BUILD_WITH_CLANG_LLD=true TRITON_BUILD_WITH_CCACHE=true python3 -m pip install --no-build-isolation -vvv '.[tests]'
TRITON_BUILD_WITH_CLANG_LLD=true TRITON_BUILD_WITH_CCACHE=true bear -- pip install --no-build-isolation -vvv '.[tests]'

- name: Run clang-tidy
working-directory: triton_shared/
run: |
python3 filter_compile_commands.py
run-clang-tidy

- name: Run shared middle-layer lit tests
working-directory: triton_shared/triton/python
Expand Down
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ sudo apt-get install -y ccache clang lld
TRITON_BUILD_WITH_CLANG_LLD=true TRITON_BUILD_WITH_CCACHE=true python3 -m pip install --no-build-isolation -vvv '.[tests]'
```

To run with clang-tidy:

```sh
sudo apt install -y bear clang-tidy
TRITON_BUILD_WITH_CLANG_LLD=true TRITON_BUILD_WITH_CCACHE=true bear -- pip install --no-build-isolation -vvv '.[tests]'
cd ../../
python filter_compile_commands.py
run-clang-tidy
```

To build with a virtualenv:

```
Expand Down
24 changes: 24 additions & 0 deletions filter_compile_commands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# this script trims compile_commands.json to contain no files from triton/ folder
# to avoid analyzing code from the submodule
import json
import os

input_file = 'triton/python/compile_commands.json'
output_file = 'compile_commands.json'

def filter_compile_commands(input_file, output_file):
with open(input_file, 'r') as f:
compile_commands = json.load(f)

filtered_commands = [
entry for entry in compile_commands
if 'triton_shared/triton' not in entry['file']
]

with open(output_file, 'w') as f:
json.dump(filtered_commands, f, indent=2)

print(f"Filtered compile_commands.json written to {output_file} with {len(filtered_commands)} entries.")

if __name__ == "__main__":
filter_compile_commands(input_file, output_file)
2 changes: 1 addition & 1 deletion include/triton-shared/Analysis/PtrAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class PtrAnalysis {
const llvm::SmallDenseMap<Value, PtrState> &knownPtrs);

static void
visitOperandRem(arith::RemSIOp mulOp, PtrState &state, const Location loc,
visitOperandRem(arith::RemSIOp remOp, PtrState &state, const Location loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, PtrState> &knownPtrs);

Expand Down
8 changes: 4 additions & 4 deletions include/triton-shared/AnalysisStructured/PtrAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class PtrAnalysis {
// result: its strides and offsets have to point to the corresponding stride
// and offset values returned by the loop.
PtrState reconcileLoopPtrState(
scf::ForOp forOp, size_t ptrArgIndex, const PtrState &state,
scf::ForOp forOp, size_t iterArgIndex, const PtrState &state,
llvm::function_ref<Value(scf::ForOp op, size_t)> getReplacementVal);

public:
Expand Down Expand Up @@ -139,7 +139,7 @@ class PtrAnalysis {
LogicalResult visitOperandMul(arith::MulIOp mulOp, PtrState &state,
const Location loc, OpBuilder &builder);

LogicalResult visitOperandRem(arith::RemSIOp mulOp, PtrState &state,
LogicalResult visitOperandRem(arith::RemSIOp remOp, PtrState &state,
const Location loc, OpBuilder &builder);

// Operand is the result of make_range.
Expand Down Expand Up @@ -254,9 +254,9 @@ class PtrAnalysis {
// strides, offsets, and modulos.
LogicalResult rewriteForOp(scf::ForOp op);

LogicalResult rewriteLoadOp(triton::LoadOp op);
LogicalResult rewriteLoadOp(triton::LoadOp op) const;

LogicalResult rewriteStoreOp(triton::StoreOp op);
LogicalResult rewriteStoreOp(triton::StoreOp op) const;

LogicalResult rewriteOp(Operation *op);
};
Expand Down
94 changes: 50 additions & 44 deletions lib/Analysis/MaskAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,41 +12,48 @@

#include "mlir/Transforms/DialectConversion.h"

namespace mlir {

namespace triton {

namespace mlir::triton {

LogicalResult MaskState::parse(Value operand, const Location loc,
OpBuilder &builder) {
if (auto op = operand.getDefiningOp<arith::ConstantOp>()) {
return this->parseConstant(op, loc, builder);
} else if (isa<IntegerType>(operand.getType())) {
}
if (isa<IntegerType>(operand.getType())) {
return this->parseIntScalar(operand, loc, builder);
} else if (auto op = operand.getDefiningOp<arith::AddIOp>()) {
}
if (auto op = operand.getDefiningOp<arith::AddIOp>()) {
return this->parseAdd(op, loc, builder);
} else if (auto op = operand.getDefiningOp<arith::AndIOp>()) {
}
if (auto op = operand.getDefiningOp<arith::AndIOp>()) {
return this->parseAnd(op, loc, builder);
} else if (auto op = operand.getDefiningOp<arith::CmpIOp>()) {
}
if (auto op = operand.getDefiningOp<arith::CmpIOp>()) {
return this->parseCmp(op, loc, builder);
} else if (auto op = operand.getDefiningOp<triton::MakeRangeOp>()) {
}
if (auto op = operand.getDefiningOp<triton::MakeRangeOp>()) {
return this->parseMakeRange(op, loc, builder);
} else if (auto op = operand.getDefiningOp<triton::BroadcastOp>()) {
}
if (auto op = operand.getDefiningOp<triton::BroadcastOp>()) {
return this->parseBroadcast(op, loc, builder);
} else if (auto op = operand.getDefiningOp<triton::SplatOp>()) {
}
if (auto op = operand.getDefiningOp<triton::SplatOp>()) {
return this->parseSplat(op, loc, builder);
} else if (auto op = operand.getDefiningOp<triton::ExpandDimsOp>()) {
}
if (auto op = operand.getDefiningOp<triton::ExpandDimsOp>()) {
return this->parseExpandDims(op, loc, builder);
} else {
return failure();
}
return failure();
}

tensor::ExtractSliceOp MaskState::getExtractSlice(Value source,
const Location loc,
OpBuilder &builder) const {
auto sourceType = cast<RankedTensorType>(source.getType());
SmallVector<OpFoldResult> offsets(getRank(), builder.getIndexAttr(0));
SmallVector<OpFoldResult> strides(getRank(), builder.getIndexAttr(1));
SmallVector<OpFoldResult> const offsets(getRank(), builder.getIndexAttr(0));
SmallVector<OpFoldResult> const strides(getRank(), builder.getIndexAttr(1));

auto dstType = tensor::ExtractSliceOp::inferResultType(sourceType, offsets,
dims, strides);
Expand All @@ -58,8 +65,8 @@ tensor::ExtractSliceOp MaskState::getExtractSlice(Value source,
memref::SubViewOp MaskState::getSubview(Value source, const Location loc,
OpBuilder &builder) const {
auto sourceType = cast<MemRefType>(source.getType());
SmallVector<OpFoldResult> offsets(getRank(), builder.getIndexAttr(0));
SmallVector<OpFoldResult> strides(getRank(), builder.getIndexAttr(1));
SmallVector<OpFoldResult> const offsets(getRank(), builder.getIndexAttr(0));
SmallVector<OpFoldResult> const strides(getRank(), builder.getIndexAttr(1));
auto dstType =
memref::SubViewOp::inferResultType(sourceType, offsets, dims, strides);

Expand Down Expand Up @@ -136,14 +143,14 @@ static memref::SubViewOp createSubview(Value src, Location loc, OpBuilder &b,
std::pair<memref::SubViewOp, memref::SubViewOp>
MaskState::getSideBySideSubviews(Value block1, Value block2, const Location loc,
OpBuilder &builder) const {
OpFoldResult subviewRowFull = dims[0];
OpFoldResult subviewColFull = dims[1];
OpFoldResult col1 = builder.create<memref::DimOp>(loc, block1, 1).getResult();
OpFoldResult subviewCol1 = minOFRs(col1, subviewColFull, loc, builder);
OpFoldResult subviewCol2 = subOFRs(subviewColFull, subviewCol1, loc, builder);

SmallVector<OpFoldResult> offsets(getRank(), builder.getIndexAttr(0));
SmallVector<OpFoldResult> strides(getRank(), builder.getIndexAttr(1));
OpFoldResult const subviewRowFull = dims[0];
OpFoldResult const subviewColFull = dims[1];
OpFoldResult const col1 = builder.create<memref::DimOp>(loc, block1, 1).getResult();
OpFoldResult const subviewCol1 = minOFRs(col1, subviewColFull, loc, builder);
OpFoldResult const subviewCol2 = subOFRs(subviewColFull, subviewCol1, loc, builder);

SmallVector<OpFoldResult> const offsets(getRank(), builder.getIndexAttr(0));
SmallVector<OpFoldResult> const strides(getRank(), builder.getIndexAttr(1));
auto sv1 = createSubview(block1, loc, builder, offsets,
{subviewRowFull, subviewCol1}, strides);
auto sv2 = createSubview(block2, loc, builder, offsets,
Expand All @@ -155,14 +162,14 @@ MaskState::getSideBySideSubviews(Value block1, Value block2, const Location loc,
std::pair<memref::SubViewOp, memref::SubViewOp>
MaskState::getStackedSubviews(Value block1, Value block2, const Location loc,
OpBuilder &builder) const {
OpFoldResult subviewRowFull = dims[0];
OpFoldResult subviewColFull = dims[1];
OpFoldResult row1 = builder.create<memref::DimOp>(loc, block1, 0).getResult();
OpFoldResult subviewRow1 = minOFRs(row1, subviewRowFull, loc, builder);
OpFoldResult subviewRow2 = subOFRs(subviewRowFull, subviewRow1, loc, builder);

SmallVector<OpFoldResult> offsets(getRank(), builder.getIndexAttr(0));
SmallVector<OpFoldResult> strides(getRank(), builder.getIndexAttr(1));
OpFoldResult const subviewRowFull = dims[0];
OpFoldResult const subviewColFull = dims[1];
OpFoldResult const row1 = builder.create<memref::DimOp>(loc, block1, 0).getResult();
OpFoldResult const subviewRow1 = minOFRs(row1, subviewRowFull, loc, builder);
OpFoldResult const subviewRow2 = subOFRs(subviewRowFull, subviewRow1, loc, builder);

SmallVector<OpFoldResult> const offsets(getRank(), builder.getIndexAttr(0));
SmallVector<OpFoldResult> const strides(getRank(), builder.getIndexAttr(1));
auto sv1 = createSubview(block1, loc, builder, offsets,
{subviewRow1, subviewColFull}, strides);
auto sv2 = createSubview(block2, loc, builder, offsets,
Expand All @@ -183,29 +190,28 @@ LogicalResult MaskState::addStates(const MaskState &lhsState,
const MaskState &rhsState, Location loc,
OpBuilder &builder) {
if (lhsState.scalar && rhsState.scalar) {
InFlightDiagnostic diag =
InFlightDiagnostic const diag =
emitError(loc) << "Unexpected case where both lhs and rhs are scalars";
return failure();
}

if (!lhsState.scalar && !rhsState.scalar) {
InFlightDiagnostic diag =
InFlightDiagnostic const diag =
emitError(loc)
<< "Unsupported scenario where neither lhs nor rhs is a scalar";
return failure();
}

if (lhsState.scalar)
return addStateScalar(rhsState, lhsState.scalar, loc, builder);
else
return addStateScalar(lhsState, rhsState.scalar, loc, builder);
return addStateScalar(lhsState, rhsState.scalar, loc, builder);
}

LogicalResult MaskState::minStates(const MaskState &lhsState,
const MaskState &rhsState, Location loc,
OpBuilder &builder) {
if (lhsState.getRank() != rhsState.getRank()) {
InFlightDiagnostic diag =
InFlightDiagnostic const diag =
emitError(loc)
<< "Unexpected case where lhs and rhs have different ranks";
return failure();
Expand Down Expand Up @@ -288,7 +294,7 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
assert(this->isEmpty());

if (cmpOp.getPredicate() != arith::CmpIPredicate::slt) {
InFlightDiagnostic diag = emitError(loc) << "Unsupported cmpi predicate";
InFlightDiagnostic const diag = emitError(loc) << "Unsupported cmpi predicate";
return failure();
}

Expand All @@ -307,7 +313,7 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
auto dimIntAttr = getIntAttr(lhsState.dims[i]);
if (!dimIntAttr || dimIntAttr.value() != 1) {
if (cmpDim != -1) {
InFlightDiagnostic diag = emitError(loc)
InFlightDiagnostic const diag = emitError(loc)
<< "Unsupported cmpi with more than one "
"dimension with size larger than 1";
return failure();
Expand Down Expand Up @@ -342,7 +348,7 @@ LogicalResult MaskState::parseMakeRange(triton::MakeRangeOp rangeOp,
auto stride = (end - start + shape[0] - 1) / shape[0];

if (stride != 1) {
InFlightDiagnostic diag =
InFlightDiagnostic const diag =
emitError(loc)
<< "stride must be 1 for make_range whose result is used "
"as load or store masks";
Expand Down Expand Up @@ -377,7 +383,7 @@ LogicalResult MaskState::parseBroadcast(triton::BroadcastOp broadcastOp,
for (size_t i = 0; i < srcShape.size(); i++) {
if (srcShape[i] == dstShape[i])
continue;
else if (srcShape[i] < dstShape[i])
if (srcShape[i] < dstShape[i])
this->dims[i] = builder.getIndexAttr(dstShape[i]);
else
llvm_unreachable("unexpected dimensions used in broadcast");
Expand All @@ -395,7 +401,7 @@ LogicalResult MaskState::parseSplat(triton::SplatOp splatOp, const Location loc,
auto dstShape = cast<ShapedType>(dst.getType()).getShape();

if (!isa<IntegerType>(src.getType())) {
InFlightDiagnostic diag =
InFlightDiagnostic const diag =
emitError(loc)
<< "splat source must be an integer scalar for load/store masks";
return failure();
Expand Down Expand Up @@ -428,5 +434,5 @@ LogicalResult MaskState::parseExpandDims(triton::ExpandDimsOp expandDimsOp,
return success();
}

} // namespace triton
} // namespace mlir
} // namespace mlir::triton

Loading