diff --git a/programming_examples/basic/matrix_multiplication/single_core/aie2.py b/programming_examples/basic/matrix_multiplication/single_core/aie2.py index 6ecdb7b409..e7feedaba4 100644 --- a/programming_examples/basic/matrix_multiplication/single_core/aie2.py +++ b/programming_examples/basic/matrix_multiplication/single_core/aie2.py @@ -10,8 +10,7 @@ import numpy as np import sys -from aie.dialects.scf import for_ as range_ -from aie.dialects.scf import yield_ +from aie.dialects.scf import yield_, for_ as range_ from aie.dialects.aiex import npu_dma_memcpy_nd, npu_sync from aie.api.dataflow.inout.inout import MyInOutProgram @@ -93,9 +92,9 @@ def my_matmul(M, K, N, m, k, n, dtype_in_str, dtype_out_str, vectorized): num_data_tiles = (M // m) * (N // n) # input/output matrices TODO(erika): fix up types - memref_A_ty = ((M, K), dtype_in) - memref_B_ty = ((K, N), dtype_in) - memref_C_ty = ((M, N), dtype_out) + memref_A_ty = ([M * K], dtype_in) + memref_B_ty = ([K * N], dtype_in) + memref_C_ty = ([M * N], dtype_out) # submatrices TODO(erika): fix up types memref_a_ty = ((m, k), dtype_in) @@ -128,9 +127,7 @@ def my_matmul(M, K, N, m, k, n, dtype_in_str, dtype_out_str, vectorized): else [] ), ) - inALink = MyObjectFifoLink( - [inA.second], [memA.first], coords=(0, 1) - ) # TODO(erika): represent_memtile + inALink = MyObjectFifoLink([inA.second], [memA.first], coords=(0, 1)) # Input B inB = MyObjectFifo(2, memref_b_ty) @@ -148,9 +145,7 @@ def my_matmul(M, K, N, m, k, n, dtype_in_str, dtype_out_str, vectorized): else [] ), ) - inBLink = MyObjectFifoLink( - [inB.second], [memB.first], coords=(0, 1) - ) # TODO(erika): represent memtile + inBLink = MyObjectFifoLink([inB.second], [memB.first], coords=(0, 1)) # Output C memC = MyObjectFifo(2, memref_c_ty) @@ -168,29 +163,29 @@ def my_matmul(M, K, N, m, k, n, dtype_in_str, dtype_out_str, vectorized): else [] ), ) - outCLink = MyObjectFifoLink( - [memC.second], [outC.first], coords=(0, 1) - ) # TODO(erika): represent memtile + outCLink = MyObjectFifoLink([memC.second], [outC.first], coords=(0, 1)) def core_fn(a, b, c, zero, matmul): - for _ in ( - range_(num_data_tiles) if num_data_tiles > 1 else range(1) - ): # issue #1547 - elem_out = c.acquire(1) - zero(elem_out) + for _ in range_(0xFFFFFFFF): + for _ in ( + range_(num_data_tiles) if num_data_tiles > 1 else range(1) + ): # issue #1547 + elem_out = c.acquire(1) + zero(elem_out) - for _ in range_(K // k) if (K // k) > 1 else range(1): # issue #1547 - elem_in_a = a.acquire(1) - elem_in_b = b.acquire(1) - matmul(elem_in_a, elem_in_b, elem_out) - a.release(1) - b.release(1) - if (K // k) > 1: - yield_([]) + for _ in range_(K // k) if (K // k) > 1 else range(1): # issue #1547 + elem_in_a = a.acquire(1) + elem_in_b = b.acquire(1) + matmul(elem_in_a, elem_in_b, elem_out) + a.release(1) + b.release(1) + if (K // k) > 1: + yield_([]) - c.release(1) - if num_data_tiles > 1: - yield_([]) + c.release(1) + if num_data_tiles > 1: + yield_([]) + yield_([]) def sequence_fn(A, B, C, inA, inB, outC): # only do 4 tile rows at a time before synchronizing, so we can reuse BDs @@ -243,11 +238,11 @@ def sequence_fn(A, B, C, inA, inB, outC): inout_program = MyInOutProgram( sequence_fn, [memref_A_ty, memref_B_ty, memref_C_ty], - [memA.second, memB.second, memC.first], + [inA.first, inB.first, outC.second], coords=(0, 0), ) worker_program = MyWorker( - core_fn, [inA.first, inB.first, outC.second, zero, matmul], coords=(0, 2) + core_fn, [memA.second, memB.second, memC.first, zero, matmul], coords=(0, 2) ) my_program = MyProgram(