diff --git a/.github/workflows/python_lint.yml b/.github/workflows/python_lint.yml index 243d479f..c93ce45f 100644 --- a/.github/workflows/python_lint.yml +++ b/.github/workflows/python_lint.yml @@ -16,9 +16,10 @@ jobs: - name: Install dependencies run: | - pip install flake8 - pip install pyre-check - pip install . + pip install -r requirements-dev.txt + + - name: Run ruff + run: ruff format --check --diff . - name: Run Flake8 run: flake8 . diff --git a/et_converter/et_converter.py b/et_converter/et_converter.py index 7c8ead39..0933a814 100644 --- a/et_converter/et_converter.py +++ b/et_converter/et_converter.py @@ -9,10 +9,9 @@ from .text2chakra_converter import Text2ChakraConverter from .pytorch2chakra_converter import PyTorch2ChakraConverter + def get_logger(log_filename: str) -> logging.Logger: - formatter = logging.Formatter( - "%(levelname)s [%(asctime)s] %(message)s", - datefmt="%m/%d/%Y %I:%M:%S %p") + formatter = logging.Formatter("%(levelname)s [%(asctime)s] %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p") file_handler = FileHandler(log_filename, mode="w") file_handler.setLevel(logging.DEBUG) @@ -29,44 +28,23 @@ def get_logger(log_filename: str) -> logging.Logger: return logger + def main() -> None: - parser = argparse.ArgumentParser( - description="Execution Trace Converter") - parser.add_argument( - "--input_type", - type=str, - default=None, - required=True, - help="Input execution trace type") + parser = argparse.ArgumentParser(description="Execution Trace Converter") + parser.add_argument("--input_type", type=str, default=None, required=True, help="Input execution trace type") parser.add_argument( - "--input_filename", - type=str, - default=None, - required=True, - help="Input execution trace filename") + "--input_filename", type=str, default=None, required=True, help="Input execution trace filename" + ) parser.add_argument( - "--output_filename", - type=str, - default=None, - required=True, - help="Output Chakra execution trace filename") + "--output_filename", type=str, default=None, required=True, help="Output Chakra execution trace filename" + ) parser.add_argument( - "--num_npus", - type=int, - default=None, - required="Text" in sys.argv, - help="Number of NPUs in a system") + "--num_npus", type=int, default=None, required="Text" in sys.argv, help="Number of NPUs in a system" + ) parser.add_argument( - "--num_passes", - type=int, - default=None, - required="Text" in sys.argv, - help="Number of training passes") - parser.add_argument( - "--log_filename", - type=str, - default="debug.log", - help="Log filename") + "--num_passes", type=int, default=None, required="Text" in sys.argv, help="Number of training passes" + ) + parser.add_argument("--log_filename", type=str, default="debug.log", help="Log filename") args = parser.parse_args() logger = get_logger(args.log_filename) @@ -75,17 +53,11 @@ def main() -> None: try: if args.input_type == "Text": converter = Text2ChakraConverter( - args.input_filename, - args.output_filename, - args.num_npus, - args.num_passes, - logger) + args.input_filename, args.output_filename, args.num_npus, args.num_passes, logger + ) converter.convert() elif args.input_type == "PyTorch": - converter = PyTorch2ChakraConverter( - args.input_filename, - args.output_filename, - logger) + converter = PyTorch2ChakraConverter(args.input_filename, args.output_filename, logger) converter.convert() else: logger.error(f"{args.input_type} unsupported") @@ -95,5 +67,6 @@ def main() -> None: logger.debug(traceback.format_exc()) sys.exit(1) + if __name__ == "__main__": main() diff --git a/et_converter/pytorch2chakra_converter.py b/et_converter/pytorch2chakra_converter.py index 6ba98f60..078bfaa0 100644 --- a/et_converter/pytorch2chakra_converter.py +++ b/et_converter/pytorch2chakra_converter.py @@ -108,12 +108,7 @@ class PyTorch2ChakraConverter: dependencies. """ - def __init__( - self, - input_filename: str, - output_filename: str, - logger: logging.Logger - ) -> None: + def __init__(self, input_filename: str, output_filename: str, logger: logging.Logger) -> None: """ Initializes the PyTorch to Chakra converter. It sets up necessary attributes and prepares the environment for the conversion process. @@ -157,8 +152,9 @@ def convert(self) -> None: self.open_chakra_execution_trace() for pytorch_nid, pytorch_node in self.pytorch_nodes.items(): - if (pytorch_node.get_op_type() == PyTorchNodeType.CPU_OP)\ - or (pytorch_node.get_op_type() == PyTorchNodeType.LABEL): + if (pytorch_node.get_op_type() == PyTorchNodeType.CPU_OP) or ( + pytorch_node.get_op_type() == PyTorchNodeType.LABEL + ): chakra_node = self.convert_to_chakra_node(pytorch_node) self.chakra_nodes[chakra_node.id] = chakra_node @@ -167,11 +163,12 @@ def convert(self) -> None: if chakra_node.type == COMM_COLL_NODE: collective_comm_type = self.get_collective_comm_type(pytorch_node.name) - chakra_gpu_node.attr.extend([ - ChakraAttr(name="comm_type", - int64_val=collective_comm_type), - ChakraAttr(name="comm_size", - int64_val=pytorch_gpu_node.comm_size)]) + chakra_gpu_node.attr.extend( + [ + ChakraAttr(name="comm_type", int64_val=collective_comm_type), + ChakraAttr(name="comm_size", int64_val=pytorch_gpu_node.comm_size), + ] + ) self.chakra_nodes[chakra_gpu_node.id] = chakra_gpu_node @@ -229,14 +226,10 @@ def _parse_and_instantiate_nodes(self, pytorch_et_data: Dict) -> None: self.pytorch_finish_ts = pytorch_et_data["finish_ts"] pytorch_nodes = pytorch_et_data["nodes"] - pytorch_node_objects = { - node_data["id"]: PyTorchNode(node_data) for node_data in pytorch_nodes - } + pytorch_node_objects = {node_data["id"]: PyTorchNode(node_data) for node_data in pytorch_nodes} self._establish_parent_child_relationships(pytorch_node_objects) - def _establish_parent_child_relationships( - self, pytorch_node_objects: Dict[int, PyTorchNode] - ) -> None: + def _establish_parent_child_relationships(self, pytorch_node_objects: Dict[int, PyTorchNode]) -> None: """ Establishes parent-child relationships among PyTorch nodes and counts the node types. @@ -252,7 +245,7 @@ def _establish_parent_child_relationships( "gpu_op": 0, "record_param_comms_op": 0, "nccl_op": 0, - "root_op": 0 + "root_op": 0, } # Establish parent-child relationships @@ -271,8 +264,10 @@ def _establish_parent_child_relationships( if pytorch_node.is_nccl_op(): parent_node.nccl_node = pytorch_node - if pytorch_node.name in ["[pytorch|profiler|execution_graph|thread]", - "[pytorch|profiler|execution_trace|thread]"]: + if pytorch_node.name in [ + "[pytorch|profiler|execution_graph|thread]", + "[pytorch|profiler|execution_trace|thread]", + ]: self.pytorch_root_nids.append(pytorch_node.id) node_type_counts["root_op"] += 1 @@ -333,17 +328,19 @@ def convert_to_chakra_node(self, pytorch_node: PyTorchNode) -> ChakraNode: chakra_node.outputs.values = str(pytorch_node.outputs) chakra_node.outputs.shapes = str(pytorch_node.output_shapes) chakra_node.outputs.types = str(pytorch_node.output_types) - chakra_node.attr.extend([ - ChakraAttr(name="rf_id", int64_val=pytorch_node.rf_id), - ChakraAttr(name="fw_parent", int64_val=pytorch_node.fw_parent), - ChakraAttr(name="seq_id", int64_val=pytorch_node.seq_id), - ChakraAttr(name="scope", int64_val=pytorch_node.scope), - ChakraAttr(name="tid", int64_val=pytorch_node.tid), - ChakraAttr(name="fw_tid", int64_val=pytorch_node.fw_tid), - ChakraAttr(name="op_schema", string_val=pytorch_node.op_schema), - ChakraAttr(name="is_cpu_op", int32_val=not pytorch_node.is_gpu_op()), - ChakraAttr(name="ts", int64_val=pytorch_node.ts) - ]) + chakra_node.attr.extend( + [ + ChakraAttr(name="rf_id", int64_val=pytorch_node.rf_id), + ChakraAttr(name="fw_parent", int64_val=pytorch_node.fw_parent), + ChakraAttr(name="seq_id", int64_val=pytorch_node.seq_id), + ChakraAttr(name="scope", int64_val=pytorch_node.scope), + ChakraAttr(name="tid", int64_val=pytorch_node.tid), + ChakraAttr(name="fw_tid", int64_val=pytorch_node.fw_tid), + ChakraAttr(name="op_schema", string_val=pytorch_node.op_schema), + ChakraAttr(name="is_cpu_op", int32_val=not pytorch_node.is_gpu_op()), + ChakraAttr(name="ts", int64_val=pytorch_node.ts), + ] + ) return chakra_node def get_chakra_node_type_from_pytorch_node(self, pytorch_node: PyTorchNode) -> ChakraNodeType: @@ -356,9 +353,7 @@ def get_chakra_node_type_from_pytorch_node(self, pytorch_node: PyTorchNode) -> C Returns: int: The corresponding Chakra node type. """ - if pytorch_node.is_gpu_op() and ( - "ncclKernel" in pytorch_node.name or "ncclDevKernel" in pytorch_node.name - ): + if pytorch_node.is_gpu_op() and ("ncclKernel" in pytorch_node.name or "ncclDevKernel" in pytorch_node.name): return COMM_COLL_NODE elif ("c10d::" in pytorch_node.name) or ("nccl:" in pytorch_node.name): return COMM_COLL_NODE @@ -392,8 +387,10 @@ def get_collective_comm_type(self, name: str) -> int: if key.lower() in name.lower(): return value - raise ValueError(f"'{name}' not found in collective communication mapping. " - "Please add this collective communication name to the mapping.") + raise ValueError( + f"'{name}' not found in collective communication mapping. " + "Please add this collective communication name to the mapping." + ) def is_root_node(self, node): """ @@ -412,8 +409,7 @@ def is_root_node(self, node): Returns: bool: True if the node is a root node, False otherwise. """ - if node.name in ["[pytorch|profiler|execution_graph|thread]", - "[pytorch|profiler|execution_trace|thread]"]: + if node.name in ["[pytorch|profiler|execution_graph|thread]", "[pytorch|profiler|execution_trace|thread]"]: return True def convert_ctrl_dep_to_data_dep(self, chakra_node: ChakraNode) -> None: @@ -591,9 +587,7 @@ def dfs(node_id: int, path: List[int]) -> bool: bool: True if a cycle is detected, False otherwise. """ if node_id in stack: - cycle_nodes = " -> ".join( - [self.chakra_nodes[n].name for n in path + [node_id]] - ) + cycle_nodes = " -> ".join([self.chakra_nodes[n].name for n in path + [node_id]]) self.logger.error(f"Cyclic dependency detected: {cycle_nodes}") return True if node_id in visited: @@ -611,10 +605,7 @@ def dfs(node_id: int, path: List[int]) -> bool: for node_id in self.chakra_nodes: if dfs(node_id, []): - raise Exception( - f"Cyclic dependency detected starting from node " - f"{self.chakra_nodes[node_id].name}" - ) + raise Exception(f"Cyclic dependency detected starting from node " f"{self.chakra_nodes[node_id].name}") def write_chakra_et(self) -> None: """ @@ -642,7 +633,7 @@ def _write_global_metadata(self) -> None: ChakraAttr(name="pid", uint64_val=self.pytorch_pid), ChakraAttr(name="time", string_val=self.pytorch_time), ChakraAttr(name="start_ts", uint64_val=self.pytorch_start_ts), - ChakraAttr(name="finish_ts", uint64_val=self.pytorch_finish_ts) + ChakraAttr(name="finish_ts", uint64_val=self.pytorch_finish_ts), ] ) encode_message(self.chakra_et, global_metadata) @@ -684,21 +675,18 @@ def simulate_execution(self) -> None: execution based on the readiness determined by dependency resolution. A simplistic global clock is used to model the execution time. """ - self.logger.info("Simulating execution of Chakra nodes based on data " - "dependencies.") + self.logger.info("Simulating execution of Chakra nodes based on data " "dependencies.") # Initialize queues for ready CPU and GPU nodes ready_cpu_nodes = [ (node_id, self.chakra_nodes[node_id]) for node_id in self.chakra_nodes - if not self.chakra_nodes[node_id].data_deps and - not self.pytorch_nodes[node_id].is_gpu_op() + if not self.chakra_nodes[node_id].data_deps and not self.pytorch_nodes[node_id].is_gpu_op() ] ready_gpu_nodes = [ (node_id, self.chakra_nodes[node_id]) for node_id in self.chakra_nodes - if not self.chakra_nodes[node_id].data_deps and - self.pytorch_nodes[node_id].is_gpu_op() + if not self.chakra_nodes[node_id].data_deps and self.pytorch_nodes[node_id].is_gpu_op() ] ready_cpu_nodes.sort(key=lambda x: x[1].id) ready_gpu_nodes.sort(key=lambda x: x[1].id) @@ -709,8 +697,7 @@ def simulate_execution(self) -> None: current_time: int = 0 # Simulated global clock in microseconds - while any([ready_cpu_nodes, ready_gpu_nodes, current_cpu_node, - current_gpu_node]): + while any([ready_cpu_nodes, ready_gpu_nodes, current_cpu_node, current_gpu_node]): if ready_cpu_nodes and not current_cpu_node: cpu_node_id, cpu_node = ready_cpu_nodes.pop(0) current_cpu_node = (cpu_node_id, current_time) @@ -731,16 +718,18 @@ def simulate_execution(self) -> None: current_time += 1 - if current_cpu_node and current_time - current_cpu_node[1] >= \ - self.chakra_nodes[current_cpu_node[0]].duration_micros: - self.logger.info(f"CPU Node ID {current_cpu_node[0]} completed " - f"at {current_time}us") + if ( + current_cpu_node + and current_time - current_cpu_node[1] >= self.chakra_nodes[current_cpu_node[0]].duration_micros + ): + self.logger.info(f"CPU Node ID {current_cpu_node[0]} completed " f"at {current_time}us") current_cpu_node = None - if current_gpu_node and current_time - current_gpu_node[1] >= \ - self.chakra_nodes[current_gpu_node[0]].duration_micros: - self.logger.info(f"GPU Node ID {current_gpu_node[0]} completed " - f"at {current_time}us") + if ( + current_gpu_node + and current_time - current_gpu_node[1] >= self.chakra_nodes[current_gpu_node[0]].duration_micros + ): + self.logger.info(f"GPU Node ID {current_gpu_node[0]} completed " f"at {current_time}us") current_gpu_node = None for node_id in list(issued_nodes): diff --git a/et_converter/pytorch_node.py b/et_converter/pytorch_node.py index 81f245b0..12b3977a 100644 --- a/et_converter/pytorch_node.py +++ b/et_converter/pytorch_node.py @@ -29,11 +29,11 @@ def __init__(self, node_data: Dict[str, Any]) -> None: PyTorch node. """ self.node_data = node_data - self.data_deps: List['PyTorchNode'] = [] - self.children: List['PyTorchNode'] = [] - self.gpu_children: List['PyTorchNode'] = [] - self.record_param_comms_node: Optional['PyTorchNode'] = None - self.nccl_node: Optional['PyTorchNode'] = None + self.data_deps: List["PyTorchNode"] = [] + self.children: List["PyTorchNode"] = [] + self.gpu_children: List["PyTorchNode"] = [] + self.record_param_comms_node: Optional["PyTorchNode"] = None + self.nccl_node: Optional["PyTorchNode"] = None def __repr__(self) -> str: """ @@ -527,7 +527,7 @@ def is_gpu_op(self) -> bool: """ return self.has_cat() - def add_data_dep(self, parent_node: 'PyTorchNode') -> None: + def add_data_dep(self, parent_node: "PyTorchNode") -> None: """ Adds a data-dependent parent node to this node. @@ -536,7 +536,7 @@ def add_data_dep(self, parent_node: 'PyTorchNode') -> None: """ self.data_deps.append(parent_node) - def add_child(self, child_node: 'PyTorchNode') -> None: + def add_child(self, child_node: "PyTorchNode") -> None: """ Adds a child node to this node. @@ -545,7 +545,7 @@ def add_child(self, child_node: 'PyTorchNode') -> None: """ self.children.append(child_node) - def add_gpu_child(self, gpu_child_node: 'PyTorchNode') -> None: + def add_gpu_child(self, gpu_child_node: "PyTorchNode") -> None: """ Adds a child GPU node for this node. diff --git a/et_converter/pytorch_tensor.py b/et_converter/pytorch_tensor.py index e46bbc71..631b54a0 100644 --- a/et_converter/pytorch_tensor.py +++ b/et_converter/pytorch_tensor.py @@ -32,9 +32,11 @@ def is_valid(self) -> bool: bool: True if tensor_data is a list of exactly five integers, False otherwise. """ - return (isinstance(self.tensor_data, list) and - len(self.tensor_data) == 6 and - all(isinstance(item, int) for item in self.tensor_data)) + return ( + isinstance(self.tensor_data, list) + and len(self.tensor_data) == 6 + and all(isinstance(item, int) for item in self.tensor_data) + ) @property def tensor_id(self) -> int: diff --git a/et_converter/text2chakra_converter.py b/et_converter/text2chakra_converter.py index 3d38547a..c49275a4 100644 --- a/et_converter/text2chakra_converter.py +++ b/et_converter/text2chakra_converter.py @@ -15,14 +15,12 @@ ALL_TO_ALL, ALL_GATHER, REDUCE_SCATTER, - GlobalMetadata + GlobalMetadata, ) + class Layer: - def __init__( - self, - line: str - ) -> None: + def __init__(self, line: str) -> None: try: col = line.strip().split() self.name = col[0] @@ -49,16 +47,12 @@ def __init__( self.bwd_wg_comp_node = None self.bwd_wg_comm_node = None except Exception: - raise ValueError(f"Cannot parse the following layer -- \"{line}\"") + raise ValueError(f'Cannot parse the following layer -- "{line}"') + class Text2ChakraConverter: def __init__( - self, - input_filename: str, - output_filename: str, - num_npus: int, - num_passes: int, - logger: logging.Logger + self, input_filename: str, output_filename: str, num_npus: int, num_passes: int, logger: logging.Logger ) -> None: self.input_filename = input_filename self.output_filename = output_filename @@ -73,26 +67,18 @@ def get_global_metadata(self): input_text = input_file.read() attr = [ ChakraAttr(name="schema", string_val="1.0.2-chakra.0.0.4"), - ChakraAttr(name="input_file", string_val=input_text) + ChakraAttr(name="input_file", string_val=input_text), ] metadata = GlobalMetadata(attr=attr) return metadata - def get_layers( - self, - f: TextIOWrapper, - num_layers: int - ) -> List[Layer]: + def get_layers(self, f: TextIOWrapper, num_layers: int) -> List[Layer]: layers = [] for line in f: layers.append(Layer(line)) return layers - def get_node( - self, - name: str, - node_type: NodeType - ) -> Any: + def get_node(self, name: str, node_type: NodeType) -> Any: node = Node() node.id = self.next_node_id self.next_node_id += 1 @@ -100,21 +86,12 @@ def get_node( node.type = node_type return node - def get_comp_node( - self, - layer_name: str, - phase: str, - comp_time: int - ) -> Any: - node = self.get_node("COMP_NODE_" + layer_name + "_" + phase, - COMP_NODE) + def get_comp_node(self, layer_name: str, phase: str, comp_time: int) -> Any: + node = self.get_node("COMP_NODE_" + layer_name + "_" + phase, COMP_NODE) node.duration_micros = comp_time return node - def get_comm_type( - self, - comm_type: str - ) -> int: + def get_comm_type(self, comm_type: str) -> int: if comm_type == "ALLREDUCE": return ALL_REDUCE elif comm_type == "ALLTOALL": @@ -125,25 +102,13 @@ def get_comm_type( return REDUCE_SCATTER return 0 - def get_comm_coll_node( - self, - layer_name: str, - comm_type: str, - comm_size: int - ) -> Any: - node = self.get_node(f"COMM_COLL_NODE_{layer_name}_{comm_type}", - COMM_COLL_NODE) - node.attr.append(ChakraAttr(name="comm_type", - int64_val=self.get_comm_type(comm_type))) - node.attr.append(ChakraAttr(name="comm_size", - uint64_val=comm_size)) + def get_comm_coll_node(self, layer_name: str, comm_type: str, comm_size: int) -> Any: + node = self.get_node(f"COMM_COLL_NODE_{layer_name}_{comm_type}", COMM_COLL_NODE) + node.attr.append(ChakraAttr(name="comm_type", int64_val=self.get_comm_type(comm_type))) + node.attr.append(ChakraAttr(name="comm_size", uint64_val=comm_size)) return node - def add_parent( - self, - child_node: Any, - parent_node: Any - ) -> None: + def add_parent(self, child_node: Any, parent_node: Any) -> None: child_node.data_deps.append(parent_node.id) def convert(self) -> None: @@ -158,22 +123,17 @@ def convert(self) -> None: self.convert_data_parallel(f, num_layers) elif parallelism_type == "MODEL": self.convert_model_parallel(f, num_layers) - elif (parallelism_type == "HYBRID_DATA_MODEL"): + elif parallelism_type == "HYBRID_DATA_MODEL": self.convert_hybrid_data_model(f, num_layers) - elif (parallelism_type == "HYBRID_MODEL_DATA"): + elif parallelism_type == "HYBRID_MODEL_DATA": self.convert_hybrid_model_data(f, num_layers) - elif (parallelism_type == "HYBRID_DLRM")\ - or (parallelism_type == "HYBRID_DLRM_ENHANCED"): + elif (parallelism_type == "HYBRID_DLRM") or (parallelism_type == "HYBRID_DLRM_ENHANCED"): last_bottom_layer = int(first_line[1]) self.convert_hybrid_dlrm(f, num_layers, last_bottom_layer) else: raise ValueError(f"Unsupported parallelism type, {parallelism_type}") - def convert_microbenchmark( - self, - f: TextIOWrapper, - num_layers: int - ) -> None: + def convert_microbenchmark(self, f: TextIOWrapper, num_layers: int) -> None: layers = self.get_layers(f, num_layers) for npu_id in range(self.num_npus): output_filename = "%s.%d.et" % (self.output_filename, npu_id) @@ -182,16 +142,12 @@ def convert_microbenchmark( encode_message(g, global_metadata) for i in range(self.num_passes): for layer in layers: - bwd_wg_comm_node = self.get_comm_coll_node(layer.name, - layer.bwd_wg_comm_type, - layer.bwd_wg_comm_size) + bwd_wg_comm_node = self.get_comm_coll_node( + layer.name, layer.bwd_wg_comm_type, layer.bwd_wg_comm_size + ) encode_message(g, bwd_wg_comm_node) - def convert_data_parallel( - self, - f: TextIOWrapper, - num_layers: int - ) -> None: + def convert_data_parallel(self, f: TextIOWrapper, num_layers: int) -> None: layers = self.get_layers(f, num_layers) for npu_id in range(self.num_npus): output_filename = "%s.%d.et" % (self.output_filename, npu_id) @@ -203,8 +159,7 @@ def convert_data_parallel( # forward pass for idx, layer in enumerate(layers): - fwd_comp_node = self.get_comp_node(layer.name, "FWD", - layer.fwd_comp_time) + fwd_comp_node = self.get_comp_node(layer.name, "FWD", layer.fwd_comp_time) if idx != 0: self.add_parent(fwd_comp_node, layers[idx - 1].fwd_comp_node) if layer.bwd_wg_comm_node is not None: @@ -214,28 +169,25 @@ def convert_data_parallel( # backward pass for idx, layer in enumerate(reversed(layers)): - bwd_wg_comp_node = self.get_comp_node(layer.name, "BWD_WG", - layer.bwd_wg_comp_time) + bwd_wg_comp_node = self.get_comp_node(layer.name, "BWD_WG", layer.bwd_wg_comp_time) if idx == 0: if fwd_comp_node is None: raise ValueError("fwd_comp_node is None") self.add_parent(bwd_wg_comp_node, fwd_comp_node) else: - self.add_parent(bwd_wg_comp_node, - layers[len(layers) - idx].bwd_ig_comp_node) + self.add_parent(bwd_wg_comp_node, layers[len(layers) - idx].bwd_ig_comp_node) encode_message(g, bwd_wg_comp_node) - bwd_wg_comm_node = self.get_comm_coll_node(layer.name, - layer.bwd_wg_comm_type, - layer.bwd_wg_comm_size) + bwd_wg_comm_node = self.get_comm_coll_node( + layer.name, layer.bwd_wg_comm_type, layer.bwd_wg_comm_size + ) self.add_parent(bwd_wg_comm_node, bwd_wg_comp_node) layer.bwd_wg_comm_node = bwd_wg_comm_node encode_message(g, bwd_wg_comm_node) if idx != (len(layers) - 1): - bwd_ig_comp_node = self.get_comp_node(layer.name, "BWD_IG", - layer.bwd_ig_comp_time) + bwd_ig_comp_node = self.get_comp_node(layer.name, "BWD_IG", layer.bwd_ig_comp_time) self.add_parent(bwd_ig_comp_node, bwd_wg_comp_node) layer.bwd_ig_comp_node = bwd_ig_comp_node encode_message(g, bwd_ig_comp_node) @@ -243,11 +195,7 @@ def convert_data_parallel( for layer in layers: layer.bwd_wg_comm_node = None - def convert_model_parallel( - self, - f: TextIOWrapper, - num_layers: int - ) -> None: + def convert_model_parallel(self, f: TextIOWrapper, num_layers: int) -> None: layers = self.get_layers(f, num_layers) for npu_id in range(self.num_npus): output_filename = "%s.%d.et" % (self.output_filename, npu_id) @@ -259,8 +207,7 @@ def convert_model_parallel( # forward pass for idx, layer in enumerate(layers): - fwd_comp_node = self.get_comp_node(layer.name, "FWD", - layer.fwd_comp_time) + fwd_comp_node = self.get_comp_node(layer.name, "FWD", layer.fwd_comp_time) if idx != 0: self.add_parent(fwd_comp_node, layers[idx - 1].fwd_comm_node) if layer.bwd_wg_comp_node is not None: @@ -268,38 +215,32 @@ def convert_model_parallel( layer.fwd_comp_node = fwd_comp_node encode_message(g, fwd_comp_node) - fwd_comm_node = self.get_comm_coll_node(layer.name, - layer.fwd_comm_type, - layer.fwd_comm_size) + fwd_comm_node = self.get_comm_coll_node(layer.name, layer.fwd_comm_type, layer.fwd_comm_size) layer.fwd_comm_node = fwd_comm_node self.add_parent(fwd_comm_node, fwd_comp_node) encode_message(g, fwd_comm_node) # backward pass for idx, layer in enumerate(reversed(layers)): - bwd_ig_comp_node = self.get_comp_node(layer.name, "BWD_IG", - layer.bwd_ig_comp_time) + bwd_ig_comp_node = self.get_comp_node(layer.name, "BWD_IG", layer.bwd_ig_comp_time) if idx == 0: if fwd_comm_node is None: raise ValueError("fwd_comm_node is None") self.add_parent(bwd_ig_comp_node, fwd_comm_node) else: - self.add_parent(bwd_ig_comp_node, - layers[len(layers) - idx].bwd_wg_comp_node) - self.add_parent(bwd_ig_comp_node, - layers[len(layers) - idx].bwd_ig_comm_node) + self.add_parent(bwd_ig_comp_node, layers[len(layers) - idx].bwd_wg_comp_node) + self.add_parent(bwd_ig_comp_node, layers[len(layers) - idx].bwd_ig_comm_node) encode_message(g, bwd_ig_comp_node) if idx != (num_layers - 1): - bwd_ig_comm_node = self.get_comm_coll_node(layer.name, - layer.bwd_ig_comm_type, - layer.bwd_ig_comm_size) + bwd_ig_comm_node = self.get_comm_coll_node( + layer.name, layer.bwd_ig_comm_type, layer.bwd_ig_comm_size + ) self.add_parent(bwd_ig_comm_node, bwd_ig_comp_node) layer.bwd_ig_comm_node = bwd_ig_comm_node encode_message(g, bwd_ig_comm_node) - bwd_wg_comp_node = self.get_comp_node(layer.name, "BWD_WG", - layer.bwd_wg_comp_time) + bwd_wg_comp_node = self.get_comp_node(layer.name, "BWD_WG", layer.bwd_wg_comp_time) self.add_parent(bwd_wg_comp_node, bwd_ig_comp_node) layer.bwd_wg_comp_node = bwd_wg_comp_node encode_message(g, bwd_wg_comp_node) @@ -307,11 +248,7 @@ def convert_model_parallel( for layer in layers: layer.bwd_wg_comp_node = None - def convert_hybrid_data_model( - self, - f: TextIOWrapper, - num_layers: int - ) -> None: + def convert_hybrid_data_model(self, f: TextIOWrapper, num_layers: int) -> None: layers = self.get_layers(f, num_layers) for npu_id in range(self.num_npus): output_filename = "%s.%d.et" % (self.output_filename, npu_id) @@ -323,53 +260,46 @@ def convert_hybrid_data_model( # forward pass for idx, layer in enumerate(layers): - fwd_comp_node = self.get_comp_node(layer.name, "FWD", - layer.fwd_comp_time) + fwd_comp_node = self.get_comp_node(layer.name, "FWD", layer.fwd_comp_time) if layer.bwd_wg_comm_node is not None: self.add_parent(fwd_comp_node, layer.bwd_wg_comm_node) if idx != 0: self.add_parent(fwd_comp_node, layers[idx - 1].fwd_comm_node) encode_message(g, fwd_comp_node) - fwd_comm_node = self.get_comm_coll_node(layer.name, - layer.fwd_comm_type, - layer.fwd_comm_size) + fwd_comm_node = self.get_comm_coll_node(layer.name, layer.fwd_comm_type, layer.fwd_comm_size) self.add_parent(fwd_comm_node, fwd_comp_node) layer.fwd_comm_node = fwd_comm_node encode_message(g, fwd_comm_node) # backward pass for idx, layer in enumerate(reversed(layers)): - bwd_ig_comp_node = self.get_comp_node(layer.name, "BWD_IG", - layer.bwd_ig_comp_time) + bwd_ig_comp_node = self.get_comp_node(layer.name, "BWD_IG", layer.bwd_ig_comp_time) if idx == 0: if fwd_comm_node is None: raise ValueError("fwd_comm_node is None") self.add_parent(bwd_ig_comp_node, fwd_comm_node) else: - self.add_parent(bwd_ig_comp_node, - layers[len(layers) - idx].bwd_wg_comp_node) - self.add_parent(bwd_ig_comp_node, - layers[len(layers) - idx].bwd_ig_comm_node) + self.add_parent(bwd_ig_comp_node, layers[len(layers) - idx].bwd_wg_comp_node) + self.add_parent(bwd_ig_comp_node, layers[len(layers) - idx].bwd_ig_comm_node) encode_message(g, bwd_ig_comp_node) if idx != num_layers - 1: - bwd_ig_comm_node = self.get_comm_coll_node(layer.name + "_IG_COMM_", - layer.bwd_ig_comm_type, - layer.bwd_ig_comm_size) + bwd_ig_comm_node = self.get_comm_coll_node( + layer.name + "_IG_COMM_", layer.bwd_ig_comm_type, layer.bwd_ig_comm_size + ) self.add_parent(bwd_ig_comm_node, bwd_ig_comp_node) layer.bwd_ig_comm_node = bwd_ig_comm_node encode_message(g, bwd_ig_comm_node) - bwd_wg_comp_node = self.get_comp_node(layer.name, "BWD_WG", - layer.bwd_wg_comp_time) + bwd_wg_comp_node = self.get_comp_node(layer.name, "BWD_WG", layer.bwd_wg_comp_time) self.add_parent(bwd_wg_comp_node, bwd_ig_comp_node) layer.bwd_wg_comp_node = bwd_wg_comp_node encode_message(g, bwd_wg_comp_node) - bwd_wg_comm_node = self.get_comm_coll_node(layer.name, - layer.bwd_wg_comm_type, - layer.bwd_wg_comm_size) + bwd_wg_comm_node = self.get_comm_coll_node( + layer.name, layer.bwd_wg_comm_type, layer.bwd_wg_comm_size + ) self.add_parent(bwd_wg_comm_node, bwd_wg_comp_node) layer.bwd_wg_comm_node = bwd_wg_comm_node encode_message(g, bwd_wg_comm_node) @@ -377,11 +307,7 @@ def convert_hybrid_data_model( for layer in layers: layer.bwd_wg_comm_node = None - def convert_hybrid_model_data( - self, - f: TextIOWrapper, - num_layers: int - ) -> None: + def convert_hybrid_model_data(self, f: TextIOWrapper, num_layers: int) -> None: layers = self.get_layers(f, num_layers) for npu_id in range(self.num_npus): output_filename = "%s.%d.et" % (self.output_filename, npu_id) @@ -393,25 +319,21 @@ def convert_hybrid_model_data( # forward pass for idx, layer in enumerate(layers): - fwd_comp_node = self.get_comp_node(layer.name, "FWD", - layer.fwd_comp_time) + fwd_comp_node = self.get_comp_node(layer.name, "FWD", layer.fwd_comp_time) if layer.bwd_wg_comm_node is not None: self.add_parent(fwd_comp_node, layer.bwd_wg_comm_node) if idx != 0: self.add_parent(fwd_comp_node, layers[idx - 1].fwd_comm_node) encode_message(g, fwd_comp_node) - fwd_comm_node = self.get_comm_coll_node(layer.name, - layer.fwd_comm_type, - layer.fwd_comm_size) + fwd_comm_node = self.get_comm_coll_node(layer.name, layer.fwd_comm_type, layer.fwd_comm_size) self.add_parent(fwd_comm_node, fwd_comp_node) layer.fwd_comm_node = fwd_comm_node encode_message(g, fwd_comm_node) # backward pass for idx, layer in enumerate(reversed(layers)): - bwd_ig_comp_node = self.get_comp_node(layer.name, "BWD_IG", - layer.bwd_ig_comp_time) + bwd_ig_comp_node = self.get_comp_node(layer.name, "BWD_IG", layer.bwd_ig_comp_time) if idx == 0: if fwd_comm_node is None: raise ValueError("fwd_comm_node is None") @@ -422,22 +344,21 @@ def convert_hybrid_model_data( encode_message(g, bwd_ig_comp_node) if idx != num_layers - 1: - bwd_ig_comm_node = self.get_comm_coll_node(layer.name, - layer.bwd_ig_comm_type, - layer.bwd_ig_comm_size) + bwd_ig_comm_node = self.get_comm_coll_node( + layer.name, layer.bwd_ig_comm_type, layer.bwd_ig_comm_size + ) self.add_parent(bwd_ig_comm_node, bwd_ig_comp_node) layer.bwd_ig_comm_node = bwd_ig_comm_node encode_message(g, bwd_ig_comm_node) - bwd_wg_comp_node = self.get_comp_node(layer.name, "BWD_WG", - layer.bwd_wg_comp_time) + bwd_wg_comp_node = self.get_comp_node(layer.name, "BWD_WG", layer.bwd_wg_comp_time) self.add_parent(bwd_wg_comp_node, bwd_ig_comp_node) layer.bwd_wg_comp_node = bwd_wg_comp_node encode_message(g, bwd_wg_comp_node) - bwd_wg_comm_node = self.get_comm_coll_node(layer.name, - layer.bwd_wg_comm_type, - layer.bwd_wg_comm_size) + bwd_wg_comm_node = self.get_comm_coll_node( + layer.name, layer.bwd_wg_comm_type, layer.bwd_wg_comm_size + ) self.add_parent(bwd_wg_comm_node, bwd_wg_comp_node) layer.bwd_wg_comm_node = bwd_wg_comm_node encode_message(g, bwd_wg_comm_node) @@ -445,12 +366,7 @@ def convert_hybrid_model_data( for layer in layers: layer.bwd_wg_comm_node = None - def convert_hybrid_dlrm( - self, - f: TextIOWrapper, - num_layers: int, - last_bottom_layer: int - ) -> None: + def convert_hybrid_dlrm(self, f: TextIOWrapper, num_layers: int, last_bottom_layer: int) -> None: layers = self.get_layers(f, num_layers) for npu_id in range(self.num_npus): output_filename = "%s.%d.et" % (self.output_filename, npu_id) @@ -462,8 +378,7 @@ def convert_hybrid_dlrm( # forward pass for idx, layer in enumerate(layers): - fwd_comp_node = self.get_comp_node(layer.name, "FWD", - layer.fwd_comp_time) + fwd_comp_node = self.get_comp_node(layer.name, "FWD", layer.fwd_comp_time) if layer.bwd_wg_comm_node is not None: self.add_parent(fwd_comp_node, layer.bwd_wg_comm_node) elif layer.bwd_wg_comp_node is not None: @@ -476,51 +391,47 @@ def convert_hybrid_dlrm( encode_message(g, fwd_comp_node) if layer.fwd_comm_type == "ALLTOALL": - fwd_comm_node = self.get_comm_coll_node(layer.name, - layer.fwd_comm_type, - layer.fwd_comm_size) + fwd_comm_node = self.get_comm_coll_node( + layer.name, layer.fwd_comm_type, layer.fwd_comm_size + ) self.add_parent(fwd_comm_node, fwd_comp_node) layer.fwd_comm_node = fwd_comm_node encode_message(g, fwd_comm_node) # backward pass for idx, layer in enumerate(reversed(layers)): - bwd_wg_comp_node = self.get_comp_node(layer.name, "BWD_WG", - layer.bwd_wg_comp_time) + bwd_wg_comp_node = self.get_comp_node(layer.name, "BWD_WG", layer.bwd_wg_comp_time) if idx == 0: if fwd_comp_node is None: raise ValueError("fwd_comp_node is None") self.add_parent(bwd_wg_comp_node, fwd_comp_node) else: if layers[len(layers) - idx].bwd_ig_comp_node is not None: - self.add_parent(bwd_wg_comp_node, - layers[len(layers) - idx].bwd_ig_comp_node) + self.add_parent(bwd_wg_comp_node, layers[len(layers) - idx].bwd_ig_comp_node) if layers[len(layers) - idx - 1].bwd_ig_comm_node is not None: - self.add_parent(bwd_wg_comp_node, - layers[len(layers) - idx - 1].bwd_ig_comm_node) + self.add_parent(bwd_wg_comp_node, layers[len(layers) - idx - 1].bwd_ig_comm_node) layer.bwd_wg_comp_node = bwd_wg_comp_node encode_message(g, bwd_wg_comp_node) if layer.bwd_wg_comm_type != "NONE": - bwd_wg_comm_node = self.get_comm_coll_node(layer.name, - layer.bwd_wg_comm_type, - layer.bwd_wg_comm_size) + bwd_wg_comm_node = self.get_comm_coll_node( + layer.name, layer.bwd_wg_comm_type, layer.bwd_wg_comm_size + ) self.add_parent(bwd_wg_comm_node, bwd_wg_comp_node) layer.bwd_wg_comm_node = bwd_wg_comm_node encode_message(g, bwd_wg_comm_node) bwd_ig_comp_node = None if idx != (len(layers) - 1): - bwd_ig_comp_node = self.get_comp_node(layer.name, "BWD_IG", - layer.bwd_ig_comp_time) + bwd_ig_comp_node = self.get_comp_node(layer.name, "BWD_IG", layer.bwd_ig_comp_time) self.add_parent(bwd_ig_comp_node, bwd_wg_comp_node) layer.bwd_ig_comp_node = bwd_ig_comp_node encode_message(g, bwd_ig_comp_node) if (len(layers) - idx - 1) == (last_bottom_layer + 1): - bwd_ig_comm_node = self.get_comm_coll_node(layers[0].name, - layers[0].bwd_ig_comm_type, - layers[0].bwd_ig_comm_size) + bwd_ig_comm_node = self.get_comm_coll_node( + layers[0].name, layers[0].bwd_ig_comm_type, layers[0].bwd_ig_comm_size + ) if bwd_ig_comp_node is None: raise ValueError("bwd_ig_comp_node is None") self.add_parent(bwd_ig_comm_node, bwd_ig_comp_node) diff --git a/et_visualizer/et_visualizer.py b/et_visualizer/et_visualizer.py index bf4bd54f..aa924df8 100644 --- a/et_visualizer/et_visualizer.py +++ b/et_visualizer/et_visualizer.py @@ -5,10 +5,7 @@ import networkx as nx import re -from chakra.third_party.utils.protolib import ( - openFileRd as open_file_rd, - decodeMessage as decode_message -) +from chakra.third_party.utils.protolib import openFileRd as open_file_rd, decodeMessage as decode_message from chakra.et_def.et_def_pb2 import ( GlobalMetadata, Node, @@ -32,23 +29,11 @@ def escape_label(label: str) -> str: def main() -> None: - parser = argparse.ArgumentParser( - description="Execution Trace Visualizer" - ) - parser.add_argument( - "--input_filename", - type=str, - default=None, - required=True, - help="Input Chakra execution trace filename" - ) + parser = argparse.ArgumentParser(description="Execution Trace Visualizer") parser.add_argument( - "--output_filename", - type=str, - default=None, - required=True, - help="Output graph filename" + "--input_filename", type=str, default=None, required=True, help="Input Chakra execution trace filename" ) + parser.add_argument("--output_filename", type=str, default=None, required=True, help="Output graph filename") args = parser.parse_args() et = open_file_rd(args.input_filename) @@ -61,10 +46,7 @@ def main() -> None: decode_message(et, gm) while decode_message(et, node): escaped_label = escape_label(node.name) - f.node(name=f"{node.id}", - label=escaped_label, - id=str(node.id), - shape="record") + f.node(name=f"{node.id}", label=escaped_label, id=str(node.id), shape="record") # Handling data dependencies for data_dep_id in node.data_deps: @@ -75,11 +57,9 @@ def main() -> None: f.edge(str(ctrl_dep_id), str(node.id), arrowhead="tee") # using "tee" arrow for ctrl_deps if args.output_filename.endswith(".pdf"): - f.render(args.output_filename.replace(".pdf", ""), - format="pdf", cleanup=True) + f.render(args.output_filename.replace(".pdf", ""), format="pdf", cleanup=True) else: # ends with ".dot" - f.render(args.output_filename.replace(".dot", ""), - format="dot", cleanup=True) + f.render(args.output_filename.replace(".dot", ""), format="dot", cleanup=True) elif args.output_filename.endswith(".graphml"): G = nx.DiGraph() decode_message(et, gm) diff --git a/pyproject.toml b/pyproject.toml index 6c1d45b2..063b36dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,4 +35,14 @@ chakra_converter = "chakra.et_converter.et_converter:main" chakra_visualizer = "chakra.et_visualizer.et_visualizer:main" chakra_timeline_visualizer = "chakra.et_timeline_visualizer.et_timeline_visualizer:main" chakra_generator = "chakra.et_generator.et_generator:main" -chakra_jsonizer = "chakra.et_jsonizer.et_jsonizer:main" \ No newline at end of file +chakra_jsonizer = "chakra.et_jsonizer.et_jsonizer:main" + +[tool.ruff] +target-version = "py39" +line-length = 120 + +[tool.ruff.lint] +select = ["I", "B", "E", "F", "SIM", "W", "C90", "EXE"] + +[tool.ruff.format] +indent-style = "space" diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 00000000..ad762d48 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,3 @@ +ruff==0.3.5 +flake8==7.0.0 +pyre-check==0.9.19 diff --git a/setup.py b/setup.py index 4ac0ffe7..d90ea865 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ class build_grpc(build): - sub_commands = [('build_grpc', None)] + build.sub_commands + sub_commands = [("build_grpc", None)] + build.sub_commands -setup(cmdclass={'build': build_grpc}) +setup(cmdclass={"build": build_grpc}) diff --git a/third_party/utils/protolib.py b/third_party/utils/protolib.py index ea0ff097..dcfb7aab 100644 --- a/third_party/utils/protolib.py +++ b/third_party/utils/protolib.py @@ -71,6 +71,7 @@ import gzip import struct + def openFileRd(in_file): """ This opens the file passed as argument for reading using an appropriate @@ -81,7 +82,7 @@ def openFileRd(in_file): # First see if this file is gzipped try: # Opening the file works even if it is not a gzip file - proto_in = gzip.open(in_file, 'rb') + proto_in = gzip.open(in_file, "rb") # Force a check of the magic number by seeking in the # file. If we do not do it here the error will occur when @@ -89,12 +90,13 @@ def openFileRd(in_file): proto_in.seek(1) proto_in.seek(0) except IOError: - proto_in = open(in_file, 'rb') + proto_in = open(in_file, "rb") except IOError: print("Failed to open ", in_file, " for reading") exit(-1) return proto_in + def _DecodeVarint32(in_file): """ The decoding of the Varint32 is copied from @@ -106,24 +108,25 @@ def _DecodeVarint32(in_file): shift = 0 pos = 0 # Use a 32-bit mask - mask = 0xffffffff + mask = 0xFFFFFFFF while 1: c = in_file.read(1) if len(c) == 0: return (0, 0) - b = struct.unpack(' 0x7fffffffffffffff: - result -= (1 << 64) + if result > 0x7FFFFFFFFFFFFFFF: + result -= 1 << 64 result |= ~mask else: result &= mask return (result, pos) shift += 7 if shift >= 64: - raise IOError('Too many bytes when decoding varint.') + raise IOError("Too many bytes when decoding varint.") + def decodeMessage(in_file, message): """ @@ -140,19 +143,21 @@ def decodeMessage(in_file, message): except IOError: return False + def _EncodeVarint32(out_file, value): - """ - The encoding of the Varint32 is copied from - google.protobuf.internal.encoder and is only repeated here to - avoid depending on the internal functions in the library. - """ - bits = value & 0x7f - value >>= 7 - while value: - out_file.write(struct.pack('>= 7 - out_file.write(struct.pack('>= 7 + out_file.write(struct.pack(" logging.Logger: - formatter = logging.Formatter( - "%(levelname)s [%(asctime)s] %(message)s", - datefmt="%m/%d/%Y %I:%M:%S %p") + formatter = logging.Formatter("%(levelname)s [%(asctime)s] %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p") file_handler = FileHandler(log_filename, mode="w") file_handler.setLevel(logging.DEBUG) @@ -35,6 +35,7 @@ def get_logger(log_filename: str) -> logging.Logger: return logger + def is_local_mem_node(node_name: str) -> bool: if ("MEM_LOAD_NODE" in node_name) and ("LOCAL_MEMORY" in node_name): return True @@ -43,6 +44,7 @@ def is_local_mem_node(node_name: str) -> bool: else: return False + def is_remote_mem_node(node_name: str) -> bool: if ("MEM_LOAD_NODE" in node_name) and ("REMOTE_MEMORY" in node_name): return True @@ -51,20 +53,21 @@ def is_remote_mem_node(node_name: str) -> bool: else: return False + def is_comp_node(node_name: str) -> bool: if "COMP_NODE" in node_name: return True else: return False + def is_comm_node(node_name: str) -> bool: - if ("COMM_SEND_NODE" in node_name)\ - or ("COMM_RECV_NODE" in node_name)\ - or ("COMM_COLL_NODE" in node_name): + if ("COMM_SEND_NODE" in node_name) or ("COMM_RECV_NODE" in node_name) or ("COMM_COLL_NODE" in node_name): return True else: return False + def get_tid(node_name: str) -> TID: if is_local_mem_node(node_name): return TID.LOCAL_MEMORY @@ -77,9 +80,8 @@ def get_tid(node_name: str) -> TID: else: raise ValueError(f"Node type cannot be identified from {node_name}") -def parse_event( - line: str -) -> Tuple[str, int, int, int, str]: + +def parse_event(line: str) -> Tuple[str, int, int, int, str]: try: cols = line.strip().split(",") trace_type = cols[0] @@ -89,13 +91,10 @@ def parse_event( node_name = cols[4].split("=")[1] return (trace_type, npu_id, curr_cycle, node_id, node_name) except Exception as e: - raise ValueError(f"Cannot parse the following event -- \"{line}\": {e}") + raise ValueError(f'Cannot parse the following event -- "{line}": {e}') + -def get_trace_events( - input_filename: str, - num_npus: int, - npu_frequency: int -) -> List[Dict[str, Any]]: +def get_trace_events(input_filename: str, num_npus: int, npu_frequency: int) -> List[Dict[str, Any]]: trace_dict = {i: {} for i in range(num_npus)} trace_events = [] @@ -114,15 +113,17 @@ def get_trace_events( duration_in_cycles = curr_cycle - issued_cycle duration_in_ms = duration_in_cycles / (npu_frequency * 1_000) - trace_events.append({ - "pid": npu_id, - "tid": tid, - "ts": issued_ms, - "dur": duration_in_ms, - "ph": "X", - "name": node_name, - "args": {"ms": duration_in_ms} - }) + trace_events.append( + { + "pid": npu_id, + "tid": tid, + "ts": issued_ms, + "dur": duration_in_ms, + "ph": "X", + "name": node_name, + "args": {"ms": duration_in_ms}, + } + ) del trace_dict[npu_id][node_id] else: @@ -130,31 +131,20 @@ def get_trace_events( return trace_events -def write_trace_events( - output_filename: str, - num_npus: int, - trace_events: List[Dict[str, Any]] -) -> None: - output_dict = { - "meta_user": "aras", - "traceEvents": trace_events, - "meta_cpu_count": num_npus - } + +def write_trace_events(output_filename: str, num_npus: int, trace_events: List[Dict[str, Any]]) -> None: + output_dict = {"meta_user": "aras", "traceEvents": trace_events, "meta_cpu_count": num_npus} with open(output_filename, "w") as f: json.dump(output_dict, f) + def main() -> None: parser = argparse.ArgumentParser(description="Timeline Visualizer") - parser.add_argument("--input_filename", type=str, default=None, required=True, - help="Input timeline filename") - parser.add_argument("--output_filename", type=str, default=None, required=True, - help="Output trace filename") - parser.add_argument("--num_npus", type=int, default=None, required=True, - help="Number of NPUs in a system") - parser.add_argument("--npu_frequency", type=int, default=None, required=True, - help="NPU frequency in MHz") - parser.add_argument("--log_filename", type=str, default="debug.log", - help="Log filename") + parser.add_argument("--input_filename", type=str, default=None, required=True, help="Input timeline filename") + parser.add_argument("--output_filename", type=str, default=None, required=True, help="Output trace filename") + parser.add_argument("--num_npus", type=int, default=None, required=True, help="Number of NPUs in a system") + parser.add_argument("--npu_frequency", type=int, default=None, required=True, help="NPU frequency in MHz") + parser.add_argument("--log_filename", type=str, default="debug.log", help="Log filename") args = parser.parse_args() logger = get_logger(args.log_filename) @@ -167,5 +157,6 @@ def main() -> None: logger.error(str(e)) sys.exit(1) + if __name__ == "__main__": main() diff --git a/utils/et_generator/et_generator.py b/utils/et_generator/et_generator.py index 61f893c4..7ce90e7b 100644 --- a/utils/et_generator/et_generator.py +++ b/utils/et_generator/et_generator.py @@ -36,6 +36,7 @@ NODE_ID = 0 + def get_node(node_name: str, node_type: ChakraNodeType) -> ChakraNode: global NODE_ID node = ChakraNode() @@ -245,32 +246,12 @@ def one_comm_coll_node_reducescatter(num_npus: int, comm_size: int) -> None: def main() -> None: - parser = argparse.ArgumentParser( - description="Execution Trace Generator" - ) - parser.add_argument( - "--num_npus", - type=int, - default=64, - help="Number of NPUs" - ) - parser.add_argument( - "--default_runtime", - type=int, - default=5, - help="Default runtime of compute nodes" - ) - parser.add_argument( - "--default_tensor_size", - type=int, - default=1024, - help="Default tensor size of memory nodes" - ) + parser = argparse.ArgumentParser(description="Execution Trace Generator") + parser.add_argument("--num_npus", type=int, default=64, help="Number of NPUs") + parser.add_argument("--default_runtime", type=int, default=5, help="Default runtime of compute nodes") + parser.add_argument("--default_tensor_size", type=int, default=1024, help="Default tensor size of memory nodes") parser.add_argument( - "--default_comm_size", - type=int, - default=65536, - help="Default communication size of communication nodes" + "--default_comm_size", type=int, default=65536, help="Default communication size of communication nodes" ) args = parser.parse_args() diff --git a/utils/et_jsonizer/et_jsonizer.py b/utils/et_jsonizer/et_jsonizer.py index 641b82e1..613c8fb7 100644 --- a/utils/et_jsonizer/et_jsonizer.py +++ b/utils/et_jsonizer/et_jsonizer.py @@ -4,10 +4,7 @@ from google.protobuf.json_format import MessageToJson -from chakra.third_party.utils.protolib import ( - openFileRd as open_file_rd, - decodeMessage as decode_message -) +from chakra.third_party.utils.protolib import openFileRd as open_file_rd, decodeMessage as decode_message from chakra.et_def.et_def_pb2 import ( GlobalMetadata, @@ -16,26 +13,18 @@ def main() -> None: - parser = argparse.ArgumentParser( - description="Converts Chakra execution trace to JSON format." - ) + parser = argparse.ArgumentParser(description="Converts Chakra execution trace to JSON format.") parser.add_argument( - "--input_filename", - type=str, - required=True, - help="Specifies the input filename of the Chakra execution trace." + "--input_filename", type=str, required=True, help="Specifies the input filename of the Chakra execution trace." ) parser.add_argument( - "--output_filename", - type=str, - required=True, - help="Specifies the output filename for the JSON data." + "--output_filename", type=str, required=True, help="Specifies the output filename for the JSON data." ) args = parser.parse_args() execution_trace = open_file_rd(args.input_filename) node = ChakraNode() - with open(args.output_filename, 'w') as file: + with open(args.output_filename, "w") as file: global_metadata = GlobalMetadata() decode_message(execution_trace, global_metadata) file.write(MessageToJson(global_metadata))