From 14d7cbcd7d4b038d50c195f62739ead4f5e01c3a Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Tue, 6 Feb 2024 10:31:40 -0500 Subject: [PATCH 1/8] lint: Add .flake8 --- .flake8 | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 .flake8 diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..f7523bce --- /dev/null +++ b/.flake8 @@ -0,0 +1,39 @@ +[flake8] +enable-extensions = G +select = B,C,E,F,G,P,SIM1,T4,W,B9,TOR0,TOR1,TOR2 +max-line-length = 120 +# C408 ignored because we like the dict keyword argument syntax +# E501 is not flexible enough, we're using B950 instead +ignore = + E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303, + # fix these lints in the future + E275, + # shebang has extra meaning in fbcode lints, so I think it's not worth trying + # to line this up with executable bit + EXE001, + # these ignores are from flake8-bugbear; please fix! + B007,B008,B017,B019,B023,B028,B903,B904,B905,B906,B907 + # these ignores are from flake8-comprehensions; please fix! + C407, + # these ignores are from flake8-logging-format; please fix! + G100,G101,G200,G201,G202 + # these ignores are from flake8-simplify. please fix or ignore with commented reason + SIM105,SIM108,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12, + # flake8-simplify code styles + SIM102,SIM103,SIM106,SIM112, + # TorchFix codes that don't make sense for PyTorch itself: + # removed and deprecated PyTorch functions. + TOR001,TOR101, + # TODO(kit1980): fix all TOR102 issues + # `torch.load` without `weights_only` parameter is unsafe + TOR102, + P201, +per-file-ignores = + __init__.py: F401 +optional-ascii-coding = True +exclude = + ./.git, + ./build, + ./et_def/et_def_pb2.py, + ./et_def/et_def_pb2_grpc.py, + ./third_party/utils/protolib.py, From 59348d01fe0a5d0cf9c5288924148c24fea3d749 Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Tue, 6 Feb 2024 10:48:39 -0500 Subject: [PATCH 2/8] lint: Add .pyre_configuration --- .gitignore | 3 ++- .pyre_configuration | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) create mode 100644 .pyre_configuration diff --git a/.gitignore b/.gitignore index 9973a044..eb9278c8 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ build/ __pycache__/ *.egg *.et -*.dot \ No newline at end of file +*.dot +.pyre diff --git a/.pyre_configuration b/.pyre_configuration new file mode 100644 index 00000000..69234f7f --- /dev/null +++ b/.pyre_configuration @@ -0,0 +1,7 @@ +{ + "source_directories": [ + "timeline_visualizer", + "et_converter" + ], + "search_path": [] +} From 2a0f3f4f69fffdd52222fe25fe0085ea8b8b47bc Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Tue, 6 Feb 2024 10:32:44 -0500 Subject: [PATCH 3/8] et_generator: Resolve flake8 errors --- utils/et_generator/et_generator.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/utils/et_generator/et_generator.py b/utils/et_generator/et_generator.py index cc2a94ee..f3b6503d 100644 --- a/utils/et_generator/et_generator.py +++ b/utils/et_generator/et_generator.py @@ -26,8 +26,6 @@ MEM_LOAD_NODE, MEM_STORE_NODE, COMP_NODE, - COMM_SEND_NODE, - COMM_RECV_NODE, COMM_COLL_NODE, ALL_REDUCE, ALL_TO_ALL, @@ -116,11 +114,11 @@ def one_metadata_node_all_types(num_npus: int) -> None: node.attr.append(ChakraAttr(name="bool_list", bool_list=bool_list)) node.attr.append(ChakraAttr(name="string", string_val="12345", doc_string="string")) - string_list = StringList(values=[str(12345+i) for i in range(10)]) + string_list = StringList(values=[str(12345 + i) for i in range(10)]) node.attr.append(ChakraAttr(name="string_list", string_list=string_list)) node.attr.append(ChakraAttr(name="bytes", bytes_val=bytes("12345", "utf-8"))) - bytes_list = BytesList(values=[bytes(str(12345+i), "utf-8") for i in range(10)]) + bytes_list = BytesList(values=[bytes(str(12345 + i), "utf-8") for i in range(10)]) node.attr.append(ChakraAttr(name="bytes_list", bytes_list=bytes_list)) encode_message(et, node) From 2f9f48bd628576cca0820f5bd9662c00d0043a38 Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Tue, 6 Feb 2024 10:37:25 -0500 Subject: [PATCH 4/8] et_converter: Resolve flake8 errors --- et_converter/et_converter.py | 123 +++++------ et_converter/flexflow2chakra_converter.py | 40 ++-- et_converter/pytorch2chakra_converter.py | 15 +- et_converter/pytorch_node.py | 8 - et_converter/pytorch_tensor.py | 2 +- et_converter/text2chakra_converter.py | 252 ++++++++++------------ 6 files changed, 191 insertions(+), 249 deletions(-) diff --git a/et_converter/et_converter.py b/et_converter/et_converter.py index 2dd52eee..588d7e70 100644 --- a/et_converter/et_converter.py +++ b/et_converter/et_converter.py @@ -12,8 +12,8 @@ def get_logger(log_filename: str) -> logging.Logger: formatter = logging.Formatter( - "%(levelname)s [%(asctime)s] %(message)s", - datefmt="%m/%d/%Y %I:%M:%S %p") + "%(levelname)s [%(asctime)s] %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p") file_handler = FileHandler(log_filename, mode="w") file_handler.setLevel(logging.DEBUG) @@ -32,63 +32,54 @@ def get_logger(log_filename: str) -> logging.Logger: def main() -> None: parser = argparse.ArgumentParser( - description="Execution Trace Converter" - ) + description="Execution Trace Converter") parser.add_argument( - "--input_type", - type=str, - default=None, - required=True, - help="Input execution trace type" - ) + "--input_type", + type=str, + default=None, + required=True, + help="Input execution trace type") parser.add_argument( - "--input_filename", - type=str, - default=None, - required=True, - help="Input execution trace filename" - ) + "--input_filename", + type=str, + default=None, + required=True, + help="Input execution trace filename") parser.add_argument( - "--output_filename", - type=str, - default=None, - required=True, - help="Output Chakra execution trace filename" - ) + "--output_filename", + type=str, + default=None, + required=True, + help="Output Chakra execution trace filename") parser.add_argument( - "--num_dims", - type=int, - default=None, - required=True, - help="Number of dimensions in the network topology" - ) + "--num_dims", + type=int, + default=None, + required=True, + help="Number of dimensions in the network topology") parser.add_argument( - "--num_npus", - type=int, - default=None, - required="Text" in sys.argv, - help="Number of NPUs in a system" - ) + "--num_npus", + type=int, + default=None, + required="Text" in sys.argv, + help="Number of NPUs in a system") parser.add_argument( - "--num_passes", - type=int, - default=None, - required="Text" in sys.argv, - help="Number of training passes" - ) + "--num_passes", + type=int, + default=None, + required="Text" in sys.argv, + help="Number of training passes") parser.add_argument( - "--npu_frequency", - type=int, - default=None, - required="FlexFlow" in sys.argv, - help="NPU frequency in MHz" - ) + "--npu_frequency", + type=int, + default=None, + required="FlexFlow" in sys.argv, + help="NPU frequency in MHz") parser.add_argument( - "--log_filename", - type=str, - default="debug.log", - help="Log filename" - ) + "--log_filename", + type=str, + default="debug.log", + help="Log filename") args = parser.parse_args() logger = get_logger(args.log_filename) @@ -97,27 +88,27 @@ def main() -> None: try: if args.input_type == "Text": converter = Text2ChakraConverter( - args.input_filename, - args.output_filename, - args.num_dims, - args.num_npus, - args.num_passes, - logger) + args.input_filename, + args.output_filename, + args.num_dims, + args.num_npus, + args.num_passes, + logger) converter.convert() elif args.input_type == "FlexFlow": converter = FlexFlow2ChakraConverter( - args.input_filename, - args.output_filename, - args.num_dims, - args.npu_frequency, - logger) + args.input_filename, + args.output_filename, + args.num_dims, + args.npu_frequency, + logger) converter.convert() elif args.input_type == "PyTorch": converter = PyTorch2ChakraConverter( - args.input_filename, - args.output_filename, - args.num_dims, - logger) + args.input_filename, + args.output_filename, + args.num_dims, + logger) converter.convert() else: logger.error(f"{args.input_type} unsupported") diff --git a/et_converter/flexflow2chakra_converter.py b/et_converter/flexflow2chakra_converter.py index abb7884a..af69e630 100644 --- a/et_converter/flexflow2chakra_converter.py +++ b/et_converter/flexflow2chakra_converter.py @@ -36,28 +36,28 @@ def get_label(self, ff_node: Any) -> str: try: label = ff_node.get_attributes()["label"] return label.replace("\"", "")[1:-1] - except: - raise ValueError(f"Cannot retrieve label from a FlexFlow node") + except Exception: + raise ValueError("Cannot retrieve label from a FlexFlow node") def get_id(self, ff_node: Any) -> int: ff_node_name = ff_node.get_name() try: return int(ff_node_name.replace("node", "")) - except: + except Exception: raise ValueError(f"Cannot retrieve id from \"{ff_node_name}\"") def get_npu_id(self, ff_node: Any) -> int: label = self.get_label(ff_node) try: return int(label.split("|")[0].strip().split("=")[1]) - except: + except Exception: raise ValueError(f"Cannot retrieve npu_id from \"{label}\"") def get_name(self, ff_node: Any) -> str: label = self.get_label(ff_node) try: return label.split("|")[1].strip() - except: + except Exception: raise ValueError(f"Cannot retrieve name from \"{label}\"") def get_node_type(self, ff_node: Any) -> int: @@ -70,7 +70,7 @@ def get_node_type(self, ff_node: Any) -> int: return COMM_SEND_NODE else: raise ValueError(f"Unsupported node_type, \"{node_type}\"") - except: + except Exception: raise ValueError(f"Cannot retrieve node_type from \"{label}\"") def get_runtime(self, ff_node: Any) -> int: @@ -78,28 +78,28 @@ def get_runtime(self, ff_node: Any) -> int: try: wall_clock_time = float(label.split("|")[4].strip().split("=")[1]) return int(round(wall_clock_time * self.num_cycles_per_sec)) - except: + except Exception: raise ValueError(f"Cannot retrieve runtime from \"{label}\"") def get_comm_src(self, ff_node: Any) -> int: label = self.get_label(ff_node) try: return int(label.split("|")[4].strip().split("=")[1]) - except: + except Exception: raise ValueError(f"Cannot retrieve comm_src from \"{label}\"") def get_comm_dst(self, ff_node: Any) -> int: label = self.get_label(ff_node) try: return int(label.split("|")[5].strip().split("=")[1]) - except: + except Exception: raise ValueError(f"Cannot retrieve comm_dst from \"{label}\"") def get_comm_size(self, ff_node: Any) -> int: label = self.get_label(ff_node) try: return int(label.split("|")[6].strip().split("=")[1]) - except: + except Exception: raise ValueError(f"Cannot retrieve comm_size from \"{label}\"") def convert_FF_node_to_CK_node(self, ff_node: Any) -> Any: @@ -165,7 +165,7 @@ def convert(self) -> None: # communication nodes elif (ck_node.type == COMM_SEND_NODE): if (self.node_id_comm_info_dict[ck_node.id]["comm_src"] == npu_id)\ - or (self.node_id_comm_info_dict[ck_node.id]["comm_dst"] == npu_id): + or (self.node_id_comm_info_dict[ck_node.id]["comm_dst"] == npu_id): comm_src = self.node_id_comm_info_dict[ck_node.id]["comm_src"] comm_dst = self.node_id_comm_info_dict[ck_node.id]["comm_dst"] comm_key = f"{ck_node.id}_{comm_src}_{comm_dst}" @@ -187,18 +187,12 @@ def convert(self) -> None: ck_comm_node.type = COMM_RECV_NODE ck_comm_node.name += f"_{ck_node.name}" - ck_comm_node.attr.append( - ChakraAttr(name="comm_src", - int64_val=self.node_id_comm_info_dict[ck_node.id]["comm_src"])) - ck_comm_node.attr.append( - ChakraAttr(name="comm_dst", - int64_val=self.node_id_comm_info_dict[ck_node.id]["comm_dst"])) - ck_comm_node.attr.append( - ChakraAttr(name="comm_size", - int64_val=self.node_id_comm_info_dict[ck_node.id]["comm_size"])) - ck_comm_node.attr.append( - ChakraAttr(name="comm_tag", - int64_val=comm_tag)) + ck_comm_node.attr.extend([ + ChakraAttr(name="comm_src", int64_val=self.node_id_comm_info_dict[ck_node.id]["comm_src"]), + ChakraAttr(name="comm_dst", int64_val=self.node_id_comm_info_dict[ck_node.id]["comm_dst"]), + ChakraAttr(name="comm_size", int64_val=self.node_id_comm_info_dict[ck_node.id]["comm_size"]), + ChakraAttr(name="comm_tag", int64_val=comm_tag) + ]) per_npu_comm_nodes += 1 total_comm_nodes += 1 diff --git a/et_converter/pytorch2chakra_converter.py b/et_converter/pytorch2chakra_converter.py index e4db4e30..6225cec1 100644 --- a/et_converter/pytorch2chakra_converter.py +++ b/et_converter/pytorch2chakra_converter.py @@ -1,10 +1,9 @@ #!/usr/bin/env python3 -import bisect import copy import json import logging -from typing import Any, Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple from chakra.third_party.utils.protolib import encodeMessage as encode_message from chakra.et_converter.pytorch_node import PyTorchNodeType, PyTorchNode @@ -12,7 +11,6 @@ GlobalMetadata, Node as ChakraNode, AttributeProto as ChakraAttr, - INVALID_NODE, COMP_NODE, COMM_COLL_NODE, ALL_REDUCE, @@ -164,7 +162,7 @@ def convert(self) -> None: for pytorch_nid, pytorch_node in self.pytorch_nodes.items(): if (pytorch_node.get_op_type() == PyTorchNodeType.CPU_OP)\ - or (pytorch_node.get_op_type() == PyTorchNodeType.LABEL): + or (pytorch_node.get_op_type() == PyTorchNodeType.LABEL): chakra_node = self.convert_to_chakra_node(pytorch_node) self.chakra_nodes[chakra_node.id] = chakra_node @@ -180,8 +178,7 @@ def convert(self) -> None: ChakraAttr(name="comm_size", int64_val=pytorch_gpu_node.comm_size), ChakraAttr(name="involved_dim", - bool_list={"values": [True]*self.num_dims}) - ]) + bool_list={"values": [True] * self.num_dims})]) self.chakra_nodes[chakra_gpu_node.id] = chakra_gpu_node @@ -352,7 +349,7 @@ def split_cpu_nodes_with_gpu_child(self) -> None: if cpu_node.exclusive_dur > 1: gpu_node = cpu_node.child_gpu cpu_node_first, cpu_node_second, updated_gpu_node =\ - self._split_cpu_node(cpu_node, gpu_node, updated_pytorch_nodes) + self._split_cpu_node(cpu_node, gpu_node, updated_pytorch_nodes) updated_pytorch_nodes[cpu_node_first.id] = copy.deepcopy(cpu_node_first) updated_pytorch_nodes[cpu_node_second.id] = copy.deepcopy(cpu_node_second) updated_pytorch_nodes[updated_gpu_node.id] = copy.deepcopy(updated_gpu_node) @@ -855,13 +852,13 @@ def simulate_execution(self) -> None: (node_id, self.chakra_nodes[node_id]) for node_id in self.chakra_nodes if not self.chakra_nodes[node_id].data_deps and - not self.pytorch_nodes[node_id].is_gpu_op() + not self.pytorch_nodes[node_id].is_gpu_op() ] ready_gpu_nodes = [ (node_id, self.chakra_nodes[node_id]) for node_id in self.chakra_nodes if not self.chakra_nodes[node_id].data_deps and - self.pytorch_nodes[node_id].is_gpu_op() + self.pytorch_nodes[node_id].is_gpu_op() ] ready_cpu_nodes.sort(key=lambda x: x[1].id) ready_gpu_nodes.sort(key=lambda x: x[1].id) diff --git a/et_converter/pytorch_node.py b/et_converter/pytorch_node.py index 0ba972ac..47edfbd0 100644 --- a/et_converter/pytorch_node.py +++ b/et_converter/pytorch_node.py @@ -3,14 +3,6 @@ from enum import Enum from typing import Any, Dict, List, Optional -from chakra.et_def.et_def_pb2 import ( - ALL_REDUCE, - ALL_GATHER, - BROADCAST, - ALL_TO_ALL, - REDUCE_SCATTER, -) - class PyTorchNodeType(Enum): CPU_OP = 1 diff --git a/et_converter/pytorch_tensor.py b/et_converter/pytorch_tensor.py index 30c4074d..e46bbc71 100644 --- a/et_converter/pytorch_tensor.py +++ b/et_converter/pytorch_tensor.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -from typing import Any, List +from typing import List class PyTorchTensor: diff --git a/et_converter/text2chakra_converter.py b/et_converter/text2chakra_converter.py index 6400321d..73573e38 100644 --- a/et_converter/text2chakra_converter.py +++ b/et_converter/text2chakra_converter.py @@ -47,7 +47,7 @@ def __init__( self.bwd_wg_update_time = str(col[11]) self.bwd_wg_comp_node = None self.bwd_wg_comm_node = None - except: + except Exception: raise ValueError(f"Cannot parse the following layer -- \"{line}\"") class Text2ChakraConverter: @@ -132,16 +132,12 @@ def get_comm_coll_node( comm_type: str, comm_size: int ) -> Any: - node = self.get_node( - f"COMM_COLL_NODE_{layer_name}_{comm_type}", - COMM_COLL_NODE) - node.attr.append( - ChakraAttr(name="comm_type", - int64_val=self.get_comm_type(comm_type))) - node.attr.append( - ChakraAttr(name="comm_size", - uint64_val = comm_size) - ) + node = self.get_node(f"COMM_COLL_NODE_{layer_name}_{comm_type}", + COMM_COLL_NODE) + node.attr.append(ChakraAttr(name="comm_type", + int64_val=self.get_comm_type(comm_type))) + node.attr.append(ChakraAttr(name="comm_size", + uint64_val=comm_size)) return node def add_parent( @@ -187,10 +183,9 @@ def convert_microbenchmark( encode_message(g, global_metadata) for i in range(self.num_passes): for layer in layers: - bwd_wg_comm_node = self.get_comm_coll_node( - layer.name, - layer.bwd_wg_comm_type, - layer.bwd_wg_comm_size) + bwd_wg_comm_node = self.get_comm_coll_node(layer.name, + layer.bwd_wg_comm_type, + layer.bwd_wg_comm_size) attr = ChakraAttr(name="involved_dim") for _ in range(self.num_dims): @@ -215,34 +210,31 @@ def convert_data_parallel( # forward pass for idx, layer in enumerate(layers): - fwd_comp_node = self.get_comp_node( - layer.name, "FWD", - layer.fwd_comp_time) + fwd_comp_node = self.get_comp_node(layer.name, "FWD", + layer.fwd_comp_time) if idx != 0: - self.add_parent(fwd_comp_node, layers[idx-1].fwd_comp_node) - if layer.bwd_wg_comm_node != None: + self.add_parent(fwd_comp_node, layers[idx - 1].fwd_comp_node) + if layer.bwd_wg_comm_node is not None: self.add_parent(fwd_comp_node, layer.bwd_wg_comm_node) layer.fwd_comp_node = fwd_comp_node encode_message(g, fwd_comp_node) # backward pass for idx, layer in enumerate(reversed(layers)): - bwd_wg_comp_node = self.get_comp_node( - layer.name, "BWD_WG", - layer.bwd_wg_comp_time) + bwd_wg_comp_node = self.get_comp_node(layer.name, "BWD_WG", + layer.bwd_wg_comp_time) if idx == 0: - if fwd_comp_node == None: + if fwd_comp_node is None: raise ValueError("fwd_comp_node is None") self.add_parent(bwd_wg_comp_node, fwd_comp_node) else: self.add_parent(bwd_wg_comp_node, - layers[len(layers)-idx].bwd_ig_comp_node) + layers[len(layers) - idx].bwd_ig_comp_node) encode_message(g, bwd_wg_comp_node) - bwd_wg_comm_node = self.get_comm_coll_node( - layer.name, - layer.bwd_wg_comm_type, - layer.bwd_wg_comm_size) + bwd_wg_comm_node = self.get_comm_coll_node(layer.name, + layer.bwd_wg_comm_type, + layer.bwd_wg_comm_size) attr = ChakraAttr(name="involved_dim") for _ in range(self.num_dims): attr.bool_list.values.append(True) @@ -253,9 +245,8 @@ def convert_data_parallel( encode_message(g, bwd_wg_comm_node) if idx != (len(layers) - 1): - bwd_ig_comp_node = self.get_comp_node( - layer.name, "BWD_IG", - layer.bwd_ig_comp_time) + bwd_ig_comp_node = self.get_comp_node(layer.name, "BWD_IG", + layer.bwd_ig_comp_time) self.add_parent(bwd_ig_comp_node, bwd_wg_comp_node) layer.bwd_ig_comp_node = bwd_ig_comp_node encode_message(g, bwd_ig_comp_node) @@ -279,20 +270,18 @@ def convert_model_parallel( # forward pass for idx, layer in enumerate(layers): - fwd_comp_node = self.get_comp_node( - layer.name, "FWD", - layer.fwd_comp_time) + fwd_comp_node = self.get_comp_node(layer.name, "FWD", + layer.fwd_comp_time) if idx != 0: - self.add_parent(fwd_comp_node, layers[idx-1].fwd_comm_node) - if layer.bwd_wg_comp_node != None: + self.add_parent(fwd_comp_node, layers[idx - 1].fwd_comm_node) + if layer.bwd_wg_comp_node is not None: self.add_parent(fwd_comp_node, layer.bwd_wg_comp_node) layer.fwd_comp_node = fwd_comp_node encode_message(g, fwd_comp_node) - fwd_comm_node = self.get_comm_coll_node( - layer.name, - layer.fwd_comm_type, - layer.fwd_comm_size) + fwd_comm_node = self.get_comm_coll_node(layer.name, + layer.fwd_comm_type, + layer.fwd_comm_size) attr = ChakraAttr(name="involved_dim") for _ in range(self.num_dims): attr.bool_list.values.append(True) @@ -303,25 +292,23 @@ def convert_model_parallel( # backward pass for idx, layer in enumerate(reversed(layers)): - bwd_ig_comp_node = self.get_comp_node( - layer.name, "BWD_IG", - layer.bwd_ig_comp_time) + bwd_ig_comp_node = self.get_comp_node(layer.name, "BWD_IG", + layer.bwd_ig_comp_time) if idx == 0: - if fwd_comm_node == None: + if fwd_comm_node is None: raise ValueError("fwd_comm_node is None") self.add_parent(bwd_ig_comp_node, fwd_comm_node) else: self.add_parent(bwd_ig_comp_node, - layers[len(layers)-idx].bwd_wg_comp_node) + layers[len(layers) - idx].bwd_wg_comp_node) self.add_parent(bwd_ig_comp_node, - layers[len(layers)-idx].bwd_ig_comm_node) + layers[len(layers) - idx].bwd_ig_comm_node) encode_message(g, bwd_ig_comp_node) if idx != (num_layers - 1): - bwd_ig_comm_node = self.get_comm_coll_node( - layer.name, - layer.bwd_ig_comm_type, - layer.bwd_ig_comm_size) + bwd_ig_comm_node = self.get_comm_coll_node(layer.name, + layer.bwd_ig_comm_type, + layer.bwd_ig_comm_size) attr = ChakraAttr(name="involved_dim") for _ in range(self.num_dims): attr.bool_list.values.append(True) @@ -330,9 +317,8 @@ def convert_model_parallel( layer.bwd_ig_comm_node = bwd_ig_comm_node encode_message(g, bwd_ig_comm_node) - bwd_wg_comp_node = self.get_comp_node( - layer.name, "BWD_WG", - layer.bwd_wg_comp_time) + bwd_wg_comp_node = self.get_comp_node(layer.name, "BWD_WG", + layer.bwd_wg_comp_time) self.add_parent(bwd_wg_comp_node, bwd_ig_comp_node) layer.bwd_wg_comp_node = bwd_wg_comp_node encode_message(g, bwd_wg_comp_node) @@ -356,22 +342,20 @@ def convert_hybrid_data_model( # forward pass for idx, layer in enumerate(layers): - fwd_comp_node = self.get_comp_node( - layer.name, "FWD", - layer.fwd_comp_time) - if layer.bwd_wg_comm_node != None: + fwd_comp_node = self.get_comp_node(layer.name, "FWD", + layer.fwd_comp_time) + if layer.bwd_wg_comm_node is not None: self.add_parent(fwd_comp_node, layer.bwd_wg_comm_node) if idx != 0: - self.add_parent(fwd_comp_node, layers[idx-1].fwd_comm_node) + self.add_parent(fwd_comp_node, layers[idx - 1].fwd_comm_node) encode_message(g, fwd_comp_node) - fwd_comm_node = self.get_comm_coll_node( - layer.name, - layer.fwd_comm_type, - layer.fwd_comm_size) + fwd_comm_node = self.get_comm_coll_node(layer.name, + layer.fwd_comm_type, + layer.fwd_comm_size) attr = ChakraAttr(name="involved_dim") attr.bool_list.values.append(True) - for _ in range(self.num_dims-1): + for _ in range(self.num_dims - 1): attr.bool_list.values.append(False) fwd_comm_node.attr.append(attr) self.add_parent(fwd_comm_node, fwd_comp_node) @@ -380,48 +364,44 @@ def convert_hybrid_data_model( # backward pass for idx, layer in enumerate(reversed(layers)): - bwd_ig_comp_node = self.get_comp_node( - layer.name, "BWD_IG", - layer.bwd_ig_comp_time) + bwd_ig_comp_node = self.get_comp_node(layer.name, "BWD_IG", + layer.bwd_ig_comp_time) if idx == 0: - if fwd_comm_node == None: + if fwd_comm_node is None: raise ValueError("fwd_comm_node is None") self.add_parent(bwd_ig_comp_node, fwd_comm_node) else: self.add_parent(bwd_ig_comp_node, - layers[len(layers)-idx].bwd_wg_comp_node) + layers[len(layers) - idx].bwd_wg_comp_node) self.add_parent(bwd_ig_comp_node, - layers[len(layers)-idx].bwd_ig_comm_node) + layers[len(layers) - idx].bwd_ig_comm_node) encode_message(g, bwd_ig_comp_node) if idx != num_layers - 1: - bwd_ig_comm_node = self.get_comm_coll_node( - layer.name + "_IG_COMM_", - layer.bwd_ig_comm_type, - layer.bwd_ig_comm_size) + bwd_ig_comm_node = self.get_comm_coll_node(layer.name + "_IG_COMM_", + layer.bwd_ig_comm_type, + layer.bwd_ig_comm_size) attr = ChakraAttr(name="involved_dim") attr.bool_list.values.append(True) - for _ in range(self.num_dims-1): + for _ in range(self.num_dims - 1): attr.bool_list.values.append(False) bwd_ig_comm_node.attr.append(attr) self.add_parent(bwd_ig_comm_node, bwd_ig_comp_node) layer.bwd_ig_comm_node = bwd_ig_comm_node encode_message(g, bwd_ig_comm_node) - bwd_wg_comp_node = self.get_comp_node( - layer.name, "BWD_WG", - layer.bwd_wg_comp_time) + bwd_wg_comp_node = self.get_comp_node(layer.name, "BWD_WG", + layer.bwd_wg_comp_time) self.add_parent(bwd_wg_comp_node, bwd_ig_comp_node) layer.bwd_wg_comp_node = bwd_wg_comp_node encode_message(g, bwd_wg_comp_node) - bwd_wg_comm_node = self.get_comm_coll_node( - layer.name, - layer.bwd_wg_comm_type, - layer.bwd_wg_comm_size) + bwd_wg_comm_node = self.get_comm_coll_node(layer.name, + layer.bwd_wg_comm_type, + layer.bwd_wg_comm_size) attr = ChakraAttr(name="involved_dim") attr.bool_list.values.append(False) - for _ in range(self.num_dims-1): + for _ in range(self.num_dims - 1): attr.bool_list.values.append(True) bwd_wg_comm_node.attr.append(attr) self.add_parent(bwd_wg_comm_node, bwd_wg_comp_node) @@ -447,22 +427,20 @@ def convert_hybrid_model_data( # forward pass for idx, layer in enumerate(layers): - fwd_comp_node = self.get_comp_node( - layer.name, "FWD", - layer.fwd_comp_time) - if layer.bwd_wg_comm_node != None: + fwd_comp_node = self.get_comp_node(layer.name, "FWD", + layer.fwd_comp_time) + if layer.bwd_wg_comm_node is not None: self.add_parent(fwd_comp_node, layer.bwd_wg_comm_node) if idx != 0: - self.add_parent(fwd_comp_node, layers[idx-1].fwd_comm_node) + self.add_parent(fwd_comp_node, layers[idx - 1].fwd_comm_node) encode_message(g, fwd_comp_node) - fwd_comm_node = self.get_comm_coll_node( - layer.name, - layer.fwd_comm_type, - layer.fwd_comm_size) + fwd_comm_node = self.get_comm_coll_node(layer.name, + layer.fwd_comm_type, + layer.fwd_comm_size) attr = ChakraAttr(name="involved_dim") attr.bool_list.values.append(False) - for _ in range(self.num_dims-1): + for _ in range(self.num_dims - 1): attr.bool_list.values.append(True) fwd_comm_node.attr.append(attr) self.add_parent(fwd_comm_node, fwd_comp_node) @@ -471,46 +449,42 @@ def convert_hybrid_model_data( # backward pass for idx, layer in enumerate(reversed(layers)): - bwd_ig_comp_node = self.get_comp_node( - layer.name, "BWD_IG", - layer.bwd_ig_comp_time) + bwd_ig_comp_node = self.get_comp_node(layer.name, "BWD_IG", + layer.bwd_ig_comp_time) if idx == 0: - if fwd_comm_node == None: + if fwd_comm_node is None: raise ValueError("fwd_comm_node is None") self.add_parent(bwd_ig_comp_node, fwd_comm_node) else: - self.add_parent(bwd_ig_comp_node, layers[len(layers)-idx].bwd_wg_comp_node) - self.add_parent(bwd_ig_comp_node, layers[len(layers)-idx].bwd_ig_comm_node) + self.add_parent(bwd_ig_comp_node, layers[len(layers) - idx].bwd_wg_comp_node) + self.add_parent(bwd_ig_comp_node, layers[len(layers) - idx].bwd_ig_comm_node) encode_message(g, bwd_ig_comp_node) if idx != num_layers - 1: - bwd_ig_comm_node = self.get_comm_coll_node( - layer.name, - layer.bwd_ig_comm_type, - layer.bwd_ig_comm_size) + bwd_ig_comm_node = self.get_comm_coll_node(layer.name, + layer.bwd_ig_comm_type, + layer.bwd_ig_comm_size) attr = ChakraAttr(name="involved_dim") attr.bool_list.values.append(False) - for _ in range(self.num_dims-1): + for _ in range(self.num_dims - 1): attr.bool_list.values.append(True) bwd_ig_comm_node.attr.append(attr) self.add_parent(bwd_ig_comm_node, bwd_ig_comp_node) layer.bwd_ig_comm_node = bwd_ig_comm_node encode_message(g, bwd_ig_comm_node) - bwd_wg_comp_node = self.get_comp_node( - layer.name, "BWD_WG", - layer.bwd_wg_comp_time) + bwd_wg_comp_node = self.get_comp_node(layer.name, "BWD_WG", + layer.bwd_wg_comp_time) self.add_parent(bwd_wg_comp_node, bwd_ig_comp_node) layer.bwd_wg_comp_node = bwd_wg_comp_node encode_message(g, bwd_wg_comp_node) - bwd_wg_comm_node = self.get_comm_coll_node( - layer.name, - layer.bwd_wg_comm_type, - layer.bwd_wg_comm_size) + bwd_wg_comm_node = self.get_comm_coll_node(layer.name, + layer.bwd_wg_comm_type, + layer.bwd_wg_comm_size) attr = ChakraAttr(name="involved_dim") attr.bool_list.values.append(True) - for _ in range(self.num_dims-1): + for _ in range(self.num_dims - 1): attr.bool_list.values.append(False) bwd_wg_comm_node.attr.append(attr) self.add_parent(bwd_wg_comm_node, bwd_wg_comp_node) @@ -537,25 +511,23 @@ def convert_hybrid_dlrm( # forward pass for idx, layer in enumerate(layers): - fwd_comp_node = self.get_comp_node( - layer.name, "FWD", - layer.fwd_comp_time) - if layer.bwd_wg_comm_node != None: + fwd_comp_node = self.get_comp_node(layer.name, "FWD", + layer.fwd_comp_time) + if layer.bwd_wg_comm_node is not None: self.add_parent(fwd_comp_node, layer.bwd_wg_comm_node) - elif layer.bwd_wg_comp_node != None: + elif layer.bwd_wg_comp_node is not None: self.add_parent(fwd_comp_node, layer.bwd_wg_comp_node) if idx != 0: - self.add_parent(fwd_comp_node, layers[idx-1].fwd_comp_node) + self.add_parent(fwd_comp_node, layers[idx - 1].fwd_comp_node) if idx == last_bottom_layer: self.add_parent(fwd_comp_node, layers[0].fwd_comm_node) layer.fwd_comp_node = fwd_comp_node encode_message(g, fwd_comp_node) if layer.fwd_comm_type == "ALLTOALL": - fwd_comm_node = self.get_comm_coll_node( - layer.name, - layer.fwd_comm_type, - layer.fwd_comm_size) + fwd_comm_node = self.get_comm_coll_node(layer.name, + layer.fwd_comm_type, + layer.fwd_comm_size) attr = ChakraAttr(name="involved_dim") for _ in range(self.num_dims): attr.bool_list.values.append(True) @@ -566,28 +538,26 @@ def convert_hybrid_dlrm( # backward pass for idx, layer in enumerate(reversed(layers)): - bwd_wg_comp_node = self.get_comp_node( - layer.name, "BWD_WG", - layer.bwd_wg_comp_time) + bwd_wg_comp_node = self.get_comp_node(layer.name, "BWD_WG", + layer.bwd_wg_comp_time) if idx == 0: - if fwd_comp_node == None: + if fwd_comp_node is None: raise ValueError("fwd_comp_node is None") self.add_parent(bwd_wg_comp_node, fwd_comp_node) else: - if layers[len(layers)-idx].bwd_ig_comp_node != None: + if layers[len(layers) - idx].bwd_ig_comp_node is not None: self.add_parent(bwd_wg_comp_node, - layers[len(layers)-idx].bwd_ig_comp_node) - if layers[len(layers)-idx-1].bwd_ig_comm_node != None: + layers[len(layers) - idx].bwd_ig_comp_node) + if layers[len(layers) - idx - 1].bwd_ig_comm_node is not None: self.add_parent(bwd_wg_comp_node, - layers[len(layers)-idx-1].bwd_ig_comm_node) + layers[len(layers) - idx - 1].bwd_ig_comm_node) layer.bwd_wg_comp_node = bwd_wg_comp_node encode_message(g, bwd_wg_comp_node) if layer.bwd_wg_comm_type != "NONE": - bwd_wg_comm_node = self.get_comm_coll_node( - layer.name, - layer.bwd_wg_comm_type, - layer.bwd_wg_comm_size) + bwd_wg_comm_node = self.get_comm_coll_node(layer.name, + layer.bwd_wg_comm_type, + layer.bwd_wg_comm_size) attr = ChakraAttr(name="involved_dim") for _ in range(self.num_dims): attr.bool_list.values.append(True) @@ -598,23 +568,21 @@ def convert_hybrid_dlrm( bwd_ig_comp_node = None if idx != (len(layers) - 1): - bwd_ig_comp_node = self.get_comp_node( - layer.name, "BWD_IG", - layer.bwd_ig_comp_time) + bwd_ig_comp_node = self.get_comp_node(layer.name, "BWD_IG", + layer.bwd_ig_comp_time) self.add_parent(bwd_ig_comp_node, bwd_wg_comp_node) layer.bwd_ig_comp_node = bwd_ig_comp_node encode_message(g, bwd_ig_comp_node) if (len(layers) - idx - 1) == (last_bottom_layer + 1): - bwd_ig_comm_node = self.get_comm_coll_node( - layers[0].name, - layers[0].bwd_ig_comm_type, - layers[0].bwd_ig_comm_size) + bwd_ig_comm_node = self.get_comm_coll_node(layers[0].name, + layers[0].bwd_ig_comm_type, + layers[0].bwd_ig_comm_size) attr = ChakraAttr(name="involved_dim") for _ in range(self.num_dims): attr.bool_list.values.append(True) bwd_ig_comm_node.attr.append(attr) - if bwd_ig_comp_node == None: + if bwd_ig_comp_node is None: raise ValueError("bwd_ig_comp_node is None") self.add_parent(bwd_ig_comm_node, bwd_ig_comp_node) layers[0].bwd_ig_comm_node = bwd_ig_comm_node From f3769de5c9413a24be9f0aee0c79e590c830813f Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Tue, 6 Feb 2024 10:57:17 -0500 Subject: [PATCH 5/8] timeline_visualizer: Resolve flake8 errors --- timeline_visualizer/timeline_visualizer.py | 105 +++++++-------------- 1 file changed, 32 insertions(+), 73 deletions(-) diff --git a/timeline_visualizer/timeline_visualizer.py b/timeline_visualizer/timeline_visualizer.py index afffb527..592d6a27 100644 --- a/timeline_visualizer/timeline_visualizer.py +++ b/timeline_visualizer/timeline_visualizer.py @@ -17,8 +17,8 @@ class TID(IntEnum): def get_logger(log_filename: str) -> logging.Logger: formatter = logging.Formatter( - "%(levelname)s [%(asctime)s] %(message)s", - datefmt="%m/%d/%Y %I:%M:%S %p") + "%(levelname)s [%(asctime)s] %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p") file_handler = FileHandler(log_filename, mode="w") file_handler.setLevel(logging.DEBUG) @@ -36,21 +36,17 @@ def get_logger(log_filename: str) -> logging.Logger: return logger def is_local_mem_node(node_name: str) -> bool: - if ("MEM_LOAD_NODE" in node_name)\ - and ("LOCAL_MEMORY" in node_name): + if ("MEM_LOAD_NODE" in node_name) and ("LOCAL_MEMORY" in node_name): return True - elif ("MEM_STORE_NODE" in node_name)\ - and ("LOCAL_MEMORY" in node_name): + elif ("MEM_STORE_NODE" in node_name) and ("LOCAL_MEMORY" in node_name): return True else: return False def is_remote_mem_node(node_name: str) -> bool: - if ("MEM_LOAD_NODE" in node_name)\ - and ("REMOTE_MEMORY" in node_name): + if ("MEM_LOAD_NODE" in node_name) and ("REMOTE_MEMORY" in node_name): return True - elif ("MEM_STORE_NODE" in node_name)\ - and ("REMOTE_MEMORY" in node_name): + elif ("MEM_STORE_NODE" in node_name) and ("REMOTE_MEMORY" in node_name): return True else: return False @@ -92,8 +88,8 @@ def parse_event( node_id = int(cols[3].split("=")[1]) node_name = cols[4].split("=")[1] return (trace_type, npu_id, curr_cycle, node_id, node_name) - except: - raise ValueError(f"Cannot parse the following event -- \"{line}\"") + except Exception as e: + raise ValueError(f"Cannot parse the following event -- \"{line}\": {e}") def get_trace_events( input_filename: str, @@ -106,12 +102,10 @@ def get_trace_events( with open(input_filename, "r") as f: for line in f: if ("issue" in line) or ("callback" in line): - (trace_type, npu_id, curr_cycle, node_id, node_name) =\ - parse_event(line) + (trace_type, npu_id, curr_cycle, node_id, node_name) = parse_event(line) if trace_type == "issue": - trace_dict[npu_id].update( - {node_id: [node_name, curr_cycle]}) + trace_dict[npu_id].update({node_id: [node_name, curr_cycle]}) elif trace_type == "callback": node_name = trace_dict[npu_id][node_id][0] tid = get_tid(node_name) @@ -120,16 +114,15 @@ def get_trace_events( duration_in_cycles = curr_cycle - issued_cycle duration_in_ms = duration_in_cycles / (npu_frequency * 1_000) - trace_events.append( - { - "pid": npu_id, - "tid": tid, - "ts": issued_ms, - "dur": duration_in_ms, - "ph": "X", - "name": node_name, - "args": {"ms": duration_in_ms} - }) + trace_events.append({ + "pid": npu_id, + "tid": tid, + "ts": issued_ms, + "dur": duration_in_ms, + "ph": "X", + "name": node_name, + "args": {"ms": duration_in_ms} + }) del trace_dict[npu_id][node_id] else: @@ -137,7 +130,6 @@ def get_trace_events( return trace_events - def write_trace_events( output_filename: str, num_npus: int, @@ -146,64 +138,31 @@ def write_trace_events( output_dict = { "meta_user": "aras", "traceEvents": trace_events, - "meta_user": "aras", "meta_cpu_count": num_npus } with open(output_filename, "w") as f: json.dump(output_dict, f) def main() -> None: - parser = argparse.ArgumentParser( - description="Timeline Visualizer" - ) - parser.add_argument( - "--input_filename", - type=str, - default=None, - required=True, - help="Input timeline filename" - ) - parser.add_argument( - "--output_filename", - type=str, - default=None, - required=True, - help="Output trace filename" - ) - parser.add_argument( - "--num_npus", - type=int, - default=None, - required=True, - help="Number of NPUs in a system" - ) - parser.add_argument( - "--npu_frequency", - type=int, - default=None, - required=True, - help="NPU frequency in MHz" - ) - parser.add_argument( - "--log_filename", - type=str, - default="debug.log", - help="Log filename" - ) + parser = argparse.ArgumentParser(description="Timeline Visualizer") + parser.add_argument("--input_filename", type=str, default=None, required=True, + help="Input timeline filename") + parser.add_argument("--output_filename", type=str, default=None, required=True, + help="Output trace filename") + parser.add_argument("--num_npus", type=int, default=None, required=True, + help="Number of NPUs in a system") + parser.add_argument("--npu_frequency", type=int, default=None, required=True, + help="NPU frequency in MHz") + parser.add_argument("--log_filename", type=str, default="debug.log", + help="Log filename") args = parser.parse_args() logger = get_logger(args.log_filename) logger.debug(" ".join(sys.argv)) try: - trace_events = get_trace_events( - args.input_filename, - args.num_npus, - args.npu_frequency) - write_trace_events( - args.output_filename, - args.num_npus, - trace_events) + trace_events = get_trace_events(args.input_filename, args.num_npus, args.npu_frequency) + write_trace_events(args.output_filename, args.num_npus, trace_events) except Exception as e: logger.error(str(e)) sys.exit(1) From 813a223ecee57accd40f6e996c1597d776ea1ab8 Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Tue, 6 Feb 2024 11:08:07 -0500 Subject: [PATCH 6/8] et_visualizer: Resolve flake8 errors --- et_visualizer/et_visualizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/et_visualizer/et_visualizer.py b/et_visualizer/et_visualizer.py index 4216c1fc..bf4bd54f 100644 --- a/et_visualizer/et_visualizer.py +++ b/et_visualizer/et_visualizer.py @@ -26,7 +26,7 @@ def escape_label(label: str) -> str: str: The escaped label string. """ # Define special characters to escape - special_chars = "{}()<>\[\]|&-" + special_chars = "{}()<>\\[\\]|&-" # Escape special characters return re.sub(f"([{special_chars}])", r"\\\1", label) From e536702f8cd26c04d99c873055bf6cb312e237e7 Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Tue, 6 Feb 2024 12:04:29 -0500 Subject: [PATCH 7/8] et_converter: Resolve pyre errors --- et_converter/flexflow2chakra_converter.py | 9 +++++---- et_converter/pytorch2chakra_converter.py | 9 ++++++--- et_converter/text2chakra_converter.py | 3 ++- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/et_converter/flexflow2chakra_converter.py b/et_converter/flexflow2chakra_converter.py index af69e630..282999da 100644 --- a/et_converter/flexflow2chakra_converter.py +++ b/et_converter/flexflow2chakra_converter.py @@ -7,6 +7,7 @@ from chakra.third_party.utils.protolib import encodeMessage as encode_message from chakra.et_def.et_def_pb2 import ( + NodeType as ChakraNodeType, Node as ChakraNode, AttributeProto as ChakraAttr, COMP_NODE, @@ -60,7 +61,7 @@ def get_name(self, ff_node: Any) -> str: except Exception: raise ValueError(f"Cannot retrieve name from \"{label}\"") - def get_node_type(self, ff_node: Any) -> int: + def get_node_type(self, ff_node: Any) -> ChakraNodeType: label = self.get_label(ff_node) try: node_type = label.split("|")[3].strip() @@ -137,7 +138,7 @@ def convert(self) -> None: src_id = int(edge.get_source().replace("node", "")) dst_id = int(edge.get_destination().replace("node", "")) ck_node = self.node_id_node_dict[dst_id] - ck_node.parent.append(src_id) + ck_node.data_deps.append(src_id) num_ff_edges += 1 self.logger.info(f"Converted {num_ff_nodes} nodes and {num_ff_edges} edges") @@ -198,10 +199,10 @@ def convert(self) -> None: total_comm_nodes += 1 # transfer dependencies - for parent_node_id in ck_node.parent: + for parent_node_id in ck_node.data_deps: parent_node = self.node_id_node_dict[parent_node_id] if self.node_id_npu_id_dict[parent_node.id] == npu_id: - ck_comm_node.parent.append(parent_node_id) + ck_comm_node.data_deps.append(parent_node_id) npu_id_node_id_node_dict[npu_id].update({node_id: ck_comm_node}) self.logger.info(f"NPU[{npu_id}]: {per_npu_comp_nodes} compute nodes and {per_npu_comm_nodes} communication nodes") diff --git a/et_converter/pytorch2chakra_converter.py b/et_converter/pytorch2chakra_converter.py index 6225cec1..43a80730 100644 --- a/et_converter/pytorch2chakra_converter.py +++ b/et_converter/pytorch2chakra_converter.py @@ -3,12 +3,13 @@ import copy import json import logging -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Set +from .pytorch_node import PyTorchNodeType, PyTorchNode from chakra.third_party.utils.protolib import encodeMessage as encode_message -from chakra.et_converter.pytorch_node import PyTorchNodeType, PyTorchNode from chakra.et_def.et_def_pb2 import ( GlobalMetadata, + NodeType as ChakraNodeType, Node as ChakraNode, AttributeProto as ChakraAttr, COMP_NODE, @@ -88,6 +89,7 @@ class PyTorch2ChakraConverter: Attributes: input_filename (str): Input file name containing PyTorch execution trace. output_filename (str): Output file name for the converted Chakra trace. + chakra_et(IO[bytes]): File handle for the Chakra execution trace output file. num_dims (int): Number of dimensions involved in the conversion process. logger (logging.Logger): Logger for logging information during conversion. id_assigner (UniqueIdAssigner): Object to manage unique ID assignments. @@ -127,6 +129,7 @@ def __init__( """ self.input_filename = input_filename self.output_filename = output_filename + self.chakra_et = None self.num_dims = num_dims self.logger = logger self.id_assigner = UniqueIdAssigner() @@ -507,7 +510,7 @@ def convert_to_chakra_node(self, pytorch_node: PyTorchNode) -> ChakraNode: ]) return chakra_node - def get_chakra_node_type_from_pytorch_node(self, pytorch_node: PyTorchNode) -> int: + def get_chakra_node_type_from_pytorch_node(self, pytorch_node: PyTorchNode) -> ChakraNodeType: """ Determines the Chakra node type from a PyTorch node. diff --git a/et_converter/text2chakra_converter.py b/et_converter/text2chakra_converter.py index 73573e38..892e95b6 100644 --- a/et_converter/text2chakra_converter.py +++ b/et_converter/text2chakra_converter.py @@ -6,6 +6,7 @@ from typing import Any, List from chakra.third_party.utils.protolib import encodeMessage as encode_message from chakra.et_def.et_def_pb2 import ( + NodeType, Node, AttributeProto as ChakraAttr, COMP_NODE, @@ -92,7 +93,7 @@ def get_layers( def get_node( self, name: str, - node_type: int + node_type: NodeType ) -> Any: node = Node() node.id = self.next_node_id From adc5f9a6007b468428164e3fd7b5dcabeb11f769 Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Tue, 6 Feb 2024 11:03:05 -0500 Subject: [PATCH 8/8] lint: Add github action for python linting --- .github/workflows/python_lint.yml | 27 +++++++++++++++++++++++++++ .pyre_configuration | 2 +- 2 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/python_lint.yml diff --git a/.github/workflows/python_lint.yml b/.github/workflows/python_lint.yml new file mode 100644 index 00000000..243d479f --- /dev/null +++ b/.github/workflows/python_lint.yml @@ -0,0 +1,27 @@ +name: Python Linting + +on: [push, pull_request] + +jobs: + lint-and-format: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.8' + + - name: Install dependencies + run: | + pip install flake8 + pip install pyre-check + pip install . + + - name: Run Flake8 + run: flake8 . + + - name: Run Pyre Check + run: pyre check diff --git a/.pyre_configuration b/.pyre_configuration index 69234f7f..ee511cff 100644 --- a/.pyre_configuration +++ b/.pyre_configuration @@ -3,5 +3,5 @@ "timeline_visualizer", "et_converter" ], - "search_path": [] + "search_path": ["/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages"] }