Skip to content

Commit

Permalink
enable aievec python bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Dec 14, 2023
1 parent 7b3b21e commit 61ec4b0
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 20 deletions.
1 change: 1 addition & 0 deletions include/aie/Dialect/AIEVec/IR/AIEVecOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define AIE_DIALECT_AIEVEC_IR_AIEVECOPS_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#include "AIEVecDialect.h"
Expand Down
4 changes: 3 additions & 1 deletion include/aie/Dialect/AIEVec/IR/AIEVecOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
include "aie/Dialect/AIE/IR/AIEAttrs.td"
include "aie/Dialect/AIEVec/IR/AIEVecTypes.td"
include "aie/Dialect/AIEVec/IR/AIEVecTypeConstraints.td"

include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

// Base class for AIE dialect ops.
Expand Down Expand Up @@ -448,7 +450,7 @@ def AIEVec_UPDOp:

def AIEVec_ConcatOp:
AIEVec_Op<"concat", [
Pure
Pure, InferTypeOpAdaptor,
]>,
Arguments<(ins Variadic<AnyVector>:$sources)>,
Results<(outs AnyVector:$result)> {
Expand Down
64 changes: 49 additions & 15 deletions lib/Dialect/AIEVec/IR/AIEVecOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,11 +575,14 @@ ParseResult BroadcastScalarOp::parse(OpAsmParser &parser,
// some specializations to print those fields specifically for FMA op.

// Print the accumulator
template <typename T> void printAccumulator(OpAsmPrinter &p, T op);
template <> inline void printAccumulator(OpAsmPrinter &p, aievec::FMAOp op) {
template <typename T>
void printAccumulator(OpAsmPrinter &p, T op);
template <>
inline void printAccumulator(OpAsmPrinter &p, aievec::FMAOp op) {
p << ", " << op.getAcc();
}
template <> inline void printAccumulator(OpAsmPrinter &p, aievec::MulOp op) {}
template <>
inline void printAccumulator(OpAsmPrinter &p, aievec::MulOp op) {}

// Mark fmsub indicator as elided if the FMA op is not fmsub
template <typename T>
Expand All @@ -595,7 +598,8 @@ inline void elideFMSubAttr(aievec::MulOp,
SmallVector<StringRef, 10> &elidedAttrs) {}

// Print out Mul and FMA op.
template <typename T> static void printMulFMAOp(OpAsmPrinter &p, T op) {
template <typename T>
static void printMulFMAOp(OpAsmPrinter &p, T op) {
// Print the left operand
p << " " << op.getLhs();
// Print the right operand
Expand Down Expand Up @@ -632,7 +636,8 @@ void aievec::FMAOp::print(OpAsmPrinter &p) {
}

// Verify Mul and FMA op.
template <typename T> LogicalResult verifyMulFMAOp(T op) {
template <typename T>
LogicalResult verifyMulFMAOp(T op) {
// Verify the types
auto lhsType = op.getLhs().getType().template dyn_cast<VectorType>();
auto rhsType = op.getRhs().getType().template dyn_cast<VectorType>();
Expand Down Expand Up @@ -776,7 +781,8 @@ ParseResult FMAOp::parse(OpAsmParser &parser, OperationState &result) {
// FMAElemOp and MULElemOp.

// Print the accumulator
template <typename T> void printAccumulator(OpAsmPrinter &p, T op);
template <typename T>
void printAccumulator(OpAsmPrinter &p, T op);
template <>
inline void printAccumulator(OpAsmPrinter &p, aievec::FMAElemOp op) {
p << ", " << op.getAcc();
Expand All @@ -799,7 +805,8 @@ inline void elideFMSubAttr(aievec::MulElemOp op,
SmallVector<StringRef, 4> &elidedAttrs) {}

// Print out MulElem and FMAElem op.
template <typename T> static void printMulFMAElemOp(OpAsmPrinter &p, T op) {
template <typename T>
static void printMulFMAElemOp(OpAsmPrinter &p, T op) {
// Print the left operand
p << " " << op.getLhs();
// Print the right operand
Expand Down Expand Up @@ -828,7 +835,8 @@ void aievec::FMAElemOp::print(OpAsmPrinter &p) {
}

// Verify MulElem and FMAElem op.
template <typename T> LogicalResult verifyMulFMAElemOp(T op) {
template <typename T>
LogicalResult verifyMulFMAElemOp(T op) {
// Verify the types
auto lhsType = op.getLhs().getType().template dyn_cast<VectorType>();
auto rhsType = op.getRhs().getType().template dyn_cast<VectorType>();
Expand Down Expand Up @@ -957,7 +965,8 @@ ParseResult FMAElemOp::parse(OpAsmParser &parser, OperationState &result) {
//===----------------------------------------------------------------------===//

// Print out Add and Sub op.
template <typename T> void printAddSubOp(OpAsmPrinter &p, T op) {
template <typename T>
void printAddSubOp(OpAsmPrinter &p, T op) {
// Print the lhs operand
p << " " << op.getLhs();
// Print the rhs operand
Expand Down Expand Up @@ -991,7 +1000,8 @@ void aievec::SubOp::print(OpAsmPrinter &p) {
}

// Verify Add and Sub op.
template <typename T> LogicalResult verifyAddSubOp(T op) {
template <typename T>
LogicalResult verifyAddSubOp(T op) {
// Verify the types
auto resultType = op.getResult().getType().template dyn_cast<VectorType>();
auto lhsType = op.getLhs().getType().template dyn_cast<VectorType>();
Expand Down Expand Up @@ -1153,6 +1163,25 @@ ParseResult ConcatOp::parse(OpAsmParser &parser, OperationState &result) {
return parser.addTypeToList(resultType, result.types);
}

LogicalResult
ConcatOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
ConcatOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
SmallVector<Value, 8> srcs(adaptor.getSources().begin(),
adaptor.getSources().end());
unsigned totalLength = 0;
for (auto source : srcs) {
VectorType type = llvm::dyn_cast<VectorType>(source.getType());
assert(type.getRank() == 1 &&
"only rank 1 vectors currently supported by concat");
totalLength += type.getDimSize(0);
}
inferredReturnTypes.push_back(VectorType::get(
{totalLength},
srcs[0].getType().dyn_cast<VectorType>().getElementType()));
return success();
}

//===----------------------------------------------------------------------===//
// ExtOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1356,7 +1385,8 @@ ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
//===----------------------------------------------------------------------===//

// Print out Pack and Unpack op.
template <typename T> static void printPackUnpackOp(OpAsmPrinter &p, T op) {
template <typename T>
static void printPackUnpackOp(OpAsmPrinter &p, T op) {
// Print the source vector
p << " " << op.getSource();

Expand All @@ -1372,7 +1402,8 @@ void PackOp::print(OpAsmPrinter &p) { printPackUnpackOp<PackOp>(p, *this); }
void UnpackOp::print(OpAsmPrinter &p) { printPackUnpackOp<UnpackOp>(p, *this); }

// Verify Pack and Unpack op.
template <typename T> LogicalResult verifyPackUnpackOp(T op) {
template <typename T>
LogicalResult verifyPackUnpackOp(T op) {
// Verify the types
auto sourceType = op.getSource().getType().template dyn_cast<VectorType>();
auto resultType = op.getResult().getType().template dyn_cast<VectorType>();
Expand Down Expand Up @@ -1626,7 +1657,8 @@ ParseResult ShuffleOp::parse(OpAsmParser &parser, OperationState &result) {
// FMAConvOp and MULConvOp.

// Print the accumulator
template <typename T> void printAccumulator(OpAsmPrinter &p, T op);
template <typename T>
void printAccumulator(OpAsmPrinter &p, T op);
template <>
inline void printAccumulator(OpAsmPrinter &p, aievec::FMAConvOp op) {
p << ", " << op.getAcc();
Expand All @@ -1649,7 +1681,8 @@ inline void elideFMSubAttr(MulConvOp op,
SmallVector<StringRef, 4> &elidedAttrs) {}

// Print out MulConv and FMAConv op.
template <typename T> static void printMulFMAConvOp(OpAsmPrinter &p, T op) {
template <typename T>
static void printMulFMAConvOp(OpAsmPrinter &p, T op) {
// Print the left operand
p << " " << op.getLhs();
// Print the right operand
Expand Down Expand Up @@ -1678,7 +1711,8 @@ void aievec::FMAConvOp::print(OpAsmPrinter &p) {
}

// Verify MulConv and FMAConv op.
template <typename T> LogicalResult verifyMulFMAConvOp(T op) {
template <typename T>
LogicalResult verifyMulFMAConvOp(T op) {
// Verify the types
auto lhsType = op.getLhs().getType().template dyn_cast<VectorType>();
auto rhsType = op.getRhs().getType().template dyn_cast<VectorType>();
Expand Down
9 changes: 9 additions & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ declare_mlir_dialect_python_bindings(
DIALECT_NAME AIEX
)

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT AIEPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
TD_FILE dialects/AIEVecBinding.td
SOURCES
dialects/aievec.py
DIALECT_NAME aievec
)

configure_file(compiler/aiecc/configure.py.in aie/compiler/aiecc/configure.py)
set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS
"${CMAKE_CURRENT_BINARY_DIR}/aie/compiler/aiecc/configure.py"
Expand Down
14 changes: 14 additions & 0 deletions python/dialects/AIEVecBinding.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//===- AIEVecBinding.td --------------------------------------*- tablegen -*-===//
//
// Copyright (C) 2023, Advanced Micro Devices, Inc.
//
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef AIEVEC_BINDING_TD
#define AIEVEC_BINDING_TD

include "aie/Dialect/AIEVec/IR/AIEVecOps.td"

#endif // AIEVEC_BINDING_TD
7 changes: 7 additions & 0 deletions python/dialects/aievec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from ._aievec_ops_gen import *

from .._mlir_libs._aie import *
from .._mlir_libs import get_dialect_registry

# Comes from _aie
register_dialect(get_dialect_registry())
73 changes: 69 additions & 4 deletions test/python/tosa_aievec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import inspect
from pathlib import Path
from textwrap import dedent

# noinspection PyUnresolvedReferences
import aie.dialects.aievec
from aie.extras.dialects import arith
from aie.extras.dialects.func import func
from aie.extras.util import mlir_mod_ctx
from aie.ir import ShapedType
from aie.extras import types as T
from aie.ir import ShapedType, AffineMap
from aie.dialects import vector, aievec, scf
from util import construct_and_print_module
from inspect import currentframe, getframeinfo

# RUN: %python %s | FileCheck %s

Expand Down Expand Up @@ -41,3 +42,67 @@ def demo_fun1():
assert hasattr(demo_fun1, "emit")
assert inspect.ismethod(demo_fun1.emit)
demo_fun1.emit()


@construct_and_print_module
def test_aievec():
@func
def mul_mul(
A: T.memref(2048, T.f32()),
B: T.memref(2048, T.f32()),
C: T.memref(2048, T.f32()),
d: T.f32(),
):
v0 = vector.broadcast(T.vector(8, T.f32()), d)
v1 = aievec.concat([v0, v0])
for i in scf.for_(0, 2048, 8):
v2 = aievec.upd(T.vector(8, T.f32()), A, [i])
v3 = aievec.upd(T.vector(8, T.f32()), B, [i])
v4 = aievec.mul(
T.vector(8, T.f32()),
v1,
v2,
xoffsets="0x76543210",
xstart="0",
zoffsets="0x76543210",
zstart="0",
)
v5 = aievec.concat([v4, v4])
v6 = aievec.mul(
T.vector(8, T.f32()),
v5,
v3,
xoffsets="0x76543210",
xstart="0",
zoffsets="0x76543210",
zstart="0",
)
vector.transfer_write(
None,
v6,
C,
[i],
AffineMap.get_identity(1),
in_bounds=[True],
)

scf.yield_([])

# CHECK-LABEL: func.func @mul_mul(
# CHECK-SAME: %[[VAL_0:.*]]: memref<2048xf32>, %[[VAL_1:.*]]: memref<2048xf32>, %[[VAL_2:.*]]: memref<2048xf32>, %[[VAL_3:.*]]: f32) {
# CHECK: %[[VAL_4:.*]] = vector.broadcast %[[VAL_3]] : f32 to vector<8xf32>
# CHECK: %[[VAL_5:.*]] = aievec.concat %[[VAL_4]], %[[VAL_4]] : vector<8xf32>, vector<16xf32>
# CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
# CHECK: %[[VAL_7:.*]] = arith.constant 2048 : index
# CHECK: %[[VAL_8:.*]] = arith.constant 8 : index
# CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] {
# CHECK: %[[VAL_10:.*]] = aievec.upd %[[VAL_0]]{{\[}}%[[VAL_9]]] {index = 0 : i8, offset = 0 : i32} : memref<2048xf32>, vector<8xf32>
# CHECK: %[[VAL_11:.*]] = aievec.upd %[[VAL_1]]{{\[}}%[[VAL_9]]] {index = 0 : i8, offset = 0 : i32} : memref<2048xf32>, vector<8xf32>
# CHECK: %[[VAL_12:.*]] = aievec.mul %[[VAL_5]], %[[VAL_10]] {xoffsets = "0x76543210", xstart = "0", zoffsets = "0x76543210", zstart = "0"} : vector<16xf32>, vector<8xf32>, vector<8xf32>
# CHECK: %[[VAL_13:.*]] = aievec.concat %[[VAL_12]], %[[VAL_12]] : vector<8xf32>, vector<16xf32>
# CHECK: %[[VAL_14:.*]] = aievec.mul %[[VAL_13]], %[[VAL_11]] {xoffsets = "0x76543210", xstart = "0", zoffsets = "0x76543210", zstart = "0"} : vector<16xf32>, vector<8xf32>, vector<8xf32>
# CHECK: vector.transfer_write %[[VAL_14]], %[[VAL_2]]{{\[}}%[[VAL_9]]] {in_bounds = [true]} : vector<8xf32>, memref<2048xf32>
# CHECK: }
# CHECK: return
# CHECK: }
mul_mul.emit()

0 comments on commit 61ec4b0

Please sign in to comment.