Skip to content

Commit

Permalink
Merge pull request #28 from mlcommons/bugfix
Browse files Browse the repository at this point in the history
Support multi-gpu child ops and remove CPU/GPU op splitting
  • Loading branch information
TaekyungHeo authored Apr 2, 2024
2 parents a815e0c + 97deba1 commit ccbc25d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 167 deletions.
165 changes: 4 additions & 161 deletions et_converter/pytorch2chakra_converter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python3

import copy
import json
import logging
from typing import Dict, List, Optional, Tuple, Set
Expand Down Expand Up @@ -157,16 +156,13 @@ def convert(self) -> None:

self.open_chakra_execution_trace()

self.split_cpu_nodes_with_gpu_child()

for pytorch_nid, pytorch_node in self.pytorch_nodes.items():
if (pytorch_node.get_op_type() == PyTorchNodeType.CPU_OP)\
or (pytorch_node.get_op_type() == PyTorchNodeType.LABEL):
chakra_node = self.convert_to_chakra_node(pytorch_node)
self.chakra_nodes[chakra_node.id] = chakra_node

if pytorch_node.child_gpu:
pytorch_gpu_node = pytorch_node.child_gpu
for pytorch_gpu_node in pytorch_node.gpu_children:
chakra_gpu_node = self.convert_to_chakra_node(pytorch_gpu_node)

if chakra_node.type == COMM_COLL_NODE:
Expand Down Expand Up @@ -267,7 +263,7 @@ def _establish_parent_child_relationships(
parent_node.add_child(pytorch_node)

if pytorch_node.is_gpu_op():
parent_node.set_child_gpu(pytorch_node)
parent_node.add_gpu_child(pytorch_node)

if pytorch_node.is_record_param_comms_op():
parent_node.record_param_comms_node = pytorch_node
Expand Down Expand Up @@ -312,160 +308,6 @@ def open_chakra_execution_trace(self) -> None:
self.logger.error(err_msg)
raise Exception(err_msg)

def split_cpu_nodes_with_gpu_child(self) -> None:
"""
Decomposes CPU nodes with GPU child nodes to model execution overlap
accurately. This method addresses scenarios where a CPU node has a GPU
child node, with an overlap in their execution ending at the same time.
The method splits the CPU node into:
1. Non-Overlapping Part: Segment before the GPU node starts.
2. Overlapping Part: Segment overlapping with the GPU node.
Timeline Stages:
Stage 1 - Original Scenario:
|------------ CPU Node ------------|
|--- GPU Node ---|
Stage 2 - After Split:
|-- Non-Overlap --|--- Overlap ----|
|--- GPU Node ---|
Raises:
ValueError: If timestamps of GPU and CPU nodes are inconsistent.
"""
self.logger.info("Decomposing CPU nodes with GPU child nodes.")
updated_pytorch_nodes: Dict[int, PyTorchNode] = {}
for cpu_node in self.pytorch_nodes.values():
if cpu_node.child_gpu is None:
new_cpu_node_id = self.id_assigner.assign_unique_id(cpu_node.id)
cpu_node.id = new_cpu_node_id
for child_node in cpu_node.children:
child_node.parent = cpu_node.id
updated_pytorch_nodes[new_cpu_node_id] = cpu_node
else:
if cpu_node.exclusive_dur > 1:
gpu_node = cpu_node.child_gpu
cpu_node_first, cpu_node_second, updated_gpu_node =\
self._split_cpu_node(cpu_node, gpu_node, updated_pytorch_nodes)
updated_pytorch_nodes[cpu_node_first.id] = copy.deepcopy(cpu_node_first)
updated_pytorch_nodes[cpu_node_second.id] = copy.deepcopy(cpu_node_second)
updated_pytorch_nodes[updated_gpu_node.id] = copy.deepcopy(updated_gpu_node)
else:
new_cpu_node_id = self.id_assigner.assign_unique_id(cpu_node.id)
cpu_node.id = new_cpu_node_id
for child_node in cpu_node.children:
child_node.parent = cpu_node.id
updated_pytorch_nodes[new_cpu_node_id] = cpu_node

gpu_node = cpu_node.child_gpu
gpu_node.parent = new_cpu_node_id
new_gpu_node_id = self.id_assigner.assign_unique_id(gpu_node.id)
updated_pytorch_nodes[new_gpu_node_id] = gpu_node

self.pytorch_nodes = updated_pytorch_nodes

def _split_cpu_node(
self, cpu_node: PyTorchNode, gpu_node: PyTorchNode,
updated_pytorch_nodes: Dict[int, PyTorchNode]
) -> Tuple[PyTorchNode, PyTorchNode, PyTorchNode]:
"""
Splits a CPU node based on the GPU node's timestamp.
Args:
cpu_node (PyTorchNode): Original CPU node to be split.
gpu_node (PyTorchNode): GPU node dictating the split.
updated_pytorch_nodes (Dict[int, PyTorchNode]): Updated PyTorch nodes.
Returns:
Tuple[PyTorchNode, PyTorchNode, PyTorchNode]: Two split nodes and
the updated GPU node.
Raises:
ValueError: For inconsistencies in the timestamps of the nodes.
"""
original_cpu_info = f"Original CPU Node ID {cpu_node.id} ({cpu_node.name}), " \
f"Inclusive Duration: {cpu_node.inclusive_dur}, " \
f"Exclusive Duration: {cpu_node.exclusive_dur}."
self.logger.debug(original_cpu_info)
self.logger.debug(f"GPU Node ID {gpu_node.id} ({gpu_node.name}), "
f"Inclusive Duration: {gpu_node.inclusive_dur}, "
f"Exclusive Duration: {gpu_node.exclusive_dur}.")

cpu_node_first = copy.deepcopy(cpu_node)
cpu_node_first.id = self.id_assigner.assign_unique_id(cpu_node.id)
cpu_node_first.ts = cpu_node.ts
cpu_node_first.exclusive_dur = int(cpu_node.exclusive_dur / 2)
cpu_node_first.set_child_gpu(gpu_node)
if cpu_node_first.ts >= gpu_node.ts or cpu_node_first.inclusive_dur <= 0:
err_msg = (f"Invalid timestamps for the first split CPU node derived from {original_cpu_info}\n"
f"\tFirst Split CPU Node Timestamp: {cpu_node_first.ts}, \n"
f"\tGPU Node Timestamp: {gpu_node.ts}, \n"
f"\tFirst Split CPU Node Inclusive Duration: {cpu_node_first.inclusive_dur}, \n"
f"\tFirst Split CPU Node Exclusive Duration: {cpu_node_first.exclusive_dur}.")
self.logger.error(err_msg)
raise ValueError(err_msg)

if cpu_node.parent in self.pytorch_nodes:
self._update_parent_node_children(self.pytorch_nodes, cpu_node, cpu_node_first)
elif cpu_node.parent in updated_pytorch_nodes:
self._update_parent_node_children(updated_pytorch_nodes, cpu_node, cpu_node_first)

self.logger.debug(f"First Split CPU Node ID {cpu_node_first.id} ({cpu_node_first.name}), "
f"Inclusive Duration: {cpu_node_first.inclusive_dur}, "
f"Exclusive Duration: {cpu_node_first.exclusive_dur}.")

gpu_node_id = self.id_assigner.assign_unique_id(gpu_node.id)
gpu_node.id = gpu_node_id
gpu_node.parent = cpu_node_first.id

cpu_node_second = copy.deepcopy(cpu_node)
cpu_node_second.id = self.id_assigner.assign_unique_id(cpu_node.id)
cpu_node_second.ts = gpu_node.ts
cpu_node_second.exclusive_dur = int(cpu_node.exclusive_dur / 2)
cpu_node_second.set_child_gpu(None)
cpu_node_second.parent = cpu_node_first.id
for child_node in cpu_node.children:
child_node.parent = cpu_node_second.id
cpu_node_second.add_child(child_node)
if cpu_node_second.ts <= cpu_node_first.ts or cpu_node_second.inclusive_dur <= 0:
err_msg = (f"Invalid timestamps for the second split CPU node derived from {original_cpu_info}\n"
f"\tFirst Split Timestamp: {cpu_node_first.ts}, \n"
f"\tSecond Split Timestamp: {cpu_node_second.ts}, \n"
f"\tSecond Split Inclusive Duration: {cpu_node_second.inclusive_dur}, "
f"\tSecond Split Exclusive Duration: {cpu_node_second.exclusive_dur}.")
self.logger.error(err_msg)
raise ValueError(err_msg)

self.logger.debug(f"Second Split CPU Node ID {cpu_node_second.id} ({cpu_node_second.name}), "
f"Inclusive Duration: {cpu_node_second.inclusive_dur}, "
f"Exclusive Duration: {cpu_node_second.exclusive_dur}.")

cpu_node_first.add_child(cpu_node_second)
cpu_node_first.add_child(gpu_node)

return cpu_node_first, cpu_node_second, gpu_node

def _update_parent_node_children(self, parent_node_dict: Dict[int, PyTorchNode],
cpu_node: PyTorchNode,
cpu_node_first: PyTorchNode) -> None:
"""
Updates the children of the parent node in the given dictionary.
This method removes the original CPU node from the parent's children list
and adds the first split node.
Args:
parent_node_dict (Dict[int, PyTorchNode]): Dictionary containing the
parent node.
cpu_node (PyTorchNode): Original CPU node being split.
cpu_node_first (PyTorchNode): First split node to add to the parent's
children.
"""
parent_node = parent_node_dict[cpu_node.parent]
parent_node.children = [child for child in parent_node.children
if child.id != cpu_node.id]
parent_node.children.extend([cpu_node_first])

def convert_to_chakra_node(self, pytorch_node: PyTorchNode) -> ChakraNode:
"""
Converts a PyTorchNode to a ChakraNode.
Expand Down Expand Up @@ -702,7 +544,8 @@ def remove_dangling_nodes(self) -> None:
if node_id not in parent_ids and not node.data_deps:
dangling_nodes.append(node)
del self.chakra_nodes[node_id]
del self.pytorch_nodes[node_id]
if node_id in self.pytorch_nodes:
del self.pytorch_nodes[node_id]

if dangling_nodes:
self.logger.info(f"Identified and removed {len(dangling_nodes)} dangling nodes:")
Expand Down
15 changes: 9 additions & 6 deletions et_converter/pytorch_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, node_data: Dict[str, Any]) -> None:
self.node_data = node_data
self.data_deps: List['PyTorchNode'] = []
self.children: List['PyTorchNode'] = []
self.child_gpu: Optional['PyTorchNode'] = None
self.gpu_children: List['PyTorchNode'] = []
self.record_param_comms_node: Optional['PyTorchNode'] = None
self.nccl_node: Optional['PyTorchNode'] = None

Expand Down Expand Up @@ -419,7 +419,9 @@ def inclusive_dur(self) -> int:
Returns:
int: The inclusive duration of the node.
"""
return self.node_data["inclusive_dur"]
if "inclusive_dur" in self.node_data:
return self.node_data["inclusive_dur"]
return 0

@inclusive_dur.setter
def inclusive_dur(self, value: int) -> None:
Expand Down Expand Up @@ -543,14 +545,14 @@ def add_child(self, child_node: 'PyTorchNode') -> None:
"""
self.children.append(child_node)

def set_child_gpu(self, child_gpu_node: Optional['PyTorchNode']) -> None:
def add_gpu_child(self, gpu_child_node: 'PyTorchNode') -> None:
"""
Sets a child GPU node for this node.
Adds a child GPU node for this node.
Args:
child_gpu_node (Optional[PyTorchNode]): The child GPU node to be set.
gpu_child_node (Optional[PyTorchNode]): The child GPU node to be added.
"""
self.child_gpu = child_gpu_node
self.gpu_children.append(gpu_child_node)

def is_record_param_comms_op(self) -> bool:
"""
Expand Down Expand Up @@ -620,6 +622,7 @@ def get_data_type_size(data_type: str) -> int:
"Tensor(int64)": 8,
"Tensor(long)": 8,
"Tensor(c10::Half)": 2,
"Tensor(c10::BFloat16)": 2,
"Tensor(unsigned char)": 1,
"Tensor(long int)": 8,
# TODO: Add more types
Expand Down

0 comments on commit ccbc25d

Please sign in to comment.