Skip to content

Commit

Permalink
single core mat mul working
Browse files Browse the repository at this point in the history
  • Loading branch information
hunhoffe committed Sep 16, 2024
1 parent 2952a49 commit 229acf4
Showing 1 changed file with 27 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import numpy as np
import sys

from aie.dialects.scf import for_ as range_
from aie.dialects.scf import yield_
from aie.dialects.scf import yield_, for_ as range_
from aie.dialects.aiex import npu_dma_memcpy_nd, npu_sync

from aie.api.dataflow.inout.inout import MyInOutProgram
Expand Down Expand Up @@ -93,9 +92,9 @@ def my_matmul(M, K, N, m, k, n, dtype_in_str, dtype_out_str, vectorized):
num_data_tiles = (M // m) * (N // n)

# input/output matrices TODO(erika): fix up types
memref_A_ty = ((M, K), dtype_in)
memref_B_ty = ((K, N), dtype_in)
memref_C_ty = ((M, N), dtype_out)
memref_A_ty = ([M * K], dtype_in)
memref_B_ty = ([K * N], dtype_in)
memref_C_ty = ([M * N], dtype_out)

# submatrices TODO(erika): fix up types
memref_a_ty = ((m, k), dtype_in)
Expand Down Expand Up @@ -128,9 +127,7 @@ def my_matmul(M, K, N, m, k, n, dtype_in_str, dtype_out_str, vectorized):
else []
),
)
inALink = MyObjectFifoLink(
[inA.second], [memA.first], coords=(0, 1)
) # TODO(erika): represent_memtile
inALink = MyObjectFifoLink([inA.second], [memA.first], coords=(0, 1))

# Input B
inB = MyObjectFifo(2, memref_b_ty)
Expand All @@ -148,9 +145,7 @@ def my_matmul(M, K, N, m, k, n, dtype_in_str, dtype_out_str, vectorized):
else []
),
)
inBLink = MyObjectFifoLink(
[inB.second], [memB.first], coords=(0, 1)
) # TODO(erika): represent memtile
inBLink = MyObjectFifoLink([inB.second], [memB.first], coords=(0, 1))

# Output C
memC = MyObjectFifo(2, memref_c_ty)
Expand All @@ -168,29 +163,29 @@ def my_matmul(M, K, N, m, k, n, dtype_in_str, dtype_out_str, vectorized):
else []
),
)
outCLink = MyObjectFifoLink(
[memC.second], [outC.first], coords=(0, 1)
) # TODO(erika): represent memtile
outCLink = MyObjectFifoLink([memC.second], [outC.first], coords=(0, 1))

def core_fn(a, b, c, zero, matmul):
for _ in (
range_(num_data_tiles) if num_data_tiles > 1 else range(1)
): # issue #1547
elem_out = c.acquire(1)
zero(elem_out)
for _ in range_(0xFFFFFFFF):
for _ in (
range_(num_data_tiles) if num_data_tiles > 1 else range(1)
): # issue #1547
elem_out = c.acquire(1)
zero(elem_out)

for _ in range_(K // k) if (K // k) > 1 else range(1): # issue #1547
elem_in_a = a.acquire(1)
elem_in_b = b.acquire(1)
matmul(elem_in_a, elem_in_b, elem_out)
a.release(1)
b.release(1)
if (K // k) > 1:
yield_([])
for _ in range_(K // k) if (K // k) > 1 else range(1): # issue #1547
elem_in_a = a.acquire(1)
elem_in_b = b.acquire(1)
matmul(elem_in_a, elem_in_b, elem_out)
a.release(1)
b.release(1)
if (K // k) > 1:
yield_([])

c.release(1)
if num_data_tiles > 1:
yield_([])
c.release(1)
if num_data_tiles > 1:
yield_([])
yield_([])

def sequence_fn(A, B, C, inA, inB, outC):
# only do 4 tile rows at a time before synchronizing, so we can reuse BDs
Expand Down Expand Up @@ -243,11 +238,11 @@ def sequence_fn(A, B, C, inA, inB, outC):
inout_program = MyInOutProgram(
sequence_fn,
[memref_A_ty, memref_B_ty, memref_C_ty],
[memA.second, memB.second, memC.first],
[inA.first, inB.first, outC.second],
coords=(0, 0),
)
worker_program = MyWorker(
core_fn, [inA.first, inB.first, outC.second, zero, matmul], coords=(0, 2)
core_fn, [memA.second, memB.second, memC.first, zero, matmul], coords=(0, 2)
)

my_program = MyProgram(
Expand Down

0 comments on commit 229acf4

Please sign in to comment.