Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make pre-allocated buffers work again #1901

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 75 additions & 48 deletions lib/Dialect/AIE/Transforms/AIEAssignBuffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ using namespace xilinx::AIE;
//===----------------------------------------------------------------------===//
LogicalResult checkAndPrintOverflow(TileOp tile, int address,
int maxDataMemorySize, int stacksize,
SmallVector<BufferOp, 4> buffers) {
SmallVector<BufferOp> &buffers) {
if (address > maxDataMemorySize) {
InFlightDiagnostic error =
tile.emitOpError("allocated buffers exceeded available memory\n");
Expand Down Expand Up @@ -65,17 +65,30 @@ LogicalResult basicAllocation(TileOp &tile) {
else
maxDataMemorySize = targetModel.getLocalMemorySize();

SmallVector<BufferOp, 4> buffers;
// Collect all the buffers for this tile.
SmallVector<BufferOp> buffers;
SmallVector<BufferOp> allocated_buffers;
// Collect all the buffers for this tile. If the buffer has an address, add
// it to allocated_buffers. Otherwise, add it to buffers.
device.walk<WalkOrder::PreOrder>([&](BufferOp buffer) {
if (buffer.getTileOp() == tile)
buffers.push_back(buffer);
if (buffer.getTileOp() == tile) {
if (buffer.getAddress())
allocated_buffers.push_back(buffer);
else
buffers.push_back(buffer);
}
});
// Sort by allocation size.

// Sort buffers by allocation size.
std::sort(buffers.begin(), buffers.end(), [](BufferOp a, BufferOp b) {
return a.getAllocationSize() > b.getAllocationSize();
});

// Sort allocated_buffers by address
std::sort(allocated_buffers.begin(), allocated_buffers.end(),
[](BufferOp a, BufferOp b) {
return a.getAddress().value() < b.getAddress().value();
});

// Address range owned by the MemTile is 0x80000.
// Address range owned by the tile is 0x8000 in
// AIE1 and 0x10000 in AIE2, but we need room at
Expand All @@ -87,9 +100,18 @@ LogicalResult basicAllocation(TileOp &tile) {
address += stacksize;
}

// As the next address to allocate is assigned, skip over any buffers
// from the allocated_buffers list.
auto current_alloc = allocated_buffers.begin();
for (auto buffer : buffers) {
if (buffer.getAddress())
buffer->emitWarning("Overriding existing address");
assert(!buffer.getAddress());
while (current_alloc != allocated_buffers.end() &&
address + buffer.getAllocationSize() >
current_alloc->getAddress().value()) {
address = current_alloc->getAddress().value() +
current_alloc->getAllocationSize();
current_alloc++;
}
buffer.setAddress(address);
address += buffer.getAllocationSize();
}
Expand Down Expand Up @@ -150,25 +172,29 @@ void setAndUpdateAddressInBank(BufferOp buffer, int64_t start_addr,
// returns true and if not, the function emits a warning that the address
// will be overwritten and returns false (which will cause the buffer to be
// added to the list of buffers without addresses, to be completed later on).
bool checkAndAddBufferWithAddress(BufferOp buffer, int numBanks,
std::vector<int64_t> &nextAddrInBanks,
std::vector<BankLimits> &bankLimits) {
if (auto addrAttr = buffer->getAttrOfType<IntegerAttr>("address")) {
int addr = addrAttr.getInt();
for (int i = 0; i < numBanks; i++) {
if (bankLimits[i].startAddr <= addr && addr < bankLimits[i].endAddr) {
if (addr >= nextAddrInBanks[i]) {
nextAddrInBanks[i] = addr + buffer.getAllocationSize();
buffer.setMemBank(i);
} else {
buffer->emitWarning("Overriding existing address");
return false;
}
}
}
return true;
FailureOr<bool>
checkAndAddBufferWithAddress(BufferOp buffer, int numBanks,
std::vector<int64_t> &nextAddrInBanks,
std::vector<BankLimits> &bankLimits) {
auto addrAttr = buffer->getAttrOfType<IntegerAttr>("address");
if (!addrAttr)
return false;

int addr = addrAttr.getInt();
for (int i = 0; i < numBanks; i++) {
// if the address is not within the bank, continue
if (addr < bankLimits[i].startAddr || addr >= bankLimits[i].endAddr)
continue;

// if the allocator already overwrote this address, fail
if (addr < nextAddrInBanks[i])
return buffer->emitOpError("would override allocated address");

// the allocator can accomadate this existing allocation
nextAddrInBanks[i] = addr + buffer.getAllocationSize();
buffer.setMemBank(i);
}
return false;
return true;
}

// Function that checks whether the given buffer already has a set mem_bank
Expand All @@ -177,22 +203,21 @@ bool checkAndAddBufferWithAddress(BufferOp buffer, int numBanks,
// function emits a warning that the mem_bank will be overwritten and returns
// false (which will cause the buffer to be added to the list of buffers
// without addresses, to be completed later on).
bool checkAndAddBufferWithMemBank(BufferOp buffer, int numBanks,
std::vector<int64_t> &nextAddrInBanks,
std::vector<BankLimits> &bankLimits) {
if (auto memBankAttr = buffer->getAttrOfType<IntegerAttr>("mem_bank")) {
int mem_bank = memBankAttr.getInt();
int64_t startAddr = nextAddrInBanks[mem_bank];
int64_t endAddr = startAddr + buffer.getAllocationSize();
if (endAddr <= bankLimits[mem_bank].endAddr) {
setAndUpdateAddressInBank(buffer, startAddr, endAddr, nextAddrInBanks);
} else {
buffer->emitWarning("Overriding existing mem_bank");
return false;
}
return true;
}
return false;
FailureOr<bool>
checkAndAddBufferWithMemBank(BufferOp buffer, int numBanks,
std::vector<int64_t> &nextAddrInBanks,
std::vector<BankLimits> &bankLimits) {
auto memBankAttr = buffer->getAttrOfType<IntegerAttr>("mem_bank");
if (!memBankAttr)
return false;

int mem_bank = memBankAttr.getInt();
int64_t startAddr = nextAddrInBanks[mem_bank];
int64_t endAddr = startAddr + buffer.getAllocationSize();
if (endAddr > bankLimits[mem_bank].endAddr)
return buffer->emitOpError("would override allocated address");
setAndUpdateAddressInBank(buffer, startAddr, endAddr, nextAddrInBanks);
return true;
}

// Function that given a buffer will iterate over all the memory banks
Expand Down Expand Up @@ -226,7 +251,7 @@ int setBufferAddress(BufferOp buffer, int numBanks, int startBankIndex,
}

LogicalResult checkAndPrintOverflow(TileOp tile, int numBanks, int stacksize,
SmallVector<BufferOp, 4> allBuffers,
SmallVector<BufferOp> &allBuffers,
std::vector<int64_t> &nextAddrInBanks,
std::vector<BankLimits> &bankLimits) {
bool foundOverflow = false;
Expand Down Expand Up @@ -311,8 +336,8 @@ LogicalResult simpleBankAwareAllocation(TileOp tile) {
}
fillBankLimits(numBanks, bankSize, bankLimits);

SmallVector<BufferOp, 4> buffersToAlloc;
SmallVector<BufferOp, 4> allBuffers;
SmallVector<BufferOp> buffersToAlloc;
SmallVector<BufferOp> allBuffers;
// Collect all the buffers for this tile.
device.walk<WalkOrder::PreOrder>([&](BufferOp buffer) {
if (buffer.getTileOp() == tile)
Expand All @@ -325,11 +350,13 @@ LogicalResult simpleBankAwareAllocation(TileOp tile) {
// the above.
for (auto buffer : allBuffers) {
if (buffer.getTileOp() == tile) {
bool has_addr = checkAndAddBufferWithAddress(buffer, numBanks,
auto has_addr = checkAndAddBufferWithAddress(buffer, numBanks,
nextAddrInBanks, bankLimits);
bool has_bank = checkAndAddBufferWithMemBank(buffer, numBanks,
auto has_bank = checkAndAddBufferWithMemBank(buffer, numBanks,
nextAddrInBanks, bankLimits);
if (!has_addr && !has_bank)
if (failed(has_addr) || failed(has_bank))
return failure();
if (!has_addr.value() && !has_bank.value())
buffersToAlloc.push_back(buffer);
}
}
Expand Down
7 changes: 6 additions & 1 deletion python/compiler/aiecc/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,12 @@ def run_passes(pass_pipeline, mlir_module_str, outputfile=None, verbose=False):
print("Running:", pass_pipeline)
with Context() as ctx, Location.unknown():
module = Module.parse(mlir_module_str)
PassManager.parse(pass_pipeline).run(module.operation)
pm = PassManager.parse(pass_pipeline)
try:
pm.run(module.operation)
except Exception as e:
print("Error running pass pipeline: ", pass_pipeline, e)
raise e
mlir_module_str = str(module)
if outputfile:
with open(outputfile, "w") as g:
Expand Down
Loading