From 0f08aea6ba91cafc672fd883a907ab03ca623e1a Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Sat, 25 Nov 2023 00:46:16 -0600 Subject: [PATCH] Bump LLVM (#779) --- include/aie/Dialect/AIE/IR/AIEDialect.h | 28 +- include/aie/Dialect/AIEVec/AIEVecUtils.h | 68 +- lib/Dialect/AIE/IR/AIEDialect.cpp | 169 ++- .../AIE/Transforms/AIECoreToStandard.cpp | 2 + .../AIE/Transforms/AIECreatePacketFlows.cpp | 2 +- .../Transforms/AIENormalizeAddressSpaces.cpp | 2 +- .../AIEObjectFifoRegisterProcess.cpp | 20 +- .../AIEObjectFifoStatefulTransform.cpp | 9 +- lib/Dialect/AIEVec/IR/AIEVecOps.cpp | 187 ++- .../TransformOps/AIEVecTransformOps.cpp | 42 +- .../AIEVec/Transforms/AIEVectorize.cpp | 277 ++-- .../DynamicSizeNoImplicitBroadcast.cpp | 6 +- .../Transforms/VectorToAIEVecConversions.cpp | 1152 +++++++---------- .../AIEX/Transforms/AIEHerdRouting.cpp | 115 +- lib/Dialect/AIEX/Utils/AIETokenAnalysis.cpp | 6 +- lib/Targets/ADFGenerateCppGraph.cpp | 99 +- lib/Targets/AIETargetXAIEV2.cpp | 78 +- lib/Targets/AIETargets.cpp | 30 +- .../AIEVecToCpp/TranslateAIEVecToCpp.cpp | 769 ++++++----- test/aievec/conv2d_i16_after_polygeist.mlir | 2 +- test/aievec/conv2d_i16_after_polygeist_2.mlir | 2 +- test/aievec/conv2d_i8_after_polygeist.mlir | 2 +- ...gemm64_int16_unroll32_after_polygeist.mlir | 2 +- utils/clone-llvm.sh | 2 +- 24 files changed, 1369 insertions(+), 1702 deletions(-) diff --git a/include/aie/Dialect/AIE/IR/AIEDialect.h b/include/aie/Dialect/AIE/IR/AIEDialect.h index 6aec480b4a..6fe66aea66 100644 --- a/include/aie/Dialect/AIE/IR/AIEDialect.h +++ b/include/aie/Dialect/AIE/IR/AIEDialect.h @@ -28,8 +28,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Types.h" -namespace xilinx { -namespace AIE { +namespace xilinx::AIE { // Check that the given DMA-like op (e.g. MemOp, ShimDMAOp) // has valid BDs. @@ -47,8 +46,7 @@ struct HasValidDMAChannels }; class TileOp; -} // namespace AIE -} // namespace xilinx +} // namespace xilinx::AIE /// Include the generated interface declarations. #include "aie/Dialect/AIE/IR/AIEInterfaces.h.inc" @@ -56,13 +54,11 @@ class TileOp; // Include dialect declarations such as parseAttributes, parseType #include "aie/Dialect/AIE/IR/AIEDialect.h.inc" -namespace xilinx { -namespace AIE { +namespace xilinx::AIE { void registerAIETranslations(); -} // namespace AIE -} // namespace xilinx +} // namespace xilinx::AIE //////////////////////////////////////////////////////////////////////////////// /////////////////////// Custom Types for the Dialect /////////////////////////// @@ -72,8 +68,7 @@ void registerAIETranslations(); #define GET_TYPEDEF_CLASSES 1 #include "aie/Dialect/AIE/IR/AIETypes.h.inc" -namespace xilinx { -namespace AIE { +namespace xilinx::AIE { namespace detail { struct AIEObjectFifoTypeStorage; } @@ -122,8 +117,7 @@ class AIEObjectFifoSubviewType Type getElementType(); }; -} // namespace AIE -} // namespace xilinx +} // namespace xilinx::AIE //////////////////////////////////////////////////////////////////////////////// // Custom Attributes /////////////////////////////////////////////////////////// @@ -136,8 +130,7 @@ class AIEObjectFifoSubviewType //////////////////// Custom Operations for the Dialect ///////////////////////// //////////////////////////////////////////////////////////////////////////////// -namespace xilinx { -namespace AIE { +namespace xilinx::AIE { #define GENERATE_TO_STRING(TYPE_WITH_INSERTION_OP) \ friend std::string to_string(const TYPE_WITH_INSERTION_OP &s) { \ @@ -224,7 +217,7 @@ parseObjectFifoProducerTile(mlir::OpAsmParser &parser, mlir::OpAsmParser::UnresolvedOperand &operand, DimTupleArrayAttr &dimensions); -void printObjectFifoProducerTile(mlir::OpAsmPrinter &_odsPrinter, +void printObjectFifoProducerTile(mlir::OpAsmPrinter &printer, mlir::Operation *op, mlir::Value tile, mlir::Attribute dimensions); @@ -233,14 +226,13 @@ mlir::ParseResult parseObjectFifoConsumerTiles( llvm::SmallVector &tiles, DimTupleArrayArrayAttr &dimensions); -void printObjectFifoConsumerTiles(mlir::OpAsmPrinter &_odsPrinter, +void printObjectFifoConsumerTiles(mlir::OpAsmPrinter &printer, mlir::Operation *op, mlir::OperandRange tiles, mlir::Attribute dimensions); uint64_t getBufferBaseAddress(mlir::Operation *bufOp); -} // namespace AIE -} // namespace xilinx +} // namespace xilinx::AIE // include TableGen generated Op definitions #define GET_OP_CLASSES diff --git a/include/aie/Dialect/AIEVec/AIEVecUtils.h b/include/aie/Dialect/AIEVec/AIEVecUtils.h index 95321a33d0..4cafaa3c79 100644 --- a/include/aie/Dialect/AIEVec/AIEVecUtils.h +++ b/include/aie/Dialect/AIEVec/AIEVecUtils.h @@ -54,7 +54,7 @@ inline unsigned getVectorLaneSize(mlir::VectorType type) { assert(type.getRank() > 0 && "Cannot handle rank-0 vectors"); auto dimSize = type.getDimSize(type.getRank() - 1); assert(dimSize >= 0 && "Vector dimension cannot be negative"); - return std::max(1u, (unsigned)dimSize); + return std::max(1u, static_cast(dimSize)); } // For a 1D vector, return its size in bits. For an nD vector, return the size @@ -75,29 +75,30 @@ inline bool isAIEOp(mlir::Operation *op) { inline mlir::VectorType getVectorOpDestType(mlir::VectorType type, bool AIEML) { mlir::Type stype = type.getElementType(); - if (auto itype = stype.dyn_cast()) { + if (auto itype = llvm::dyn_cast(stype)) { // Integer vector types are sized for the appropriate accumulators assert(itype.getWidth() <= 64); - unsigned width = 0; - if (AIEML) { + unsigned width; + if (AIEML) width = itype.getWidth() <= 16 ? 32 : 64; - } else { + else width = itype.getWidth() <= 16 ? 48 : 80; - } mlir::Type ctype = mlir::IntegerType::get(itype.getContext(), width); return mlir::VectorType::get(type.getShape(), ctype); - } else if (auto ftype = stype.dyn_cast()) { - if (AIEML && ftype.getWidth() == 16) { + } + + if (auto ftype = llvm::dyn_cast(stype)) { + if (AIEML && ftype.getWidth() == 16) return mlir::VectorType::get(type.getShape(), - ftype.getF32(ftype.getContext())); - } + mlir::FloatType::getF32(ftype.getContext())); // Floating point vector types for aie1 are returned as is since the // floating point operations write back to registers and not accumulators return type; - } else - llvm::report_fatal_error("Unsupported destination type"); + } + + llvm::report_fatal_error("Unsupported destination type"); } // Linearize the exprVec as a strided access, but do not simplify @@ -109,24 +110,23 @@ flattenedStridedExpr(llvm::ArrayRef sizes, if (sizes.empty() || exprs.empty()) return nullptr; - if (llvm::is_contained(sizes, 0)) + if (is_contained(sizes, 0)) return getAffineConstantExpr(0, context); auto maps = mlir::AffineMap::inferFromExprList(exprs); - if (maps.empty()) { + if (maps.empty()) return nullptr; - } unsigned nSymbols = maps[0].getNumSymbols(); mlir::AffineExpr expr; bool dynamicPoisonBit = false; int64_t runningSize = 1; - for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { + for (auto en : zip(reverse(exprs), reverse(sizes))) { int64_t size = std::get<1>(en); - if (size == 0) continue; + mlir::AffineExpr dimExpr = std::get<0>(en); mlir::AffineExpr stride = dynamicPoisonBit ? getAffineSymbolExpr(nSymbols++, context) @@ -134,21 +134,17 @@ flattenedStridedExpr(llvm::ArrayRef sizes, expr = expr ? expr + dimExpr * stride : dimExpr * stride; if (size > 0) { runningSize *= size; - if (runningSize <= 0) { + if (runningSize <= 0) return nullptr; - } - } else { + } else dynamicPoisonBit = true; - } } return expr; } // Construct a linearized affine expression for the upd op. -inline mlir::AffineExpr -constructLinearizedAffineExprForUPDOp(aievec::UPDOp updOp) { - mlir::MemRefType memRefType = - updOp.getSource().getType().cast(); +inline mlir::AffineExpr constructLinearizedAffineExprForUPDOp(UPDOp updOp) { + auto memRefType = updOp.getSource().getType().cast(); mlir::MLIRContext *context = memRefType.getContext(); llvm::SmallVector exprVec; @@ -158,12 +154,12 @@ constructLinearizedAffineExprForUPDOp(aievec::UPDOp updOp) { if (auto apOf = value.getDefiningOp()) { mlir::AffineMap map = apOf.getAffineMap(); // Cannot create linearized mlir::AffineExpr for complicated index. - if (map.getNumResults() != 1) { + if (map.getNumResults() != 1) return nullptr; - } + llvm::SmallVector indexExprs; - for (auto index : apOf.getMapOperands()) { + for (auto index : apOf.getMapOperands()) if (auto cIdx = index.getDefiningOp()) { auto idxVal = cIdx.getValue().cast().getValue(); unsigned idx = idxVal.getSExtValue(); @@ -174,7 +170,6 @@ constructLinearizedAffineExprForUPDOp(aievec::UPDOp updOp) { getAffineDimExpr(indexToExprDimMap.size(), context); indexExprs.push_back(indexToExprDimMap[index]); } - } exprVec.push_back(map.getResult(0).replaceDims(indexExprs)); } else if (auto cOp = value.getDefiningOp()) { @@ -189,13 +184,11 @@ constructLinearizedAffineExprForUPDOp(aievec::UPDOp updOp) { } } - if (exprVec.empty()) { + if (exprVec.empty()) return nullptr; - } auto ret = flattenedStridedExpr(memRefType.getShape(), exprVec, memRefType.getContext()); - return ret; } @@ -208,17 +201,17 @@ extractBaseAndOffset(mlir::AffineExpr expr) { mlir::AffineExpr base = expr; int32_t offset = 0; - if (auto constExpr = expr.dyn_cast()) { + if (auto constExpr = llvm::dyn_cast(expr)) { base = nullptr; offset += constExpr.getValue(); - } else if (auto binopExpr = expr.dyn_cast()) { + } else if (auto binopExpr = llvm::dyn_cast(expr)) { if (binopExpr.getKind() == mlir::AffineExprKind::Add) { mlir::AffineExpr lhs = binopExpr.getLHS(), rhs = binopExpr.getRHS(); - if (auto constExpr = lhs.dyn_cast()) { + if (auto constExpr = llvm::dyn_cast(lhs)) { base = rhs; offset += constExpr.getValue(); } - if (auto constExpr = rhs.dyn_cast()) { + if (auto constExpr = llvm::dyn_cast(rhs)) { base = base == rhs ? nullptr : lhs; offset += constExpr.getValue(); } @@ -235,10 +228,9 @@ extractBaseAndOffset(mlir::AffineExpr expr) { // parent of the block. inline bool isAssumingNoImplicitBroadcastOfDynamicSizes(mlir::Block *block) { for (mlir::Operation *parentOp = block->getParentOp(); parentOp; - parentOp = parentOp->getParentOp()) { + parentOp = parentOp->getParentOp()) if (parentOp->hasAttr("tosa.no_implicit_broadcast_of_dynamic_sizes")) return true; - } return false; } diff --git a/lib/Dialect/AIE/IR/AIEDialect.cpp b/lib/Dialect/AIE/IR/AIEDialect.cpp index c99977b7cc..06d26ecae1 100644 --- a/lib/Dialect/AIE/IR/AIEDialect.cpp +++ b/lib/Dialect/AIE/IR/AIEDialect.cpp @@ -29,7 +29,7 @@ using namespace xilinx::AIE; namespace { -struct AIEInlinerInterface : public DialectInlinerInterface { +struct AIEInlinerInterface : DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; // We don't have any special restrictions on what can be inlined into // destination regions. Always allow it. @@ -48,11 +48,10 @@ struct AIEInlinerInterface : public DialectInlinerInterface { void handleTerminator(Operation *op, Block *newDest) const final {} // Handle the given inlined terminator by replacing it with a new operation // as necessary. Required when the region has only one block. - void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final {} + void handleTerminator(Operation *op, ArrayRef valuesToRepl) const {} }; -struct AIEDialectFoldInterface : public DialectFoldInterface { +struct AIEDialectFoldInterface : DialectFoldInterface { using DialectFoldInterface::DialectFoldInterface; /// Registered hook to check if the given region, which is attached to an @@ -131,7 +130,7 @@ struct UsesAreAccessable { namespace detail { /// This class represents the internal storage of the AIE `ObjectFifoType`. -struct AIEObjectFifoTypeStorage : public TypeStorage { +struct AIEObjectFifoTypeStorage : TypeStorage { /// The `KeyTy` is a required type that provides an interface for the storage /// instance. This type will be used when uniquing an instance of the type /// storage. @@ -183,7 +182,7 @@ Type AIEObjectFifoType::getElementType() { namespace detail { /// This class represents the internal storage of the AIE /// `ObjectFifoSubviewType`. -struct AIEObjectFifoSubviewTypeStorage : public TypeStorage { +struct AIEObjectFifoSubviewTypeStorage : TypeStorage { /// The `KeyTy` is a required type that provides an interface for the storage /// instance. This type will be used when uniquing an instance of the type /// storage. @@ -238,12 +237,11 @@ Type AIEObjectFifoSubviewType::getElementType() { /// AIE-type /// ::= `objectFifo` `<` type `>` /// ::= `objectFifoSubview` `<` type `>` -static OptionalParseResult aieTypeParser(MLIRContext *context, - DialectAsmParser &parser, +static OptionalParseResult aieTypeParser(DialectAsmParser &parser, StringRef name, Type &result) { if (name.equals("objectFifo")) { Type elementType; - llvm::SMLoc typeLoc = parser.getCurrentLocation(); + SMLoc typeLoc = parser.getCurrentLocation(); if (parser.parseLess() || parser.parseType(elementType) || parser.parseGreater()) return failure(); @@ -266,7 +264,7 @@ static OptionalParseResult aieTypeParser(MLIRContext *context, // Parse the element type of the struct. Type elementType; // Parse the current element type. - llvm::SMLoc typeLoc = parser.getCurrentLocation(); + SMLoc typeLoc = parser.getCurrentLocation(); if (parser.parseType(elementType)) return failure(); @@ -293,11 +291,9 @@ static OptionalParseResult aieTypeParser(MLIRContext *context, /// refer to a type defined in this dialect. static ParseResult parse(Type &result, StringRef name, DialectAsmParser &parser) { - auto *context = parser.getBuilder().getContext(); - OptionalParseResult parseResult; - parseResult = aieTypeParser(context, parser, name, result); - if (parseResult.has_value()) + if (OptionalParseResult parseResult = aieTypeParser(parser, name, result); + parseResult.has_value()) return parseResult.value(); parser.emitError(parser.getNameLoc(), "unknown AIE dialect type: \"") @@ -356,8 +352,8 @@ template struct HasSomeTerminator { for (auto ®ion : op->getRegions()) { for (auto &block : region) { if (!block.empty()) { - Operation *operation = &block.back(); - if (!llvm::isa_and_nonnull(operation)) + if (Operation *operation = &block.back(); + !llvm::isa_and_nonnull(operation)) return operation->emitOpError("is not an allowed terminator") .attachNote(op->getLoc()) .append("in this context: "); @@ -381,7 +377,7 @@ LogicalResult HasValidBDs::verifyTrait(Operation *op) { for (auto &block : element.getBody()) { if (!block.template getOps().empty()) { if (bdNum >= bdMax) { - auto bd = *(block.template getOps().begin()); + auto bd = *block.template getOps().begin(); return (op->emitOpError("has more than ") << bdMax << " blocks") .attachNote(bd.getLoc()) .append("no space for this bd: "); @@ -417,8 +413,8 @@ LogicalResult HasValidDMAChannels::verifyTrait(Operation *op) { // ObjectFifoCreateOp LogicalResult ObjectFifoCreateOp::verify() { if (isa(getElemNumber())) { - size_t numDepths = dyn_cast(getElemNumber()).size(); - if (numDepths != (getConsumerTiles().size() + 1)) // +1 for producer depth + if (size_t numDepths = dyn_cast(getElemNumber()).size(); + numDepths != getConsumerTiles().size() + 1) // +1 for producer depth return emitOpError("does not have enough depths specified for producer " "and for each consumer."); } @@ -438,10 +434,10 @@ TileOp ObjectFifoCreateOp::getProducerTileOp() { namespace xilinx::AIE { ParseResult parseObjectFifoProducerTile(OpAsmParser &parser, - OpAsmParser::UnresolvedOperand &tile, + OpAsmParser::UnresolvedOperand &operand, DimTupleArrayAttr &dimensions) { std::vector emptyDims = {}; - if (parser.parseOperand(tile)) + if (parser.parseOperand(operand)) return failure(); if (succeeded(parser.parseOptionalKeyword("toStream"))) { if (parser.parseCustomAttributeWithFallback( @@ -449,18 +445,18 @@ ParseResult parseObjectFifoProducerTile(OpAsmParser &parser, return failure(); } } else { - dimensions = DimTupleArrayAttr::get(parser.getContext(), - ArrayRef(emptyDims)); + dimensions = + DimTupleArrayAttr::get(parser.getContext(), ArrayRef(emptyDims)); } return success(); } -void printObjectFifoProducerTile(OpAsmPrinter &_odsPrinter, Operation *op, +void printObjectFifoProducerTile(OpAsmPrinter &printer, Operation *op, Value operand, DimTupleArrayAttr dimensions) { - _odsPrinter << operand; + printer << operand; if (!dimensions.empty()) { - _odsPrinter << " toStream "; - _odsPrinter.printStrippedAttrOrType(dimensions); + printer << " toStream "; + printer.printStrippedAttrOrType(dimensions); } } @@ -498,19 +494,19 @@ ParseResult parseObjectFifoConsumerTiles( return success(); } -void printObjectFifoConsumerTiles(OpAsmPrinter &odsPrinter, Operation *op, +void printObjectFifoConsumerTiles(OpAsmPrinter &printer, Operation *op, OperandRange tiles, DimTupleArrayArrayAttr dimsPerTileAttr) { size_t tileIdx = 0; for (auto tile : tiles) { - odsPrinter << tile; + printer << tile; if (dimsPerTileAttr && dimsPerTileAttr.size() == tiles.size() && dimsPerTileAttr[tileIdx] && !dimsPerTileAttr[tileIdx].empty()) { - odsPrinter << " fromStream "; - odsPrinter.printStrippedAttrOrType(dimsPerTileAttr[tileIdx]); + printer << " fromStream "; + printer.printStrippedAttrOrType(dimsPerTileAttr[tileIdx]); } if (tileIdx < tiles.size() - 1) { - odsPrinter << ", "; + printer << ", "; } tileIdx++; } @@ -524,8 +520,7 @@ LogicalResult ObjectFifoLinkOp::verify() { return emitError("ObjectFifoLinkOp does not support 'join' and " "'distribute' at the same time"); - auto sharedTile = getOptionalSharedTile(); - if (!sharedTile) + if (auto sharedTile = getOptionalSharedTile(); !sharedTile) return emitError("ObjectFifoLinkOp must have a link point, i.e., a " "shared tile between objectFifos"); @@ -570,8 +565,8 @@ LogicalResult ObjectFifoLinkOp::verify() { int outputSize = 0; for (auto fifoOut : getOutputObjectFifos()) { - if ((!fifoOut.getDimensionsToStream().empty()) && - (fifoOut.getConsumerTiles().size() > 1)) { + if (!fifoOut.getDimensionsToStream().empty() && + fifoOut.getConsumerTiles().size() > 1) { return emitOpError("currently does not support objectFifos with " "dimensionsToStream and multiple consumers."); } @@ -614,8 +609,8 @@ std::optional ObjectFifoLinkOp::getOptionalSharedTile() { } auto fifoIn = getInputObjectFifos(); - auto fifoOut = getOutputObjectFifos(); - if (!fifoIn.empty() && !fifoOut.empty()) + if (auto fifoOut = getOutputObjectFifos(); + !fifoIn.empty() && !fifoOut.empty()) for (auto consumerIn : fifoIn[0].getConsumerTiles()) if (consumerIn == fifoOut[0].getProducerTile()) return {fifoOut[0].getProducerTile()}; @@ -629,8 +624,8 @@ std::vector ObjectFifoLinkOp::getInputObjectFifos() { if (parent->hasTrait()) { for (auto sym : getFifoIns()) { auto name = dyn_cast(sym); - auto st = SymbolTable::lookupSymbolIn(parent, name); - if (st && isa(st)) + if (auto st = SymbolTable::lookupSymbolIn(parent, name); + st && isa(st)) inputObjFifos.push_back(dyn_cast(st)); } } @@ -645,8 +640,8 @@ std::vector ObjectFifoLinkOp::getOutputObjectFifos() { if (parent->hasTrait()) { for (auto sym : getFifoOuts()) { auto name = dyn_cast(sym); - auto st = SymbolTable::lookupSymbolIn(parent, name); - if (st && isa(st)) + if (auto st = SymbolTable::lookupSymbolIn(parent, name); + st && isa(st)) outputObjFifos.push_back(dyn_cast(st)); } } @@ -670,8 +665,8 @@ ObjectFifoCreateOp ObjectFifoRegisterExternalBuffersOp::getObjectFifo() { Operation *parent = getOperation(); while ((parent = parent->getParentOp())) { if (parent->hasTrait()) { - auto st = SymbolTable::lookupSymbolIn(parent, getObjFifoName()); - if (st && isa(st)) + if (auto st = SymbolTable::lookupSymbolIn(parent, getObjFifoName()); + st && isa(st)) return dyn_cast(st); } } @@ -715,8 +710,8 @@ ObjectFifoCreateOp ObjectFifoAcquireOp::getObjectFifo() { Operation *parent = getOperation(); while ((parent = parent->getParentOp())) { if (parent->hasTrait()) { - auto st = SymbolTable::lookupSymbolIn(parent, getObjFifoName()); - if (st && isa(st)) + if (auto st = SymbolTable::lookupSymbolIn(parent, getObjFifoName()); + st && isa(st)) return dyn_cast(st); } } @@ -760,8 +755,8 @@ ObjectFifoCreateOp ObjectFifoReleaseOp::getObjectFifo() { Operation *parent = getOperation(); while ((parent = parent->getParentOp())) { if (parent->hasTrait()) { - auto st = SymbolTable::lookupSymbolIn(parent, getObjFifoName()); - if (st && isa(st)) + if (auto st = SymbolTable::lookupSymbolIn(parent, getObjFifoName()); + st && isa(st)) return dyn_cast(st); } } @@ -770,12 +765,12 @@ ObjectFifoCreateOp ObjectFifoReleaseOp::getObjectFifo() { // ObjectFifoSubviewAccessOp LogicalResult ObjectFifoSubviewAccessOp::verify() { - auto parent = getOperation()->getParentOfType(); - if (parent == nullptr) + if (auto parent = getOperation()->getParentOfType(); + parent == nullptr) return emitOpError("must be called from inside a CoreOp"); - auto acqOp = getSubview().getDefiningOp(); - if (getIndex() >= acqOp.acqNumber()) + if (auto acqOp = getSubview().getDefiningOp(); + getIndex() >= acqOp.acqNumber()) return emitOpError("accessed farther than number of acquired elements " "(index out of bounds)."); @@ -805,8 +800,8 @@ ObjectFifoCreateOp ObjectFifoRegisterProcessOp::getObjectFifo() { Operation *parent = getOperation(); while ((parent = parent->getParentOp())) { if (parent->hasTrait()) { - auto st = SymbolTable::lookupSymbolIn(parent, getObjFifoName()); - if (st && isa(st)) + if (auto st = SymbolTable::lookupSymbolIn(parent, getObjFifoName()); + st && isa(st)) return dyn_cast(st); } } @@ -895,10 +890,9 @@ LogicalResult SwitchboxOp::verify() { << index << " for " << dir << " bundle " << stringifyWireBundle(bundle) << " must be less than " << bound; - else - return ops.emitOpError() - << dir << " bundle " << stringifyWireBundle(bundle) - << " not supported; index: " << index << ", bound: " << bound; + return ops.emitOpError() + << dir << " bundle " << stringifyWireBundle(bundle) + << " not supported; index: " << index << ", bound: " << bound; } return success(); }; @@ -965,7 +959,7 @@ LogicalResult SwitchboxOp::verify() { int arbiter = -1; for (auto val : connectOp.getAmsels()) { auto amsel = dyn_cast(val.getDefiningOp()); - if ((arbiter != -1) && (arbiter != amsel.arbiterIndex())) + if (arbiter != -1 && arbiter != amsel.arbiterIndex()) return connectOp.emitOpError( "a master port can only be tied to one arbiter"); arbiter = amsel.arbiterIndex(); @@ -1108,8 +1102,7 @@ int ShimDMAOp::colIndex() { return getTileOp().colIndex(); } int ShimDMAOp::rowIndex() { return getTileOp().rowIndex(); } LogicalResult PacketRulesOp::verify() { - Region &body = getRules(); - if (body.empty()) + if (Region &body = getRules(); body.empty()) return emitOpError("should have non-empty body"); return success(); } @@ -1163,7 +1156,7 @@ LogicalResult BufferOp::verify() { uint64_t xilinx::AIE::getBufferBaseAddress(Operation *bufOp) { if (auto buf = dyn_cast(bufOp)) return buf.address(); - else if (isa_and_nonnull(bufOp)) + if (isa_and_nonnull(bufOp)) llvm::report_fatal_error( "External buffer addresses are assigned at runtime."); llvm::report_fatal_error("unknown buffer type"); @@ -1230,7 +1223,7 @@ int MemOp::rowIndex() { return getTileOp().rowIndex(); } /// Returns the region on the current operation that is callable. This may /// return nullptr in the case of an external callable object, e.g. an external /// function. -Region *MemOp::getCallableRegion() { return &(getBody()); } +Region *MemOp::getCallableRegion() { return &getBody(); } // MemTileDMAOp LogicalResult MemTileDMAOp::verify() { @@ -1256,7 +1249,7 @@ LogicalResult MemTileDMAOp::verify() { // Move this code to the dialect // Set of blocks found to be reachable within a given region. llvm::SmallSet reachable; - llvm::SmallVector worklist; + SmallVector worklist; Block *firstBD = startOp.getSuccessor(0); reachable.insert(firstBD); worklist.push_back(firstBD); @@ -1272,10 +1265,10 @@ LogicalResult MemTileDMAOp::verify() { } } } - for (auto b : reachable) { - for (auto bd : b->getOps()) { - auto bufferOp = bd.getBufferOp(); - if (bufferOp.getTileOp().colIndex() != colIndex() || + for (Block *b : reachable) { + for (DMABDOp bd : b->getOps()) { + if (auto bufferOp = bd.getBufferOp(); + bufferOp.getTileOp().colIndex() != colIndex() || bufferOp.getTileOp().rowIndex() != rowIndex()) { InFlightDiagnostic err = bd.emitOpError() @@ -1288,8 +1281,8 @@ LogicalResult MemTileDMAOp::verify() { } } for (auto useLock : b->getOps()) { - auto lockOp = useLock.getLockOp(); - if (lockOp.getTileOp().colIndex() != colIndex() || + if (auto lockOp = useLock.getLockOp(); + lockOp.getTileOp().colIndex() != colIndex() || lockOp.getTileOp().rowIndex() != rowIndex()) { InFlightDiagnostic err = useLock.emitOpError() @@ -1316,8 +1309,8 @@ BufferOp DMABDOp::getBufferOp() { LogicalResult DMABDOp::verify() { if (auto memOp = getOperation()->getParentOfType()) { - auto bufferOp = getBufferOp(); - if (bufferOp.getTileOp().colIndex() != memOp.colIndex() || + if (auto bufferOp = getBufferOp(); + bufferOp.getTileOp().colIndex() != memOp.colIndex() || bufferOp.getTileOp().rowIndex() != memOp.rowIndex()) return emitOpError("can only access a buffer in the same tile."); } @@ -1335,7 +1328,7 @@ LogicalResult DMABDOp::verify() { for (int64_t memrefDim : buffer.getShape()) memrefSize *= 4 * memrefDim; - llvm::ArrayRef dims = *getDimensions(); + ArrayRef dims = *getDimensions(); size_t maxNDims = 3; if (isa_and_nonnull((*this)->getParentOp())) { maxNDims = 4; @@ -1388,7 +1381,7 @@ int MemTileDMAOp::rowIndex() { return getTileOp().rowIndex(); } /// Returns the region on the current operation that is callable. This may /// return nullptr in the case of an external callable object, e.g. an external /// function. -Region *MemTileDMAOp::getCallableRegion() { return &(getBody()); } +Region *MemTileDMAOp::getCallableRegion() { return &getBody(); } // SwitchboxOp TileOp SwitchboxOp::getTileOp() { @@ -1418,15 +1411,15 @@ int LockOp::colIndex() { return getTileOp().colIndex(); } int LockOp::rowIndex() { return getTileOp().rowIndex(); } LogicalResult LockOp::verify() { - auto result = UsesAreAccessable::verifyTrait(*this); - if (result.failed()) + if (auto result = UsesAreAccessable::verifyTrait(*this); result.failed()) return result; if (getLockID().has_value()) { const auto &targetModel = getTargetModel(getTileOp()); auto tileOp = getTileOp(); - int numLocks = targetModel.getNumLocks(tileOp.getCol(), tileOp.getRow()); - if (getLockID().value() >= numLocks) + if (int numLocks = + targetModel.getNumLocks(tileOp.getCol(), tileOp.getRow()); + getLockID().value() >= numLocks) return emitOpError("lock assigned invalid id (maximum is ") << numLocks - 1 << ")"; } @@ -1439,8 +1432,8 @@ struct UsesOneLockInDMABlock { auto block = op->getBlock(); int lockID = -1; for (auto op : block->getOps()) { - auto lock = dyn_cast(op.getLock().getDefiningOp()); - if (lock.getLockID().has_value()) { + if (auto lock = dyn_cast(op.getLock().getDefiningOp()); + lock.getLockID().has_value()) { if (lockID != -1 && lockID != lock.getLockIDValue()) return failure(); lockID = lock.getLockIDValue(); @@ -1475,8 +1468,8 @@ struct AccessesLocalLocks { static LogicalResult verifyTrait(Operation *op) { if (auto memOp = op->getParentOfType()) { auto useLock = dyn_cast(op); - auto lock = useLock.getLockOp(); - if (lock.getTileOp().colIndex() != memOp.colIndex() || + if (auto lock = useLock.getLockOp(); + lock.getTileOp().colIndex() != memOp.colIndex() || lock.getTileOp().rowIndex() != memOp.rowIndex()) return failure(); } @@ -1515,15 +1508,13 @@ LogicalResult UseLockOp::verify() { return success(); // Or it can be in a CoreOp, or some FuncOp called from a CoreOp - } else if (HasSomeParent::verifyTrait(*this) - .succeeded()) { + } + if (HasSomeParent::verifyTrait(*this).succeeded()) { return success(); - - } else { - return (*this)->emitOpError() - << "expects some parent op to be one of " - << "AIE::device, AIE::core, func::func, AIE::mem, or AIE::shimDMA"; } + return (*this)->emitOpError() + << "expects some parent op to be one of " + << "AIE::device, AIE::core, func::func, AIE::mem, or AIE::shimDMA"; } #include "aie/Dialect/AIE/IR/AIEEnums.cpp.inc" diff --git a/lib/Dialect/AIE/Transforms/AIECoreToStandard.cpp b/lib/Dialect/AIE/Transforms/AIECoreToStandard.cpp index e9356431fb..7065eadc75 100644 --- a/lib/Dialect/AIE/Transforms/AIECoreToStandard.cpp +++ b/lib/Dialect/AIE/Transforms/AIECoreToStandard.cpp @@ -384,6 +384,8 @@ struct AIECoreToStandardPass : AIECoreToStandardBase { case AIEArch::AIE2: triple = "aie2"; break; + default: + llvm::report_fatal_error("unsupported arch"); } // Ensure that we don't have an incorrect target triple. This may override diff --git a/lib/Dialect/AIE/Transforms/AIECreatePacketFlows.cpp b/lib/Dialect/AIE/Transforms/AIECreatePacketFlows.cpp index e188631ff1..404180bc90 100644 --- a/lib/Dialect/AIE/Transforms/AIECreatePacketFlows.cpp +++ b/lib/Dialect/AIE/Transforms/AIECreatePacketFlows.cpp @@ -319,7 +319,7 @@ struct AIERoutePacketFlowsPass Region &r = pktflow.getPorts(); Block &b = r.front(); int flowID = pktflow.IDInt(); - int xSrc, ySrc; + int xSrc = 0, ySrc = 0; Port sourcePort; for (Operation &Op : b.getOperations()) { diff --git a/lib/Dialect/AIE/Transforms/AIENormalizeAddressSpaces.cpp b/lib/Dialect/AIE/Transforms/AIENormalizeAddressSpaces.cpp index 44019c6ff8..1346337706 100644 --- a/lib/Dialect/AIE/Transforms/AIENormalizeAddressSpaces.cpp +++ b/lib/Dialect/AIE/Transforms/AIENormalizeAddressSpaces.cpp @@ -23,7 +23,7 @@ using namespace xilinx; using namespace xilinx::AIE; Type memRefToDefaultAddressSpace(Type t) { - if (auto memRefType = t.dyn_cast(); + if (auto memRefType = llvm::dyn_cast(t); memRefType && memRefType.getMemorySpace() != nullptr) return MemRefType::get(memRefType.getShape(), memRefType.getElementType(), memRefType.getLayout(), nullptr /* Address Space */); diff --git a/lib/Dialect/AIE/Transforms/AIEObjectFifoRegisterProcess.cpp b/lib/Dialect/AIE/Transforms/AIEObjectFifoRegisterProcess.cpp index 77992d4fe1..580a477c91 100644 --- a/lib/Dialect/AIE/Transforms/AIEObjectFifoRegisterProcess.cpp +++ b/lib/Dialect/AIE/Transforms/AIEObjectFifoRegisterProcess.cpp @@ -127,21 +127,21 @@ struct AIEObjectFifoRegisterProcessPass builder.setInsertionPointToEnd(device.getBody()); ObjectFifoCreateOp objFifo = registerOp.getObjectFifo(); auto elementType = - objFifo.getElemType().dyn_cast().getElementType(); + llvm::dyn_cast(objFifo.getElemType()) + .getElementType(); if (consumersPerFifo.find(objFifo) == consumersPerFifo.end()) { std::queue consumers; - for (auto consumerTile : objFifo.getConsumerTiles()) { + for (auto consumerTile : objFifo.getConsumerTiles()) consumers.push(consumerTile); - } consumersPerFifo[objFifo] = consumers; } // identify tile on which to generate the pattern Value tile; - if (registerOp.getPort() == ObjectFifoPort::Produce) { + if (registerOp.getPort() == ObjectFifoPort::Produce) tile = objFifo.getProducerTile(); - } else if (registerOp.getPort() == ObjectFifoPort::Consume) { + else if (registerOp.getPort() == ObjectFifoPort::Consume) { assert(!consumersPerFifo[objFifo].empty() && "No more available consumer tiles for process."); tile = consumersPerFifo[objFifo].front(); @@ -150,12 +150,11 @@ struct AIEObjectFifoRegisterProcessPass // retrieve core associated to above tile or create new one CoreOp *core = nullptr; - for (auto coreOp : device.getOps()) { + for (auto coreOp : device.getOps()) if (coreOp.getTile() == tile) { core = &coreOp; break; } - } if (core == nullptr) { auto coreOp = builder.create(builder.getUnknownLoc(), builder.getIndexType(), tile); @@ -186,16 +185,14 @@ struct AIEObjectFifoRegisterProcessPass auto acqPattern = registerOp.getAcquirePattern().getValues(); std::vector acqVector; - for (auto i = acqPattern.begin(); i != acqPattern.end(); ++i) { + for (auto i = acqPattern.begin(); i != acqPattern.end(); ++i) acqVector.push_back(*i); - } auto relPattern = registerOp.getReleasePattern().getValues(); std::vector relVector; - for (auto i = relPattern.begin(); i != relPattern.end(); ++i) { + for (auto i = relPattern.begin(); i != relPattern.end(); ++i) relVector.push_back(*i); - } if (acqSize == 1) { // duplicate acquire pattern @@ -219,7 +216,6 @@ struct AIEObjectFifoRegisterProcessPass auto currRel = relVector[i]; if (i < acqSize - 1) { auto nextAcq = acqVector[i + 1]; - if (auto nextRel = relVector[i + 1]; currAcq.getInt() == nextAcq.getInt() && currRel.getInt() == nextRel.getInt()) { diff --git a/lib/Dialect/AIE/Transforms/AIEObjectFifoStatefulTransform.cpp b/lib/Dialect/AIE/Transforms/AIEObjectFifoStatefulTransform.cpp index 3c2a0b14cd..48f8d1de9e 100644 --- a/lib/Dialect/AIE/Transforms/AIEObjectFifoStatefulTransform.cpp +++ b/lib/Dialect/AIE/Transforms/AIEObjectFifoStatefulTransform.cpp @@ -769,7 +769,7 @@ struct AIEObjectFifoStatefulTransformPass lastDmaBlock->getTerminator()->setSuccessor(dmaBlock, 1); // create Bd blocks - Block *succ = nullptr; + Block *succ; Block *curr = bdBlock; size_t blockIndex = 0; for (size_t i = 0; i < numBlocks; i++) { @@ -965,15 +965,16 @@ struct AIEObjectFifoStatefulTransformPass .getDefiningOp() .getValue(); int64_t old_upper_value = - old_upper_bound.dyn_cast().getInt(); + llvm::dyn_cast(old_upper_bound).getInt(); auto old_lower_bound = forLoop.getLowerBound() .getDefiningOp() .getValue(); int64_t old_lower_value = - old_lower_bound.dyn_cast().getInt(); + llvm::dyn_cast(old_lower_bound).getInt(); auto old_step = forLoop.getStep().getDefiningOp().getValue(); - int64_t old_step_value = old_step.dyn_cast().getInt(); + int64_t old_step_value = + llvm::dyn_cast(old_step).getInt(); int64_t num_iter = (old_upper_value - old_lower_value) / old_step_value; diff --git a/lib/Dialect/AIEVec/IR/AIEVecOps.cpp b/lib/Dialect/AIEVec/IR/AIEVecOps.cpp index 72241d6f4b..6aa48d3a0f 100644 --- a/lib/Dialect/AIEVec/IR/AIEVecOps.cpp +++ b/lib/Dialect/AIEVec/IR/AIEVecOps.cpp @@ -12,7 +12,7 @@ #include "aie/Dialect/AIEVec/IR/AIEVecOps.h" #include "aie/Dialect/AIEVec/AIEVecUtils.h" -#include "mlir/IR/AffineMap.h" + #include "mlir/IR/OpDefinition.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/FoldUtils.h" @@ -59,8 +59,8 @@ void UPDOp::print(OpAsmPrinter &p) { // Verify UPD op. LogicalResult UPDOp::verify() { // Verify the types: source is memref, and result is vector - MemRefType sourceType = getSource().getType().dyn_cast(); - VectorType resultType = getResult().getType().dyn_cast(); + MemRefType sourceType = llvm::dyn_cast(getSource().getType()); + VectorType resultType = llvm::dyn_cast(getResult().getType()); if (!sourceType) return emitError("requires memref type"); if (!resultType) @@ -71,7 +71,7 @@ LogicalResult UPDOp::verify() { // If this UPD op is linked to another UPD op, then verify that the linked // vector and the result vector match. if (getVector()) { - Type vecType = getVector().getType().dyn_cast(); + Type vecType = llvm::dyn_cast(getVector().getType()); if (vecType != resultType) return emitError("result types of linked UPD ops do not match"); } @@ -107,10 +107,10 @@ ParseResult UPDOp::parse(OpAsmParser &parser, OperationState &result) { return parser.emitError(typesLoc, "requires two types"); // Some verification - auto memrefType = types[0].dyn_cast(); + auto memrefType = llvm::dyn_cast(types[0]); if (!memrefType) return parser.emitError(typesLoc, "requires memref type"); - VectorType vectorType = types[1].dyn_cast(); + VectorType vectorType = llvm::dyn_cast(types[1]); if (!vectorType) return parser.emitError(typesLoc, "requires vector type"); auto indicesType = builder.getIndexType(); @@ -152,8 +152,8 @@ void CastOp::print(OpAsmPrinter &p) { // Verify Cast op. LogicalResult CastOp::verify() { // Verify the types - VectorType sourceType = getSource().getType().dyn_cast(); - VectorType resultType = getResult().getType().dyn_cast(); + VectorType sourceType = llvm::dyn_cast(getSource().getType()); + VectorType resultType = llvm::dyn_cast(getResult().getType()); if (!sourceType) return emitError("requires source vector type"); if (!resultType) @@ -190,10 +190,10 @@ ParseResult CastOp::parse(OpAsmParser &parser, OperationState &result) { return parser.emitError(typesLoc, "requires two types"); // Some verification of types - VectorType sourceType = types[0].dyn_cast(); + VectorType sourceType = llvm::dyn_cast(types[0]); if (!sourceType) return parser.emitError(typesLoc, "requires vector type"); - VectorType vectorType = types[1].dyn_cast(); + VectorType vectorType = llvm::dyn_cast(types[1]); if (!vectorType) return parser.emitError(typesLoc, "requires vector type"); @@ -248,8 +248,8 @@ void SRSOp::print(OpAsmPrinter &p) { // Verify SRS op. LogicalResult SRSOp::verify() { // Verify the types - VectorType sourceType = getSource().getType().dyn_cast(); - VectorType resultType = getResult().getType().dyn_cast(); + VectorType sourceType = llvm::dyn_cast(getSource().getType()); + VectorType resultType = llvm::dyn_cast(getResult().getType()); if (!sourceType) return emitError("requires accumulator type"); if (!resultType) @@ -299,15 +299,15 @@ ParseResult SRSOp::parse(OpAsmParser &parser, OperationState &result) { return parser.emitError(typesLoc, "requires three types"); // Some verification of types - VectorType accType = types[0].dyn_cast(); + VectorType accType = llvm::dyn_cast(types[0]); if (!accType) return parser.emitError(typesLoc, "requires vector type"); - IntegerType shiftType = types[1].dyn_cast(); + IntegerType shiftType = llvm::dyn_cast(types[1]); if (!shiftType) return parser.emitError(typesLoc, "requires integer type"); - VectorType vectorType = types[2].dyn_cast(); + VectorType vectorType = llvm::dyn_cast(types[2]); if (!vectorType) return parser.emitError(typesLoc, "requires vector type"); @@ -353,8 +353,8 @@ void UPSOp::print(OpAsmPrinter &p) { // Verify UPS op. LogicalResult UPSOp::verify() { // Verify the types - VectorType sourceType = getSource().getType().dyn_cast(); - VectorType resultType = getResult().getType().dyn_cast(); + VectorType sourceType = llvm::dyn_cast(getSource().getType()); + VectorType resultType = llvm::dyn_cast(getResult().getType()); if (!sourceType) return emitError("requires vector type"); if (!resultType) @@ -403,10 +403,10 @@ ParseResult UPSOp::parse(OpAsmParser &parser, OperationState &result) { return parser.emitError(typesLoc, "requires two types"); // Some verification - VectorType vectorType = types[0].dyn_cast(); + VectorType vectorType = llvm::dyn_cast(types[0]); if (!vectorType) return parser.emitError(typesLoc, "requires vector type"); - VectorType accType = types[1].dyn_cast(); + VectorType accType = llvm::dyn_cast(types[1]); if (!accType) return parser.emitError(typesLoc, "requires vector type"); @@ -436,8 +436,8 @@ void BroadcastOp::print(OpAsmPrinter &p) { // Verify Broadcast op. LogicalResult BroadcastOp::verify() { // Verify the types - VectorType sourceType = getSource().getType().dyn_cast(); - VectorType resultType = getResult().getType().dyn_cast(); + VectorType sourceType = llvm::dyn_cast(getSource().getType()); + VectorType resultType = llvm::dyn_cast(getResult().getType()); if (!sourceType) return emitError("requires vector type"); @@ -490,11 +490,11 @@ ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) { return parser.emitError(typesLoc, "requires two types"); // Some verification - VectorType vecType = types[0].dyn_cast(); + VectorType vecType = llvm::dyn_cast(types[0]); if (!vecType) return parser.emitError(typesLoc, "requires vector type"); - VectorType resType = types[1].dyn_cast(); + VectorType resType = llvm::dyn_cast(types[1]); if (!resType) return parser.emitError(typesLoc, "requires vector type"); @@ -522,7 +522,7 @@ void BroadcastScalarOp::print(OpAsmPrinter &p) { LogicalResult BroadcastScalarOp::verify() { // Verify the types Type sourceType = getSource().getType(); - VectorType resultType = getResult().getType().dyn_cast(); + VectorType resultType = llvm::dyn_cast(getResult().getType()); if (!resultType) return emitError("requires vector type"); @@ -548,7 +548,7 @@ ParseResult BroadcastScalarOp::parse(OpAsmParser &parser, if (parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) return failure(); - if (result.attributes.getAttrs().size() != 0) + if (!result.attributes.getAttrs().empty()) return parser.emitError(typesLoc, "do not require attributes"); // Assert that there is two type (source and result vector) @@ -556,7 +556,7 @@ ParseResult BroadcastScalarOp::parse(OpAsmParser &parser, return parser.emitError(typesLoc, "requires two types"); // Some verification - VectorType resType = types[1].dyn_cast(); + VectorType resType = llvm::dyn_cast(types[1]); if (!resType) return parser.emitError(typesLoc, "requires vector type"); @@ -575,7 +575,7 @@ ParseResult BroadcastScalarOp::parse(OpAsmParser &parser, // some specializations to print those fields specifically for FMA op. // Print the accumulator -template inline void printAccumulator(OpAsmPrinter &p, T op); +template void printAccumulator(OpAsmPrinter &p, T op); template <> inline void printAccumulator(OpAsmPrinter &p, aievec::FMAOp op) { p << ", " << op.getAcc(); } @@ -583,7 +583,7 @@ template <> inline void printAccumulator(OpAsmPrinter &p, aievec::MulOp op) {} // Mark fmsub indicator as elided if the FMA op is not fmsub template -inline void elideFMSubAttr(T op, SmallVector &elidedAttrs); +void elideFMSubAttr(T op, SmallVector &elidedAttrs); template <> inline void elideFMSubAttr(aievec::FMAOp op, SmallVector &elidedAttrs) { @@ -634,14 +634,13 @@ void aievec::FMAOp::print(OpAsmPrinter &p) { // Verify Mul and FMA op. template LogicalResult verifyMulFMAOp(T op) { // Verify the types - VectorType lhsType = op.getLhs().getType().template dyn_cast(); - VectorType rhsType = op.getRhs().getType().template dyn_cast(); + auto lhsType = op.getLhs().getType().template dyn_cast(); + auto rhsType = op.getRhs().getType().template dyn_cast(); if (!lhsType || !rhsType) return op.emitError("requires vector type"); - VectorType resultType = - op.getResult().getType().template dyn_cast(); + auto resultType = op.getResult().getType().template dyn_cast(); if (!resultType) return op.emitError("requires vector type"); @@ -733,15 +732,15 @@ ParseResult parseMulFMAOp(OpAsmParser &parser, OperationState &result, return parser.emitError(typesLoc, "requires three types"); // Some verification - VectorType lhsType = types[0].dyn_cast(); + VectorType lhsType = llvm::dyn_cast(types[0]); if (!lhsType) return parser.emitError(typesLoc, "requires vector type"); - VectorType rhsType = types[1].dyn_cast(); + VectorType rhsType = llvm::dyn_cast(types[1]); if (!rhsType) return parser.emitError(typesLoc, "requires vector type"); // Int ops use the accumulator while float ops use normal vector registers - VectorType accType = types[2].dyn_cast(); + VectorType accType = llvm::dyn_cast(types[2]); if (!accType) return parser.emitError(typesLoc, "requires vector type"); @@ -777,7 +776,7 @@ ParseResult FMAOp::parse(OpAsmParser &parser, OperationState &result) { // FMAElemOp and MULElemOp. // Print the accumulator -template inline void printAccumulator(OpAsmPrinter &p, T op); +template void printAccumulator(OpAsmPrinter &p, T op); template <> inline void printAccumulator(OpAsmPrinter &p, aievec::FMAElemOp op) { p << ", " << op.getAcc(); @@ -787,7 +786,7 @@ inline void printAccumulator(OpAsmPrinter &p, aievec::MulElemOp op) {} // Mark fmsub indicator as elided if the FMAElem op is not fmsub template -inline void elideFMSubAttr(T op, SmallVector &elidedAttrs); +void elideFMSubAttr(T op, SmallVector &elidedAttrs); template <> inline void elideFMSubAttr(aievec::FMAElemOp op, SmallVector &elidedAttrs) { @@ -831,14 +830,13 @@ void aievec::FMAElemOp::print(OpAsmPrinter &p) { // Verify MulElem and FMAElem op. template LogicalResult verifyMulFMAElemOp(T op) { // Verify the types - VectorType lhsType = op.getLhs().getType().template dyn_cast(); - VectorType rhsType = op.getRhs().getType().template dyn_cast(); + auto lhsType = op.getLhs().getType().template dyn_cast(); + auto rhsType = op.getRhs().getType().template dyn_cast(); if (!lhsType || !rhsType) return op.emitError("requires vector type"); - VectorType resultType = - op.getResult().getType().template dyn_cast(); + auto resultType = op.getResult().getType().template dyn_cast(); if (!resultType) return op.emitError("requires vector type"); @@ -920,15 +918,15 @@ ParseResult parseMulFMAElemOp(OpAsmParser &parser, OperationState &result, return parser.emitError(typesLoc, "requires three types"); // Some verification - VectorType lhsType = types[0].dyn_cast(); + VectorType lhsType = llvm::dyn_cast(types[0]); if (!lhsType) return parser.emitError(typesLoc, "requires vector type"); - VectorType rhsType = types[1].dyn_cast(); + VectorType rhsType = llvm::dyn_cast(types[1]); if (!rhsType) return parser.emitError(typesLoc, "requires vector type"); // Int ops use the accumulator while float ops use normal vector registers - VectorType accType = types[2].dyn_cast(); + VectorType accType = llvm::dyn_cast(types[2]); if (!accType) return parser.emitError(typesLoc, "requires vector type"); @@ -995,10 +993,9 @@ void aievec::SubOp::print(OpAsmPrinter &p) { // Verify Add and Sub op. template LogicalResult verifyAddSubOp(T op) { // Verify the types - VectorType resultType = - op.getResult().getType().template dyn_cast(); - VectorType lhsType = op.getLhs().getType().template dyn_cast(); - VectorType rhsType = op.getRhs().getType().template dyn_cast(); + auto resultType = op.getResult().getType().template dyn_cast(); + auto lhsType = op.getLhs().getType().template dyn_cast(); + auto rhsType = op.getRhs().getType().template dyn_cast(); if (!lhsType || !rhsType || !resultType) return op.emitError("requires vector type"); @@ -1039,13 +1036,13 @@ ParseResult parseAddSubOp(OpAsmParser &parser, OperationState &result) { return parser.emitError(typesLoc, "requires three types"); // Some verification - VectorType lhsType = types[0].dyn_cast(); + VectorType lhsType = llvm::dyn_cast(types[0]); if (!lhsType) return parser.emitError(typesLoc, "requires vector type"); - VectorType rhsType = types[1].dyn_cast(); + VectorType rhsType = llvm::dyn_cast(types[1]); if (!rhsType) return parser.emitError(typesLoc, "requires vector type"); - VectorType resultType = types[2].dyn_cast(); + VectorType resultType = llvm::dyn_cast(types[2]); if (!resultType) return parser.emitError(typesLoc, "requires vector type"); @@ -1091,15 +1088,15 @@ LogicalResult ConcatOp::verify() { // Verify the types VectorType sourceType = - getSources().getTypes().front().dyn_cast(); - VectorType resultType = getResult().getType().dyn_cast(); + llvm::dyn_cast(getSources().getTypes().front()); + VectorType resultType = llvm::dyn_cast(getResult().getType()); if (!sourceType || !resultType) return emitError("requires vector type"); SmallVector srcs(getSources().begin(), getSources().end()); // All the sources must have the same type for (auto source : srcs) { - VectorType type = source.getType().dyn_cast(); + VectorType type = llvm::dyn_cast(source.getType()); if (!type) return emitError("requires vector type"); if (type != sourceType) @@ -1109,7 +1106,7 @@ LogicalResult ConcatOp::verify() { // The lanes in concatenated type must be the sum of lanes of source vector unsigned totalLanes = 0; for (auto source : srcs) { - VectorType type = source.getType().dyn_cast(); + VectorType type = llvm::dyn_cast(source.getType()); totalLanes += getVectorLaneSize(type); } @@ -1144,8 +1141,8 @@ ParseResult ConcatOp::parse(OpAsmParser &parser, OperationState &result) { return parser.emitError(typesLoc, "requires two types"); // Some verification - VectorType sourceType = types[0].dyn_cast(); - VectorType resultType = types[1].dyn_cast(); + VectorType sourceType = llvm::dyn_cast(types[0]); + VectorType resultType = llvm::dyn_cast(types[1]); if (!sourceType || !resultType) return parser.emitError(typesLoc, "requires vector type"); @@ -1175,8 +1172,8 @@ void ExtOp::print(OpAsmPrinter &p) { // Verify Ext op. LogicalResult ExtOp::verify() { // Verify the types - VectorType sourceType = getSource().getType().dyn_cast(); - VectorType resultType = getResult().getType().dyn_cast(); + VectorType sourceType = llvm::dyn_cast(getSource().getType()); + VectorType resultType = llvm::dyn_cast(getResult().getType()); if (!sourceType || !resultType) return emitError("requires vector type"); @@ -1229,8 +1226,8 @@ ParseResult ExtOp::parse(OpAsmParser &parser, OperationState &result) { return parser.emitError(typesLoc, "requires two types"); // Some verification - VectorType sourceType = types[0].dyn_cast(); - VectorType resultType = types[1].dyn_cast(); + VectorType sourceType = llvm::dyn_cast(types[0]); + VectorType resultType = llvm::dyn_cast(types[1]); if (!sourceType || !resultType) return parser.emitError(typesLoc, "requires vector type"); @@ -1277,8 +1274,8 @@ void aievec::SelectOp::print(OpAsmPrinter &p) { // Verify select op. LogicalResult aievec::SelectOp::verify() { // Verify the types - VectorType resultType = getResult().getType().dyn_cast(); - VectorType xbuffType = getXbuff().getType().dyn_cast(); + VectorType resultType = llvm::dyn_cast(getResult().getType()); + VectorType xbuffType = llvm::dyn_cast(getXbuff().getType()); if (!resultType || !xbuffType) return emitError("requires vector type"); @@ -1291,7 +1288,7 @@ LogicalResult aievec::SelectOp::verify() { // If yuff is present, its vector type should be same as xbuff if (getYbuff()) { - VectorType ybuffType = getYbuff().getType().dyn_cast(); + VectorType ybuffType = llvm::dyn_cast(getYbuff().getType()); if (xbuffType != ybuffType) return emitError("types of xbuff and ybuff must match"); } @@ -1330,16 +1327,16 @@ ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) { return parser.emitError(typesLoc, "requires at least two type"); // Some verification - VectorType xbuffType = types[0].dyn_cast(); + VectorType xbuffType = llvm::dyn_cast(types[0]); if (!xbuffType) return parser.emitError(typesLoc, "requires vector type"); VectorType ybuffType; if (hasYbuff.succeeded()) { - ybuffType = types[1].dyn_cast(); + ybuffType = llvm::dyn_cast(types[1]); if (!ybuffType) return parser.emitError(typesLoc, "requires vector type"); } - VectorType resultType = types.back().dyn_cast(); + VectorType resultType = llvm::dyn_cast(types.back()); if (!resultType) return parser.emitError(typesLoc, "requires vector type"); @@ -1377,10 +1374,8 @@ void UnpackOp::print(OpAsmPrinter &p) { printPackUnpackOp(p, *this); } // Verify Pack and Unpack op. template LogicalResult verifyPackUnpackOp(T op) { // Verify the types - VectorType sourceType = - op.getSource().getType().template dyn_cast(); - VectorType resultType = - op.getResult().getType().template dyn_cast(); + auto sourceType = op.getSource().getType().template dyn_cast(); + auto resultType = op.getResult().getType().template dyn_cast(); if (!sourceType || !resultType) return op.emitError("requires vector type"); @@ -1440,8 +1435,8 @@ ParseResult parsePackUnpackOp(OpAsmParser &parser, OperationState &result) { return parser.emitError(typesLoc, "requires two types"); // Some verification - VectorType sourceType = types[0].dyn_cast(); - VectorType resultType = types[1].dyn_cast(); + VectorType sourceType = llvm::dyn_cast(types[0]); + VectorType resultType = llvm::dyn_cast(types[1]); if (!sourceType || !resultType) return parser.emitError(typesLoc, "requires vector type"); @@ -1483,13 +1478,13 @@ void ShiftOp::print(OpAsmPrinter &p) { // Verify Shift op. LogicalResult ShiftOp::verify() { // Verify the types - VectorType resultType = getResult().getType().dyn_cast(); + VectorType resultType = llvm::dyn_cast(getResult().getType()); if (!resultType) return emitError("requires vector type"); // lhs, rhs and result must have the same type - VectorType lhsType = getLhs().getType().dyn_cast(); - VectorType rhsType = getRhs().getType().dyn_cast(); + VectorType lhsType = llvm::dyn_cast(getLhs().getType()); + VectorType rhsType = llvm::dyn_cast(getRhs().getType()); if (!lhsType || !rhsType) return emitError("requires vector type"); @@ -1527,10 +1522,10 @@ ParseResult ShiftOp::parse(OpAsmParser &parser, OperationState &result) { return parser.emitError(typesLoc, "requires four types"); // Some verification - VectorType lhsType = types[0].dyn_cast(); - VectorType rhsType = types[1].dyn_cast(); - IntegerType shiftType = types[2].dyn_cast(); - VectorType resultType = types[3].dyn_cast(); + VectorType lhsType = llvm::dyn_cast(types[0]); + VectorType rhsType = llvm::dyn_cast(types[1]); + IntegerType shiftType = llvm::dyn_cast(types[2]); + VectorType resultType = llvm::dyn_cast(types[3]); if (!lhsType || !rhsType || !resultType) return parser.emitError(typesLoc, "requires vector type"); @@ -1565,8 +1560,8 @@ void ShuffleOp::print(OpAsmPrinter &p) { // Verify Shuffle op. LogicalResult ShuffleOp::verify() { // Verify the types - VectorType sourceType = getSource().getType().dyn_cast(); - VectorType resultType = getResult().getType().dyn_cast(); + VectorType sourceType = llvm::dyn_cast(getSource().getType()); + VectorType resultType = llvm::dyn_cast(getResult().getType()); if (!sourceType || !resultType) return emitError("requires vector type"); @@ -1609,8 +1604,8 @@ ParseResult ShuffleOp::parse(OpAsmParser &parser, OperationState &result) { return parser.emitError(typesLoc, "expects one attribute"); // Some verification - VectorType sourceType = types[0].dyn_cast(); - VectorType resultType = types[1].dyn_cast(); + VectorType sourceType = llvm::dyn_cast(types[0]); + VectorType resultType = llvm::dyn_cast(types[1]); if (!sourceType || !resultType) return parser.emitError(typesLoc, "requires vector type"); @@ -1631,7 +1626,7 @@ ParseResult ShuffleOp::parse(OpAsmParser &parser, OperationState &result) { // FMAConvOp and MULConvOp. // Print the accumulator -template inline void printAccumulator(OpAsmPrinter &p, T op); +template void printAccumulator(OpAsmPrinter &p, T op); template <> inline void printAccumulator(OpAsmPrinter &p, aievec::FMAConvOp op) { p << ", " << op.getAcc(); @@ -1641,16 +1636,16 @@ inline void printAccumulator(OpAsmPrinter &p, aievec::MulConvOp op) {} // Mark fmsub indicator as elided if the FMAElem op is not fmsub template -inline void elideFMSubAttr(T op, SmallVector &elidedAttrs); +void elideFMSubAttr(T op, SmallVector &elidedAttrs); template <> -inline void elideFMSubAttr(aievec::FMAConvOp op, +inline void elideFMSubAttr(FMAConvOp op, SmallVector &elidedAttrs) { if (!op.getFmsub()) elidedAttrs.push_back(op.getSubAttrName()); } template <> -inline void elideFMSubAttr(aievec::MulConvOp op, +inline void elideFMSubAttr(MulConvOp op, SmallVector &elidedAttrs) {} // Print out MulConv and FMAConv op. @@ -1685,8 +1680,8 @@ void aievec::FMAConvOp::print(OpAsmPrinter &p) { // Verify MulConv and FMAConv op. template LogicalResult verifyMulFMAConvOp(T op) { // Verify the types - VectorType lhsType = op.getLhs().getType().template dyn_cast(); - VectorType rhsType = op.getRhs().getType().template dyn_cast(); + auto lhsType = op.getLhs().getType().template dyn_cast(); + auto rhsType = op.getRhs().getType().template dyn_cast(); if (!lhsType || !rhsType) return op.emitError("requires vector type"); @@ -1694,13 +1689,11 @@ template LogicalResult verifyMulFMAConvOp(T op) { unsigned M = op.getM(); unsigned N = op.getN(); - if (M <= 0 || N <= 0 || 2 * M < M + N - 1) { + if (M <= 0 || N <= 0 || 2 * M < M + N - 1) return op.emitError( "M and N should be larger than 0 and 2*M should be no less than M+N-1"); - } - VectorType resultType = - op.getResult().getType().template dyn_cast(); + auto resultType = op.getResult().getType().template dyn_cast(); if (!resultType) return op.emitError("requires vector type"); @@ -1781,15 +1774,15 @@ ParseResult parseMulFMAConvOp(OpAsmParser &parser, OperationState &result, return parser.emitError(typesLoc, "requires three types"); // Some verification - VectorType lhsType = types[0].dyn_cast(); + VectorType lhsType = llvm::dyn_cast(types[0]); if (!lhsType) return parser.emitError(typesLoc, "requires vector type"); - VectorType rhsType = types[1].dyn_cast(); + VectorType rhsType = llvm::dyn_cast(types[1]); if (!rhsType) return parser.emitError(typesLoc, "requires vector type"); // Int ops use the accumulator - VectorType accType = types[2].dyn_cast(); + VectorType accType = llvm::dyn_cast(types[2]); if (!accType) return parser.emitError(typesLoc, "requires vector type"); diff --git a/lib/Dialect/AIEVec/TransformOps/AIEVecTransformOps.cpp b/lib/Dialect/AIEVec/TransformOps/AIEVecTransformOps.cpp index 35c650b0b6..1224260bf8 100644 --- a/lib/Dialect/AIEVec/TransformOps/AIEVecTransformOps.cpp +++ b/lib/Dialect/AIEVec/TransformOps/AIEVecTransformOps.cpp @@ -10,7 +10,6 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/Utils/Utils.h" @@ -94,13 +93,16 @@ static bool vectorizeContractionOpBlock(OpBuilder &rewriter, Location loc, auto baC = static_cast(dstBlock.getArgument(2)); // Store vectorized values for op replacement llvm::DenseMap convertedValues; - convertedValues.try_emplace(static_cast(srcBlock.getArgument(0)), baA); - convertedValues.try_emplace(static_cast(srcBlock.getArgument(1)), baB); - convertedValues.try_emplace(static_cast(srcBlock.getArgument(2)), baC); + convertedValues.try_emplace(srcBlock.getArgument(0), baA); + convertedValues.try_emplace(srcBlock.getArgument(1), baB); + convertedValues.try_emplace(srcBlock.getArgument(2), baC); auto indexingMaps = rewriter.getAffineMapArrayAttr( - {AffineMap::getPermutationMap({1, 0, 2}, ctx).dropResults(0), - AffineMap::getPermutationMap({0, 2, 1}, ctx).dropResults(0), - AffineMap::getPermutationMap({2, 0, 1}, ctx).dropResults(0)}); + {AffineMap::getPermutationMap(ArrayRef{1, 0, 2}, ctx) + .dropResults(0), + AffineMap::getPermutationMap(ArrayRef{0, 2, 1}, ctx) + .dropResults(0), + AffineMap::getPermutationMap(ArrayRef{2, 0, 1}, ctx) + .dropResults(0)}); auto iteratorTypes = rewriter.getArrayAttr( {vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), @@ -125,15 +127,14 @@ static bool vectorizeContractionOpBlock(OpBuilder &rewriter, Location loc, opA = convertedValues[rhsDefOp->getOperand(0)]; opB = convertedValues[rhsDefOp->getOperand(1)]; opC = convertedValues[lhs]; - } else { + } else return WalkResult::interrupt(); - } auto conOp = rewriter.create( loc, opA, opB, opC, indexingMaps, iteratorTypes); convertedValues.try_emplace(op->getResult(0), conOp.getResult()); return WalkResult::advance(); }) - .Case([&](auto mulOp) { + .Case([&](auto) { if (mulOpFound) return WalkResult::interrupt(); mulOpFound = true; @@ -170,14 +171,13 @@ static bool vectorizeContractionOpBlock(OpBuilder &rewriter, Location loc, } DiagnosedSilenceableFailure transform::VectorizeContractionOp::applyToOne( - transform::TransformRewriter &rewriter, linalg::GenericOp target, - transform::ApplyToEachResultList &results, - transform::TransformState &state) { + TransformRewriter &rewriter, linalg::GenericOp target, + ApplyToEachResultList &results, TransformState &state) { auto ctx = target.getContext(); SmallVector inputs = target.getInputs(); - SmallVector outputs = target.getOutputs(); - if (inputs.size() != 2 || outputs.size() != 1) + if (SmallVector outputs = target.getOutputs(); + inputs.size() != 2 || outputs.size() != 1) return emitSilenceableError() << "payload is not a contraction."; // Split the iterators in two: inner contraction + remaining @@ -200,9 +200,15 @@ DiagnosedSilenceableFailure transform::VectorizeContractionOp::applyToOne( //=== // 1. Build the indexing maps for the operands of a GEMM contraction - auto mmAidxMap = AffineMap::getPermutationMap({1, 0, 2}, ctx).dropResults(0); - auto mmBidxMap = AffineMap::getPermutationMap({0, 2, 1}, ctx).dropResults(0); - auto mmCidxMap = AffineMap::getPermutationMap({2, 0, 1}, ctx).dropResults(0); + auto mmAidxMap = + AffineMap::getPermutationMap(ArrayRef{1, 0, 2}, ctx) + .dropResults(0); + auto mmBidxMap = + AffineMap::getPermutationMap(ArrayRef{0, 2, 1}, ctx) + .dropResults(0); + auto mmCidxMap = + AffineMap::getPermutationMap(ArrayRef{2, 0, 1}, ctx) + .dropResults(0); // 2. Get the indexing maps for the 2 innermost dimmensions of each operand SmallVector outerMostResults; diff --git a/lib/Dialect/AIEVec/Transforms/AIEVectorize.cpp b/lib/Dialect/AIEVec/Transforms/AIEVectorize.cpp index 5683c38f7d..5496d8f3c6 100644 --- a/lib/Dialect/AIEVec/Transforms/AIEVectorize.cpp +++ b/lib/Dialect/AIEVec/Transforms/AIEVectorize.cpp @@ -166,15 +166,14 @@ struct Scheme { //===----------------------------------------------------------------------===// // Combine the result of vector-related utilities into a single utility. -static inline AIEVecAttributes getVectorStats(VectorType type) { +static AIEVecAttributes getVectorStats(VectorType type) { return AIEVecAttributes(getVectorLaneSize(type), getVectorSizeInBits(type), type.getElementType(), getElementSizeInBits(type)); } // Get the vector stats for an operation's result. -static inline AIEVecAttributes getResultVecStats(Operation *op, - unsigned idx = 0) { - VectorType vtype = op->getResult(idx).getType().cast(); +static AIEVecAttributes getResultVecStats(Operation *op, unsigned idx = 0) { + auto vtype = op->getResult(idx).getType().cast(); return getVectorStats(vtype); } @@ -186,11 +185,11 @@ static Operation *getOperandDefOp(VectState *state, Operation *op, } // Get the vector stats for an operation's operand. -static inline AIEVecAttributes -getOperandVecStats(Operation *op, VectState *state, unsigned idx = 0) { +static AIEVecAttributes getOperandVecStats(Operation *op, VectState *state, + unsigned idx = 0) { assert(op->getNumOperands() > idx); Operation *defOp = getOperandDefOp(state, op, idx); - VectorType vtype = defOp->getResult(0).getType().cast(); + auto vtype = defOp->getResult(0).getType().cast(); auto ret = getVectorStats(vtype); // if the defining op is a transfer read, get the extent read from source if (auto readOp = dyn_cast(defOp)) { @@ -205,20 +204,20 @@ getOperandVecStats(Operation *op, VectState *state, unsigned idx = 0) { } // Get the number of rows and columns in the vector scheme. -static inline std::pair getNumRowsAndCols(Operation *op, - VectState *state) { +static std::pair getNumRowsAndCols(Operation *op, + VectState *state) { assert(op->getNumOperands() >= 2 && op->getNumResults() == 1); Operation *left = getOperandDefOp(state, op, 0); Operation *right = getOperandDefOp(state, op, 1); // Get the number of lanes - VectorType vtype = op->getResult(0).getType().cast(); + auto vtype = op->getResult(0).getType().cast(); int32_t lanes = getVectorLaneSize(vtype); // Get the data sizes for left and right operands - VectorType ltype = left->getResult(0).getType().cast(); - VectorType rtype = right->getResult(0).getType().cast(); + auto ltype = left->getResult(0).getType().cast(); + auto rtype = right->getResult(0).getType().cast(); int32_t lsize = getElementSizeInBits(ltype); int32_t rsize = getElementSizeInBits(rtype); @@ -247,8 +246,7 @@ static inline std::pair getNumRowsAndCols(Operation *op, // operand of Op2 has access extent [128,512], where these two accesses belong // to the same ReuseInterval, then the union is [0,512]. This union will be the // new access extent of the left operands of both Op1 and Op2. -static inline void fuseAccessExtent(Operation *Op1, Operation *Op2, - VectState *state) { +static void fuseAccessExtent(Operation *Op1, Operation *Op2, VectState *state) { // Assert that the input operations are of expected type assert([&] { bool expectedTypes = @@ -333,11 +331,11 @@ static bool isWellFormedVectorOp(Operation *Op) { return false; } - VectorType refType = operandsAndResults.back().getType().cast(); + auto refType = operandsAndResults.back().getType().cast(); Type scalarType = refType.getElementType(); unsigned refSize = getVectorLaneSize(refType); for (auto val : operandsAndResults) { - VectorType vtype = val.getType().cast(); + auto vtype = val.getType().cast(); // Check 2. All the vector sizes must be same if (refSize != getVectorLaneSize(vtype)) return false; @@ -426,7 +424,7 @@ static AffineExpr constructLinearizedAffineExpr(TransferReadOp readOp, SmallVector indices(readOp.getIndices().begin(), readOp.getIndices().end()); - MemRefType memRefType = readOp.getSource().getType().cast(); + auto memRefType = readOp.getSource().getType().cast(); MLIRContext *context = memRefType.getContext(); SmallVector exprVec; @@ -438,8 +436,7 @@ static AffineExpr constructLinearizedAffineExpr(TransferReadOp readOp, // If the access is a map via affine apply op (e.g., A[i+2], where the map // is d0 -> d0+2), push in the map after replacing all the dims with unique // index identifiers (e.g., let the unique identifier for index i be k0). - if (affine::AffineApplyOp apOf = - value.getDefiningOp()) { + if (auto apOf = value.getDefiningOp()) { AffineMap map = apOf.getAffineMap(); assert(map.getNumResults() == 1 && "Failed to create linearized affineExpr for complicated index"); @@ -495,21 +492,21 @@ static std::pair getBaseAndOffset(AffineExpr expr) { AffineExpr base = expr; int32_t offset = 0; // If expr is already a constant, the base is nullptr, and offset is expr - if (auto constExpr = expr.dyn_cast()) { + if (auto constExpr = llvm::dyn_cast(expr)) { base = nullptr; offset += constExpr.getValue(); } // If this is a binary '+' expression, compute the constant offset. Currently // this is just a simple FSM. This must evolve as we explore more complex // access patterns. - else if (auto binopExpr = expr.dyn_cast()) { + else if (auto binopExpr = llvm::dyn_cast(expr)) { if (binopExpr.getKind() == AffineExprKind::Add) { AffineExpr lhs = binopExpr.getLHS(), rhs = binopExpr.getRHS(); - if (auto constExpr = lhs.dyn_cast()) { + if (auto constExpr = llvm::dyn_cast(lhs)) { base = rhs; offset += constExpr.getValue(); } - if (auto constExpr = rhs.dyn_cast()) { + if (auto constExpr = llvm::dyn_cast(rhs)) { base = base == rhs ? nullptr : lhs; offset += constExpr.getValue(); } @@ -526,7 +523,7 @@ static aievec::CastOp generateCastOp(Value source, VectorType resType, bool isResAcc, VectState *state, Location loc) { // Create the Cast op - aievec::CastOp castOp = + auto castOp = state->builder.create(loc, resType, source, isResAcc); assert(castOp && "could not create srs op"); @@ -550,8 +547,8 @@ static aievec::SRSOp generateSRSOp(Value source, Type scalarType, auto shiftParamOp = state->builder.create( loc, state->builder.getI32IntegerAttr(state->shift)); // Create the SRS op - aievec::SRSOp srsOp = state->builder.create( - loc, srsType, source, shiftParamOp.getResult()); + auto srsOp = state->builder.create(loc, srsType, source, + shiftParamOp.getResult()); assert(srsOp && "could not create srs op"); return srsOp; @@ -567,7 +564,7 @@ static aievec::UPSOp generateUPSOp(Value source, VectState *state, "ups source should not be accumulator"); // Create a new UPS instruction - aievec::UPSOp upsOp = + auto upsOp = state->builder.create(loc, accType, source, state->shift); assert(upsOp && "could not create ups op"); @@ -577,9 +574,9 @@ static aievec::UPSOp generateUPSOp(Value source, VectState *state, // Generate and return a Broadcast op. static aievec::BroadcastOp generateBroadcastOp(Value source, int8_t idx, VectState *state, Location loc) { - VectorType type = source.getType().cast(); + auto type = source.getType().cast(); // Create a new Broadcast instruction - aievec::BroadcastOp broadcastOp = + auto broadcastOp = state->builder.create(loc, type, source, idx); assert(broadcastOp && "could not create broadcast op"); @@ -592,11 +589,11 @@ static aievec::ConcatOp generateConcatOp(SmallVector &sources, VectorType concatType = nullptr) { assert(sources.size() > 1 && "must concat at least two vectors"); - VectorType vecType = sources.back().getType().cast(); + auto vecType = sources.back().getType().cast(); assert([&] { for (auto source : sources) { - VectorType type = source.getType().cast(); + auto type = source.getType().cast(); if (type != vecType) { printf("sources of concat op not of same type\n"); return false; @@ -613,7 +610,7 @@ static aievec::ConcatOp generateConcatOp(SmallVector &sources, } // Create the concat op - aievec::ConcatOp concatOp = + auto concatOp = state->builder.create(loc, concatType, sources); assert(concatOp && "could not create concat op"); @@ -631,14 +628,14 @@ static aievec::SelectOp generateSelectOp(Value xbuff, AIEOpAttributes &opAttr, assert(opAttr.start.size() == opAttr.offset.size() && opAttr.start.size() == 2); - VectorType xtype = xbuff.getType().cast(); + auto xtype = xbuff.getType().cast(); // Verify that lanes is <= xtype lanes assert(lanes <= getVectorLaneSize(xtype)); // Create the result type VectorType resultType = createVectorType(lanes, xtype.getElementType()); // Create AIE dialect select op - aievec::SelectOp selectOp = state->builder.create( + auto selectOp = state->builder.create( loc, resultType, xbuff, opAttr.select, opAttr.start[0], opAttr.offset[0], opAttr.offset_hi[0], opAttr.square[0], opAttr.start[1], opAttr.offset[1], opAttr.offset_hi[1], opAttr.square[1], ybuff); @@ -651,14 +648,14 @@ static aievec::SelectOp generateSelectOp(Value xbuff, AIEOpAttributes &opAttr, // output, and idx defines which part of source is extracted. static aievec::ExtOp generateExtOp(Value source, unsigned lanes, int8_t idx, VectState *state, Location loc) { - VectorType stype = source.getType().cast(); + auto stype = source.getType().cast(); // Verify that lanes*idx is <= stype lanes assert(lanes * (idx + 1) <= getVectorLaneSize(stype)); // Create the result type VectorType resultType = createVectorType(lanes, stype.getElementType()); // Create AIE dialect ext op - aievec::ExtOp extOp = + auto extOp = state->builder.create(loc, resultType, source, idx); assert(extOp && "could not create ext op"); @@ -669,14 +666,13 @@ static aievec::ExtOp generateExtOp(Value source, unsigned lanes, int8_t idx, static aievec::PackOp generatePackOp(Value source, VectState *state, Location loc) { // Create the result type - VectorType stype = source.getType().cast(); + auto stype = source.getType().cast(); unsigned lanes = getVectorLaneSize(stype); - Type i8Type = mlir::IntegerType::get(source.getContext(), 8); + Type i8Type = IntegerType::get(source.getContext(), 8); VectorType resultType = createVectorType(lanes, i8Type); // Create AIE dialect pack op - aievec::PackOp packOp = - state->builder.create(loc, resultType, source); + auto packOp = state->builder.create(loc, resultType, source); assert(packOp && "could not create pack op"); return packOp; @@ -689,7 +685,7 @@ static aievec::AddOp generateAddOp(Operation *Op, AIEOpAttributes &opAttr, assert(opAttr.start.size() == opAttr.offset.size() && opAttr.start.size() == 2); - aievec::AddOp addOp = state->builder.create( + auto addOp = state->builder.create( Op->getLoc(), Op->getResult(0).getType(), Op->getOperand(0), Op->getOperand(1), opAttr.start[0], opAttr.offset[0], opAttr.offset_hi[0], opAttr.square[0], opAttr.start[1], opAttr.offset[1], opAttr.offset_hi[1], @@ -704,7 +700,7 @@ static aievec::SubOp generateSubOp(Operation *Op, AIEOpAttributes &opAttr, assert(opAttr.start.size() == opAttr.offset.size() && opAttr.start.size() == 2); - aievec::SubOp subOp = state->builder.create( + auto subOp = state->builder.create( Op->getLoc(), Op->getResult(0).getType(), Op->getOperand(0), Op->getOperand(1), opAttr.start[0], opAttr.offset[0], opAttr.offset_hi[0], opAttr.square[0], opAttr.start[1], opAttr.offset[1], opAttr.offset_hi[1], @@ -715,10 +711,10 @@ static aievec::SubOp generateSubOp(Operation *Op, AIEOpAttributes &opAttr, static aievec::ShiftOp generateShiftOp(Value lhs, Value rhs, int32_t shiftBytes, VectState *state, Location loc, VectorType resType = nullptr) { - VectorType vecType = rhs.getType().cast(); + auto vecType = rhs.getType().cast(); assert([&] { - VectorType type = lhs.getType().cast(); + auto type = lhs.getType().cast(); if (type != vecType) { printf("lhs and rhs do not have same type\n"); return false; @@ -732,7 +728,7 @@ static aievec::ShiftOp generateShiftOp(Value lhs, Value rhs, int32_t shiftBytes, resType = createVectorType(lanes, scalarType); } - arith::ConstantOp constOp = state->builder.create( + auto constOp = state->builder.create( loc, state->builder.getI32IntegerAttr(shiftBytes)); auto shiftOp = state->builder.create(loc, resType, lhs, rhs, constOp.getResult()); @@ -743,7 +739,7 @@ static aievec::ShiftOp generateShiftOp(Value lhs, Value rhs, int32_t shiftBytes, static aievec::ShuffleOp generateShuffleOp(Value source, VectState *state, Location loc, unsigned mode, VectorType resType = nullptr) { - VectorType vecType = source.getType().cast(); + auto vecType = source.getType().cast(); if (!resType) { unsigned lanes = 512 / getElementSizeInBits(vecType); @@ -774,14 +770,14 @@ static Operation *generateMulOrFMAConvOpForInt8(Operation *Op, Value rhs = state->sextTruncDefMap.count(Op->getOperand(0).getDefiningOp()) ? Op->getOperand(0).getDefiningOp()->getOperand(0) : Op->getOperand(0); - VectorType vType = lhs.getType().cast(); + auto vType = lhs.getType().cast(); Type stype = vType.getElementType(); - IntegerType itype = stype.cast(); + auto itype = stype.cast(); unsigned width = itype.getWidth() <= 8 ? 32 : 64; int32_t M = 32; int32_t N = 8; - Type ctype = mlir::IntegerType::get(itype.getContext(), width); + Type ctype = IntegerType::get(itype.getContext(), width); Type opType = VectorType::get(vType.getShape(), ctype); auto defOp = rhs.getDefiningOp(); state->builder.setInsertionPointAfter(defOp); @@ -848,7 +844,6 @@ static Operation *generateFMAOp(vector::FMAOp fmaOp, AIEOpAttributes &opAttr, // Check if this is an fmsub op, and if so, then we need to generate msc op bool isSub = state->mscOps.count(fmaOp); - Operation *xfmaOp = nullptr; // We need to generate a UPS op for the integer and AIEML path if the // accumulator is coming from a vector register. @@ -858,6 +853,7 @@ static Operation *generateFMAOp(vector::FMAOp fmaOp, AIEOpAttributes &opAttr, .getElementType() .isa(); + Operation *xfmaOp; if (AIEML && getVectorSizeInBits(rhs.getType().cast()) == 512) { if (!writesToAccumulator(acc.getDefiningOp())) { acc = generateUPSOp(acc, state, fmaOp->getLoc()); @@ -983,7 +979,7 @@ generateUPDOp(TransferReadOp readOp, // Create the upd vector type. To do so, we need the underlying element type. // We can divide the interval size by that to get the number of lanes in the // result vector of upd op. - VectorType vecType = readOp.getVector().getType().cast(); + auto vecType = readOp.getVector().getType().cast(); Type elementType = vecType.getElementType(); int32_t elementSizeInBits = getElementSizeInBits(vecType); int intervalWidthInBytes = intervalWidth / elementSizeInBits; @@ -1027,10 +1023,8 @@ generateUPDOp(TransferReadOp readOp, readOp.getIndices().end()); // Get the linearized access expression for the read to compute the offset AffineExpr linearAccess = constructLinearizedAffineExpr(readOp, state); - AffineExpr base; - int32_t offset; // Get the base and offset from linear access expr - std::tie(base, offset) = getBaseAndOffset(linearAccess); + auto [base, offset] = getBaseAndOffset(linearAccess); offset *= elementSizeInBits; // get the offset in bits // The insertion point depends on whether the region has a single block or @@ -1076,8 +1070,7 @@ generateUPDOp(TransferReadOp readOp, // If the transfer_read has some apply operations, then they also need to // be hoisted. for (auto &value : indices) { - if (affine::AffineApplyOp apOf = - value.getDefiningOp()) { + if (auto apOf = value.getDefiningOp()) { // Skip hoisting if already above in lexicographical order if (apOf->getBlock() == readOp->getBlock() && apOf->isBeforeInBlock(updOp)) @@ -1109,7 +1102,7 @@ static int32_t computeVecorizedLoopStepSize(Operation *op, VectState *state) { return 1; int32_t step = 0; - VectorType vectorType = readOp.getResult().getType().cast(); + auto vectorType = readOp.getResult().getType().cast(); SmallVector indices(readOp.getIndices().begin(), readOp.getIndices().end()); assert(vectorType && !indices.empty()); @@ -1123,7 +1116,7 @@ static int32_t computeVecorizedLoopStepSize(Operation *op, VectState *state) { // The vectorized (i.e., last) index of the permutation must correspond to a // loop nest. If not, this is a splat read. AffineExpr expr = readOp.getPermutationMap().getResults().back(); - if (auto dimExpr = expr.dyn_cast()) { + if (auto dimExpr = llvm::dyn_cast(expr)) { assert(dimExpr.getPosition() <= indices.size() && "Failed to find the permutation index in index map"); auto index = indices[dimExpr.getPosition()]; @@ -1132,7 +1125,7 @@ static int32_t computeVecorizedLoopStepSize(Operation *op, VectState *state) { [[maybe_unused]] bool found = false; for (auto loop : enclosingLoops) { auto iv = cast(loop).getInductionVar(); - auto invariants = mlir::affine::getInvariantAccesses(iv, indices); + auto invariants = affine::getInvariantAccesses(iv, indices); if (!invariants.count(index)) { assert( !found && @@ -1162,15 +1155,13 @@ int32_t computeStartInAIEVec(Operation *op, VectState *state) { auto readOp = cast(op); // Get the scalar element type's size in bits - VectorType vtype = readOp.getVector().getType().cast(); + auto vtype = readOp.getVector().getType().cast(); int32_t scalarSizeInBits = getElementSizeInBits(vtype); // Get the linearized access expr for this read AffineExpr linearAccess = constructLinearizedAffineExpr(readOp, state); - AffineExpr base; - int32_t offset; // get the base and offset from linear access expr - std::tie(base, offset) = getBaseAndOffset(linearAccess); + auto [base, offset] = getBaseAndOffset(linearAccess); offset *= scalarSizeInBits; // compute offset in bits // Now find the reuse interval to which this readOp belongs IntervalReuse *iv = state->getIntervalForOperation(op); @@ -1198,7 +1189,7 @@ static Operation *concatAndInterleave_i8xi8(Operation *source1, // v16int16 vector, since select operation does not operate on v16int8 // vector. Type i16Type = - mlir::IntegerType::get(source1->getResult(0).getType().getContext(), 16); + IntegerType::get(source1->getResult(0).getType().getContext(), 16); auto srsOp1 = generateSRSOp(source1->getResult(0), i16Type, state, loc); auto srsOp2 = generateSRSOp(source2->getResult(0), i16Type, state, loc); @@ -1288,8 +1279,8 @@ static bool canFuseMulAndAddOrSubIntoFMAOp(Operation *Op, VectState *state) { return false; // Check 7. All the vector sizes must be same - VectorType lhsType = lhs.getType().cast(); - VectorType rhsType = rhs.getType().cast(); + auto lhsType = lhs.getType().cast(); + auto rhsType = rhs.getType().cast(); VectorType accType = state->sextTruncDefMap.count( acc.getDefiningOp()->getOperand(0).getDefiningOp()) ? acc.getDefiningOp() @@ -1442,7 +1433,7 @@ static void fuseMulAndAddOrSubIntoFMAOp(Operation *Op, VectState *state) { // generates one MulConvOp or FMAConvOp for each vector dialect mul/fma op. static void generateMulOrFMAOp(Operation *Op, Scheme &scheme, AIEOpAttributes &opAttr, VectState *state, - std::string nextStart = "") { + const std::string &nextStart = "") { // Assert that we computed the attributes for both the operands assert(opAttr.start.size() == opAttr.offset.size() && opAttr.start.size() == 2); @@ -1458,15 +1449,15 @@ static void generateMulOrFMAOp(Operation *Op, Scheme &scheme, // Generate an AIE dialect mul/fma op from a vector dialect mul/fma op auto genOp = [&](Operation *Op, AIEOpAttributes &opAttr, VectState *state, bool i8xi8_pairedOp = false) { - Operation *repOp = nullptr; + Operation *repOp; // Create aievec::FMAOp corresponding to the vector::FMAOp - if (vector::FMAOp fmaOp = dyn_cast(Op)) + if (auto fmaOp = dyn_cast(Op)) repOp = generateFMAOp(fmaOp, opAttr, state, i8xi8_pairedOp); // Create aievec::MulOp corresponding to the vector::MulIOp - else if (MulIOp mulOp = dyn_cast(Op)) + else if (auto mulOp = dyn_cast(Op)) repOp = generateMulOp(mulOp, opAttr, state); // Create aievec::MulOp corresponding to the vector::MulFOp - else if (MulFOp mulOp = dyn_cast(Op)) + else if (auto mulOp = dyn_cast(Op)) repOp = generateMulOp(mulOp, opAttr, state); else llvm_unreachable("Operation not mul/fma op"); @@ -1483,9 +1474,9 @@ static void generateMulOrFMAOp(Operation *Op, Scheme &scheme, if (!nextStart.empty()) { if (AIEML && scheme.lanes == 32 && scheme.xbits == 8 && scheme.zbits == 8) { repOp = generateMulOrFMAConvOpForInt8(Op, opAttr, state); - if (llvm::any_of(repOp->getUsers(), notMulOrFMAOp)) { - Type i8Type = mlir::IntegerType::get( - repOp->getResult(0).getType().getContext(), 8); + if (any_of(repOp->getUsers(), notMulOrFMAOp)) { + Type i8Type = + IntegerType::get(repOp->getResult(0).getType().getContext(), 8); repOp = generateSRSOp(repOp->getResult(0), i8Type, state, repOp->getLoc()); } @@ -1499,7 +1490,7 @@ static void generateMulOrFMAOp(Operation *Op, Scheme &scheme, state->pairedOp[repOp] = pairedOp; // If any of the uses of incoming op is not a mul/fma op, then we need to // concatenate the paired ops and generate a v16xi8 vector. - if (llvm::any_of(Op->getUsers(), notMulOrFMAOp)) + if (any_of(Op->getUsers(), notMulOrFMAOp)) repOp = concatAndInterleave_i8xi8(repOp, pairedOp, state, Op->getLoc()); } } @@ -1546,23 +1537,21 @@ static void computeXbuffAttr_i16xi16( assert((accIncr <= 1 || colOffset <= 1) && "cannot generate offset and square for xbuff"); - std::string startStr, offsetStr, offsetHiStr, squareStr; - // Arch restriction: xstart should be a multiple of 2. int32_t m2start = (start / 2) * 2; - startStr = std::to_string(m2start); + std::string startStr = std::to_string(m2start); // m2Offset accounts for the extra 1 if the start is not a multiple of 2 int32_t m2Offset = start - m2start; // Compute hi and lo offsets to something resembling "0x_7_6_5_4" and // "0x_3_2_1_0" respectively. The '_' are 0 if colOffset is 1. - offsetStr = "0x"; + std::string offsetStr = "0x"; int32_t offset = std::max(colOffset, accIncr); for (int i = vecSize / 2 - 2; i >= 0; i -= 2) { offsetStr.push_back(offset <= 1 ? '0' : getHexValue((offset - 2) / 2)); offsetStr.push_back(getHexValue((i * accIncr) / 2)); } - offsetHiStr = "0x"; + std::string offsetHiStr = "0x"; for (int i = vecSize - 2, e = vecSize / 2; i >= e; i -= 2) { offsetHiStr.push_back(offset <= 1 ? '0' : getHexValue((offset - 2) / 2)); offsetHiStr.push_back(getHexValue((i * accIncr) / 2)); @@ -1574,7 +1563,7 @@ static void computeXbuffAttr_i16xi16( assert(m2Offset == 0 || (astep <= 1 && cstep <= 1)); SmallVector sqPattern = {astep + cstep, astep, cstep, 0}; - squareStr = "0x"; + std::string squareStr = "0x"; for (auto sq : sqPattern) squareStr.push_back(getHexValue(sq + m2Offset)); @@ -1594,10 +1583,10 @@ static void computeZbuffAttr_i16xi16( int32_t zeroOffset, // offset of 0 value in the filter int32_t colOffset, // zbuff access distance between vector cols AIEOpAttributes &opAttr) { - std::string startStr, offsetStr, offsetHiStr, stepStr; + std::string offsetStr, offsetHiStr; // zstart must be 4b value. assert(start < (AIEML ? 32 : 16) && "zstart must be 4b value"); - startStr = std::to_string(start); + std::string startStr = std::to_string(start); // If zbuff comes from splat, use default offsets if (accIncr == 0) @@ -1615,7 +1604,7 @@ static void computeZbuffAttr_i16xi16( // Compute step between columns int32_t step = colOffset == -1 ? zeroOffset - 1 - start : colOffset; assert(step >= 0 && "zstep cannot be negative"); - stepStr = std::to_string(step); + std::string stepStr = std::to_string(step); // And now we have everything to push into opAttr opAttr.start.push_back(startStr); @@ -1643,10 +1632,9 @@ static void computeXbuffAttr_i8xi8( int32_t colStep = 2 * colOffset; assert(colStep % 4 == 0 && "xstep must be multiple of 4"); - std::string startStr, offsetStr, squareStr, stepStr; // Arch restriction: xstart must be a multiple of 4 int32_t m4start = (start / 4) * 4; - startStr = std::to_string(m4start); + std::string startStr = std::to_string(m4start); // m4Offset accounts for the excess if start is not a multiple of 4 int32_t m4Offset = start - m4start; // Because of duplication, m4Offset can only be 0 or 2 @@ -1654,12 +1642,12 @@ static void computeXbuffAttr_i8xi8( // Compute offsetStr to something resembling "0x_0_0_0_0", where _ is // (colStep-4)/4. - offsetStr = "0x"; + std::string offsetStr = "0x"; for (int i = vecSize / 4 - 1; i >= 0; --i) { offsetStr.push_back(getHexValue(colStep / 4 - 1)); offsetStr += "0"; } - stepStr = std::to_string(colStep); + std::string stepStr = std::to_string(colStep); // Now compute the square for zbuff. We want a {0,x,0,x} pattern. int32_t offsetWithoutDup = colOffset / 2; @@ -1669,7 +1657,7 @@ static void computeXbuffAttr_i8xi8( assert(m4Offset == 0 || rstep <= 1); SmallVector sqPattern = {rstep, 0, rstep, 0}; - squareStr = "0x"; + std::string squareStr = "0x"; for (auto sq : sqPattern) squareStr.push_back(getHexValue(sq + m4Offset)); @@ -1693,22 +1681,21 @@ static void computeZbuffAttr_i8xi8( // The colOffset must be either <=1, or a multiple of 2 assert((colOffset <= 1 || colOffset % 2 == 0) && "zbuff value not supported"); - std::string startStr, offsetStr, squareStr, stepStr; // Arch restriction: zstart is a multiple of 2 int32_t m2start = (start / 2) * 2; - startStr = std::to_string(m2start); + std::string startStr = std::to_string(m2start); // m2Offset accounts for the extra 1 if the start is not a multiple of 2 int32_t m2Offset = start - m2start; // Compute offsetStr to something resembling "0x43322110". The usual pattern // is "0x_3_2_1_0", and the purpose is to fill the "_". - offsetStr = "0x"; + std::string offsetStr = "0x"; for (int i = vecSize / 4 - 1; i >= 0; --i) { int32_t val = i * accIncr + (colOffset + 1) / 2; offsetStr.push_back(getHexValue(val)); offsetStr.push_back(getHexValue(i * accIncr)); } - stepStr = std::to_string(2 * std::abs(colOffset)); + std::string stepStr = std::to_string(2 * std::abs(colOffset)); nextStart = std::to_string(m2start + 2 * accIncr * (vecSize / 4)); // Now compute the square for zbuff. We want a {0,1+x,y,y+1+x} pattern, where @@ -1717,7 +1704,7 @@ static void computeZbuffAttr_i8xi8( assert(m2Offset == 0 || rstep <= 1); SmallVector sqPattern = {accIncr + rstep, accIncr, rstep, 0}; - squareStr = "0x"; + std::string squareStr = "0x"; for (auto sq : sqPattern) squareStr.push_back(getHexValue(sq + m2Offset)); @@ -2028,13 +2015,11 @@ static void fuseFMAOpsForColumnTopology(func::FuncOp func, VectState *state) { llvm::SmallSet fusedOpSet; // Fuse FMA ops to exploit column topology - func.walk([&](mlir::Operation *op) { + func.walk([&](Operation *op) { if (isa(op)) { // Only process fma ops that are not already fused with another mul/fma if (!fusedOpSet.count(op)) { - // Get the rows and columns for this topology - int32_t lanes, cols; - std::tie(lanes, cols) = getNumRowsAndCols(op, state); + auto [lanes, cols] = getNumRowsAndCols(op, state); // Try fusing a linear chain of FMA ops (max length = cols) starting at // op. fuseFMAOps(op, fusedOpSet, cols, state); @@ -2057,11 +2042,11 @@ static bool matchAttributesAndDistanceForFusion(T1 curOp, T2 defOp) { curOp.getOffsetHi(1) == defOp.getOffsetHi(1) && curOp.getSquare(1) == defOp.getSquare(1) && curOp.getStep(1) == defOp.getStep(1) && - stoi((std::string)curOp.getStart(0)) - - stoi((std::string)defOp.getStart(0)) == + stoi(static_cast(curOp.getStart(0))) - + stoi(static_cast(defOp.getStart(0))) == 2 && - stoi((std::string)curOp.getStart(1)) - - stoi((std::string)defOp.getStart(1)) == + stoi(static_cast(curOp.getStart(1))) - + stoi(static_cast(defOp.getStart(1))) == 2; } @@ -2099,22 +2084,20 @@ static bool matchAttributesAndDistanceForFusion(T1 curOp, T2 defOp) { // aievec.fma_conv %8, %2, %7 {M = 16 : si32, N = 4 : si32} // Currently, we only support mul_conv_16x4 and mac_conv_16x4 intrinsics for // int16 type of AIE-ML architecture. -static bool canFuseMulFMAOpsForInt16(Operation *Op, VectState *state) { +static bool canFuseMulFMAOpsForInt16(Operation *Op) { // Check 1. This should be an aievec fma operation assert(isa(Op) && "operation must be an aievec fma op"); - aievec::FMAOp curOp = cast(Op); + auto curOp = cast(Op); // Check 2. Element type should be int16 - VectorType vType = Op->getOperand(1).getType().cast(); + auto vType = Op->getOperand(1).getType().cast(); Type stype = vType.getElementType(); - IntegerType itype = stype.dyn_cast(); + auto itype = llvm::dyn_cast(stype); if (!itype) return false; - unsigned width = itype.getWidth(); - - if (width != 16) + if (unsigned width = itype.getWidth(); width != 16) return false; // Check 3. acc operand of the Op should be a mul op or fma op @@ -2180,7 +2163,7 @@ static bool canFuseMulFMAOpsForInt16(Operation *Op, VectState *state) { // Rewrite a mul/fma and fma op as a aievec MUL_conv or FMA_Conv op static void fuseMulFMAOpsForInt16(Operation *Op, VectState *state) { - aievec::FMAOp curOp = cast(Op); + auto curOp = cast(Op); Value lhs = curOp->getOperand(0); @@ -2214,19 +2197,19 @@ static void fuseMulFMAOpsForInt16(Operation *Op, VectState *state) { // Get the def op of acc. It is either a mul op or a fma op. Operation *convOp = nullptr; Operation *mulOrFMAOp = Op->getOperand(2).getDefiningOp(); - aievec::MulOp mulOp = dyn_cast(mulOrFMAOp); - aievec::FMAOp fmaOp = dyn_cast(mulOrFMAOp); - int32_t zStart = 0; + auto mulOp = dyn_cast(mulOrFMAOp); + auto fmaOp = dyn_cast(mulOrFMAOp); + int32_t zStart; if (mulOp) { aievec::MulOp defOp = mulOp; - zStart = stoi((std::string)defOp.getStart(1)); + zStart = stoi(static_cast(defOp.getStart(1))); } else { aievec::FMAOp defOp = fmaOp; - zStart = stoi((std::string)defOp.getStart(1)); + zStart = stoi(static_cast(defOp.getStart(1))); } - VectorType vType = Op->getOperand(1).getType().cast(); + auto vType = Op->getOperand(1).getType().cast(); int32_t shiftBytes = zStart * getElementSizeInBits(vType) / 8; auto defOp = mulOp ? mulOp : fmaOp; @@ -2234,17 +2217,13 @@ static void fuseMulFMAOpsForInt16(Operation *Op, VectState *state) { Location loc = defOp->getLoc(); // Generate a shift_bytes operation for concatRhs if needed. - if (shiftBytes) { + if (shiftBytes) concatRhs = generateShiftOp(concatRhs, concatRhs, shiftBytes, state, loc); - } Type stype = vType.getElementType(); - unsigned width = 0; - IntegerType itype = stype.cast(); - - width = itype.getWidth() <= 8 ? 32 : 64; - - Type ctype = mlir::IntegerType::get(itype.getContext(), width); + auto itype = stype.cast(); + unsigned width = itype.getWidth() <= 8 ? 32 : 64; + Type ctype = IntegerType::get(itype.getContext(), width); Type opType = VectorType::get(vType.getShape(), ctype); Value acc = nullptr; // Curently, we only support 16x4 convolution intrinsics for int16 type @@ -2255,10 +2234,10 @@ static void fuseMulFMAOpsForInt16(Operation *Op, VectState *state) { // operation with index 1 lhs = curOp->getOperand(0); - if (mulOp) { + if (mulOp) convOp = state->builder.create(loc, opType, lhs, concatRhs, M, N); - } else { + else { acc = defOp->getOperand(2); bool isSub = state->mscOps.count(defOp); convOp = state->builder.create( @@ -2271,10 +2250,9 @@ static void fuseMulFMAOpsForInt16(Operation *Op, VectState *state) { } static void fuseMulFMAOpsByMulFMAConv(func::FuncOp func, VectState *state) { - func.walk([&](mlir::Operation *Op) { - if (isa(Op) && canFuseMulFMAOpsForInt16(Op, state)) { + func.walk([&](Operation *Op) { + if (isa(Op) && canFuseMulFMAOpsForInt16(Op)) fuseMulFMAOpsForInt16(Op, state); - } }); } @@ -2287,7 +2265,7 @@ static void fuseMulFMAOpsByMulFMAConv(func::FuncOp func, VectState *state) { static void generateAIEMulOrFMAOpsInFunc(func::FuncOp func, VectState *state) { // For each mul/fma op, compute the scheme-dependent operand attributes, and // generate corresponding AIE dialect ops. - func.walk([&](mlir::Operation *op) { + func.walk([&](Operation *op) { if (isa(op)) generateSchemeBasedMulOrFMAOp(op, state); }); @@ -2427,7 +2405,7 @@ static void generateSchemeBasedAddOrSubOp(Operation *Op, VectState *state) { // for the adds involving splat. If none of the operands of the add op is // splat, we must generate simple scheme add op. static void generateAIEAddOrSubOpsInFunc(func::FuncOp func, VectState *state) { - func.walk([&](mlir::Operation *op) { + func.walk([&](Operation *op) { if (isa(op)) generateSchemeBasedAddOrSubOp(op, state); }); @@ -2491,7 +2469,7 @@ static void insertSRSOp(Operation *Op, VectState *state) { // operation is non-AIE op, then we need to generate SRS op to move value // from accumulator to vector auto isNonAIEOp = [&](Operation *op) { return !isAIEOp(op); }; - if (!llvm::any_of(Op->getUsers(), isNonAIEOp)) + if (!any_of(Op->getUsers(), isNonAIEOp)) return; // Given an accumulator, one can use different srs intrinsic to generate @@ -2510,7 +2488,7 @@ static void insertSRSOp(Operation *Op, VectState *state) { // Get the underlying scalar element type of user op. If the user is a // write op, it won't have a result. So get the element type from memref. - Type scalarType = nullptr; + Type scalarType; MemRefType memRefType = nullptr; if (auto writeOp = dyn_cast(user)) { // Get the element type from the memref output @@ -2546,7 +2524,7 @@ static void insertSRSOp(Operation *Op, VectState *state) { user->replaceUsesOfWith(operand, castOp); break; } - aievec::SRSOp srsOp = nullptr; + aievec::SRSOp srsOp; if (!typeToSRSOpMap.count(scalarType)) { srsOp = generateSRSOp(Op->getResult(0), scalarType, state, Op->getLoc()); @@ -2566,7 +2544,7 @@ static void insertSRSOp(Operation *Op, VectState *state) { // Generate SRS op whenever we move data from an accumulator AIE dialect to a // vector. static void insertSRSOpsInFunc(func::FuncOp func, VectState *state) { - func.walk([&](mlir::Operation *op) { + func.walk([&](Operation *op) { // Insert an SRS op if the op outputs to an accumulator if (writesToAccumulator(op)) insertSRSOp(op, state); @@ -2663,7 +2641,7 @@ computeEnclosingLoopsPerBlock(affine::AffineForOp forOp, VectState *state, // this rule is the 8x8 bit scheme, where the xbuff is a bit more restrictive, // so we prefer splat as left operand of multiplication for 8x8 scheme. static void reassociateMulOpInFunc(func::FuncOp func, VectState *state) { - func.walk([&](mlir::Operation *op) { + func.walk([&](Operation *op) { // Only reassociate vector mul ops that are well formed. This also includes // the multiplication component in fma ops. if (isa(op) && isWellFormedVectorOp(op)) { @@ -2682,7 +2660,7 @@ static void reassociateMulOpInFunc(func::FuncOp func, VectState *state) { // commutativity of add op, and is only applied so that we can leverage the // same code functionality for generating mac and msc ops. static void reassociateAddOpInFunc(func::FuncOp func, VectState *state) { - func.walk([&](mlir::Operation *op) { + func.walk([&](Operation *op) { // Only reassociate vector add ops that are well formed. if (isa(op) && isWellFormedVectorOp(op)) { // addOp must have two operands and one result @@ -2768,10 +2746,8 @@ static void recordSextOps(func::FuncOp func, VectState *state) { static void computeReuse(TransferReadOp readOp, VectState *state) { // Construct a linearized access expression for the transfer_read AffineExpr linearAccess = constructLinearizedAffineExpr(readOp, state); - AffineExpr base; - int32_t offset; // Decompose the linear access into a base and constant offset value - std::tie(base, offset) = getBaseAndOffset(linearAccess); + auto [base, offset] = getBaseAndOffset(linearAccess); // Get the step size of the vectorized loop that encloses this read operation int32_t step = computeVecorizedLoopStepSize(readOp, state); @@ -2811,7 +2787,7 @@ static void computeReuse(TransferReadOp readOp, VectState *state) { } } - VectorType vecType = readOp.getVector().getType().cast(); + auto vecType = readOp.getVector().getType().cast(); if (AIEML && (getVectorSizeInBits(vecType) == 512 || getElementSizeInBits(vecType) == 8)) { minVecSize *= 2; @@ -2838,7 +2814,7 @@ static void computeReuse(TransferReadOp readOp, VectState *state) { // If no reuse is found, create a new IntervalReuse object with just this // operation's read access extent. if (!found) { - IntervalReuse *iv = new IntervalReuse(readOp, base); + auto iv = new IntervalReuse(readOp, base); iv->insertInterval(readOp, state->opToIntervalMap, offset, step, isSplat, minVecSize); state->reuseIntervals.push_back(iv); @@ -2846,7 +2822,7 @@ static void computeReuse(TransferReadOp readOp, VectState *state) { } static LogicalResult isUnalignedLoad(TransferReadOp readOp, VectState *state) { - VectorType vectorType = readOp.getResult().getType().cast(); + auto vectorType = readOp.getResult().getType().cast(); unsigned lanes = getVectorLaneSize(vectorType); AffineExpr linearAccess = constructLinearizedAffineExpr(readOp, state); @@ -2854,7 +2830,7 @@ static LogicalResult isUnalignedLoad(TransferReadOp readOp, VectState *state) { return success(); } - MemRefType memRefType = readOp.getSource().getType().cast(); + auto memRefType = readOp.getSource().getType().cast(); MLIRContext *context = memRefType.getContext(); ArrayRef sizes = memRefType.getShape(); int numDims = sizes.size(); @@ -2869,19 +2845,18 @@ static LogicalResult isUnalignedLoad(TransferReadOp readOp, VectState *state) { // If the lowest dim has iv, check whether its corresponding loop step is // divisible by the vector lanes. - int32_t step = 0; if (auto dimExpr = - getAffineDimExpr(numDims - 1, context).dyn_cast()) { + dyn_cast(getAffineDimExpr(numDims - 1, context))) { auto index = indices[dimExpr.getPosition()]; // Iterate over all enclosing loops, and find the one that is variant in // index. for (auto loop : enclosingLoops) { auto affineForOp = cast(loop); auto iv = affineForOp.getInductionVar(); - auto invariants = mlir::affine::getInvariantAccesses(iv, indices); + auto invariants = affine::getInvariantAccesses(iv, indices); if (!invariants.count(index)) { - step = affineForOp.getStepAsInt(); + int step = affineForOp.getStepAsInt(); if (step % lanes) { return readOp->emitError() << "Loop step of inner index of " << readOp->getName() @@ -2982,7 +2957,7 @@ static void reassociateOpsInFunc(func::FuncOp func, VectState *state) { reassociateAddOpInFunc(func, state); } -struct AIEVectorize : public AIEVectorizeBase { +struct AIEVectorize : AIEVectorizeBase { AIEVectorize() = default; void runOnOperation() override; }; @@ -3008,7 +2983,7 @@ void AIEVectorize::runOnOperation() { // Iterate over all the functions in this module, and vectorize them for (func::FuncOp func : module.getOps()) { // Create a new global state - VectState *state = + auto state = new VectState(func.getContext(), shiftParam, zeroOffset, dupFactor); // record the sext op and its operand's def op to sextTruncDefMap @@ -3074,6 +3049,6 @@ void AIEVectorize::runOnOperation() { postCanonicalizeIR(module); } -std::unique_ptr xilinx::aievec::createAIEVectorizePass() { +std::unique_ptr aievec::createAIEVectorizePass() { return std::make_unique(); } diff --git a/lib/Dialect/AIEVec/Transforms/DynamicSizeNoImplicitBroadcast.cpp b/lib/Dialect/AIEVec/Transforms/DynamicSizeNoImplicitBroadcast.cpp index 95ead5c9da..69ab7e4f83 100644 --- a/lib/Dialect/AIEVec/Transforms/DynamicSizeNoImplicitBroadcast.cpp +++ b/lib/Dialect/AIEVec/Transforms/DynamicSizeNoImplicitBroadcast.cpp @@ -39,7 +39,7 @@ using namespace xilinx::aievec; // when the CmpIOp compares the equality of a dynamic dimension's runtime size // to a constant 1, and is guarded by the attribute // `tosa.no_implicit_broadcast_of_dynamic_sizes`. -struct DynamicSizeNoImplicitBroadcastPattern : public RewritePattern { +struct DynamicSizeNoImplicitBroadcastPattern : RewritePattern { DynamicSizeNoImplicitBroadcastPattern(MLIRContext *context) : RewritePattern(arith::CmpIOp::getOperationName(), /*benefit=*/1, context) {} @@ -79,7 +79,7 @@ struct DynamicSizeNoImplicitBroadcastPattern : public RewritePattern { auto index = constIndexOp.getValue().cast().getValue().getZExtValue(); - auto inputDimType = lhsOp->getOperand(0).getType().dyn_cast(); + auto inputDimType = dyn_cast(lhsOp->getOperand(0).getType()); if (!inputDimType || !inputDimType.isDynamicDim(index)) return failure(); @@ -95,7 +95,7 @@ struct DynamicSizeNoImplicitBroadcastPattern : public RewritePattern { //============================================================================// struct DynamicSizeNoImplicitBroadcastPass - : public PassWrapper> { + : PassWrapper> { StringRef getArgument() const final { return "test-dynamic-size-no-implicit-broadcast"; diff --git a/lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp b/lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp index 97fcd0281b..f9ce93be54 100644 --- a/lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp +++ b/lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp @@ -68,8 +68,8 @@ extractMACOperandsFromAddOperands(Value addLhs, Value addRhs) { return std::make_tuple(mulOp.getLhs(), mulOp.getRhs(), acc); // If the MulIOp has been already translated to aievec::MulOp: - aievec::SRSOp lhsSrsOp = addLhs.getDefiningOp(); - aievec::SRSOp rhsSrsOp = addRhs.getDefiningOp(); + auto lhsSrsOp = addLhs.getDefiningOp(); + auto rhsSrsOp = addRhs.getDefiningOp(); aievec::MulOp aieMulOp = nullptr; if (lhsSrsOp) { aieMulOp = lhsSrsOp.getSource().getDefiningOp(); @@ -90,7 +90,7 @@ extractMACOperandsFromAddOperands(Value addLhs, Value addRhs) { static std::optional convertValueToTargetTypeAieML(ConversionPatternRewriter &rewriter, Location loc, Value inputVal, VectorType tgtType) { - VectorType srcType = cast(inputVal.getType()); + auto srcType = cast(inputVal.getType()); auto srcElemType = srcType.getElementType(); unsigned srcBitWidth = srcElemType.getIntOrFloatBitWidth(); unsigned srcLaneSize = getVectorLaneSize(srcType); @@ -117,7 +117,7 @@ convertValueToTargetTypeAieML(ConversionPatternRewriter &rewriter, Location loc, loc, srcType, broadcastZeroOp->getResult(0), 0); SmallVector inputSources = {inputVal, extOp->getResult(0)}; - aievec::ConcatOp concatOp = + auto concatOp = rewriter.create(loc, tgtType, inputSources); return concatOp.getResult(); @@ -132,7 +132,9 @@ convertValueToTargetTypeAieML(ConversionPatternRewriter &rewriter, Location loc, auto castOp = rewriter.create( loc, tgtType, upsOp.getResult(), /*isResAcc*/ false); return castOp.getResult(); - } else if (srcBitWidth == 8 && tgtBitWidth == 32 && srcLaneSize == 16) { + } + + if (srcBitWidth == 8 && tgtBitWidth == 32 && srcLaneSize == 16) { // Case 2: vector<16xi8> to vector<16xi32> conversion by aievec.concat + // aievec.ups + aievec.cast + aievec.ext auto concatOutType = createVectorType(32, srcElemType); @@ -147,7 +149,9 @@ convertValueToTargetTypeAieML(ConversionPatternRewriter &rewriter, Location loc, auto extOp = rewriter.create(loc, tgtType, castOp.getResult(), 0); return extOp.getResult(); - } else if (srcBitWidth == 8 && tgtBitWidth == 16 && srcLaneSize == 32) { + } + + if (srcBitWidth == 8 && tgtBitWidth == 16 && srcLaneSize == 32) { // Case 3: vector<32xi8> to vector<32xi16> conversion by aievec.unpack auto unpackOp = rewriter.create(loc, tgtType, inputVal); return unpackOp.getResult(); @@ -165,8 +169,7 @@ buildAttributeListForRotationSelectOp(PatternRewriter &rewriter, VectorType vTy, int64_t rotation) { unsigned width = 0; auto elemTy = vTy.getElementType(); - auto intTy = dyn_cast(elemTy); - if (intTy) + if (auto intTy = dyn_cast(elemTy)) width = intTy.getWidth(); StringAttr attr0 = rewriter.getStringAttr("0"); StringAttr attr0x06040200 = rewriter.getStringAttr("0x06040200"); @@ -184,7 +187,7 @@ buildAttributeListForRotationSelectOp(PatternRewriter &rewriter, VectorType vTy, StringAttr ystartAttrName = rewriter.getStringAttr("ystart"); switch (width) { - case 16: + case 16: { if (rotation % 2) { int64_t xstart = rotation + 1; int64_t ystart = rotation - 1; @@ -198,19 +201,18 @@ buildAttributeListForRotationSelectOp(PatternRewriter &rewriter, VectorType vTy, {yoffsetsHiAttrName, rewriter.getStringAttr("0x0d0b0907")}, {ysquareAttrName, attr0x2103}, {ystartAttrName, rewriter.getStringAttr(std::to_string(ystart))}}); - } else { - return SmallVector( - {{selectAttrName, attr0}, - {xoffsetsAttrName, attr0x06040200}, - {xoffsetsHiAttrName, attr0x0e0c0a08}, - {xsquareAttrName, attr0x3210}, - {xstartAttrName, rewriter.getStringAttr(std::to_string(rotation))}, - {yoffsetsAttrName, attr0}, - {yoffsetsHiAttrName, attr0}, - {ysquareAttrName, attr0}, - {ystartAttrName, attr0}}); } - break; + return SmallVector( + {{selectAttrName, attr0}, + {xoffsetsAttrName, attr0x06040200}, + {xoffsetsHiAttrName, attr0x0e0c0a08}, + {xsquareAttrName, attr0x3210}, + {xstartAttrName, rewriter.getStringAttr(std::to_string(rotation))}, + {yoffsetsAttrName, attr0}, + {yoffsetsHiAttrName, attr0}, + {ysquareAttrName, attr0}, + {ystartAttrName, attr0}}); + } case 32: return SmallVector( {{selectAttrName, attr0}, @@ -220,20 +222,21 @@ buildAttributeListForRotationSelectOp(PatternRewriter &rewriter, VectorType vTy, {yoffsetsAttrName, attr0}, {ysquareAttrName, attr0}, {ystartAttrName, attr0}}); + default: + llvm::report_fatal_error("Unexpected width!"); } + return {}; } -namespace xilinx { -namespace aievec { +namespace xilinx::aievec { SmallVector buildFMAOpSplatAttrForElemTy(aievec::FMAOp fmaOp, int64_t bcastPos, int64_t step = 1) { unsigned width = 0; auto elemTy = fmaOp.getLhs().getType().getElementType(); - auto intTy = dyn_cast(elemTy); - if (intTy) + if (auto intTy = dyn_cast(elemTy)) width = intTy.getWidth(); auto ctx = fmaOp.getContext(); switch (width) { @@ -282,12 +285,14 @@ SmallVector buildFMAOpSplatAttrForElemTy(aievec::FMAOp fmaOp, {fmaOp.getZstepAttrName(), fmaOp.getZstepAttr()}, {fmaOp.getZsquareAttrName(), fmaOp.getZsquareAttr()}, {fmaOp.getFmsubAttrName(), fmaOp.getFmsubAttr()}}); + default: + llvm::report_fatal_error("Unexpected width!"); } + return {}; } -} // namespace aievec -} // namespace xilinx +} // namespace xilinx::aievec template static LogicalResult genAddElemAieML(ConversionPatternRewriter &rewriter, @@ -331,7 +336,7 @@ convertToIntegerPredicate(arith::CmpFPredicate pred) { case CmpFPredicate::ONE: return CmpIPredicate::ne; default: - llvm_unreachable("Unexpected predicate!"); + llvm::report_fatal_error("Unexpected predicate!"); } } @@ -376,15 +381,14 @@ static void generateAIEVecOpsForReductionOp(ConversionPatternRewriter &rewriter, "shiftIndex must be power of 2"); Location loc = srcOp.getLoc(); - VectorType vType = dyn_cast(curValue.getType()); + auto vType = dyn_cast(curValue.getType()); Type scalarType = vType.getElementType(); - SmallVector sources = {curValue}; Type vecType = curValue.getType(); DstOpTy curOp = nullptr; unsigned elWidth = scalarType.getIntOrFloatBitWidth(); for (int id = shiftIndex; id > 0; id /= 2) { - arith::ConstantOp constOp = rewriter.create( + auto constOp = rewriter.create( loc, rewriter.getI32IntegerAttr(id * elWidth / 8)); auto shiftBytesOp = rewriter.create( @@ -396,11 +400,10 @@ static void generateAIEVecOpsForReductionOp(ConversionPatternRewriter &rewriter, curValue = curOp.getResult(); } - arith::ConstantOp zeroConstOp = + auto zeroConstOp = rewriter.create(loc, rewriter.getI32IntegerAttr(0)); rewriter.replaceOpWithNewOp(srcOp, scalarType, curOp, zeroConstOp.getResult()); - return; } //===----------------------------------------------------------------------===// @@ -410,8 +413,8 @@ static void generateAIEVecOpsForReductionOp(ConversionPatternRewriter &rewriter, // This pattern fold `vector.extract` and `vector.broadcast` into // `aievec.broadcast` for aie-ml struct FoldVectorExtractAndBroadcastToAIEBroadcast - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::BroadcastOp bcastOp, OpAdaptor adaptor, @@ -426,12 +429,12 @@ struct FoldVectorExtractAndBroadcastToAIEBroadcast auto src = extOp.getVector(); auto pos = extOp.getStaticPosition(); int64_t posVal = pos[0]; - VectorType srcVecType = cast(src.getType()); - VectorType resultType = cast(bcastOp.getResult().getType()); + auto srcVecType = cast(src.getType()); + auto resultType = cast(bcastOp.getResult().getType()); if (srcVecType != resultType) { if (srcVecType.getNumElements() != 2 * resultType.getNumElements()) return failure(); - int8_t half = static_cast(posVal / resultType.getNumElements()); + auto half = static_cast(posVal / resultType.getNumElements()); posVal -= half * resultType.getNumElements(); src = rewriter .create(extOp.getLoc(), resultType, src, @@ -440,9 +443,9 @@ struct FoldVectorExtractAndBroadcastToAIEBroadcast } unsigned elWidth = resultType.getElementType().getIntOrFloatBitWidth(); - unsigned laneSize = getVectorLaneSize(resultType); - if (laneSize * elWidth == 512) { + if (unsigned laneSize = getVectorLaneSize(resultType); + laneSize * elWidth == 512) { // Common use case for the broadcast_elem intrinsic rewriter.replaceOpWithNewOp(bcastOp, resultType, src, posVal); @@ -460,7 +463,7 @@ struct FoldVectorExtractAndBroadcastToAIEBroadcast // e.g. need v32int32 due to the subsequent v32acc32 operation VectorType aievecBcastType = createVectorType(512 / elWidth, resultType.getElementType()); - int8_t half = static_cast(posVal / resultType.getNumElements()); + auto half = static_cast(posVal / resultType.getNumElements()); posVal -= half * resultType.getNumElements(); auto extOp = rewriter.create(bcastOp.getLoc(), aievecBcastType, src, @@ -479,21 +482,21 @@ struct FoldVectorExtractAndBroadcastToAIEBroadcast }; struct ConvertBroadcastToAIEBroadcast - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::BroadcastOp bcastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (auto extOp = adaptor.getSource().getDefiningOp()) + if (adaptor.getSource().getDefiningOp()) return failure(); // Only support broadcasting a single element for now if (!isa(adaptor.getSource().getType())) return failure(); - VectorType resultType = cast(bcastOp.getResult().getType()); + auto resultType = cast(bcastOp.getResult().getType()); Type scalarType = resultType.getElementType(); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); unsigned laneSize = getVectorLaneSize(resultType); @@ -503,14 +506,18 @@ struct ConvertBroadcastToAIEBroadcast rewriter.replaceOpWithNewOp(bcastOp, resultType, src); return success(); - } else if (laneSize * elWidth == 256) { + } + + if (laneSize * elWidth == 256) { VectorType vecType = createVectorType(512 / elWidth, scalarType); auto aieBcastOp = rewriter.create( bcastOp.getLoc(), vecType, src); rewriter.replaceOpWithNewOp(bcastOp, resultType, aieBcastOp.getResult(), 0); return success(); - } else if (laneSize * elWidth == 1024) { + } + + if (laneSize * elWidth == 1024) { VectorType vecType = createVectorType(512 / elWidth, scalarType); auto aieBcastOp = rewriter.create( bcastOp.getLoc(), vecType, src); @@ -527,18 +534,18 @@ struct ConvertBroadcastToAIEBroadcast // This pattern replaces `arith.muli`+`arith.addi` on vectors with // `aievec.mac_elem`. This pattern works for aie-ml. struct ConvertMulAddToAIEVecFMAElemOpPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; ConvertMulAddToAIEVecFMAElemOpPattern(MLIRContext *context, unsigned shiftParam = 0) - : OpConversionPattern(context), shiftParam(shiftParam) {} + : OpConversionPattern(context), shiftParam(shiftParam) {} LogicalResult matchAndRewrite(arith::AddIOp addOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Verify it's a vector operation - VectorType resultType = dyn_cast(addOp.getType()); + auto resultType = dyn_cast(addOp.getType()); if (!resultType) return failure(); @@ -580,18 +587,18 @@ struct ConvertMulAddToAIEVecFMAElemOpPattern // This pattern replaces `arith.mulf` on vectors with // `aievec.mul_elem`. This pattern works for aie-ml. struct ConvertMulFToAIEVecMulElemOpPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; ConvertMulFToAIEVecMulElemOpPattern(MLIRContext *context, unsigned shiftParam = 0) - : OpConversionPattern(context), shiftParam(shiftParam) {} + : OpConversionPattern(context), shiftParam(shiftParam) {} LogicalResult matchAndRewrite(arith::MulFOp mulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Verify it's a vector operation - VectorType resultType = dyn_cast(mulOp.getType()); + auto resultType = dyn_cast(mulOp.getType()); if (!resultType) return failure(); @@ -618,8 +625,8 @@ struct ConvertMulFToAIEVecMulElemOpPattern if (auto rvalExtOp = rval.getDefiningOp()) { rval = rvalExtOp->getOperand(0); } - VectorType lSrcType = cast(lval.getType()); - VectorType rSrcType = cast(rval.getType()); + auto lSrcType = cast(lval.getType()); + auto rSrcType = cast(rval.getType()); unsigned lBitWidth = lSrcType.getElementType().getIntOrFloatBitWidth(); unsigned rBitWidth = rSrcType.getElementType().getIntOrFloatBitWidth(); Type accType = getVectorOpDestType(lSrcType, /*AIEML =*/true); @@ -638,10 +645,9 @@ struct ConvertMulFToAIEVecMulElemOpPattern // Prepare lhr/rhs for the aievec.mul_elem op VectorType targetInputType = createVectorType(512 / lBitWidth, lSrcType.getElementType()); - if (rBitWidth > lBitWidth) { + if (rBitWidth > lBitWidth) targetInputType = createVectorType(512 / rBitWidth, rSrcType.getElementType()); - } auto lValConverted = convertValueToTargetTypeAieML(rewriter, mulOp.getLoc(), lval, targetInputType); auto rValConverted = convertValueToTargetTypeAieML(rewriter, mulOp.getLoc(), @@ -650,7 +656,7 @@ struct ConvertMulFToAIEVecMulElemOpPattern return failure(); // Create an aievec.mul_elem op - aievec::MulElemOp mulElemOp = rewriter.create( + auto mulElemOp = rewriter.create( mulOp.getLoc(), accType, *lValConverted, *rValConverted); // Create an aievec.cast or an aievec.srs op @@ -679,18 +685,18 @@ struct ConvertMulFToAIEVecMulElemOpPattern // This pattern replaces `arith.muli` on vectors with // `aievec.mul_elem`. This pattern works for aie-ml. struct ConvertMulIToAIEVecMulElemOpPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; ConvertMulIToAIEVecMulElemOpPattern(MLIRContext *context, unsigned shiftParam = 0) - : OpConversionPattern(context), shiftParam(shiftParam) {} + : OpConversionPattern(context), shiftParam(shiftParam) {} LogicalResult matchAndRewrite(arith::MulIOp mulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Verify it's a vector operation - VectorType resultType = dyn_cast(mulOp.getType()); + auto resultType = dyn_cast(mulOp.getType()); if (!resultType) return failure(); @@ -717,8 +723,8 @@ struct ConvertMulIToAIEVecMulElemOpPattern if (auto rvalExtOp = rval.getDefiningOp()) { rval = rvalExtOp->getOperand(0); } - VectorType lSrcType = cast(lval.getType()); - VectorType rSrcType = cast(rval.getType()); + auto lSrcType = cast(lval.getType()); + auto rSrcType = cast(rval.getType()); unsigned lBitWidth = lSrcType.getElementType().getIntOrFloatBitWidth(); unsigned rBitWidth = rSrcType.getElementType().getIntOrFloatBitWidth(); Type accType = getVectorOpDestType(lSrcType, /*AIEML =*/true); @@ -741,7 +747,7 @@ struct ConvertMulIToAIEVecMulElemOpPattern return failure(); // Create an aievec.mul_elem op - aievec::MulElemOp mulElemOp = rewriter.create( + auto mulElemOp = rewriter.create( mulOp.getLoc(), accType, *lValConverted, *rValConverted); // Create an aievec.cast or an aievec.srs op @@ -769,8 +775,8 @@ struct ConvertMulIToAIEVecMulElemOpPattern // This pattern folds an extract + broadcast feeding into an `aievec::FMAOp` // into the op, using the shuffle attributes. -struct FoldBroadcastToFMAOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct FoldBroadcastToFMAOp : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(aievec::FMAOp fmaOp, OpAdaptor adaptor, @@ -816,14 +822,13 @@ struct FoldBroadcastToFMAOp : public OpConversionPattern { } }; -struct ConvertMulAddToAIEVecFMAOpPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ConvertMulAddToAIEVecFMAOpPattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(aievec::AddOp addOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType vecType = cast(addOp.getType()); + auto vecType = cast(addOp.getType()); auto res = extractMACOperandsFromAddOperands(adaptor.getLhs(), adaptor.getRhs()); @@ -860,15 +865,15 @@ struct ConvertMulAddToAIEVecFMAOpPattern // it performs a naïve direct translation. This needs to be expanded to // support more complex scenarios. struct LowerVectorTransferReadToAIEUPD - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LowerVectorTransferReadToAIEUPD(MLIRContext *context, int64_t minVectorSize, int64_t maxVectorSize, int64_t alignment, int64_t maxLoadSize) - : OpConversionPattern(context), - minVectorSize(minVectorSize), maxVectorSize(maxVectorSize), - vectorAlignment(alignment), maxLoadSize(maxLoadSize) {} + : OpConversionPattern(context), minVectorSize(minVectorSize), + maxVectorSize(maxVectorSize), vectorAlignment(alignment), + maxLoadSize(maxLoadSize) {} LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, @@ -928,7 +933,7 @@ struct LowerVectorTransferReadToAIEUPD // XXX: Notice that this template doesn't verify that the vector element type // XXX: is supported by the target architecture. template -struct OneToOneVectorOpToAIEVecOpPattern : public OpConversionPattern { +struct OneToOneVectorOpToAIEVecOpPattern : OpConversionPattern { using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename SrcOpTy::Adaptor; @@ -943,9 +948,8 @@ struct OneToOneVectorOpToAIEVecOpPattern : public OpConversionPattern { } }; -struct LowerVectorAddIOpToAIEVecAddOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct LowerVectorAddIOpToAIEVecAddOp : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::AddIOp addOp, OpAdaptor adaptor, @@ -979,9 +983,8 @@ using LowerVectorSubIOpToAIEVecSubOp = using LowerVectorSubFOpToAIEVecSubOp = OneToOneVectorOpToAIEVecOpPattern; -struct LowerVectorMulIOpToAIEVecMulOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct LowerVectorMulIOpToAIEVecMulOp : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::MulIOp mulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -1001,7 +1004,7 @@ struct LowerVectorMulIOpToAIEVecMulOp template struct LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp - : public OpConversionPattern { + : OpConversionPattern { using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename SrcOpTy::Adaptor; @@ -1048,10 +1051,9 @@ struct LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp rewriter.replaceOpWithNewOp(srcOp, srcOp.getType(), lhs, rhs); return success(); - } else { - return genAddElemAieML(rewriter, lhs, rhs, - resultType, srcOp); } + return genAddElemAieML(rewriter, lhs, rhs, resultType, + srcOp); } // If element width is 32, we need to consider sign extension cases @@ -1064,10 +1066,9 @@ struct LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp rewriter.replaceOpWithNewOp(srcOp, srcOp.getType(), lhs, rhs); return success(); - } else { - return genAddElemAieML(rewriter, lhs, rhs, - resultType, srcOp); } + return genAddElemAieML(rewriter, lhs, rhs, + resultType, srcOp); } if (lhsExt && rhsExt) { @@ -1125,7 +1126,9 @@ struct LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp rewriter.replaceOpWithNewOp( srcOp, srcOp.getType(), elemOp.getResult(), /*isResAcc*/ false); return success(); - } else if (bitWidth == 16) { + } + + if (bitWidth == 16) { accType = getVectorOpDestType(resultType, /*AIEML =*/true); auto lUpsOp = rewriter.create(srcOp.getLoc(), accType, lval); @@ -1195,9 +1198,9 @@ struct LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp auto extVal = lhsExt ? lval : rval; VectorType vType = cast(extVal.getType()); Type accType = getVectorOpDestType(vType, /*AIEML =*/true); - aievec::UPSOp upsOp = nullptr; - aievec::CastOp castOp = nullptr; + aievec::UPSOp upsOp; + aievec::CastOp castOp; if (lhsExt) { upsOp = rewriter.create(srcOp.getLoc(), accType, lval); @@ -1211,15 +1214,18 @@ struct LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp lval, /*isResAcc*/ true); } + auto elemOp = rewriter.create( srcOp.getLoc(), upsOp->getResult(0).getType(), upsOp->getResult(0), castOp->getResult(0)); rewriter.replaceOpWithNewOp( srcOp, srcOp.getType(), elemOp.getResult(), /*isResAcc*/ false); + return success(); } } + // v16bfloat16 Type accType = getVectorOpDestType(resultType, /*AIEML =*/true); auto lUpsOp = @@ -1233,8 +1239,10 @@ struct LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp srcOp.getLoc(), rewriter.getI32IntegerAttr(0)); rewriter.replaceOpWithNewOp( srcOp, srcOp.getType(), elemOp.getResult(), shiftParamOp.getResult()); + return success(); } + return failure(); } }; @@ -1253,8 +1261,7 @@ using LowerVectorSubFOpToAIEVecSubElemOp = aievec::SubElemOp>; template -struct LowerVectorMinMaxOpToAIEVecMinMaxOp - : public OpConversionPattern { +struct LowerVectorMinMaxOpToAIEVecMinMaxOp : OpConversionPattern { using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename SrcOpTy::Adaptor; @@ -1294,7 +1301,7 @@ using LowerVectorMaximumFOpToAIEVecMaxOp = LowerVectorMinMaxOpToAIEVecMinMaxOp; template -struct LowerVectorCmpOpToAIEVecCmpOp : public OpConversionPattern { +struct LowerVectorCmpOpToAIEVecCmpOp : OpConversionPattern { using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename SrcOpTy::Adaptor; @@ -1350,14 +1357,13 @@ using LowerVectorCmpIOpToAIEVecCmpOp = using LowerVectorCmpFOpToAIEVecCmpOp = LowerVectorCmpOpToAIEVecCmpOp; -struct LowerVectorSelectOpToAIEVecSelOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct LowerVectorSelectOpToAIEVecSelOp : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::SelectOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType resultType = dyn_cast(srcOp.getType()); + auto resultType = dyn_cast(srcOp.getType()); if (!resultType) return failure(); @@ -1388,20 +1394,18 @@ struct LowerVectorSelectOpToAIEVecSelOp } }; -struct LowerVectorReductionMinOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct LowerVectorReductionMinOp : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto kind = srcOp.getKind(); - if (kind != vector::CombiningKind::MINSI && - kind != vector::CombiningKind::MINUI && - kind != vector::CombiningKind::MINF) + if (auto kind = srcOp.getKind(); kind != vector::CombiningKind::MINSI && + kind != vector::CombiningKind::MINUI && + kind != vector::CombiningKind::MINF) return failure(); - VectorType vType = cast(srcOp.getVector().getType()); + auto vType = cast(srcOp.getVector().getType()); Type scalarType = vType.getElementType(); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); unsigned laneSize = getVectorLaneSize(vType); @@ -1416,20 +1420,18 @@ struct LowerVectorReductionMinOp } }; -struct LowerVectorReductionMaxOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct LowerVectorReductionMaxOp : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto kind = srcOp.getKind(); - if (kind != vector::CombiningKind::MAXSI && - kind != vector::CombiningKind::MAXUI && - kind != vector::CombiningKind::MAXF) + if (auto kind = srcOp.getKind(); kind != vector::CombiningKind::MAXSI && + kind != vector::CombiningKind::MAXUI && + kind != vector::CombiningKind::MAXF) return failure(); - VectorType vType = cast(srcOp.getVector().getType()); + auto vType = cast(srcOp.getVector().getType()); Type scalarType = vType.getElementType(); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); unsigned laneSize = getVectorLaneSize(vType); @@ -1444,18 +1446,16 @@ struct LowerVectorReductionMaxOp } }; -struct LowerVectorReductionAddIntOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct LowerVectorReductionAddIntOp : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto kind = srcOp.getKind(); - if (kind != vector::CombiningKind::ADD) + if (auto kind = srcOp.getKind(); kind != vector::CombiningKind::ADD) return failure(); - VectorType vType = cast(srcOp.getVector().getType()); + auto vType = cast(srcOp.getVector().getType()); Type scalarType = vType.getElementType(); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); unsigned laneSize = getVectorLaneSize(vType); @@ -1484,26 +1484,25 @@ struct LowerVectorReductionAddIntOp shiftIndex /= 2; generateAIEVecOpsForReductionOp( rewriter, srcOp, shiftIndex, addElemOp.getResult()); - } else { + } else generateAIEVecOpsForReductionOp( rewriter, srcOp, shiftIndex, srcOp.getVector()); - } + return success(); } }; struct LowerVectorReductionAddFloatOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto kind = srcOp.getKind(); - if (kind != vector::CombiningKind::ADD) + if (auto kind = srcOp.getKind(); kind != vector::CombiningKind::ADD) return failure(); - VectorType vType = cast(srcOp.getVector().getType()); + auto vType = cast(srcOp.getVector().getType()); Type scalarType = vType.getElementType(); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); unsigned laneSize = getVectorLaneSize(vType); @@ -1520,7 +1519,7 @@ struct LowerVectorReductionAddFloatOp aievec::CastOp curOp = nullptr; for (int id = shiftIndex; id > 0; id /= 2) { - arith::ConstantOp constOp = rewriter.create( + auto constOp = rewriter.create( loc, rewriter.getI32IntegerAttr(id * elWidth / 8)); auto shiftBytesOp = rewriter.create( @@ -1536,11 +1535,10 @@ struct LowerVectorReductionAddFloatOp rCastOp.getResult()); curOp = rewriter.create(loc, vType, elemOp.getResult(), /*isResAcc*/ false); - curValue = curOp.getResult(); } - arith::ConstantOp zeroConstOp = + auto zeroConstOp = rewriter.create(loc, rewriter.getI32IntegerAttr(0)); rewriter.replaceOpWithNewOp(srcOp, scalarType, curOp, zeroConstOp.getResult()); @@ -1549,17 +1547,16 @@ struct LowerVectorReductionAddFloatOp }; struct LowerVectorReductionAddBfloat16Op - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto kind = srcOp.getKind(); - if (kind != vector::CombiningKind::ADD) + if (auto kind = srcOp.getKind(); kind != vector::CombiningKind::ADD) return failure(); - VectorType vType = cast(srcOp.getVector().getType()); + auto vType = cast(srcOp.getVector().getType()); Type scalarType = vType.getElementType(); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); unsigned laneSize = getVectorLaneSize(vType); @@ -1586,7 +1583,7 @@ struct LowerVectorReductionAddBfloat16Op aievec::AddElemOp curOp = nullptr; for (int id = shiftIndex; id > 0; id /= 2) { - arith::ConstantOp constOp = rewriter.create( + auto constOp = rewriter.create( loc, rewriter.getI32IntegerAttr(id * accWidth / 8)); auto shiftBytesOp = rewriter.create( loc, accType, curValue, curValue, constOp, true); @@ -1603,7 +1600,7 @@ struct LowerVectorReductionAddBfloat16Op auto concatOp = rewriter.create(loc, vecType, concatSources); - arith::ConstantOp zeroConstOp = + auto zeroConstOp = rewriter.create(loc, rewriter.getI32IntegerAttr(0)); rewriter.replaceOpWithNewOp(srcOp, scalarType, concatOp, zeroConstOp.getResult()); @@ -1614,8 +1611,8 @@ struct LowerVectorReductionAddBfloat16Op // Convert a `vector.extract_strided_slice` op on 1D vectors into an // `aievec.select` + `aievec.ext` op. struct LowerVectorExtractStridedSliceOpAIEv1Pattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, @@ -1646,7 +1643,6 @@ struct LowerVectorExtractStridedSliceOpAIEv1Pattern rewriter.replaceOpWithNewOp(extractOp, extractOp.getType(), selectOp.getResult(), rewriter.getI8IntegerAttr(0)); - return success(); } }; @@ -1654,8 +1650,8 @@ struct LowerVectorExtractStridedSliceOpAIEv1Pattern // Convert a `vector.extract_strided_slice` op on 1D vectors into an // `aievec.shift` op. struct LowerVectorExtractStridedSliceOpAIEMLPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, @@ -1699,11 +1695,11 @@ struct LowerVectorExtractStridedSliceOpAIEMLPattern // Replaces a short UPD op with a wide one followed by an ext op of the bottom // half. -struct ExpandUPDToUPDAndExtPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ExpandUPDToUPDAndExtPattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; ExpandUPDToUPDAndExtPattern(MLIRContext *context) - : OpConversionPattern(context) {} + : OpConversionPattern(context) {} LogicalResult matchAndRewrite(aievec::UPDOp updOp, OpAdaptor adaptor, @@ -1729,11 +1725,10 @@ struct ExpandUPDToUPDAndExtPattern : public OpConversionPattern { // Replaces a wide UPD op followed by an ext op of the bottom half with a short // UPD op. -struct FuseExtIntoUPDPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct FuseExtIntoUPDPattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - FuseExtIntoUPDPattern(MLIRContext *context) - : OpConversionPattern(context) {} + FuseExtIntoUPDPattern(MLIRContext *context) : OpConversionPattern(context) {} LogicalResult matchAndRewrite(aievec::ExtOp extOp, OpAdaptor adaptor, @@ -1759,17 +1754,16 @@ struct FuseExtIntoUPDPattern : public OpConversionPattern { }; // Lower ExpOp to function call -struct ComputeExpOpByLUTPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ComputeExpOpByLUTPattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(math::ExpOp expOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType srcType = dyn_cast(adaptor.getOperand().getType()); + auto srcType = dyn_cast(adaptor.getOperand().getType()); - if (!srcType) { + if (!srcType) return failure(); - } Type scalarType = srcType.getElementType(); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); @@ -1778,7 +1772,7 @@ struct ComputeExpOpByLUTPattern : public OpConversionPattern { return failure(); StringRef includeName = "lut_based_ops.h"; - ModuleOp moduleOp = expOp->getParentOfType(); + auto moduleOp = expOp->getParentOfType(); rewriter.setInsertionPointToStart( &moduleOp.getRegion().getBlocks().front()); rewriter.create(moduleOp.getLoc(), includeName, false); @@ -1787,7 +1781,7 @@ struct ComputeExpOpByLUTPattern : public OpConversionPattern { rewriter.setInsertionPoint(expOp); Type accType = getVectorOpDestType(srcType, /*AIEML =*/true); - auto funcOp = rewriter.create( + auto funcOp = rewriter.create( expOp.getLoc(), TypeRange{accType}, "getExpBf16", nullptr, nullptr, expOperands); auto shiftParamOp = rewriter.create( @@ -1806,212 +1800,192 @@ struct ComputeExpOpByLUTPattern : public OpConversionPattern { // %1 = arith.truncf %0 : f32 to bf16 // to - // %0 = emitc.call "getInvBf16"(%0) : f32 -> bf16; -struct ComputeInvOpByLUTPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ComputeInvOpByLUTPattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::DivFOp divOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type srcType = adaptor.getLhs().getType(); if (!divOp->hasOneUse() || isa(srcType) || - !isa(srcType)) { + !isa(srcType)) return failure(); - } - if (!isa(*divOp->getUsers().begin())) { + if (!isa(*divOp->getUsers().begin())) return failure(); - } - FloatType fType = cast(srcType); - - if (fType.getWidth() != 32) { + auto fType = cast(srcType); + if (fType.getWidth() != 32) return failure(); - } auto constOp = dyn_cast(divOp.getLhs().getDefiningOp()); if (!constOp || constOp.getValue().cast().getValue().convertToDouble() != - 1.0f) { + 1.0f) return failure(); - } StringRef includeName = "lut_based_ops.h"; - ModuleOp moduleOp = divOp->getParentOfType(); + auto moduleOp = divOp->getParentOfType(); rewriter.setInsertionPointToStart( &moduleOp.getRegion().getBlocks().front()); rewriter.create(moduleOp.getLoc(), includeName, false); SmallVector invOperands = {adaptor.getRhs()}; - arith::TruncFOp truncOp = cast(*divOp->getUsers().begin()); + auto truncOp = cast(*divOp->getUsers().begin()); rewriter.setInsertionPoint(truncOp); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( truncOp, TypeRange{truncOp.getResult().getType()}, "getInvBf16", nullptr, nullptr, invOperands); rewriter.eraseOp(divOp); + return success(); } }; // Convert math.tanh to a function call to compute tanh(x) by look up tables -struct ComputeTanhOpByLUTPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ComputeTanhOpByLUTPattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(math::TanhOp tanhOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType srcType = dyn_cast(tanhOp.getOperand().getType()); - if (!srcType) { + auto srcType = dyn_cast(tanhOp.getOperand().getType()); + if (!srcType) return failure(); - } Type scalarType = srcType.getElementType(); - - if (!isa(scalarType)) { + if (!isa(scalarType)) return failure(); - } unsigned laneSize = getVectorLaneSize(srcType); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); - - if (elWidth != 16 || laneSize != 16) { + if (elWidth != 16 || laneSize != 16) return failure(); - } StringRef includeName = "lut_based_ops.h"; - ModuleOp moduleOp = tanhOp->getParentOfType(); + auto moduleOp = tanhOp->getParentOfType(); rewriter.setInsertionPointToStart( &moduleOp.getRegion().getBlocks().front()); rewriter.create(moduleOp.getLoc(), includeName, false); rewriter.setInsertionPoint(tanhOp); SmallVector tanhOperands = {adaptor.getOperand()}; - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( tanhOp, TypeRange{tanhOp.getResult().getType()}, "getTanhBf16", nullptr, nullptr, tanhOperands); + return success(); } }; // Convert math.sqrt to a function call to compute sqrt(x) for v16bfloat16 and // v32bfloat16 types -struct ComputeSqrtOpPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ComputeSqrtOpPattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(math::SqrtOp sqrtOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType srcType = dyn_cast(sqrtOp.getOperand().getType()); - if (!srcType) { + auto srcType = dyn_cast(sqrtOp.getOperand().getType()); + if (!srcType) return failure(); - } Type scalarType = srcType.getElementType(); - - if (!isa(scalarType)) { + if (!isa(scalarType)) return failure(); - } unsigned laneSize = getVectorLaneSize(srcType); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); - - if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) { + if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) return failure(); - } StringRef includeName = "vec_math.h"; - ModuleOp moduleOp = sqrtOp->getParentOfType(); + auto moduleOp = sqrtOp->getParentOfType(); rewriter.setInsertionPointToStart( &moduleOp.getRegion().getBlocks().front()); rewriter.create(moduleOp.getLoc(), includeName, false); rewriter.setInsertionPoint(sqrtOp); SmallVector sqrtOperands = {adaptor.getOperand()}; - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( sqrtOp, TypeRange{sqrtOp.getResult().getType()}, "getSqrtBf16", nullptr, nullptr, sqrtOperands); + return success(); } }; // Convert math.rsqrt to a function call to compute 1.0f / sqrt(x) for // v16bfloat16 and v32bfloat16 types -struct ComputeRsqrtOpPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ComputeRsqrtOpPattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(math::RsqrtOp rsqrtOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType srcType = dyn_cast(rsqrtOp.getOperand().getType()); - if (!srcType) { + auto srcType = dyn_cast(rsqrtOp.getOperand().getType()); + if (!srcType) return failure(); - } Type scalarType = srcType.getElementType(); - - if (!isa(scalarType)) { + if (!isa(scalarType)) return failure(); - } unsigned laneSize = getVectorLaneSize(srcType); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); - - if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) { + if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) return failure(); - } StringRef includeName = "vec_math.h"; - ModuleOp moduleOp = rsqrtOp->getParentOfType(); + auto moduleOp = rsqrtOp->getParentOfType(); rewriter.setInsertionPointToStart( &moduleOp.getRegion().getBlocks().front()); rewriter.create(moduleOp.getLoc(), includeName, false); rewriter.setInsertionPoint(rsqrtOp); SmallVector rsqrtOperands = {adaptor.getOperand()}; - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( rsqrtOp, TypeRange{rsqrtOp.getResult().getType()}, "getRsqrtBf16", nullptr, nullptr, rsqrtOperands); + return success(); } }; // Convert math.erf to a function call to compute erf(x) for v16bfloat16 and // v32bfloat16 types -struct ComputeErfOpPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ComputeErfOpPattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(math::ErfOp erfOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType srcType = dyn_cast(erfOp.getOperand().getType()); - if (!srcType) { + auto srcType = dyn_cast(erfOp.getOperand().getType()); + if (!srcType) return failure(); - } Type scalarType = srcType.getElementType(); - - if (!isa(scalarType)) { + if (!isa(scalarType)) return failure(); - } unsigned laneSize = getVectorLaneSize(srcType); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); - - if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) { + if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) return failure(); - } StringRef includeName = "vec_math.h"; - ModuleOp moduleOp = erfOp->getParentOfType(); + auto moduleOp = erfOp->getParentOfType(); rewriter.setInsertionPointToStart( &moduleOp.getRegion().getBlocks().front()); rewriter.create(moduleOp.getLoc(), includeName, false); rewriter.setInsertionPoint(erfOp); SmallVector erfOperands = {adaptor.getOperand()}; - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( erfOp, TypeRange{erfOp.getResult().getType()}, "getErfBf16", nullptr, nullptr, erfOperands); + return success(); } }; @@ -2019,7 +1993,7 @@ struct ComputeErfOpPattern : public OpConversionPattern { // Convert math.absf and math.absi to a function call to compute abs(x) for // v16bfloat16, v32bfloat16, v16float, v16int32, v32int16 and v64int8 types template -struct ComputeAbsOpPattern : public OpConversionPattern { +struct ComputeAbsOpPattern : OpConversionPattern { using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename SrcOpTy::Adaptor; @@ -2027,16 +2001,17 @@ struct ComputeAbsOpPattern : public OpConversionPattern { matchAndRewrite(SrcOpTy absOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { StringRef includeName = "vec_math.h"; - ModuleOp moduleOp = absOp->template getParentOfType(); + auto moduleOp = absOp->template getParentOfType(); rewriter.setInsertionPointToStart( &moduleOp.getRegion().getBlocks().front()); rewriter.create(moduleOp.getLoc(), includeName, false); rewriter.setInsertionPoint(absOp); SmallVector absOperands = {adaptor.getOperand()}; - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( absOp, TypeRange{absOp.getResult().getType()}, "getAbs", nullptr, nullptr, absOperands); + return success(); } }; @@ -2045,7 +2020,7 @@ using ComputeAbsFOpPattern = ComputeAbsOpPattern; using ComputeAbsIOpPattern = ComputeAbsOpPattern; template -struct LowerExtOpPattern : public OpConversionPattern { +struct LowerExtOpPattern : OpConversionPattern { using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename SrcOpTy::Adaptor; @@ -2064,10 +2039,10 @@ struct LowerExtOpPattern : public OpConversionPattern { extOp.getLoc(), rewriter.getI32IntegerAttr(0)); rewriter.replaceOpWithNewOp( extOp, dstType, upsOp.getResult(), shiftParamOp.getResult()); - } else { + } else rewriter.replaceOpWithNewOp( extOp, dstType, upsOp.getResult(), /*isResAcc*/ false); - } + return success(); } }; @@ -2076,7 +2051,7 @@ using LowerExtFOpPattern = LowerExtOpPattern; using LowerExtSIOpPattern = LowerExtOpPattern; template -struct LowerTruncOpPattern : public OpConversionPattern { +struct LowerTruncOpPattern : OpConversionPattern { using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename SrcOpTy::Adaptor; @@ -2106,6 +2081,7 @@ struct LowerTruncOpPattern : public OpConversionPattern { rewriter.replaceOpWithNewOp( truncOp, dstType, castOp.getResult(), shiftParamOp.getResult()); } + return success(); } }; @@ -2129,21 +2105,18 @@ using LowerTruncIOpPattern = LowerTruncOpPattern; template static bool hasSigmoidComputationChain(DivFOpTy divfOp, arith::NegFOp &negOp) { auto constOp = dyn_cast(divfOp.getLhs().getDefiningOp()); - if (!constOp) { + if (!constOp) return false; - } auto cstDense = dyn_cast(constOp.getValue()); - if (!cstDense) { + if (!cstDense) return false; - } - if (cstDense.template getSplatValue().convertToFloat() != 1.0f) { + if (cstDense.template getSplatValue().convertToFloat() != 1.0f) return false; - } - Operation *addLvalOp = nullptr; - Operation *addRvalOp = nullptr; + Operation *addLvalOp; + Operation *addRvalOp; // divfOp's rval could be an arith::AddFOp or the pattern like- // %1 = aievec.ups %a // %2 = aievec.ups %b; @@ -2152,19 +2125,19 @@ static bool hasSigmoidComputationChain(DivFOpTy divfOp, arith::NegFOp &negOp) { auto addOp = dyn_cast(divfOp.getRhs().getDefiningOp()); if (!addOp) { auto srsOp = dyn_cast(divfOp.getRhs().getDefiningOp()); - if (!srsOp) { + if (!srsOp) return false; - } + auto addElemOp = dyn_cast(srsOp.getSource().getDefiningOp()); - if (!addElemOp) { + if (!addElemOp) return false; - } + auto lUpsOp = dyn_cast(addElemOp.getLhs().getDefiningOp()); auto rUpsOp = dyn_cast(addElemOp.getRhs().getDefiningOp()); - if (!lUpsOp || !rUpsOp) { + if (!lUpsOp || !rUpsOp) return false; - } + addLvalOp = lUpsOp.getSource().getDefiningOp(); addRvalOp = rUpsOp.getSource().getDefiningOp(); // One of add operation's operand is a constant op and another operand could @@ -2172,13 +2145,13 @@ static bool hasSigmoidComputationChain(DivFOpTy divfOp, arith::NegFOp &negOp) { auto addDefOp = isa(addLvalOp) ? dyn_cast(addRvalOp) : dyn_cast(addLvalOp); - if (!addDefOp) { + if (!addDefOp) addLvalOp = isa(addLvalOp) ? dyn_cast(addRvalOp) : dyn_cast(addLvalOp); - } else { + else addLvalOp = addDefOp.getSource().getDefiningOp(); - } + addRvalOp = isa(addLvalOp) ? lUpsOp.getSource().getDefiningOp() : rUpsOp.getSource().getDefiningOp(); @@ -2187,50 +2160,45 @@ static bool hasSigmoidComputationChain(DivFOpTy divfOp, arith::NegFOp &negOp) { addRvalOp = addOp.getRhs().getDefiningOp(); } - if (!addLvalOp || !addRvalOp) { + if (!addLvalOp || !addRvalOp) return false; - } if (!((isa(addLvalOp) && isa(addRvalOp)) || (isa(addRvalOp) && isa(addLvalOp)) || - (isa(addLvalOp) && - cast(addLvalOp).getCallee() == "getExpBf16" && + (isa(addLvalOp) && + cast(addLvalOp).getCallee() == "getExpBf16" && isa(addRvalOp)) || - (isa(addRvalOp) && - cast(addRvalOp).getCallee() == "getExpBf16" && - isa(addLvalOp)))) { + (isa(addRvalOp) && + cast(addRvalOp).getCallee() == "getExpBf16" && + isa(addLvalOp)))) return false; - } constOp = isa(addLvalOp) ? cast(addLvalOp) : cast(addRvalOp); cstDense = dyn_cast(constOp.getValue()); - if (!cstDense) { + if (!cstDense) return false; - } - - if (cstDense.template getSplatValue().convertToFloat() != 1.0f) { + if (cstDense.template getSplatValue().convertToFloat() != 1.0f) return false; - } auto expOp = isa(addLvalOp) ? cast(addLvalOp) - : (isa(addLvalOp) - ? cast(addLvalOp) + : (isa(addLvalOp) + ? cast(addLvalOp) : (isa(addRvalOp) ? cast(addRvalOp) - : cast(addRvalOp))); + : cast(addRvalOp))); - auto expOperand = isa(expOp) - ? cast(expOp).getOperand() - : *(cast(expOp).getOperands().begin()); + auto expOperand = + isa(expOp) + ? cast(expOp).getOperand() + : *(cast(expOp).getOperands().begin()); negOp = dyn_cast(expOperand.getDefiningOp()); - if (!negOp) { + if (!negOp) return false; - } return true; } @@ -2248,44 +2216,38 @@ static bool hasSigmoidComputationChain(DivFOpTy divfOp, arith::NegFOp &negOp) { // // to a function call to compute sigmoid value for v16bfloat16 and // v32bfloat16 types -struct ComputeSigmoidOpPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ComputeSigmoidOpPattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::DivFOp divfOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType srcType = dyn_cast(adaptor.getLhs().getType()); - - if (!srcType) { + auto srcType = dyn_cast(adaptor.getLhs().getType()); + if (!srcType) return failure(); - } Type scalarType = srcType.getElementType(); - if (!isa(scalarType)) { + if (!isa(scalarType)) return failure(); - } unsigned laneSize = getVectorLaneSize(srcType); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); - - if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) { + if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) return failure(); - } arith::NegFOp negOp = nullptr; - if (!hasSigmoidComputationChain(adaptor, negOp)) { + if (!hasSigmoidComputationChain(adaptor, negOp)) return failure(); - } StringRef includeName = "vec_math.h"; - ModuleOp moduleOp = divfOp->getParentOfType(); + auto moduleOp = divfOp->getParentOfType(); rewriter.setInsertionPointToStart( &moduleOp.getRegion().getBlocks().front()); rewriter.create(moduleOp.getLoc(), includeName, false); rewriter.setInsertionPoint(divfOp); SmallVector sigmoidOperands = {negOp.getOperand()}; - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( divfOp, TypeRange{adaptor.getLhs().getType()}, "getSigmoidBf16", nullptr, nullptr, sigmoidOperands); @@ -2294,106 +2256,95 @@ struct ComputeSigmoidOpPattern : public OpConversionPattern { }; // Convert math.ceil to a function call to compute ceil(x) for v16bfloat16 -struct ComputeCeilOpPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ComputeCeilOpPattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(math::CeilOp ceilOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType srcType = dyn_cast(ceilOp.getOperand().getType()); - if (!srcType) { + auto srcType = dyn_cast(ceilOp.getOperand().getType()); + if (!srcType) return failure(); - } Type scalarType = srcType.getElementType(); - - if (!isa(scalarType)) { + if (!isa(scalarType)) return failure(); - } + unsigned laneSize = getVectorLaneSize(srcType); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); - - if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) { + if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) return failure(); - } StringRef includeName = "vec_math.h"; - ModuleOp moduleOp = ceilOp->getParentOfType(); + auto moduleOp = ceilOp->getParentOfType(); rewriter.setInsertionPointToStart( &moduleOp.getRegion().getBlocks().front()); rewriter.create(moduleOp.getLoc(), includeName, false); rewriter.setInsertionPoint(ceilOp); SmallVector ceilOperands = {adaptor.getOperand()}; - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( ceilOp, TypeRange{ceilOp.getResult().getType()}, "getCeilBf16", nullptr, nullptr, ceilOperands); + return success(); } }; // Convert math.floor to a function call to compute floor(x) for v16bfloat16 -struct ComputeFloorOpPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ComputeFloorOpPattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(math::FloorOp floorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType srcType = dyn_cast(floorOp.getOperand().getType()); - if (!srcType) { + auto srcType = dyn_cast(floorOp.getOperand().getType()); + if (!srcType) return failure(); - } Type scalarType = srcType.getElementType(); - - if (!isa(scalarType)) { + if (!isa(scalarType)) return failure(); - } + unsigned laneSize = getVectorLaneSize(srcType); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); - - if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) { + if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) return failure(); - } StringRef includeName = "vec_math.h"; - ModuleOp moduleOp = floorOp->getParentOfType(); + auto moduleOp = floorOp->getParentOfType(); rewriter.setInsertionPointToStart( &moduleOp.getRegion().getBlocks().front()); rewriter.create(moduleOp.getLoc(), includeName, false); rewriter.setInsertionPoint(floorOp); SmallVector floorOperands = {adaptor.getOperand()}; - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( floorOp, TypeRange{floorOp.getResult().getType()}, "getFloorBf16", nullptr, nullptr, floorOperands); + return success(); } }; // Convert arith.negf to aievec.neg to negate the vector for v16bfloat16 and // v16float types. -struct ComputeNegOpPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ComputeNegOpPattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::NegFOp negOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType srcType = dyn_cast(negOp.getOperand().getType()); - if (!srcType) { + auto srcType = dyn_cast(negOp.getOperand().getType()); + if (!srcType) return failure(); - } Type scalarType = srcType.getElementType(); - - if (!isa(scalarType)) { + if (!isa(scalarType)) return failure(); - } - unsigned laneSize = getVectorLaneSize(srcType); - if (laneSize != 16) { + if (unsigned laneSize = getVectorLaneSize(srcType); laneSize != 16) return failure(); - } Location loc = negOp.getLoc(); auto accType = getVectorOpDestType(srcType, /*AIEML =*/true); @@ -2402,10 +2353,8 @@ struct ComputeNegOpPattern : public OpConversionPattern { if (elWidth == 16) { auto upsOp = rewriter.create(loc, accType, adaptor.getOperand()); - auto aieNegOp = rewriter.create(loc, accType, upsOp.getResult()); - auto shiftParamOp = rewriter.create( negOp.getLoc(), rewriter.getI32IntegerAttr(0)); rewriter.replaceOpWithNewOp( @@ -2418,6 +2367,7 @@ struct ComputeNegOpPattern : public OpConversionPattern { rewriter.replaceOpWithNewOp( negOp, srcType, aieNegOp.getResult(), /*isResAcc*/ false); } + return success(); } }; @@ -2425,48 +2375,42 @@ struct ComputeNegOpPattern : public OpConversionPattern { // Check whether the value of constant operation is int type and the dense value // is -1. static bool hasConstNegOneValue(arith::ConstantOp constOp, unsigned elWidth) { - if (!constOp) { + if (!constOp) return false; - } + auto cstDense = dyn_cast(constOp.getValue()); - if (!cstDense) { + if (!cstDense) return false; - } - if (elWidth == 32) { + if (elWidth == 32) return cstDense.getSplatValue() == -1; - } else if (elWidth == 16) { + if (elWidth == 16) return cstDense.getSplatValue() == -1; - } else if (elWidth == 8) { + if (elWidth == 8) return cstDense.getSplatValue() == -1; - } return false; } // Convert arith.xori to aievec.bxor to compute bitwise xor of two vectors for // integer types -struct ComputeBxorAndBnegOpPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ComputeBxorAndBnegOpPattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::XOrIOp xorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType srcType = dyn_cast(xorOp.getLhs().getType()); - if (!srcType) { + auto srcType = dyn_cast(xorOp.getLhs().getType()); + if (!srcType) return failure(); - } Type scalarType = srcType.getElementType(); - - if (!isa(scalarType)) { + if (!isa(scalarType)) return failure(); - } + unsigned laneSize = getVectorLaneSize(srcType); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); - - if (laneSize * elWidth != 512) { + if (laneSize * elWidth != 512) return failure(); - } auto lhsConstOp = dyn_cast(xorOp.getLhs().getDefiningOp()); @@ -2480,16 +2424,16 @@ struct ComputeBxorAndBnegOpPattern : public OpConversionPattern { Value val = hasConstNegOneValue(lhsConstOp, elWidth) ? adaptor.getRhs() : adaptor.getLhs(); rewriter.replaceOpWithNewOp(xorOp, srcType, val); - } else { + } else rewriter.replaceOpWithNewOp( xorOp, srcType, adaptor.getLhs(), adaptor.getRhs()); - } + return success(); } }; template -struct ComputeBandAndBorOpPattern : public OpConversionPattern { +struct ComputeBandAndBorOpPattern : OpConversionPattern { using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename SrcOpTy::Adaptor; @@ -2497,18 +2441,17 @@ struct ComputeBandAndBorOpPattern : public OpConversionPattern { matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType srcType = dyn_cast(srcOp.getLhs().getType()); - if (!srcType) { + if (!srcType) return failure(); - } + Type scalarType = srcType.getElementType(); - if (!isa(scalarType)) { + if (!isa(scalarType)) return failure(); - } + unsigned laneSize = getVectorLaneSize(srcType); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); - if (laneSize * elWidth != 512) { + if (laneSize * elWidth != 512) return failure(); - } rewriter.replaceOpWithNewOp(srcOp, srcOp.getResult().getType(), adaptor.getLhs(), adaptor.getRhs()); @@ -2526,86 +2469,70 @@ using ComputeBandOpPattern = // arithmetic right shift for integer types. Currently, only support the shift // value with a broadcast vector. struct ComputeSignedIntRightShiftOpPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::ShRSIOp rsOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType srcType = dyn_cast(adaptor.getLhs().getType()); - if (!srcType) { + auto srcType = dyn_cast(adaptor.getLhs().getType()); + if (!srcType) return failure(); - } Type scalarType = srcType.getElementType(); - unsigned laneSize = getVectorLaneSize(srcType); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); - - if (laneSize * elWidth != 512) { + if (laneSize * elWidth != 512) return failure(); - } auto bcastOp = dyn_cast(adaptor.getRhs().getDefiningOp()); - - if (!bcastOp) { + if (!bcastOp) return failure(); - } - arith::ConstantOp constOp = rewriter.create( + auto constOp = rewriter.create( bcastOp.getLoc(), rewriter.getI32IntegerAttr(bcastOp.getIdx())); - auto extElemOp = rewriter.create( bcastOp.getLoc(), scalarType, bcastOp, constOp.getResult()); - Location loc = rsOp.getLoc(); // The vector with v64int8 type can be divided into two v32int8 vectors and // be processed individually and be concatenated at the end. if (elWidth == 8) { VectorType halfSrcType = createVectorType(laneSize / 2, scalarType); - auto rsOpLow = rewriter.create(loc, halfSrcType, adaptor.getLhs(), 0); auto rsOpHigh = rewriter.create(loc, halfSrcType, adaptor.getLhs(), 1); - Type accType = getVectorOpDestType(halfSrcType, /*AIEML =*/true); - auto upsOpLow = rewriter.create(loc, accType, rsOpLow.getResult()); - auto srsOpLow = rewriter.create( loc, halfSrcType, upsOpLow.getResult(), extElemOp.getResult()); - auto upsOpHigh = rewriter.create(loc, accType, rsOpHigh.getResult()); - auto srsOpHigh = rewriter.create( loc, halfSrcType, upsOpHigh.getResult(), extElemOp.getResult()); - SmallVector inputSources = {srsOpLow.getResult(), srsOpHigh.getResult()}; rewriter.replaceOpWithNewOp(rsOp, srcType, inputSources); } else { Type accType = getVectorOpDestType(srcType, /*AIEML =*/true); - auto upsOp = rewriter.create(loc, accType, adaptor.getLhs()); - rewriter.replaceOpWithNewOp( rsOp, srcType, upsOp.getResult(), extElemOp.getResult()); } + return success(); } }; // Convert a `vector.contract` op to an `aievec.matmul` op for AIEml struct LowerVectorContractionOpToAIEVecMatMulPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ContractionOp contractOp, OpAdaptor adaptor, @@ -2626,20 +2553,20 @@ struct LowerVectorContractionOpToAIEVecMatMulPattern // There is a possibility that, when the linalg op is converted to // contractions, lower precisions operands are cast to the target // precission outside the contraction. For those cases, we check. - if (auto lhsExtSIOp = lhs.getDefiningOp()) { + if (auto lhsExtSIOp = lhs.getDefiningOp()) lhs = lhsExtSIOp.getIn(); - } else if (auto lhsExtUIOp = lhs.getDefiningOp()) { + else if (auto lhsExtUIOp = lhs.getDefiningOp()) lhs = lhsExtUIOp.getIn(); - } else if (auto lhsExtFOp = lhs.getDefiningOp()) { + else if (auto lhsExtFOp = lhs.getDefiningOp()) lhs = lhsExtFOp.getIn(); - } - if (auto rhsExtSIOp = rhs.getDefiningOp()) { + + if (auto rhsExtSIOp = rhs.getDefiningOp()) rhs = rhsExtSIOp.getIn(); - } else if (auto rhsExtUIOp = rhs.getDefiningOp()) { + else if (auto rhsExtUIOp = rhs.getDefiningOp()) rhs = rhsExtUIOp.getIn(); - } else if (auto rhsExtFOp = rhs.getDefiningOp()) { + else if (auto rhsExtFOp = rhs.getDefiningOp()) rhs = rhsExtFOp.getIn(); - } + matmulOp = rewriter.create( contractOp.getLoc(), contractOp.getResult().getType(), lhs, rhs, acc); @@ -2658,8 +2585,7 @@ struct LowerVectorContractionOpToAIEVecMatMulPattern // Pattern collection //===----------------------------------------------------------------------===// -static void populateAIEVecV1ConversionPatterns(RewritePatternSet &patterns, - AnalysisManager &am) { +static void populateAIEVecV1ConversionPatterns(RewritePatternSet &patterns) { patterns.add(patterns.getContext(), 128, 512, 128, 256); // clang-format off @@ -2675,8 +2601,7 @@ static void populateAIEVecV1ConversionPatterns(RewritePatternSet &patterns, // clang-format on } -static void populateAIEVecV2ConversionPatterns(RewritePatternSet &patterns, - AnalysisManager &am) { +static void populateAIEVecV2ConversionPatterns(RewritePatternSet &patterns) { patterns.add(patterns.getContext(), 128, 1024, 256, 1024); // clang-format off @@ -2731,325 +2656,263 @@ static void populateAIEVecV2ConversionPatterns(RewritePatternSet &patterns, // TODO: Review the validity of these legalizations beyond basic cases. static bool isInSigmoidOperationChain(math::ExpOp expOp) { - auto negOp = dyn_cast(expOp.getOperand().getDefiningOp()); - if (!negOp) { + if (auto negOp = dyn_cast(expOp.getOperand().getDefiningOp()); + !negOp) return false; - } arith::AddFOp addOp = nullptr; for (Operation *user : expOp->getUsers()) { addOp = dyn_cast(user); - if (addOp) { + if (addOp) break; - } } - if (!addOp) { + if (!addOp) return false; - } auto addLvalOp = addOp.getLhs().getDefiningOp(); auto addRvalOp = addOp.getRhs().getDefiningOp(); - if (!((isa(addLvalOp) && isa(addRvalOp)) || - (isa(addRvalOp) && isa(addLvalOp)))) { + (isa(addRvalOp) && isa(addLvalOp)))) return false; - } auto constOp = isa(addLvalOp) ? cast(addLvalOp) : cast(addRvalOp); auto cstDense = dyn_cast(constOp.getValue()); - if (!cstDense) { + if (!cstDense) return false; - } - if (cstDense.getSplatValue().convertToFloat() != 1.0f) { + if (cstDense.getSplatValue().convertToFloat() != 1.0f) return false; - } arith::DivFOp divOp = nullptr; for (Operation *user : addOp->getUsers()) { divOp = dyn_cast(user); - if (divOp) { + if (divOp) break; - } } - if (!divOp) { + if (!divOp) return false; - } constOp = dyn_cast(divOp.getLhs().getDefiningOp()); - if (!constOp) { + if (!constOp) return false; - } - cstDense = dyn_cast(constOp.getValue()); - if (!cstDense) { + if (!cstDense) return false; - } - - if (cstDense.getSplatValue().convertToFloat() != 1.0f) { + if (cstDense.getSplatValue().convertToFloat() != 1.0f) return false; - } + return true; } -static void configureAIEVecCommonLegalizations(ConversionTarget &target, - AnalysisManager &am) { +static void configureAIEVecCommonLegalizations(ConversionTarget &target) { target.addLegalDialect(); target.addIllegalOp(); target.addIllegalOp(); target.addDynamicallyLegalOp([](math::ExpOp expOp) { - VectorType srcType = dyn_cast(expOp.getOperand().getType()); - if (!srcType) { + auto srcType = dyn_cast(expOp.getOperand().getType()); + if (!srcType) return true; - } + Type scalarType = srcType.getElementType(); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); unsigned laneSize = getVectorLaneSize(srcType); if (!isa(scalarType) || laneSize != 16 || elWidth != 16) return true; - - if (expOp->hasOneUse() && isInSigmoidOperationChain(expOp)) { + if (expOp->hasOneUse() && isInSigmoidOperationChain(expOp)) return true; - } return false; }); target.addDynamicallyLegalOp([](math::TanhOp tanhOp) { - VectorType srcType = dyn_cast(tanhOp.getOperand().getType()); - if (!srcType) { + auto srcType = dyn_cast(tanhOp.getOperand().getType()); + if (!srcType) return true; - } Type scalarType = srcType.getElementType(); - - if (!isa(scalarType)) { + if (!isa(scalarType)) return true; - } unsigned laneSize = getVectorLaneSize(srcType); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); - if (elWidth != 16 || laneSize != 16) { + if (elWidth != 16 || laneSize != 16) return true; - } return false; }); target.addDynamicallyLegalOp([](math::SqrtOp sqrtOp) { - VectorType srcType = dyn_cast(sqrtOp.getOperand().getType()); - if (!srcType) { + auto srcType = dyn_cast(sqrtOp.getOperand().getType()); + if (!srcType) return true; - } Type scalarType = srcType.getElementType(); - - if (!isa(scalarType)) { + if (!isa(scalarType)) return true; - } unsigned laneSize = getVectorLaneSize(srcType); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); - if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) { + if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) return true; - } return false; }); target.addDynamicallyLegalOp([](math::RsqrtOp rsqrtOp) { - VectorType srcType = dyn_cast(rsqrtOp.getOperand().getType()); + auto srcType = dyn_cast(rsqrtOp.getOperand().getType()); Type scalarType = srcType.getElementType(); - if (!srcType || !isa(scalarType)) { + if (!srcType || !isa(scalarType)) return true; - } unsigned laneSize = getVectorLaneSize(srcType); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); - if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) { + if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) return true; - } return false; }); target.addDynamicallyLegalOp([](math::ErfOp erfOp) { - VectorType srcType = dyn_cast(erfOp.getOperand().getType()); - if (!srcType) { + auto srcType = dyn_cast(erfOp.getOperand().getType()); + if (!srcType) return true; - } Type scalarType = srcType.getElementType(); - - if (!isa(scalarType)) { + if (!isa(scalarType)) return true; - } unsigned laneSize = getVectorLaneSize(srcType); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); - if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) { + if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) return true; - } return false; }); target.addDynamicallyLegalOp([](math::AbsFOp absfOp) { - VectorType srcType = dyn_cast(absfOp.getOperand().getType()); - if (!srcType) { + auto srcType = dyn_cast(absfOp.getOperand().getType()); + if (!srcType) return true; - } Type scalarType = srcType.getElementType(); unsigned laneSize = getVectorLaneSize(srcType); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); - if (elWidth * laneSize != 512 && elWidth * laneSize != 256) { + if (elWidth * laneSize != 512 && elWidth * laneSize != 256) return true; - } return false; }); target.addDynamicallyLegalOp([](math::AbsIOp absiOp) { - VectorType srcType = dyn_cast(absiOp.getOperand().getType()); - if (!srcType) { + auto srcType = dyn_cast(absiOp.getOperand().getType()); + if (!srcType) return true; - } Type scalarType = srcType.getElementType(); unsigned laneSize = getVectorLaneSize(srcType); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); - if (elWidth * laneSize != 512 && elWidth * laneSize != 256) { + if (elWidth * laneSize != 512 && elWidth * laneSize != 256) return true; - } return false; }); target.addDynamicallyLegalOp([](arith::DivFOp divfOp) { - VectorType srcType = dyn_cast(divfOp.getLhs().getType()); - - if (!srcType) { + if (auto srcType = dyn_cast(divfOp.getLhs().getType()); + !srcType) { Type scalarType = divfOp.getLhs().getType(); - if (!divfOp->hasOneUse() || !isa(scalarType)) { + if (!divfOp->hasOneUse() || !isa(scalarType)) return true; - } - - if (!isa(*divfOp->getUsers().begin())) { + if (!isa(*divfOp->getUsers().begin())) return true; - } - FloatType fType = cast(scalarType); - - if (fType.getWidth() != 32) { + auto fType = cast(scalarType); + if (fType.getWidth() != 32) return true; - } auto constOp = dyn_cast(divfOp.getLhs().getDefiningOp()); if (!constOp || constOp.getValue().cast().getValue().convertToDouble() != - 1.0f) { + 1.0f) return true; - } } else { - Type scalarType = srcType.getElementType(); - if (!isa(scalarType)) { + if (!isa(scalarType)) return true; - } unsigned laneSize = getVectorLaneSize(srcType); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); - if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) { + if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) return true; - } arith::NegFOp negOp = nullptr; - if (!hasSigmoidComputationChain(divfOp, negOp)) { + if (!hasSigmoidComputationChain(divfOp, negOp)) return true; - } } + return false; }); target.addDynamicallyLegalOp([](math::CeilOp ceilOp) { - VectorType srcType = dyn_cast(ceilOp.getOperand().getType()); - if (!srcType) { + auto srcType = dyn_cast(ceilOp.getOperand().getType()); + if (!srcType) return true; - } - Type scalarType = srcType.getElementType(); - - if (!isa(scalarType)) { + if (!isa(scalarType)) return true; - } unsigned laneSize = getVectorLaneSize(srcType); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); - if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) { + if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) return true; - } return false; }); target.addDynamicallyLegalOp([](math::FloorOp floorOp) { - VectorType srcType = dyn_cast(floorOp.getOperand().getType()); - if (!srcType) { + auto srcType = dyn_cast(floorOp.getOperand().getType()); + if (!srcType) return true; - } - Type scalarType = srcType.getElementType(); - - if (!isa(scalarType)) { + if (!isa(scalarType)) return true; - } unsigned laneSize = getVectorLaneSize(srcType); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); - if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) { + if (elWidth != 16 || (laneSize != 16 && laneSize != 32)) return true; - } return false; }); target.addDynamicallyLegalOp([](arith::NegFOp negOp) { - VectorType srcType = dyn_cast(negOp.getOperand().getType()); - if (!srcType) { + auto srcType = dyn_cast(negOp.getOperand().getType()); + if (!srcType) return true; - } - - Type scalarType = srcType.getElementType(); - - if (!isa(scalarType)) { + if (Type scalarType = srcType.getElementType(); !isa(scalarType)) return true; - } unsigned laneSize = getVectorLaneSize(srcType); return laneSize != 16; }); target.addDynamicallyLegalOp([](arith::XOrIOp xorOp) { - VectorType srcType = dyn_cast(xorOp.getLhs().getType()); - if (!srcType) { + auto srcType = dyn_cast(xorOp.getLhs().getType()); + if (!srcType) return true; - } - Type scalarType = srcType.getElementType(); - - if (!isa(scalarType)) { + if (!isa(scalarType)) return true; - } + unsigned laneSize = getVectorLaneSize(srcType); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); @@ -3057,16 +2920,13 @@ static void configureAIEVecCommonLegalizations(ConversionTarget &target, }); target.addDynamicallyLegalOp([](arith::OrIOp orOp) { - VectorType srcType = dyn_cast(orOp.getLhs().getType()); - if (!srcType) { + auto srcType = dyn_cast(orOp.getLhs().getType()); + if (!srcType) return true; - } - Type scalarType = srcType.getElementType(); - - if (!isa(scalarType)) { + if (!isa(scalarType)) return true; - } + unsigned laneSize = getVectorLaneSize(srcType); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); @@ -3074,11 +2934,9 @@ static void configureAIEVecCommonLegalizations(ConversionTarget &target, }); target.addDynamicallyLegalOp([](arith::ShRSIOp rsOp) { - VectorType srcType = dyn_cast(rsOp.getLhs().getType()); - if (!srcType) { + auto srcType = dyn_cast(rsOp.getLhs().getType()); + if (!srcType) return true; - } - Type scalarType = srcType.getElementType(); unsigned laneSize = getVectorLaneSize(srcType); @@ -3088,16 +2946,13 @@ static void configureAIEVecCommonLegalizations(ConversionTarget &target, }); target.addDynamicallyLegalOp([](arith::AndIOp andOp) { - VectorType srcType = dyn_cast(andOp.getLhs().getType()); - if (!srcType) { + auto srcType = dyn_cast(andOp.getLhs().getType()); + if (!srcType) return true; - } - Type scalarType = srcType.getElementType(); - - if (!isa(scalarType)) { + if (!isa(scalarType)) return true; - } + unsigned laneSize = getVectorLaneSize(srcType); unsigned elWidth = scalarType.getIntOrFloatBitWidth(); @@ -3114,8 +2969,7 @@ static void configureAIEVecCommonLegalizations(ConversionTarget &target, [](arith::SubFOp op) { return !isa(op.getType()); }); } -static void configureAIEVecV1Legalizations(ConversionTarget &target, - AnalysisManager &am) { +static void configureAIEVecV1Legalizations(ConversionTarget &target) { target.addDynamicallyLegalOp( [](arith::MulIOp op) { return !isa(op.getType()); }); target.addDynamicallyLegalOp( @@ -3127,9 +2981,9 @@ static void configureAIEVecV1Legalizations(ConversionTarget &target, concatOp = dyn_cast(op.getLhs().getDefiningOp()); if (!concatOp) return true; + vector::BroadcastOp srcBcast = nullptr; - auto lhsOp = concatOp.getSources()[0].getDefiningOp(); - if (lhsOp) + if (auto lhsOp = concatOp.getSources()[0].getDefiningOp()) srcBcast = dyn_cast(lhsOp); if (!srcBcast) { auto rhsOp = op.getRhs().getDefiningOp(); @@ -3137,13 +2991,14 @@ static void configureAIEVecV1Legalizations(ConversionTarget &target, return true; srcBcast = dyn_cast(rhsOp); } - if (srcBcast) { - auto srcOp = srcBcast.getSource().getDefiningOp(); - if (srcOp) + + if (srcBcast) + if (auto srcOp = srcBcast.getSource().getDefiningOp()) return !isa(srcOp); - } + return true; }); + target.addDynamicallyLegalOp([](aievec::AddOp op) { auto lSrsOp = op.getLhs().getDefiningOp(); auto rSrsOp = op.getRhs().getDefiningOp(); @@ -3153,8 +3008,7 @@ static void configureAIEVecV1Legalizations(ConversionTarget &target, target.addLegalDialect(); } -static void configureAIEVecV2Legalizations(ConversionTarget &target, - AnalysisManager &am) { +static void configureAIEVecV2Legalizations(ConversionTarget &target) { target.addLegalOp(); // A set recording the vector lane size and element width supported @@ -3172,9 +3026,9 @@ static void configureAIEVecV2Legalizations(ConversionTarget &target, target.addDynamicallyLegalOp([=](arith::AddIOp op) { auto resultType = dyn_cast(op.getType()); - if (!resultType) { + if (!resultType) return true; - } + auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth(); unsigned laneSize = getVectorLaneSize(resultType); @@ -3184,9 +3038,8 @@ static void configureAIEVecV2Legalizations(ConversionTarget &target, target.addDynamicallyLegalOp([=](arith::SubIOp op) { auto resultType = dyn_cast(op.getType()); - if (!resultType) { + if (!resultType) return true; - } auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth(); unsigned laneSize = getVectorLaneSize(resultType); @@ -3196,27 +3049,26 @@ static void configureAIEVecV2Legalizations(ConversionTarget &target, target.addDynamicallyLegalOp([](arith::AddFOp op) { auto resultType = dyn_cast(op.getType()); - if (!resultType) { + if (!resultType) return true; - } + unsigned laneSize = getVectorLaneSize(resultType); return laneSize != 16; }); target.addDynamicallyLegalOp([](arith::SubFOp op) { auto resultType = dyn_cast(op.getType()); - if (!resultType) { + if (!resultType) return true; - } + unsigned laneSize = getVectorLaneSize(resultType); return laneSize != 16; }); target.addDynamicallyLegalOp([](arith::MulIOp op) { auto resultType = dyn_cast(op.getType()); - if (!resultType) { + if (!resultType) return true; - } auto isAddOp = [&](Operation *op) { return isa(op); }; // Verify it is not a part of MAC if (op->hasOneUse() && llvm::any_of(op->getUsers(), isAddOp)) @@ -3231,9 +3083,9 @@ static void configureAIEVecV2Legalizations(ConversionTarget &target, target.addDynamicallyLegalOp([](arith::MulFOp op) { auto resultType = dyn_cast(op.getType()); - if (!resultType) { + if (!resultType) return true; - } + auto isAddOp = [&](Operation *op) { return isa(op); }; // Verify it is not a part of FMA if (op->hasOneUse() && llvm::any_of(op->getUsers(), isAddOp)) @@ -3242,14 +3094,14 @@ static void configureAIEVecV2Legalizations(ConversionTarget &target, auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth(); unsigned laneSize = getVectorLaneSize(resultType); - return (laneSize != 16 || (resultElWidth != 16 && resultElWidth != 32)); + return laneSize != 16 || (resultElWidth != 16 && resultElWidth != 32); }); target.addDynamicallyLegalOp([=](arith::MinSIOp op) { auto resultType = dyn_cast(op.getType()); - if (!resultType) { + if (!resultType) return true; - } + auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth(); unsigned laneSize = getVectorLaneSize(resultType); @@ -3259,9 +3111,9 @@ static void configureAIEVecV2Legalizations(ConversionTarget &target, target.addDynamicallyLegalOp([=](arith::MaxSIOp op) { auto resultType = dyn_cast(op.getType()); - if (!resultType) { + if (!resultType) return true; - } + auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth(); unsigned laneSize = getVectorLaneSize(resultType); @@ -3271,9 +3123,9 @@ static void configureAIEVecV2Legalizations(ConversionTarget &target, target.addDynamicallyLegalOp([=](arith::MinimumFOp op) { auto resultType = dyn_cast(op.getType()); - if (!resultType) { + if (!resultType) return true; - } + auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth(); unsigned laneSize = getVectorLaneSize(resultType); @@ -3283,9 +3135,9 @@ static void configureAIEVecV2Legalizations(ConversionTarget &target, target.addDynamicallyLegalOp([=](arith::MaximumFOp op) { auto resultType = dyn_cast(op.getType()); - if (!resultType) { + if (!resultType) return true; - } + auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth(); unsigned laneSize = getVectorLaneSize(resultType); @@ -3295,64 +3147,58 @@ static void configureAIEVecV2Legalizations(ConversionTarget &target, target.addDynamicallyLegalOp([=](arith::CmpIOp op) { auto lhsType = dyn_cast(op.getLhs().getType()); - if (!lhsType) { + if (!lhsType) return true; - } + auto lhsElWidth = lhsType.getElementType().getIntOrFloatBitWidth(); unsigned laneSize = getVectorLaneSize(lhsType); - if (!(elWidthSet.count(lhsElWidth) && laneSize * lhsElWidth == 512)) { + if (!(elWidthSet.count(lhsElWidth) && laneSize * lhsElWidth == 512)) return true; - } return false; }); target.addDynamicallyLegalOp([=](arith::CmpFOp op) { auto lhsType = dyn_cast(op.getLhs().getType()); - if (!lhsType) { + if (!lhsType) return true; - } + auto lhsElWidth = lhsType.getElementType().getIntOrFloatBitWidth(); unsigned laneSize = getVectorLaneSize(lhsType); - if (!(elWidthSet.count(lhsElWidth) && laneSize * lhsElWidth == 512)) { + if (!(elWidthSet.count(lhsElWidth) && laneSize * lhsElWidth == 512)) return true; - } return false; }); target.addDynamicallyLegalOp([=](arith::SelectOp op) { auto resultType = dyn_cast(op.getType()); - if (!resultType) { + if (!resultType) return true; - } + auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth(); unsigned laneSize = getVectorLaneSize(resultType); - if (!(elWidthSet.count(resultElWidth) && laneSize * resultElWidth == 512)) { + if (!(elWidthSet.count(resultElWidth) && laneSize * resultElWidth == 512)) return true; - } return false; }); target.addDynamicallyLegalOp( [=](vector::ReductionOp op) { - auto kind = op.getKind(); - - if (kind != vector::CombiningKind::ADD && - kind != vector::CombiningKind::MINSI && - kind != vector::CombiningKind::MINUI && - kind != vector::CombiningKind::MINF && - kind != vector::CombiningKind::MAXSI && - kind != vector::CombiningKind::MAXUI && - kind != vector::CombiningKind::MAXF) { + if (auto kind = op.getKind(); kind != vector::CombiningKind::ADD && + kind != vector::CombiningKind::MINSI && + kind != vector::CombiningKind::MINUI && + kind != vector::CombiningKind::MINF && + kind != vector::CombiningKind::MAXSI && + kind != vector::CombiningKind::MAXUI && + kind != vector::CombiningKind::MAXF) return true; - } - VectorType vType = dyn_cast(op.getVector().getType()); + auto vType = dyn_cast(op.getVector().getType()); if (!vType) return true; @@ -3385,8 +3231,7 @@ static void configureAIEVecV2Legalizations(ConversionTarget &target, /// Lower incoming vector operations into their corresponding AIE vector /// intrinsics. -struct LowerVectorToAIEVec - : public PassWrapper> { +struct LowerVectorToAIEVec : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerVectorToAIEVec) LowerVectorToAIEVec() = default; @@ -3420,35 +3265,32 @@ struct LowerVectorToAIEVec MLIRContext *context = &getContext(); RewritePatternSet patterns(context); ConversionTarget target(*context); - AIEArch aieVersion = AIEArch::AIE; + auto aieVersion = AIEArch::AIE; if (!aieTarget.empty()) { std::string target = aieTarget; - if (target == "aieml") { + if (target == "aieml") aieVersion = AIEArch::AIE_ML; - } else if (target != "aie") { + else if (target != "aie") { op->emitError() << "unknown AIE target '" << aieTarget << "'"; - signalPassFailure(); - return; + return signalPassFailure(); } } - AnalysisManager am = getAnalysisManager(); - configureAIEVecCommonLegalizations(target, am); + configureAIEVecCommonLegalizations(target); if (aieVersion == AIEArch::AIE) { - populateAIEVecV1ConversionPatterns(patterns, am); - configureAIEVecV1Legalizations(target, am); + populateAIEVecV1ConversionPatterns(patterns); + configureAIEVecV1Legalizations(target); } else { - populateAIEVecV2ConversionPatterns(patterns, am); - configureAIEVecV2Legalizations(target, am); + populateAIEVecV2ConversionPatterns(patterns); + configureAIEVecV2Legalizations(target); } - if (failed(applyPartialConversion(op, target, std::move(patterns)))) { - signalPassFailure(); - } + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + return signalPassFailure(); } }; -static std::unique_ptr<::mlir::Pass> +static std::unique_ptr createLowerVectorToAIEVec(const LowerVectorToAIEVecOptions &options) { return std::make_unique(options); } @@ -3457,8 +3299,7 @@ createLowerVectorToAIEVec(const LowerVectorToAIEVecOptions &options) { // Custom canonicalization passes //===--------------------------------------------------------------------------- -struct ProcessExtOpsPass - : public PassWrapper> { +struct ProcessExtOpsPass : PassWrapper> { void runOnOperation() override { MLIRContext *context = &getContext(); @@ -3468,122 +3309,97 @@ struct ProcessExtOpsPass LowerTruncIOpPattern>(patterns.getContext()); target.addLegalDialect(); target.addDynamicallyLegalOp([](arith::ExtFOp extfOp) { - VectorType srcType = dyn_cast(extfOp.getIn().getType()); - VectorType dstType = dyn_cast(extfOp.getOut().getType()); - if (!srcType || !dstType) { + auto srcType = dyn_cast(extfOp.getIn().getType()); + auto dstType = dyn_cast(extfOp.getOut().getType()); + if (!srcType || !dstType) return true; - } Type srcScalarType = srcType.getElementType(); Type dstScalarType = dstType.getElementType(); - - if (!isa(srcScalarType) || !isa(dstScalarType)) { + if (!isa(srcScalarType) || !isa(dstScalarType)) return true; - } unsigned srcLaneSize = getVectorLaneSize(srcType); unsigned dstLaneSize = getVectorLaneSize(dstType); - unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth(); unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth(); - if (srcElWidth != 16 || srcLaneSize != 16 || dstElWidth != 32 || - dstLaneSize != 16) { + dstLaneSize != 16) return true; - } return false; }); target.addDynamicallyLegalOp([](arith::ExtSIOp extsiOp) { - VectorType srcType = dyn_cast(extsiOp.getIn().getType()); - VectorType dstType = dyn_cast(extsiOp.getOut().getType()); - if (!srcType || !dstType) { + auto srcType = dyn_cast(extsiOp.getIn().getType()); + auto dstType = dyn_cast(extsiOp.getOut().getType()); + if (!srcType || !dstType) return true; - } Type srcScalarType = srcType.getElementType(); Type dstScalarType = dstType.getElementType(); - - if (!isa(srcScalarType) || - !isa(dstScalarType)) { + if (!isa(srcScalarType) || !isa(dstScalarType)) return true; - } unsigned srcLaneSize = getVectorLaneSize(srcType); unsigned dstLaneSize = getVectorLaneSize(dstType); - unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth(); unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth(); - if (!(srcLaneSize == 32 && (dstElWidth > srcElWidth) && - (dstLaneSize == srcLaneSize))) { + (dstLaneSize == srcLaneSize))) return true; - } return false; }); target.addDynamicallyLegalOp([](arith::TruncFOp truncfOp) { - VectorType srcType = dyn_cast(truncfOp.getIn().getType()); - VectorType dstType = dyn_cast(truncfOp.getOut().getType()); - if (!srcType || !dstType) { + auto srcType = dyn_cast(truncfOp.getIn().getType()); + auto dstType = dyn_cast(truncfOp.getOut().getType()); + if (!srcType || !dstType) return true; - } Type srcScalarType = srcType.getElementType(); Type dstScalarType = dstType.getElementType(); - - if (!isa(srcScalarType) || !isa(dstScalarType)) { + if (!isa(srcScalarType) || !isa(dstScalarType)) return true; - } unsigned srcLaneSize = getVectorLaneSize(srcType); unsigned dstLaneSize = getVectorLaneSize(dstType); - unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth(); unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth(); - if (srcElWidth != 32 || srcLaneSize != 16 || dstElWidth != 16 || - dstLaneSize != 16) { + dstLaneSize != 16) return true; - } return false; }); target.addDynamicallyLegalOp([](arith::TruncIOp trunciOp) { - VectorType srcType = dyn_cast(trunciOp.getIn().getType()); - VectorType dstType = dyn_cast(trunciOp.getOut().getType()); - if (!srcType || !dstType) { + auto srcType = dyn_cast(trunciOp.getIn().getType()); + auto dstType = dyn_cast(trunciOp.getOut().getType()); + if (!srcType || !dstType) return true; - } Type srcScalarType = srcType.getElementType(); Type dstScalarType = dstType.getElementType(); - - if (!isa(srcScalarType) || - !isa(dstScalarType)) { + if (!isa(srcScalarType) || !isa(dstScalarType)) return true; - } unsigned srcLaneSize = getVectorLaneSize(srcType); unsigned dstLaneSize = getVectorLaneSize(dstType); - unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth(); unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth(); if (!(srcLaneSize == 32 && (dstElWidth < srcElWidth) && - (dstLaneSize == srcLaneSize))) { + (dstLaneSize == srcLaneSize))) return true; - } return false; }); - auto op = getOperation(); - if (failed(applyPartialConversion(op, target, std::move(patterns)))) { - signalPassFailure(); + if (auto op = getOperation(); + failed(applyPartialConversion(op, target, std::move(patterns)))) { + return signalPassFailure(); } } }; @@ -3592,8 +3408,7 @@ struct ProcessExtOpsPass // bottom half. This can be used together with SimplifyUPDOpsPass to find // additional common subexpressions with UPDs generated from unaligned // `transfer_read` ops. -struct ExtendUPDOpsPass - : public PassWrapper> { +struct ExtendUPDOpsPass : PassWrapper> { void runOnOperation() override { MLIRContext *context = &getContext(); @@ -3607,9 +3422,10 @@ struct ExtendUPDOpsPass llvm::all_of(op->getUsers(), [](Operation *op) { return isa(op); }); }); - auto op = getOperation(); - if (failed(applyPartialConversion(op, target, std::move(patterns)))) { - signalPassFailure(); + + if (auto op = getOperation(); + failed(applyPartialConversion(op, target, std::move(patterns)))) { + return signalPassFailure(); } } }; @@ -3619,8 +3435,7 @@ struct ExtendUPDOpsPass // TODO: This pass can be extended to work with wide UPD ops that are used by // TODO: a single ext op of the top half, which might be a good opportunity to // TODO: further optimize wide UPDs. -struct SimplifyUPDOpsPass - : public PassWrapper> { +struct SimplifyUPDOpsPass : PassWrapper> { void runOnOperation() override { MLIRContext *context = &getContext(); @@ -3633,9 +3448,10 @@ struct SimplifyUPDOpsPass return !defOp || !isa(defOp) || !defOp->hasOneUse() || op.getIndex() != 0; }); - auto op = getOperation(); - if (failed(applyPartialConversion(op, target, std::move(patterns)))) { - signalPassFailure(); + + if (auto op = getOperation(); + failed(applyPartialConversion(op, target, std::move(patterns)))) { + return signalPassFailure(); } } }; diff --git a/lib/Dialect/AIEX/Transforms/AIEHerdRouting.cpp b/lib/Dialect/AIEX/Transforms/AIEHerdRouting.cpp index 196dfa5121..c51215f31e 100644 --- a/lib/Dialect/AIEX/Transforms/AIEHerdRouting.cpp +++ b/lib/Dialect/AIEX/Transforms/AIEHerdRouting.cpp @@ -12,19 +12,19 @@ #include "aie/Dialect/AIEX/IR/AIEXDialect.h" #include "aie/Dialect/AIEX/Transforms/AIEXPasses.h" -#include "mlir/IR/Attributes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Tools/mlir-translate/MlirTranslateMain.h" #include "mlir/Transforms/DialectConversion.h" +#define DEBUG_TYPE "aie-herd-routing" + using namespace mlir; using namespace xilinx; using namespace xilinx::AIE; using namespace xilinx::AIEX; -template -struct AIEOpRemoval : public OpConversionPattern { +template struct AIEOpRemoval : OpConversionPattern { using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename MyOp::Adaptor; @@ -60,8 +60,8 @@ std::optional getAvailableDestChannel(SmallVector &connects, // look for existing connect for (int i = 0; i < numChannels; i++) { - Port port = {destBundle, i}; - if (std::find(connects.begin(), connects.end(), + if (Port port = {destBundle, i}; + std::find(connects.begin(), connects.end(), Connect{sourcePort, port}) != connects.end()) return {i}; } @@ -70,8 +70,8 @@ std::optional getAvailableDestChannel(SmallVector &connects, for (int i = 0; i < numChannels; i++) { Port port = {destBundle, i}; SmallVector ports; - for (auto connect : connects) - ports.push_back(connect.dst); + for (auto [src, dst] : connects) + ports.push_back(dst); if (std::find(ports.begin(), ports.end(), port) == ports.end()) return {i}; @@ -88,21 +88,19 @@ void buildRoute(int xSrc, int ySrc, int xDest, int yDest, int xCur = xSrc; int yCur = ySrc; - WireBundle curBundle; - int curChannel; - WireBundle lastBundle; + WireBundle curBundle = WireBundle::Core; + int curChannel = 0; + WireBundle lastBundle = WireBundle::Core; Port lastPort = {sourceBundle, sourceChannel}; SmallVector congestion; - llvm::dbgs() << "Build route: " << xSrc << " " << ySrc << " --> " << xDest - << " " << yDest << '\n'; + LLVM_DEBUG(llvm::dbgs() << "Build route: " << xSrc << " " << ySrc << " --> " + << xDest << " " << yDest << '\n'); // traverse horizontally, then vertically - while (!((xCur == xDest) && (yCur == yDest))) { - llvm::dbgs() << "coord " << xCur << " " << yCur << '\n'; - + while (!(xCur == xDest && yCur == yDest)) { + LLVM_DEBUG(llvm::dbgs() << "coord " << xCur << " " << yCur << '\n'); TileID curCoord = {xCur, yCur}; - SmallVector moves; if (xCur < xDest) @@ -133,41 +131,37 @@ void buildRoute(int xSrc, int ySrc, int xDest, int yDest, if (move == lastBundle) continue; - if (move == WireBundle::East) { + if (move == WireBundle::East) xCur = xCur + 1; - // yCur = yCur; - } else if (move == WireBundle::West) { + // yCur = yCur; + else if (move == WireBundle::West) xCur = xCur - 1; - // yCur = yCur; - } else if (move == WireBundle::North) { + // yCur = yCur; + else if (move == WireBundle::North) // xCur = xCur; yCur = yCur + 1; - } else if (move == WireBundle::South) { + else if (move == WireBundle::South) // xCur = xCur; yCur = yCur - 1; - } if (std::find(congestion.begin(), congestion.end(), TileID{xCur, yCur}) != congestion.end()) continue; curBundle = move; - lastBundle = (move == WireBundle::East) ? WireBundle::West - : (move == WireBundle::West) ? WireBundle::East - : (move == WireBundle::North) ? WireBundle::South - : (move == WireBundle::South) ? WireBundle::North - : lastBundle; + lastBundle = move == WireBundle::East ? WireBundle::West + : move == WireBundle::West ? WireBundle::East + : move == WireBundle::North ? WireBundle::South + : move == WireBundle::South ? WireBundle::North + : lastBundle; break; } assert(curChannel >= 0 && "Could not find available destination port!"); - - llvm::dbgs() << "[" << stringifyWireBundle(lastPort.bundle) << " : " - << lastPort.channel - << "], " - "[" - << stringifyWireBundle(curBundle) << " : " << curChannel - << "]\n"; + LLVM_DEBUG(llvm::dbgs() + << "[" << stringifyWireBundle(lastPort.bundle) << " : " + << lastPort.channel << "], [" << stringifyWireBundle(curBundle) + << " : " << curChannel << "]\n"); Port curPort = {curBundle, curChannel}; Connect connect = {lastPort, curPort}; @@ -179,19 +173,17 @@ void buildRoute(int xSrc, int ySrc, int xDest, int yDest, lastPort = {lastBundle, curChannel}; } - llvm::dbgs() << "coord " << xCur << " " << yCur << '\n'; - llvm::dbgs() << "[" << stringifyWireBundle(lastPort.bundle) << " : " - << lastPort.channel - << "], " - "[" - << stringifyWireBundle(destBundle) << " : " << destChannel - << "]\n"; + LLVM_DEBUG(llvm::dbgs() << "coord " << xCur << " " << yCur << '\n'); + LLVM_DEBUG(llvm::dbgs() << "[" << stringifyWireBundle(lastPort.bundle) + << " : " << lastPort.channel << "], [" + << stringifyWireBundle(destBundle) << " : " + << destChannel << "]\n"); switchboxes[std::make_pair(herdOp, TileID{xCur, yCur})].push_back( {lastPort, Port{destBundle, destChannel}}); } -struct AIEHerdRoutingPass : public AIEHerdRoutingBase { +struct AIEHerdRoutingPass : AIEHerdRoutingBase { void runOnOperation() override { DeviceOp device = getOperation(); @@ -206,9 +198,8 @@ struct AIEHerdRoutingPass : public AIEHerdRoutingBase { DenseMap, SmallVector> switchboxes; - for (auto herd : device.getOps()) { + for (auto herd : device.getOps()) herds.push_back(herd); - } for (auto placeOp : device.getOps()) { placeOps.push_back(placeOp); @@ -225,10 +216,10 @@ struct AIEHerdRoutingPass : public AIEHerdRoutingBase { for (auto routeOp : device.getOps()) { routeOps.push_back(routeOp); - AIEX::SelectOp sourceHerds = - dyn_cast(routeOp.getSourceHerds().getDefiningOp()); - AIEX::SelectOp destHerds = - dyn_cast(routeOp.getDestHerds().getDefiningOp()); + auto sourceHerds = + dyn_cast(routeOp.getSourceHerds().getDefiningOp()); + auto destHerds = + dyn_cast(routeOp.getDestHerds().getDefiningOp()); WireBundle sourceBundle = routeOp.getSourceBundle(); WireBundle destBundle = routeOp.getDestBundle(); int sourceChannel = routeOp.getSourceChannelValue(); @@ -262,15 +253,11 @@ struct AIEHerdRoutingPass : public AIEHerdRoutingBase { assert(distances.count(std::make_pair(sourceHerd, destHerd)) == 1); - std::pair distance = - distances[std::make_pair(sourceHerd, destHerd)]; - int distX = distance.first; - int distY = distance.second; + auto [distX, distY] = distances[std::make_pair(sourceHerd, destHerd)]; // FIXME: this looks like it can be improved further ... - for (int xSrc = sourceStartX; xSrc < sourceEndX; xSrc += sourceStrideX) { - for (int ySrc = sourceStartY; ySrc < sourceEndY; - ySrc += sourceStrideY) { - for (int xDst = destStartX; xDst < destEndX; xDst += destStrideX) { + for (int xSrc = sourceStartX; xSrc < sourceEndX; xSrc += sourceStrideX) + for (int ySrc = sourceStartY; ySrc < sourceEndY; ySrc += sourceStrideY) + for (int xDst = destStartX; xDst < destEndX; xDst += destStrideX) for (int yDst = destStartY; yDst < destEndY; yDst += destStrideY) { // Build route (x0, y0) --> (x1, y1) int x0 = xSrc; @@ -299,9 +286,6 @@ struct AIEHerdRoutingPass : public AIEHerdRoutingBase { routes.push_back(route); } - } - } - } } for (const auto &swboxCfg : switchboxes) { @@ -315,17 +299,15 @@ struct AIEHerdRoutingPass : public AIEHerdRoutingBase { auto iterx = builder.create(builder.getUnknownLoc(), x, x + 1, 1); auto itery = builder.create(builder.getUnknownLoc(), y, y + 1, 1); - auto sel = builder.create(builder.getUnknownLoc(), herd, - iterx, itery); + auto sel = + builder.create(builder.getUnknownLoc(), herd, iterx, itery); auto swbox = builder.create(builder.getUnknownLoc(), sel); SwitchboxOp::ensureTerminator(swbox.getConnections(), builder, builder.getUnknownLoc()); Block &b = swbox.getConnections().front(); builder.setInsertionPoint(b.getTerminator()); - for (auto connect : connects) { - Port sourcePort = connect.src; - Port destPort = connect.dst; + for (auto [sourcePort, destPort] : connects) { WireBundle sourceBundle = sourcePort.bundle; int sourceChannel = sourcePort.channel; WireBundle destBundle = destPort.bundle; @@ -347,7 +329,6 @@ struct AIEHerdRoutingPass : public AIEHerdRoutingBase { } }; -std::unique_ptr> -xilinx::AIEX::createAIEHerdRoutingPass() { +std::unique_ptr> AIEX::createAIEHerdRoutingPass() { return std::make_unique(); } diff --git a/lib/Dialect/AIEX/Utils/AIETokenAnalysis.cpp b/lib/Dialect/AIEX/Utils/AIETokenAnalysis.cpp index 3a7b8b735b..fafa61b712 100644 --- a/lib/Dialect/AIEX/Utils/AIETokenAnalysis.cpp +++ b/lib/Dialect/AIEX/Utils/AIETokenAnalysis.cpp @@ -12,8 +12,6 @@ #include "aie/Dialect/AIE/IR/AIEDialect.h" #include "aie/Dialect/AIEX/IR/AIEXDialect.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Tools/mlir-translate/MlirTranslateMain.h" @@ -84,6 +82,8 @@ void xilinx::AIEX::TokenAnalysis::runAnalysis() { } } + int releaseValue = 0; + int acquireValue = 0; // Look for a pair of UseTokenOps (or UseTokenOp and MemcpyOp) such that one // releases and one acquires the same token + value. They form a chain of // releasing and acquiring a token. From the chains of tokens collected, we @@ -93,7 +93,6 @@ void xilinx::AIEX::TokenAnalysis::runAnalysis() { auto tokenRels = map.second; auto tokenAcqs = tokenAcqMap[tokenName]; for (auto ROp : tokenRels) { - int releaseValue; if (auto op = dyn_cast(ROp)) releaseValue = op.getTokenValue(); @@ -101,7 +100,6 @@ void xilinx::AIEX::TokenAnalysis::runAnalysis() { releaseValue = op.getReleaseTokenValue(); for (auto AOp : tokenAcqs) { - int acquireValue; if (auto op = dyn_cast(AOp)) acquireValue = op.getTokenValue(); diff --git a/lib/Targets/ADFGenerateCppGraph.cpp b/lib/Targets/ADFGenerateCppGraph.cpp index dbf142e6b7..915f044fbd 100644 --- a/lib/Targets/ADFGenerateCppGraph.cpp +++ b/lib/Targets/ADFGenerateCppGraph.cpp @@ -9,15 +9,17 @@ //===----------------------------------------------------------------------===// #include "AIETargets.h" + #include "aie/Dialect/ADF/ADFDialect.h" #include "aie/Dialect/ADF/ADFOps.h" + #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" + #include "llvm/Support/FileSystem.h" -#include + #include #include @@ -34,8 +36,7 @@ struct Indent { }; static void resetIndent() { currentindent = 0; } -llvm::raw_ostream &operator<<(llvm::raw_ostream &os, - const struct Indent &indent) { +raw_ostream &operator<<(raw_ostream &os, const Indent &indent) { for (int i = 0; i < currentindent; ++i) os << " "; return os; @@ -49,51 +50,49 @@ struct GraphWriter { std::unordered_map kernelOp2VarName; StringRef getCTypeString(const Type &type) { - if (const auto &t = type.dyn_cast()) - return t.getMnemonic(); - if (const auto &t = type.dyn_cast()) - return t.getMnemonic(); - if (const auto &t = type.dyn_cast()) - return t.getMnemonic(); - if (const auto &t = type.dyn_cast()) - return t.getMnemonic(); - if (const auto &t = type.dyn_cast()) - return t.getMnemonic(); - if (const auto &t = type.dyn_cast()) - return t.getMnemonic(); - if (const auto &t = type.dyn_cast()) - return t.getMnemonic(); - if (const auto &t = type.dyn_cast()) - return t.getMnemonic(); - if (const auto &t = type.dyn_cast()) - return t.getMnemonic(); - assert(false); - return {}; + if (llvm::dyn_cast(type)) + return int8Type::getMnemonic(); + if (llvm::dyn_cast(type)) + return int16Type::getMnemonic(); + if (llvm::dyn_cast(type)) + return int32Type::getMnemonic(); + if (llvm::dyn_cast(type)) + return int64Type::getMnemonic(); + if (llvm::dyn_cast(type)) + return uint8Type::getMnemonic(); + if (llvm::dyn_cast(type)) + return uint16Type::getMnemonic(); + if (llvm::dyn_cast(type)) + return uint32Type::getMnemonic(); + if (llvm::dyn_cast(type)) + return uint64Type::getMnemonic(); + if (llvm::dyn_cast(type)) + return floatType::getMnemonic(); + llvm::report_fatal_error("unknown type"); } - std::string getKernelTypeString(std::string direction, Type type) { - if (auto window = type.dyn_cast()) { + + std::string getKernelTypeString(const std::string &direction, Type type) { + if (auto window = llvm::dyn_cast(type)) return (direction + "_window_" + getCTypeString(window.getType()) + " *") .str(); - } else if (auto stream = type.dyn_cast()) { + if (auto stream = llvm::dyn_cast(type)) return (direction + "_stream_" + getCTypeString(stream.getType()) + " *") .str(); - } else if (auto stream = type.dyn_cast()) { + if (auto stream = llvm::dyn_cast(type)) return std::string(getCTypeString(stream.getType())); - } - assert(false); - return {}; + + llvm::report_fatal_error("unknown kernel type"); } std::string getConnectionTypeString(Type type) { - if (auto windowType = type.dyn_cast()) { + if (auto windowType = llvm::dyn_cast(type)) return std::string("window<") + std::to_string(windowType.getSize()) + "> "; - } else if (auto windowType = type.dyn_cast()) + if (llvm::dyn_cast(type)) return "stream"; - else if (auto windowType = type.dyn_cast()) + if (llvm::dyn_cast(type)) return "parameter"; - assert(false); - return {}; + llvm::report_fatal_error("unknown connection type"); } std::string getTempNetName() { @@ -118,7 +117,6 @@ struct GraphWriter { output << getTempNetName() << " (" << driverOp.getName() << ", " << targetKernelName << ".in[" << targetIndex << "]);\n"; } - // todo: kernel should not drive graph input, add an mlir verifier // condition } @@ -154,13 +152,13 @@ struct GraphWriter { output << getTempNetName() << " (" << sourceKernelName << ".out[" << sourceIndex << "], " << outputOp.getName() << ");\n"; } - // todo: kernel should not drive graph input, add an mlir verifier // condition } sourceIndex++; } } + void writeKernelFunctions(ModuleOp module) { output << "#include \n"; output << "#ifndef FUNCTION_KERNELS_H\n"; @@ -169,13 +167,10 @@ struct GraphWriter { for (Block &block : module.getBodyRegion()) for (auto funcOp : block.getOps()) { output << "void " << funcOp.getSymName() << "("; - FunctionType type = funcOp.getFunctionType(); - - for (unsigned i = 0; i < type.getNumInputs(); i++) { + for (unsigned i = 0; i < type.getNumInputs(); i++) output << getKernelTypeString("input", type.getInput(i)) << " in" << i << ", "; - } for (unsigned i = 0; i < type.getNumResults(); i++) { output << getKernelTypeString("output", type.getResult(i)) << " out" @@ -189,7 +184,7 @@ struct GraphWriter { output << "#endif\n\n"; } - void writeClass(ADF::GraphOp graph) { + void writeClass(GraphOp graph) { output << "#include \n"; output << "using namespace adf;\n"; output << "class " << graph.getName() << " : public graph {\n"; @@ -239,7 +234,6 @@ struct GraphWriter { } else if (auto graph = dyn_cast(op)) { visitOpResultUsers(graph); } else if (auto graph = dyn_cast(op)) { - ; // the graph output should have no users in adf, do nothing here } } // all op visited @@ -264,23 +258,14 @@ struct GraphWriter { } }; -mlir::LogicalResult xilinx::AIE::ADFGenerateCPPGraph(ModuleOp module, - raw_ostream &output) { +LogicalResult AIE::ADFGenerateCPPGraph(ModuleOp module, raw_ostream &output) { GraphWriter writer(output); resetIndent(); writer.writeKernelFunctions(module); for (Block &block : module.getBodyRegion()) - for (auto graphOp : block.getOps()) { + for (auto graphOp : block.getOps()) writer.writeClass(graphOp); - } - return mlir::success(); -} - -// }; - -// std::unique_ptr> -// xilinx::ADF::createADFGenerateCppGraphPass() { -// return std::make_unique(); -// } \ No newline at end of file + return success(); +} \ No newline at end of file diff --git a/lib/Targets/AIETargetXAIEV2.cpp b/lib/Targets/AIETargetXAIEV2.cpp index f1ae0f7885..575097d76d 100644 --- a/lib/Targets/AIETargetXAIEV2.cpp +++ b/lib/Targets/AIETargetXAIEV2.cpp @@ -27,8 +27,7 @@ using namespace xilinx; using namespace xilinx::AIE; using namespace xilinx::AIEX; -namespace xilinx { -namespace AIE { +namespace xilinx::AIE { // This string is output at the top of the lowered C++ code. const char *xaie_cpp_file_header = R"code( @@ -97,9 +96,6 @@ mlir::LogicalResult generateDMAConfig(OpType memOp, raw_ostream &output, int BaseAddrA = 0; bool hasA = false; bool hasB = false; - StringRef bufA = "0"; - StringRef bufB = "0"; - StringRef AbMode = disable; int ndims = 0; ArrayRef dims; // StringRef FifoMode = disable; // FIXME: when to enable FIFO mode? @@ -114,26 +110,24 @@ mlir::LogicalResult generateDMAConfig(OpType memOp, raw_ostream &output, // Memtile DMAs can access neighboring tiles. if (targetModel.isMemTile(col, row)) { - if (targetModel.isWest(col, row, bufferCol, bufferRow)) { + if (targetModel.isWest(col, row, bufferCol, bufferRow)) BaseAddrA += 0x0; - } else if (targetModel.isInternal(col, row, bufferCol, bufferRow)) { + else if (targetModel.isInternal(col, row, bufferCol, bufferRow)) BaseAddrA += targetModel.getMemTileSize() * 1; - } else if (targetModel.isEast(col, row, bufferCol, bufferRow)) { + else if (targetModel.isEast(col, row, bufferCol, bufferRow)) BaseAddrA += targetModel.getMemTileSize() * 2; - } } } + if (op.isA() || targetModel.isShimNOCTile(col, row)) { lenA = op.getLenValue(); bytesA = bufferType.getElementTypeBitWidth() / 8; offsetA = op.getOffsetValue() * bytesA; - bufA = "XAIEDMA_TILE_BD_ADDRA"; hasA = true; } if (op.isB()) { lenB = op.getLenValue(); bytesB = bufferType.getElementTypeBitWidth() / 8; - bufB = "XAIEDMA_TILE_BD_ADDRB"; hasB = true; } if (op.getDimensions()) { @@ -142,14 +136,12 @@ mlir::LogicalResult generateDMAConfig(OpType memOp, raw_ostream &output, } } - if (0 != ndims && AIEArch::AIE2 != targetModel.getTargetArch()) { + if (0 != ndims && AIEArch::AIE2 != targetModel.getTargetArch()) return memOp.emitOpError("DMA contains at least one multi-dimensional " "buffer descriptor. This is currently only " "supported for AIE-ML devices."); - } if (hasA && hasB) { - AbMode = enable; if (lenA != lenB) llvm::errs() << "ABmode must have matching lengths.\n"; if (bytesA != bytesB) @@ -166,14 +158,14 @@ mlir::LogicalResult generateDMAConfig(OpType memOp, raw_ostream &output, int lockID = lock.getLockIDValue(); // Memtile DMAs can access neighboring tiles. if (targetModel.isMemTile(col, row)) { - if (targetModel.isWest(col, row, lockCol, lockRow)) { + if (targetModel.isWest(col, row, lockCol, lockRow)) lockID += 0; - } else if (targetModel.isInternal(col, row, lockCol, lockRow)) { + else if (targetModel.isInternal(col, row, lockCol, lockRow)) lockID += targetModel.getNumLocks(lockCol, lockRow) * 1; - } else if (targetModel.isEast(col, row, lockCol, lockRow)) { + else if (targetModel.isEast(col, row, lockCol, lockRow)) lockID += targetModel.getNumLocks(lockCol, lockRow) * 2; - } } + if (op.acquire() || op.acquireGE()) { hasAcq = true; acqLockID = lockID; @@ -234,16 +226,14 @@ mlir::LogicalResult generateDMAConfig(OpType memOp, raw_ostream &output, << "/* QoS */ 0, " << "/* Cache */ 0, " << "/* Secure */ " << enable << "));\n"; - } else { + } else output << "__mlir_aie_try(XAie_DmaSetAddrLen(" << tileDMAInstRefStr(col, row, bdNum) << ", /* addrA */ " << "0x" << llvm::utohexstr(BaseAddrA + offsetA) << ", " << " /* len */ " << lenA << " * " << bytesA << "));\n"; - } - } else { + } else generateXAieDmaSetMultiDimAddr(output, ndims, dims, col, row, bdNum, BaseAddrA, offsetA, lenA, bytesA, "1"); - } if (block.getNumSuccessors() > 0) { Block *nextBlock = block.getSuccessors()[0]; // should have only one @@ -259,6 +249,7 @@ mlir::LogicalResult generateDMAConfig(OpType memOp, raw_ostream &output, << " /* nextbd */ " << nextBdNum << ", " << " /* enableNextBd */ " << enableNextBd << "));\n"; } + if (foundBdPacket) { output << "__mlir_aie_try(XAie_DmaSetPkt(" << tileDMAInstRefStr(col, row, bdNum) << ", " @@ -276,10 +267,8 @@ mlir::LogicalResult generateDMAConfig(OpType memOp, raw_ostream &output, for (auto &block : memOp.getBody()) { for (auto op : block.template getOps()) { int bdNum = blockMap[op.getDest()]; - - llvm::StringRef dmaDir = stringifyDMAChannelDir(op.getChannelDir()); + StringRef dmaDir = stringifyDMAChannelDir(op.getChannelDir()); int chNum = op.getChannelIndex(); - output << "__mlir_aie_try(XAie_DmaChannelPushBdToQueue(" << deviceInstRef << ", " << tileLocStr(col, row) << ", " << "/* ChNum */" << chNum @@ -307,9 +296,8 @@ mlir::LogicalResult AIETranslateToXAIEV2(ModuleOp module, raw_ostream &output) { DenseMap tiles; DenseMap> buffers; - if (module.getOps().empty()) { + if (module.getOps().empty()) return module.emitOpError("expected AIE.device operation at toplevel"); - } DeviceOp targetOp = *(module.getOps().begin()); const auto &targetModel = targetOp.getTargetModel(); @@ -391,12 +379,11 @@ mlir::LogicalResult AIETranslateToXAIEV2(ModuleOp module, raw_ostream &output) { << tileLocStr(col, row) << ", XAie_LockInit(l, 0x0), 0));\n"; if (auto coreOp = tileOp.getCoreOp()) { std::string fileName; - if (auto fileAttr = coreOp->getAttrOfType("elf_file")) { + if (auto fileAttr = coreOp->getAttrOfType("elf_file")) fileName = std::string(fileAttr.getValue()); - } else { + else fileName = std::string("core_") + std::to_string(col) + "_" + std::to_string(row) + ".elf"; - } output << "{\n" << "AieRC RC = XAie_LoadElf(" << deviceInstRef << ", " << tileLocStr(col, row) << ", " @@ -557,12 +544,11 @@ mlir::LogicalResult AIETranslateToXAIEV2(ModuleOp module, raw_ostream &output) { for (auto &block : op.getBody()) { if (!block.getOps().empty()) { blockMap[&block] = bdNum; - uint64_t offset = 0; for (auto op : block.getOps()) { offset = op.getOffsetValue(); - auto buffer = cast( - op.getBuffer().getDefiningOp()); + auto buffer = + cast(op.getBuffer().getDefiningOp()); output << "u64 mlir_aie_external_get_addr_myBuffer_" << col << row << "_" << bdNum << "(void) {\n" @@ -599,11 +585,10 @@ mlir::LogicalResult AIETranslateToXAIEV2(ModuleOp module, raw_ostream &output) { int row = tile.rowIndex(); int lockID = lock.getLockIDValue(); auto init = lock.getInit(); - if (init) { + if (init) output << "__mlir_aie_try(XAie_LockSetValue(" << deviceInstRef << ", " << tileLocStr(col, row) << ", " << "XAie_LockInit(" << lockID << ", " << *init << ")));\n"; - } } output << "return XAIE_OK;\n"; output << "} // mlir_aie_initialize_locks\n"; @@ -632,8 +617,8 @@ mlir::LogicalResult AIETranslateToXAIEV2(ModuleOp module, raw_ostream &output) { output << "x = " << col << ";\n"; output << "y = " << row << ";\n"; } - } else if (AIEX::SelectOp sel = dyn_cast( - switchboxOp.getTile().getDefiningOp())) { + } else if (auto sel = + dyn_cast(switchboxOp.getTile().getDefiningOp())) { // parameterize streamswitch's configuration isParam = true; HerdOp sourceHerd = dyn_cast(sel.getStartHerd().getDefiningOp()); @@ -661,14 +646,13 @@ mlir::LogicalResult AIETranslateToXAIEV2(ModuleOp module, raw_ostream &output) { << "; y += " << strideYValue << ") {\n"; } - for (auto connectOp : b.getOps()) { + for (auto connectOp : b.getOps()) output << "__mlir_aie_try(XAie_StrmConnCctEnable(" << deviceInstRef << ", " << tileLocStr("x", "y") << ", " << stringifyWireBundle(connectOp.getSourceBundle()).upper() << ", " << connectOp.sourceIndex() << ", " << stringifyWireBundle(connectOp.getDestBundle()).upper() << ", " << connectOp.destIndex() << "));\n"; - } for (auto connectOp : b.getOps()) { int mask = 0; @@ -746,7 +730,7 @@ mlir::LogicalResult AIETranslateToXAIEV2(ModuleOp module, raw_ostream &output) { } for (auto connectOp : b.getOps()) { - if (connectOp.getSourceBundle() == WireBundle::North) { + if (connectOp.getSourceBundle() == WireBundle::North) // demux! output << "__mlir_aie_try(XAie_EnableAieToShimDmaStrmPort(" @@ -755,7 +739,7 @@ mlir::LogicalResult AIETranslateToXAIEV2(ModuleOp module, raw_ostream &output) { // << // stringifyWireBundle(connectOp.sourceBundle()).upper() << connectOp.sourceIndex() << "));\n"; - } else if (connectOp.getDestBundle() == WireBundle::North) { + else if (connectOp.getDestBundle() == WireBundle::North) // mux output << "__mlir_aie_try(XAie_EnableShimDmaToAieStrmPort(" @@ -764,7 +748,6 @@ mlir::LogicalResult AIETranslateToXAIEV2(ModuleOp module, raw_ostream &output) { // << // stringifyWireBundle(connectOp.sourceBundle()).upper() << connectOp.destIndex() << "));\n"; - } } } for (auto switchboxOp : targetOp.getOps()) { @@ -772,17 +755,15 @@ mlir::LogicalResult AIETranslateToXAIEV2(ModuleOp module, raw_ostream &output) { Block &b = r.front(); bool isEmpty = b.getOps().empty(); int col = switchboxOp.getCol(); - if (!isEmpty) { + if (!isEmpty) output << "// Shim Switch column " << col << "\n"; - } - for (auto connectOp : b.getOps()) { + for (auto connectOp : b.getOps()) output << "__mlir_aie_try(XAie_StrmConnCctEnable(" << deviceInstRef << ", " << tileLocStr(col, 0) << ", " << stringifyWireBundle(connectOp.getSourceBundle()).upper() << ", " << connectOp.sourceIndex() << ", " << stringifyWireBundle(connectOp.getDestBundle()).upper() << ", " << connectOp.destIndex() << "));\n"; - } } output << "return XAIE_OK;\n"; @@ -805,7 +786,7 @@ mlir::LogicalResult AIETranslateToXAIEV2(ModuleOp module, raw_ostream &output) { Type t = buf.getType(); Type et; std::string typestr; - if (auto memrefType = t.dyn_cast()) { + if (auto memrefType = llvm::dyn_cast(t)) { et = memrefType.getElementType(); if (et.isInteger(32)) typestr = "int32_t"; @@ -885,5 +866,4 @@ mlir::LogicalResult AIETranslateToXAIEV2(ModuleOp module, raw_ostream &output) { return success(); } -} // namespace AIE -} // namespace xilinx +} // namespace xilinx::AIE diff --git a/lib/Targets/AIETargets.cpp b/lib/Targets/AIETargets.cpp index 9e8c23dede..7ff86f8565 100644 --- a/lib/Targets/AIETargets.cpp +++ b/lib/Targets/AIETargets.cpp @@ -24,7 +24,6 @@ #include "mlir/IR/Attributes.h" #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Target/LLVMIR/Import.h" -#include "mlir/Tools/mlir-translate/MlirTranslateMain.h" #include "mlir/Tools/mlir-translate/Translation.h" #include "llvm/Support/JSON.h" @@ -42,14 +41,17 @@ static llvm::cl::opt llvm::cl::init(0)); llvm::json::Value attrToJSON(Attribute &attr) { - if (auto a = attr.dyn_cast()) { - return llvm::json::Value(a.getValue().str()); - } else if (auto arrayAttr = attr.dyn_cast()) { + if (auto a = llvm::dyn_cast(attr)) + return {a.getValue().str()}; + + if (auto arrayAttr = llvm::dyn_cast(attr)) { llvm::json::Array arrayJSON; for (auto a : arrayAttr) arrayJSON.push_back(attrToJSON(a)); return llvm::json::Value(std::move(arrayJSON)); - } else if (auto dictAttr = attr.dyn_cast()) { + } + + if (auto dictAttr = llvm::dyn_cast(attr)) { llvm::json::Object dictJSON; for (auto a : dictAttr) { auto ident = a.getName(); @@ -57,14 +59,15 @@ llvm::json::Value attrToJSON(Attribute &attr) { dictJSON[ident.str()] = attrToJSON(attr); } return llvm::json::Value(std::move(dictJSON)); - } else if (auto intAttr = attr.dyn_cast()) { + } + + if (auto intAttr = llvm::dyn_cast(attr)) return llvm::json::Value(intAttr.getInt()); - } else - return llvm::json::Value(std::string("")); + + return llvm::json::Value(std::string("")); } -namespace xilinx { -namespace AIE { +namespace xilinx::AIE { static void registerDialects(DialectRegistry ®istry) { registry.insert(); @@ -273,7 +276,7 @@ ENTRY(_main_init) SECTIONS { . = 0x0; - .text : { + .text : { /* the _main_init symbol from me_basic.o has to come at address zero. */ *me_basic.o(.text) . = 0x200; @@ -286,7 +289,7 @@ SECTIONS _dtors_end = .; *(.text) } > program - .data : { + .data : { *(.data*); *(.rodata*) } > data @@ -546,5 +549,4 @@ SECTIONS }, registerDialects); } -} // namespace AIE -} // namespace xilinx +} // namespace xilinx::AIE diff --git a/lib/Targets/AIEVecToCpp/TranslateAIEVecToCpp.cpp b/lib/Targets/AIEVecToCpp/TranslateAIEVecToCpp.cpp index 2d994be4bd..ad27923d46 100644 --- a/lib/Targets/AIEVecToCpp/TranslateAIEVecToCpp.cpp +++ b/lib/Targets/AIEVecToCpp/TranslateAIEVecToCpp.cpp @@ -12,8 +12,10 @@ //===----------------------------------------------------------------------===// #include "TranslateAIEVecToCpp.h" + #include "aie/Dialect/AIEVec/AIEVecUtils.h" #include "aie/Dialect/AIEVec/IR/AIEVecOps.h" + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" @@ -23,20 +25,18 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" #include "mlir/Support/IndentedOstream.h" #include "mlir/Support/MathExtras.h" -#include "llvm/ADT/DenseMap.h" + #include "llvm/ADT/ScopedHashTable.h" #include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" + #include #include @@ -55,9 +55,9 @@ using llvm::formatv; /// on each element doesn't return a string. template -inline LogicalResult -interleaveWithError(ForwardIterator begin, ForwardIterator end, - UnaryFunctor eachFn, NullaryFunctor betweenFn) { +LogicalResult interleaveWithError(ForwardIterator begin, ForwardIterator end, + UnaryFunctor eachFn, + NullaryFunctor betweenFn) { if (begin == end) return success(); if (failed(eachFn(*begin))) @@ -72,17 +72,15 @@ interleaveWithError(ForwardIterator begin, ForwardIterator end, } template -inline LogicalResult interleaveWithError(const Container &c, - UnaryFunctor eachFn, - NullaryFunctor betweenFn) { +LogicalResult interleaveWithError(const Container &c, UnaryFunctor eachFn, + NullaryFunctor betweenFn) { return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn); } template -inline LogicalResult interleaveCommaWithError(const Container &c, - raw_ostream &os, - UnaryFunctor eachFn) { - return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; }); +LogicalResult interleaveCommaWithError(const Container &c, raw_ostream &os, + UnaryFunctor eachFn) { + return interleaveWithError(c.begin(), c.end(), eachFn, [&] { os << ", "; }); } namespace { @@ -147,7 +145,8 @@ struct CppEmitter { std::string getNewName(std::string prefix = "v"); // Set the dim size at position index of the memref to the parameter - void setMemRefDimParam(Value memref, unsigned index, std::string parameter); + void setMemRefDimParam(Value memref, unsigned index, + const std::string ¶meter); // For the dynamic shaped memref, return the parametric size at index StringRef getMemRefDimParam(Value memref, unsigned index); @@ -187,11 +186,11 @@ struct CppEmitter { bool hasBlockLabel(Block &block); /// Returns the output stream. - raw_indented_ostream &ostream() { return os; }; + raw_indented_ostream &ostream() { return os; } /// Returns if all variables for op results and basic block arguments need to /// be declared at the beginning of a function. - bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; }; + bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; } private: using ValueMapper = llvm::ScopedHashTable; @@ -241,13 +240,12 @@ static bool skippedOp(Operation *op, CppEmitter &emitter, // skip op 2 : some aievec::srs for float types if (auto srsOp = dyn_cast(op)) { // Get the datatype of the source accumulator and result vector - VectorType accType = srsOp.getSource().getType().cast(); + auto accType = srsOp.getSource().getType().cast(); Type eltType = accType.getElementType(); - Value source = srsOp.getSource(); // If the underlying element types are float, then we do not really need an // srs op if source of srsOp has only one use. - if (!AIEML && eltType.isa() && - source.getDefiningOp()->hasOneUse()) { + if (Value source = srsOp.getSource(); !AIEML && eltType.isa() && + source.getDefiningOp()->hasOneUse()) { StringRef srcName = emitter.getOrCreateName(source); emitter.setName(srsOp->getResult(0), srcName); skip = true; @@ -256,13 +254,12 @@ static bool skippedOp(Operation *op, CppEmitter &emitter, // skip op 3 : some aievec::ups for float ops else if (auto upsOp = dyn_cast(op)) { // Get the datatype of the source vector and result accumulator - VectorType accType = upsOp.getResult().getType().cast(); + auto accType = upsOp.getResult().getType().cast(); Type eltType = accType.getElementType(); - Value source = upsOp.getSource(); // If the underlying element types are float, then we do not really need a // ups op if the source accumulator has only one use. - if (!AIEML && eltType.isa() && - source.getDefiningOp()->hasOneUse()) { + if (Value source = upsOp.getSource(); !AIEML && eltType.isa() && + source.getDefiningOp()->hasOneUse()) { StringRef srcName = emitter.getOrCreateName(source); emitter.setName(upsOp->getResult(0), srcName); skip = true; @@ -291,7 +288,7 @@ static bool skippedOp(Operation *op, CppEmitter &emitter, static LogicalResult parseMemRefDynamicDims(CppEmitter &emitter, func::FuncOp func) { // Step1: Walk over all the operations that are memref dimOp - func.walk([&](mlir::Operation *Op) { + func.walk([&](Operation *Op) { if (auto op = dyn_cast(Op)) { // Extract the source memref, result, and index Value source = op.getSource(); @@ -299,7 +296,7 @@ static LogicalResult parseMemRefDynamicDims(CppEmitter &emitter, auto indexOp = dyn_cast(op.getIndex().getDefiningOp()); assert(indexOp && "Failed to get the index value of dimOp"); // Get the constant index value - llvm::APInt idxVal = indexOp.getValue().cast().getValue(); + APInt idxVal = indexOp.getValue().cast().getValue(); unsigned index = idxVal.getZExtValue(); // Assign a printable name to the result StringRef name = emitter.getOrCreateName(result, "m"); @@ -310,7 +307,7 @@ static LogicalResult parseMemRefDynamicDims(CppEmitter &emitter, // Step2: Iterate over all the block arguments, and make sure that the memref // args have a parameter associated with the dynamic sized dimension for (BlockArgument arg : func.getArguments()) { - MemRefType argType = arg.getType().dyn_cast(); + auto argType = llvm::dyn_cast(arg.getType()); if (!argType) continue; for (unsigned dim = 0; dim < argType.getRank(); ++dim) { @@ -329,8 +326,7 @@ static LogicalResult parseMemRefDynamicDims(CppEmitter &emitter, // Print the memref dims, if the memref has dynamic shape static LogicalResult printMemRefDims(CppEmitter &emitter, BlockArgument arg) { raw_indented_ostream &os = emitter.ostream(); - MemRefType argType = arg.getType().dyn_cast(); - if (argType) { + if (auto argType = llvm::dyn_cast(arg.getType())) { for (unsigned dim = 0; dim < argType.getRank(); ++dim) { if (argType.isDynamicDim(dim)) { StringRef param = emitter.getMemRefDimParam(arg, dim); @@ -345,7 +341,7 @@ static LogicalResult printMemRefDims(CppEmitter &emitter, BlockArgument arg) { static LogicalResult createLinearizedAccess(CppEmitter &emitter, Value source, SmallVector indices, std::string &access) { - MemRefType memRefType = source.getType().dyn_cast(); + auto memRefType = llvm::dyn_cast(source.getType()); assert(memRefType && "cannot creating linearized expression for non-memref type"); ArrayRef stride = memRefType.getShape(); @@ -396,11 +392,9 @@ static LogicalResult createLinearizedAccess(CppEmitter &emitter, Value source, // Return true if the array accessed by this value is readonly static bool isReadOnly(Value read) { - for (auto *user : read.getUsers()) { - if (isa(user)) - return false; - } - return true; + return std::none_of( + read.getUsers().begin(), read.getUsers().end(), + [](auto *user) { return isa(user); }); } //===----------------------------------------------------------------------===// @@ -411,10 +405,10 @@ static bool isReadOnly(Value read) { static std::pair getTripCount(scf::ForOp forOp) { // If the upper and lower bounds are constant values, return the difference. auto lb = forOp.getLowerBound().getDefiningOp(); - auto ub = forOp.getUpperBound().getDefiningOp(); - if (lb && ub) { - llvm::APInt ubValue = ub.getValue().cast().getValue(); - llvm::APInt lbValue = lb.getValue().cast().getValue(); + if (auto ub = forOp.getUpperBound().getDefiningOp(); + lb && ub) { + APInt ubValue = ub.getValue().cast().getValue(); + APInt lbValue = lb.getValue().cast().getValue(); return std::make_pair(true, ubValue.getSExtValue() - lbValue.getSExtValue()); } @@ -424,7 +418,7 @@ static std::pair getTripCount(scf::ForOp forOp) { // Get the loop step size of the for operator static std::pair getStep(scf::ForOp forOp) { if (auto step = forOp.getStep().getDefiningOp()) { - llvm::APInt stepValue = step.getValue().cast().getValue(); + APInt stepValue = step.getValue().cast().getValue(); return std::make_pair(true, stepValue.getSExtValue()); } return std::make_pair(false, 0); @@ -481,7 +475,7 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::UPDOp updOp) { raw_indented_ostream &os = emitter.ostream(); Value result = updOp.getResult(); - VectorType resultType = result.getType().cast(); + auto resultType = result.getType().cast(); int32_t vecSizeInBits = getVectorSizeInBits(resultType); int32_t elementSizeInBits = getElementSizeInBits(resultType); @@ -541,7 +535,7 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::UPDOp updOp) { // If the source array of upd is read-only, load from restrict pointer bool readOnly = isReadOnly(source); std::string restrictPrefix = - readOnly ? ("r_" + emitter.getOrCreateName(result).str() + "_") : ""; + readOnly ? "r_" + emitter.getOrCreateName(result).str() + "_" : ""; // Create a restrict pointer if (readOnly && !vector) { if (failed(emitter.emitType(updOp->getLoc(), source.getType()))) @@ -591,7 +585,7 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::UPSOp upsOp) { if (!emitter.hasValueInScope(source)) return failure(); - VectorType accType = upsOp.getResult().getType().cast(); + auto accType = upsOp.getResult().getType().cast(); unsigned lanes = getVectorLaneSize(accType); Type eltType = accType.getElementType(); @@ -603,8 +597,8 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::UPSOp upsOp) { } // Determine if it is lups or ups based on accumulator type - auto iType = eltType.dyn_cast(); - auto fType = eltType.dyn_cast(); + auto iType = llvm::dyn_cast(eltType); + auto fType = llvm::dyn_cast(eltType); if (iType) { if (iType.getWidth() == 80) os << "l"; @@ -648,33 +642,30 @@ static LogicalResult printOperation(CppEmitter &emitter, return failure(); // Get the datatype of the source and result vector - VectorType resType = castOp->getResult(0).getType().cast(); + auto resType = castOp->getResult(0).getType().cast(); Type eltType = resType.getElementType(); unsigned lanes = getVectorLaneSize(resType); raw_indented_ostream &os = emitter.ostream(); - unsigned width = 0; + unsigned width; if (isResAcc) { - if (eltType.isa()) { + if (eltType.isa()) os << "v" << lanes << "accfloat"; - } else { + else { width = getElementSizeInBits(resType); os << "v" << lanes << "acc" << width; } + } else if (eltType.isa()) { + width = eltType.cast().getWidth(); + os << "v" << lanes; + if (width == 16) + os << "bfloat16"; + else + os << "float"; } else { - if (eltType.isa()) { - width = eltType.cast().getWidth(); - os << "v" << lanes; - if (width == 16) { - os << "bfloat16"; - } else { - os << "float"; - } - } else { - width = getElementSizeInBits(resType); - os << "v" << lanes << "int" << width; - } + width = getElementSizeInBits(resType); + os << "v" << lanes << "int" << width; } os << "("; os << emitter.getOrCreateName(source); @@ -709,8 +700,8 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::SRSOp srsOp) { Value shift = srsOp.getShift(); // Get the datatype of the source accumulator and result vector - VectorType accType = srsOp.getSource().getType().cast(); - VectorType resType = srsOp->getResult(0).getType().cast(); + auto accType = srsOp.getSource().getType().cast(); + auto resType = srsOp->getResult(0).getType().cast(); Type eltType = accType.getElementType(); unsigned lanes = getVectorLaneSize(resType); @@ -728,18 +719,16 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::SRSOp srsOp) { // srs op. We can simply generate an assignment if (eltType.isa()) { if (AIEML) { - unsigned width = getElementSizeInBits(resType); - if (width == 32) { + if (unsigned width = getElementSizeInBits(resType); width == 32) os << "srs"; - } else if (width == 16) { + else if (width == 16) os << "to_v16bfloat16"; - } os << "("; os << emitter.getOrCreateName(source); os << ")"; - } else { + } else os << emitter.getOrCreateName(source); - } + return success(); } @@ -748,7 +737,7 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::SRSOp srsOp) { unsigned resultWidth = getElementSizeInBits(accType); unsigned resWidth = getElementSizeInBits(resType); unsigned srcWidth = 0; - if (auto iType = eltType.dyn_cast()) + if (auto iType = llvm::dyn_cast(eltType)) srcWidth = iType.getWidth(); // Based on the datatypes, generate srs version @@ -758,21 +747,20 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::SRSOp srsOp) { else if (srcWidth == 48 && resultWidth == 8) os << "b"; - if (AIEML) { + if (AIEML) os << "srs_to_v" << std::to_string(lanes) << "int" << std::to_string(resWidth); - } else { + else os << "srs"; - } os << "("; os << emitter.getOrCreateName(source); os << ", "; - if (srsOp.getShift().getType().cast().getWidth() != 32) { + if (srsOp.getShift().getType().cast().getWidth() != 32) os << "(int32_t)"; - } os << emitter.getOrCreateName(shift); os << ")"; + return success(); } @@ -798,6 +786,7 @@ static LogicalResult printOperation(CppEmitter &emitter, os << ", "; os << std::to_string(idx); os << ")"; + return success(); } @@ -806,8 +795,7 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::BroadcastScalarOp broadcastScalarOp) { auto source = broadcastScalarOp.getSource(); - VectorType resType = - broadcastScalarOp.getResult().getType().cast(); + auto resType = broadcastScalarOp.getResult().getType().cast(); unsigned width = getElementSizeInBits(resType); unsigned lanes = getVectorLaneSize(resType); raw_indented_ostream &os = emitter.ostream(); @@ -821,12 +809,12 @@ printOperation(CppEmitter &emitter, if (eltType.isa()) { os << lanes << "int"; os << width; - } else if (width == 16) { + } else if (width == 16) os << lanes << "bfloat16"; - } else { + else os << lanes << "float"; - } os << "(" << emitter.getOrCreateName(source) << ")"; + return success(); } @@ -844,7 +832,7 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::ExtOp extOp) { if (!emitter.hasValueInScope(source)) return failure(); - VectorType resType = extOp.getResult().getType().cast(); + auto resType = extOp.getResult().getType().cast(); Type eltType = resType.getElementType(); unsigned lanes = getVectorLaneSize(resType); unsigned resWidth = getElementSizeInBits(resType); @@ -852,13 +840,12 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::ExtOp extOp) { // Print the version of ext for aie-ml if (AIEML) { os << "extract_v" << std::to_string(lanes); - if (eltType.isa()) { + if (eltType.isa()) os << "int" << std::to_string(resWidth); - } else if (resWidth == 16) { + else if (resWidth == 16) os << "bfloat16"; - } else { + else os << "float"; - } } else { // Print the version of ext for aie1 int32_t vecSizeInBits = getVectorSizeInBits(resType); @@ -903,6 +890,7 @@ static LogicalResult printOperation(CppEmitter &emitter, first = false; } os << ")"; + return success(); } @@ -934,6 +922,7 @@ static LogicalResult printOperation(CppEmitter &emitter, return failure(); os << emitter.getOrCreateName(shift); os << ")"; + return success(); } @@ -959,6 +948,7 @@ static LogicalResult printOperation(CppEmitter &emitter, os << ", "; os << std::to_string(mode); os << ")"; + return success(); } @@ -975,7 +965,7 @@ static LogicalResult printOperation(CppEmitter &emitter, return failure(); // Determine if we want to geneate select32, or select16, or select8 - VectorType xbuffType = selectOp.getXbuff().getType().cast(); + auto xbuffType = selectOp.getXbuff().getType().cast(); int32_t elementSizeInBits = getElementSizeInBits(xbuffType); assert(elementSizeInBits == 16 || elementSizeInBits == 32 || elementSizeInBits == 64); @@ -1021,8 +1011,8 @@ static LogicalResult printOperation(CppEmitter &emitter, os << ", " << selectOp.getYoffsetsHi(); if (!selectOp.getYsquare().empty()) os << ", " << selectOp.getYsquare(); - os << ")"; + return success(); } @@ -1038,7 +1028,7 @@ static LogicalResult printOperation(CppEmitter &emitter, return failure(); // Determine the flavor of result - VectorType sourceType = packOp.getSource().getType().cast(); + auto sourceType = packOp.getSource().getType().cast(); Type scalarType = sourceType.getElementType(); os << (scalarType.isUnsignedInteger() ? "upack" : "pack"); os << "("; @@ -1047,6 +1037,7 @@ static LogicalResult printOperation(CppEmitter &emitter, return failure(); os << emitter.getOrCreateName(source); os << ")"; + return success(); } @@ -1097,7 +1088,6 @@ static LogicalResult printMinMaxOperand(CppEmitter &emitter, T op, return failure(); raw_indented_ostream &os = emitter.ostream(); - os << emitter.getOrCreateName(operand); return success(); @@ -1117,7 +1107,6 @@ static LogicalResult printAddElemOrSubElemOperand(CppEmitter &emitter, T op, return failure(); raw_indented_ostream &os = emitter.ostream(); - os << emitter.getOrCreateName(operand); return success(); @@ -1174,13 +1163,9 @@ static LogicalResult printFMAOrMulElemOperand(CppEmitter &emitter, T op, return failure(); raw_indented_ostream &os = emitter.ostream(); - os << emitter.getOrCreateName(operand); - - if (size == 32 && iType) { - StringRef str = opNum == 0 ? "undef_v16int32()" : "broadcast_zero_s32()"; - os << ", " << str; - } + if (size == 32 && iType) + os << ", " << (opNum == 0 ? "undef_v16int32()" : "broadcast_zero_s32()"); return success(); } @@ -1188,7 +1173,6 @@ static LogicalResult printFMAOrMulElemOperand(CppEmitter &emitter, T op, // Print lhs or rhs operand of mul_conv/mac_conv intrinsic template static LogicalResult printFMAOrMulConvOperand(CppEmitter &emitter, T op, - int32_t M, int32_t N, unsigned opNum) { // We currently only support printing operands 0 and 1 if (opNum > 1) @@ -1200,7 +1184,6 @@ static LogicalResult printFMAOrMulConvOperand(CppEmitter &emitter, T op, return failure(); raw_indented_ostream &os = emitter.ostream(); - os << emitter.getOrCreateName(operand); return success(); @@ -1220,15 +1203,16 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::MulOp mulOp) { std::string opname; // Create opname based on the result type - VectorType resType = mulOp.getResult().getType().cast(); + auto resType = mulOp.getResult().getType().cast(); Type eltType = resType.getElementType(); if (!simpleScheme) { - if (auto iType = eltType.dyn_cast()) { + if (auto iType = llvm::dyn_cast(eltType)) { if (iType.getWidth() == 80) opname = "l"; } else if (eltType.isa()) opname = "fp"; } + opname += "mul"; if (!simpleScheme && !eltType.isa()) opname += std::to_string(getVectorLaneSize(resType)); @@ -1253,52 +1237,49 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::MulOp mulOp) { // Generate the MulElem op static LogicalResult printOperation(CppEmitter &emitter, - aievec::MulElemOp mul_elemOp) { - auto lhs = mul_elemOp.getLhs(); - auto rhs = mul_elemOp.getRhs(); + aievec::MulElemOp mulElemOp) { + auto lhs = mulElemOp.getLhs(); + auto rhs = mulElemOp.getRhs(); // The sources should have already been emitted if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs)) return failure(); - std::string opname; - opname = "mul_elem"; + std::string opname = "mul_elem"; // Create opname based on the source type - VectorType lhsType = mul_elemOp.getLhs().getType().cast(); + auto lhsType = mulElemOp.getLhs().getType().cast(); Type eltType = lhsType.getElementType(); int32_t lsize = getElementSizeInBits(lhsType); - auto iType = eltType.dyn_cast(); + auto iType = llvm::dyn_cast(eltType); if (iType) { - if (lsize == 32) { + if (lsize == 32) opname += "_16_2"; - } else if (lsize == 16) { + else if (lsize == 16) opname += "_32"; - } else if (lsize == 8) { + else if (lsize == 8) opname += "_32_2"; - } } else if (eltType.isa()) { - if (lsize == 32) { + if (lsize == 32) opname += "_16"; - } else if (lsize == 16) { + else if (lsize == 16) opname += "_16_2"; - } } raw_indented_ostream &os = emitter.ostream(); // Generate the initialization for the accumulator - if (failed(emitter.emitAssignPrefix(*mul_elemOp, true /*isAcc*/))) + if (failed(emitter.emitAssignPrefix(*mulElemOp, true /*isAcc*/))) return failure(); os << opname; os << "("; - if (failed(printFMAOrMulElemOperand(emitter, mul_elemOp, + if (failed(printFMAOrMulElemOperand(emitter, mulElemOp, iType, lsize, 1))) return failure(); os << ", "; - if (failed(printFMAOrMulElemOperand(emitter, mul_elemOp, + if (failed(printFMAOrMulElemOperand(emitter, mulElemOp, iType, lsize, 0))) return failure(); os << ")"; @@ -1308,47 +1289,45 @@ static LogicalResult printOperation(CppEmitter &emitter, // Generate the MulConv op static LogicalResult printOperation(CppEmitter &emitter, - aievec::MulConvOp mul_convOp) { - auto lhs = mul_convOp.getLhs(); - auto rhs = mul_convOp.getRhs(); + aievec::MulConvOp mulConvOp) { + auto lhs = mulConvOp.getLhs(); + auto rhs = mulConvOp.getRhs(); // The sources should have already been emitted if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs)) return failure(); - std::string opname; - opname = "mul_conv"; - // Create opname based on the source type - VectorType lhsType = mul_convOp.getLhs().getType().cast(); + auto lhsType = mulConvOp.getLhs().getType().cast(); Type eltType = lhsType.getElementType(); int32_t lsize = getElementSizeInBits(lhsType); - auto iType = eltType.dyn_cast(); + auto iType = llvm::dyn_cast(eltType); // Only support int16 and int8 cases if (!iType || !(lsize == 16 || lsize == 8)) { return failure(); } - int32_t M = mul_convOp.getM(); - int32_t N = mul_convOp.getN(); - opname += ("_" + std::to_string(M) + "x" + std::to_string(N)); + int32_t M = mulConvOp.getM(); + int32_t N = mulConvOp.getN(); + std::string opname = + "mul_conv_" + std::to_string(M) + "x" + std::to_string(N); raw_indented_ostream &os = emitter.ostream(); // Generate the initialization for the accumulator - if (failed(emitter.emitAssignPrefix(*mul_convOp, true /*isAcc*/))) + if (failed(emitter.emitAssignPrefix(*mulConvOp, true /*isAcc*/))) return failure(); os << opname; os << "("; - if (failed(printFMAOrMulConvOperand(emitter, mul_convOp, M, - N, 0))) + if (failed( + printFMAOrMulConvOperand(emitter, mulConvOp, 0))) return failure(); os << ", "; - if (failed(printFMAOrMulConvOperand(emitter, mul_convOp, M, - N, 1))) + if (failed( + printFMAOrMulConvOperand(emitter, mulConvOp, 1))) return failure(); os << ")"; @@ -1371,15 +1350,14 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::AddOp addOp) { return failure(); // Get the scalar type of result vector - VectorType resultType = addOp.getResult().getType().cast(); + auto resultType = addOp.getResult().getType().cast(); unsigned lanes = getVectorLaneSize(resultType); Type elementType = resultType.getElementType(); bool floatType = elementType.isa(); // Detemine if the add scheme is simple or complex - bool simpleScheme = addOp.getStart(0).empty(); - if (simpleScheme) { + if (addOp.getStart(0).empty()) { // Handle float type operation if (floatType) { os << "fpadd"; @@ -1406,6 +1384,7 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::AddOp addOp) { if (failed(printAddOrSubOperand(emitter, addOp, 1))) return failure(); os << ")"; + return success(); } @@ -1425,15 +1404,14 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::SubOp subOp) { return failure(); // Get the scalar type of result vector - VectorType resultType = subOp.getResult().getType().cast(); + auto resultType = subOp.getResult().getType().cast(); unsigned lanes = getVectorLaneSize(resultType); Type elementType = resultType.getElementType(); bool floatType = elementType.isa(); // Detemine if the sub scheme is simple or complex - bool simpleScheme = subOp.getStart(0).empty(); - if (simpleScheme) { + if (subOp.getStart(0).empty()) { // Handle float type operation if (floatType) { os << "fpsub"; @@ -1460,6 +1438,7 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::SubOp subOp) { if (failed(printAddOrSubOperand(emitter, subOp, 1))) return failure(); os << ")"; + return success(); } @@ -1485,6 +1464,7 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::MinOp minOp) { if (failed(printMinMaxOperand(emitter, minOp, 1))) return failure(); os << ")"; + return success(); } @@ -1510,6 +1490,7 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::MaxOp maxOp) { if (failed(printMinMaxOperand(emitter, maxOp, 1))) return failure(); os << ")"; + return success(); } @@ -1530,6 +1511,7 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::NegOp negOp) { os << "neg("; os << emitter.getOrCreateName(src); os << ")"; + return success(); } @@ -1551,6 +1533,7 @@ static LogicalResult printOperation(CppEmitter &emitter, os << "bneg("; os << emitter.getOrCreateName(src); os << ")"; + return success(); } @@ -1574,6 +1557,7 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::BxorOp xorOp) { os << ", "; os << emitter.getOrCreateName(rhs); os << ")"; + return success(); } @@ -1597,6 +1581,7 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::BandOp andOp) { os << ", "; os << emitter.getOrCreateName(rhs); os << ")"; + return success(); } @@ -1620,14 +1605,15 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::BorOp orOp) { os << ", "; os << emitter.getOrCreateName(rhs); os << ")"; + return success(); } // Generate the AddElem op static LogicalResult printOperation(CppEmitter &emitter, - aievec::AddElemOp add_elemOp) { - auto lhs = add_elemOp.getLhs(); - auto rhs = add_elemOp.getRhs(); + aievec::AddElemOp addElemOp) { + auto lhs = addElemOp.getLhs(); + auto rhs = addElemOp.getRhs(); // The sources should have already been emitted if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs)) @@ -1638,33 +1624,34 @@ static LogicalResult printOperation(CppEmitter &emitter, // Generate the initialization for the result // FIXME: move the logic to the op creation and add isAcc to the op attribute bool isAcc = false; - VectorType resType = cast(add_elemOp.getResult().getType()); + auto resType = cast(addElemOp.getResult().getType()); auto resElemType = resType.getElementType(); unsigned resBitWidth = resElemType.getIntOrFloatBitWidth(); unsigned resLaneSize = getVectorLaneSize(resType); - if (isa(resElemType) || (resBitWidth * resLaneSize == 1024)) + if (isa(resElemType) || resBitWidth * resLaneSize == 1024) isAcc = true; - if (failed(emitter.emitAssignPrefix(*add_elemOp, /*isAcc=*/isAcc))) + if (failed(emitter.emitAssignPrefix(*addElemOp, /*isAcc=*/isAcc))) return failure(); os << "add("; - if (failed(printAddElemOrSubElemOperand(emitter, - add_elemOp, 0))) + if (failed(printAddElemOrSubElemOperand(emitter, addElemOp, + 0))) return failure(); os << ", "; - if (failed(printAddElemOrSubElemOperand(emitter, - add_elemOp, 1))) + if (failed(printAddElemOrSubElemOperand(emitter, addElemOp, + 1))) return failure(); os << ")"; + return success(); } // Generate the SubElem op static LogicalResult printOperation(CppEmitter &emitter, - aievec::SubElemOp sub_elemOp) { - auto lhs = sub_elemOp.getLhs(); - auto rhs = sub_elemOp.getRhs(); + aievec::SubElemOp subElemOp) { + auto lhs = subElemOp.getLhs(); + auto rhs = subElemOp.getRhs(); // The sources should have already been emitted if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs)) @@ -1675,25 +1662,26 @@ static LogicalResult printOperation(CppEmitter &emitter, // Generate the initialization for the result // FIXME: move the logic to the op creation and add isAcc to the op attribute bool isAcc = false; - VectorType resType = cast(sub_elemOp.getResult().getType()); + auto resType = cast(subElemOp.getResult().getType()); auto resElemType = resType.getElementType(); unsigned resBitWidth = resElemType.getIntOrFloatBitWidth(); unsigned resLaneSize = getVectorLaneSize(resType); - if (isa(resElemType) || (resBitWidth * resLaneSize == 1024)) + if (isa(resElemType) || resBitWidth * resLaneSize == 1024) isAcc = true; - if (failed(emitter.emitAssignPrefix(*sub_elemOp, /*isAcc=*/isAcc))) + if (failed(emitter.emitAssignPrefix(*subElemOp, /*isAcc=*/isAcc))) return failure(); os << "sub("; - if (failed(printAddElemOrSubElemOperand(emitter, - sub_elemOp, 0))) + if (failed(printAddElemOrSubElemOperand(emitter, subElemOp, + 0))) return failure(); os << ", "; - if (failed(printAddElemOrSubElemOperand(emitter, - sub_elemOp, 1))) + if (failed(printAddElemOrSubElemOperand(emitter, subElemOp, + 1))) return failure(); os << ")"; + return success(); } @@ -1713,15 +1701,16 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::FMAOp fmaOp) { std::string opname; // Create opname based on the result type - VectorType resType = fmaOp.getResult().getType().cast(); + auto resType = fmaOp.getResult().getType().cast(); Type eltType = resType.getElementType(); if (!simpleScheme) { - if (auto iType = eltType.dyn_cast()) { + if (auto iType = llvm::dyn_cast(eltType)) { if (iType.getWidth() == 80) opname = "l"; } else if (eltType.isa()) opname = "fp"; } + opname += fmaOp.getFmsub() ? "msc" : "mac"; if (!simpleScheme && !eltType.isa()) opname += std::to_string(getVectorLaneSize(resType)); @@ -1750,38 +1739,35 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::FMAOp fmaOp) { // Generate the FMAElem op static LogicalResult printOperation(CppEmitter &emitter, - aievec::FMAElemOp fma_elemOp) { - auto acc = fma_elemOp.getAcc(); - auto lhs = fma_elemOp.getLhs(); - auto rhs = fma_elemOp.getRhs(); + aievec::FMAElemOp fmaElemOp) { + auto acc = fmaElemOp.getAcc(); + auto lhs = fmaElemOp.getLhs(); + auto rhs = fmaElemOp.getRhs(); // The sources should have already been emitted if (!emitter.hasValueInScope(acc) || !emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs)) return failure(); - std::string opname; - opname = fma_elemOp.getFmsub() ? "msc_elem" : "mac_elem"; + std::string opname = fmaElemOp.getFmsub() ? "msc_elem" : "mac_elem"; // Create opname based on the lhs and rhs type - VectorType lhsType = fma_elemOp.getLhs().getType().cast(); + auto lhsType = fmaElemOp.getLhs().getType().cast(); Type eltType = lhsType.getElementType(); int32_t lsize = getElementSizeInBits(lhsType); - auto iType = eltType.dyn_cast(); + auto iType = llvm::dyn_cast(eltType); if (iType) { - if (lsize == 32) { + if (lsize == 32) opname += "_16_2"; - } else if (lsize == 16) { + else if (lsize == 16) opname += "_32"; - } else if (lsize == 8) { + else if (lsize == 8) opname += "_32_2"; - } } else if (eltType.isa()) { - if (lsize == 32) { + if (lsize == 32) opname += "_16"; - } else if (lsize == 16) { + else if (lsize == 16) opname += "_16_2"; - } } raw_indented_ostream &os = emitter.ostream(); @@ -1791,11 +1777,11 @@ static LogicalResult printOperation(CppEmitter &emitter, os << " = "; os << opname; os << "("; - if (failed(printFMAOrMulElemOperand(emitter, fma_elemOp, + if (failed(printFMAOrMulElemOperand(emitter, fmaElemOp, iType, lsize, 1))) return failure(); os << ", "; - if (failed(printFMAOrMulElemOperand(emitter, fma_elemOp, + if (failed(printFMAOrMulElemOperand(emitter, fmaElemOp, iType, lsize, 0))) return failure(); os << ", "; @@ -1803,39 +1789,37 @@ static LogicalResult printOperation(CppEmitter &emitter, os << ")"; // Finally, set the name of the result to the accumulator's name - emitter.setName(fma_elemOp->getResult(0), accName); + emitter.setName(fmaElemOp->getResult(0), accName); return success(); } // Generate the FMAConv op static LogicalResult printOperation(CppEmitter &emitter, - aievec::FMAConvOp fma_convOp) { - auto acc = fma_convOp.getAcc(); - auto lhs = fma_convOp.getLhs(); - auto rhs = fma_convOp.getRhs(); + aievec::FMAConvOp fmaConvOp) { + auto acc = fmaConvOp.getAcc(); + auto lhs = fmaConvOp.getLhs(); + auto rhs = fmaConvOp.getRhs(); // The sources should have already been emitted if (!emitter.hasValueInScope(acc) || !emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs)) return failure(); - std::string opname; - opname = fma_convOp.getFmsub() ? "msc_conv" : "mac_conv"; + std::string opname = fmaConvOp.getFmsub() ? "msc_conv" : "mac_conv"; // Create opname based on the lhs and rhs type - VectorType lhsType = fma_convOp.getLhs().getType().cast(); + auto lhsType = fmaConvOp.getLhs().getType().cast(); Type eltType = lhsType.getElementType(); int32_t lsize = getElementSizeInBits(lhsType); - auto iType = eltType.dyn_cast(); + auto iType = llvm::dyn_cast(eltType); // Only support int16 and int8 cases - if (!iType || !(lsize == 16 || lsize == 8)) { + if (!iType || !(lsize == 16 || lsize == 8)) return failure(); - } - int32_t M = fma_convOp.getM(); - int32_t N = fma_convOp.getN(); - opname += ("_" + std::to_string(M) + "x" + std::to_string(N)); + int32_t M = fmaConvOp.getM(); + int32_t N = fmaConvOp.getN(); + opname += "_" + std::to_string(M) + "x" + std::to_string(N); raw_indented_ostream &os = emitter.ostream(); @@ -1844,28 +1828,27 @@ static LogicalResult printOperation(CppEmitter &emitter, os << " = "; os << opname; os << "("; - if (failed(printFMAOrMulConvOperand(emitter, fma_convOp, M, - N, 0))) + if (failed( + printFMAOrMulConvOperand(emitter, fmaConvOp, 0))) return failure(); os << ", "; - if (failed(printFMAOrMulConvOperand(emitter, fma_convOp, M, - N, 1))) + if (failed( + printFMAOrMulConvOperand(emitter, fmaConvOp, 1))) return failure(); os << ", "; os << accName; os << ")"; // Finally, set the name of the result to the accumulator's name - emitter.setName(fma_convOp->getResult(0), accName); + emitter.setName(fmaConvOp->getResult(0), accName); return success(); } // Generate the comparison intrinsics(eq, ne, lt, le, gt, ge) for AIE-ML static LogicalResult printOperation(CppEmitter &emitter, aievec::CmpOp cmpOp) { - if (!AIEML) { + if (!AIEML) return failure(); - } // The lhs and rhs should have already been emitted Value lhs = cmpOp.getLhs(); @@ -1881,27 +1864,26 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::CmpOp cmpOp) { raw_indented_ostream &os = emitter.ostream(); StringRef pred = cmpOp.getPred(); - if (pred == "eq") { + if (pred == "eq") os << "eq"; - } else if (pred == "ne") { + else if (pred == "ne") os << "ne"; - } else if (pred == "slt" || pred == "ult") { + else if (pred == "slt" || pred == "ult") os << "lt"; - } else if (pred == "sle" || pred == "ule") { + else if (pred == "sle" || pred == "ule") os << "le"; - } else if (pred == "sgt" || pred == "ugt") { + else if (pred == "sgt" || pred == "ugt") os << "gt"; - } else if (pred == "sge" || pred == "uge") { + else if (pred == "sge" || pred == "uge") os << "ge"; - } else { + else return failure(); - } os << "("; - VectorType vType = lhs.getType().cast(); - Type eltType = vType.getElementType(); + auto vType = lhs.getType().cast(); - if (eltType.isa() && + if (Type eltType = vType.getElementType(); + eltType.isa() && (pred == "ult" || pred == "ule" || pred == "ugt" || pred == "uge")) { unsigned lanes = getVectorLaneSize(vType); unsigned width = getElementSizeInBits(vType); @@ -1919,14 +1901,14 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::CmpOp cmpOp) { os << emitter.getOrCreateName(rhs); } os << ")"; + return success(); } // Generate the sel intrinsic for AIE-ML static LogicalResult printOperation(CppEmitter &emitter, aievec::SelOp selOp) { - if (!AIEML) { + if (!AIEML) return failure(); - } // The lhs, rhs and sel should have already been emitted Value lhs = selOp.getLhs(); @@ -1950,6 +1932,7 @@ static LogicalResult printOperation(CppEmitter &emitter, aievec::SelOp selOp) { os << ", "; os << emitter.getOrCreateName(sel); os << ")"; + return success(); } @@ -1976,6 +1959,7 @@ static LogicalResult printOperation(CppEmitter &emitter, os << ", "; os << emitter.getOrCreateName(index); os << ")"; + return success(); } @@ -2009,6 +1993,7 @@ static LogicalResult printOperation(CppEmitter &emitter, os << ")"; os << " = "; os << emitter.getOrCreateName(vector); + return success(); } @@ -2034,6 +2019,7 @@ static LogicalResult printOperation(CppEmitter &emitter, os << emitter.getOrCreateName(memref); os << " = "; os << emitter.getOrCreateName(value); + return success(); } @@ -2051,7 +2037,6 @@ static LogicalResult printValueForwardOperation(CppEmitter &emitter, OpTy op) { return failure(); raw_indented_ostream &os = emitter.ostream(); - os << emitter.getOrCreateName(source); return success(); @@ -2079,10 +2064,9 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, // the FuncOp. if (emitter.shouldDeclareVariablesAtTop()) { // Skip the assignment if the emitc.constant has no value. - if (auto oAttr = value.dyn_cast()) { + if (auto oAttr = llvm::dyn_cast(value)) if (oAttr.getValue().empty()) return success(); - } if (failed(emitter.emitVariableAssignment(result))) return failure(); @@ -2090,12 +2074,11 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, } // Emit a variable declaration for an emitc.constant op without value. - if (auto oAttr = value.dyn_cast()) { + if (auto oAttr = llvm::dyn_cast(value)) if (oAttr.getValue().empty()) // The semicolon gets printed by the emitOperation function. return emitter.emitVariableDeclaration(result, /*trailingSemicolon=*/false); - } // Emit a variable declaration. if (failed(emitter.emitAssignPrefix(*operation))) @@ -2107,7 +2090,6 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ConstantOp constantOp) { Operation *operation = constantOp.getOperation(); Attribute value = constantOp.getValue(); - return printConstantOp(emitter, operation, value); } @@ -2115,7 +2097,6 @@ static LogicalResult printOperation(CppEmitter &emitter, arith::ConstantOp constantOp) { Operation *operation = constantOp.getOperation(); Attribute value = constantOp.getValue(); - return printConstantOp(emitter, operation, value); } @@ -2124,8 +2105,7 @@ static LogicalResult printOperation(CppEmitter &emitter, raw_ostream &os = emitter.ostream(); Block &successor = *branchOp.getSuccessor(); - for (auto pair : - llvm::zip(branchOp.getOperands(), successor.getArguments())) { + for (auto pair : zip(branchOp.getOperands(), successor.getArguments())) { Value &operand = std::get<0>(pair); BlockArgument &argument = std::get<1>(pair); os << emitter.getOrCreateName(argument) << " = " @@ -2133,7 +2113,7 @@ static LogicalResult printOperation(CppEmitter &emitter, } os << "goto "; - if (!(emitter.hasBlockLabel(successor))) + if (!emitter.hasBlockLabel(successor)) return branchOp.emitOpError("unable to find label for successor block"); os << emitter.getOrCreateName(successor); return success(); @@ -2151,8 +2131,8 @@ static LogicalResult printOperation(CppEmitter &emitter, os.indent(); // If condition is true. - for (auto pair : llvm::zip(condBranchOp.getTrueOperands(), - trueSuccessor.getArguments())) { + for (auto pair : + zip(condBranchOp.getTrueOperands(), trueSuccessor.getArguments())) { Value &operand = std::get<0>(pair); BlockArgument &argument = std::get<1>(pair); os << emitter.getOrCreateName(argument) << " = " @@ -2160,15 +2140,14 @@ static LogicalResult printOperation(CppEmitter &emitter, } os << "goto "; - if (!(emitter.hasBlockLabel(trueSuccessor))) { + if (!emitter.hasBlockLabel(trueSuccessor)) return condBranchOp.emitOpError("unable to find label for successor block"); - } os << emitter.getOrCreateName(trueSuccessor) << ";\n"; os.unindent() << "} else {\n"; os.indent(); // If condition is false. - for (auto pair : llvm::zip(condBranchOp.getFalseOperands(), - falseSuccessor.getArguments())) { + for (auto pair : + zip(condBranchOp.getFalseOperands(), falseSuccessor.getArguments())) { Value &operand = std::get<0>(pair); BlockArgument &argument = std::get<1>(pair); os << emitter.getOrCreateName(argument) << " = " @@ -2176,12 +2155,12 @@ static LogicalResult printOperation(CppEmitter &emitter, } os << "goto "; - if (!(emitter.hasBlockLabel(falseSuccessor))) { + if (!emitter.hasBlockLabel(falseSuccessor)) return condBranchOp.emitOpError() << "unable to find label for successor block"; - } os << emitter.getOrCreateName(falseSuccessor) << ";\n"; os.unindent() << "}"; + return success(); } @@ -2194,10 +2173,12 @@ static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) { if (failed(emitter.emitOperands(*callOp.getOperation()))) return failure(); os << ")"; + return success(); } -static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) { +static LogicalResult printOperation(CppEmitter &emitter, + emitc::CallOpaqueOp callOp) { raw_ostream &os = emitter.ostream(); Operation &op = *callOp.getOperation(); if (callOp.getCallee() == "getTanhBf16" || @@ -2209,18 +2190,17 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) { callOp.getCallee() == "getFloorBf16") { if (failed(emitter.emitAssignPrefix(op, /*isAcc*/ false))) return failure(); - } else { - if (failed(emitter.emitAssignPrefix(op, /*isAcc*/ true))) - return failure(); - } + } else if (failed(emitter.emitAssignPrefix(op, /*isAcc*/ true))) + return failure(); + os << callOp.getCallee(); auto emitArgs = [&](Attribute attr) -> LogicalResult { - if (auto t = attr.dyn_cast()) { - // Index attributes are treated specially as operand index. + // Index attributes are treated specially as operand index. + if (auto t = llvm::dyn_cast(attr)) if (t.getType().isIndex()) { int64_t idx = t.getInt(); - if ((idx < 0) || (idx >= op.getNumOperands())) + if (idx < 0 || idx >= op.getNumOperands()) return op.emitOpError("invalid operand index"); if (!emitter.hasValueInScope(op.getOperand(idx))) return op.emitOpError("operand ") @@ -2228,7 +2208,6 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) { os << emitter.getOrCreateName(op.getOperand(idx)); return success(); } - } if (failed(emitter.emitAttribute(op.getLoc(), attr))) return failure(); @@ -2252,15 +2231,16 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) { if (failed(emittedArgs)) return failure(); os << ")"; + return success(); } static LogicalResult printOperation(CppEmitter &emitter, emitc::ApplyOp applyOp) { raw_ostream &os = emitter.ostream(); - Operation &op = *applyOp.getOperation(); - if (failed(emitter.emitAssignPrefix(op))) + if (Operation &op = *applyOp.getOperation(); + failed(emitter.emitAssignPrefix(op))) return failure(); os << applyOp.getApplicableOperator(); os << emitter.getOrCreateName(applyOp.getOperand()); @@ -2282,22 +2262,19 @@ static LogicalResult printOperation(CppEmitter &emitter, } static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) { - raw_indented_ostream &os = emitter.ostream(); OperandRange operands = forOp.getInitArgs(); Block::BlockArgListType iterArgs = forOp.getRegionIterArgs(); Operation::result_range results = forOp.getResults(); - if (!emitter.shouldDeclareVariablesAtTop()) { - for (OpResult result : results) { + if (!emitter.shouldDeclareVariablesAtTop()) + for (OpResult result : results) if (failed(emitter.emitVariableDeclaration(result, /*trailingSemicolon=*/true))) return failure(); - } - } - for (auto pair : llvm::zip(iterArgs, operands)) { + for (auto pair : zip(iterArgs, operands)) { if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType()))) return failure(); os << " " << emitter.getOrCreateName(std::get<0>(pair)) << " = "; @@ -2309,6 +2286,7 @@ static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) { if (failed( emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType()))) return failure(); + os << " "; os << emitter.getOrCreateName(forOp.getInductionVar()); os << " = "; @@ -2325,17 +2303,15 @@ static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) { os << "chess_prepare_for_pipelining\n"; // Try to find the upper bound and step of the for operator. // If the bounds are found, print them - auto tc = getTripCount(forOp); - if (tc.first) { - auto step = getStep(forOp); - int64_t lb = - step.first && step.second > 0 ? floorDiv(tc.second, step.second) : 1; - int64_t ub = - step.first && step.second > 0 ? ceilDiv(tc.second, step.second) : 0; + if (auto [constantLoopBound, tripCount] = getTripCount(forOp); + constantLoopBound) { + auto [constantStep, step] = getStep(forOp); + int64_t lb = constantStep && step > 0 ? floorDiv(tripCount, step) : 1; + int64_t ub = constantStep && step > 0 ? ceilDiv(tripCount, step) : 0; os << "chess_loop_range("; os << std::to_string(lb); os << ", "; - if (step.first && step.second > 0) + if (constantStep && step > 0) os << std::to_string(ub); os << ")\n"; } @@ -2350,14 +2326,15 @@ static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) { // the end of a loop iteration and set the result variables after the for // loop. for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) { - bool trailingSemicolon = !isa(*it); - if (failed(emitter.emitOperation(*it, trailingSemicolon))) + if (bool trailingSemicolon = + !isa(*it); + failed(emitter.emitOperation(*it, trailingSemicolon))) return failure(); } Operation *yieldOp = forRegion.getBlocks().front().getTerminator(); // Copy yield operands into iterArgs at the end of a loop iteration. - for (auto pair : llvm::zip(iterArgs, yieldOp->getOperands())) { + for (auto pair : zip(iterArgs, yieldOp->getOperands())) { BlockArgument iterArg = std::get<0>(pair); Value operand = std::get<1>(pair); os << emitter.getOrCreateName(iterArg) << " = " @@ -2367,7 +2344,7 @@ static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) { os.unindent() << "}"; // Copy iterArgs into results after the for loop. - for (auto pair : llvm::zip(results, iterArgs)) { + for (auto pair : zip(results, iterArgs)) { OpResult result = std::get<0>(pair); BlockArgument iterArg = std::get<1>(pair); os << "\n" @@ -2381,13 +2358,11 @@ static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) { static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) { raw_indented_ostream &os = emitter.ostream(); - if (!emitter.shouldDeclareVariablesAtTop()) { - for (OpResult result : ifOp.getResults()) { + if (!emitter.shouldDeclareVariablesAtTop()) + for (OpResult result : ifOp.getResults()) if (failed(emitter.emitVariableDeclaration(result, /*trailingSemicolon=*/true))) return failure(); - } - } os << "if ("; if (failed(emitter.emitOperands(*ifOp.getOperation()))) @@ -2396,26 +2371,23 @@ static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) { os.indent(); Region &thenRegion = ifOp.getThenRegion(); - for (Operation &op : thenRegion.getOps()) { - // Note: This prints a superfluous semicolon if the terminating yield op has - // zero results. + // Note: This prints a superfluous semicolon if the terminating yield op has + // zero results. + for (Operation &op : thenRegion.getOps()) if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true))) return failure(); - } os.unindent() << "}"; - Region &elseRegion = ifOp.getElseRegion(); - if (!elseRegion.empty()) { + if (Region &elseRegion = ifOp.getElseRegion(); !elseRegion.empty()) { os << " else {\n"; os.indent(); - for (Operation &op : elseRegion.getOps()) { - // Note: This prints a superfluous semicolon if the terminating yield op - // has zero results. + // Note: This prints a superfluous semicolon if the terminating yield op + // has zero results. + for (Operation &op : elseRegion.getOps()) if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true))) return failure(); - } os.unindent() << "}"; } @@ -2427,10 +2399,9 @@ static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) { raw_ostream &os = emitter.ostream(); Operation &parentOp = *yieldOp.getOperation()->getParentOp(); - if (yieldOp.getNumOperands() != parentOp.getNumResults()) { + if (yieldOp.getNumOperands() != parentOp.getNumResults()) return yieldOp.emitError("number of operands does not to match the number " "of the parent op's results"); - } if (failed(interleaveWithError( llvm::zip(parentOp.getResults(), yieldOp.getOperands()), @@ -2444,7 +2415,7 @@ static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) { os << emitter.getOrCreateName(operand); return success(); }, - [&]() { os << ";\n"; }))) + [&] { os << ";\n"; }))) return failure(); return success(); @@ -2465,17 +2436,18 @@ static LogicalResult printOperation(CppEmitter &emitter, if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation()))) return failure(); os << ")"; - return success(); } + + return success(); } static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) { CppEmitter::Scope scope(emitter); - for (Operation &op : moduleOp) { + for (Operation &op : moduleOp) if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false))) return failure(); - } + return success(); } @@ -2483,10 +2455,10 @@ static LogicalResult printOperation(CppEmitter &emitter, func::FuncOp functionOp) { // We need to declare variables at top if the function has multiple blocks. if (!emitter.shouldDeclareVariablesAtTop() && - functionOp.getBlocks().size() > 1) { + functionOp.getBlocks().size() > 1) return functionOp.emitOpError( "with multiple blocks needs variables declared at top"); - } + CppEmitter::Scope scope(emitter); // Find any memref dim op in the function, and parse the dimension of each @@ -2508,8 +2480,7 @@ static LogicalResult printOperation(CppEmitter &emitter, return failure(); // If it is a memref argument, we need to check if it has dynamic // shape. If so, the dimensions have to be printed out - MemRefType argType = dyn_cast(type); - if (argType) + if (auto argType = dyn_cast(type)) for (unsigned dim = 0; dim < argType.getRank(); ++dim) if (argType.isDynamicDim(dim)) os << ", size_t"; @@ -2543,10 +2514,9 @@ static LogicalResult printOperation(CppEmitter &emitter, functionOp.walk([&](Operation *op) -> WalkResult { for (OpResult result : op->getResults()) { if (failed(emitter.emitVariableDeclaration( - result, /*trailingSemicolon=*/true))) { - return WalkResult( - op->emitError("unable to declare result variable for op")); - } + result, /*trailingSemicolon=*/true))) + return { + op->emitError("unable to declare result variable for op")}; } return WalkResult::advance(); }); @@ -2556,9 +2526,8 @@ static LogicalResult printOperation(CppEmitter &emitter, Region::BlockListType &blocks = functionOp.getBlocks(); // Create label names for basic blocks. - for (Block &block : blocks) { + for (Block &block : blocks) emitter.getOrCreateName(block); - } // Declare variables for basic block arguments. for (auto it = std::next(blocks.begin()); it != blocks.end(); ++it) { @@ -2568,33 +2537,31 @@ static LogicalResult printOperation(CppEmitter &emitter, return functionOp.emitOpError(" block argument #") << arg.getArgNumber() << " is out of scope"; if (failed( - emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) { + emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) return failure(); - } os << " " << emitter.getOrCreateName(arg) << ";\n"; } } for (Block &block : blocks) { // Only print a label if there is more than one block. - if (blocks.size() > 1) { + if (blocks.size() > 1) if (failed(emitter.emitLabel(block))) return failure(); - } for (Operation &op : block.getOperations()) { // When generating code for an scf.if or std.cond_br op no semicolon needs // to be printed after the closing brace. // When generating code for an scf.for op, printing a trailing semicolon // is handled within the printOperation function. - bool trailingSemicolon = - !isa(op); - - if (failed(emitter.emitOperation( + if (bool trailingSemicolon = + !isa(op); + failed(emitter.emitOperation( op, /*trailingSemicolon=*/trailingSemicolon))) return failure(); } } os.unindent() << "}\n"; + return success(); } @@ -2626,7 +2593,7 @@ std::string CppEmitter::getNewName(std::string prefix) { /// Given a dynamic shaped memref, set its size at position 'index' to // parameter 'result' void CppEmitter::setMemRefDimParam(Value memref, unsigned index, - std::string parameter) { + const std::string ¶meter) { auto p = std::make_pair(memref, index); assert(!paramIndexMapper.count(p) && "memref dimension already set"); paramIndexMapper[p] = parameter; @@ -2643,7 +2610,7 @@ StringRef CppEmitter::getMemRefDimParam(Value memref, unsigned index) { /// associated with it bool CppEmitter::isMemRefDimParam(Value memref, unsigned index) { assert([&] { - MemRefType type = memref.getType().dyn_cast(); + auto type = llvm::dyn_cast(memref.getType()); if (!(type && type.isDynamicDim(index))) { printf("the dimension size at index is not dynamic\n"); return false; @@ -2666,13 +2633,12 @@ StringRef CppEmitter::getOrCreateName(Block &block, std::string prefix) { bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) { switch (val) { case IntegerType::Signless: - return false; case IntegerType::Signed: return false; case IntegerType::Unsigned: return true; } - llvm_unreachable("Unexpected IntegerType::SignednessSemantics"); + llvm::report_fatal_error("Unexpected IntegerType::SignednessSemantics"); } bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); } @@ -2692,40 +2658,37 @@ static std::string getSplatValueOfIntDense(DenseIntElementsAttr dense) { // Get the first float value of a dense type value as a string. static std::string getSplatValueOfFloatDense(DenseFPElementsAttr dense, bool isBFloat = false) { - APFloat apFloat = dense.getSplatValue(); + auto apFloat = dense.getSplatValue(); float splatVal = apFloat.convertToFloat(); std::string firstValue = std::to_string(splatVal); - if (apFloat.isPosInfinity()) { - if (isBFloat) { + if (apFloat.isPosInfinity()) + if (isBFloat) // TODO: Clean this up; emitting largest finite value in lieu of infinity; // system headers do not provide a simple way to initialize a bfloat16 to // infinity. firstValue = std::to_string(0x1.FEp+127f); - } else { + else firstValue = std::to_string(std::numeric_limits::max()); - } - } else if (apFloat.isNegInfinity()) { - if (isBFloat) { + else if (apFloat.isNegInfinity()) + if (isBFloat) firstValue = std::to_string(-0x1.FEp+127f); - } else { + else firstValue = std::to_string(std::numeric_limits::lowest()); - } - } else if (!apFloat.isNonZero()) { + else if (!apFloat.isNonZero()) firstValue = "0"; - } return firstValue; } LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { auto printInt = [&](const APInt &val, bool isUnsigned) { - if (val.getBitWidth() == 1) { + if (val.getBitWidth() == 1) if (val.getBoolValue()) os << "true"; else os << "false"; - } else { + else { SmallString<128> strValue; val.toString(strValue, 10, !isUnsigned, false); os << strValue; @@ -2746,11 +2709,11 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { break; default: break; - }; + } os << strValue; - } else if (val.isNaN()) { + } else if (val.isNaN()) os << "NAN"; - } else if (val.isInfinity()) { + else if (val.isInfinity()) { if (val.isNegative()) os << "-"; os << "INFINITY"; @@ -2758,22 +2721,23 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { }; // Print floating point attributes. - if (auto fAttr = attr.dyn_cast()) { + if (auto fAttr = llvm::dyn_cast(attr)) { printFloat(fAttr.getValue()); return success(); } - if (auto dense = attr.dyn_cast()) { + + if (auto dense = llvm::dyn_cast(attr)) { if (AIEML && dense.isSplat()) { - if (auto vType = dense.getType().dyn_cast()) { - if (auto fType = vType.getElementType().dyn_cast()) { + if (auto vType = llvm::dyn_cast(dense.getType())) + if (auto fType = llvm::dyn_cast(vType.getElementType())) { unsigned width = fType.getWidth(); - std::string splatValue = ""; - if (width == 32) { + std::string splatValue; + if (width == 32) splatValue = getSplatValueOfFloatDense(dense); - } else if (width == 16) { + else if (width == 16) splatValue = getSplatValueOfFloatDense(dense, /*isBFloat*/ true); - } - if (width == 32 || (width == 16 && getVectorLaneSize(vType) == 32)) { + + if (width == 32 || (width == 16 && getVectorLaneSize(vType) == 32)) if (splatValue == "0") { os << "broadcast_zero_"; if (failed(emitType(loc, fType))) @@ -2790,11 +2754,11 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { os << splatValue; os << ")"; } - } else if (width == 16 && getVectorLaneSize(vType) == 16) { + else if (width == 16 && getVectorLaneSize(vType) == 16) { os << "extract_v16bfloat16("; - if (splatValue == "0") { + if (splatValue == "0") os << "broadcast_zero_bfloat16()"; - } else { + else { os << "broadcast_to_v32bfloat16"; os << "(("; if (failed(emitType(loc, fType))) @@ -2806,7 +2770,6 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { os << ", 0)"; } } - } // TODO: Deal with multiple dense value case for AIEML. } else { os << '{'; @@ -2817,19 +2780,20 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { } // Print integer attributes. - if (auto iAttr = attr.dyn_cast()) { - if (auto iType = iAttr.getType().dyn_cast()) { + if (auto iAttr = llvm::dyn_cast(attr)) { + if (auto iType = llvm::dyn_cast(iAttr.getType())) { printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness())); return success(); } - if (auto iType = iAttr.getType().dyn_cast()) { + if (iAttr.getType().dyn_cast()) { printInt(iAttr.getValue(), false); return success(); } } - if (auto dense = attr.dyn_cast()) { - if (auto tType = dense.getType().dyn_cast()) { - if (auto iType = tType.getElementType().dyn_cast()) { + + if (auto dense = llvm::dyn_cast(attr)) { + if (auto tType = llvm::dyn_cast(dense.getType())) { + if (auto iType = llvm::dyn_cast(tType.getElementType())) { os << '{'; interleaveComma(dense, os, [&](const APInt &val) { printInt(val, shouldMapToUnsigned(iType.getSignedness())); @@ -2837,7 +2801,7 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { os << '}'; return success(); } - if (auto iType = tType.getElementType().dyn_cast()) { + if (tType.getElementType().dyn_cast()) { os << '{'; interleaveComma(dense, os, [&](const APInt &val) { printInt(val, false); }); @@ -2845,8 +2809,9 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { return success(); } } - if (auto vType = dense.getType().dyn_cast()) { - if (auto iType = vType.getElementType().dyn_cast()) { + + if (auto vType = llvm::dyn_cast(dense.getType())) { + if (auto iType = llvm::dyn_cast(vType.getElementType())) { unsigned width = iType.getWidth(); if (llvm::all_of(dense, [](const APInt &val) { return val == 0; })) { if (AIEML) { @@ -2867,14 +2832,13 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { } if (AIEML && dense.isSplat()) { - std::string splatValue = ""; - if (width == 32) { + std::string splatValue; + if (width == 32) splatValue = getSplatValueOfIntDense(dense); - } else if (width == 16) { + else if (width == 16) splatValue = getSplatValueOfIntDense(dense); - } else if (width == 8) { + else if (width == 8) splatValue = getSplatValueOfIntDense(dense); - } os << "broadcast_to_"; if (failed(emitType(loc, vType))) return failure(); @@ -2894,7 +2858,7 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { } return success(); } - if (auto iType = vType.getElementType().dyn_cast()) { + if (vType.getElementType().dyn_cast()) { os << '{'; interleaveComma(dense, os, [&](const APInt &val) { printInt(val, false); }); @@ -2905,13 +2869,13 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { } // Print opaque attributes. - if (auto oAttr = attr.dyn_cast()) { + if (auto oAttr = llvm::dyn_cast(attr)) { os << oAttr.getValue(); return success(); } // Print symbolic reference attributes. - if (auto sAttr = attr.dyn_cast()) { + if (auto sAttr = llvm::dyn_cast(attr)) { if (sAttr.getNestedReferences().size() > 1) return emitError(loc, "attribute has more than 1 nested reference"); os << sAttr.getRootReference().getValue(); @@ -2919,7 +2883,7 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { } // Print type attributes. - if (auto type = attr.dyn_cast()) + if (auto type = llvm::dyn_cast(attr)) return emitType(loc, type.getValue()); return emitError(loc, "cannot emit attribute of type ") << attr; @@ -2941,23 +2905,22 @@ CppEmitter::emitOperandsAndAttributes(Operation &op, if (failed(emitOperands(op))) return failure(); // Insert comma in between operands and non-filtered attributes if needed. - if (op.getNumOperands() > 0) { - for (NamedAttribute attr : op.getAttrs()) { - if (!llvm::is_contained(exclude, attr.getName().strref())) { + if (op.getNumOperands() > 0) + for (NamedAttribute attr : op.getAttrs()) + if (!is_contained(exclude, attr.getName().strref())) { os << ", "; break; } - } - } // Emit attributes. auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult { - if (llvm::is_contained(exclude, attr.getName().strref())) + if (is_contained(exclude, attr.getName().strref())) return success(); os << "/* " << attr.getName().getValue() << " */"; if (failed(emitAttribute(op.getLoc(), attr.getValue()))) return failure(); return success(); }; + return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute); } @@ -2967,22 +2930,23 @@ LogicalResult CppEmitter::emitVariableAssignment(OpResult result) { "result variable for the operation has not been declared"); } os << getOrCreateName(result) << " = "; + return success(); } LogicalResult CppEmitter::emitVariableDeclaration(OpResult result, bool trailingSemicolon, bool isAcc) { - if (hasValueInScope(result)) { + if (hasValueInScope(result)) return result.getDefiningOp()->emitError( "result variable for the operation already declared"); - } if (failed( emitType(result.getOwner()->getLoc(), result.getType(), true, isAcc))) return failure(); os << " " << getOrCreateName(result); if (trailingSemicolon) os << ";\n"; + return success(); } @@ -3004,12 +2968,11 @@ LogicalResult CppEmitter::emitAssignPrefix(Operation &op, bool isAcc) { break; } default: - if (!shouldDeclareVariablesAtTop()) { - for (OpResult result : op.getResults()) { + if (!shouldDeclareVariablesAtTop()) + for (OpResult result : op.getResults()) if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true))) return failure(); - } - } + os << "std::tie("; interleaveComma(op.getResults(), os, [&](Value result) { os << getOrCreateName(result); }); @@ -3034,13 +2997,12 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { return success(); LogicalResult status = - llvm::TypeSwitch(&op) + TypeSwitch(&op) // EmitC ops. - .Case( + .Case( [&](auto op) { return printOperation(*this, op); }) .Case([&](auto op) { - StringRef name = op.getInclude(); - if (!includeNames.count(name)) { + if (StringRef name = op.getInclude(); !includeNames.count(name)) { includeNames.insert(name); return printOperation(*this, op); } @@ -3067,16 +3029,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { .Case( [&](auto op) { return printOperation(*this, op); }) - .Case< - aievec::AddOp, aievec::AddElemOp, aievec::ConcatOp, aievec::ExtOp, - aievec::FMAOp, aievec::MulOp, aievec::PackOp, aievec::SelectOp, - aievec::SRSOp, aievec::SubOp, aievec::SubElemOp, aievec::UPDOp, - aievec::UPSOp, aievec::FMAElemOp, aievec::MulElemOp, - aievec::BroadcastOp, aievec::BroadcastScalarOp, aievec::MulConvOp, - aievec::FMAConvOp, aievec::ShiftOp, aievec::ShuffleOp, - aievec::CastOp, aievec::MinOp, aievec::MaxOp, aievec::NegOp, - aievec::CmpOp, aievec::SelOp, aievec::ExtElemOp, aievec::BxorOp, - aievec::BnegOp, aievec::BandOp, aievec::BorOp, aievec::UnpackOp>( + .Case( [&](auto op) { return printOperation(*this, op); }) .Default([&](Operation *) { return op.emitOpError("unable to find printer for op"); @@ -3085,48 +3042,49 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { if (failed(status)) return failure(); os << (trailingSemicolon ? ";\n" : "\n"); + return success(); } LogicalResult CppEmitter::emitType(Location loc, Type type, bool stdintType, bool isAcc) { - if (auto iType = type.dyn_cast()) { + if (auto iType = llvm::dyn_cast(type)) switch (iType.getWidth()) { case 1: - return (os << "bool"), success(); + return os << "bool", success(); case 8: case 16: case 32: - case 64: + case 64: { if (shouldMapToUnsigned(iType.getSignedness())) - return (os << "uint" << iType.getWidth() << (stdintType ? "_t" : "")), - success(); - else - return (os << "int" << iType.getWidth() << (stdintType ? "_t" : "")), + return os << "uint" << iType.getWidth() << (stdintType ? "_t" : ""), success(); + return os << "int" << iType.getWidth() << (stdintType ? "_t" : ""), + success(); + } case 48: case 80: - return (os << "acc" << iType.getWidth()), success(); + return os << "acc" << iType.getWidth(), success(); default: return emitError(loc, "cannot emit integer type ") << type; } - } - if (auto fType = type.dyn_cast()) { + + if (auto fType = llvm::dyn_cast(type)) switch (fType.getWidth()) { case 16: - return (os << "bfloat16"), success(); + return os << "bfloat16", success(); case 32: - return (os << "float"), success(); + return os << "float", success(); case 64: - return (os << "double"), success(); + return os << "double", success(); default: return emitError(loc, "cannot emit float type ") << type; } - } - if (auto iType = type.dyn_cast()) - return (os << "size_t"), success(); - if (auto tType = type.dyn_cast()) { + if (llvm::dyn_cast(type)) + return os << "size_t", success(); + + if (auto tType = llvm::dyn_cast(type)) { if (!tType.hasRank()) return emitError(loc, "cannot emit unranked tensor type"); if (!tType.hasStaticShape()) @@ -3142,22 +3100,22 @@ LogicalResult CppEmitter::emitType(Location loc, Type type, bool stdintType, os << ">"; return success(); } - if (auto tType = type.dyn_cast()) + if (auto tType = llvm::dyn_cast(type)) return emitTupleType(loc, tType.getTypes()); - if (auto oType = type.dyn_cast()) { + if (auto oType = llvm::dyn_cast(type)) { os << oType.getValue(); return success(); } // Types added for AIE // MemRefType: printed as 'eltType'* - if (auto tType = type.dyn_cast()) { + if (auto tType = llvm::dyn_cast(type)) { if (failed(emitType(loc, tType.getElementType()))) return failure(); os << " * restrict"; return success(); } // VectorType: printed as v'lane''eltType' - if (auto tType = type.dyn_cast()) { + if (auto tType = llvm::dyn_cast(type)) { Type eltType = tType.getElementType(); if (tType.getRank() != 1) return failure(); @@ -3169,23 +3127,22 @@ LogicalResult CppEmitter::emitType(Location loc, Type type, bool stdintType, if (eltType.isa()) { // AIE-ML has `ups_to_v16acc32`, `ups_to_v16acc64`, `ups_to_v32acc32` // intrinsics - unsigned width = eltType.cast().getWidth(); - if ((dimSize == 16 && width == 64) || (dimSize == 32 && width == 32) || - (dimSize == 16 && width == 32)) { - return (os << "acc" << width), success(); - } else { - return failure(); - } - } else if (eltType.isa()) { - // AIE-ML only has a `ups_to_v16accfloat` intrinsic - return (os << "accfloat"), success(); + if (unsigned width = eltType.cast().getWidth(); + (dimSize == 16 && width == 64) || (dimSize == 32 && width == 32) || + (dimSize == 16 && width == 32)) + return os << "acc" << width, success(); + return failure(); } + if (eltType.isa()) + // AIE-ML only has a `ups_to_v16accfloat` intrinsic + return os << "accfloat", success(); } if (failed(emitType(loc, eltType, false))) return failure(); return success(); } + return emitError(loc, "cannot emit type ") << type; } diff --git a/test/aievec/conv2d_i16_after_polygeist.mlir b/test/aievec/conv2d_i16_after_polygeist.mlir index 519f881798..fb8d1b6337 100644 --- a/test/aievec/conv2d_i16_after_polygeist.mlir +++ b/test/aievec/conv2d_i16_after_polygeist.mlir @@ -1,6 +1,6 @@ // RUN: aie-opt %s -affine-super-vectorize="virtual-vector-size=16" --aie-vectorize="shift=10 zero-offset=4" -aieml=true -canonicalize -split-input-file | FileCheck %s -module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>>, llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu", "polygeist.target-cpu" = "x86-64", "polygeist.target-features" = "+cx8,+fxsr,+mmx,+sse,+sse2,+x87", "polygeist.tune-cpu" = "generic"} { +module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>>, llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu", "polygeist.target-cpu" = "x86-64", "polygeist.target-features" = "+cx8,+fxsr,+mmx,+sse,+sse2,+x87", "polygeist.tune-cpu" = "generic"} { func.func @conv2d(%arg0: memref, %arg1: memref, %arg2: memref) attributes {llvm.linkage = #llvm.linkage} { affine.for %arg3 = 0 to 16 { affine.for %arg4 = 0 to 256 { diff --git a/test/aievec/conv2d_i16_after_polygeist_2.mlir b/test/aievec/conv2d_i16_after_polygeist_2.mlir index feb832771f..fdb14c028b 100644 --- a/test/aievec/conv2d_i16_after_polygeist_2.mlir +++ b/test/aievec/conv2d_i16_after_polygeist_2.mlir @@ -1,6 +1,6 @@ // RUN: aie-opt %s -affine-super-vectorize="virtual-vector-size=16" --aie-vectorize="shift=10 zero-offset=4" -aieml=true -canonicalize -split-input-file | FileCheck %s -module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>>, llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu", "polygeist.target-cpu" = "x86-64", "polygeist.target-features" = "+cx8,+fxsr,+mmx,+sse,+sse2,+x87", "polygeist.tune-cpu" = "generic"} { +module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>>, llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu", "polygeist.target-cpu" = "x86-64", "polygeist.target-features" = "+cx8,+fxsr,+mmx,+sse,+sse2,+x87", "polygeist.tune-cpu" = "generic"} { func.func @conv2d(%arg0: memref, %arg1: memref, %arg2: memref) attributes {llvm.linkage = #llvm.linkage} { affine.for %arg3 = 0 to 16 { affine.for %arg4 = 0 to 256 { diff --git a/test/aievec/conv2d_i8_after_polygeist.mlir b/test/aievec/conv2d_i8_after_polygeist.mlir index 4dba7d61e0..de4949bff5 100644 --- a/test/aievec/conv2d_i8_after_polygeist.mlir +++ b/test/aievec/conv2d_i8_after_polygeist.mlir @@ -1,6 +1,6 @@ // RUN: aie-opt %s -affine-super-vectorize="virtual-vector-size=32" -aieml=true --aie-vectorize="shift=0 dup-factor=2" -canonicalize | FileCheck %s -module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>>, llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu", "polygeist.target-cpu" = "x86-64", "polygeist.target-features" = "+cx8,+fxsr,+mmx,+sse,+sse2,+x87", "polygeist.tune-cpu" = "generic"} { +module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>>, llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu", "polygeist.target-cpu" = "x86-64", "polygeist.target-features" = "+cx8,+fxsr,+mmx,+sse,+sse2,+x87", "polygeist.tune-cpu" = "generic"} { func.func @conv2d(%arg0: memref, %arg1: memref, %arg2: memref) attributes {llvm.linkage = #llvm.linkage} { affine.for %arg3 = 0 to 16 { affine.for %arg4 = 0 to 256 { diff --git a/test/aievec/gemm64_int16_unroll32_after_polygeist.mlir b/test/aievec/gemm64_int16_unroll32_after_polygeist.mlir index fc2b305861..8e78b59b1f 100644 --- a/test/aievec/gemm64_int16_unroll32_after_polygeist.mlir +++ b/test/aievec/gemm64_int16_unroll32_after_polygeist.mlir @@ -1,6 +1,6 @@ // RUN: aie-opt %s -affine-super-vectorize="virtual-vector-size=32" -aieml=true --aie-vectorize -canonicalize | FileCheck %s -module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>>, llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu", "polygeist.target-cpu" = "x86-64", "polygeist.target-features" = "+cx8,+fxsr,+mmx,+sse,+sse2,+x87", "polygeist.tune-cpu" = "generic"} { +module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>>, llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu", "polygeist.target-cpu" = "x86-64", "polygeist.target-features" = "+cx8,+fxsr,+mmx,+sse,+sse2,+x87", "polygeist.tune-cpu" = "generic"} { func.func @matmul(%arg0: memref, %arg1: memref, %arg2: memref) attributes {llvm.linkage = #llvm.linkage} { affine.for %arg3 = 0 to 64 { affine.for %arg4 = 0 to 64 { diff --git a/utils/clone-llvm.sh b/utils/clone-llvm.sh index 455939162f..a650d10efb 100755 --- a/utils/clone-llvm.sh +++ b/utils/clone-llvm.sh @@ -15,7 +15,7 @@ # The LLVM commit to use. # TODO: create a branch or a tag instead, to avoid fetching main and # this commit later. -commithash=e9453f3c3c7e682e39952c9e18e6b1f8152b0ffa +commithash=3422203abccd7ff4bd839011b780b41f8fa8ca8f here=$PWD