From 437c2c3e86e7b216a400c61e13fd3ff97da4278f Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Mon, 8 Jan 2024 21:34:20 -0500 Subject: [PATCH 01/13] et_converter: Refactor PyTorch2ChakraConverter --- et_converter/pytorch2chakra_converter.py | 1451 ++++++++++------------ et_converter/pytorch_node.py | 628 ++++++++++ et_converter/pytorch_tensor.py | 110 ++ 3 files changed, 1421 insertions(+), 768 deletions(-) create mode 100644 et_converter/pytorch_node.py create mode 100644 et_converter/pytorch_tensor.py diff --git a/et_converter/pytorch2chakra_converter.py b/et_converter/pytorch2chakra_converter.py index 9bd7048f..e92491dc 100644 --- a/et_converter/pytorch2chakra_converter.py +++ b/et_converter/pytorch2chakra_converter.py @@ -4,11 +4,11 @@ import copy import json import logging - -from enum import Enum -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Tuple from chakra.third_party.utils.protolib import encodeMessage as encode_message +from chakra.et_converter.pytorch_node import PyTorchNodeType, PyTorchNode +from chakra.et_converter.pytorch_tensor import PyTorchTensor, list_to_pytorch_tensor from chakra.et_def.et_def_pb2 import ( GlobalMetadata, Node as ChakraNode, @@ -24,13 +24,84 @@ ) -class PyTorchNodeType(Enum): - CPU_OP = 1 - GPU_OP = 2 - LABEL = 3 # Non-operator nodes +class UniqueIdAssigner: + """ + Class for assigning unique IDs. Generates a new unique ID for each call, + even with the same original ID, and keeps track of all assigned IDs. + + Attributes: + next_id (int): The next available unique ID. + original_to_assigned_ids (Dict[int, List[int]]): Mapping from original + IDs to lists of assigned unique IDs. + """ + + def __init__(self) -> None: + self.next_id = 0 + self.original_to_assigned_ids: Dict[int, List[int]] = {} + + def assign_unique_id(self, original_id: int) -> int: + """ + Generates and tracks a new unique ID for each call for a given original ID. + + Args: + original_id (int): The original ID to generate a unique ID for. + + Returns: + int: A new unique ID for the original ID. + """ + unique_id = self.next_id + self.next_id += 1 + + assigned_ids = self.original_to_assigned_ids.setdefault(original_id, []) + assigned_ids.append(unique_id) + + return unique_id + + def get_assigned_ids(self, original_id: int) -> List[int]: + """ + Retrieves all unique IDs assigned to a given original ID. + + Args: + original_id (int): The original ID to retrieve unique IDs for. + + Returns: + List[int]: List of unique IDs assigned to the original ID. + """ + return self.original_to_assigned_ids.get(original_id, []) class PyTorch2ChakraConverter: + """ + Converter class for transforming PyTorch execution traces into Chakra format. + + This class is responsible for converting the execution traces collected + from PyTorch into a format that is compatible with Chakra, a performance + analysis tool. It handles the intricate mappings and transformations + required to accurately represent the execution in a different format. + + Attributes: + input_filename (str): Input file name containing PyTorch execution trace. + output_filename (str): Output file name for the converted Chakra trace. + num_dims (int): Number of dimensions involved in the conversion process. + logger (logging.Logger): Logger for logging information during conversion. + id_assigner (UniqueIdAssigner): Object to manage unique ID assignments. + pytorch_schema (Optional[str]): Schema info of the PyTorch trace. + pytorch_pid (Optional[int]): Process ID associated with the PyTorch trace. + pytorch_time (Optional[str]): Time info of the PyTorch trace. + pytorch_start_ts (Optional[int]): Start timestamp of the PyTorch trace. + pytorch_finish_ts (Optional[int]): Finish timestamp of the PyTorch trace. + pytorch_nodes (Dict[int, Any]): Map of PyTorch node IDs to nodes. + pytorch_root_nids (List[int]): List of root node IDs in the PyTorch trace. + pytorch_cpu_node_id_gpu_node_map (Dict[int, List[int]]): Map of PyTorch + CPU node IDs to GPU node IDs. + chakra_nodes (Dict[int, Any]): Map of Chakra node IDs to nodes. + phase_end_nids (List[int]): List of node IDs for phase dependencies. + input_storage_id_nid_map (Dict[int, int]): Map of input storage IDs to node IDs. + output_storage_id_nid_map (Dict[int, int]): Map of output storage IDs to node IDs. + input_tensor_id_nid_map (Dict[int, int]): Map of input tensor IDs to node IDs. + output_tensor_id_nid_map (Dict[int, int]): Map of output tensor IDs to node IDs. + """ + def __init__( self, input_filename: str, @@ -38,865 +109,709 @@ def __init__( num_dims: int, logger: logging.Logger ) -> None: - try: - self.pytorch_et = open(input_filename, "r") - except IOError as e: - raise Exception(f"Could not open file {input_filename}") - pytorch_et_data = json.load(self.pytorch_et) - self.pt_schema = pytorch_et_data["schema"] - self.pt_pid = pytorch_et_data["pid"] - self.pt_time = pytorch_et_data["time"] - self.pt_start_ts = pytorch_et_data["start_ts"] - self.pt_finish_ts = pytorch_et_data["finish_ts"] - self.pt_nodes = pytorch_et_data["nodes"] - - try: - self.chakra_et = open(output_filename, "wb") - except IOError as e: - raise Exception(f"Could not open file {output_filename}") + """ + Initializes the PyTorch to Chakra converter. It sets up necessary + attributes and prepares the environment for the conversion process. + Args: + input_filename (str): Name of the input file containing PyTorch execution trace. + output_filename (str): Name of the output file for the converted Chakra trace. + num_dims (int): Number of dimensions involved in the conversion process. + logger (logging.Logger): Logger for logging information during the conversion. + """ + self.input_filename = input_filename + self.output_filename = output_filename self.num_dims = num_dims self.logger = logger + self.id_assigner = UniqueIdAssigner() + self.initialize_attributes() + + def initialize_attributes(self) -> None: + # Initialize file and trace-related attributes + self.pytorch_schema = None + self.pytorch_pid = None + self.pytorch_time = None + self.pytorch_start_ts = None + self.pytorch_finish_ts = None + self.pytorch_nodes = None + self.pytorch_root_nids = [] + + # Initialize node mapping dictionaries + self.pytorch_cpu_node_id_gpu_node_map = {} + self.chakra_nodes = {} + + # Initialize lists for phase dependencies and data dependency maps + self.phase_end_nids = [] + + # Map of input storage IDs to node IDs: + # This dictionary tracks which nodes are consuming tensors based on their + # storage ID, establishing a link between tensor storage and node consumption. + self.input_storage_id_nid_map = {} + + # Map of output storage IDs to node IDs: + # Similar to input_storage_id_nid_map, but this tracks the production of + # tensors by nodes, associating tensor storage IDs with the nodes that + # produce them. + self.output_storage_id_nid_map = {} + + # Map of input tensor IDs to node IDs: + # This dictionary is used when storage IDs are not applicable. It tracks + # which nodes are consuming tensors by using tensor IDs, creating a link + # between tensor IDs and the nodes that consume them. + self.input_tensor_id_nid_map = {} + + # Map of output tensor IDs to node IDs: + # Similar to input_tensor_id_nid_map, but for tracking the output of tensors + # from nodes. It associates tensor IDs with the nodes that output them, + # used when storage IDs are not available. + self.output_tensor_id_nid_map = {} + + def convert(self) -> None: + """ + Converts PyTorch execution traces into the Chakra format. Orchestrates + the conversion process including trace loading, trace opening, phase + end node construction, node splitting, and node conversion. + """ + self.load_pytorch_execution_traces() + + self.open_chakra_execution_trace() + + self.construct_phase_end_nids() + + self.split_cpu_nodes_with_gpu_child() + + for pytorch_nid, pytorch_node in self.pytorch_nodes.items(): + if pytorch_node.is_cpu_op(): + self.update_input_tensor_map(pytorch_node.id, pytorch_node.inputs) + self.update_output_tensor_map(pytorch_node.id, pytorch_node.outputs) + + if pytorch_node.child_gpu: + pytorch_gpu_node = pytorch_node.child_gpu + self.update_input_tensor_map(pytorch_gpu_node.id, pytorch_gpu_node.inputs) + # Ignoring GPU->CPU dependencies for now since it creates unwanted dependencies. + + chakra_node = self.convert_to_chakra_node(pytorch_node) + self.chakra_nodes[chakra_node.id] = chakra_node + + if pytorch_node.child_gpu: + pytorch_gpu_node = pytorch_node.child_gpu + chakra_gpu_node = self.convert_to_chakra_node(pytorch_gpu_node) + + if chakra_node.type == COMM_COLL_NODE: + pytorch_nccl_node = self.get_nccl_node(pytorch_node) + chakra_gpu_node.attr.extend([ + ChakraAttr(name="comm_type", + int64_val=pytorch_nccl_node.collective_comm_type), + ChakraAttr(name="comm_size", + int64_val=pytorch_nccl_node.comm_size), + ChakraAttr(name="involved_dim", + bool_list={"values": [True]*self.num_dims}) + ]) + + chakra_gpu_node.data_deps.append(chakra_node.id) + self.chakra_nodes[chakra_gpu_node.id] = chakra_gpu_node + + for data_dep_pytorch_node in pytorch_node.data_deps: + chakra_node.data_deps.append(data_dep_pytorch_node.id) + + dep_nid = self.get_prev_phase_end_nid(chakra_node) + if (dep_nid != -1) and (dep_nid not in chakra_node.data_deps): + chakra_node.data_deps.append(dep_nid) +>>>>>>> a4155fe (et_converter: Refactor PyTorch2ChakraConverter) - # All PyTorch CPU operators are kept in pt_cpu_node_dict. - # Mappings between PyTorch NIDs and PyTorch nodes. - self.pt_cpu_node_dict = {} - - # All PyTorch GPU operators are kept in pt_gpu_node_dict. - # Mappings between PyTorch CPU node IDs (parent) and PyTorch GPU nodes (children). - self.pt_gpu_node_dict = {} - - # All record_param_comms nodes are tracked in pt_record_param_comms_node_dict. - # Mappings between parent PyTorch NIDs and PyTorch record_param_comms nodes. - self.pt_record_param_comms_node_dict = {} - - # All PyTorch NCCL nodes are kept in pt_nccl_node_dict. - # Mappings between parent PyTorch NIDs and PyTorch NCCL nodes. - self.pt_nccl_node_dict = {} - - # All Chakra nodes are maintained in ck_node_dict. - # Mappings between Chakra NIDs and Chakra nodes. - self.ck_node_dict = {} - - # A list of NIDs to enforce dependencies between phases. - # Phase of training iteration may include forward-pass, back-prop, optimizer, etc. - # We assume a phase ops cannot start until after all ops of previous phases are executed - self.inter_phase_dependency = [] - - # --------------------------------------------------------------------- - # These four dictionaries are used for identifying data dependencies - # between operators. Data dependencies can be discovered by identifying - # tensor input-output relationships between operators. - # - # Tensors have two types of IDs: storage ID and tensor ID - # A storage ID is considered as valid when it is larger than zero. - # When a storage ID is valid, it should be used for identifying a tensor - # Otherwise, a tensor ID should be utilized. - # --------------------------------------------------------------------- - # Mapping between storage_id and nid - self.input_storage_id_nid_dict = {} # storage_id is an input of a node with nid - self.output_storage_id_nid_dict = {} # storage_id is an output of a node with nid - # Mapping between tensor_id and nid - self.input_tensor_id_nid_dict = {} # tensor_id is an input of a node with nid - self.output_tensor_id_nid_dict = {} # tensor_id is an output of a node with nid - - def __del__(self): - if self.pytorch_et and not self.pytorch_et.closed: - self.pytorch_et.close() - if self.chakra_et and not self.chakra_et.closed: - self.chakra_et.close() + self.identify_data_dependency() - @staticmethod - def is_valid_tensor( - obj: Any - ) -> bool: - """ - Returns true if a given object is a valid tensor. + self.write_chakra_et() - An object is a valid tensor object when it is a list and the length of - the list is six. - """ - return isinstance(obj, list) and (len(obj) == 6) + self.close_chakra_execution_trace() - @staticmethod - def get_storage_id_from_tensor( - tensor: List[Any] - ) -> int: - """ - Returns the storage ID of a tensor. + def load_pytorch_execution_traces(self) -> None: """ - if len(tensor) < 2: - raise IndexError("Index out of bounds") - return tensor[1] + Loads PyTorch execution traces from a file. - @staticmethod - def get_tensor_id_from_tensor( - tensor: List[Any] - ) -> int: - """ - Returns the tensor ID of a tensor. - """ - if len(tensor) < 1: - raise IndexError("Index out of bounds") - return tensor[0] + Reads and parses the PyTorch execution trace data from a file, creating + PyTorchNode objects and establishing node relationships. - def has_valid_storage_id( - self, - tensor: List[Any] - ) -> bool: + Raises: + Exception: If there is an IOError in opening the file. """ - Returns true if a given tensor has a valid storage ID. + self.logger.info("Loading PyTorch execution traces from file.") + try: + with open(self.input_filename, "r") as pytorch_et: + pytorch_et_data = json.load(pytorch_et) + self._parse_and_instantiate_nodes(pytorch_et_data) + except IOError as e: + self.logger.error(f"Error opening file {self.input_filename}: {e}") + raise Exception(f"Could not open file {self.input_filename}") - A storage ID is considered valid if it is larger than zero. - When a storage ID is valid, it should be used instead of a tensor ID. + def _parse_and_instantiate_nodes(self, pytorch_et_data: Dict) -> None: """ - storage_id = self.get_storage_id_from_tensor(tensor) - return storage_id > 0 + Parses and instantiates PyTorch nodes from execution trace data. - @staticmethod - def has_cat_field( - node: Dict[str, Any] - ) -> bool: - """ - Returns true if a PyTorch node has a category field. - """ - return "cat" in node.keys() + Args: + pytorch_et_data (Dict): The execution trace data. - @staticmethod - def get_cat_field( - node: Dict[str, Any] - ) -> bool: - """ - Returns the category field of a given PyTorch node. + Extracts node information, sorts nodes by timestamp, and establishes + parent-child relationships among them. """ - return node["cat"] + self.logger.info("Extracting and processing node data from execution trace.") + self.pytorch_schema = pytorch_et_data["schema"] + self.pytorch_pid = pytorch_et_data["pid"] + self.pytorch_time = pytorch_et_data["time"] + self.pytorch_start_ts = pytorch_et_data["start_ts"] + self.pytorch_finish_ts = pytorch_et_data["finish_ts"] - @staticmethod - def has_dur( - node: Dict[str, Any] - ) -> bool: - """ - Returns true if a PyTorch node has a duration field. + pytorch_nodes = pytorch_et_data["nodes"] + pytorch_node_objects = { + node_data["id"]: PyTorchNode(node_data) for node_data in pytorch_nodes + } + self._establish_parent_child_relationships(pytorch_node_objects) + + def _establish_parent_child_relationships( + self, pytorch_node_objects: Dict[int, PyTorchNode] + ) -> None: """ - return "dur" in node.keys() + Establishes parent-child relationships among PyTorch nodes and counts + the node types. + + Args: + pytorch_node_objects (Dict[int, PyTorchNode]): Dictionary of PyTorch + node objects. + """ + # Initialize counters for different types of nodes + node_type_counts = { + "total_op": 0, + "cpu_op": 0, + "gpu_op": 0, + "record_param_comms_op": 0, + "nccl_op": 0, + "root_op": 0 + } - def get_pytorch_node_type( - self, - node: Dict[str, Any] - ) -> PyTorchNodeType: - if self.is_gpu_op(node): - return PyTorchNodeType.GPU_OP - elif (node["op_schema"] or node["outputs"])\ - or ("c10d::" in node["name"] or ("nccl:" in node["name"])): - return PyTorchNodeType.CPU_OP - else: - return PyTorchNodeType.LABEL + # Establish parent-child relationships + for pytorch_node in pytorch_node_objects.values(): + parent_id = pytorch_node.parent + if parent_id in pytorch_node_objects: + parent_node = pytorch_node_objects[parent_id] + parent_node.add_child(pytorch_node) - @staticmethod - def is_record_param_comms_node( - node: Dict[str, Any] - ) -> bool: - """ - Returns true if a PyToch node has "record_param_comms" in its name. - """ - return "record_param_comms" in node["name"] + if pytorch_node.is_gpu_op(): + parent_node.set_child_gpu(pytorch_node) - @staticmethod - def is_nccl_node( - node: Dict[str, Any] - ) -> bool: - """ - Returns true if a PyToch node is a NCCL node. - """ - return "nccl:" in node["name"] + if pytorch_node.is_record_param_comms_op(): + parent_node.record_param_comms_node = pytorch_node - def is_cpu_op_with_dur( - self, - node: Dict[str, Any] - ) -> bool: - """ - Returns true if a PyTorch node is a CPU operator and has a duration field. - """ - return (self.get_pytorch_node_type(node) == PyTorchNodeType.CPU_OP)\ - and self.has_dur(node) + if pytorch_node.is_nccl_op(): + parent_node.nccl_node = pytorch_node - def is_cpu_op( - self, - node: Dict[str, Any] - ) -> bool: - """ - Takes a PyTorch node and returns true if the node is a CPU operator. - """ - return self.get_pytorch_node_type(node) == PyTorchNodeType.CPU_OP - - @staticmethod - def get_collective_comm_type( - node: Dict[str, Any] - ) -> int: - """ - Returns the collective communication type of a given PyTorch node. - """ - if "all_reduce" in node["name"]: - return ALL_REDUCE - elif "all_to_all" in node["name"]: - return ALL_TO_ALL - elif "all_gather" in node["name"]: - return ALL_GATHER - elif "reduce_scatter" in node["name"]: - return REDUCE_SCATTER - elif "broadcast" in node["name"]: - return BROADCAST - else: - node_name = node["name"] - raise ValueError(f"{node_name} is not supported") - return INVALID_COMM - - @staticmethod - def get_data_type_size( - data_type: str - ) -> int: - """ - Returns the data type size of a given data type in string. - - References - * https://pytorch.org/docs/stable/tensors.html - * https://github.com/pytorch/pytorch/blob/master/c10/util/Half.h - """ - data_type_size_dict = { - "Tensor(float32)": 4, - "Tensor(float)": 4, - "Tensor(float64)": 8, - "Tensor(double)": 8, - "Tensor(float16)": 2, - "Tensor(half)": 2, - "Tensor(bfloat16)": 2, - "Tensor(complex64)": 8, - "Tensor(complex128)": 16, - "Tensor(uint8)": 1, - "Tensor(int8)": 1, - "Tensor(int16)": 2, - "Tensor(short)": 2, - "Tensor(int32)": 4, - "Tensor(int)": 4, - "Tensor(int64)": 8, - "Tensor(long)": 8, - "Tensor(c10::Half)": 2, - "Tensor(unsigned char)": 1, - "Tensor(long int)": 8, - } - try: - data_type_size = data_type_size_dict[data_type] - return data_type_size - except: - raise ValueError(f"{data_type} is unsupported") + if pytorch_node.name in ["[pytorch|profiler|execution_graph|thread]", + "[pytorch|profiler|execution_trace|thread]"]: + self.pytorch_root_nids.append(pytorch_node.id) + node_type_counts["root_op"] += 1 - def get_chakra_node_type_from_pytorch_node( - self, - node: Dict[str, Any] - ) -> int: - if self.has_cat_field(node) and ("ncclKernel" in node["name"]): - return COMM_COLL_NODE - elif self.has_cat_field(node): - return COMP_NODE - elif ("c10d::" in node["name"]) or ("nccl:" in node["name"]): - return COMM_COLL_NODE - elif (node["op_schema"] != "") or node["outputs"]: - return COMP_NODE - return INVALID_NODE + # Collect statistics + node_type_counts["total_op"] += 1 + if pytorch_node.is_cpu_op(): + node_type_counts["cpu_op"] += 1 + if pytorch_node.is_gpu_op(): + node_type_counts["gpu_op"] += 1 + if pytorch_node.is_record_param_comms_op(): + node_type_counts["record_param_comms_op"] += 1 + if pytorch_node.is_nccl_op(): + node_type_counts["nccl_op"] += 1 - def has_gpu_op( - self, - nid: int - ) -> bool: - """ - Returns true if a Chakra node has any associated GPU operator. - """ - return nid in self.pt_gpu_node_dict.keys() + # Log the counts of each node type + for node_type, count in node_type_counts.items(): + self.logger.info(f"{node_type}: {count}") - def get_comm_size( - self, - node: Dict[str, Any] - ) -> int: - """ - Calculates the communication size for a given input_type and input_shape. - """ - comm_size = 1 - for input_types in node["input_types"]: - comm_size *= self.get_data_type_size(input_types) - for input_shape_outer in node["input_shapes"]: - for input_shape_inner in input_shape_outer: - comm_size = comm_size * input_shape_inner - return comm_size + self.pytorch_nodes = pytorch_node_objects - def sort_pytorch_nodes_with_starting_time( - self - ) -> None: + def open_chakra_execution_trace(self) -> None: """ - Sorts PyTorch nodes with their starting time ("ts"). + Opens the Chakra execution trace file for writing. - Sorting helps executing nodes with earlier starting time first. + Raises: + Exception: If there is an IOError in opening the file. """ - self.pt_nodes = sorted(self.pt_nodes, key=lambda kv: kv["ts"]) + self.logger.info(f"Opening Chakra execution trace file: {self.output_filename}") + try: + self.chakra_et = open(self.output_filename, "wb") + except IOError as e: + err_msg = f"Error opening file {self.output_filename}: {e}" + self.logger.error(err_msg) + raise Exception(err_msg) - def get_total_runtime_ms( - self, - pt_node_list: List[Any] - ) -> int: - """ - Returns the total runtime of PyTorch CPU operators with a duration field. + def construct_phase_end_nids(self) -> None: """ - total_runtime_ms = 0 - for pt_node in pt_node_list: - if self.is_cpu_op_with_dur(pt_node): - total_runtime_ms += pt_node["dur"] # in milliseconds - return total_runtime_ms + Identifies the dependencies between phases in the execution trace. - def get_prev_inter_phase_dep_nid( - self, - node: ChakraNode, - ) -> int: + Uses a depth-first search (DFS) approach starting from phase root nodes to find + the largest Node ID (NID) in each phase for dependency tracking. """ - Returns the NID of the latest operator of the previous phase. + self.logger.info("Constructing phase end node IDs.") + for node in self.pytorch_nodes.values(): + if self.is_phase_root_op(node): + largest_nid_within_phase = self.dfs(node) + if largest_nid_within_phase != -1: + self.phase_end_nids.append(largest_nid_within_phase) + self.phase_end_nids.sort() - Finds the closest but smaller value from inter_phase_dependency compared to node.id. + def is_phase_root_op(self, node: PyTorchNode) -> bool: """ - index = bisect.bisect_left(self.inter_phase_dependency, node.id) + Determines if a node is a root node of a phase. - if index == 0: - # All elements in the list are greater than node.id; no element satisfies the condition. - return -1 - else: - # The element at index-1 will be the closest, smaller value compared to node.id. - return self.inter_phase_dependency[index - 1] + Args: + node (PyTorchNode): The node to be checked. - @staticmethod - def find_root_nids( - nodes: List[Any] - ) -> int: + Returns: + bool: True if the node is a root node of a phase, False otherwise. """ - Finds a root node and return its NID. + return node.parent in self.pytorch_root_nids - * Assumption: There could be multiple root node in a given execution trace. + def dfs(self, node: PyTorchNode) -> int: """ - root_nids = [] - for node in nodes: - if "[pytorch|profiler|execution_graph|thread]" in node["name"]: - root_nids.append(node["id"]) - elif "[pytorch|profiler|execution_trace|thread]" in node["name"]: - root_nids.append(node["id"]) - if not root_nids: - raise ValueError("Cannot find a root NID") - return root_nids + Performs a depth-first search to find the largest Node ID (NID) in a subtree. - @staticmethod - def is_label_node( - node: Dict[str, Any] - ) -> bool: - """ - Returns true if a given PyTorch node is a label node. + Explores the subtree of the given node to find the largest NID among CPU operation nodes. - All label node names start with "## ". + Args: + node (PyTorchNode): The node from which the search starts. + + Returns: + int: The largest NID found in the subtree, or -1 if no CPU operation node is found. """ - return node["name"].startswith("## ") + if node.get_op_type() == PyTorchNodeType.GPU_OP: + return -1 + elif node.get_op_type() == PyTorchNodeType.CPU_OP: + return node.id + else: # PyTorchNodeType.LABEL or any other type + largest_nid = -1 + for child_node in node.children: + largest_nid = max(largest_nid, self.dfs(child_node)) + return largest_nid - def is_phase_root_node( - self, - root_nids: List[int], - node: Dict[str, Any] - ) -> bool: - return node["parent"] in root_nids + self.pytorch_nodes = updated_pytorch_nodes + + def split_cpu_nodes_with_gpu_child(self) -> None: + """ + Decomposes CPU nodes with GPU child nodes to model execution overlap + accurately. This method addresses scenarios where a CPU node has a GPU + child node, with an overlap in their execution ending at the same time. + The method splits the CPU node into: + 1. Non-Overlapping Part: Segment before the GPU node starts. + 2. Overlapping Part: Segment overlapping with the GPU node. + + Timeline Stages: + Stage 1 - Original Scenario: + |------------ CPU Node ------------| + |--- GPU Node ---| + + Stage 2 - After Split: + |-- Non-Overlap --|--- Overlap ----| + |--- GPU Node ---| + + Raises: + ValueError: If timestamps of GPU and CPU nodes are inconsistent. + """ + self.logger.info("Decomposing CPU nodes with GPU child nodes.") + updated_pytorch_nodes: Dict[int, PyTorchNode] = {} + for cpu_node in self.pytorch_nodes.values(): + if cpu_node.child_gpu is None: + new_cpu_node_id = self.id_assigner.assign_unique_id(cpu_node.id) + cpu_node.id = new_cpu_node_id + for child_node in cpu_node.children: + child_node.parent = cpu_node.id + updated_pytorch_nodes[new_cpu_node_id] = cpu_node + else: + gpu_node = cpu_node.child_gpu + if gpu_node.ts >= (cpu_node.ts + cpu_node.dur): + err_msg = f"Inconsistent timestamps for CPU node {cpu_node.id} and its GPU child" + self.logger.error(err_msg) + raise ValueError(err_msg) + + cpu_node_first, cpu_node_second, updated_gpu_node =\ + self._split_cpu_node(cpu_node, gpu_node) + updated_pytorch_nodes[cpu_node_first.id] = cpu_node_first + updated_pytorch_nodes[cpu_node_second.id] = cpu_node_second + updated_pytorch_nodes[updated_gpu_node.id] = updated_gpu_node + + self.pytorch_nodes = updated_pytorch_nodes + + self.update_phase_end_nids() + + def _split_cpu_node( + self, cpu_node: PyTorchNode, gpu_node: PyTorchNode + ) -> Tuple[PyTorchNode, PyTorchNode, PyTorchNode]: + """ + Splits a CPU node based on the GPU node's timestamp. + + Args: + cpu_node (PyTorchNode): Original CPU node to be split. + gpu_node (PyTorchNode): GPU node dictating the split. + + Returns: + Tuple[PyTorchNode, PyTorchNode, PyTorchNode]: Two split nodes and the updated GPU node. + + Raises: + ValueError: For inconsistencies in the timestamps of the nodes. + """ + original_cpu_info = f"Original CPU Node ID {cpu_node.id} ({cpu_node.name}), " \ + f"Duration: {cpu_node.dur}." + self.logger.debug(original_cpu_info) + self.logger.debug(f"GPU Node ID {gpu_node.id} ({gpu_node.name}), " + f"Duration: {gpu_node.dur}.") + + cpu_node_first = copy.deepcopy(cpu_node) + cpu_node_first.id = self.id_assigner.assign_unique_id(cpu_node.id) + cpu_node_first.ts = cpu_node.ts + cpu_node_first.dur = gpu_node.ts - cpu_node.ts + cpu_node_first.set_child_gpu = gpu_node + for child_node in cpu_node_first.children: + child_node.parent = cpu_node_first.id + if cpu_node_first.ts >= gpu_node.ts or cpu_node_first.dur <= 0: + err_msg = (f"Invalid timestamps for the first split CPU node derived from {original_cpu_info}\n" + f"\tFirst Split CPU Node Timestamp: {cpu_node_first.ts}, \n" + f"\tGPU Node Timestamp: {gpu_node.ts}, \n" + f"\tFirst Split CPU Node Duration: {cpu_node_first.dur}.") + self.logger.error(err_msg) + raise ValueError(err_msg) + + self.logger.debug(f"First Split CPU Node ID {cpu_node_first.id} ({cpu_node_first.name}), " + f"Duration: {cpu_node_first.dur}") + + gpu_node_id = self.id_assigner.assign_unique_id(gpu_node.id) + gpu_node.id = gpu_node_id + + cpu_node_second = copy.deepcopy(cpu_node) + cpu_node_second.id = self.id_assigner.assign_unique_id(cpu_node.id) + cpu_node_second.ts = gpu_node.ts + cpu_node_second.dur = cpu_node.dur - (gpu_node.ts - cpu_node.ts) + cpu_node_second.set_child_gpu(None) + cpu_node_second.add_data_dep(cpu_node_first) + for child_node in cpu_node_second.children: + child_node.parent = cpu_node_second.id + if cpu_node_second.ts <= cpu_node_first.ts or cpu_node_second.dur <= 0: + err_msg = (f"Invalid timestamps for the second split CPU node derived from {original_cpu_info}\n" + f"\tFirst Split Timestamp: {cpu_node_first.ts}, \n" + f"\tSecond Split Timestamp: {cpu_node_second.ts}, \n" + f"\tSecond Split Duration: {cpu_node_second.dur}.") + self.logger.error(err_msg) + raise ValueError(err_msg) + + self.logger.debug(f"Second Split CPU Node ID {cpu_node_second.id} ({cpu_node_second.name}), " + f"Duration: {cpu_node_second.dur}.") + + return cpu_node_first, cpu_node_second, gpu_node + + def update_phase_end_nids(self) -> None: + """ + Updates the phase end node IDs with the largest new node ID assigned + during the splitting of CPU nodes with GPU children. Utilizes the + get_assigned_ids function from UniqueIdAssigner to find all new IDs and + selects the largest one for each original node ID. + + This ensures that the phase end boundaries are correctly maintained after + splitting the nodes. + """ + self.logger.info( + "Updating phase end node IDs with the largest new IDs after node splitting." + ) + updated_phase_end_nids = [] + for node_id in self.phase_end_nids: + assigned_ids = self.id_assigner.get_assigned_ids(node_id) + if assigned_ids: + updated_phase_end_nids.append(max(assigned_ids)) + updated_phase_end_nids.sort() + self.phase_end_nids = updated_phase_end_nids - def is_gpu_op( - self, - node: Dict[str, Any] - ) -> bool: + def update_input_tensor_map(self, nid: int, inputs: List[List[int]]) -> None: """ - Takes a PyTorch node and returns true if it is a GPU operator. + Updates input_storage_id_nid_map and input_tensor_id_nid_map with input + tensor information. + + Each dictionary is populated with mappings between storage ID (or tensor ID) + and node IDs. For example, if node 0 takes tensor 10 as an input, a new + mapping will be created like this `10: [0]`. - All GPU operators have a category field. + Args: + nid (int): Node ID associated with the input tensors. + inputs (List[List[int]]): List of input tensor data. """ - return self.has_cat_field(node) + for i in inputs: + tensor = list_to_pytorch_tensor(i) + if tensor.is_valid(): + if tensor.has_valid_storage_id(): + storage_id = tensor.storage_id + self.input_storage_id_nid_map.setdefault( + storage_id, [] + ).append(nid) + else: + tensor_id = tensor.tensor_id + self.input_tensor_id_nid_map.setdefault( + tensor_id, [] + ).append(nid) - def find_children_gpu_ops( - self, - root_cpu_nid: int, - cpu_node: Dict[str, Any], - ) -> None: + def update_output_tensor_map(self, nid: int, outputs: List[List[int]]) -> None: """ - Discovers all GPU operators under a CPU operator. + Updates output_storage_id_nid_map and output_tensor_id_nid_map with output + tensor information. + + Each dictionary is populated with mappings between storage ID (or tensor ID) + and node IDs. For example, if node 0 produces tensor 10 as an output, + a new mapping will be created like this `10: [0]`. - Once discovered, GPU operators are tracked in pt_gpu_node_dict. + Args: + nid (int): Node ID associated with the output tensors. + outputs (List[List[int]]): List of output tensor data. """ - cpu_nid = cpu_node["id"] - for node in self.pt_nodes: - if node["parent"] == cpu_nid: - if self.is_gpu_op(node): - self.pt_gpu_node_dict.setdefault(root_cpu_nid, []).append(node) + for o in outputs: + tensor = list_to_pytorch_tensor(o) + if tensor.is_valid(): + if tensor.has_valid_storage_id(): + storage_id = tensor.storage_id + self.output_storage_id_nid_map.setdefault( + storage_id, [] + ).append(nid) else: - # label or CPU operators - self.find_children_gpu_ops(root_cpu_nid, node) + tensor_id = tensor.tensor_id + self.output_tensor_id_nid_map.setdefault( + tensor_id, [] + ).append(nid) + + def convert_to_chakra_node(self, pytorch_node: PyTorchNode) -> ChakraNode: + """ + Converts a PyTorchNode to a ChakraNode. + + Args: + pytorch_node (PyTorchNode): The PyTorch node to convert. + + Returns: + ChakraNode: The converted Chakra node. + """ + self.logger.debug(f"Converting PyTorch node ID {pytorch_node.id} to Chakra node.") + + chakra_node = ChakraNode() + chakra_node.id = pytorch_node.id + chakra_node.name = pytorch_node.name + chakra_node.type = self.get_chakra_node_type_from_pytorch_node(pytorch_node) + if pytorch_node.parent in self.chakra_nodes: + chakra_node.ctrl_deps.append(pytorch_node.parent) + chakra_node.duration_micros = pytorch_node.dur if pytorch_node.has_dur() else 0 + chakra_node.inputs.values = str(pytorch_node.inputs) + chakra_node.inputs.shapes = str(pytorch_node.input_shapes) + chakra_node.inputs.types = str(pytorch_node.input_types) + chakra_node.outputs.values = str(pytorch_node.outputs) + chakra_node.outputs.shapes = str(pytorch_node.output_shapes) + chakra_node.outputs.types = str(pytorch_node.output_types) + chakra_node.attr.extend([ + ChakraAttr(name="rf_id", int64_val=pytorch_node.rf_id), + ChakraAttr(name="fw_parent", int64_val=pytorch_node.fw_parent), + ChakraAttr(name="seq_id", int64_val=pytorch_node.seq_id), + ChakraAttr(name="scope", int64_val=pytorch_node.scope), + ChakraAttr(name="tid", int64_val=pytorch_node.tid), + ChakraAttr(name="fw_tid", int64_val=pytorch_node.fw_tid), + ChakraAttr(name="op_schema", string_val=pytorch_node.op_schema), + ChakraAttr(name="is_cpu_op", int32_val=not pytorch_node.is_gpu_op()), + ChakraAttr(name="ts", int64_val=pytorch_node.ts) + ]) + return chakra_node + + def get_chakra_node_type_from_pytorch_node(self, pytorch_node: PyTorchNode) -> int: + """ + Determines the Chakra node type from a PyTorch node. + + Args: + pytorch_node (PyTorchNode): The PyTorch node to determine the type of. + + Returns: + int: The corresponding Chakra node type. + """ + if pytorch_node.is_gpu_op() and ( + "ncclKernel" in pytorch_node.name or "ncclDevKernel" in pytorch_node.name + ): + return COMM_COLL_NODE + elif pytorch_node.is_gpu_op(): + return COMP_NODE + elif ("c10d::" in pytorch_node.name) or ("nccl:" in pytorch_node.name): + return COMM_COLL_NODE + elif (pytorch_node.op_schema != "") or pytorch_node.outputs: + return COMP_NODE + return INVALID_NODE - def dfs( - self, - node: Dict[str, Any], - root_nid: int, - ) -> int: - """ - Discovers all PyTorch CPU operators under a given node while populating - pt_cpu_node_dict, After that, returns the largest NID in the tree. + def get_nccl_node(self, node: PyTorchNode) -> PyTorchNode: """ - nid = node["id"] - node_type = self.get_pytorch_node_type(node) - if node_type == PyTorchNodeType.GPU_OP: - return -1 - elif node_type == PyTorchNodeType.CPU_OP: - self.pt_cpu_node_dict[nid] = node - self.find_children_gpu_ops(node["id"], node) - return nid - elif node_type == PyTorchNodeType.LABEL: - largest_nid = -1 - for child in self.pt_nodes: - # We should not call dfs for the root node or phase root nodes - # as they will be covered by other DFS calls. - if child["parent"] == nid: - largest_nid = max(largest_nid, self.dfs(child, root_nid)) - return largest_nid - else: - raise ValueError(f"Invalid node type: {node_type}") - return -1 + Returns a PyTorch NCCL node for a given Chakra CPU node. - def discover_pytorch_cpu_ops( - self - ) -> None: - """ - Discovers PyTorch CPU operators and populate pt_cpu_node_dict. + Critical for identifying communication type and size in communication nodes. + There are two primary cases to consider: when the given node is a parent + of a record_param_comms node or a NCCL node. - Run DFS on a root node and phase root nodes as they may have CPU operators. - DFS populates pt_cpu_node_dict and returns the largest NID within the phase. - """ - root_nids = self.find_root_nids(self.pt_nodes) - for node in self.pt_nodes: - if self.is_phase_root_node(root_nids, node): - largest_nid_within_phase = self.dfs(node, root_nids) - if largest_nid_within_phase != -1: - self.inter_phase_dependency.append(largest_nid_within_phase) + Args: + node (PyTorchNode): The parent node for which the NCCL node is needed. - # Make sure that the NIDs in inter_phase_dependency are in the increasing order. - self.inter_phase_dependency.sort() + Returns: + PyTorchNode: The corresponding NCCL node. - def assign_chakra_ids( - self, - total_assigned_ids: Dict[int,bool], - assigned_ids: List[int], - initial_id_to_assign: int - ) -> int: - """ - This function is used to assign unique ids to the ops. During the conversion, we may decompose an op into multiple - ops. So it is important to re-assign unique ids to all ops and make sure the ops that should be executed first have - smaller ids. - """ - orig_id = initial_id_to_assign - while True: - if initial_id_to_assign in total_assigned_ids.keys(): - initial_id_to_assign += 1 + Raises: + ValueError: If no corresponding NCCL node is found. + """ + self.logger.debug(f"Retrieving NCCL node for PyTorch node ID {node.id}.") + if node.record_param_comms_node: + record_param_comms_node = node.record_param_comms_node + if record_param_comms_node.nccl_node: + return record_param_comms_node.nccl_node else: - total_assigned_ids[initial_id_to_assign] = True - if orig_id in assigned_ids.keys(): - assigned_ids[orig_id].append(initial_id_to_assign) - else: - assigned_ids[orig_id] = [initial_id_to_assign] - return initial_id_to_assign + err_msg = "No NCCL node found in the record_param_comms node." + self.logger.error(err_msg) + raise ValueError(err_msg) + elif node.nccl_node: + return node.nccl_node + else: + err_msg = "No NCCL node associated with the given PyTorch node." + self.logger.error(err_msg) + raise ValueError(err_msg) - def merge_gpu_ops_with_cpu_ops( - self, - ) -> Any: - """ - This function decomposes the CPU ops that have GPU child ops into multiple sub_ops. - This required to allow running GPU ops and CPU ops at the same time. - """ - self.logger.info("Merge CPU ops with GPU ops") - - decomposed_nodes = [] - assigned_ids = {} - total_assigned_ids = {} - new_pt_gpu_node_dict = {} - decomposed_nodes_dep = {} - for nid, node in self.pt_cpu_node_dict.items(): - if self.has_gpu_op(nid): - self.pt_gpu_node_dict[nid] = sorted(self.pt_gpu_node_dict[nid], key=lambda kv: kv["ts"]) - - for gpu_node in self.pt_gpu_node_dict[nid]: - assert (node["ts"] + node["dur"]) > gpu_node["ts"] - - last_ts = node["ts"] - for i in range(len(self.pt_gpu_node_dict[nid])+1): - copy_node = copy.deepcopy(node) - copy_node["id"] = self.assign_chakra_ids(total_assigned_ids, assigned_ids, nid) - copy_node["name"] = copy_node["name"]+"("+str(i)+")" - if i < len(self.pt_gpu_node_dict[nid]): - self.pt_gpu_node_dict[nid][i]["id"] =\ - self.assign_chakra_ids( - total_assigned_ids, - assigned_ids, - self.pt_gpu_node_dict[nid][i]["id"]) - assert self.pt_gpu_node_dict[nid][i]["ts"] > copy_node["ts"] - copy_node["ts"] = last_ts - copy_node["dur"] = self.pt_gpu_node_dict[nid][i]["ts"]-last_ts - last_ts = self.pt_gpu_node_dict[nid][i]["ts"] - new_pt_gpu_node_dict.setdefault(copy_node["id"], []).append(self.pt_gpu_node_dict[nid][i]) - else: - copy_node["dur"] = copy_node["dur"]-(last_ts-copy_node["ts"]) - copy_node["ts"] = last_ts - last_ts = copy_node["ts"]+copy_node["dur"] - - assert (copy_node["ts"] >= 0) and (copy_node["dur"] > 0) - if i > 0: - assert copy_node["ts"] > decomposed_nodes[-1]["ts"] - decomposed_nodes_dep[copy_node["id"]] = decomposed_nodes[-1]["id"] - decomposed_nodes.append(copy_node) - else: - node["id"] = self.assign_chakra_ids(total_assigned_ids, assigned_ids, nid) - decomposed_nodes.append(node) + def get_prev_phase_end_nid(self, node: ChakraNode) -> int: + """ + Returns the Node ID (NID) of the latest node of the previous phase for + the given ChakraNode. - merged_pt_cpu_node_dict = { - decomposed_node["id"]: decomposed_node for decomposed_node in decomposed_nodes - } + This method is used to find the closest but smaller value from + phase_end_nids compared to the given node's ID. It helps in + determining the dependencies between different phases in the trace. - self.pt_cpu_node_dict = merged_pt_cpu_node_dict - self.pt_gpu_node_dict = new_pt_gpu_node_dict - return assigned_ids, decomposed_nodes_dep + Args: + node (ChakraNode): The node to find the previous phase dependency for. - def validate_pt_node_dict( - self, - ) -> None: + Returns: + int: NID of the latest node of the previous phase, or -1 if none. """ - Raises an exception if any anomaly is detected in pt_cpu_node_dict or - pt_gpu_node_dict. + self.logger.debug( + f"Finding previous inter-phase dependency for node ID {node.id}." + ) + index = bisect.bisect_left(self.phase_end_nids, node.id) - * NIDs of CPU nodes should be unique. - * CPU operators can have at most one GPU operator. - """ - seen_nids = set() - for nid, node in self.pt_cpu_node_dict.items(): - assert nid == node["id"] - if nid in seen_nids: - self.logger.error(f"NID {nid} is duplicate") - raise ValueError("Duplicate NID detected!") - seen_nids.add(nid) - if nid in self.pt_gpu_node_dict.keys(): - assert len(self.pt_gpu_node_dict[nid]) == 1 + if index == 0: + # All elements in the list are greater than node.id; + # no element satisfies the condition. + return -1 + else: + # The element at index-1 will be the closest, smaller value + # compared to node.id. + return self.phase_end_nids[index - 1] - def discover_pytorch_comm_ops( - self, - assigned_ids: List[int] - ) -> None: + def identify_data_dependency(self) -> None: """ - Discovers communication nodes and populate pt_record_param_comms_node_dict - and pt_nccl_node_dict. - """ - self.logger.info("Discover communication nodes") - for node in self.pt_nodes: - if self.is_record_param_comms_node(node): - if node["parent"] in assigned_ids.keys(): - nodes_to_assign = assigned_ids[node["parent"]] - for parent_id in nodes_to_assign: - self.pt_record_param_comms_node_dict.update({parent_id: node}) - else: - self.pt_record_param_comms_node_dict.update({node["parent"]: node}) - if self.is_nccl_node(node): - if node["parent"] in assigned_ids.keys(): - nodes_to_assign=assigned_ids[node["parent"]] - for parent_id in nodes_to_assign: - self.pt_nccl_node_dict.update({parent_id: node}) - else: - self.pt_nccl_node_dict.update({node["parent"]: node}) + Identifies data dependencies between nodes using tensor input/output + relationships. - for i in range(len(self.inter_phase_dependency)): - # If an op is decomposed into multiple sub_ops, we want to point to the last subop [-1] - self.inter_phase_dependency[i] = assigned_ids[self.inter_phase_dependency[i]][-1] - self.inter_phase_dependency.sort() - - def update_input_tensor_dict( - self, - nid: int, - inputs: str - ) -> int: + Determines the relationships based on whether the tensors use storage IDs + or tensor IDs. """ - Updates input_storage_id_nid_dict and input_tensor_id_nid_dict + self.logger.info("Identifying data dependencies among nodes.") + self.identify_data_dependency_with_storage_id() + self.identify_data_dependency_with_tensor_id() - Each dictionary is populcated with mappings between storage ID - (or tensor ID) and corresponding node IDs. If node 0 takes tensor 10 as - an input, a new mapping will be created like this `10: [0]` + def identify_data_dependency_with_storage_id(self) -> None: """ - for i in inputs: - if self.is_valid_tensor(i): - if self.has_valid_storage_id(i): - storage_id = self.get_storage_id_from_tensor(i) - self.input_storage_id_nid_dict.setdefault(storage_id, []).append(nid) - else: - tensor_id = self.get_tensor_id_from_tensor(i) - self.input_tensor_id_nid_dict.setdefault(tensor_id, []).append(nid) + Identifies data dependency between nodes based on storage IDs. - def update_output_tensor_dict( - self, - nid: int, - outputs: str - ) -> int: + Uses the mapping of input and output tensors to their storage IDs to + establish dependencies. """ - Updates output_storage_id_nid_dict and output_tensor_id_nid_dict. + self.logger.info("Identifying data dependencies using storage IDs.") + self.update_data_dependencies( + self.input_storage_id_nid_map, + self.output_storage_id_nid_map) - Each dictionary is populcated with mappings between storage ID - (or tensor ID) and corresponding node IDs. If node 0 produces tensor 10 - as an output, a new mapping will be created like this `10: [0]`. + def identify_data_dependency_with_tensor_id(self) -> None: """ - for o in outputs: - if self.is_valid_tensor(o): - if self.has_valid_storage_id(o): - storage_id = self.get_storage_id_from_tensor(o) - self.output_storage_id_nid_dict.setdefault(storage_id, []).append(nid) - else: - tensor_id = self.get_tensor_id_from_tensor(o) - self.output_tensor_id_nid_dict.setdefault(tensor_id, []).append(nid) - - def convert_pytorch_node_to_chakra_node( - self, - pt_node: Dict[str, Any] - ) -> ChakraNode: - """ - Converts a PyToch node to a Chakra node. - """ - ck_node = ChakraNode() - ck_node.id = pt_node["id"] - ck_node.name = pt_node["name"] - ck_node.type = self.get_chakra_node_type_from_pytorch_node(pt_node) - ck_node.ctrl_deps.append(pt_node["parent"]) - if "dur" in pt_node.keys(): - ck_node.duration_micros = pt_node["dur"] - else: - ck_node.duration_micros = 0 - ck_node.inputs.values = str(pt_node["inputs"]) - ck_node.inputs.shapes = str(pt_node["input_shapes"]) - ck_node.inputs.types = str(pt_node["input_types"]) - ck_node.outputs.values = str(pt_node["outputs"]) - ck_node.outputs.shapes = str(pt_node["output_shapes"]) - ck_node.outputs.types = str(pt_node["output_types"]) - ck_node.attr.append( - ChakraAttr(name="is_cpu_op", - bool_val=self.is_cpu_op(pt_node))) - if "fw_parent" in pt_node.keys(): - ck_node.attr.append( - ChakraAttr(name="fw_parent", - int64_val=pt_node["fw_parent"])) - if "fw_tid" in pt_node.keys(): - ck_node.attr.append( - ChakraAttr(name="fw_tid", - int64_val=pt_node["fw_tid"])) - if "op_schema" in pt_node.keys(): - ck_node.attr.append( - ChakraAttr(name="op_schema", - string_val=pt_node["op_schema"])) - if "seq_id" in pt_node.keys(): - ck_node.attr.append( - ChakraAttr(name="seq_id", - int64_val=pt_node["seq_id"])) - if "rf_id" in pt_node.keys(): - ck_node.attr.append( - ChakraAttr(name="rf_id", - int64_val=pt_node["rf_id"])) - if "scope" in pt_node.keys(): - ck_node.attr.append( - ChakraAttr(name="scope", - int64_val=pt_node["scope"])) - if "tid" in pt_node.keys(): - ck_node.attr.append( - ChakraAttr(name="tid", - int64_val=pt_node["tid"])) - return ck_node - - def get_nccl_node( - self, - nid: int - ) -> Dict[str, Any]: - """ - Returns a PyTorch NCCL node for a given Chakra NID. - - For communication nodes, finding a corresponding NCCL node is critical - to identify the communication type and communication size. - - There are two cases: - (1) Given node is a parent of a record_param_comms node - * In this case, the corresponding NCCL node should be a child of - the record_param_comms_pt node. - (2) Given node is a parent of a NCCL node - """ - pt_nccl_node = None - if nid in self.pt_record_param_comms_node_dict.keys(): - pt_record_param_comms_node = self.pt_record_param_comms_node_dict[nid] - rpcp_nid = pt_record_param_comms_node["id"] - if rpcp_nid in self.pt_nccl_node_dict.keys(): - pt_nccl_node = self.pt_nccl_node_dict[rpcp_nid] - else: - raise ValueError( - f"NID {nid} has a pt_record_param_comms_node " - f"but it does not have a correspondin pt_nccl_node.") - elif nid in self.pt_nccl_node_dict.keys(): - pt_nccl_node = self.pt_nccl_node_dict[nid] - else: - raise ValueError( - f"NID {nid} does not have an entry in pt_record_param_comms_node_dict " - f"nor pt_nccl_node_dict" - ) - return pt_nccl_node + Identifies data dependency between nodes based on tensor IDs. - def add_gpu_chakra_node( - self, - ck_cpu_node: ChakraNode - ) -> None: - """ - Converts a PyTorch GPU node to a Chakra node and add it to ck_node_dict. - """ - assert ck_cpu_node.id in self.pt_gpu_node_dict.keys() - pt_gpu_node = self.pt_gpu_node_dict[ck_cpu_node.id][0] - if len(self.pt_gpu_node_dict[ck_cpu_node.id]) != 1: - raise ValueError(f"Chakra node {ck_cpu_node.id} has more than one GPU operators") - ck_gpu_node = self.convert_pytorch_node_to_chakra_node(pt_gpu_node) - if ck_cpu_node.type == COMM_COLL_NODE: - pt_nccl_node = self.get_nccl_node(ck_cpu_node.id) - ck_gpu_node.attr.append( - ChakraAttr(name="comm_type", - int64_val=self.get_collective_comm_type(pt_nccl_node))) - ck_gpu_node.attr.append( - ChakraAttr(name="comm_size", - int64_val=self.get_comm_size(pt_nccl_node))) - attr = ChakraAttr(name="involved_dim") - for _ in range(self.num_dims): - attr.bool_list.values.append(True) - ck_gpu_node.attr.append(attr) - ck_gpu_node.data_deps.append(ck_cpu_node.id) - self.ck_node_dict[ck_gpu_node.id] = ck_gpu_node - - def identify_data_dependency_with_storage_id( - self - ) -> None: - """ - Identifies data dependency between operators with storage IDs. + Establishes dependencies using tensor IDs for tensors without valid + storage IDs. """ - self.logger.info("Identify data dependency with storage IDs") - for input_storage_id, child_nids in self.input_storage_id_nid_dict.items(): - if input_storage_id in self.output_storage_id_nid_dict: - parent_nids = self.output_storage_id_nid_dict[input_storage_id] - for child_nid in child_nids: - for parent_nid in parent_nids: - child_node = self.ck_node_dict[child_nid] - if (parent_nid not in child_node.data_deps)\ - and (parent_nid < child_nid): - child_node.data_deps.append(parent_nid) + self.logger.info("Identifying data dependencies using tensor IDs.") + self.update_data_dependencies( + self.input_tensor_id_nid_map, + self.output_tensor_id_nid_map) - def identify_data_dependency_with_tensor_id( - self - ) -> None: + def update_data_dependencies(self, input_map: Dict[int, List[int]], + output_map: Dict[int, List[int]]) -> None: """ - Identifies data dependency between operators with tensor IDs. + Updates data dependencies for nodes based on input and output tensor maps. + + Args: + input_map (Dict[int, List[int]]): Map of input tensor IDs to node IDs. + output_map (Dict[int, List[int]]): Map of output tensor IDs to node IDs. """ - self.logger.info("Identify data dependency with tensor IDs") - for input_tensor_id, child_nids in self.input_tensor_id_nid_dict.items(): - if input_tensor_id in self.output_tensor_id_nid_dict: - parent_nids = self.output_tensor_id_nid_dict[input_tensor_id] + self.logger.debug("Updating data dependencies for nodes.") + for input_id, child_nids in input_map.items(): + if input_id in output_map: + parent_nids = output_map[input_id] for child_nid in child_nids: for parent_nid in parent_nids: - child_node = self.ck_node_dict[child_nid] + child_node = self.chakra_nodes[child_nid] if (parent_nid not in child_node.data_deps)\ - and (parent_nid < child_nid): + and (parent_nid < child_nid): child_node.data_deps.append(parent_nid) - def identify_data_dependency( - self - ) -> None: + def write_chakra_et(self) -> None: """ - Identifies data dependency between operators using tensors. + Writes the Chakra execution trace by encoding global metadata and nodes. - Dependencies between operators can be identified by their tensor input/ - output relationships. A tensor can be identified by either a storage ID - or a tensor ID. Use the storage ID if it's valid; otherwise, use the - tensor ID. + Encodes and writes both the metadata and individual nodes to create a + complete execution trace. """ - self.logger.info("Identify data dependency") - self.identify_data_dependency_with_storage_id() - self.identify_data_dependency_with_tensor_id() + self.logger.info("Writing Chakra execution trace.") + self._write_global_metadata() + self._encode_and_write_nodes() + self.logger.info("Chakra execution trace writing completed.") - def write_chakra_et( - self, - ) -> None: - self.logger.info("Write Chakra trace") + def _write_global_metadata(self) -> None: + """ + Encodes and writes global metadata for the Chakra execution trace. - self.logger.info("Encode global metadata") - md = GlobalMetadata( + This process includes encoding metadata like schema, process ID, timestamps, + and other relevant information for the Chakra execution trace. + """ + self.logger.info("Encoding global metadata for Chakra execution trace.") + global_metadata = GlobalMetadata( attr=[ - ChakraAttr(name="schema", string_val=self.pt_schema), - ChakraAttr(name="pid", uint64_val=self.pt_pid), - ChakraAttr(name="time", string_val=self.pt_time), - ChakraAttr(name="start_ts", uint64_val=self.pt_start_ts), - ChakraAttr(name="finish_ts", uint64_val=self.pt_finish_ts) + ChakraAttr(name="schema", string_val=self.pytorch_schema), + ChakraAttr(name="pid", uint64_val=self.pytorch_pid), + ChakraAttr(name="time", string_val=self.pytorch_time), + ChakraAttr(name="start_ts", uint64_val=self.pytorch_start_ts), + ChakraAttr(name="finish_ts", uint64_val=self.pytorch_finish_ts) ] ) - encode_message(self.chakra_et, md) + encode_message(self.chakra_et, global_metadata) + + def _encode_and_write_nodes(self) -> None: + """ + Encodes and writes nodes for the Chakra execution trace. - self.logger.info("Encode nodes (operators)") + Each node from the PyTorch execution trace is encoded and written into the + Chakra format. This includes node IDs, names, types, dependencies, and + other attributes. + """ + self.logger.info("Encoding and writing nodes for Chakra execution trace.") seen_nids = set() - for nid in sorted(self.ck_node_dict.keys()): + for nid in sorted(self.chakra_nodes.keys()): if nid in seen_nids: - self.logger.error(f"NID {nid} is duplicate") - raise ValueError("Duplicate NID detected!") + err_msg = f"Duplicate NID {nid} detected in Chakra nodes." + self.logger.error(err_msg) + raise ValueError(err_msg) seen_nids.add(nid) - ck_node = self.ck_node_dict[nid] - encode_message(self.chakra_et, ck_node) - - self.logger.info("All Chakra nodes are written to the output file") - - def convert( - self - ) -> None: - self.sort_pytorch_nodes_with_starting_time() - - self.discover_pytorch_cpu_ops() - - total_runtime_ns = self.get_total_runtime_ms(list(self.pt_cpu_node_dict.values())) * 1000 - self.logger.info(f"Total runtime exluding children operators: {total_runtime_ns} ns") - - assigned_ids, decomposed_nodes_dep = self.merge_gpu_ops_with_cpu_ops() - - self.validate_pt_node_dict() + chakra_node = self.chakra_nodes[nid] + encode_message(self.chakra_et, chakra_node) - self.discover_pytorch_comm_ops(assigned_ids) - - self.logger.info("Convert PyTorch nodes to Chakra nodes") - for pt_nid, pt_node in self.pt_cpu_node_dict.items(): - self.update_input_tensor_dict(pt_node["id"], pt_node["inputs"]) - self.update_output_tensor_dict(pt_node["id"], pt_node["outputs"]) - if pt_nid in self.pt_gpu_node_dict.keys(): - for pt_gpu_node in self.pt_gpu_node_dict[pt_nid]: - # Assumption: same input / output as its parent CPU operator - self.update_input_tensor_dict(pt_gpu_node["id"], pt_gpu_node["inputs"]) - # For now we ignore GPU->CPU dependencies since it creates unwanted dependencies. - # self.update_output_tensor_dict(pt_gpu_node["id"], pt_gpu_node["outputs"]) - - ck_node = self.convert_pytorch_node_to_chakra_node(pt_node) - self.ck_node_dict[ck_node.id] = ck_node - if self.has_gpu_op(ck_node.id): - self.add_gpu_chakra_node(ck_node) - - # Adding previous phase node dependency - dep_nid = self.get_prev_inter_phase_dep_nid(ck_node) - if (dep_nid != -1) and (dep_nid not in ck_node.data_deps): - ck_node.data_deps.append(dep_nid) - - # Adding decomposed nodes dependency - # When we decompose a CPU op into multiple sub_ops, these ops have linear dependeny with themselves - # For example, the first sub_op should be finished before the second sub_op. Here, we capture these dependencies. - if (pt_nid in decomposed_nodes_dep.keys())\ - and (decomposed_nodes_dep[pt_nid] not in ck_node.data_deps): - ck_node.data_deps.append(decomposed_nodes_dep[pt_nid]) - - self.identify_data_dependency() + def close_chakra_execution_trace(self) -> None: + """ + Closes the Chakra execution trace file if it is open. - self.write_chakra_et() + Ensures proper closure of the trace file to preserve data integrity. + """ + self.logger.info("Closing Chakra execution trace file.") + if self.chakra_et and not self.chakra_et.closed: + self.chakra_et.close() diff --git a/et_converter/pytorch_node.py b/et_converter/pytorch_node.py new file mode 100644 index 00000000..b5068734 --- /dev/null +++ b/et_converter/pytorch_node.py @@ -0,0 +1,628 @@ +#!/usr/bin/env python3 + +from enum import Enum +from typing import Any, Dict, List, Optional + +from chakra.et_def.et_def_pb2 import ( + ALL_REDUCE, + ALL_GATHER, + BROADCAST, + ALL_TO_ALL, + REDUCE_SCATTER, +) + + +class PyTorchNodeType(Enum): + CPU_OP = 1 + GPU_OP = 2 + LABEL = 3 # Non-operator nodes + + +class PyTorchNode: + """ + Represents a node in a PyTorch execution trace. + + Attributes: + node_data (Dict[str, Any]): Data of the PyTorch node. + data_deps (List[PyTorchNode]): List of data-dependent parent nodes. + children (List[PyTorchNode]): List of child nodes. + """ + + def __init__(self, node_data: Dict[str, Any]) -> None: + """ + Initializes a PyTorchNode object with the provided node data. + + Args: + node_data (Dict[str, Any]): Dictionary containing the data of the + PyTorch node. + """ + self.node_data = node_data + self.data_deps: List['PyTorchNode'] = [] + self.children: List['PyTorchNode'] = [] + self.child_gpu: Optional['PyTorchNode'] = None + self.record_param_comms_node: Optional['PyTorchNode'] = None + self.nccl_node: Optional['PyTorchNode'] = None + + def __repr__(self) -> str: + """ + Represent the PyTorchNode as a string. + Returns: + str: A detailed string representation of the PyTorchNode. + """ + return ( + f"PyTorchNode(" + f"id={self.id}, name={self.name}, " + f"op_type={self.get_op_type()}, " + f"timestamp={self.ts}, duration={self.dur})" + ) + + @property + def name(self) -> str: + """ + Returns the name of the node. + + Returns: + str: Name of the node. + """ + return self.node_data["name"] + + @name.setter + def name(self, value: str) -> None: + """ + Sets the name of the node. + + Args: + value (str): The new name of the node. + """ + self.node_data["name"] = value + + @property + def id(self) -> int: + """ + Returns the node ID. + + Returns: + int: ID of the node. + """ + return self.node_data["id"] + + @id.setter + def id(self, value: int) -> None: + """ + Sets the node ID. + + Args: + value (int): The new ID of the node. + """ + self.node_data["id"] = value + + @property + def rf_id(self) -> int: + """ + Returns the unique record function ID. + + Returns: + int: The unique record function ID. + """ + return self.node_data["rf_id"] + + @rf_id.setter + def rf_id(self, value: int) -> None: + """ + Sets the unique record function ID. + + Args: + value (int): The new unique record function ID. + """ + self.node_data["rf_id"] = value + + @property + def parent(self) -> int: + """ + Returns the parent node ID. + + Returns: + int: The parent node ID. + """ + return self.node_data["parent"] + + @parent.setter + def parent(self, value: int) -> None: + """ + Sets the parent node ID. + + Args: + value (int): The new parent node ID. + """ + self.node_data["parent"] = value + + @property + def fw_parent(self) -> int: + """ + Returns the parent node ID from the forward thread. + + Returns: + int: The parent node ID from the forward thread. + """ + return self.node_data["fw_parent"] + + @fw_parent.setter + def fw_parent(self, value: int) -> None: + """ + Sets the parent node ID from the forward thread. + + Args: + value (int): The new parent node ID from the forward thread. + """ + self.node_data["fw_parent"] = value + + @property + def seq_id(self) -> int: + """ + Returns the record function sequence ID used to correlate forward and + backward operators. + + Returns: + int: The record function sequence ID. + """ + return self.node_data["seq_id"] + + @seq_id.setter + def seq_id(self, value: int) -> None: + """ + Sets the record function sequence ID. + + Args: + value (int): The new sequence ID. + """ + self.node_data["seq_id"] = value + + @property + def scope(self) -> int: + """ + Returns the record scope. + + Returns: + int: The record scope. + """ + return self.node_data["scope"] + + @scope.setter + def scope(self, value: int) -> None: + """ + Sets the record scope. + + Args: + value (int): The new scope value. + """ + self.node_data["scope"] = value + + @property + def tid(self) -> int: + """ + Returns the record function thread ID. + + Returns: + int: The record function thread ID. + """ + return self.node_data["tid"] + + @tid.setter + def tid(self, value: int) -> None: + """ + Sets the record function thread ID. + + Args: + value (int): The new thread ID. + """ + self.node_data["tid"] = value + + @property + def fw_tid(self) -> int: + """ + Returns the thread ID of the forward execution thread. + + Returns: + int: The thread ID of the forward execution thread. + """ + return self.node_data["fw_tid"] + + @fw_tid.setter + def fw_tid(self, value: int) -> None: + """ + Sets the thread ID of the forward execution thread. + + Args: + value (int): The new forward thread ID. + """ + self.node_data["fw_tid"] = value + + @property + def op_schema(self) -> str: + """ + Returns the PyTorch operator schema. + + Returns: + str: The PyTorch operator schema. + """ + return self.node_data["op_schema"] + + @op_schema.setter + def op_schema(self, value: str) -> None: + """ + Sets the PyTorch operator schema. + + Args: + value (str): The new operator schema. + """ + self.node_data["op_schema"] = value + + @property + def inputs(self) -> List[Any]: + """ + Returns the array of input arguments. + + Returns: + List[Any]: The array of input arguments. + """ + return self.node_data["inputs"] + + @inputs.setter + def inputs(self, value: List[Any]) -> None: + """ + Sets the array of input arguments. + + Args: + value (List[Any]): The new array of input arguments. + """ + self.node_data["inputs"] = value + + @property + def input_shapes(self) -> List[Any]: + """ + Returns the array of input shapes. + + Returns: + List[Any]: The array of input shapes. + """ + return self.node_data["input_shapes"] + + @input_shapes.setter + def input_shapes(self, value: List[Any]) -> None: + """ + Sets the array of input shapes. + + Args: + value (List[Any]): The new array of input shapes. + """ + self.node_data["input_shapes"] = value + + @property + def input_types(self) -> List[Any]: + """ + Returns the array of input types. + + Returns: + List[Any]: The array of input types. + """ + return self.node_data["input_types"] + + @input_types.setter + def input_types(self, value: List[Any]) -> None: + """ + Sets the array of input types. + + Args: + value (List[Any]): The new array of input types. + """ + self.node_data["input_types"] = value + + @property + def outputs(self) -> List[Any]: + """ + Returns the array of output arguments. + + Returns: + List[Any]: The array of output arguments. + """ + return self.node_data["outputs"] + + @outputs.setter + def outputs(self, value: List[Any]) -> None: + """ + Sets the array of output arguments. + + Args: + value (List[Any]): The new array of output arguments. + """ + self.node_data["outputs"] = value + + @property + def output_shapes(self) -> List[Any]: + """ + Returns the array of output shapes. + + Returns: + List[Any]: The array of output shapes. + """ + return self.node_data["output_shapes"] + + @output_shapes.setter + def output_shapes(self, value: List[Any]) -> None: + """ + Sets the array of output shapes. + + Args: + value (List[Any]): The new array of output shapes. + """ + self.node_data["output_shapes"] = value + + @property + def output_types(self) -> List[Any]: + """ + Returns the array of output types. + + Returns: + List[Any]: The array of output types. + """ + return self.node_data["output_types"] + + @output_types.setter + def output_types(self, value: List[Any]) -> None: + """ + Sets the array of output types. + + Args: + value (List[Any]): The new array of output types. + """ + self.node_data["output_types"] = value + + @property + def ts(self) -> int: + """ + Returns the timestamp of the node. + + Returns: + int: The timestamp of the node. + """ + return self.node_data.get("ts", 0) + + @ts.setter + def ts(self, value: int) -> None: + """ + Sets the timestamp of the node. + + Args: + value (int): The new timestamp of the node. + """ + self.node_data["ts"] = value + + @property + def cat(self) -> str: + """ + Returns the category field of the node. + + Returns: + str: The category field of the node. + """ + return self.node_data.get("cat", "") + + @cat.setter + def cat(self, value: str) -> None: + """ + Sets the category field of the node. + + Args: + value (str): The new category field of the node. + """ + self.node_data["cat"] = value + + @property + def dur(self) -> int: + """ + Returns the duration of the node. + + Returns: + int: The duration of the node. + """ + return self.node_data["dur"] + + @dur.setter + def dur(self, value: int) -> None: + """ + Sets the duration of the node. + + Args: + value (int): The new duration of the node. + """ + self.node_data["dur"] = value + + def has_ts(self) -> bool: + """ + Checks if the node has a timestamp field. + + Returns: + bool: True if the node has a timestamp field, False otherwise. + """ + return "ts" in self.node_data + + def has_cat(self) -> bool: + """ + Checks if the node has a category field. + + Returns: + bool: True if the node has a category field, False otherwise. + """ + return "cat" in self.node_data + + def has_dur(self) -> bool: + """ + Checks if the node has a duration field. + + Returns: + bool: True if the node has a duration field, False otherwise. + """ + return "dur" in self.node_data + + def get_op_type(self) -> PyTorchNodeType: + """ + Determines the type of PyTorch operation. + + Returns: + PyTorchNodeType: The type of the PyTorch operation. + """ + if self.is_gpu_op(): + return PyTorchNodeType.GPU_OP + elif self.node_data.get("op_schema") or self.node_data.get("outputs"): + return PyTorchNodeType.CPU_OP + else: + return PyTorchNodeType.LABEL + + def is_cpu_op(self) -> bool: + """ + Checks if the node is a CPU operator. + + Returns: + bool: True if the node is a CPU operator, False otherwise. + """ + return self.get_op_type() == PyTorchNodeType.CPU_OP + + def is_gpu_op(self) -> bool: + """ + Checks if the node is a GPU operator. + + Returns: + bool: True if the node is a GPU operator, False otherwise. + """ + return self.has_cat() + + def add_data_dep(self, parent_node: 'PyTorchNode') -> None: + """ + Adds a data-dependent parent node to this node. + + Args: + parent_node (PyTorchNode): The parent node to be added. + """ + self.data_deps.append(parent_node) + + def add_child(self, child_node: 'PyTorchNode') -> None: + """ + Adds a child node to this node. + + Args: + child_node (PyTorchNode): The child node to be added. + """ + self.children.append(child_node) + + def set_child_gpu(self, child_gpu_node: Optional['PyTorchNode']) -> None: + """ + Sets a child GPU node for this node. + + Args: + child_gpu_node (Optional[PyTorchNode]): The child GPU node to be set. + """ + self.child_gpu = child_gpu_node + + def is_record_param_comms_op(self) -> bool: + """ + Checks if the node is a record_param_comms operator. + + Returns: + bool: True if the node is a record_param_comms operator, False otherwise. + """ + return "record_param_comms" in self.name + + def is_nccl_op(self) -> bool: + """ + Checks if the node is a NCCL operator. + + Returns: + bool: True if the node is a NCCL operator, False otherwise. + """ + return "nccl:" in self.name + + @property + def comm_size(self) -> int: + """ + Calculates the communication size for the given input types and shapes. + + Returns: + int: The calculated communication size. + """ + comm_size = 1 + for input_type, input_shape in zip(self.input_types, self.input_shapes): + type_size = self.get_data_type_size(input_type) + shape_size = 1 + for dim in input_shape: + shape_size *= dim + comm_size += type_size * shape_size + return comm_size + + @property + def collective_comm_type(self) -> int: + """ + Returns the collective communication type of the node. + + Raises: + ValueError: If the communication type is not found in the mapping. + + Returns: + int: The collective communication type of the node. + """ + comm_type_mapping = { + "all_reduce": ALL_REDUCE, + "all_to_all": ALL_TO_ALL, + "all_gather": ALL_GATHER, + "reduce_scatter": REDUCE_SCATTER, + "broadcast": BROADCAST, + "AllReduce": ALL_REDUCE, + "Broadcast": BROADCAST, + # TODO: Add more cases + } + for key, value in comm_type_mapping.items(): + if key in self.node_data["name"]: + return value + + raise ValueError("Communication type not found in mapping.") + + @staticmethod + def get_data_type_size(data_type: str) -> int: + """ + Returns the data type size of a given data type in string. + + Args: + data_type (str): The data type as a string. + + Returns: + int: The size of the data type in bytes. + + Raises: + ValueError: If the data type is not supported. + """ + data_type_size_map = { + "Tensor(float32)": 4, + "Tensor(float)": 4, + "Tensor(float64)": 8, + "Tensor(double)": 8, + "Tensor(float16)": 2, + "Tensor(half)": 2, + "Tensor(bfloat16)": 2, + "Tensor(complex64)": 8, + "Tensor(complex128)": 16, + "Tensor(uint8)": 1, + "Tensor(int8)": 1, + "Tensor(int16)": 2, + "Tensor(short)": 2, + "Tensor(int32)": 4, + "Tensor(int)": 4, + "Tensor(int64)": 8, + "Tensor(long)": 8, + "Tensor(c10::Half)": 2, + "Tensor(unsigned char)": 1, + "Tensor(long int)": 8, + # TODO: Add more types + } + try: + return data_type_size_map[data_type] + except KeyError: + raise ValueError(f"Unsupported data type: {data_type}") diff --git a/et_converter/pytorch_tensor.py b/et_converter/pytorch_tensor.py new file mode 100644 index 00000000..30c4074d --- /dev/null +++ b/et_converter/pytorch_tensor.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 + +from typing import Any, List + + +class PyTorchTensor: + """ + Represents a tensor with its associated properties. + + Attributes: + tensor_data (List[int]): Data of the tensor including tensor_id, + storage_id, offset, number of elements, and size of each + element in bytes. + """ + + def __init__(self, tensor_data: List[int]) -> None: + """ + Initializes a PyTorchTensor object with the provided tensor data. + + Args: + tensor_data (List[int]): Data of the tensor including tensor_id, + storage_id, offset, number of elements, and size of each + element in bytes. + """ + self.tensor_data = tensor_data + + def is_valid(self) -> bool: + """ + Checks if the tensor data is valid. + + Returns: + bool: True if tensor_data is a list of exactly five integers, + False otherwise. + """ + return (isinstance(self.tensor_data, list) and + len(self.tensor_data) == 6 and + all(isinstance(item, int) for item in self.tensor_data)) + + @property + def tensor_id(self) -> int: + """ + Returns the tensor ID. + + Returns: + int: Tensor ID. + """ + return self.tensor_data[0] + + @property + def storage_id(self) -> int: + """ + Returns the storage ID. + + Returns: + int: Storage ID. + """ + return self.tensor_data[1] + + @property + def offset(self) -> int: + """ + Returns the offset. + + Returns: + int: Offset value. + """ + return self.tensor_data[2] + + @property + def num_elem(self) -> int: + """ + Returns the number of elements in the tensor. + + Returns: + int: Number of elements. + """ + return self.tensor_data[3] + + @property + def elem_bytes(self) -> int: + """ + Returns the size of each element in bytes. + + Returns: + int: Size of each element in bytes. + """ + return self.tensor_data[4] + + def has_valid_storage_id(self) -> bool: + """ + Checks if the tensor has a valid storage ID. + + Returns: + bool: True if the storage ID is greater than 0, False otherwise. + """ + return self.storage_id > 0 + + +def list_to_pytorch_tensor(tensor_list: List[int]) -> PyTorchTensor: + """ + Converts a list representation of a tensor into a PyTorchTensor object. + + Args: + tensor_list (List[int]): Data representing a tensor, including + tensor_id, storage_id, offset, num_elem, elem_bytes. + + Returns: + PyTorchTensor: The PyTorchTensor object created from the data. + """ + return PyTorchTensor(tensor_list) From ca63f651f644f367a1b13a98252fab99e71d4751 Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Thu, 18 Jan 2024 18:06:35 -0500 Subject: [PATCH 02/13] et_converter: Identify cycle dependencies for validation --- et_converter/pytorch2chakra_converter.py | 54 ++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/et_converter/pytorch2chakra_converter.py b/et_converter/pytorch2chakra_converter.py index e92491dc..7e9abf13 100644 --- a/et_converter/pytorch2chakra_converter.py +++ b/et_converter/pytorch2chakra_converter.py @@ -221,6 +221,8 @@ def convert(self) -> None: self.identify_data_dependency() + self.identify_cyclic_dependencies() + self.write_chakra_et() self.close_chakra_execution_trace() @@ -756,6 +758,58 @@ def update_data_dependencies(self, input_map: Dict[int, List[int]], and (parent_nid < child_nid): child_node.data_deps.append(parent_nid) + def identify_cyclic_dependencies(self) -> None: + """ + Identifies if there are any cyclic dependencies among Chakra nodes. + + This method checks for cycles in the graph of Chakra nodes using a + depth-first search (DFS) algorithm. It logs an error message and raises + an exception if a cycle is detected, ensuring the graph is a Directed + Acyclic Graph (DAG). + + Raises: + Exception: If a cyclic dependency is detected among the Chakra nodes. + """ + visited = set() + stack = set() + + def dfs(node_id: int, path: List[int]) -> bool: + """ + Depth-first search to detect cycles. + + Args: + node_id (int): The node ID to start the DFS from. + path (List[int]): The path traversed so far, for tracing the cycle. + + Returns: + bool: True if a cycle is detected, False otherwise. + """ + if node_id in stack: + cycle_nodes = " -> ".join( + [self.chakra_nodes[n].name for n in path + [node_id]] + ) + self.logger.error(f"Cyclic dependency detected: {cycle_nodes}") + return True + if node_id in visited: + return False + + visited.add(node_id) + stack.add(node_id) + path.append(node_id) + for child_id in self.chakra_nodes[node_id].data_deps: + if dfs(child_id, path.copy()): + return True + stack.remove(node_id) + path.pop() + return False + + for node_id in self.chakra_nodes: + if dfs(node_id, []): + raise Exception( + f"Cyclic dependency detected starting from node " + f"{self.chakra_nodes[node_id].name}" + ) + def write_chakra_et(self) -> None: """ Writes the Chakra execution trace by encoding global metadata and nodes. From 19550c8d3dca9d9893fe293b17a751d563a6dd8d Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Thu, 18 Jan 2024 22:03:12 -0500 Subject: [PATCH 03/13] et_converter: Remove phase end nid tracking --- et_converter/pytorch2chakra_converter.py | 114 ----------------------- 1 file changed, 114 deletions(-) diff --git a/et_converter/pytorch2chakra_converter.py b/et_converter/pytorch2chakra_converter.py index 7e9abf13..46b568ce 100644 --- a/et_converter/pytorch2chakra_converter.py +++ b/et_converter/pytorch2chakra_converter.py @@ -95,7 +95,6 @@ class PyTorch2ChakraConverter: pytorch_cpu_node_id_gpu_node_map (Dict[int, List[int]]): Map of PyTorch CPU node IDs to GPU node IDs. chakra_nodes (Dict[int, Any]): Map of Chakra node IDs to nodes. - phase_end_nids (List[int]): List of node IDs for phase dependencies. input_storage_id_nid_map (Dict[int, int]): Map of input storage IDs to node IDs. output_storage_id_nid_map (Dict[int, int]): Map of output storage IDs to node IDs. input_tensor_id_nid_map (Dict[int, int]): Map of input tensor IDs to node IDs. @@ -140,9 +139,6 @@ def initialize_attributes(self) -> None: self.pytorch_cpu_node_id_gpu_node_map = {} self.chakra_nodes = {} - # Initialize lists for phase dependencies and data dependency maps - self.phase_end_nids = [] - # Map of input storage IDs to node IDs: # This dictionary tracks which nodes are consuming tensors based on their # storage ID, establishing a link between tensor storage and node consumption. @@ -176,8 +172,6 @@ def convert(self) -> None: self.open_chakra_execution_trace() - self.construct_phase_end_nids() - self.split_cpu_nodes_with_gpu_child() for pytorch_nid, pytorch_node in self.pytorch_nodes.items(): @@ -214,11 +208,6 @@ def convert(self) -> None: for data_dep_pytorch_node in pytorch_node.data_deps: chakra_node.data_deps.append(data_dep_pytorch_node.id) - dep_nid = self.get_prev_phase_end_nid(chakra_node) - if (dep_nid != -1) and (dep_nid not in chakra_node.data_deps): - chakra_node.data_deps.append(dep_nid) ->>>>>>> a4155fe (et_converter: Refactor PyTorch2ChakraConverter) - self.identify_data_dependency() self.identify_cyclic_dependencies() @@ -343,57 +332,6 @@ def open_chakra_execution_trace(self) -> None: self.logger.error(err_msg) raise Exception(err_msg) - def construct_phase_end_nids(self) -> None: - """ - Identifies the dependencies between phases in the execution trace. - - Uses a depth-first search (DFS) approach starting from phase root nodes to find - the largest Node ID (NID) in each phase for dependency tracking. - """ - self.logger.info("Constructing phase end node IDs.") - for node in self.pytorch_nodes.values(): - if self.is_phase_root_op(node): - largest_nid_within_phase = self.dfs(node) - if largest_nid_within_phase != -1: - self.phase_end_nids.append(largest_nid_within_phase) - self.phase_end_nids.sort() - - def is_phase_root_op(self, node: PyTorchNode) -> bool: - """ - Determines if a node is a root node of a phase. - - Args: - node (PyTorchNode): The node to be checked. - - Returns: - bool: True if the node is a root node of a phase, False otherwise. - """ - return node.parent in self.pytorch_root_nids - - def dfs(self, node: PyTorchNode) -> int: - """ - Performs a depth-first search to find the largest Node ID (NID) in a subtree. - - Explores the subtree of the given node to find the largest NID among CPU operation nodes. - - Args: - node (PyTorchNode): The node from which the search starts. - - Returns: - int: The largest NID found in the subtree, or -1 if no CPU operation node is found. - """ - if node.get_op_type() == PyTorchNodeType.GPU_OP: - return -1 - elif node.get_op_type() == PyTorchNodeType.CPU_OP: - return node.id - else: # PyTorchNodeType.LABEL or any other type - largest_nid = -1 - for child_node in node.children: - largest_nid = max(largest_nid, self.dfs(child_node)) - return largest_nid - - self.pytorch_nodes = updated_pytorch_nodes - def split_cpu_nodes_with_gpu_child(self) -> None: """ Decomposes CPU nodes with GPU child nodes to model execution overlap @@ -439,8 +377,6 @@ def split_cpu_nodes_with_gpu_child(self) -> None: self.pytorch_nodes = updated_pytorch_nodes - self.update_phase_end_nids() - def _split_cpu_node( self, cpu_node: PyTorchNode, gpu_node: PyTorchNode ) -> Tuple[PyTorchNode, PyTorchNode, PyTorchNode]: @@ -505,27 +441,6 @@ def _split_cpu_node( return cpu_node_first, cpu_node_second, gpu_node - def update_phase_end_nids(self) -> None: - """ - Updates the phase end node IDs with the largest new node ID assigned - during the splitting of CPU nodes with GPU children. Utilizes the - get_assigned_ids function from UniqueIdAssigner to find all new IDs and - selects the largest one for each original node ID. - - This ensures that the phase end boundaries are correctly maintained after - splitting the nodes. - """ - self.logger.info( - "Updating phase end node IDs with the largest new IDs after node splitting." - ) - updated_phase_end_nids = [] - for node_id in self.phase_end_nids: - assigned_ids = self.id_assigner.get_assigned_ids(node_id) - if assigned_ids: - updated_phase_end_nids.append(max(assigned_ids)) - updated_phase_end_nids.sort() - self.phase_end_nids = updated_phase_end_nids - def update_input_tensor_map(self, nid: int, inputs: List[List[int]]) -> None: """ Updates input_storage_id_nid_map and input_tensor_id_nid_map with input @@ -673,35 +588,6 @@ def get_nccl_node(self, node: PyTorchNode) -> PyTorchNode: self.logger.error(err_msg) raise ValueError(err_msg) - def get_prev_phase_end_nid(self, node: ChakraNode) -> int: - """ - Returns the Node ID (NID) of the latest node of the previous phase for - the given ChakraNode. - - This method is used to find the closest but smaller value from - phase_end_nids compared to the given node's ID. It helps in - determining the dependencies between different phases in the trace. - - Args: - node (ChakraNode): The node to find the previous phase dependency for. - - Returns: - int: NID of the latest node of the previous phase, or -1 if none. - """ - self.logger.debug( - f"Finding previous inter-phase dependency for node ID {node.id}." - ) - index = bisect.bisect_left(self.phase_end_nids, node.id) - - if index == 0: - # All elements in the list are greater than node.id; - # no element satisfies the condition. - return -1 - else: - # The element at index-1 will be the closest, smaller value - # compared to node.id. - return self.phase_end_nids[index - 1] - def identify_data_dependency(self) -> None: """ Identifies data dependencies between nodes using tensor input/output From 14f513ac4b0288c45f7158eaf3b09577e4fe1b06 Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Thu, 18 Jan 2024 22:04:47 -0500 Subject: [PATCH 04/13] et_converter: Remove get_nccl_node --- et_converter/pytorch2chakra_converter.py | 38 ++---------------------- 1 file changed, 2 insertions(+), 36 deletions(-) diff --git a/et_converter/pytorch2chakra_converter.py b/et_converter/pytorch2chakra_converter.py index 46b568ce..27f6fbec 100644 --- a/et_converter/pytorch2chakra_converter.py +++ b/et_converter/pytorch2chakra_converter.py @@ -192,12 +192,11 @@ def convert(self) -> None: chakra_gpu_node = self.convert_to_chakra_node(pytorch_gpu_node) if chakra_node.type == COMM_COLL_NODE: - pytorch_nccl_node = self.get_nccl_node(pytorch_node) chakra_gpu_node.attr.extend([ ChakraAttr(name="comm_type", - int64_val=pytorch_nccl_node.collective_comm_type), + int64_val=pytorch_gpu_node.collective_comm_type), ChakraAttr(name="comm_size", - int64_val=pytorch_nccl_node.comm_size), + int64_val=pytorch_gpu_node.comm_size), ChakraAttr(name="involved_dim", bool_list={"values": [True]*self.num_dims}) ]) @@ -555,39 +554,6 @@ def get_chakra_node_type_from_pytorch_node(self, pytorch_node: PyTorchNode) -> i return COMP_NODE return INVALID_NODE - def get_nccl_node(self, node: PyTorchNode) -> PyTorchNode: - """ - Returns a PyTorch NCCL node for a given Chakra CPU node. - - Critical for identifying communication type and size in communication nodes. - There are two primary cases to consider: when the given node is a parent - of a record_param_comms node or a NCCL node. - - Args: - node (PyTorchNode): The parent node for which the NCCL node is needed. - - Returns: - PyTorchNode: The corresponding NCCL node. - - Raises: - ValueError: If no corresponding NCCL node is found. - """ - self.logger.debug(f"Retrieving NCCL node for PyTorch node ID {node.id}.") - if node.record_param_comms_node: - record_param_comms_node = node.record_param_comms_node - if record_param_comms_node.nccl_node: - return record_param_comms_node.nccl_node - else: - err_msg = "No NCCL node found in the record_param_comms node." - self.logger.error(err_msg) - raise ValueError(err_msg) - elif node.nccl_node: - return node.nccl_node - else: - err_msg = "No NCCL node associated with the given PyTorch node." - self.logger.error(err_msg) - raise ValueError(err_msg) - def identify_data_dependency(self) -> None: """ Identifies data dependencies between nodes using tensor input/output From 266753c851d81d74329a0a375c03edd20468d424 Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Thu, 18 Jan 2024 22:07:16 -0500 Subject: [PATCH 05/13] et_converter: Remove data dependency identification --- et_converter/pytorch2chakra_converter.py | 144 ----------------------- 1 file changed, 144 deletions(-) diff --git a/et_converter/pytorch2chakra_converter.py b/et_converter/pytorch2chakra_converter.py index 27f6fbec..317f5d0a 100644 --- a/et_converter/pytorch2chakra_converter.py +++ b/et_converter/pytorch2chakra_converter.py @@ -8,7 +8,6 @@ from chakra.third_party.utils.protolib import encodeMessage as encode_message from chakra.et_converter.pytorch_node import PyTorchNodeType, PyTorchNode -from chakra.et_converter.pytorch_tensor import PyTorchTensor, list_to_pytorch_tensor from chakra.et_def.et_def_pb2 import ( GlobalMetadata, Node as ChakraNode, @@ -95,10 +94,6 @@ class PyTorch2ChakraConverter: pytorch_cpu_node_id_gpu_node_map (Dict[int, List[int]]): Map of PyTorch CPU node IDs to GPU node IDs. chakra_nodes (Dict[int, Any]): Map of Chakra node IDs to nodes. - input_storage_id_nid_map (Dict[int, int]): Map of input storage IDs to node IDs. - output_storage_id_nid_map (Dict[int, int]): Map of output storage IDs to node IDs. - input_tensor_id_nid_map (Dict[int, int]): Map of input tensor IDs to node IDs. - output_tensor_id_nid_map (Dict[int, int]): Map of output tensor IDs to node IDs. """ def __init__( @@ -139,29 +134,6 @@ def initialize_attributes(self) -> None: self.pytorch_cpu_node_id_gpu_node_map = {} self.chakra_nodes = {} - # Map of input storage IDs to node IDs: - # This dictionary tracks which nodes are consuming tensors based on their - # storage ID, establishing a link between tensor storage and node consumption. - self.input_storage_id_nid_map = {} - - # Map of output storage IDs to node IDs: - # Similar to input_storage_id_nid_map, but this tracks the production of - # tensors by nodes, associating tensor storage IDs with the nodes that - # produce them. - self.output_storage_id_nid_map = {} - - # Map of input tensor IDs to node IDs: - # This dictionary is used when storage IDs are not applicable. It tracks - # which nodes are consuming tensors by using tensor IDs, creating a link - # between tensor IDs and the nodes that consume them. - self.input_tensor_id_nid_map = {} - - # Map of output tensor IDs to node IDs: - # Similar to input_tensor_id_nid_map, but for tracking the output of tensors - # from nodes. It associates tensor IDs with the nodes that output them, - # used when storage IDs are not available. - self.output_tensor_id_nid_map = {} - def convert(self) -> None: """ Converts PyTorch execution traces into the Chakra format. Orchestrates @@ -176,12 +148,8 @@ def convert(self) -> None: for pytorch_nid, pytorch_node in self.pytorch_nodes.items(): if pytorch_node.is_cpu_op(): - self.update_input_tensor_map(pytorch_node.id, pytorch_node.inputs) - self.update_output_tensor_map(pytorch_node.id, pytorch_node.outputs) - if pytorch_node.child_gpu: pytorch_gpu_node = pytorch_node.child_gpu - self.update_input_tensor_map(pytorch_gpu_node.id, pytorch_gpu_node.inputs) # Ignoring GPU->CPU dependencies for now since it creates unwanted dependencies. chakra_node = self.convert_to_chakra_node(pytorch_node) @@ -207,8 +175,6 @@ def convert(self) -> None: for data_dep_pytorch_node in pytorch_node.data_deps: chakra_node.data_deps.append(data_dep_pytorch_node.id) - self.identify_data_dependency() - self.identify_cyclic_dependencies() self.write_chakra_et() @@ -440,60 +406,6 @@ def _split_cpu_node( return cpu_node_first, cpu_node_second, gpu_node - def update_input_tensor_map(self, nid: int, inputs: List[List[int]]) -> None: - """ - Updates input_storage_id_nid_map and input_tensor_id_nid_map with input - tensor information. - - Each dictionary is populated with mappings between storage ID (or tensor ID) - and node IDs. For example, if node 0 takes tensor 10 as an input, a new - mapping will be created like this `10: [0]`. - - Args: - nid (int): Node ID associated with the input tensors. - inputs (List[List[int]]): List of input tensor data. - """ - for i in inputs: - tensor = list_to_pytorch_tensor(i) - if tensor.is_valid(): - if tensor.has_valid_storage_id(): - storage_id = tensor.storage_id - self.input_storage_id_nid_map.setdefault( - storage_id, [] - ).append(nid) - else: - tensor_id = tensor.tensor_id - self.input_tensor_id_nid_map.setdefault( - tensor_id, [] - ).append(nid) - - def update_output_tensor_map(self, nid: int, outputs: List[List[int]]) -> None: - """ - Updates output_storage_id_nid_map and output_tensor_id_nid_map with output - tensor information. - - Each dictionary is populated with mappings between storage ID (or tensor ID) - and node IDs. For example, if node 0 produces tensor 10 as an output, - a new mapping will be created like this `10: [0]`. - - Args: - nid (int): Node ID associated with the output tensors. - outputs (List[List[int]]): List of output tensor data. - """ - for o in outputs: - tensor = list_to_pytorch_tensor(o) - if tensor.is_valid(): - if tensor.has_valid_storage_id(): - storage_id = tensor.storage_id - self.output_storage_id_nid_map.setdefault( - storage_id, [] - ).append(nid) - else: - tensor_id = tensor.tensor_id - self.output_tensor_id_nid_map.setdefault( - tensor_id, [] - ).append(nid) - def convert_to_chakra_node(self, pytorch_node: PyTorchNode) -> ChakraNode: """ Converts a PyTorchNode to a ChakraNode. @@ -554,62 +466,6 @@ def get_chakra_node_type_from_pytorch_node(self, pytorch_node: PyTorchNode) -> i return COMP_NODE return INVALID_NODE - def identify_data_dependency(self) -> None: - """ - Identifies data dependencies between nodes using tensor input/output - relationships. - - Determines the relationships based on whether the tensors use storage IDs - or tensor IDs. - """ - self.logger.info("Identifying data dependencies among nodes.") - self.identify_data_dependency_with_storage_id() - self.identify_data_dependency_with_tensor_id() - - def identify_data_dependency_with_storage_id(self) -> None: - """ - Identifies data dependency between nodes based on storage IDs. - - Uses the mapping of input and output tensors to their storage IDs to - establish dependencies. - """ - self.logger.info("Identifying data dependencies using storage IDs.") - self.update_data_dependencies( - self.input_storage_id_nid_map, - self.output_storage_id_nid_map) - - def identify_data_dependency_with_tensor_id(self) -> None: - """ - Identifies data dependency between nodes based on tensor IDs. - - Establishes dependencies using tensor IDs for tensors without valid - storage IDs. - """ - self.logger.info("Identifying data dependencies using tensor IDs.") - self.update_data_dependencies( - self.input_tensor_id_nid_map, - self.output_tensor_id_nid_map) - - def update_data_dependencies(self, input_map: Dict[int, List[int]], - output_map: Dict[int, List[int]]) -> None: - """ - Updates data dependencies for nodes based on input and output tensor maps. - - Args: - input_map (Dict[int, List[int]]): Map of input tensor IDs to node IDs. - output_map (Dict[int, List[int]]): Map of output tensor IDs to node IDs. - """ - self.logger.debug("Updating data dependencies for nodes.") - for input_id, child_nids in input_map.items(): - if input_id in output_map: - parent_nids = output_map[input_id] - for child_nid in child_nids: - for parent_nid in parent_nids: - child_node = self.chakra_nodes[child_nid] - if (parent_nid not in child_node.data_deps)\ - and (parent_nid < child_nid): - child_node.data_deps.append(parent_nid) - def identify_cyclic_dependencies(self) -> None: """ Identifies if there are any cyclic dependencies among Chakra nodes. From 55d9692f0ca0d89ff25ab36580d65c2802e2a324 Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Thu, 18 Jan 2024 22:01:33 -0500 Subject: [PATCH 06/13] et_converter: Bugfix in dependency construction --- et_converter/pytorch2chakra_converter.py | 186 ++++++++++++++++++++--- 1 file changed, 162 insertions(+), 24 deletions(-) diff --git a/et_converter/pytorch2chakra_converter.py b/et_converter/pytorch2chakra_converter.py index 317f5d0a..412f02f5 100644 --- a/et_converter/pytorch2chakra_converter.py +++ b/et_converter/pytorch2chakra_converter.py @@ -38,6 +38,15 @@ def __init__(self) -> None: self.next_id = 0 self.original_to_assigned_ids: Dict[int, List[int]] = {} + def set_next_id(self, next_id: int) -> None: + """ + Sets the starting next unique ID. + + Args: + next_id (int): The starting next unique ID to set. + """ + self.next_id = next_id + def assign_unique_id(self, original_id: int) -> int: """ Generates and tracks a new unique ID for each call for a given original ID. @@ -147,11 +156,8 @@ def convert(self) -> None: self.split_cpu_nodes_with_gpu_child() for pytorch_nid, pytorch_node in self.pytorch_nodes.items(): - if pytorch_node.is_cpu_op(): - if pytorch_node.child_gpu: - pytorch_gpu_node = pytorch_node.child_gpu - # Ignoring GPU->CPU dependencies for now since it creates unwanted dependencies. - + if (pytorch_node.get_op_type() == PyTorchNodeType.CPU_OP)\ + or (pytorch_node.get_op_type() == PyTorchNodeType.LABEL): chakra_node = self.convert_to_chakra_node(pytorch_node) self.chakra_nodes[chakra_node.id] = chakra_node @@ -169,11 +175,11 @@ def convert(self) -> None: bool_list={"values": [True]*self.num_dims}) ]) - chakra_gpu_node.data_deps.append(chakra_node.id) self.chakra_nodes[chakra_gpu_node.id] = chakra_gpu_node - for data_dep_pytorch_node in pytorch_node.data_deps: - chakra_node.data_deps.append(data_dep_pytorch_node.id) + root_nodes = [node for node in self.chakra_nodes.values() if self.is_root_node(node)] + for root_node in root_nodes: + self.convert_ctrl_dep_to_data_dep(root_node) self.identify_cyclic_dependencies() @@ -196,6 +202,7 @@ def load_pytorch_execution_traces(self) -> None: with open(self.input_filename, "r") as pytorch_et: pytorch_et_data = json.load(pytorch_et) self._parse_and_instantiate_nodes(pytorch_et_data) + self.id_assigner.set_next_id(max(self.pytorch_nodes.keys()) + 1) except IOError as e: self.logger.error(f"Error opening file {self.input_filename}: {e}") raise Exception(f"Could not open file {self.input_filename}") @@ -329,21 +336,17 @@ def split_cpu_nodes_with_gpu_child(self) -> None: updated_pytorch_nodes[new_cpu_node_id] = cpu_node else: gpu_node = cpu_node.child_gpu - if gpu_node.ts >= (cpu_node.ts + cpu_node.dur): - err_msg = f"Inconsistent timestamps for CPU node {cpu_node.id} and its GPU child" - self.logger.error(err_msg) - raise ValueError(err_msg) - cpu_node_first, cpu_node_second, updated_gpu_node =\ - self._split_cpu_node(cpu_node, gpu_node) - updated_pytorch_nodes[cpu_node_first.id] = cpu_node_first - updated_pytorch_nodes[cpu_node_second.id] = cpu_node_second - updated_pytorch_nodes[updated_gpu_node.id] = updated_gpu_node + self._split_cpu_node(cpu_node, gpu_node, updated_pytorch_nodes) + updated_pytorch_nodes[cpu_node_first.id] = copy.deepcopy(cpu_node_first) + updated_pytorch_nodes[cpu_node_second.id] = copy.deepcopy(cpu_node_second) + updated_pytorch_nodes[updated_gpu_node.id] = copy.deepcopy(updated_gpu_node) self.pytorch_nodes = updated_pytorch_nodes def _split_cpu_node( - self, cpu_node: PyTorchNode, gpu_node: PyTorchNode + self, cpu_node: PyTorchNode, gpu_node: PyTorchNode, + updated_pytorch_nodes: Dict[int, PyTorchNode] ) -> Tuple[PyTorchNode, PyTorchNode, PyTorchNode]: """ Splits a CPU node based on the GPU node's timestamp. @@ -351,9 +354,11 @@ def _split_cpu_node( Args: cpu_node (PyTorchNode): Original CPU node to be split. gpu_node (PyTorchNode): GPU node dictating the split. + updated_pytorch_nodes (Dict[int, PyTorchNode]): Updated PyTorch nodes. Returns: - Tuple[PyTorchNode, PyTorchNode, PyTorchNode]: Two split nodes and the updated GPU node. + Tuple[PyTorchNode, PyTorchNode, PyTorchNode]: Two split nodes and + the updated GPU node. Raises: ValueError: For inconsistencies in the timestamps of the nodes. @@ -368,9 +373,7 @@ def _split_cpu_node( cpu_node_first.id = self.id_assigner.assign_unique_id(cpu_node.id) cpu_node_first.ts = cpu_node.ts cpu_node_first.dur = gpu_node.ts - cpu_node.ts - cpu_node_first.set_child_gpu = gpu_node - for child_node in cpu_node_first.children: - child_node.parent = cpu_node_first.id + cpu_node_first.set_child_gpu(gpu_node) if cpu_node_first.ts >= gpu_node.ts or cpu_node_first.dur <= 0: err_msg = (f"Invalid timestamps for the first split CPU node derived from {original_cpu_info}\n" f"\tFirst Split CPU Node Timestamp: {cpu_node_first.ts}, \n" @@ -379,20 +382,27 @@ def _split_cpu_node( self.logger.error(err_msg) raise ValueError(err_msg) + if cpu_node.parent in self.pytorch_nodes: + self._update_parent_node_children(self.pytorch_nodes, cpu_node, cpu_node_first) + elif cpu_node.parent in updated_pytorch_nodes: + self._update_parent_node_children(updated_pytorch_nodes, cpu_node, cpu_node_first) + self.logger.debug(f"First Split CPU Node ID {cpu_node_first.id} ({cpu_node_first.name}), " f"Duration: {cpu_node_first.dur}") gpu_node_id = self.id_assigner.assign_unique_id(gpu_node.id) gpu_node.id = gpu_node_id + gpu_node.parent = cpu_node_first.id cpu_node_second = copy.deepcopy(cpu_node) cpu_node_second.id = self.id_assigner.assign_unique_id(cpu_node.id) cpu_node_second.ts = gpu_node.ts cpu_node_second.dur = cpu_node.dur - (gpu_node.ts - cpu_node.ts) cpu_node_second.set_child_gpu(None) - cpu_node_second.add_data_dep(cpu_node_first) - for child_node in cpu_node_second.children: + cpu_node_second.parent = cpu_node_first.id + for child_node in cpu_node.children: child_node.parent = cpu_node_second.id + cpu_node_second.add_child(child_node) if cpu_node_second.ts <= cpu_node_first.ts or cpu_node_second.dur <= 0: err_msg = (f"Invalid timestamps for the second split CPU node derived from {original_cpu_info}\n" f"\tFirst Split Timestamp: {cpu_node_first.ts}, \n" @@ -404,8 +414,32 @@ def _split_cpu_node( self.logger.debug(f"Second Split CPU Node ID {cpu_node_second.id} ({cpu_node_second.name}), " f"Duration: {cpu_node_second.dur}.") + cpu_node_first.add_child(cpu_node_second) + cpu_node_first.add_child(gpu_node) + return cpu_node_first, cpu_node_second, gpu_node + def _update_parent_node_children(self, parent_node_dict: Dict[int, PyTorchNode], + cpu_node: PyTorchNode, + cpu_node_first: PyTorchNode) -> None: + """ + Updates the children of the parent node in the given dictionary. + + This method removes the original CPU node from the parent's children list + and adds the first split node. + + Args: + parent_node_dict (Dict[int, PyTorchNode]): Dictionary containing the + parent node. + cpu_node (PyTorchNode): Original CPU node being split. + cpu_node_first (PyTorchNode): First split node to add to the parent's + children. + """ + parent_node = parent_node_dict[cpu_node.parent] + parent_node.children = [child for child in parent_node.children + if child.id != cpu_node.id] + parent_node.children.extend([cpu_node_first]) + def convert_to_chakra_node(self, pytorch_node: PyTorchNode) -> ChakraNode: """ Converts a PyTorchNode to a ChakraNode. @@ -466,6 +500,110 @@ def get_chakra_node_type_from_pytorch_node(self, pytorch_node: PyTorchNode) -> i return COMP_NODE return INVALID_NODE + def is_root_node(self, node): + """ + Determines whether a given node is a root node in the execution trace. + + In the context of PyTorch execution traces, root nodes are the starting + points of execution graphs or execution traces. These nodes typically do + not have parent nodes and act as the original sources of execution flow. + This method identifies such root nodes based on their names. Specifically, + nodes with names indicating they are part of the PyTorch execution graph or + execution trace threads are considered root nodes. + + Args: + node (ChakraNode): The node to be evaluated. + + Returns: + bool: True if the node is a root node, False otherwise. + """ + if node.name in ["[pytorch|profiler|execution_graph|thread]", + "[pytorch|profiler|execution_trace|thread]"]: + return True + + def convert_ctrl_dep_to_data_dep(self, chakra_node: ChakraNode) -> None: + """ + Traverses nodes based on control dependencies (parent nodes) and encodes + data dependencies appropriately. This method is crucial for converting the + dependency structure from PyTorch execution traces to Chakra execution + traces. In PyTorch traces, control dependencies are represented by a + parent field in each node, denoting the parent node ID. This structure + indicates which functions (operators) are called by a particular operator. + + In contrast, Chakra execution traces, while retaining control dependencies + for compatibility, primarily rely on data dependencies to represent + relationships between nodes. Data dependencies in Chakra are more broadly + defined compared to those in PyTorch, where they are implicitly encoded in + tensor input-output relationships. In Chakra, data dependencies are explicit + and represent a general dependency between nodes. + + To convert PyTorch's control dependencies to Chakra's data dependencies, a + Depth-First Search (DFS) is performed. The DFS traversal starts from a given + Chakra node, traversing through its children (based on control + dependencies). During traversal, data dependencies are encoded by linking + nodes that have been visited in sequence. These dependencies form a chain, + mirroring the function call order from the PyTorch trace. + + Special attention is given to the types of nodes involved. CPU and label + nodes (non-GPU) in PyTorch can only depend on other CPU or label nodes. + However, GPU nodes can depend on any type of node. Thus, while traversing, + if a GPU node is encountered, it can establish a data dependency with the + last visited node of any type. For CPU and label nodes, the dependency is + only established with the last visited non-GPU node. This distinction + ensures that the converted dependencies accurately reflect the execution + dynamics of the original PyTorch trace within the Chakra framework. + + Args: + chakra_node (ChakraNode): The starting node for the traversal and + dependency processing. + """ + visited = set() + stack = [chakra_node] + last_visited_non_gpu = None + last_visited_any = None + + while stack: + current_node = stack.pop() + if current_node.id in visited: + continue + + visited.add(current_node.id) + + # Determine the operator type of the current node + pytorch_node = self.pytorch_nodes.get(current_node.id) + if pytorch_node: + node_op_type = pytorch_node.get_op_type() + + if node_op_type == PyTorchNodeType.GPU_OP: + # GPU operators can depend on any type of operator + if last_visited_any: + if last_visited_any.id not in current_node.data_deps: + current_node.data_deps.append(last_visited_any.id) + self.logger.debug( + f"GPU Node ID {current_node.id} now has a data " + f"dependency on Node ID {last_visited_any.id}" + ) + last_visited_any = current_node + else: + # CPU operators depend on non-GPU operators + if last_visited_non_gpu: + if last_visited_non_gpu.id not in current_node.data_deps: + current_node.data_deps.append(last_visited_non_gpu.id) + self.logger.debug( + f"CPU Node ID {current_node.id} now has a data " + f"dependency on non-GPU Node ID " + f"{last_visited_non_gpu.id}" + ) + last_visited_non_gpu = current_node + last_visited_any = current_node + + # Add children to the stack + children_chakra_ids = [child.id for child in pytorch_node.children] + for child_chakra_id in sorted(children_chakra_ids, reverse=True): + child_chakra_node = self.chakra_nodes.get(child_chakra_id) + if child_chakra_node and child_chakra_node.id not in visited: + stack.append(child_chakra_node) + def identify_cyclic_dependencies(self) -> None: """ Identifies if there are any cyclic dependencies among Chakra nodes. From 3e0a94192c21dd43ace7f9367c29a25ea5edff25 Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Sun, 21 Jan 2024 14:39:49 -0500 Subject: [PATCH 07/13] et_converter: Remove dangling nodes --- et_converter/pytorch2chakra_converter.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/et_converter/pytorch2chakra_converter.py b/et_converter/pytorch2chakra_converter.py index 412f02f5..8e0b7f6a 100644 --- a/et_converter/pytorch2chakra_converter.py +++ b/et_converter/pytorch2chakra_converter.py @@ -181,6 +181,8 @@ def convert(self) -> None: for root_node in root_nodes: self.convert_ctrl_dep_to_data_dep(root_node) + self.remove_dangling_nodes() + self.identify_cyclic_dependencies() self.write_chakra_et() @@ -604,6 +606,27 @@ def convert_ctrl_dep_to_data_dep(self, chakra_node: ChakraNode) -> None: if child_chakra_node and child_chakra_node.id not in visited: stack.append(child_chakra_node) + def remove_dangling_nodes(self) -> None: + """ + Removes any dangling nodes from the chakra_nodes dictionary. + A node is considered dangling if it has no parents and no children. + """ + parent_ids = set() + for node in self.chakra_nodes.values(): + parent_ids.update(node.data_deps) + + dangling_nodes = [] + for node_id, node in list(self.chakra_nodes.items()): + if node_id not in parent_ids and not node.data_deps: + dangling_nodes.append(node) + del self.chakra_nodes[node_id] + del self.pytorch_nodes[node_id] + + if dangling_nodes: + self.logger.info(f"Identified and removed {len(dangling_nodes)} dangling nodes:") + for node in dangling_nodes: + self.logger.info(f" - Node ID {node.id}: {node.name}") + def identify_cyclic_dependencies(self) -> None: """ Identifies if there are any cyclic dependencies among Chakra nodes. From 776baa51ec7d3a1434229b4041a712897d52799c Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Wed, 24 Jan 2024 10:20:28 -0500 Subject: [PATCH 08/13] et_converter: Support inter-thread dependencies --- et_converter/pytorch2chakra_converter.py | 15 +++++++++++++++ et_converter/pytorch_node.py | 11 +++++++++++ 2 files changed, 26 insertions(+) diff --git a/et_converter/pytorch2chakra_converter.py b/et_converter/pytorch2chakra_converter.py index 8e0b7f6a..8db179f2 100644 --- a/et_converter/pytorch2chakra_converter.py +++ b/et_converter/pytorch2chakra_converter.py @@ -555,6 +555,13 @@ def convert_ctrl_dep_to_data_dep(self, chakra_node: ChakraNode) -> None: ensures that the converted dependencies accurately reflect the execution dynamics of the original PyTorch trace within the Chakra framework. + Furthermore, inter-thread dependencies are explicitly encoded in the Chakra + execution traces. This feature allows for the representation of dependencies + across different CPU threads, which are observed in Kineto traces via + chrome://tracing. These dependencies are crucial for understanding the + interaction between CPU threads and ensuring accurate modeling and analysis + of concurrent operations within the Chakra framework. + Args: chakra_node (ChakraNode): The starting node for the traversal and dependency processing. @@ -587,6 +594,14 @@ def convert_ctrl_dep_to_data_dep(self, chakra_node: ChakraNode) -> None: ) last_visited_any = current_node else: + if pytorch_node.inter_thread_dep: + for id in self.id_assigner.get_assigned_ids(pytorch_node.inter_thread_dep): + current_node.data_deps.append(id) + self.logger.debug( + f"CPU Node ID {current_node.id} now has an inter-thread data " + f"dependency on Node ID {id}" + ) + # CPU operators depend on non-GPU operators if last_visited_non_gpu: if last_visited_non_gpu.id not in current_node.data_deps: diff --git a/et_converter/pytorch_node.py b/et_converter/pytorch_node.py index b5068734..e42898c9 100644 --- a/et_converter/pytorch_node.py +++ b/et_converter/pytorch_node.py @@ -437,6 +437,17 @@ def dur(self, value: int) -> None: """ self.node_data["dur"] = value + @property + def inter_thread_dep(self) -> Optional[int]: + """ + Returns the inter-thread dependency value of the node, if available. + + Returns: + Optional[int]: The inter-thread dependency value or None if not + available. + """ + return self.node_data.get("inter_thread_dep") + def has_ts(self) -> bool: """ Checks if the node has a timestamp field. From d3c3caca65a5f753ebe9632cdfb63199c1f84272 Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Wed, 24 Jan 2024 10:55:39 -0500 Subject: [PATCH 09/13] et_converter: Support intra-stream dependencies --- et_converter/pytorch2chakra_converter.py | 92 ++++++++++++++---------- et_converter/pytorch_node.py | 4 ++ 2 files changed, 58 insertions(+), 38 deletions(-) diff --git a/et_converter/pytorch2chakra_converter.py b/et_converter/pytorch2chakra_converter.py index 8db179f2..9f44823e 100644 --- a/et_converter/pytorch2chakra_converter.py +++ b/et_converter/pytorch2chakra_converter.py @@ -555,6 +555,11 @@ def convert_ctrl_dep_to_data_dep(self, chakra_node: ChakraNode) -> None: ensures that the converted dependencies accurately reflect the execution dynamics of the original PyTorch trace within the Chakra framework. + Additionally, this method enforces sequential dependencies between GPU + operators within the same stream. It ensures that the execution order of + GPU operators is preserved in the Chakra trace, reflecting the sequential + execution within the same GPU stream in the original PyTorch trace. + Furthermore, inter-thread dependencies are explicitly encoded in the Chakra execution traces. This feature allows for the representation of dependencies across different CPU threads, which are observed in Kineto traces via @@ -566,10 +571,11 @@ def convert_ctrl_dep_to_data_dep(self, chakra_node: ChakraNode) -> None: chakra_node (ChakraNode): The starting node for the traversal and dependency processing. """ - visited = set() - stack = [chakra_node] - last_visited_non_gpu = None - last_visited_any = None + visited: Set[int] = set() + stack: List[ChakraNode] = [chakra_node] + last_visited_non_gpu: Optional[ChakraNode] = None + last_visited_any: Optional[ChakraNode] = None + last_gpu_in_stream: Dict[int, ChakraNode] = {} while stack: current_node = stack.pop() @@ -580,46 +586,56 @@ def convert_ctrl_dep_to_data_dep(self, chakra_node: ChakraNode) -> None: # Determine the operator type of the current node pytorch_node = self.pytorch_nodes.get(current_node.id) - if pytorch_node: - node_op_type = pytorch_node.get_op_type() - - if node_op_type == PyTorchNodeType.GPU_OP: - # GPU operators can depend on any type of operator - if last_visited_any: - if last_visited_any.id not in current_node.data_deps: - current_node.data_deps.append(last_visited_any.id) - self.logger.debug( - f"GPU Node ID {current_node.id} now has a data " - f"dependency on Node ID {last_visited_any.id}" - ) - last_visited_any = current_node - else: - if pytorch_node.inter_thread_dep: - for id in self.id_assigner.get_assigned_ids(pytorch_node.inter_thread_dep): + if not pytorch_node: + continue + + node_op_type = pytorch_node.get_op_type() + + if node_op_type == PyTorchNodeType.GPU_OP: + if last_visited_any: + if last_visited_any.id not in current_node.data_deps: + current_node.data_deps.append(last_visited_any.id) + self.logger.debug( + f"GPU Node ID {current_node.id} now has a data " + f"dependency on Node ID {last_visited_any.id}" + ) + + stream_id = pytorch_node.stream + if stream_id in last_gpu_in_stream: + if last_gpu_in_stream[stream_id].id not in current_node.data_deps: + current_node.data_deps.append(last_gpu_in_stream[stream_id].id) + self.logger.debug( + f"GPU Node ID {current_node.id} in stream {stream_id} now has a data " + f"dependency on GPU Node ID {last_gpu_in_stream[stream_id].id} in the same stream." + ) + last_gpu_in_stream[stream_id] = current_node + last_visited_any = current_node + else: + if pytorch_node.inter_thread_dep: + for id in self.id_assigner.get_assigned_ids(pytorch_node.inter_thread_dep): + if id not in current_node.data_deps: current_node.data_deps.append(id) self.logger.debug( f"CPU Node ID {current_node.id} now has an inter-thread data " f"dependency on Node ID {id}" ) - # CPU operators depend on non-GPU operators - if last_visited_non_gpu: - if last_visited_non_gpu.id not in current_node.data_deps: - current_node.data_deps.append(last_visited_non_gpu.id) - self.logger.debug( - f"CPU Node ID {current_node.id} now has a data " - f"dependency on non-GPU Node ID " - f"{last_visited_non_gpu.id}" - ) - last_visited_non_gpu = current_node - last_visited_any = current_node - - # Add children to the stack - children_chakra_ids = [child.id for child in pytorch_node.children] - for child_chakra_id in sorted(children_chakra_ids, reverse=True): - child_chakra_node = self.chakra_nodes.get(child_chakra_id) - if child_chakra_node and child_chakra_node.id not in visited: - stack.append(child_chakra_node) + if last_visited_non_gpu: + if last_visited_non_gpu.id not in current_node.data_deps: + current_node.data_deps.append(last_visited_non_gpu.id) + self.logger.debug( + f"CPU Node ID {current_node.id} now has a data " + f"dependency on non-GPU Node ID {last_visited_non_gpu.id}" + ) + last_visited_non_gpu = current_node + last_visited_any = current_node + + # Add children to the stack + children_chakra_ids = [child.id for child in pytorch_node.children] + for child_chakra_id in sorted(children_chakra_ids, reverse=True): + child_chakra_node = self.chakra_nodes.get(child_chakra_id) + if child_chakra_node and child_chakra_node.id not in visited: + stack.append(child_chakra_node) def remove_dangling_nodes(self) -> None: """ diff --git a/et_converter/pytorch_node.py b/et_converter/pytorch_node.py index e42898c9..1e442983 100644 --- a/et_converter/pytorch_node.py +++ b/et_converter/pytorch_node.py @@ -448,6 +448,10 @@ def inter_thread_dep(self) -> Optional[int]: """ return self.node_data.get("inter_thread_dep") + @property + def stream(self) -> int: + return self.node_data["stream"] + def has_ts(self) -> bool: """ Checks if the node has a timestamp field. From 6e6031747b49d59f955697d677ed8e605f306783 Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Sat, 27 Jan 2024 13:17:20 -0500 Subject: [PATCH 10/13] et_converter: Fix collective comm identification --- et_converter/pytorch2chakra_converter.py | 34 +++++++++++++++++++++++- et_converter/pytorch_node.py | 27 ------------------- 2 files changed, 33 insertions(+), 28 deletions(-) diff --git a/et_converter/pytorch2chakra_converter.py b/et_converter/pytorch2chakra_converter.py index 9f44823e..76a15ff0 100644 --- a/et_converter/pytorch2chakra_converter.py +++ b/et_converter/pytorch2chakra_converter.py @@ -166,9 +166,10 @@ def convert(self) -> None: chakra_gpu_node = self.convert_to_chakra_node(pytorch_gpu_node) if chakra_node.type == COMM_COLL_NODE: + collective_comm_type = self.get_collective_comm_type(pytorch_node.name) chakra_gpu_node.attr.extend([ ChakraAttr(name="comm_type", - int64_val=pytorch_gpu_node.collective_comm_type), + int64_val=collective_comm_type), ChakraAttr(name="comm_size", int64_val=pytorch_gpu_node.comm_size), ChakraAttr(name="involved_dim", @@ -502,6 +503,37 @@ def get_chakra_node_type_from_pytorch_node(self, pytorch_node: PyTorchNode) -> i return COMP_NODE return INVALID_NODE + def get_collective_comm_type(self, name: str) -> int: + """ + Returns the collective communication type of the node. + + Args: + name (str): The name of the node. + + Raises: + ValueError: If the communication type is not found in the mapping. + + Returns: + int: The collective communication type of the node. + """ + comm_type_mapping = { + "all_reduce": ALL_REDUCE, + "all_to_all": ALL_TO_ALL, + "all_gather": ALL_GATHER, + "reduce_scatter": REDUCE_SCATTER, + "broadcast": BROADCAST, + "AllReduce": ALL_REDUCE, + "Broadcast": BROADCAST, + # Additional cases can be added here + } + + for key, value in comm_type_mapping.items(): + if key.lower() in name.lower(): + return value + + raise ValueError(f"'{name}' not found in collective communication mapping. " + "Please add this collective communication name to the mapping.") + def is_root_node(self, node): """ Determines whether a given node is a root node in the execution trace. diff --git a/et_converter/pytorch_node.py b/et_converter/pytorch_node.py index 1e442983..ba50a926 100644 --- a/et_converter/pytorch_node.py +++ b/et_converter/pytorch_node.py @@ -573,33 +573,6 @@ def comm_size(self) -> int: comm_size += type_size * shape_size return comm_size - @property - def collective_comm_type(self) -> int: - """ - Returns the collective communication type of the node. - - Raises: - ValueError: If the communication type is not found in the mapping. - - Returns: - int: The collective communication type of the node. - """ - comm_type_mapping = { - "all_reduce": ALL_REDUCE, - "all_to_all": ALL_TO_ALL, - "all_gather": ALL_GATHER, - "reduce_scatter": REDUCE_SCATTER, - "broadcast": BROADCAST, - "AllReduce": ALL_REDUCE, - "Broadcast": BROADCAST, - # TODO: Add more cases - } - for key, value in comm_type_mapping.items(): - if key in self.node_data["name"]: - return value - - raise ValueError("Communication type not found in mapping.") - @staticmethod def get_data_type_size(data_type: str) -> int: """ From 92ee8068eb3c700b371214e9839c6c5015abab5b Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Wed, 31 Jan 2024 21:26:54 -0500 Subject: [PATCH 11/13] et_converter: Simulate execution of Chakra nodes based on data dependencies --- et_converter/pytorch2chakra_converter.py | 105 +++++++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/et_converter/pytorch2chakra_converter.py b/et_converter/pytorch2chakra_converter.py index 76a15ff0..f89b5724 100644 --- a/et_converter/pytorch2chakra_converter.py +++ b/et_converter/pytorch2chakra_converter.py @@ -103,6 +103,11 @@ class PyTorch2ChakraConverter: pytorch_cpu_node_id_gpu_node_map (Dict[int, List[int]]): Map of PyTorch CPU node IDs to GPU node IDs. chakra_nodes (Dict[int, Any]): Map of Chakra node IDs to nodes. + parent_to_children_map (Dict[int, List[int]]): Map of Chakra parent node + IDs to their child node + IDs. Used to simulate + execution based on data + dependencies. """ def __init__( @@ -143,6 +148,8 @@ def initialize_attributes(self) -> None: self.pytorch_cpu_node_id_gpu_node_map = {} self.chakra_nodes = {} + self.parent_to_children_map = {} + def convert(self) -> None: """ Converts PyTorch execution traces into the Chakra format. Orchestrates @@ -184,12 +191,16 @@ def convert(self) -> None: self.remove_dangling_nodes() + self.update_parent_to_children_map() + self.identify_cyclic_dependencies() self.write_chakra_et() self.close_chakra_execution_trace() + self.simulate_execution() + def load_pytorch_execution_traces(self) -> None: """ Loads PyTorch execution traces from a file. @@ -690,6 +701,18 @@ def remove_dangling_nodes(self) -> None: for node in dangling_nodes: self.logger.info(f" - Node ID {node.id}: {node.name}") + def update_parent_to_children_map(self) -> None: + """ + Updates the parent_to_children_map based on the data dependencies of each node. + This map is used to efficiently simulate node execution based on data dependencies. + """ + for node_id, node in self.chakra_nodes.items(): + for dep_id in node.data_deps: + # Ensure the dependency is registered as a parent of the current node + if dep_id not in self.parent_to_children_map: + self.parent_to_children_map[dep_id] = [] + self.parent_to_children_map[dep_id].append(node_id) + def identify_cyclic_dependencies(self) -> None: """ Identifies if there are any cyclic dependencies among Chakra nodes. @@ -801,3 +824,85 @@ def close_chakra_execution_trace(self) -> None: self.logger.info("Closing Chakra execution trace file.") if self.chakra_et and not self.chakra_et.closed: self.chakra_et.close() + + def simulate_execution(self) -> None: + """ + Simulates the execution of Chakra nodes based on data dependencies. + + This method considers both CPU and GPU nodes. Nodes are issued for + execution based on the readiness determined by dependency resolution. + A simplistic global clock is used to model the execution time. + """ + self.logger.info("Simulating execution of Chakra nodes based on data " + "dependencies.") + + # Initialize queues for ready CPU and GPU nodes + ready_cpu_nodes = [ + (node_id, self.chakra_nodes[node_id]) + for node_id in self.chakra_nodes + if not self.chakra_nodes[node_id].data_deps and + not self.pytorch_nodes[node_id].is_gpu_op() + ] + ready_gpu_nodes = [ + (node_id, self.chakra_nodes[node_id]) + for node_id in self.chakra_nodes + if not self.chakra_nodes[node_id].data_deps and + self.pytorch_nodes[node_id].is_gpu_op() + ] + ready_cpu_nodes.sort(key=lambda x: x[1].id) + ready_gpu_nodes.sort(key=lambda x: x[1].id) + + issued_nodes: Set[int] = set() + current_cpu_node: Optional[Tuple[int, int]] = None + current_gpu_node: Optional[Tuple[int, int]] = None + + current_time: int = 0 # Simulated global clock in microseconds + + while any([ready_cpu_nodes, ready_gpu_nodes, current_cpu_node, + current_gpu_node]): + if ready_cpu_nodes and not current_cpu_node: + cpu_node_id, cpu_node = ready_cpu_nodes.pop(0) + current_cpu_node = (cpu_node_id, current_time) + issued_nodes.add(cpu_node_id) + self.logger.info( + f"Issuing CPU Node ID {cpu_node_id} ({cpu_node.name}) at " + f"{current_time}us with duration {cpu_node.duration_micros}us" + ) + + if ready_gpu_nodes and not current_gpu_node: + gpu_node_id, gpu_node = ready_gpu_nodes.pop(0) + current_gpu_node = (gpu_node_id, current_time) + issued_nodes.add(gpu_node_id) + self.logger.info( + f"Issuing GPU Node ID {gpu_node_id} ({gpu_node.name}) at " + f"{current_time}us with duration {gpu_node.duration_micros}us" + ) + + current_time += 1 + + if current_cpu_node and current_time - current_cpu_node[1] >= \ + self.chakra_nodes[current_cpu_node[0]].duration_micros: + self.logger.info(f"CPU Node ID {current_cpu_node[0]} completed " + f"at {current_time}us") + current_cpu_node = None + + if current_gpu_node and current_time - current_gpu_node[1] >= \ + self.chakra_nodes[current_gpu_node[0]].duration_micros: + self.logger.info(f"GPU Node ID {current_gpu_node[0]} completed " + f"at {current_time}us") + current_gpu_node = None + + for node_id in list(issued_nodes): + children_ids = self.parent_to_children_map.get(node_id, []) + for child_id in children_ids: + child_node = self.chakra_nodes[child_id] + child_node.data_deps.remove(node_id) + if not child_node.data_deps: + if not self.pytorch_nodes[child_id].is_gpu_op(): + ready_cpu_nodes.append((child_id, child_node)) + else: + ready_gpu_nodes.append((child_id, child_node)) + + issued_nodes.clear() + + self.logger.info("Simulation of Chakra node execution completed.") From 0e19270d8d5097e18cc2646b94da4feab938ac7a Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Thu, 1 Feb 2024 07:29:16 -0500 Subject: [PATCH 12/13] et_converter: Differentiate inclusive and exclusive duration --- et_converter/pytorch2chakra_converter.py | 52 ++++++++++++++++-------- et_converter/pytorch_node.py | 44 +++++++++++++++----- 2 files changed, 68 insertions(+), 28 deletions(-) diff --git a/et_converter/pytorch2chakra_converter.py b/et_converter/pytorch2chakra_converter.py index f89b5724..8b4d1b7e 100644 --- a/et_converter/pytorch2chakra_converter.py +++ b/et_converter/pytorch2chakra_converter.py @@ -349,12 +349,24 @@ def split_cpu_nodes_with_gpu_child(self) -> None: child_node.parent = cpu_node.id updated_pytorch_nodes[new_cpu_node_id] = cpu_node else: - gpu_node = cpu_node.child_gpu - cpu_node_first, cpu_node_second, updated_gpu_node =\ - self._split_cpu_node(cpu_node, gpu_node, updated_pytorch_nodes) - updated_pytorch_nodes[cpu_node_first.id] = copy.deepcopy(cpu_node_first) - updated_pytorch_nodes[cpu_node_second.id] = copy.deepcopy(cpu_node_second) - updated_pytorch_nodes[updated_gpu_node.id] = copy.deepcopy(updated_gpu_node) + if cpu_node.exclusive_dur > 1: + gpu_node = cpu_node.child_gpu + cpu_node_first, cpu_node_second, updated_gpu_node =\ + self._split_cpu_node(cpu_node, gpu_node, updated_pytorch_nodes) + updated_pytorch_nodes[cpu_node_first.id] = copy.deepcopy(cpu_node_first) + updated_pytorch_nodes[cpu_node_second.id] = copy.deepcopy(cpu_node_second) + updated_pytorch_nodes[updated_gpu_node.id] = copy.deepcopy(updated_gpu_node) + else: + new_cpu_node_id = self.id_assigner.assign_unique_id(cpu_node.id) + cpu_node.id = new_cpu_node_id + for child_node in cpu_node.children: + child_node.parent = cpu_node.id + updated_pytorch_nodes[new_cpu_node_id] = cpu_node + + gpu_node = cpu_node.child_gpu + gpu_node.parent = new_cpu_node_id + new_gpu_node_id = self.id_assigner.assign_unique_id(gpu_node.id) + updated_pytorch_nodes[new_gpu_node_id] = gpu_node self.pytorch_nodes = updated_pytorch_nodes @@ -378,21 +390,24 @@ def _split_cpu_node( ValueError: For inconsistencies in the timestamps of the nodes. """ original_cpu_info = f"Original CPU Node ID {cpu_node.id} ({cpu_node.name}), " \ - f"Duration: {cpu_node.dur}." + f"Inclusive Duration: {cpu_node.inclusive_dur}, " \ + f"Exclusive Duration: {cpu_node.exclusive_dur}." self.logger.debug(original_cpu_info) self.logger.debug(f"GPU Node ID {gpu_node.id} ({gpu_node.name}), " - f"Duration: {gpu_node.dur}.") + f"Inclusive Duration: {gpu_node.inclusive_dur}, " + f"Exclusive Duration: {gpu_node.exclusive_dur}.") cpu_node_first = copy.deepcopy(cpu_node) cpu_node_first.id = self.id_assigner.assign_unique_id(cpu_node.id) cpu_node_first.ts = cpu_node.ts - cpu_node_first.dur = gpu_node.ts - cpu_node.ts + cpu_node_first.exclusive_dur = int(cpu_node.exclusive_dur / 2) cpu_node_first.set_child_gpu(gpu_node) - if cpu_node_first.ts >= gpu_node.ts or cpu_node_first.dur <= 0: + if cpu_node_first.ts >= gpu_node.ts or cpu_node_first.inclusive_dur <= 0: err_msg = (f"Invalid timestamps for the first split CPU node derived from {original_cpu_info}\n" f"\tFirst Split CPU Node Timestamp: {cpu_node_first.ts}, \n" f"\tGPU Node Timestamp: {gpu_node.ts}, \n" - f"\tFirst Split CPU Node Duration: {cpu_node_first.dur}.") + f"\tFirst Split CPU Node Inclusive Duration: {cpu_node_first.inclusive_dur}, \n" + f"\tFirst Split CPU Node Exclusive Duration: {cpu_node_first.exclusive_dur}.") self.logger.error(err_msg) raise ValueError(err_msg) @@ -402,7 +417,8 @@ def _split_cpu_node( self._update_parent_node_children(updated_pytorch_nodes, cpu_node, cpu_node_first) self.logger.debug(f"First Split CPU Node ID {cpu_node_first.id} ({cpu_node_first.name}), " - f"Duration: {cpu_node_first.dur}") + f"Inclusive Duration: {cpu_node_first.inclusive_dur}, " + f"Exclusive Duration: {cpu_node_first.exclusive_dur}.") gpu_node_id = self.id_assigner.assign_unique_id(gpu_node.id) gpu_node.id = gpu_node_id @@ -411,22 +427,24 @@ def _split_cpu_node( cpu_node_second = copy.deepcopy(cpu_node) cpu_node_second.id = self.id_assigner.assign_unique_id(cpu_node.id) cpu_node_second.ts = gpu_node.ts - cpu_node_second.dur = cpu_node.dur - (gpu_node.ts - cpu_node.ts) + cpu_node_second.exclusive_dur = int(cpu_node.exclusive_dur / 2) cpu_node_second.set_child_gpu(None) cpu_node_second.parent = cpu_node_first.id for child_node in cpu_node.children: child_node.parent = cpu_node_second.id cpu_node_second.add_child(child_node) - if cpu_node_second.ts <= cpu_node_first.ts or cpu_node_second.dur <= 0: + if cpu_node_second.ts <= cpu_node_first.ts or cpu_node_second.inclusive_dur <= 0: err_msg = (f"Invalid timestamps for the second split CPU node derived from {original_cpu_info}\n" f"\tFirst Split Timestamp: {cpu_node_first.ts}, \n" f"\tSecond Split Timestamp: {cpu_node_second.ts}, \n" - f"\tSecond Split Duration: {cpu_node_second.dur}.") + f"\tSecond Split Inclusive Duration: {cpu_node_second.inclusive_dur}, " + f"\tSecond Split Exclusive Duration: {cpu_node_second.exclusive_dur}.") self.logger.error(err_msg) raise ValueError(err_msg) self.logger.debug(f"Second Split CPU Node ID {cpu_node_second.id} ({cpu_node_second.name}), " - f"Duration: {cpu_node_second.dur}.") + f"Inclusive Duration: {cpu_node_second.inclusive_dur}, " + f"Exclusive Duration: {cpu_node_second.exclusive_dur}.") cpu_node_first.add_child(cpu_node_second) cpu_node_first.add_child(gpu_node) @@ -472,7 +490,7 @@ def convert_to_chakra_node(self, pytorch_node: PyTorchNode) -> ChakraNode: chakra_node.type = self.get_chakra_node_type_from_pytorch_node(pytorch_node) if pytorch_node.parent in self.chakra_nodes: chakra_node.ctrl_deps.append(pytorch_node.parent) - chakra_node.duration_micros = pytorch_node.dur if pytorch_node.has_dur() else 0 + chakra_node.duration_micros = pytorch_node.exclusive_dur chakra_node.inputs.values = str(pytorch_node.inputs) chakra_node.inputs.shapes = str(pytorch_node.input_shapes) chakra_node.inputs.types = str(pytorch_node.input_types) diff --git a/et_converter/pytorch_node.py b/et_converter/pytorch_node.py index ba50a926..0ba972ac 100644 --- a/et_converter/pytorch_node.py +++ b/et_converter/pytorch_node.py @@ -53,7 +53,9 @@ def __repr__(self) -> str: f"PyTorchNode(" f"id={self.id}, name={self.name}, " f"op_type={self.get_op_type()}, " - f"timestamp={self.ts}, duration={self.dur})" + f"timestamp={self.ts}, " + f"inclusive_duration={self.inclusive_dur}, " + f"exclusive_duration={self.exclusive_dur})" ) @property @@ -418,24 +420,44 @@ def cat(self, value: str) -> None: self.node_data["cat"] = value @property - def dur(self) -> int: + def inclusive_dur(self) -> int: """ - Returns the duration of the node. + Returns the inclusive duration of the node. Returns: - int: The duration of the node. + int: The inclusive duration of the node. """ - return self.node_data["dur"] + return self.node_data["inclusive_dur"] - @dur.setter - def dur(self, value: int) -> None: + @inclusive_dur.setter + def inclusive_dur(self, value: int) -> None: """ - Sets the duration of the node. + Sets the inclusive duration of the node. Args: - value (int): The new duration of the node. + value (int): The new inclusive duration of the node. """ - self.node_data["dur"] = value + self.node_data["inclusive_dur"] = value + + @property + def exclusive_dur(self) -> int: + """ + Returns the exclusive duration of the node. + + Returns: + int: The exclusive duration of the node. + """ + return self.node_data.get("exclusive_dur", 0) + + @exclusive_dur.setter + def exclusive_dur(self, value: int) -> None: + """ + Sets the exclusive duration of the node. + + Args: + value (int): The new exclusive duration of the node. + """ + self.node_data["exclusive_dur"] = value @property def inter_thread_dep(self) -> Optional[int]: @@ -477,7 +499,7 @@ def has_dur(self) -> bool: Returns: bool: True if the node has a duration field, False otherwise. """ - return "dur" in self.node_data + return "inclusive_dur" in self.node_data def get_op_type(self) -> PyTorchNodeType: """ From 4b85c074b22fa87ece5bfbf87ac96a17ad3f0dff Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Tue, 6 Feb 2024 07:21:13 -0500 Subject: [PATCH 13/13] et_converter: Identify non-comm nodes as comp nodes --- et_converter/pytorch2chakra_converter.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/et_converter/pytorch2chakra_converter.py b/et_converter/pytorch2chakra_converter.py index 8b4d1b7e..e4db4e30 100644 --- a/et_converter/pytorch2chakra_converter.py +++ b/et_converter/pytorch2chakra_converter.py @@ -524,13 +524,9 @@ def get_chakra_node_type_from_pytorch_node(self, pytorch_node: PyTorchNode) -> i "ncclKernel" in pytorch_node.name or "ncclDevKernel" in pytorch_node.name ): return COMM_COLL_NODE - elif pytorch_node.is_gpu_op(): - return COMP_NODE elif ("c10d::" in pytorch_node.name) or ("nccl:" in pytorch_node.name): return COMM_COLL_NODE - elif (pytorch_node.op_schema != "") or pytorch_node.outputs: - return COMP_NODE - return INVALID_NODE + return COMP_NODE def get_collective_comm_type(self, name: str) -> int: """