Skip to content

Commit

Permalink
[LLVMGPU] Add debug print for contraction problem size. NFC. (iree-or…
Browse files Browse the repository at this point in the history
…g#17845)

Signed-off-by: Jakub Kuderski <[email protected]>
  • Loading branch information
kuhar authored Jul 10, 2024
1 parent 78c0051 commit 534928d
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,9 +380,9 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target,
}

[[maybe_unused]] static void
debugPrintContractionInfo(unsigned numLoops,
debugPrintContractionInfo(StringRef label, unsigned numLoops,
linalg::ContractionDimensions contractionDims,
ArrayRef<int64_t> workgroupTileSizes) {
ArrayRef<int64_t> sizes) {
ArrayRef<unsigned> dimVals[] = {contractionDims.batch, contractionDims.m,
contractionDims.n, contractionDims.k};
std::string dimSymbols(numLoops, '*');
Expand All @@ -395,8 +395,8 @@ debugPrintContractionInfo(unsigned numLoops,
llvm::interleaveComma(dimSymbols, llvm::dbgs());
llvm::dbgs() << "]\n";

DBGS() << "Workgroup tile sizes: [";
llvm::interleaveComma(workgroupTileSizes, llvm::dbgs());
DBGS() << label << ": [";
llvm::interleaveComma(sizes, llvm::dbgs());
llvm::dbgs() << "]\n";
}

Expand All @@ -419,6 +419,9 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target,
return failure();
}

LLVM_DEBUG(debugPrintContractionInfo("Problem size", op.getNumLoops(),
*contractionDims, bounds));

// For now we are not being smart and trying to reshape dimensions to allow
// for better usage of intrinsics, and instead are tiling all dimensions
// except the inner most m, n, and k dimensions to 1.
Expand Down Expand Up @@ -577,8 +580,8 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target,
// Follow the LLVMGPU convention of keeping all of the tile sizes in one list.
workgroupTileSizes[kDim] = schedule->kTileCount * schedule->kSize;

LLVM_DEBUG(debugPrintContractionInfo(op.getNumLoops(), *contractionDims,
workgroupTileSizes));
LLVM_DEBUG(debugPrintContractionInfo("Workgroup tile sizes", op.getNumLoops(),
*contractionDims, workgroupTileSizes));

TileSizesListType tileSizes;
tileSizes.push_back(workgroupTileSizes);
Expand Down

0 comments on commit 534928d

Please sign in to comment.