From 61ec4b056f8f4f959c7a66d9c2fe1c283ff4a46a Mon Sep 17 00:00:00 2001 From: max Date: Wed, 13 Dec 2023 00:21:09 -0600 Subject: [PATCH] enable aievec python bindings --- include/aie/Dialect/AIEVec/IR/AIEVecOps.h | 1 + include/aie/Dialect/AIEVec/IR/AIEVecOps.td | 4 +- lib/Dialect/AIEVec/IR/AIEVecOps.cpp | 64 ++++++++++++++----- python/CMakeLists.txt | 9 +++ python/dialects/AIEVecBinding.td | 14 +++++ python/dialects/aievec.py | 7 +++ test/python/tosa_aievec.py | 73 ++++++++++++++++++++-- 7 files changed, 152 insertions(+), 20 deletions(-) create mode 100644 python/dialects/AIEVecBinding.td create mode 100644 python/dialects/aievec.py diff --git a/include/aie/Dialect/AIEVec/IR/AIEVecOps.h b/include/aie/Dialect/AIEVec/IR/AIEVecOps.h index e686e43701..b1a26c3c43 100644 --- a/include/aie/Dialect/AIEVec/IR/AIEVecOps.h +++ b/include/aie/Dialect/AIEVec/IR/AIEVecOps.h @@ -14,6 +14,7 @@ #define AIE_DIALECT_AIEVEC_IR_AIEVECOPS_H #include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "AIEVecDialect.h" diff --git a/include/aie/Dialect/AIEVec/IR/AIEVecOps.td b/include/aie/Dialect/AIEVec/IR/AIEVecOps.td index 2486450563..7180728df1 100644 --- a/include/aie/Dialect/AIEVec/IR/AIEVecOps.td +++ b/include/aie/Dialect/AIEVec/IR/AIEVecOps.td @@ -16,6 +16,8 @@ include "aie/Dialect/AIE/IR/AIEAttrs.td" include "aie/Dialect/AIEVec/IR/AIEVecTypes.td" include "aie/Dialect/AIEVec/IR/AIEVecTypeConstraints.td" + +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" // Base class for AIE dialect ops. @@ -448,7 +450,7 @@ def AIEVec_UPDOp: def AIEVec_ConcatOp: AIEVec_Op<"concat", [ - Pure + Pure, InferTypeOpAdaptor, ]>, Arguments<(ins Variadic:$sources)>, Results<(outs AnyVector:$result)> { diff --git a/lib/Dialect/AIEVec/IR/AIEVecOps.cpp b/lib/Dialect/AIEVec/IR/AIEVecOps.cpp index 6aa48d3a0f..dac1bc20f8 100644 --- a/lib/Dialect/AIEVec/IR/AIEVecOps.cpp +++ b/lib/Dialect/AIEVec/IR/AIEVecOps.cpp @@ -575,11 +575,14 @@ ParseResult BroadcastScalarOp::parse(OpAsmParser &parser, // some specializations to print those fields specifically for FMA op. // Print the accumulator -template void printAccumulator(OpAsmPrinter &p, T op); -template <> inline void printAccumulator(OpAsmPrinter &p, aievec::FMAOp op) { +template +void printAccumulator(OpAsmPrinter &p, T op); +template <> +inline void printAccumulator(OpAsmPrinter &p, aievec::FMAOp op) { p << ", " << op.getAcc(); } -template <> inline void printAccumulator(OpAsmPrinter &p, aievec::MulOp op) {} +template <> +inline void printAccumulator(OpAsmPrinter &p, aievec::MulOp op) {} // Mark fmsub indicator as elided if the FMA op is not fmsub template @@ -595,7 +598,8 @@ inline void elideFMSubAttr(aievec::MulOp, SmallVector &elidedAttrs) {} // Print out Mul and FMA op. -template static void printMulFMAOp(OpAsmPrinter &p, T op) { +template +static void printMulFMAOp(OpAsmPrinter &p, T op) { // Print the left operand p << " " << op.getLhs(); // Print the right operand @@ -632,7 +636,8 @@ void aievec::FMAOp::print(OpAsmPrinter &p) { } // Verify Mul and FMA op. -template LogicalResult verifyMulFMAOp(T op) { +template +LogicalResult verifyMulFMAOp(T op) { // Verify the types auto lhsType = op.getLhs().getType().template dyn_cast(); auto rhsType = op.getRhs().getType().template dyn_cast(); @@ -776,7 +781,8 @@ ParseResult FMAOp::parse(OpAsmParser &parser, OperationState &result) { // FMAElemOp and MULElemOp. // Print the accumulator -template void printAccumulator(OpAsmPrinter &p, T op); +template +void printAccumulator(OpAsmPrinter &p, T op); template <> inline void printAccumulator(OpAsmPrinter &p, aievec::FMAElemOp op) { p << ", " << op.getAcc(); @@ -799,7 +805,8 @@ inline void elideFMSubAttr(aievec::MulElemOp op, SmallVector &elidedAttrs) {} // Print out MulElem and FMAElem op. -template static void printMulFMAElemOp(OpAsmPrinter &p, T op) { +template +static void printMulFMAElemOp(OpAsmPrinter &p, T op) { // Print the left operand p << " " << op.getLhs(); // Print the right operand @@ -828,7 +835,8 @@ void aievec::FMAElemOp::print(OpAsmPrinter &p) { } // Verify MulElem and FMAElem op. -template LogicalResult verifyMulFMAElemOp(T op) { +template +LogicalResult verifyMulFMAElemOp(T op) { // Verify the types auto lhsType = op.getLhs().getType().template dyn_cast(); auto rhsType = op.getRhs().getType().template dyn_cast(); @@ -957,7 +965,8 @@ ParseResult FMAElemOp::parse(OpAsmParser &parser, OperationState &result) { //===----------------------------------------------------------------------===// // Print out Add and Sub op. -template void printAddSubOp(OpAsmPrinter &p, T op) { +template +void printAddSubOp(OpAsmPrinter &p, T op) { // Print the lhs operand p << " " << op.getLhs(); // Print the rhs operand @@ -991,7 +1000,8 @@ void aievec::SubOp::print(OpAsmPrinter &p) { } // Verify Add and Sub op. -template LogicalResult verifyAddSubOp(T op) { +template +LogicalResult verifyAddSubOp(T op) { // Verify the types auto resultType = op.getResult().getType().template dyn_cast(); auto lhsType = op.getLhs().getType().template dyn_cast(); @@ -1153,6 +1163,25 @@ ParseResult ConcatOp::parse(OpAsmParser &parser, OperationState &result) { return parser.addTypeToList(resultType, result.types); } +LogicalResult +ConcatOp::inferReturnTypes(MLIRContext *, std::optional, + ConcatOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnTypes) { + SmallVector srcs(adaptor.getSources().begin(), + adaptor.getSources().end()); + unsigned totalLength = 0; + for (auto source : srcs) { + VectorType type = llvm::dyn_cast(source.getType()); + assert(type.getRank() == 1 && + "only rank 1 vectors currently supported by concat"); + totalLength += type.getDimSize(0); + } + inferredReturnTypes.push_back(VectorType::get( + {totalLength}, + srcs[0].getType().dyn_cast().getElementType())); + return success(); +} + //===----------------------------------------------------------------------===// // ExtOp //===----------------------------------------------------------------------===// @@ -1356,7 +1385,8 @@ ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) { //===----------------------------------------------------------------------===// // Print out Pack and Unpack op. -template static void printPackUnpackOp(OpAsmPrinter &p, T op) { +template +static void printPackUnpackOp(OpAsmPrinter &p, T op) { // Print the source vector p << " " << op.getSource(); @@ -1372,7 +1402,8 @@ void PackOp::print(OpAsmPrinter &p) { printPackUnpackOp(p, *this); } void UnpackOp::print(OpAsmPrinter &p) { printPackUnpackOp(p, *this); } // Verify Pack and Unpack op. -template LogicalResult verifyPackUnpackOp(T op) { +template +LogicalResult verifyPackUnpackOp(T op) { // Verify the types auto sourceType = op.getSource().getType().template dyn_cast(); auto resultType = op.getResult().getType().template dyn_cast(); @@ -1626,7 +1657,8 @@ ParseResult ShuffleOp::parse(OpAsmParser &parser, OperationState &result) { // FMAConvOp and MULConvOp. // Print the accumulator -template void printAccumulator(OpAsmPrinter &p, T op); +template +void printAccumulator(OpAsmPrinter &p, T op); template <> inline void printAccumulator(OpAsmPrinter &p, aievec::FMAConvOp op) { p << ", " << op.getAcc(); @@ -1649,7 +1681,8 @@ inline void elideFMSubAttr(MulConvOp op, SmallVector &elidedAttrs) {} // Print out MulConv and FMAConv op. -template static void printMulFMAConvOp(OpAsmPrinter &p, T op) { +template +static void printMulFMAConvOp(OpAsmPrinter &p, T op) { // Print the left operand p << " " << op.getLhs(); // Print the right operand @@ -1678,7 +1711,8 @@ void aievec::FMAConvOp::print(OpAsmPrinter &p) { } // Verify MulConv and FMAConv op. -template LogicalResult verifyMulFMAConvOp(T op) { +template +LogicalResult verifyMulFMAConvOp(T op) { // Verify the types auto lhsType = op.getLhs().getType().template dyn_cast(); auto rhsType = op.getRhs().getType().template dyn_cast(); diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index a0fd22821c..956eaebc0d 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -58,6 +58,15 @@ declare_mlir_dialect_python_bindings( DIALECT_NAME AIEX ) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT AIEPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}" + TD_FILE dialects/AIEVecBinding.td + SOURCES + dialects/aievec.py + DIALECT_NAME aievec +) + configure_file(compiler/aiecc/configure.py.in aie/compiler/aiecc/configure.py) set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/aie/compiler/aiecc/configure.py" diff --git a/python/dialects/AIEVecBinding.td b/python/dialects/AIEVecBinding.td new file mode 100644 index 0000000000..4eae05c0e0 --- /dev/null +++ b/python/dialects/AIEVecBinding.td @@ -0,0 +1,14 @@ +//===- AIEVecBinding.td --------------------------------------*- tablegen -*-===// +// +// Copyright (C) 2023, Advanced Micro Devices, Inc. +// +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef AIEVEC_BINDING_TD +#define AIEVEC_BINDING_TD + +include "aie/Dialect/AIEVec/IR/AIEVecOps.td" + +#endif // AIEVEC_BINDING_TD diff --git a/python/dialects/aievec.py b/python/dialects/aievec.py new file mode 100644 index 0000000000..aab7a35ebd --- /dev/null +++ b/python/dialects/aievec.py @@ -0,0 +1,7 @@ +from ._aievec_ops_gen import * + +from .._mlir_libs._aie import * +from .._mlir_libs import get_dialect_registry + +# Comes from _aie +register_dialect(get_dialect_registry()) diff --git a/test/python/tosa_aievec.py b/test/python/tosa_aievec.py index faa9d2c198..cf33ffcee2 100644 --- a/test/python/tosa_aievec.py +++ b/test/python/tosa_aievec.py @@ -2,14 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import inspect from pathlib import Path -from textwrap import dedent +# noinspection PyUnresolvedReferences +import aie.dialects.aievec from aie.extras.dialects import arith from aie.extras.dialects.func import func -from aie.extras.util import mlir_mod_ctx -from aie.ir import ShapedType +from aie.extras import types as T +from aie.ir import ShapedType, AffineMap +from aie.dialects import vector, aievec, scf from util import construct_and_print_module -from inspect import currentframe, getframeinfo # RUN: %python %s | FileCheck %s @@ -41,3 +42,67 @@ def demo_fun1(): assert hasattr(demo_fun1, "emit") assert inspect.ismethod(demo_fun1.emit) demo_fun1.emit() + + +@construct_and_print_module +def test_aievec(): + @func + def mul_mul( + A: T.memref(2048, T.f32()), + B: T.memref(2048, T.f32()), + C: T.memref(2048, T.f32()), + d: T.f32(), + ): + v0 = vector.broadcast(T.vector(8, T.f32()), d) + v1 = aievec.concat([v0, v0]) + for i in scf.for_(0, 2048, 8): + v2 = aievec.upd(T.vector(8, T.f32()), A, [i]) + v3 = aievec.upd(T.vector(8, T.f32()), B, [i]) + v4 = aievec.mul( + T.vector(8, T.f32()), + v1, + v2, + xoffsets="0x76543210", + xstart="0", + zoffsets="0x76543210", + zstart="0", + ) + v5 = aievec.concat([v4, v4]) + v6 = aievec.mul( + T.vector(8, T.f32()), + v5, + v3, + xoffsets="0x76543210", + xstart="0", + zoffsets="0x76543210", + zstart="0", + ) + vector.transfer_write( + None, + v6, + C, + [i], + AffineMap.get_identity(1), + in_bounds=[True], + ) + + scf.yield_([]) + + # CHECK-LABEL: func.func @mul_mul( + # CHECK-SAME: %[[VAL_0:.*]]: memref<2048xf32>, %[[VAL_1:.*]]: memref<2048xf32>, %[[VAL_2:.*]]: memref<2048xf32>, %[[VAL_3:.*]]: f32) { + # CHECK: %[[VAL_4:.*]] = vector.broadcast %[[VAL_3]] : f32 to vector<8xf32> + # CHECK: %[[VAL_5:.*]] = aievec.concat %[[VAL_4]], %[[VAL_4]] : vector<8xf32>, vector<16xf32> + # CHECK: %[[VAL_6:.*]] = arith.constant 0 : index + # CHECK: %[[VAL_7:.*]] = arith.constant 2048 : index + # CHECK: %[[VAL_8:.*]] = arith.constant 8 : index + # CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] { + # CHECK: %[[VAL_10:.*]] = aievec.upd %[[VAL_0]]{{\[}}%[[VAL_9]]] {index = 0 : i8, offset = 0 : i32} : memref<2048xf32>, vector<8xf32> + # CHECK: %[[VAL_11:.*]] = aievec.upd %[[VAL_1]]{{\[}}%[[VAL_9]]] {index = 0 : i8, offset = 0 : i32} : memref<2048xf32>, vector<8xf32> + # CHECK: %[[VAL_12:.*]] = aievec.mul %[[VAL_5]], %[[VAL_10]] {xoffsets = "0x76543210", xstart = "0", zoffsets = "0x76543210", zstart = "0"} : vector<16xf32>, vector<8xf32>, vector<8xf32> + # CHECK: %[[VAL_13:.*]] = aievec.concat %[[VAL_12]], %[[VAL_12]] : vector<8xf32>, vector<16xf32> + # CHECK: %[[VAL_14:.*]] = aievec.mul %[[VAL_13]], %[[VAL_11]] {xoffsets = "0x76543210", xstart = "0", zoffsets = "0x76543210", zstart = "0"} : vector<16xf32>, vector<8xf32>, vector<8xf32> + # CHECK: vector.transfer_write %[[VAL_14]], %[[VAL_2]]{{\[}}%[[VAL_9]]] {in_bounds = [true]} : vector<8xf32>, memref<2048xf32> + # CHECK: } + # CHECK: return + # CHECK: } + mul_mul.emit()