Skip to content

Commit

Permalink
Refactor trace_link for better maintainability and readability
Browse files Browse the repository at this point in the history
  • Loading branch information
TaekyungHeo committed Jul 13, 2024
1 parent c2dd55a commit d50225b
Show file tree
Hide file tree
Showing 10 changed files with 663 additions and 630 deletions.
14 changes: 6 additions & 8 deletions USER_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,20 @@ $ pip uninstall chakra

## Tools Overview
### Execution Trace Link (chakra_trace_link)
Merge PyTorch Chakra host trace and Kineto trace to encode GPU operators into the output execution trace.

Merge Chakra host execution trace and Chakra device execution trace to encode GPU operators into the output execution trace.
```bash
$ chakra_trace_link \
--pytorch-et-file /path/to/pytorch_et \
--kineto-file /path/to/kineto \
--output-file /path/to/merged_et
--chakra-host-trace /path/to/chakra_host_trace \
--chakra-device-trace /path/to/chakra_device_trace \
--output-file /path/to/chakra_host_device_trace.json
```

### Execution Trace Converter (chakra_converter)
Converts the merged execution traces into the Chakra schema.

```bash
$ chakra_converter \
--input_filename /path/to/merged_et \
--output_filename /path/to/chakra_et \
--input_filename /path/to/chakra_host_device_trace.json \
--output_filename /path/to/chakra_trace \
--input_type <input_type>
```

Expand Down
251 changes: 251 additions & 0 deletions src/trace_link/chakra_device_trace_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
import logging
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Tuple

from et_replay.lib.utils import read_dictionary_from_json_file

from .kineto_operator import KinetoOperator


class ChakraDeviceTraceLoader:
"""Loads Chakra device traces."""

def load(
self, chakra_device_trace: str
) -> Tuple[
List[KinetoOperator],
Dict[int, List[KinetoOperator]],
Dict[int, KinetoOperator],
List[KinetoOperator],
Dict[int, KinetoOperator],
Dict[int, KinetoOperator],
int,
int,
Dict[int, Tuple[int, int]],
Dict[int, KinetoOperator],
List[KinetoOperator],
List[int],
]:
"""
Load and process the Chakra device trace.
Args:
chakra_device_trace (str): Path to the Chakra device trace file.
Returns:
Tuple containing various data structures needed for linking traces.
"""
logging.debug(f"Starting to load Chakra device trace from file: {chakra_device_trace}.")
chakra_trace_data = read_dictionary_from_json_file(chakra_device_trace)
sorted_kineto_ops = sorted(
[KinetoOperator(op) for op in chakra_trace_data["traceEvents"]],
key=lambda op: op.timestamp,
)

dev_data = self.construct_dev_data_structures(sorted_kineto_ops, chakra_device_trace)
self.calculate_exclusive_dur(dev_data["kineto_tid_cpu_ops_map"])

dev_data["sorted_kineto_cpu_ops"] = sorted(dev_data["kineto_cpu_ops"], key=lambda op: op.timestamp)
dev_data["sorted_kineto_cpu_op_ts"] = [op.timestamp for op in dev_data["sorted_kineto_cpu_ops"]]

logging.debug(
f"Processed Chakra device trace with {len(dev_data['kineto_cpu_ops'])} CPU ops, "
f"{len(dev_data['kineto_id_cuda_launch_op_map'])} CPU launcher ops, "
f"and {len(dev_data['kineto_gpu_ops'])} GPU ops."
)
logging.debug("Chakra device trace has been loaded and processed successfully.")
return (
dev_data["kineto_cpu_ops"],
dev_data["kineto_tid_cpu_ops_map"],
dev_data["kineto_correlation_cuda_runtime_map"],
dev_data["kineto_gpu_ops"],
dev_data["kineto_id_arrow_op_map"],
dev_data["kineto_id_cuda_launch_op_map"],
dev_data["kineto_process_start_time"],
dev_data["kineto_process_end_time"],
dev_data["kineto_thread_info"],
dev_data["kineto_rf_id_to_kineto_op_map"],
dev_data["sorted_kineto_cpu_ops"],
dev_data["sorted_kineto_cpu_op_ts"],
)

def construct_dev_data_structures(self, kineto_ops: List[KinetoOperator], trace_file: str) -> Dict:
"""
Construct necessary data structures required for trace linking from the provided Kineto operators.
This method identifies process start time, end time, thread start time, and end time, and also categorizes
operators into CPU, GPU, and other relevant groups.
Args:
kineto_ops (List[KinetoOperator]): List of Kineto operators to categorize.
trace_file (str): Path to the trace file for logging purposes.
Returns:
Dict: Dictionary containing categorized operators and timing boundaries.
"""
logging.debug("Categorizing Kineto operators and calculating timing boundaries.")
process_start_time = sys.maxsize
process_end_time = 0
thread_info = {}

kineto_cpu_ops = []
kineto_tid_cpu_ops_map = {}
kineto_correlation_cuda_runtime_map = {}
kineto_gpu_ops = []
kineto_id_arrow_op_map = {}
kineto_id_cuda_launch_op_map = {}

for op in kineto_ops:
if op.is_cpu_op():
kineto_cpu_ops.append(op)
kineto_tid_cpu_ops_map.setdefault(op.tid, []).append(op)
logging.debug(f"Added CPU or user annotation op: {op.name}")

elif op.is_kernel_launch_op():
kineto_id_cuda_launch_op_map[op.external_id] = op
if op.correlation in kineto_correlation_cuda_runtime_map:
error_msg = (
f"Duplicate correlation ID {op.correlation} found in kineto_id_cuda_launch_op_map. "
"The kineto_id_cuda_launch_op_map works as a mapping to link GPU operators with the launcher "
"CPU operator for the GPU operator. The correlation field works as a link, and this map has a "
"mapping between the correlation and the launcher operator. Each kernel launch operator "
"should have a unique correlation ID for linking it to a GPU operator. Therefore, duplicated "
"correlation is not expected in the map. Please review the file manually to see if the "
f"operator has an invalid correlation value in file: {trace_file}."
)
logging.error(error_msg)
raise ValueError(error_msg)
kineto_correlation_cuda_runtime_map[op.correlation] = op
logging.debug(f"Added CPU launcher op: {op.name}")

elif op.is_gpu_op():
kineto_gpu_ops.append(op)
logging.debug(f"Added GPU op: {op.name}")

elif op.is_ac2g_op(): # arrow from CPU to GPU
assert (op.phase == "s") or (op.phase == "f")
if op.id is None:
error_msg = (
f"'id' field is None in Kineto operator: {op} in file: {trace_file}. This is unexpected as "
"'id' should generally be populated for 'ac2g' operators. Please verify the validity of "
"the Kineto trace and the operator data."
)
logging.error(error_msg)
raise KeyError(error_msg)

kineto_id_arrow_op_map[op.id] = op

# Update timing boundaries
if op.tid is not None:
process_start_time = min(process_start_time, op.timestamp)
process_end_time = max(process_end_time, op.timestamp + op.inclusive_dur)
thread_start_end = thread_info.setdefault(op.tid, [sys.maxsize, 0])
thread_start_end[0] = min(thread_start_end[0], op.timestamp)
thread_start_end[1] = max(thread_start_end[1], op.timestamp + op.inclusive_dur)

kineto_rf_id_to_kineto_op_map = {op.rf_id: op for op in kineto_cpu_ops if op.rf_id is not None}

return {
"kineto_cpu_ops": kineto_cpu_ops,
"kineto_tid_cpu_ops_map": kineto_tid_cpu_ops_map,
"kineto_correlation_cuda_runtime_map": kineto_correlation_cuda_runtime_map,
"kineto_gpu_ops": kineto_gpu_ops,
"kineto_id_arrow_op_map": kineto_id_arrow_op_map,
"kineto_id_cuda_launch_op_map": kineto_id_cuda_launch_op_map,
"kineto_process_start_time": process_start_time,
"kineto_process_end_time": process_end_time,
"kineto_thread_info": thread_info,
"kineto_rf_id_to_kineto_op_map": kineto_rf_id_to_kineto_op_map,
"sorted_kineto_cpu_ops": [],
"sorted_kineto_cpu_op_ts": [],
}

def calculate_exclusive_dur(self, kineto_tid_cpu_ops_map: Dict[int, List[KinetoOperator]]) -> None:
"""
Calculate the exclusive duration of each operator in the Kineto traces in parallel.
The exclusive duration is defined as the total duration of the operator minus any time spent in child operators,
effectively representing the time spent exclusively in that operator.
Args:
kineto_tid_cpu_ops_map (Dict[int, List[KinetoOperator]]): Map of thread IDs to their corresponding Kineto
operators.
"""
logging.debug("Calculating exclusive durations for Kineto operators in parallel.")

def process_ops_for_thread(ops: List[KinetoOperator]) -> None:
logging.debug(f"Processing {len(ops)} operators in thread.")
sorted_ops = sorted(ops, key=lambda op: (op.timestamp, op.inclusive_dur))
for i, op in enumerate(sorted_ops):
exclusive_dur = op.inclusive_dur
overlapping_regions = []

# Identify overlapping regions with child operators
for child_op in sorted_ops[i + 1 :]:
if child_op.timestamp >= op.timestamp and (child_op.timestamp + child_op.inclusive_dur) <= (
op.timestamp + op.inclusive_dur
):
overlap_start = child_op.timestamp
overlap_end = child_op.timestamp + child_op.inclusive_dur
overlapping_regions.append((overlap_start, overlap_end))
if (op.timestamp + op.inclusive_dur) < child_op.timestamp:
break

# Merge overlapping regions and calculate exclusive duration
merged_regions = self.merge_overlapping_intervals(overlapping_regions)
for start, end in merged_regions:
exclusive_dur -= end - start

# Check if exclusive_dur is not negative or zero
if exclusive_dur < 0:
error_msg = (
f"Exclusive duration calculation error for node '{op.name}' "
f"(ts: {op.timestamp}, inclusive_dur: {op.inclusive_dur}, rf_id: {op.rf_id}): "
f"Duration cannot be less than zero."
)
logging.error(error_msg)
raise ValueError(error_msg)

op.exclusive_dur = exclusive_dur
logging.debug(
f"Node '{op.name}' (ts: {op.timestamp}, inclusive_dur: {op.inclusive_dur}, "
f"rf_id: {op.rf_id}) exclusive duration: {op.exclusive_dur} microseconds."
)

with ThreadPoolExecutor() as executor:
futures = [executor.submit(process_ops_for_thread, ops) for ops in kineto_tid_cpu_ops_map.values()]

for future in as_completed(futures):
future.result() # Wait for all threads to complete and handle any exceptions

logging.debug("Exclusive durations for Kineto operators calculated successfully.")

@staticmethod
def merge_overlapping_intervals(intervals: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
"""
Merge overlapping intervals into a single interval.
Args:
intervals (List[Tuple[int, int]]): List of intervals.
Returns:
List[Tuple[int, int]]: List of merged intervals.
"""
if not intervals:
return []

# Sort intervals based on the start time
intervals.sort(key=lambda x: x[0])
merged = [intervals[0]]

for current in intervals:
prev = merged[-1]
if current[0] <= prev[1]:
# There is overlap, merge the current interval with the previous one
merged[-1] = (prev[0], max(prev[1], current[1]))
else:
# No overlap, add the current interval
merged.append(current)

return merged
57 changes: 57 additions & 0 deletions src/trace_link/chakra_host_trace_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import logging
import sys
from typing import List

from et_replay.lib.execution_trace import Node as PyTorchOperator
from et_replay.lib.utils import load_execution_trace_file

# Increase the recursion limit for deep Chakra host execution traces.
sys.setrecursionlimit(10**6)


class ChakraHostTraceLoader:
"""Loads Chakra host traces."""

def load(self, chakra_host_trace_file: str) -> List[PyTorchOperator]:
"""
Load and process the Chakra Host Execution Trace.
Args:
chakra_host_trace_file (str): Path to the PyTorch execution trace file.
Returns:
List[PyTorchOperator]: List of PyTorch operators.
"""
logging.debug(f"Starting to load Chakra host execution trace from file: {chakra_host_trace_file}.")
chakra_host_trace = load_execution_trace_file(chakra_host_trace_file)

root_node = chakra_host_trace.get_nodes()[1] # Root node is usually 1-based
chakra_host_ops = self.extract_chakra_host_ops(root_node)
logging.debug(f"Extracted {len(chakra_host_ops)} operators from Chakra host execution trace.")
logging.debug("Chakra host execution trace has been loaded and processed successfully.")

return chakra_host_ops

def extract_chakra_host_ops(self, node: PyTorchOperator) -> List[PyTorchOperator]:
"""
Extract and sort nodes from the PyTorch execution trace recursively.
This method traverses the execution trace starting from the provided node, extracting all the operator nodes
recursively, and then returns them sorted by their identifiers.
Args:
node (PyTorchOperator): Starting node for extraction.
Returns:
List[PyTorchOperator]: Sorted list of extracted PyTorchOperator nodes.
"""
nodes = []

def traverse(node: PyTorchOperator):
nodes.append(node)
for child in node.children:
traverse(child)

traverse(node)
logging.debug(f"Traversed {len(nodes)} nodes from root node ID: {node.id}")
return sorted(nodes, key=lambda x: x.id)
10 changes: 5 additions & 5 deletions src/trace_link/kineto_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class KinetoOperator:
external_id (int): An external identifier associated with the operator.
ev_idx (int): Event index of the operator.
tid (int): Thread identifier where the operator was executed.
pytorch_op (Optional[PyTorchOperator]): Corresponding PyTorch operator object.
parent_pytorch_op_id (Optional[int]): ID of the parent PyTorch operator.
host_op (Optional[PyTorchOperator]): Corresponding PyTorch operator object.
parent_host_op_id (Optional[int]): ID of the parent PyTorch operator.
inter_thread_dep (Optional[int]): Identifier for inter-thread dependencies.
stream (Optional[int]): CUDA stream identifier associated with the operator.
rf_id (Optional[int]): Record function identifier.
Expand All @@ -45,8 +45,8 @@ def __init__(self, kineto_op: Dict[str, Any]) -> None:
self.external_id: int = int(kineto_op.get("args", {}).get("External id", -1))
self.ev_idx: int = int(kineto_op.get("args", {}).get("Ev Idx", -1))
self.tid: int = kineto_op.get("tid", 0)
self.pytorch_op: Optional[PyTorchOperator] = None
self.parent_pytorch_op_id: Optional[int] = None
self.host_op: Optional[PyTorchOperator] = None
self.parent_host_op_id: Optional[int] = None
self.inter_thread_dep: Optional[int] = None
self.stream: Optional[int] = kineto_op.get("args", {}).get("stream", None)
self.rf_id: Optional[int] = kineto_op.get("args", {}).get("Record function id", None)
Expand All @@ -64,7 +64,7 @@ def __repr__(self) -> str:
f"phase={self.phase}, inclusive_dur={self.inclusive_dur}, "
f"exclusive_dur={self.exclusive_dur}, timestamp={self.timestamp}, "
f"external_id={self.external_id}, ev_idx={self.ev_idx}, tid={self.tid}, "
f"parent_pytorch_op_id={self.parent_pytorch_op_id}, inter_thread_dep={self.inter_thread_dep}, "
f"parent_host_op_id={self.parent_host_op_id}, inter_thread_dep={self.inter_thread_dep}, "
f"stream={self.stream}, rf_id={self.rf_id}, correlation={self.correlation})"
)

Expand Down
Loading

0 comments on commit d50225b

Please sign in to comment.