Skip to content

Commit

Permalink
[triton][tool] A CLI Tool for Tensor Layout Printing (#4486)
Browse files Browse the repository at this point in the history
A CLI tool to print the layout of a tensor. Currently, only triton_gpu's
`DistributedEncoding` (no `SharedEncoding`) tensor layout print is
supported via the exposed `getLayoutStr` API from the dialect library.
In the future, we could also add more tensor layout print from other
backend HW targets (e.g., CPU).

Example usage:

```
triton-tensor-layout -l "#triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>"

triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt

triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt -alias-names="blocked,mma" -use-hw-view
```
An input file usually looks like:
```
#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}>
```


The core Triton is a small number of people, and we receive many PRs
(thank
you!).  To help us review your code more quickly, **if you are a new
contributor (less than 3 PRs merged) we ask that you complete the
following
tasks and include the filled-out checklist in your PR description.**

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.

- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [x] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
  - [] This PR does not need a test because `FILL THIS IN`.

- Select one of the following.
  - [] I have not added any `lit` tests.
- [x] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)

---------

Co-authored-by: Yuanwei Fang <[email protected]>
  • Loading branch information
fywkevin and Yuanwei Fang authored Aug 15, 2024
1 parent 1a20556 commit 45af9a9
Show file tree
Hide file tree
Showing 3 changed files with 294 additions and 0 deletions.
7 changes: 7 additions & 0 deletions bin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,10 @@ target_link_libraries(triton-llvm-opt PRIVATE
LLVMCodeGen
)
export_executable_symbols_for_plugins(triton-llvm-opt)


add_llvm_executable(triton-tensor-layout triton-tensor-layout.cpp PARTIAL_SOURCES_INTENDED)
target_link_libraries(triton-tensor-layout PRIVATE
TritonGPUIR
${triton_libs}
)
229 changes: 229 additions & 0 deletions bin/triton-tensor-layout.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/AsmParser/AsmParserState.h"
#include "mlir/IR/MLIRContext.h"

#include "triton/Dialect/TritonGPU/IR/Dialect.h"

#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"

using namespace llvm;
using namespace mlir;

// A CLI tool to print the layout of a tensor.
//
// clang-format off
// Example usage:
//
// triton-tensor-layout -l "#triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>"
//
// triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt
//
// triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt -alias-names="blocked,mma" -use-hw-view
//
// An input file usually looks like:
// '''
// #mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}>
// #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}>
// '''
// clang-format on

//===--------------------------------------------------------------------===//
// CLI options
//===--------------------------------------------------------------------===//

cl::OptionCategory PrinterCategory("Available Print Options",
"Options for the tensor layout printing.");

static cl::opt<std::string> InputFile(
"i", cl::desc("File that contains the tensor data layout attributes"),
cl::init(""), cl::value_desc("filename"), cl::cat(PrinterCategory));

static cl::opt<std::string>
OutputFile("o", cl::desc("Output file to write the layout into"),
cl::init(""), cl::value_desc("filename"),
cl::cat(PrinterCategory));

static cl::opt<std::string>
DataLayoutStr("l", cl::desc("Tensor data layout attribute in string"),
cl::value_desc("layout-string"), cl::init(""),
cl::cat(PrinterCategory));

static cl::list<std::string>
AliasName("alias-names",
cl::desc("A list of alias names (separated by comma) of the "
"layout attributes in the input file"),
cl::value_desc("name1,name2,name3,..."), cl::CommaSeparated,
cl::ZeroOrMore, cl::cat(PrinterCategory));

static cl::opt<bool> UseHWPointOfView(
"use-hw-view",
llvm::cl::desc(
"Print the layout in hardware point of view. This means the output is "
"from the warp's perspective. Otherwise, the output is from the "
"tensor's perspective (e.g., each element maps to xxx thread)."),
cl::init(false), cl::cat(PrinterCategory));

static cl::opt<std::string> TensorStr(
"t", cl::desc("Tensor shape and element type (e.g., tensor<2x2xf32>)"),
cl::init(""), cl::value_desc("tensor-type"), cl::cat(PrinterCategory));

//===--------------------------------------------------------------------===//
// Helper functions
//===--------------------------------------------------------------------===//

LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) {
StringRef dialectName = tensorType.getEncoding().getDialect().getNamespace();

// Dispatch to the corresponding dialect helper function to print the layout.
if (dialectName == "triton_gpu") {
os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView);
return success();
}

llvm::errs() << "Unsupported tensor layout attribute: "
<< tensorType.getEncoding() << "\n";
return failure();
}

LogicalResult printLayoutFromFile(MLIRContext *context, StringRef filename,
ArrayRef<std::string> names,
TensorType tensorTy, raw_string_ostream &ss) {
if (filename.empty())
return success();

llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(filename);
if (std::error_code ec = fileOrErr.getError()) {
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
return failure();
}

llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
ParserConfig config(context);
auto asmState = AsmParserState();

Block parsedIR;
if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, config, &asmState))) {
llvm::errs() << "Fail to parse the input file: " << filename << "\n";
return failure();
}

auto printLambda = [&](StringRef name, Attribute attr) {
ss << "Print layout attribute: #" << name << " = " << attr << "\n";

auto rankedTensorTy = RankedTensorType::get(
tensorTy.getShape(), tensorTy.getElementType(), attr);

return layoutPrint(rankedTensorTy, ss);
};

if (names.empty())
// If no alias name is given, we print all layout attributes in the file.
for (auto def : asmState.getAttributeAliasDefs()) {
if (failed(printLambda(def.name, def.value)))
return failure();
}
else {
// Print the layout attributes with the given alias names.
for (auto alias : names) {
auto def = asmState.getAttributeAliasDef(alias);
if (!def) {
llvm::errs() << "Can't find the layout attribute: " << alias << "\n";
return failure();
}

if (failed(printLambda(alias, def->value)))
return failure();

ss << "\n";
}
}

return success();
}

LogicalResult printLayoutFromString(MLIRContext *context,
StringRef layoutAttrStr,
TensorType tensorTy,
raw_string_ostream &ss) {
if (layoutAttrStr.empty())
return success();

Attribute layout = parseAttribute(layoutAttrStr, context);
if (!layout) {
llvm::errs() << "Invalid layout attribute: " << layoutAttrStr << "\n";
return failure();
}

auto rankedTensorTy = RankedTensorType::get(
tensorTy.getShape(), tensorTy.getElementType(), layout);

ss << "Print layout attribute: " << layout << "\n";

return layoutPrint(rankedTensorTy, ss);
}

//===--------------------------------------------------------------------===//
// Main entry point
//===--------------------------------------------------------------------===//

int main(int argc, char **argv) {
cl::HideUnrelatedOptions(PrinterCategory);
cl::ParseCommandLineOptions(argc, argv, "tensor layout printer\n");

DialectRegistry registry;
// Register all dialects that can print tensor layout.
registry.insert<triton::gpu::TritonGPUDialect>();

MLIRContext ctx(registry);
ctx.loadAllAvailableDialects();

if (TensorStr.empty()) {
llvm::errs() << "Must specify the tensor type argument\n";
return 1;
}

Type parsedTy = parseType(TensorStr, &ctx);
if (!parsedTy) {
llvm::errs() << "Fail to parse the tensor type argument: " << TensorStr
<< "\n";
return 1;
}

TensorType tensorType = dyn_cast<TensorType>(parsedTy);
if (!tensorType) {
llvm::errs() << "Invalid tensor type argument: " << TensorStr << "\n";
return 1;
}

std::string storage;
raw_string_ostream ss(storage);

if (failed(printLayoutFromFile(&ctx, InputFile, AliasName, tensorType, ss)))
return 1;

if (failed(printLayoutFromString(&ctx, DataLayoutStr, tensorType, ss)))
return 1;

if (OutputFile.empty()) {
llvm::outs() << ss.str();
} else {
std::error_code ec;
llvm::raw_fd_ostream outFs(OutputFile, ec, llvm::sys::fs::OF_Text);
if (ec) {
llvm::errs() << "Error: " << ec.message() << " : unable to open "
<< OutputFile << " for output\n";
return 1;
}
outFs << ss.str();
outFs.close();
}

return 0;
}
Loading

0 comments on commit 45af9a9

Please sign in to comment.