Skip to content

Commit

Permalink
Merge pull request #67 from mlcommons/et-generator-fix
Browse files Browse the repository at this point in the history
[chakra][et_generator] Fix bugs in et_generator
  • Loading branch information
srinivas212 authored Nov 23, 2023
2 parents af9f890 + aeab0d4 commit 7a5faa8
Showing 1 changed file with 42 additions and 40 deletions.
82 changes: 42 additions & 40 deletions utils/et_generator/et_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
BoolList,
StringList,
BytesList,
GlobalMetadata,
AttributeProto as ChakraAttr,
METADATA_NODE,
MEM_LOAD_NODE,
Expand Down Expand Up @@ -58,6 +59,8 @@ def one_metadata_node_all_types(num_npus: int) -> None:
for npu_id in range(num_npus):
output_filename = f"one_metadata_node_all_types.{npu_id}.et"
with open(output_filename, "wb") as et:
encode_message(et, GlobalMetadata(version="0.0.4"))

node = get_node("METADATA_NODE", METADATA_NODE)

node.attr.append(ChakraAttr(name="double", double_val=1.2345, doc_string="double"))
Expand Down Expand Up @@ -123,20 +126,26 @@ def one_metadata_node_all_types(num_npus: int) -> None:
encode_message(et, node)


def one_mem_load_node(num_npus: int, tensor_size: int) -> None:
def one_remote_mem_load_node(num_npus: int, tensor_size: int) -> None:
for npu_id in range(num_npus):
output_filename = f"one_mem_load_node.{npu_id}.et"
output_filename = f"one_remote_mem_load_node.{npu_id}.et"
with open(output_filename, "wb") as et:
encode_message(et, GlobalMetadata(version="0.0.4"))

node = get_node("MEM_LOAD_NODE", MEM_LOAD_NODE)
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
node.attr.append(ChakraAttr(name="tensor_size", uint64_val=tensor_size))
encode_message(et, node)


def one_mem_store_node(num_npus: int, tensor_size: int) -> None:
def one_remote_mem_store_node(num_npus: int, tensor_size: int) -> None:
for npu_id in range(num_npus):
output_filename = f"one_mem_store_node.{npu_id}.et"
output_filename = f"one_remote_mem_store_node.{npu_id}.et"
with open(output_filename, "wb") as et:
encode_message(et, GlobalMetadata(version="0.0.4"))

node = get_node("MEM_STORE_NODE", MEM_STORE_NODE)
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
node.attr.append(ChakraAttr(name="tensor_size", uint64_val=tensor_size))
encode_message(et, node)

Expand All @@ -145,7 +154,10 @@ def one_comp_node(num_npus: int, runtime: int) -> None:
for npu_id in range(num_npus):
output_filename = f"one_comp_node.{npu_id}.et"
with open(output_filename, "wb") as et:
encode_message(et, GlobalMetadata(version="0.0.4"))

node = get_node("COMP_NODE", COMP_NODE)
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
node.duration_micros = runtime
encode_message(et, node)

Expand All @@ -154,59 +166,46 @@ def two_comp_nodes_independent(num_npus: int, runtime: int) -> None:
for npu_id in range(num_npus):
output_filename = f"two_comp_nodes_independent.{npu_id}.et"
with open(output_filename, "wb") as et:
encode_message(et, GlobalMetadata(version="0.0.4"))

node = get_node("COMP_NODE", COMP_NODE)
node.duration_micros = runtime
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
encode_message(et, node)

node = get_node("COMP_NODE", COMP_NODE)
node.duration_micros = runtime
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
encode_message(et, node)


def two_comp_nodes_dependent(num_npus: int, runtime: int) -> None:
for npu_id in range(num_npus):
output_filename = f"two_comp_nodes_dependent.{npu_id}.et"
with open(output_filename, "wb") as et:
encode_message(et, GlobalMetadata(version="0.0.4"))

parent_node = get_node("COMP_NODE", COMP_NODE)
parent_node.duration_micros = runtime
parent_node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
encode_message(et, parent_node)

child_node = get_node("COMP_NODE", COMP_NODE)
child_node.duration_micros = runtime
child_node.data_deps.append(parent_node.id)
child_node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
encode_message(et, child_node)


def one_comm_send_node(num_npus: int, comm_size: int) -> None:
for npu_id in range(num_npus):
output_filename = f"one_comm_send_node.{npu_id}.et"
with open(output_filename, "wb") as et:
node = get_node("COMM_SEND_NODE", COMM_SEND_NODE)
node.attr.append(ChakraAttr(name="src", int64_val=npu_id))
node.attr.append(ChakraAttr(name="dst", int64_val=(npu_id+1) % num_npus))
node.attr.append(ChakraAttr(name="comm_size", uint64_val=comm_size))
encode_message(et, node)


def one_comm_recv_node(num_npus: int, comm_size: int) -> None:
for npu_id in range(num_npus):
output_filename = f"one_comm_recv_node.{npu_id}.et"
with open(output_filename, "wb") as et:
node = get_node("COMM_RECV_NODE", COMM_RECV_NODE)
src_attr = ChakraAttr(name="src", uint64_val=(npu_id-1) % num_npus)
dst_attr = ChakraAttr(name="dst", uint64_val=npu_id)
size_attr = ChakraAttr(name="comm_size", uint64_val=comm_size)
node.attr.extend([src_attr, dst_attr, size_attr])
encode_message(et, node)


def one_comm_coll_node_allreduce(num_npus: int, num_dims: int, comm_size: int) -> None:
for npu_id in range(num_npus):
output_filename = f"one_comm_coll_node_allreduce.{npu_id}.et"
with open(output_filename, "wb") as et:
encode_message(et, GlobalMetadata(version="0.0.4"))

node = get_node("ALL_REDUCE", COMM_COLL_NODE)
attr = get_comm_type_attr(ALL_REDUCE)
node.attr.append(attr)
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
node.attr.append(get_comm_type_attr(ALL_REDUCE))
node.attr.append(ChakraAttr(name="comm_size", uint64_val=comm_size))
attr = get_involved_dim_attr(num_dims)
node.attr.append(attr)
Expand All @@ -217,9 +216,11 @@ def one_comm_coll_node_alltoall(num_npus: int, num_dims: int, comm_size: int) ->
for npu_id in range(num_npus):
output_filename = f"one_comm_coll_node_alltoall.{npu_id}.et"
with open(output_filename, "wb") as et:
encode_message(et, GlobalMetadata(version="0.0.4"))

node = get_node("ALL_TO_ALL", COMM_COLL_NODE)
attr = get_comm_type_attr(ALL_TO_ALL)
node.attr.append(attr)
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
node.attr.append(get_comm_type_attr(ALL_TO_ALL))
node.attr.append(ChakraAttr(name="comm_size", uint64_val=comm_size))
attr = get_involved_dim_attr(num_dims)
node.attr.append(attr)
Expand All @@ -230,9 +231,11 @@ def one_comm_coll_node_allgather(num_npus: int, num_dims: int, comm_size: int) -
for npu_id in range(num_npus):
output_filename = f"one_comm_coll_node_allgather.{npu_id}.et"
with open(output_filename, "wb") as et:
encode_message(et, GlobalMetadata(version="0.0.4"))

node = get_node("ALL_GATHER", COMM_COLL_NODE)
attr = get_comm_type_attr(ALL_GATHER)
node.attr.append(attr)
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
node.attr.append(get_comm_type_attr(ALL_GATHER))
node.attr.append(ChakraAttr(name="comm_size", uint64_val=comm_size))
attr = get_involved_dim_attr(num_dims)
node.attr.append(attr)
Expand All @@ -243,9 +246,11 @@ def one_comm_coll_node_reducescatter(num_npus: int, num_dims: int, comm_size: in
for npu_id in range(num_npus):
output_filename = f"one_comm_coll_node_reducescatter.{npu_id}.et"
with open(output_filename, "wb") as et:
encode_message(et, GlobalMetadata(version="0.0.4"))

node = get_node("REDUCE_SCATTER", COMM_COLL_NODE)
attr = get_comm_type_attr(REDUCE_SCATTER)
node.attr.append(attr)
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
node.attr.append(get_comm_type_attr(REDUCE_SCATTER))
node.attr.append(ChakraAttr(name="comm_size", uint64_val=comm_size))
attr = get_involved_dim_attr(num_dims)
node.attr.append(attr)
Expand Down Expand Up @@ -290,16 +295,13 @@ def main() -> None:

one_metadata_node_all_types(args.num_npus)

one_mem_load_node(args.num_npus, args.default_tensor_size)
one_mem_store_node(args.num_npus, args.default_tensor_size)
one_remote_mem_load_node(args.num_npus, args.default_tensor_size)
one_remote_mem_store_node(args.num_npus, args.default_tensor_size)

one_comp_node(args.num_npus, args.default_runtime)
two_comp_nodes_independent(args.num_npus, args.default_runtime)
two_comp_nodes_dependent(args.num_npus, args.default_runtime)

one_comm_send_node(args.num_npus, args.default_comm_size)
one_comm_recv_node(args.num_npus, args.default_comm_size)

one_comm_coll_node_allreduce(args.num_npus, args.num_dims, args.default_comm_size)
one_comm_coll_node_alltoall(args.num_npus, args.num_dims, args.default_comm_size)
one_comm_coll_node_allgather(args.num_npus, args.num_dims, args.default_comm_size)
Expand Down

0 comments on commit 7a5faa8

Please sign in to comment.