diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..639bce4a --- /dev/null +++ b/.flake8 @@ -0,0 +1,35 @@ +[flake8] +enable-extensions = G +select = B,C,E,F,G,P,SIM1,T4,W,B9,TOR0,TOR1,TOR2 +max-line-length = 120 +# C408 ignored because we like the dict keyword argument syntax +# E501 is not flexible enough, we're using B950 instead +ignore = + E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303, + # fix these lints in the future + E275, + # shebang has extra meaning in fbcode lints, so I think it's not worth trying + # to line this up with executable bit + EXE001, + # these ignores are from flake8-bugbear; please fix! + B007,B008,B017,B019,B023,B028,B903,B904,B905,B906,B907 + # these ignores are from flake8-comprehensions; please fix! + C407, + # these ignores are from flake8-logging-format; please fix! + G100,G101,G200,G201,G202 + # these ignores are from flake8-simplify. please fix or ignore with commented reason + SIM105,SIM108,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12, + # flake8-simplify code styles + SIM102,SIM103,SIM106,SIM112, + # TorchFix codes that don't make sense for PyTorch itself: + # removed and deprecated PyTorch functions. + TOR001,TOR101, + # TODO(kit1980): fix all TOR102 issues + # `torch.load` without `weights_only` parameter is unsafe + TOR102, + P201, +per-file-ignores = + __init__.py: F401 +optional-ascii-coding = True +exclude = + ./.git, diff --git a/.github/workflows/python_lint.yml b/.github/workflows/python_lint.yml new file mode 100644 index 00000000..a0d3208e --- /dev/null +++ b/.github/workflows/python_lint.yml @@ -0,0 +1,25 @@ +name: Python Linting + +on: [push, pull_request] + +jobs: + lint-and-format: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.8' + + - name: Install dependencies + run: | + pip install black flake8 + + - name: Check code formatting with Black + run: black --check . + + - name: Run Flake8 + run: flake8 . diff --git a/et_converter/et_converter.py b/et_converter/et_converter.py index 2dd52eee..5fb4f70e 100644 --- a/et_converter/et_converter.py +++ b/et_converter/et_converter.py @@ -10,10 +10,11 @@ from .flexflow2chakra_converter import FlexFlow2ChakraConverter from .pytorch2chakra_converter import PyTorch2ChakraConverter + def get_logger(log_filename: str) -> logging.Logger: formatter = logging.Formatter( - "%(levelname)s [%(asctime)s] %(message)s", - datefmt="%m/%d/%Y %I:%M:%S %p") + "%(levelname)s [%(asctime)s] %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p" + ) file_handler = FileHandler(log_filename, mode="w") file_handler.setLevel(logging.DEBUG) @@ -30,64 +31,60 @@ def get_logger(log_filename: str) -> logging.Logger: return logger + def main() -> None: - parser = argparse.ArgumentParser( - description="Execution Trace Converter" - ) + parser = argparse.ArgumentParser(description="Execution Trace Converter") parser.add_argument( - "--input_type", - type=str, - default=None, - required=True, - help="Input execution trace type" + "--input_type", + type=str, + default=None, + required=True, + help="Input execution trace type", ) parser.add_argument( - "--input_filename", - type=str, - default=None, - required=True, - help="Input execution trace filename" + "--input_filename", + type=str, + default=None, + required=True, + help="Input execution trace filename", ) parser.add_argument( - "--output_filename", - type=str, - default=None, - required=True, - help="Output Chakra execution trace filename" + "--output_filename", + type=str, + default=None, + required=True, + help="Output Chakra execution trace filename", ) parser.add_argument( - "--num_dims", - type=int, - default=None, - required=True, - help="Number of dimensions in the network topology" + "--num_dims", + type=int, + default=None, + required=True, + help="Number of dimensions in the network topology", ) parser.add_argument( - "--num_npus", - type=int, - default=None, - required="Text" in sys.argv, - help="Number of NPUs in a system" + "--num_npus", + type=int, + default=None, + required="Text" in sys.argv, + help="Number of NPUs in a system", ) parser.add_argument( - "--num_passes", - type=int, - default=None, - required="Text" in sys.argv, - help="Number of training passes" + "--num_passes", + type=int, + default=None, + required="Text" in sys.argv, + help="Number of training passes", ) parser.add_argument( - "--npu_frequency", - type=int, - default=None, - required="FlexFlow" in sys.argv, - help="NPU frequency in MHz" + "--npu_frequency", + type=int, + default=None, + required="FlexFlow" in sys.argv, + help="NPU frequency in MHz", ) parser.add_argument( - "--log_filename", - type=str, - default="debug.log", - help="Log filename" + "--log_filename", type=str, default="debug.log", help="Log filename" ) args = parser.parse_args() @@ -97,27 +94,27 @@ def main() -> None: try: if args.input_type == "Text": converter = Text2ChakraConverter( - args.input_filename, - args.output_filename, - args.num_dims, - args.num_npus, - args.num_passes, - logger) + args.input_filename, + args.output_filename, + args.num_dims, + args.num_npus, + args.num_passes, + logger, + ) converter.convert() elif args.input_type == "FlexFlow": converter = FlexFlow2ChakraConverter( - args.input_filename, - args.output_filename, - args.num_dims, - args.npu_frequency, - logger) + args.input_filename, + args.output_filename, + args.num_dims, + args.npu_frequency, + logger, + ) converter.convert() elif args.input_type == "PyTorch": converter = PyTorch2ChakraConverter( - args.input_filename, - args.output_filename, - args.num_dims, - logger) + args.input_filename, args.output_filename, args.num_dims, logger + ) converter.convert() else: logger.error(f"{args.input_type} unsupported") @@ -127,5 +124,6 @@ def main() -> None: logger.debug(traceback.format_exc()) sys.exit(1) + if __name__ == "__main__": main() diff --git a/et_converter/flexflow2chakra_converter.py b/et_converter/flexflow2chakra_converter.py index abb7884a..660836c6 100644 --- a/et_converter/flexflow2chakra_converter.py +++ b/et_converter/flexflow2chakra_converter.py @@ -14,6 +14,7 @@ COMM_RECV_NODE, ) + class FlexFlow2ChakraConverter: def __init__( self, @@ -21,7 +22,7 @@ def __init__( output_filename: str, num_dims: int, npu_frequency: int, - logger: logging.Logger + logger: logging.Logger, ) -> None: self.input_filename = input_filename self.output_filename = output_filename @@ -35,30 +36,30 @@ def __init__( def get_label(self, ff_node: Any) -> str: try: label = ff_node.get_attributes()["label"] - return label.replace("\"", "")[1:-1] - except: - raise ValueError(f"Cannot retrieve label from a FlexFlow node") + return label.replace('"', "")[1:-1] + except Exception: + raise ValueError("Cannot retrieve label from a FlexFlow node") def get_id(self, ff_node: Any) -> int: ff_node_name = ff_node.get_name() try: return int(ff_node_name.replace("node", "")) - except: - raise ValueError(f"Cannot retrieve id from \"{ff_node_name}\"") + except Exception: + raise ValueError(f'Cannot retrieve id from "{ff_node_name}"') def get_npu_id(self, ff_node: Any) -> int: label = self.get_label(ff_node) try: return int(label.split("|")[0].strip().split("=")[1]) - except: - raise ValueError(f"Cannot retrieve npu_id from \"{label}\"") + except Exception: + raise ValueError(f'Cannot retrieve npu_id from "{label}"') def get_name(self, ff_node: Any) -> str: label = self.get_label(ff_node) try: return label.split("|")[1].strip() - except: - raise ValueError(f"Cannot retrieve name from \"{label}\"") + except Exception: + raise ValueError(f'Cannot retrieve name from "{label}"') def get_node_type(self, ff_node: Any) -> int: label = self.get_label(ff_node) @@ -69,38 +70,38 @@ def get_node_type(self, ff_node: Any) -> int: elif node_type == "COMM_SEND_RECV_NODE": return COMM_SEND_NODE else: - raise ValueError(f"Unsupported node_type, \"{node_type}\"") - except: - raise ValueError(f"Cannot retrieve node_type from \"{label}\"") + raise ValueError(f'Unsupported node_type, "{node_type}"') + except Exception: + raise ValueError(f'Cannot retrieve node_type from "{label}"') def get_runtime(self, ff_node: Any) -> int: label = self.get_label(ff_node) try: wall_clock_time = float(label.split("|")[4].strip().split("=")[1]) return int(round(wall_clock_time * self.num_cycles_per_sec)) - except: - raise ValueError(f"Cannot retrieve runtime from \"{label}\"") + except Exception: + raise ValueError(f'Cannot retrieve runtime from "{label}"') def get_comm_src(self, ff_node: Any) -> int: label = self.get_label(ff_node) try: return int(label.split("|")[4].strip().split("=")[1]) - except: - raise ValueError(f"Cannot retrieve comm_src from \"{label}\"") + except Exception: + raise ValueError(f'Cannot retrieve comm_src from "{label}"') def get_comm_dst(self, ff_node: Any) -> int: label = self.get_label(ff_node) try: return int(label.split("|")[5].strip().split("=")[1]) - except: - raise ValueError(f"Cannot retrieve comm_dst from \"{label}\"") + except Exception: + raise ValueError(f'Cannot retrieve comm_dst from "{label}"') def get_comm_size(self, ff_node: Any) -> int: label = self.get_label(ff_node) try: return int(label.split("|")[6].strip().split("=")[1]) - except: - raise ValueError(f"Cannot retrieve comm_size from \"{label}\"") + except Exception: + raise ValueError(f'Cannot retrieve comm_size from "{label}"') def convert_FF_node_to_CK_node(self, ff_node: Any) -> Any: ck_node = ChakraNode() @@ -111,9 +112,15 @@ def convert_FF_node_to_CK_node(self, ff_node: Any) -> Any: ck_node.duration_micros = self.get_runtime(ff_node) elif ck_node.type == COMM_SEND_NODE: self.node_id_comm_info_dict[ck_node.id] = {} - self.node_id_comm_info_dict[ck_node.id]["comm_src"] = self.get_comm_src(ff_node) - self.node_id_comm_info_dict[ck_node.id]["comm_dst"] = self.get_comm_dst(ff_node) - self.node_id_comm_info_dict[ck_node.id]["comm_size"] = self.get_comm_size(ff_node) + self.node_id_comm_info_dict[ck_node.id]["comm_src"] = self.get_comm_src( + ff_node + ) + self.node_id_comm_info_dict[ck_node.id]["comm_dst"] = self.get_comm_dst( + ff_node + ) + self.node_id_comm_info_dict[ck_node.id]["comm_size"] = self.get_comm_size( + ff_node + ) self.node_id_npu_id_dict.update({ck_node.id: self.get_npu_id(ff_node)}) return ck_node @@ -163,9 +170,12 @@ def convert(self) -> None: total_comp_nodes += 1 # communication nodes - elif (ck_node.type == COMM_SEND_NODE): - if (self.node_id_comm_info_dict[ck_node.id]["comm_src"] == npu_id)\ - or (self.node_id_comm_info_dict[ck_node.id]["comm_dst"] == npu_id): + elif ck_node.type == COMM_SEND_NODE: + if ( + self.node_id_comm_info_dict[ck_node.id]["comm_src"] == npu_id + ) or ( + self.node_id_comm_info_dict[ck_node.id]["comm_dst"] == npu_id + ): comm_src = self.node_id_comm_info_dict[ck_node.id]["comm_src"] comm_dst = self.node_id_comm_info_dict[ck_node.id]["comm_dst"] comm_key = f"{ck_node.id}_{comm_src}_{comm_dst}" @@ -179,26 +189,47 @@ def convert(self) -> None: # create a new communication node ck_comm_node = ChakraNode() ck_comm_node.id = ck_node.id - if self.node_id_comm_info_dict[ck_node.id]["comm_src"] == npu_id: + if ( + self.node_id_comm_info_dict[ck_node.id]["comm_src"] + == npu_id + ): ck_comm_node.name = "COMM_SEND_NODE" ck_comm_node.type = COMM_SEND_NODE - elif self.node_id_comm_info_dict[ck_node.id]["comm_dst"] == npu_id: + elif ( + self.node_id_comm_info_dict[ck_node.id]["comm_dst"] + == npu_id + ): ck_comm_node.name = "COMM_RECV_NODE" ck_comm_node.type = COMM_RECV_NODE ck_comm_node.name += f"_{ck_node.name}" ck_comm_node.attr.append( - ChakraAttr(name="comm_src", - int64_val=self.node_id_comm_info_dict[ck_node.id]["comm_src"])) + ChakraAttr( + name="comm_src", + int64_val=self.node_id_comm_info_dict[ck_node.id][ + "comm_src" + ], + ) + ) ck_comm_node.attr.append( - ChakraAttr(name="comm_dst", - int64_val=self.node_id_comm_info_dict[ck_node.id]["comm_dst"])) + ChakraAttr( + name="comm_dst", + int64_val=self.node_id_comm_info_dict[ck_node.id][ + "comm_dst" + ], + ) + ) ck_comm_node.attr.append( - ChakraAttr(name="comm_size", - int64_val=self.node_id_comm_info_dict[ck_node.id]["comm_size"])) + ChakraAttr( + name="comm_size", + int64_val=self.node_id_comm_info_dict[ck_node.id][ + "comm_size" + ], + ) + ) ck_comm_node.attr.append( - ChakraAttr(name="comm_tag", - int64_val=comm_tag)) + ChakraAttr(name="comm_tag", int64_val=comm_tag) + ) per_npu_comm_nodes += 1 total_comm_nodes += 1 @@ -210,8 +241,12 @@ def convert(self) -> None: ck_comm_node.parent.append(parent_node_id) npu_id_node_id_node_dict[npu_id].update({node_id: ck_comm_node}) - self.logger.info(f"NPU[{npu_id}]: {per_npu_comp_nodes} compute nodes and {per_npu_comm_nodes} communication nodes") - self.logger.info(f"Total: {total_comp_nodes} compute nodes and {total_comm_nodes} communication nodes") + self.logger.info( + f"NPU[{npu_id}]: {per_npu_comp_nodes} compute nodes and {per_npu_comm_nodes} communication nodes" + ) + self.logger.info( + f"Total: {total_comp_nodes} compute nodes and {total_comm_nodes} communication nodes" + ) # write per-NPU Chakra graphs for npu_id in sorted(npu_id_node_id_node_dict.keys()): diff --git a/et_converter/pytorch2chakra_converter.py b/et_converter/pytorch2chakra_converter.py index 2a1ece74..54ba9c1e 100644 --- a/et_converter/pytorch2chakra_converter.py +++ b/et_converter/pytorch2chakra_converter.py @@ -27,7 +27,7 @@ class PyTorchNodeType(Enum): CPU_OP = 1 GPU_OP = 2 - LABEL = 3 # Non-operator nodes + LABEL = 3 # Non-operator nodes class PyTorch2ChakraConverter: @@ -36,7 +36,7 @@ def __init__( input_filename: str, output_filename: str, num_dims: int, - logger: logging.Logger + logger: logging.Logger, ) -> None: try: self.pytorch_et = open(input_filename, "r") @@ -94,11 +94,13 @@ def __init__( # Otherwise, a tensor ID should be utilized. # --------------------------------------------------------------------- # Mapping between storage_id and nid - self.input_storage_id_nid_dict = {} # storage_id is an input of a node with nid - self.output_storage_id_nid_dict = {} # storage_id is an output of a node with nid + self.input_storage_id_nid_dict = {} # storage_id is an input of a node with nid + self.output_storage_id_nid_dict = ( + {} + ) # storage_id is an output of a node with nid # Mapping between tensor_id and nid - self.input_tensor_id_nid_dict = {} # tensor_id is an input of a node with nid - self.output_tensor_id_nid_dict = {} # tensor_id is an output of a node with nid + self.input_tensor_id_nid_dict = {} # tensor_id is an input of a node with nid + self.output_tensor_id_nid_dict = {} # tensor_id is an output of a node with nid def __del__(self): if self.pytorch_et and not self.pytorch_et.closed: @@ -107,9 +109,7 @@ def __del__(self): self.chakra_et.close() @staticmethod - def is_valid_tensor( - obj: Any - ) -> bool: + def is_valid_tensor(obj: Any) -> bool: """ Returns true if a given object is a valid tensor. @@ -119,9 +119,7 @@ def is_valid_tensor( return isinstance(obj, list) and (len(obj) == 6) @staticmethod - def get_storage_id_from_tensor( - tensor: List[Any] - ) -> int: + def get_storage_id_from_tensor(tensor: List[Any]) -> int: """ Returns the storage ID of a tensor. """ @@ -130,9 +128,7 @@ def get_storage_id_from_tensor( return tensor[1] @staticmethod - def get_tensor_id_from_tensor( - tensor: List[Any] - ) -> int: + def get_tensor_id_from_tensor(tensor: List[Any]) -> int: """ Returns the tensor ID of a tensor. """ @@ -140,10 +136,7 @@ def get_tensor_id_from_tensor( raise IndexError("Index out of bounds") return tensor[0] - def has_valid_storage_id( - self, - tensor: List[Any] - ) -> bool: + def has_valid_storage_id(self, tensor: List[Any]) -> bool: """ Returns true if a given tensor has a valid storage ID. @@ -154,85 +147,66 @@ def has_valid_storage_id( return storage_id > 0 @staticmethod - def has_cat_field( - node: Dict[str, Any] - ) -> bool: + def has_cat_field(node: Dict[str, Any]) -> bool: """ Returns true if a PyTorch node has a category field. """ return "cat" in node.keys() @staticmethod - def get_cat_field( - node: Dict[str, Any] - ) -> bool: + def get_cat_field(node: Dict[str, Any]) -> bool: """ Returns the category field of a given PyTorch node. """ return node["cat"] @staticmethod - def has_dur( - node: Dict[str, Any] - ) -> bool: + def has_dur(node: Dict[str, Any]) -> bool: """ Returns true if a PyTorch node has a duration field. """ return "dur" in node.keys() - def get_pytorch_node_type( - self, - node: Dict[str, Any] - ) -> PyTorchNodeType: + def get_pytorch_node_type(self, node: Dict[str, Any]) -> PyTorchNodeType: if self.is_gpu_op(node): return PyTorchNodeType.GPU_OP - elif (node["op_schema"] or node["outputs"])\ - or ("c10d::" in node["name"] or ("nccl:" in node["name"])): + elif (node["op_schema"] or node["outputs"]) or ( + "c10d::" in node["name"] or ("nccl:" in node["name"]) + ): return PyTorchNodeType.CPU_OP else: return PyTorchNodeType.LABEL @staticmethod - def is_record_param_comms_node( - node: Dict[str, Any] - ) -> bool: + def is_record_param_comms_node(node: Dict[str, Any]) -> bool: """ Returns true if a PyToch node has "record_param_comms" in its name. """ return "record_param_comms" in node["name"] @staticmethod - def is_nccl_node( - node: Dict[str, Any] - ) -> bool: + def is_nccl_node(node: Dict[str, Any]) -> bool: """ Returns true if a PyToch node is a NCCL node. """ return "nccl:" in node["name"] - def is_cpu_op_with_dur( - self, - node: Dict[str, Any] - ) -> bool: + def is_cpu_op_with_dur(self, node: Dict[str, Any]) -> bool: """ Returns true if a PyTorch node is a CPU operator and has a duration field. """ - return (self.get_pytorch_node_type(node) == PyTorchNodeType.CPU_OP)\ - and self.has_dur(node) + return ( + self.get_pytorch_node_type(node) == PyTorchNodeType.CPU_OP + ) and self.has_dur(node) - def is_cpu_op( - self, - node: Dict[str, Any] - ) -> bool: + def is_cpu_op(self, node: Dict[str, Any]) -> bool: """ Takes a PyTorch node and returns true if the node is a CPU operator. """ return self.get_pytorch_node_type(node) == PyTorchNodeType.CPU_OP @staticmethod - def get_collective_comm_type( - node: Dict[str, Any] - ) -> int: + def get_collective_comm_type(node: Dict[str, Any]) -> int: """ Returns the collective communication type of a given PyTorch node. """ @@ -252,9 +226,7 @@ def get_collective_comm_type( return INVALID_COMM @staticmethod - def get_data_type_size( - data_type: str - ) -> int: + def get_data_type_size(data_type: str) -> int: """ Returns the data type size of a given data type in string. @@ -263,37 +235,34 @@ def get_data_type_size( * https://github.com/pytorch/pytorch/blob/master/c10/util/Half.h """ data_type_size_dict = { - "Tensor(float32)": 4, - "Tensor(float)": 4, - "Tensor(float64)": 8, - "Tensor(double)": 8, - "Tensor(float16)": 2, - "Tensor(half)": 2, - "Tensor(bfloat16)": 2, - "Tensor(complex64)": 8, - "Tensor(complex128)": 16, - "Tensor(uint8)": 1, - "Tensor(int8)": 1, - "Tensor(int16)": 2, - "Tensor(short)": 2, - "Tensor(int32)": 4, - "Tensor(int)": 4, - "Tensor(int64)": 8, - "Tensor(long)": 8, - "Tensor(c10::Half)": 2, - "Tensor(unsigned char)": 1, - "Tensor(long int)": 8, + "Tensor(float32)": 4, + "Tensor(float)": 4, + "Tensor(float64)": 8, + "Tensor(double)": 8, + "Tensor(float16)": 2, + "Tensor(half)": 2, + "Tensor(bfloat16)": 2, + "Tensor(complex64)": 8, + "Tensor(complex128)": 16, + "Tensor(uint8)": 1, + "Tensor(int8)": 1, + "Tensor(int16)": 2, + "Tensor(short)": 2, + "Tensor(int32)": 4, + "Tensor(int)": 4, + "Tensor(int64)": 8, + "Tensor(long)": 8, + "Tensor(c10::Half)": 2, + "Tensor(unsigned char)": 1, + "Tensor(long int)": 8, } try: data_type_size = data_type_size_dict[data_type] return data_type_size - except: + except Exception: raise ValueError(f"{data_type} is unsupported") - def get_chakra_node_type_from_pytorch_node( - self, - node: Dict[str, Any] - ) -> int: + def get_chakra_node_type_from_pytorch_node(self, node: Dict[str, Any]) -> int: if self.has_cat_field(node) and ("ncclKernel" in node["name"]): return COMM_COLL_NODE elif self.has_cat_field(node): @@ -304,19 +273,13 @@ def get_chakra_node_type_from_pytorch_node( return COMP_NODE return INVALID_NODE - def has_gpu_op( - self, - nid: int - ) -> bool: + def has_gpu_op(self, nid: int) -> bool: """ Returns true if a Chakra node has any associated GPU operator. """ return nid in self.pt_gpu_node_dict.keys() - def get_comm_size( - self, - node: Dict[str, Any] - ) -> int: + def get_comm_size(self, node: Dict[str, Any]) -> int: """ Calculates the communication size for a given input_type and input_shape. """ @@ -328,9 +291,7 @@ def get_comm_size( comm_size = comm_size * input_shape_inner return comm_size - def sort_pytorch_nodes_with_starting_time( - self - ) -> None: + def sort_pytorch_nodes_with_starting_time(self) -> None: """ Sorts PyTorch nodes with their starting time ("ts"). @@ -338,17 +299,14 @@ def sort_pytorch_nodes_with_starting_time( """ self.pt_nodes = sorted(self.pt_nodes, key=lambda kv: kv["ts"]) - def get_total_runtime_ms( - self, - pt_node_list: List[Any] - ) -> int: + def get_total_runtime_ms(self, pt_node_list: List[Any]) -> int: """ Returns the total runtime of PyTorch CPU operators with a duration field. """ total_runtime_ms = 0 for pt_node in pt_node_list: if self.is_cpu_op_with_dur(pt_node): - total_runtime_ms += pt_node["dur"] # in milliseconds + total_runtime_ms += pt_node["dur"] # in milliseconds return total_runtime_ms def get_prev_inter_phase_dep_nid( @@ -370,9 +328,7 @@ def get_prev_inter_phase_dep_nid( return self.inter_phase_dependency[index - 1] @staticmethod - def find_root_nids( - nodes: List[Any] - ) -> int: + def find_root_nids(nodes: List[Any]) -> int: """ Finds a root node and return its NID. @@ -387,9 +343,7 @@ def find_root_nids( return root_nids @staticmethod - def is_label_node( - node: Dict[str, Any] - ) -> bool: + def is_label_node(node: Dict[str, Any]) -> bool: """ Returns true if a given PyTorch node is a label node. @@ -397,17 +351,10 @@ def is_label_node( """ return node["name"].startswith("## ") - def is_phase_root_node( - self, - root_nids: List[int], - node: Dict[str, Any] - ) -> bool: + def is_phase_root_node(self, root_nids: List[int], node: Dict[str, Any]) -> bool: return node["parent"] in root_nids - def is_gpu_op( - self, - node: Dict[str, Any] - ) -> bool: + def is_gpu_op(self, node: Dict[str, Any]) -> bool: """ Takes a PyTorch node and returns true if it is a GPU operator. @@ -463,9 +410,7 @@ def dfs( raise ValueError(f"Invalid node type: {node_type}") return -1 - def discover_pytorch_cpu_ops( - self - ) -> None: + def discover_pytorch_cpu_ops(self) -> None: """ Discovers PyTorch CPU operators and populate pt_cpu_node_dict. @@ -484,9 +429,9 @@ def discover_pytorch_cpu_ops( def assign_chakra_ids( self, - total_assigned_ids: Dict[int,bool], + total_assigned_ids: Dict[int, bool], assigned_ids: List[int], - initial_id_to_assign: int + initial_id_to_assign: int, ) -> int: """ This function is used to assign unique ids to the ops. During the conversion, we may decompose an op into multiple @@ -521,43 +466,56 @@ def merge_gpu_ops_with_cpu_ops( decomposed_nodes_dep = {} for nid, node in self.pt_cpu_node_dict.items(): if self.has_gpu_op(nid): - self.pt_gpu_node_dict[nid] = sorted(self.pt_gpu_node_dict[nid], key=lambda kv: kv["ts"]) + self.pt_gpu_node_dict[nid] = sorted( + self.pt_gpu_node_dict[nid], key=lambda kv: kv["ts"] + ) for gpu_node in self.pt_gpu_node_dict[nid]: assert (node["ts"] + node["dur"]) > gpu_node["ts"] last_ts = node["ts"] - for i in range(len(self.pt_gpu_node_dict[nid])+1): + for i in range(len(self.pt_gpu_node_dict[nid]) + 1): copy_node = copy.deepcopy(node) - copy_node["id"] = self.assign_chakra_ids(total_assigned_ids, assigned_ids, nid) - copy_node["name"] = copy_node["name"]+"("+str(i)+")" + copy_node["id"] = self.assign_chakra_ids( + total_assigned_ids, assigned_ids, nid + ) + copy_node["name"] = copy_node["name"] + "(" + str(i) + ")" if i < len(self.pt_gpu_node_dict[nid]): - self.pt_gpu_node_dict[nid][i]["id"] =\ - self.assign_chakra_ids( - total_assigned_ids, - assigned_ids, - self.pt_gpu_node_dict[nid][i]["id"]) + self.pt_gpu_node_dict[nid][i]["id"] = self.assign_chakra_ids( + total_assigned_ids, + assigned_ids, + self.pt_gpu_node_dict[nid][i]["id"], + ) assert self.pt_gpu_node_dict[nid][i]["ts"] > copy_node["ts"] copy_node["ts"] = last_ts - copy_node["dur"] = self.pt_gpu_node_dict[nid][i]["ts"]-last_ts + copy_node["dur"] = self.pt_gpu_node_dict[nid][i]["ts"] - last_ts last_ts = self.pt_gpu_node_dict[nid][i]["ts"] - new_pt_gpu_node_dict.setdefault(copy_node["id"], []).append(self.pt_gpu_node_dict[nid][i]) + new_pt_gpu_node_dict.setdefault(copy_node["id"], []).append( + self.pt_gpu_node_dict[nid][i] + ) else: - copy_node["dur"] = copy_node["dur"]-(last_ts-copy_node["ts"]) + copy_node["dur"] = copy_node["dur"] - ( + last_ts - copy_node["ts"] + ) copy_node["ts"] = last_ts - last_ts = copy_node["ts"]+copy_node["dur"] + last_ts = copy_node["ts"] + copy_node["dur"] assert (copy_node["ts"] >= 0) and (copy_node["dur"] > 0) if i > 0: assert copy_node["ts"] > decomposed_nodes[-1]["ts"] - decomposed_nodes_dep[copy_node["id"]] = decomposed_nodes[-1]["id"] + decomposed_nodes_dep[copy_node["id"]] = decomposed_nodes[-1][ + "id" + ] decomposed_nodes.append(copy_node) else: - node["id"] = self.assign_chakra_ids(total_assigned_ids, assigned_ids, nid) + node["id"] = self.assign_chakra_ids( + total_assigned_ids, assigned_ids, nid + ) decomposed_nodes.append(node) merged_pt_cpu_node_dict = { - decomposed_node["id"]: decomposed_node for decomposed_node in decomposed_nodes + decomposed_node["id"]: decomposed_node + for decomposed_node in decomposed_nodes } self.pt_cpu_node_dict = merged_pt_cpu_node_dict @@ -584,10 +542,7 @@ def validate_pt_node_dict( if nid in self.pt_gpu_node_dict.keys(): assert len(self.pt_gpu_node_dict[nid]) == 1 - def discover_pytorch_comm_ops( - self, - assigned_ids: List[int] - ) -> None: + def discover_pytorch_comm_ops(self, assigned_ids: List[int]) -> None: """ Discovers communication nodes and populate pt_record_param_comms_node_dict and pt_nccl_node_dict. @@ -603,22 +558,20 @@ def discover_pytorch_comm_ops( self.pt_record_param_comms_node_dict.update({node["parent"]: node}) if self.is_nccl_node(node): if node["parent"] in assigned_ids.keys(): - nodes_to_assign=assigned_ids[node["parent"]] - for parent_id in nodes_to_assign: - self.pt_nccl_node_dict.update({parent_id: node}) + nodes_to_assign = assigned_ids[node["parent"]] + for parent_id in nodes_to_assign: + self.pt_nccl_node_dict.update({parent_id: node}) else: self.pt_nccl_node_dict.update({node["parent"]: node}) for i in range(len(self.inter_phase_dependency)): # If an op is decomposed into multiple sub_ops, we want to point to the last subop [-1] - self.inter_phase_dependency[i] = assigned_ids[self.inter_phase_dependency[i]][-1] + self.inter_phase_dependency[i] = assigned_ids[ + self.inter_phase_dependency[i] + ][-1] self.inter_phase_dependency.sort() - def update_input_tensor_dict( - self, - nid: int, - inputs: str - ) -> int: + def update_input_tensor_dict(self, nid: int, inputs: str) -> int: """ Updates input_storage_id_nid_dict and input_tensor_id_nid_dict @@ -630,16 +583,14 @@ def update_input_tensor_dict( if self.is_valid_tensor(i): if self.has_valid_storage_id(i): storage_id = self.get_storage_id_from_tensor(i) - self.input_storage_id_nid_dict.setdefault(storage_id, []).append(nid) + self.input_storage_id_nid_dict.setdefault(storage_id, []).append( + nid + ) else: tensor_id = self.get_tensor_id_from_tensor(i) self.input_tensor_id_nid_dict.setdefault(tensor_id, []).append(nid) - def update_output_tensor_dict( - self, - nid: int, - outputs: str - ) -> int: + def update_output_tensor_dict(self, nid: int, outputs: str) -> int: """ Updates output_storage_id_nid_dict and output_tensor_id_nid_dict. @@ -651,14 +602,15 @@ def update_output_tensor_dict( if self.is_valid_tensor(o): if self.has_valid_storage_id(o): storage_id = self.get_storage_id_from_tensor(o) - self.output_storage_id_nid_dict.setdefault(storage_id, []).append(nid) + self.output_storage_id_nid_dict.setdefault(storage_id, []).append( + nid + ) else: tensor_id = self.get_tensor_id_from_tensor(o) self.output_tensor_id_nid_dict.setdefault(tensor_id, []).append(nid) def convert_pytorch_node_to_chakra_node( - self, - pt_node: Dict[str, Any] + self, pt_node: Dict[str, Any] ) -> ChakraNode: """ Converts a PyToch node to a Chakra node. @@ -679,42 +631,29 @@ def convert_pytorch_node_to_chakra_node( ck_node.outputs.shapes = str(pt_node["output_shapes"]) ck_node.outputs.types = str(pt_node["output_types"]) ck_node.attr.append( - ChakraAttr(name="is_cpu_op", - bool_val=self.is_cpu_op(pt_node))) + ChakraAttr(name="is_cpu_op", bool_val=self.is_cpu_op(pt_node)) + ) if "fw_parent" in pt_node.keys(): ck_node.attr.append( - ChakraAttr(name="fw_parent", - int64_val=pt_node["fw_parent"])) + ChakraAttr(name="fw_parent", int64_val=pt_node["fw_parent"]) + ) if "fw_tid" in pt_node.keys(): - ck_node.attr.append( - ChakraAttr(name="fw_tid", - int64_val=pt_node["fw_tid"])) + ck_node.attr.append(ChakraAttr(name="fw_tid", int64_val=pt_node["fw_tid"])) if "op_schema" in pt_node.keys(): ck_node.attr.append( - ChakraAttr(name="op_schema", - string_val=pt_node["op_schema"])) + ChakraAttr(name="op_schema", string_val=pt_node["op_schema"]) + ) if "seq_id" in pt_node.keys(): - ck_node.attr.append( - ChakraAttr(name="seq_id", - int64_val=pt_node["seq_id"])) + ck_node.attr.append(ChakraAttr(name="seq_id", int64_val=pt_node["seq_id"])) if "rf_id" in pt_node.keys(): - ck_node.attr.append( - ChakraAttr(name="rf_id", - int64_val=pt_node["rf_id"])) + ck_node.attr.append(ChakraAttr(name="rf_id", int64_val=pt_node["rf_id"])) if "scope" in pt_node.keys(): - ck_node.attr.append( - ChakraAttr(name="scope", - int64_val=pt_node["scope"])) + ck_node.attr.append(ChakraAttr(name="scope", int64_val=pt_node["scope"])) if "tid" in pt_node.keys(): - ck_node.attr.append( - ChakraAttr(name="tid", - int64_val=pt_node["tid"])) + ck_node.attr.append(ChakraAttr(name="tid", int64_val=pt_node["tid"])) return ck_node - def get_nccl_node( - self, - nid: int - ) -> Dict[str, Any]: + def get_nccl_node(self, nid: int) -> Dict[str, Any]: """ Returns a PyTorch NCCL node for a given Chakra NID. @@ -735,8 +674,9 @@ def get_nccl_node( pt_nccl_node = self.pt_nccl_node_dict[rpcp_nid] else: raise ValueError( - f"NID {nid} has a pt_record_param_comms_node " - f"but it does not have a correspondin pt_nccl_node.") + f"NID {nid} has a pt_record_param_comms_node " + f"but it does not have a correspondin pt_nccl_node." + ) elif nid in self.pt_nccl_node_dict.keys(): pt_nccl_node = self.pt_nccl_node_dict[nid] else: @@ -746,26 +686,28 @@ def get_nccl_node( ) return pt_nccl_node - def add_gpu_chakra_node( - self, - ck_cpu_node: ChakraNode - ) -> None: + def add_gpu_chakra_node(self, ck_cpu_node: ChakraNode) -> None: """ Converts a PyTorch GPU node to a Chakra node and add it to ck_node_dict. """ assert ck_cpu_node.id in self.pt_gpu_node_dict.keys() pt_gpu_node = self.pt_gpu_node_dict[ck_cpu_node.id][0] if len(self.pt_gpu_node_dict[ck_cpu_node.id]) != 1: - raise ValueError(f"Chakra node {ck_cpu_node.id} has more than one GPU operators") + raise ValueError( + f"Chakra node {ck_cpu_node.id} has more than one GPU operators" + ) ck_gpu_node = self.convert_pytorch_node_to_chakra_node(pt_gpu_node) if ck_cpu_node.type == COMM_COLL_NODE: pt_nccl_node = self.get_nccl_node(ck_cpu_node.id) ck_gpu_node.attr.append( - ChakraAttr(name="comm_type", - int64_val=self.get_collective_comm_type(pt_nccl_node))) + ChakraAttr( + name="comm_type", + int64_val=self.get_collective_comm_type(pt_nccl_node), + ) + ) ck_gpu_node.attr.append( - ChakraAttr(name="comm_size", - int64_val=self.get_comm_size(pt_nccl_node))) + ChakraAttr(name="comm_size", int64_val=self.get_comm_size(pt_nccl_node)) + ) attr = ChakraAttr(name="involved_dim") for _ in range(self.num_dims): attr.bool_list.values.append(True) @@ -773,9 +715,7 @@ def add_gpu_chakra_node( ck_gpu_node.data_deps.append(ck_cpu_node.id) self.ck_node_dict[ck_gpu_node.id] = ck_gpu_node - def identify_data_dependency_with_storage_id( - self - ) -> None: + def identify_data_dependency_with_storage_id(self) -> None: """ Identifies data dependency between operators with storage IDs. """ @@ -786,13 +726,12 @@ def identify_data_dependency_with_storage_id( for child_nid in child_nids: for parent_nid in parent_nids: child_node = self.ck_node_dict[child_nid] - if (parent_nid not in child_node.data_deps)\ - and (parent_nid < child_nid): + if (parent_nid not in child_node.data_deps) and ( + parent_nid < child_nid + ): child_node.data_deps.append(parent_nid) - def identify_data_dependency_with_tensor_id( - self - ) -> None: + def identify_data_dependency_with_tensor_id(self) -> None: """ Identifies data dependency between operators with tensor IDs. """ @@ -803,13 +742,12 @@ def identify_data_dependency_with_tensor_id( for child_nid in child_nids: for parent_nid in parent_nids: child_node = self.ck_node_dict[child_nid] - if (parent_nid not in child_node.data_deps)\ - and (parent_nid < child_nid): + if (parent_nid not in child_node.data_deps) and ( + parent_nid < child_nid + ): child_node.data_deps.append(parent_nid) - def identify_data_dependency( - self - ) -> None: + def identify_data_dependency(self) -> None: """ Identifies data dependency between operators using tensors. @@ -834,7 +772,7 @@ def write_chakra_et( ChakraAttr(name="pid", uint64_val=self.pt_pid), ChakraAttr(name="time", string_val=self.pt_time), ChakraAttr(name="start_ts", uint64_val=self.pt_start_ts), - ChakraAttr(name="finish_ts", uint64_val=self.pt_finish_ts) + ChakraAttr(name="finish_ts", uint64_val=self.pt_finish_ts), ] ) encode_message(self.chakra_et, md) @@ -851,15 +789,17 @@ def write_chakra_et( self.logger.info("All Chakra nodes are written to the output file") - def convert( - self - ) -> None: + def convert(self) -> None: self.sort_pytorch_nodes_with_starting_time() self.discover_pytorch_cpu_ops() - total_runtime_ns = self.get_total_runtime_ms(list(self.pt_cpu_node_dict.values())) * 1000 - self.logger.info(f"Total runtime exluding children operators: {total_runtime_ns} ns") + total_runtime_ns = ( + self.get_total_runtime_ms(list(self.pt_cpu_node_dict.values())) * 1000 + ) + self.logger.info( + f"Total runtime exluding children operators: {total_runtime_ns} ns" + ) assigned_ids, decomposed_nodes_dep = self.merge_gpu_ops_with_cpu_ops() @@ -874,7 +814,9 @@ def convert( if pt_nid in self.pt_gpu_node_dict.keys(): for pt_gpu_node in self.pt_gpu_node_dict[pt_nid]: # Assumption: same input / output as its parent CPU operator - self.update_input_tensor_dict(pt_gpu_node["id"], pt_gpu_node["inputs"]) + self.update_input_tensor_dict( + pt_gpu_node["id"], pt_gpu_node["inputs"] + ) # For now we ignore GPU->CPU dependencies since it creates unwanted dependencies. # self.update_output_tensor_dict(pt_gpu_node["id"], pt_gpu_node["outputs"]) @@ -891,9 +833,10 @@ def convert( # Adding decomposed nodes dependency # When we decompose a CPU op into multiple sub_ops, these ops have linear dependeny with themselves # For example, the first sub_op should be finished before the second sub_op. Here, we capture these dependencies. - if (pt_nid in decomposed_nodes_dep.keys())\ - and (decomposed_nodes_dep[pt_nid] not in ck_node.data_deps): - ck_node.data_deps.append(decomposed_nodes_dep[pt_nid]) + if (pt_nid in decomposed_nodes_dep.keys()) and ( + decomposed_nodes_dep[pt_nid] not in ck_node.data_deps + ): + ck_node.data_deps.append(decomposed_nodes_dep[pt_nid]) self.identify_data_dependency() diff --git a/et_converter/text2chakra_converter.py b/et_converter/text2chakra_converter.py index c5fb826c..160211ae 100644 --- a/et_converter/text2chakra_converter.py +++ b/et_converter/text2chakra_converter.py @@ -16,11 +16,9 @@ REDUCE_SCATTER, ) + class Layer: - def __init__( - self, - line: str - ) -> None: + def __init__(self, line: str) -> None: try: col = line.strip().split() self.name = col[0] @@ -46,8 +44,9 @@ def __init__( self.bwd_wg_update_time = str(col[11]) self.bwd_wg_comp_node = None self.bwd_wg_comm_node = None - except: - raise ValueError(f"Cannot parse the following layer -- \"{line}\"") + except Exception: + raise ValueError(f'Cannot parse the following layer -- "{line}"') + class Text2ChakraConverter: def __init__( @@ -57,7 +56,7 @@ def __init__( num_dims: int, num_npus: int, num_passes: int, - logger: logging.Logger + logger: logging.Logger, ) -> None: self.input_filename = input_filename self.output_filename = output_filename @@ -67,21 +66,13 @@ def __init__( self.logger = logger self.next_node_id = 0 - def get_layers( - self, - f: TextIOWrapper, - num_layers: int - ) -> List[Layer]: + def get_layers(self, f: TextIOWrapper, num_layers: int) -> List[Layer]: layers = [] for line in f: layers.append(Layer(line)) return layers - def get_node( - self, - name: str, - node_type: int - ) -> Any: + def get_node(self, name: str, node_type: int) -> Any: node = Node() node.id = self.next_node_id self.next_node_id += 1 @@ -89,21 +80,12 @@ def get_node( node.type = node_type return node - def get_comp_node( - self, - layer_name: str, - phase: str, - comp_time: int - ) -> Any: - node = self.get_node("COMP_NODE_" + layer_name + "_" + phase, - COMP_NODE) + def get_comp_node(self, layer_name: str, phase: str, comp_time: int) -> Any: + node = self.get_node("COMP_NODE_" + layer_name + "_" + phase, COMP_NODE) node.duration_micros = comp_time return node - def get_comm_type( - self, - comm_type: str - ) -> int: + def get_comm_type(self, comm_type: str) -> int: if comm_type == "ALLREDUCE": return ALL_REDUCE elif comm_type == "ALLTOALL": @@ -115,24 +97,15 @@ def get_comm_type( return 0 def get_comm_coll_node( - self, - layer_name: str, - comm_type: str, - comm_size: int + self, layer_name: str, comm_type: str, comm_size: int ) -> Any: - node = self.get_node( - f"COMM_COLL_NODE_{layer_name}_{comm_type}", - COMM_COLL_NODE) + node = self.get_node(f"COMM_COLL_NODE_{layer_name}_{comm_type}", COMM_COLL_NODE) node.attr.append( - ChakraAttr(name="comm_type", - int64_val=self.get_comm_type(comm_type))) + ChakraAttr(name="comm_type", int64_val=self.get_comm_type(comm_type)) + ) return node - def add_parent( - self, - child_node: Any, - parent_node: Any - ) -> None: + def add_parent(self, child_node: Any, parent_node: Any) -> None: child_node.parent.append(parent_node.id) def convert(self) -> None: @@ -147,22 +120,19 @@ def convert(self) -> None: self.convert_data_parallel(f, num_layers) elif parallelism_type == "MODEL": self.convert_model_parallel(f, num_layers) - elif (parallelism_type == "HYBRID_DATA_MODEL"): + elif parallelism_type == "HYBRID_DATA_MODEL": self.convert_hybrid_data_model(f, num_layers) - elif (parallelism_type == "HYBRID_MODEL_DATA"): + elif parallelism_type == "HYBRID_MODEL_DATA": self.convert_hybrid_model_data(f, num_layers) - elif (parallelism_type == "HYBRID_DLRM")\ - or (parallelism_type == "HYBRID_DLRM_ENHANCED"): + elif (parallelism_type == "HYBRID_DLRM") or ( + parallelism_type == "HYBRID_DLRM_ENHANCED" + ): last_bottom_layer = int(first_line[1]) self.convert_hybrid_dlrm(f, num_layers, last_bottom_layer) else: raise ValueError(f"Unsupported parallelism type, {parallelism_type}") - def convert_microbenchmark( - self, - f: TextIOWrapper, - num_layers: int - ) -> None: + def convert_microbenchmark(self, f: TextIOWrapper, num_layers: int) -> None: layers = self.get_layers(f, num_layers) for npu_id in range(self.num_npus): output_filename = "%s.%d.et" % (self.output_filename, npu_id) @@ -170,9 +140,8 @@ def convert_microbenchmark( for i in range(self.num_passes): for layer in layers: bwd_wg_comm_node = self.get_comm_coll_node( - layer.name, - layer.bwd_wg_comm_type, - layer.bwd_wg_comm_size) + layer.name, layer.bwd_wg_comm_type, layer.bwd_wg_comm_size + ) attr = ChakraAttr(name="involved_dim") for _ in range(self.num_dims): @@ -181,11 +150,7 @@ def convert_microbenchmark( encode_message(g, bwd_wg_comm_node) - def convert_data_parallel( - self, - f: TextIOWrapper, - num_layers: int - ) -> None: + def convert_data_parallel(self, f: TextIOWrapper, num_layers: int) -> None: layers = self.get_layers(f, num_layers) for npu_id in range(self.num_npus): output_filename = "%s.%d.et" % (self.output_filename, npu_id) @@ -196,11 +161,13 @@ def convert_data_parallel( # forward pass for idx, layer in enumerate(layers): fwd_comp_node = self.get_comp_node( - layer.name, "FWD", - layer.fwd_comp_time) + layer.name, "FWD", layer.fwd_comp_time + ) if idx != 0: - self.add_parent(fwd_comp_node, layers[idx-1].fwd_comp_node) - if layer.bwd_wg_comm_node != None: + self.add_parent( + fwd_comp_node, layers[idx - 1].fwd_comp_node + ) + if layer.bwd_wg_comm_node is not None: self.add_parent(fwd_comp_node, layer.bwd_wg_comm_node) layer.fwd_comp_node = fwd_comp_node encode_message(g, fwd_comp_node) @@ -208,21 +175,22 @@ def convert_data_parallel( # backward pass for idx, layer in enumerate(reversed(layers)): bwd_wg_comp_node = self.get_comp_node( - layer.name, "BWD_WG", - layer.bwd_wg_comp_time) + layer.name, "BWD_WG", layer.bwd_wg_comp_time + ) if idx == 0: - if fwd_comp_node == None: + if fwd_comp_node is None: raise ValueError("fwd_comp_node is None") self.add_parent(bwd_wg_comp_node, fwd_comp_node) else: - self.add_parent(bwd_wg_comp_node, - layers[len(layers)-idx].bwd_ig_comp_node) + self.add_parent( + bwd_wg_comp_node, + layers[len(layers) - idx].bwd_ig_comp_node, + ) encode_message(g, bwd_wg_comp_node) bwd_wg_comm_node = self.get_comm_coll_node( - layer.name, - layer.bwd_wg_comm_type, - layer.bwd_wg_comm_size) + layer.name, layer.bwd_wg_comm_type, layer.bwd_wg_comm_size + ) attr = ChakraAttr(name="involved_dim") for _ in range(self.num_dims): attr.bool_list.values.append(True) @@ -234,8 +202,8 @@ def convert_data_parallel( if idx != (len(layers) - 1): bwd_ig_comp_node = self.get_comp_node( - layer.name, "BWD_IG", - layer.bwd_ig_comp_time) + layer.name, "BWD_IG", layer.bwd_ig_comp_time + ) self.add_parent(bwd_ig_comp_node, bwd_wg_comp_node) layer.bwd_ig_comp_node = bwd_ig_comp_node encode_message(g, bwd_ig_comp_node) @@ -243,11 +211,7 @@ def convert_data_parallel( for layer in layers: layer.bwd_wg_comm_node = None - def convert_model_parallel( - self, - f: TextIOWrapper, - num_layers: int - ) -> None: + def convert_model_parallel(self, f: TextIOWrapper, num_layers: int) -> None: layers = self.get_layers(f, num_layers) for npu_id in range(self.num_npus): output_filename = "%s.%d.et" % (self.output_filename, npu_id) @@ -258,19 +222,20 @@ def convert_model_parallel( # forward pass for idx, layer in enumerate(layers): fwd_comp_node = self.get_comp_node( - layer.name, "FWD", - layer.fwd_comp_time) + layer.name, "FWD", layer.fwd_comp_time + ) if idx != 0: - self.add_parent(fwd_comp_node, layers[idx-1].fwd_comm_node) - if layer.bwd_wg_comp_node != None: + self.add_parent( + fwd_comp_node, layers[idx - 1].fwd_comm_node + ) + if layer.bwd_wg_comp_node is not None: self.add_parent(fwd_comp_node, layer.bwd_wg_comp_node) layer.fwd_comp_node = fwd_comp_node encode_message(g, fwd_comp_node) fwd_comm_node = self.get_comm_coll_node( - layer.name, - layer.fwd_comm_type, - layer.fwd_comm_size) + layer.name, layer.fwd_comm_type, layer.fwd_comm_size + ) attr = ChakraAttr(name="involved_dim") for _ in range(self.num_dims): attr.bool_list.values.append(True) @@ -282,24 +247,29 @@ def convert_model_parallel( # backward pass for idx, layer in enumerate(reversed(layers)): bwd_ig_comp_node = self.get_comp_node( - layer.name, "BWD_IG", - layer.bwd_ig_comp_time) + layer.name, "BWD_IG", layer.bwd_ig_comp_time + ) if idx == 0: - if fwd_comm_node == None: + if fwd_comm_node is None: raise ValueError("fwd_comm_node is None") self.add_parent(bwd_ig_comp_node, fwd_comm_node) else: - self.add_parent(bwd_ig_comp_node, - layers[len(layers)-idx].bwd_wg_comp_node) - self.add_parent(bwd_ig_comp_node, - layers[len(layers)-idx].bwd_ig_comm_node) + self.add_parent( + bwd_ig_comp_node, + layers[len(layers) - idx].bwd_wg_comp_node, + ) + self.add_parent( + bwd_ig_comp_node, + layers[len(layers) - idx].bwd_ig_comm_node, + ) encode_message(g, bwd_ig_comp_node) if idx != (num_layers - 1): bwd_ig_comm_node = self.get_comm_coll_node( - layer.name, - layer.bwd_ig_comm_type, - layer.bwd_ig_comm_size) + layer.name, + layer.bwd_ig_comm_type, + layer.bwd_ig_comm_size, + ) attr = ChakraAttr(name="involved_dim") for _ in range(self.num_dims): attr.bool_list.values.append(True) @@ -309,8 +279,8 @@ def convert_model_parallel( encode_message(g, bwd_ig_comm_node) bwd_wg_comp_node = self.get_comp_node( - layer.name, "BWD_WG", - layer.bwd_wg_comp_time) + layer.name, "BWD_WG", layer.bwd_wg_comp_time + ) self.add_parent(bwd_wg_comp_node, bwd_ig_comp_node) layer.bwd_wg_comp_node = bwd_wg_comp_node encode_message(g, bwd_wg_comp_node) @@ -318,11 +288,7 @@ def convert_model_parallel( for layer in layers: layer.bwd_wg_comp_node = None - def convert_hybrid_data_model( - self, - f: TextIOWrapper, - num_layers: int - ) -> None: + def convert_hybrid_data_model(self, f: TextIOWrapper, num_layers: int) -> None: layers = self.get_layers(f, num_layers) for npu_id in range(self.num_npus): output_filename = "%s.%d.et" % (self.output_filename, npu_id) @@ -333,21 +299,22 @@ def convert_hybrid_data_model( # forward pass for idx, layer in enumerate(layers): fwd_comp_node = self.get_comp_node( - layer.name, "FWD", - layer.fwd_comp_time) - if layer.bwd_wg_comm_node != None: + layer.name, "FWD", layer.fwd_comp_time + ) + if layer.bwd_wg_comm_node is not None: self.add_parent(fwd_comp_node, layer.bwd_wg_comm_node) if idx != 0: - self.add_parent(fwd_comp_node, layers[idx-1].fwd_comm_node) + self.add_parent( + fwd_comp_node, layers[idx - 1].fwd_comm_node + ) encode_message(g, fwd_comp_node) fwd_comm_node = self.get_comm_coll_node( - layer.name, - layer.fwd_comm_type, - layer.fwd_comm_size) + layer.name, layer.fwd_comm_type, layer.fwd_comm_size + ) attr = ChakraAttr(name="involved_dim") attr.bool_list.values.append(True) - for _ in range(self.num_dims-1): + for _ in range(self.num_dims - 1): attr.bool_list.values.append(False) fwd_comm_node.attr.append(attr) self.add_parent(fwd_comm_node, fwd_comp_node) @@ -357,27 +324,32 @@ def convert_hybrid_data_model( # backward pass for idx, layer in enumerate(reversed(layers)): bwd_ig_comp_node = self.get_comp_node( - layer.name, "BWD_IG", - layer.bwd_ig_comp_time) + layer.name, "BWD_IG", layer.bwd_ig_comp_time + ) if idx == 0: - if fwd_comm_node == None: + if fwd_comm_node is None: raise ValueError("fwd_comm_node is None") self.add_parent(bwd_ig_comp_node, fwd_comm_node) else: - self.add_parent(bwd_ig_comp_node, - layers[len(layers)-idx].bwd_wg_comp_node) - self.add_parent(bwd_ig_comp_node, - layers[len(layers)-idx].bwd_ig_comm_node) + self.add_parent( + bwd_ig_comp_node, + layers[len(layers) - idx].bwd_wg_comp_node, + ) + self.add_parent( + bwd_ig_comp_node, + layers[len(layers) - idx].bwd_ig_comm_node, + ) encode_message(g, bwd_ig_comp_node) if idx != num_layers - 1: bwd_ig_comm_node = self.get_comm_coll_node( - layer.name + "_IG_COMM_", - layer.bwd_ig_comm_type, - layer.bwd_ig_comm_size) + layer.name + "_IG_COMM_", + layer.bwd_ig_comm_type, + layer.bwd_ig_comm_size, + ) attr = ChakraAttr(name="involved_dim") attr.bool_list.values.append(True) - for _ in range(self.num_dims-1): + for _ in range(self.num_dims - 1): attr.bool_list.values.append(False) bwd_ig_comm_node.attr.append(attr) self.add_parent(bwd_ig_comm_node, bwd_ig_comp_node) @@ -385,19 +357,18 @@ def convert_hybrid_data_model( encode_message(g, bwd_ig_comm_node) bwd_wg_comp_node = self.get_comp_node( - layer.name, "BWD_WG", - layer.bwd_wg_comp_time) + layer.name, "BWD_WG", layer.bwd_wg_comp_time + ) self.add_parent(bwd_wg_comp_node, bwd_ig_comp_node) layer.bwd_wg_comp_node = bwd_wg_comp_node encode_message(g, bwd_wg_comp_node) bwd_wg_comm_node = self.get_comm_coll_node( - layer.name, - layer.bwd_wg_comm_type, - layer.bwd_wg_comm_size) + layer.name, layer.bwd_wg_comm_type, layer.bwd_wg_comm_size + ) attr = ChakraAttr(name="involved_dim") attr.bool_list.values.append(False) - for _ in range(self.num_dims-1): + for _ in range(self.num_dims - 1): attr.bool_list.values.append(True) bwd_wg_comm_node.attr.append(attr) self.add_parent(bwd_wg_comm_node, bwd_wg_comp_node) @@ -407,11 +378,7 @@ def convert_hybrid_data_model( for layer in layers: layer.bwd_wg_comm_node = None - def convert_hybrid_model_data( - self, - f: TextIOWrapper, - num_layers: int - ) -> None: + def convert_hybrid_model_data(self, f: TextIOWrapper, num_layers: int) -> None: layers = self.get_layers(f, num_layers) for npu_id in range(self.num_npus): output_filename = "%s.%d.et" % (self.output_filename, npu_id) @@ -422,21 +389,22 @@ def convert_hybrid_model_data( # forward pass for idx, layer in enumerate(layers): fwd_comp_node = self.get_comp_node( - layer.name, "FWD", - layer.fwd_comp_time) - if layer.bwd_wg_comm_node != None: + layer.name, "FWD", layer.fwd_comp_time + ) + if layer.bwd_wg_comm_node is not None: self.add_parent(fwd_comp_node, layer.bwd_wg_comm_node) if idx != 0: - self.add_parent(fwd_comp_node, layers[idx-1].fwd_comm_node) + self.add_parent( + fwd_comp_node, layers[idx - 1].fwd_comm_node + ) encode_message(g, fwd_comp_node) fwd_comm_node = self.get_comm_coll_node( - layer.name, - layer.fwd_comm_type, - layer.fwd_comm_size) + layer.name, layer.fwd_comm_type, layer.fwd_comm_size + ) attr = ChakraAttr(name="involved_dim") attr.bool_list.values.append(False) - for _ in range(self.num_dims-1): + for _ in range(self.num_dims - 1): attr.bool_list.values.append(True) fwd_comm_node.attr.append(attr) self.add_parent(fwd_comm_node, fwd_comp_node) @@ -446,25 +414,32 @@ def convert_hybrid_model_data( # backward pass for idx, layer in enumerate(reversed(layers)): bwd_ig_comp_node = self.get_comp_node( - layer.name, "BWD_IG", - layer.bwd_ig_comp_time) + layer.name, "BWD_IG", layer.bwd_ig_comp_time + ) if idx == 0: - if fwd_comm_node == None: + if fwd_comm_node is None: raise ValueError("fwd_comm_node is None") self.add_parent(bwd_ig_comp_node, fwd_comm_node) else: - self.add_parent(bwd_ig_comp_node, layers[len(layers)-idx].bwd_wg_comp_node) - self.add_parent(bwd_ig_comp_node, layers[len(layers)-idx].bwd_ig_comm_node) + self.add_parent( + bwd_ig_comp_node, + layers[len(layers) - idx].bwd_wg_comp_node, + ) + self.add_parent( + bwd_ig_comp_node, + layers[len(layers) - idx].bwd_ig_comm_node, + ) encode_message(g, bwd_ig_comp_node) if idx != num_layers - 1: bwd_ig_comm_node = self.get_comm_coll_node( - layer.name, - layer.bwd_ig_comm_type, - layer.bwd_ig_comm_size) + layer.name, + layer.bwd_ig_comm_type, + layer.bwd_ig_comm_size, + ) attr = ChakraAttr(name="involved_dim") attr.bool_list.values.append(False) - for _ in range(self.num_dims-1): + for _ in range(self.num_dims - 1): attr.bool_list.values.append(True) bwd_ig_comm_node.attr.append(attr) self.add_parent(bwd_ig_comm_node, bwd_ig_comp_node) @@ -472,19 +447,18 @@ def convert_hybrid_model_data( encode_message(g, bwd_ig_comm_node) bwd_wg_comp_node = self.get_comp_node( - layer.name, "BWD_WG", - layer.bwd_wg_comp_time) + layer.name, "BWD_WG", layer.bwd_wg_comp_time + ) self.add_parent(bwd_wg_comp_node, bwd_ig_comp_node) layer.bwd_wg_comp_node = bwd_wg_comp_node encode_message(g, bwd_wg_comp_node) bwd_wg_comm_node = self.get_comm_coll_node( - layer.name, - layer.bwd_wg_comm_type, - layer.bwd_wg_comm_size) + layer.name, layer.bwd_wg_comm_type, layer.bwd_wg_comm_size + ) attr = ChakraAttr(name="involved_dim") attr.bool_list.values.append(True) - for _ in range(self.num_dims-1): + for _ in range(self.num_dims - 1): attr.bool_list.values.append(False) bwd_wg_comm_node.attr.append(attr) self.add_parent(bwd_wg_comm_node, bwd_wg_comp_node) @@ -495,10 +469,7 @@ def convert_hybrid_model_data( layer.bwd_wg_comm_node = None def convert_hybrid_dlrm( - self, - f: TextIOWrapper, - num_layers: int, - last_bottom_layer: int + self, f: TextIOWrapper, num_layers: int, last_bottom_layer: int ) -> None: layers = self.get_layers(f, num_layers) for npu_id in range(self.num_npus): @@ -510,14 +481,16 @@ def convert_hybrid_dlrm( # forward pass for idx, layer in enumerate(layers): fwd_comp_node = self.get_comp_node( - layer.name, "FWD", - layer.fwd_comp_time) - if layer.bwd_wg_comm_node != None: + layer.name, "FWD", layer.fwd_comp_time + ) + if layer.bwd_wg_comm_node is not None: self.add_parent(fwd_comp_node, layer.bwd_wg_comm_node) - elif layer.bwd_wg_comp_node != None: + elif layer.bwd_wg_comp_node is not None: self.add_parent(fwd_comp_node, layer.bwd_wg_comp_node) if idx != 0: - self.add_parent(fwd_comp_node, layers[idx-1].fwd_comp_node) + self.add_parent( + fwd_comp_node, layers[idx - 1].fwd_comp_node + ) if idx == last_bottom_layer: self.add_parent(fwd_comp_node, layers[0].fwd_comm_node) layer.fwd_comp_node = fwd_comp_node @@ -525,9 +498,8 @@ def convert_hybrid_dlrm( if layer.fwd_comm_type == "ALLTOALL": fwd_comm_node = self.get_comm_coll_node( - layer.name, - layer.fwd_comm_type, - layer.fwd_comm_size) + layer.name, layer.fwd_comm_type, layer.fwd_comm_size + ) attr = ChakraAttr(name="involved_dim") for _ in range(self.num_dims): attr.bool_list.values.append(True) @@ -539,27 +511,35 @@ def convert_hybrid_dlrm( # backward pass for idx, layer in enumerate(reversed(layers)): bwd_wg_comp_node = self.get_comp_node( - layer.name, "BWD_WG", - layer.bwd_wg_comp_time) + layer.name, "BWD_WG", layer.bwd_wg_comp_time + ) if idx == 0: - if fwd_comp_node == None: + if fwd_comp_node is None: raise ValueError("fwd_comp_node is None") self.add_parent(bwd_wg_comp_node, fwd_comp_node) else: - if layers[len(layers)-idx].bwd_ig_comp_node != None: - self.add_parent(bwd_wg_comp_node, - layers[len(layers)-idx].bwd_ig_comp_node) - if layers[len(layers)-idx-1].bwd_ig_comm_node != None: - self.add_parent(bwd_wg_comp_node, - layers[len(layers)-idx-1].bwd_ig_comm_node) + if layers[len(layers) - idx].bwd_ig_comp_node is not None: + self.add_parent( + bwd_wg_comp_node, + layers[len(layers) - idx].bwd_ig_comp_node, + ) + if ( + layers[len(layers) - idx - 1].bwd_ig_comm_node + is not None + ): + self.add_parent( + bwd_wg_comp_node, + layers[len(layers) - idx - 1].bwd_ig_comm_node, + ) layer.bwd_wg_comp_node = bwd_wg_comp_node encode_message(g, bwd_wg_comp_node) if layer.bwd_wg_comm_type != "NONE": bwd_wg_comm_node = self.get_comm_coll_node( - layer.name, - layer.bwd_wg_comm_type, - layer.bwd_wg_comm_size) + layer.name, + layer.bwd_wg_comm_type, + layer.bwd_wg_comm_size, + ) attr = ChakraAttr(name="involved_dim") for _ in range(self.num_dims): attr.bool_list.values.append(True) @@ -571,22 +551,23 @@ def convert_hybrid_dlrm( bwd_ig_comp_node = None if idx != (len(layers) - 1): bwd_ig_comp_node = self.get_comp_node( - layer.name, "BWD_IG", - layer.bwd_ig_comp_time) + layer.name, "BWD_IG", layer.bwd_ig_comp_time + ) self.add_parent(bwd_ig_comp_node, bwd_wg_comp_node) layer.bwd_ig_comp_node = bwd_ig_comp_node encode_message(g, bwd_ig_comp_node) if (len(layers) - idx - 1) == (last_bottom_layer + 1): bwd_ig_comm_node = self.get_comm_coll_node( - layers[0].name, - layers[0].bwd_ig_comm_type, - layers[0].bwd_ig_comm_size) + layers[0].name, + layers[0].bwd_ig_comm_type, + layers[0].bwd_ig_comm_size, + ) attr = ChakraAttr(name="involved_dim") for _ in range(self.num_dims): attr.bool_list.values.append(True) bwd_ig_comm_node.attr.append(attr) - if bwd_ig_comp_node == None: + if bwd_ig_comp_node is None: raise ValueError("bwd_ig_comp_node is None") self.add_parent(bwd_ig_comm_node, bwd_ig_comp_node) layers[0].bwd_ig_comm_node = bwd_ig_comm_node diff --git a/et_visualizer/et_visualizer.py b/et_visualizer/et_visualizer.py index 589db442..fedac8f5 100644 --- a/et_visualizer/et_visualizer.py +++ b/et_visualizer/et_visualizer.py @@ -6,28 +6,26 @@ from chakra.third_party.utils.protolib import ( openFileRd as open_file_rd, - decodeMessage as decode_message + decodeMessage as decode_message, ) from chakra.et_def.et_def_pb2 import Node def main() -> None: - parser = argparse.ArgumentParser( - description="Execution Trace Visualizer" - ) + parser = argparse.ArgumentParser(description="Execution Trace Visualizer") parser.add_argument( "--input_filename", type=str, default=None, required=True, - help="Input Chakra execution trace filename" + help="Input Chakra execution trace filename", ) parser.add_argument( "--output_filename", type=str, default=None, required=True, - help="Output graph filename" + help="Output graph filename", ) args = parser.parse_args() @@ -38,25 +36,30 @@ def main() -> None: if args.output_filename.endswith((".pdf", ".dot")): f = graphviz.Digraph() while decode_message(et, node): - f.node(name=f"{node.id}", - label=f"{node.name}", - id=str(node.id), - shape="record") + f.node( + name=f"{node.id}", label=f"{node.name}", id=str(node.id), shape="record" + ) # Handling data dependencies for data_dep_id in node.data_deps: - f.edge(str(data_dep_id), str(node.id), arrowhead="normal") # using "normal" arrow for data_deps + f.edge( + str(data_dep_id), str(node.id), arrowhead="normal" + ) # using "normal" arrow for data_deps # Handling control dependencies for ctrl_dep_id in node.ctrl_deps: - f.edge(str(ctrl_dep_id), str(node.id), arrowhead="tee") # using "tee" arrow for ctrl_deps + f.edge( + str(ctrl_dep_id), str(node.id), arrowhead="tee" + ) # using "tee" arrow for ctrl_deps if args.output_filename.endswith(".pdf"): - f.render(args.output_filename.replace(".pdf", ""), - format="pdf", cleanup=True) + f.render( + args.output_filename.replace(".pdf", ""), format="pdf", cleanup=True + ) else: # ends with ".dot" - f.render(args.output_filename.replace(".dot", ""), - format="dot", cleanup=True) + f.render( + args.output_filename.replace(".dot", ""), format="dot", cleanup=True + ) elif args.output_filename.endswith(".graphml"): G = nx.DiGraph() while decode_message(et, node): diff --git a/setup.py b/setup.py index 4ac0ffe7..d90ea865 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ class build_grpc(build): - sub_commands = [('build_grpc', None)] + build.sub_commands + sub_commands = [("build_grpc", None)] + build.sub_commands -setup(cmdclass={'build': build_grpc}) +setup(cmdclass={"build": build_grpc}) diff --git a/third_party/utils/protolib.py b/third_party/utils/protolib.py index ea0ff097..dcfb7aab 100644 --- a/third_party/utils/protolib.py +++ b/third_party/utils/protolib.py @@ -71,6 +71,7 @@ import gzip import struct + def openFileRd(in_file): """ This opens the file passed as argument for reading using an appropriate @@ -81,7 +82,7 @@ def openFileRd(in_file): # First see if this file is gzipped try: # Opening the file works even if it is not a gzip file - proto_in = gzip.open(in_file, 'rb') + proto_in = gzip.open(in_file, "rb") # Force a check of the magic number by seeking in the # file. If we do not do it here the error will occur when @@ -89,12 +90,13 @@ def openFileRd(in_file): proto_in.seek(1) proto_in.seek(0) except IOError: - proto_in = open(in_file, 'rb') + proto_in = open(in_file, "rb") except IOError: print("Failed to open ", in_file, " for reading") exit(-1) return proto_in + def _DecodeVarint32(in_file): """ The decoding of the Varint32 is copied from @@ -106,24 +108,25 @@ def _DecodeVarint32(in_file): shift = 0 pos = 0 # Use a 32-bit mask - mask = 0xffffffff + mask = 0xFFFFFFFF while 1: c = in_file.read(1) if len(c) == 0: return (0, 0) - b = struct.unpack(' 0x7fffffffffffffff: - result -= (1 << 64) + if result > 0x7FFFFFFFFFFFFFFF: + result -= 1 << 64 result |= ~mask else: result &= mask return (result, pos) shift += 7 if shift >= 64: - raise IOError('Too many bytes when decoding varint.') + raise IOError("Too many bytes when decoding varint.") + def decodeMessage(in_file, message): """ @@ -140,19 +143,21 @@ def decodeMessage(in_file, message): except IOError: return False + def _EncodeVarint32(out_file, value): - """ - The encoding of the Varint32 is copied from - google.protobuf.internal.encoder and is only repeated here to - avoid depending on the internal functions in the library. - """ - bits = value & 0x7f - value >>= 7 - while value: - out_file.write(struct.pack('>= 7 - out_file.write(struct.pack('>= 7 + out_file.write(struct.pack(" logging.Logger: formatter = logging.Formatter( - "%(levelname)s [%(asctime)s] %(message)s", - datefmt="%m/%d/%Y %I:%M:%S %p") + "%(levelname)s [%(asctime)s] %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p" + ) file_handler = FileHandler(log_filename, mode="w") file_handler.setLevel(logging.DEBUG) @@ -35,40 +37,43 @@ def get_logger(log_filename: str) -> logging.Logger: return logger + def is_local_mem_node(node_name: str) -> bool: - if ("MEM_LOAD_NODE" in node_name)\ - and ("LOCAL_MEMORY" in node_name): + if ("MEM_LOAD_NODE" in node_name) and ("LOCAL_MEMORY" in node_name): return True - elif ("MEM_STORE_NODE" in node_name)\ - and ("LOCAL_MEMORY" in node_name): + elif ("MEM_STORE_NODE" in node_name) and ("LOCAL_MEMORY" in node_name): return True else: return False + def is_remote_mem_node(node_name: str) -> bool: - if ("MEM_LOAD_NODE" in node_name)\ - and ("REMOTE_MEMORY" in node_name): + if ("MEM_LOAD_NODE" in node_name) and ("REMOTE_MEMORY" in node_name): return True - elif ("MEM_STORE_NODE" in node_name)\ - and ("REMOTE_MEMORY" in node_name): + elif ("MEM_STORE_NODE" in node_name) and ("REMOTE_MEMORY" in node_name): return True else: return False + def is_comp_node(node_name: str) -> bool: if "COMP_NODE" in node_name: return True else: return False + def is_comm_node(node_name: str) -> bool: - if ("COMM_SEND_NODE" in node_name)\ - or ("COMM_RECV_NODE" in node_name)\ - or ("COMM_COLL_NODE" in node_name): + if ( + ("COMM_SEND_NODE" in node_name) + or ("COMM_RECV_NODE" in node_name) + or ("COMM_COLL_NODE" in node_name) + ): return True else: return False + def get_tid(node_name: str) -> TID: if is_local_mem_node(node_name): return TID.LOCAL_MEMORY @@ -81,9 +86,8 @@ def get_tid(node_name: str) -> TID: else: raise ValueError(f"Node type cannot be identified from {node_name}") -def parse_event( - line: str -) -> Tuple[str, int, int, int, str]: + +def parse_event(line: str) -> Tuple[str, int, int, int, str]: try: cols = line.strip().split(",") trace_type = cols[0] @@ -92,13 +96,12 @@ def parse_event( node_id = int(cols[3].split("=")[1]) node_name = cols[4].split("=")[1] return (trace_type, npu_id, curr_cycle, node_id, node_name) - except: - raise ValueError(f"Cannot parse the following event -- \"{line}\"") + except Exception: + raise ValueError(f'Cannot parse the following event -- "{line}"') + def get_trace_events( - input_filename: str, - num_npus: int, - npu_frequency: int + input_filename: str, num_npus: int, npu_frequency: int ) -> List[Dict[str, Any]]: trace_dict = {i: {} for i in range(num_npus)} trace_events = [] @@ -106,12 +109,10 @@ def get_trace_events( with open(input_filename, "r") as f: for line in f: if ("issue" in line) or ("callback" in line): - (trace_type, npu_id, curr_cycle, node_id, node_name) =\ - parse_event(line) + (trace_type, npu_id, curr_cycle, node_id, node_name) = parse_event(line) if trace_type == "issue": - trace_dict[npu_id].update( - {node_id: [node_name, curr_cycle]}) + trace_dict[npu_id].update({node_id: [node_name, curr_cycle]}) elif trace_type == "callback": node_name = trace_dict[npu_id][node_id][0] tid = get_tid(node_name) @@ -121,15 +122,16 @@ def get_trace_events( duration_in_ms = duration_in_cycles / (npu_frequency * 1_000) trace_events.append( - { - "pid": npu_id, - "tid": tid, - "ts": issued_ms, - "dur": duration_in_ms, - "ph": "X", - "name": node_name, - "args": {"ms": duration_in_ms} - }) + { + "pid": npu_id, + "tid": tid, + "ts": issued_ms, + "dur": duration_in_ms, + "ph": "X", + "name": node_name, + "args": {"ms": duration_in_ms}, + } + ) del trace_dict[npu_id][node_id] else: @@ -139,56 +141,50 @@ def get_trace_events( def write_trace_events( - output_filename: str, - num_npus: int, - trace_events: List[Dict[str, Any]] + output_filename: str, num_npus: int, trace_events: List[Dict[str, Any]] ) -> None: output_dict = { "meta_user": "aras", "traceEvents": trace_events, "meta_user": "aras", - "meta_cpu_count": num_npus + "meta_cpu_count": num_npus, } with open(output_filename, "w") as f: json.dump(output_dict, f) + def main() -> None: - parser = argparse.ArgumentParser( - description="Timeline Visualizer" - ) + parser = argparse.ArgumentParser(description="Timeline Visualizer") parser.add_argument( - "--input_filename", - type=str, - default=None, - required=True, - help="Input timeline filename" + "--input_filename", + type=str, + default=None, + required=True, + help="Input timeline filename", ) parser.add_argument( - "--output_filename", - type=str, - default=None, - required=True, - help="Output trace filename" + "--output_filename", + type=str, + default=None, + required=True, + help="Output trace filename", ) parser.add_argument( - "--num_npus", - type=int, - default=None, - required=True, - help="Number of NPUs in a system" + "--num_npus", + type=int, + default=None, + required=True, + help="Number of NPUs in a system", ) parser.add_argument( - "--npu_frequency", - type=int, - default=None, - required=True, - help="NPU frequency in MHz" + "--npu_frequency", + type=int, + default=None, + required=True, + help="NPU frequency in MHz", ) parser.add_argument( - "--log_filename", - type=str, - default="debug.log", - help="Log filename" + "--log_filename", type=str, default="debug.log", help="Log filename" ) args = parser.parse_args() @@ -197,16 +193,13 @@ def main() -> None: try: trace_events = get_trace_events( - args.input_filename, - args.num_npus, - args.npu_frequency) - write_trace_events( - args.output_filename, - args.num_npus, - trace_events) + args.input_filename, args.num_npus, args.npu_frequency + ) + write_trace_events(args.output_filename, args.num_npus, trace_events) except Exception as e: logger.error(str(e)) sys.exit(1) + if __name__ == "__main__": main() diff --git a/utils/et_generator/et_generator.py b/utils/et_generator/et_generator.py index cc2a94ee..e4cc7a08 100644 --- a/utils/et_generator/et_generator.py +++ b/utils/et_generator/et_generator.py @@ -26,8 +26,6 @@ MEM_LOAD_NODE, MEM_STORE_NODE, COMP_NODE, - COMM_SEND_NODE, - COMM_RECV_NODE, COMM_COLL_NODE, ALL_REDUCE, ALL_TO_ALL, @@ -37,6 +35,7 @@ NODE_ID = 0 + def get_node(node_name: str, node_type: int) -> ChakraNode: global NODE_ID node = ChakraNode() @@ -63,35 +62,51 @@ def one_metadata_node_all_types(num_npus: int) -> None: node = get_node("METADATA_NODE", METADATA_NODE) - node.attr.append(ChakraAttr(name="double", double_val=1.2345, doc_string="double")) + node.attr.append( + ChakraAttr(name="double", double_val=1.2345, doc_string="double") + ) double_list = DoubleList(values=[1.2345, 2.3456]) node.attr.append(ChakraAttr(name="double_list", double_list=double_list)) - node.attr.append(ChakraAttr(name="float", float_val=1.2345, doc_string="float")) + node.attr.append( + ChakraAttr(name="float", float_val=1.2345, doc_string="float") + ) float_list = FloatList(values=[1.2345, 2.3456]) node.attr.append(ChakraAttr(name="float_list", float_list=float_list)) - node.attr.append(ChakraAttr(name="int32", int32_val=12345, doc_string="int32")) + node.attr.append( + ChakraAttr(name="int32", int32_val=12345, doc_string="int32") + ) int32_list = Int32List(values=[12345, 23456]) node.attr.append(ChakraAttr(name="int32_list", int32_list=int32_list)) - node.attr.append(ChakraAttr(name="int64", int64_val=9876543210, doc_string="int64")) + node.attr.append( + ChakraAttr(name="int64", int64_val=9876543210, doc_string="int64") + ) int64_list = Int64List(values=[9876543210, 1234567890]) node.attr.append(ChakraAttr(name="int64_list", int64_list=int64_list)) - node.attr.append(ChakraAttr(name="uint32", uint32_val=12345, doc_string="uint32")) + node.attr.append( + ChakraAttr(name="uint32", uint32_val=12345, doc_string="uint32") + ) uint32_list = Uint32List(values=[12345, 23456]) node.attr.append(ChakraAttr(name="uint32_list", uint32_list=uint32_list)) - node.attr.append(ChakraAttr(name="uint64", uint64_val=9876543210, doc_string="uint64")) + node.attr.append( + ChakraAttr(name="uint64", uint64_val=9876543210, doc_string="uint64") + ) uint64_list = Uint64List(values=[9876543210, 1234567890]) node.attr.append(ChakraAttr(name="uint64_list", uint64_list=uint64_list)) - node.attr.append(ChakraAttr(name="sint32", sint32_val=-12345, doc_string="sint32")) + node.attr.append( + ChakraAttr(name="sint32", sint32_val=-12345, doc_string="sint32") + ) sint32_list = Sint32List(values=[12345, -23456]) node.attr.append(ChakraAttr(name="sint32_list", sint32_list=sint32_list)) - node.attr.append(ChakraAttr(name="sint64", sint64_val=-9876543210, doc_string="sint64")) + node.attr.append( + ChakraAttr(name="sint64", sint64_val=-9876543210, doc_string="sint64") + ) sint64_list = Sint64List(values=[9876543210, -1234567890]) node.attr.append(ChakraAttr(name="sint64_list", sint64_list=sint64_list)) @@ -105,22 +120,32 @@ def one_metadata_node_all_types(num_npus: int) -> None: node.attr.append(ChakraAttr(name="sfixed32", sfixed32_val=-12345)) sfixed32_list = Sfixed32List(values=[12345, -23456]) - node.attr.append(ChakraAttr(name="sfixed32_list", sfixed32_list=sfixed32_list)) + node.attr.append( + ChakraAttr(name="sfixed32_list", sfixed32_list=sfixed32_list) + ) node.attr.append(ChakraAttr(name="sfixed64", sfixed64_val=-9876543210)) sfixed64_list = Sfixed64List(values=[9876543210, -1234567890]) - node.attr.append(ChakraAttr(name="sfixed64_list", sfixed64_list=sfixed64_list)) + node.attr.append( + ChakraAttr(name="sfixed64_list", sfixed64_list=sfixed64_list) + ) node.attr.append(ChakraAttr(name="bool", bool_val=True, doc_string="bool")) bool_list = BoolList(values=[i % 2 == 0 for i in range(10)]) node.attr.append(ChakraAttr(name="bool_list", bool_list=bool_list)) - node.attr.append(ChakraAttr(name="string", string_val="12345", doc_string="string")) - string_list = StringList(values=[str(12345+i) for i in range(10)]) + node.attr.append( + ChakraAttr(name="string", string_val="12345", doc_string="string") + ) + string_list = StringList(values=[str(12345 + i) for i in range(10)]) node.attr.append(ChakraAttr(name="string_list", string_list=string_list)) - node.attr.append(ChakraAttr(name="bytes", bytes_val=bytes("12345", "utf-8"))) - bytes_list = BytesList(values=[bytes(str(12345+i), "utf-8") for i in range(10)]) + node.attr.append( + ChakraAttr(name="bytes", bytes_val=bytes("12345", "utf-8")) + ) + bytes_list = BytesList( + values=[bytes(str(12345 + i), "utf-8") for i in range(10)] + ) node.attr.append(ChakraAttr(name="bytes_list", bytes_list=bytes_list)) encode_message(et, node) @@ -242,7 +267,9 @@ def one_comm_coll_node_allgather(num_npus: int, num_dims: int, comm_size: int) - encode_message(et, node) -def one_comm_coll_node_reducescatter(num_npus: int, num_dims: int, comm_size: int) -> None: +def one_comm_coll_node_reducescatter( + num_npus: int, num_dims: int, comm_size: int +) -> None: for npu_id in range(num_npus): output_filename = f"one_comm_coll_node_reducescatter.{npu_id}.et" with open(output_filename, "wb") as et: @@ -258,38 +285,31 @@ def one_comm_coll_node_reducescatter(num_npus: int, num_dims: int, comm_size: in def main() -> None: - parser = argparse.ArgumentParser( - description="Execution Trace Generator" - ) - parser.add_argument( - "--num_npus", - type=int, - default=64, - help="Number of NPUs" - ) + parser = argparse.ArgumentParser(description="Execution Trace Generator") + parser.add_argument("--num_npus", type=int, default=64, help="Number of NPUs") parser.add_argument( "--num_dims", type=int, default=2, - help="Number of dimensions in the network topology" + help="Number of dimensions in the network topology", ) parser.add_argument( "--default_runtime", type=int, default=5, - help="Default runtime of compute nodes" + help="Default runtime of compute nodes", ) parser.add_argument( "--default_tensor_size", type=int, default=1024, - help="Default tensor size of memory nodes" + help="Default tensor size of memory nodes", ) parser.add_argument( "--default_comm_size", type=int, default=65536, - help="Default communication size of communication nodes" + help="Default communication size of communication nodes", ) args = parser.parse_args() @@ -305,7 +325,9 @@ def main() -> None: one_comm_coll_node_allreduce(args.num_npus, args.num_dims, args.default_comm_size) one_comm_coll_node_alltoall(args.num_npus, args.num_dims, args.default_comm_size) one_comm_coll_node_allgather(args.num_npus, args.num_dims, args.default_comm_size) - one_comm_coll_node_reducescatter(args.num_npus, args.num_dims, args.default_comm_size) + one_comm_coll_node_reducescatter( + args.num_npus, args.num_dims, args.default_comm_size + ) if __name__ == "__main__": diff --git a/utils/et_jsonizer/et_jsonizer.py b/utils/et_jsonizer/et_jsonizer.py index 5dc5c08e..c13494e7 100644 --- a/utils/et_jsonizer/et_jsonizer.py +++ b/utils/et_jsonizer/et_jsonizer.py @@ -6,7 +6,7 @@ from chakra.third_party.utils.protolib import ( openFileRd as open_file_rd, - decodeMessage as decode_message + decodeMessage as decode_message, ) from chakra.et_def.et_def_pb2 import ( @@ -15,28 +15,26 @@ def main() -> None: - parser = argparse.ArgumentParser( - description="Execution Trace Jsonizer" - ) + parser = argparse.ArgumentParser(description="Execution Trace Jsonizer") parser.add_argument( "--input_filename", type=str, default=None, required=True, - help="Input Chakra execution trace filename" + help="Input Chakra execution trace filename", ) parser.add_argument( "--output_filename", type=str, default=None, required=True, - help="Output filename" + help="Output filename", ) args = parser.parse_args() et = open_file_rd(args.input_filename) node = ChakraNode() - with open(args.output_filename, 'w') as f: + with open(args.output_filename, "w") as f: while decode_message(et, node): f.write(MessageToJson(node)) et.close()