-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[triton][tool] A CLI Tool for Tensor Layout Printing (#4486)
A CLI tool to print the layout of a tensor. Currently, only triton_gpu's `DistributedEncoding` (no `SharedEncoding`) tensor layout print is supported via the exposed `getLayoutStr` API from the dialect library. In the future, we could also add more tensor layout print from other backend HW targets (e.g., CPU). Example usage: ``` triton-tensor-layout -l "#triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>" triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt -alias-names="blocked,mma" -use-hw-view ``` An input file usually looks like: ``` #mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}> #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}> ``` The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [] I have not added any `lit` tests. - [x] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --------- Co-authored-by: Yuanwei Fang <[email protected]>
- Loading branch information
Showing
3 changed files
with
294 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,229 @@ | ||
#include "mlir/AsmParser/AsmParser.h" | ||
#include "mlir/AsmParser/AsmParserState.h" | ||
#include "mlir/IR/MLIRContext.h" | ||
|
||
#include "triton/Dialect/TritonGPU/IR/Dialect.h" | ||
|
||
#include "llvm/Support/CommandLine.h" | ||
#include "llvm/Support/ErrorOr.h" | ||
#include "llvm/Support/FileSystem.h" | ||
#include "llvm/Support/MemoryBuffer.h" | ||
#include "llvm/Support/SourceMgr.h" | ||
#include "llvm/Support/raw_ostream.h" | ||
|
||
using namespace llvm; | ||
using namespace mlir; | ||
|
||
// A CLI tool to print the layout of a tensor. | ||
// | ||
// clang-format off | ||
// Example usage: | ||
// | ||
// triton-tensor-layout -l "#triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>" | ||
// | ||
// triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt | ||
// | ||
// triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt -alias-names="blocked,mma" -use-hw-view | ||
// | ||
// An input file usually looks like: | ||
// ''' | ||
// #mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}> | ||
// #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}> | ||
// ''' | ||
// clang-format on | ||
|
||
//===--------------------------------------------------------------------===// | ||
// CLI options | ||
//===--------------------------------------------------------------------===// | ||
|
||
cl::OptionCategory PrinterCategory("Available Print Options", | ||
"Options for the tensor layout printing."); | ||
|
||
static cl::opt<std::string> InputFile( | ||
"i", cl::desc("File that contains the tensor data layout attributes"), | ||
cl::init(""), cl::value_desc("filename"), cl::cat(PrinterCategory)); | ||
|
||
static cl::opt<std::string> | ||
OutputFile("o", cl::desc("Output file to write the layout into"), | ||
cl::init(""), cl::value_desc("filename"), | ||
cl::cat(PrinterCategory)); | ||
|
||
static cl::opt<std::string> | ||
DataLayoutStr("l", cl::desc("Tensor data layout attribute in string"), | ||
cl::value_desc("layout-string"), cl::init(""), | ||
cl::cat(PrinterCategory)); | ||
|
||
static cl::list<std::string> | ||
AliasName("alias-names", | ||
cl::desc("A list of alias names (separated by comma) of the " | ||
"layout attributes in the input file"), | ||
cl::value_desc("name1,name2,name3,..."), cl::CommaSeparated, | ||
cl::ZeroOrMore, cl::cat(PrinterCategory)); | ||
|
||
static cl::opt<bool> UseHWPointOfView( | ||
"use-hw-view", | ||
llvm::cl::desc( | ||
"Print the layout in hardware point of view. This means the output is " | ||
"from the warp's perspective. Otherwise, the output is from the " | ||
"tensor's perspective (e.g., each element maps to xxx thread)."), | ||
cl::init(false), cl::cat(PrinterCategory)); | ||
|
||
static cl::opt<std::string> TensorStr( | ||
"t", cl::desc("Tensor shape and element type (e.g., tensor<2x2xf32>)"), | ||
cl::init(""), cl::value_desc("tensor-type"), cl::cat(PrinterCategory)); | ||
|
||
//===--------------------------------------------------------------------===// | ||
// Helper functions | ||
//===--------------------------------------------------------------------===// | ||
|
||
LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) { | ||
StringRef dialectName = tensorType.getEncoding().getDialect().getNamespace(); | ||
|
||
// Dispatch to the corresponding dialect helper function to print the layout. | ||
if (dialectName == "triton_gpu") { | ||
os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView); | ||
return success(); | ||
} | ||
|
||
llvm::errs() << "Unsupported tensor layout attribute: " | ||
<< tensorType.getEncoding() << "\n"; | ||
return failure(); | ||
} | ||
|
||
LogicalResult printLayoutFromFile(MLIRContext *context, StringRef filename, | ||
ArrayRef<std::string> names, | ||
TensorType tensorTy, raw_string_ostream &ss) { | ||
if (filename.empty()) | ||
return success(); | ||
|
||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = | ||
llvm::MemoryBuffer::getFileOrSTDIN(filename); | ||
if (std::error_code ec = fileOrErr.getError()) { | ||
llvm::errs() << "Could not open input file: " << ec.message() << "\n"; | ||
return failure(); | ||
} | ||
|
||
llvm::SourceMgr sourceMgr; | ||
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); | ||
ParserConfig config(context); | ||
auto asmState = AsmParserState(); | ||
|
||
Block parsedIR; | ||
if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, config, &asmState))) { | ||
llvm::errs() << "Fail to parse the input file: " << filename << "\n"; | ||
return failure(); | ||
} | ||
|
||
auto printLambda = [&](StringRef name, Attribute attr) { | ||
ss << "Print layout attribute: #" << name << " = " << attr << "\n"; | ||
|
||
auto rankedTensorTy = RankedTensorType::get( | ||
tensorTy.getShape(), tensorTy.getElementType(), attr); | ||
|
||
return layoutPrint(rankedTensorTy, ss); | ||
}; | ||
|
||
if (names.empty()) | ||
// If no alias name is given, we print all layout attributes in the file. | ||
for (auto def : asmState.getAttributeAliasDefs()) { | ||
if (failed(printLambda(def.name, def.value))) | ||
return failure(); | ||
} | ||
else { | ||
// Print the layout attributes with the given alias names. | ||
for (auto alias : names) { | ||
auto def = asmState.getAttributeAliasDef(alias); | ||
if (!def) { | ||
llvm::errs() << "Can't find the layout attribute: " << alias << "\n"; | ||
return failure(); | ||
} | ||
|
||
if (failed(printLambda(alias, def->value))) | ||
return failure(); | ||
|
||
ss << "\n"; | ||
} | ||
} | ||
|
||
return success(); | ||
} | ||
|
||
LogicalResult printLayoutFromString(MLIRContext *context, | ||
StringRef layoutAttrStr, | ||
TensorType tensorTy, | ||
raw_string_ostream &ss) { | ||
if (layoutAttrStr.empty()) | ||
return success(); | ||
|
||
Attribute layout = parseAttribute(layoutAttrStr, context); | ||
if (!layout) { | ||
llvm::errs() << "Invalid layout attribute: " << layoutAttrStr << "\n"; | ||
return failure(); | ||
} | ||
|
||
auto rankedTensorTy = RankedTensorType::get( | ||
tensorTy.getShape(), tensorTy.getElementType(), layout); | ||
|
||
ss << "Print layout attribute: " << layout << "\n"; | ||
|
||
return layoutPrint(rankedTensorTy, ss); | ||
} | ||
|
||
//===--------------------------------------------------------------------===// | ||
// Main entry point | ||
//===--------------------------------------------------------------------===// | ||
|
||
int main(int argc, char **argv) { | ||
cl::HideUnrelatedOptions(PrinterCategory); | ||
cl::ParseCommandLineOptions(argc, argv, "tensor layout printer\n"); | ||
|
||
DialectRegistry registry; | ||
// Register all dialects that can print tensor layout. | ||
registry.insert<triton::gpu::TritonGPUDialect>(); | ||
|
||
MLIRContext ctx(registry); | ||
ctx.loadAllAvailableDialects(); | ||
|
||
if (TensorStr.empty()) { | ||
llvm::errs() << "Must specify the tensor type argument\n"; | ||
return 1; | ||
} | ||
|
||
Type parsedTy = parseType(TensorStr, &ctx); | ||
if (!parsedTy) { | ||
llvm::errs() << "Fail to parse the tensor type argument: " << TensorStr | ||
<< "\n"; | ||
return 1; | ||
} | ||
|
||
TensorType tensorType = dyn_cast<TensorType>(parsedTy); | ||
if (!tensorType) { | ||
llvm::errs() << "Invalid tensor type argument: " << TensorStr << "\n"; | ||
return 1; | ||
} | ||
|
||
std::string storage; | ||
raw_string_ostream ss(storage); | ||
|
||
if (failed(printLayoutFromFile(&ctx, InputFile, AliasName, tensorType, ss))) | ||
return 1; | ||
|
||
if (failed(printLayoutFromString(&ctx, DataLayoutStr, tensorType, ss))) | ||
return 1; | ||
|
||
if (OutputFile.empty()) { | ||
llvm::outs() << ss.str(); | ||
} else { | ||
std::error_code ec; | ||
llvm::raw_fd_ostream outFs(OutputFile, ec, llvm::sys::fs::OF_Text); | ||
if (ec) { | ||
llvm::errs() << "Error: " << ec.message() << " : unable to open " | ||
<< OutputFile << " for output\n"; | ||
return 1; | ||
} | ||
outFs << ss.str(); | ||
outFs.close(); | ||
} | ||
|
||
return 0; | ||
} |
Oops, something went wrong.