Skip to content

Commit

Permalink
Refactor process_thread as a member method for improved testability
Browse files Browse the repository at this point in the history
  • Loading branch information
TaekyungHeo committed Jun 11, 2024
1 parent 979c3e0 commit 48b3abc
Showing 1 changed file with 91 additions and 44 deletions.
135 changes: 91 additions & 44 deletions src/trace_link/trace_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,23 @@ def link(self, pytorch_et_file: str, kineto_file: str, output_file: str) -> None
self.pytorch_et_file = pytorch_et_file
self.kineto_file = kineto_file
self.pytorch_ops, kineto_data = self.load_traces(pytorch_et_file, kineto_file)
self.update_kineto_data(kineto_data)
self.enforce_inter_thread_order()

(
self.kineto_cpu_ops,
self.kineto_tid_cpu_ops_map,
self.kineto_correlation_cuda_runtime_map,
self.kineto_gpu_ops,
self.kineto_id_arrow_op_map,
self.kineto_id_cuda_launch_op_map,
self.kineto_process_start_time,
self.kineto_process_end_time,
self.kineto_thread_info,
self.kineto_rf_id_to_kineto_op_map,
self.sorted_kineto_cpu_ops,
self.sorted_kineto_cpu_op_ts,
) = self.update_kineto_data(kineto_data)

self.kineto_tid_cpu_ops_map = self.enforce_inter_thread_order(self.kineto_tid_cpu_ops_map)
self.link_traces()
self.dump_pytorch_execution_trace_plus(output_file)

Expand Down Expand Up @@ -373,27 +388,47 @@ def merge_overlapping_intervals(intervals: List[Tuple[int, int]]) -> List[Tuple[

return merged

def update_kineto_data(self, kineto_data: Dict) -> None:
def update_kineto_data(self, kineto_data: Dict) -> Tuple:
"""
Update the variables of the TraceLinker class using the data structures from the kineto_data dictionary.
Args:
kineto_data (Dict): Dictionary containing categorized operators and timing boundaries.
Returns:
Tuple: Contains all updated variables from the kineto_data dictionary.
"""
self.kineto_cpu_ops = kineto_data["kineto_cpu_ops"]
self.kineto_tid_cpu_ops_map = kineto_data["kineto_tid_cpu_ops_map"]
self.kineto_correlation_cuda_runtime_map = kineto_data["kineto_correlation_cuda_runtime_map"]
self.kineto_gpu_ops = kineto_data["kineto_gpu_ops"]
self.kineto_id_arrow_op_map = kineto_data["kineto_id_arrow_op_map"]
self.kineto_id_cuda_launch_op_map = kineto_data["kineto_id_cuda_launch_op_map"]
self.kineto_process_start_time = kineto_data["kineto_process_start_time"]
self.kineto_process_end_time = kineto_data["kineto_process_end_time"]
self.kineto_thread_info = kineto_data["kineto_thread_info"]
self.kineto_rf_id_to_kineto_op_map = {op.rf_id: op for op in self.kineto_cpu_ops if op.rf_id is not None}
self.sorted_kineto_cpu_ops = kineto_data["sorted_kineto_cpu_ops"]
self.sorted_kineto_cpu_op_ts = kineto_data["sorted_kineto_cpu_op_ts"]

def enforce_inter_thread_order(self, threshold: int = 1000) -> None:
kineto_cpu_ops = kineto_data["kineto_cpu_ops"]
kineto_tid_cpu_ops_map = kineto_data["kineto_tid_cpu_ops_map"]
kineto_correlation_cuda_runtime_map = kineto_data["kineto_correlation_cuda_runtime_map"]
kineto_gpu_ops = kineto_data["kineto_gpu_ops"]
kineto_id_arrow_op_map = kineto_data["kineto_id_arrow_op_map"]
kineto_id_cuda_launch_op_map = kineto_data["kineto_id_cuda_launch_op_map"]
kineto_process_start_time = kineto_data["kineto_process_start_time"]
kineto_process_end_time = kineto_data["kineto_process_end_time"]
kineto_thread_info = kineto_data["kineto_thread_info"]
kineto_rf_id_to_kineto_op_map = {op.rf_id: op for op in kineto_cpu_ops if op.rf_id is not None}
sorted_kineto_cpu_ops = kineto_data["sorted_kineto_cpu_ops"]
sorted_kineto_cpu_op_ts = kineto_data["sorted_kineto_cpu_op_ts"]

return (
kineto_cpu_ops,
kineto_tid_cpu_ops_map,
kineto_correlation_cuda_runtime_map,
kineto_gpu_ops,
kineto_id_arrow_op_map,
kineto_id_cuda_launch_op_map,
kineto_process_start_time,
kineto_process_end_time,
kineto_thread_info,
kineto_rf_id_to_kineto_op_map,
sorted_kineto_cpu_ops,
sorted_kineto_cpu_op_ts,
)

def enforce_inter_thread_order(
self, kineto_tid_cpu_ops_map: Dict[int, List[KinetoOperator]], threshold: int = 1000
) -> Dict[int, List[KinetoOperator]]:
"""
Enforce order between groups of operators in different threads.
Expand All @@ -406,39 +441,18 @@ def enforce_inter_thread_order(self, threshold: int = 1000) -> None:
on the last CPU operator from other threads, enforcing order and dependency across threads.
Args:
kineto_tid_cpu_ops_map (Dict[int, List[KinetoOperator]]): Kineto CPU operators grouped by thread ID.
threshold (int): Threshold for significant gap detection in microseconds, used to define group boundaries.
Returns:
Dict[int, List[KinetoOperator]]: Updated map with enforced inter-thread order.
"""
self.logger.info("Enforcing inter-thread order in Kineto traces.")

def process_thread(
tid: int,
ops: List[KinetoOperator],
ops_by_tid: Dict[int, List[KinetoOperator]],
) -> None:
self.logger.info(f"Thread {tid}: Identifying gaps for dependency linking with threshold {threshold}us.")
sorted_ops = sorted(ops, key=lambda op: op.timestamp)
last_cpu_node_rf_id = None

for i, op in enumerate(sorted_ops):
if (
i == 0
or (sorted_ops[i].timestamp - sorted_ops[i - 1].timestamp - sorted_ops[i - 1].inclusive_dur)
> threshold
):
last_cpu_node_rf_id = self.find_last_cpu_node_before_timestamp(ops_by_tid, tid, op.timestamp)
if last_cpu_node_rf_id:
self.logger.debug(
f"Thread {tid}: Linking op '{op.name}' to CPU node before gap with rf_id "
f"'{last_cpu_node_rf_id}'."
)

if last_cpu_node_rf_id:
op.inter_thread_dep = last_cpu_node_rf_id

with ThreadPoolExecutor() as executor:
futures = {
executor.submit(process_thread, tid, ops, self.kineto_tid_cpu_ops_map): tid
for tid, ops in self.kineto_tid_cpu_ops_map.items()
executor.submit(self.process_thread, tid, ops, kineto_tid_cpu_ops_map, threshold): tid
for tid, ops in kineto_tid_cpu_ops_map.items()
}

for future in as_completed(futures):
Expand All @@ -449,6 +463,39 @@ def process_thread(
except Exception as e:
self.logger.error(f"Error processing thread {tid}: {e}")

return kineto_tid_cpu_ops_map

def process_thread(
self, tid: int, ops: List[KinetoOperator], ops_by_tid: Dict[int, List[KinetoOperator]], threshold: int
) -> None:
"""
Process a single thread's operators to enforce inter-thread order.
Args:
tid (int): Thread ID.
ops (List[KinetoOperator]): List of Kineto operators for the thread.
ops_by_tid (Dict[int, List[KinetoOperator]]): Kineto operators grouped by thread ID.
threshold (int): Threshold for significant gap detection in microseconds.
"""
self.logger.info(f"Thread {tid}: Identifying gaps for dependency linking with threshold {threshold}us.")
sorted_ops = sorted(ops, key=lambda op: op.timestamp)
last_cpu_node_rf_id = None

for i, op in enumerate(sorted_ops):
if (
i == 0
or (sorted_ops[i].timestamp - sorted_ops[i - 1].timestamp - sorted_ops[i - 1].inclusive_dur) > threshold
):
last_cpu_node_rf_id = self.find_last_cpu_node_before_timestamp(ops_by_tid, tid, op.timestamp)
if last_cpu_node_rf_id:
self.logger.debug(
f"Thread {tid}: Linking op '{op.name}' to CPU node before gap with rf_id "
f"'{last_cpu_node_rf_id}'."
)

if last_cpu_node_rf_id:
op.inter_thread_dep = last_cpu_node_rf_id

def find_last_cpu_node_before_timestamp(
self,
ops_by_tid: Dict[int, List[KinetoOperator]],
Expand Down

0 comments on commit 48b3abc

Please sign in to comment.