From 48b3abc28a778f68887b61600d68545dbb5b2027 Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Tue, 11 Jun 2024 15:50:01 -0400 Subject: [PATCH] Refactor process_thread as a member method for improved testability --- src/trace_link/trace_linker.py | 135 ++++++++++++++++++++++----------- 1 file changed, 91 insertions(+), 44 deletions(-) diff --git a/src/trace_link/trace_linker.py b/src/trace_link/trace_linker.py index fc043783..47605e99 100644 --- a/src/trace_link/trace_linker.py +++ b/src/trace_link/trace_linker.py @@ -104,8 +104,23 @@ def link(self, pytorch_et_file: str, kineto_file: str, output_file: str) -> None self.pytorch_et_file = pytorch_et_file self.kineto_file = kineto_file self.pytorch_ops, kineto_data = self.load_traces(pytorch_et_file, kineto_file) - self.update_kineto_data(kineto_data) - self.enforce_inter_thread_order() + + ( + self.kineto_cpu_ops, + self.kineto_tid_cpu_ops_map, + self.kineto_correlation_cuda_runtime_map, + self.kineto_gpu_ops, + self.kineto_id_arrow_op_map, + self.kineto_id_cuda_launch_op_map, + self.kineto_process_start_time, + self.kineto_process_end_time, + self.kineto_thread_info, + self.kineto_rf_id_to_kineto_op_map, + self.sorted_kineto_cpu_ops, + self.sorted_kineto_cpu_op_ts, + ) = self.update_kineto_data(kineto_data) + + self.kineto_tid_cpu_ops_map = self.enforce_inter_thread_order(self.kineto_tid_cpu_ops_map) self.link_traces() self.dump_pytorch_execution_trace_plus(output_file) @@ -373,27 +388,47 @@ def merge_overlapping_intervals(intervals: List[Tuple[int, int]]) -> List[Tuple[ return merged - def update_kineto_data(self, kineto_data: Dict) -> None: + def update_kineto_data(self, kineto_data: Dict) -> Tuple: """ Update the variables of the TraceLinker class using the data structures from the kineto_data dictionary. Args: kineto_data (Dict): Dictionary containing categorized operators and timing boundaries. + + Returns: + Tuple: Contains all updated variables from the kineto_data dictionary. """ - self.kineto_cpu_ops = kineto_data["kineto_cpu_ops"] - self.kineto_tid_cpu_ops_map = kineto_data["kineto_tid_cpu_ops_map"] - self.kineto_correlation_cuda_runtime_map = kineto_data["kineto_correlation_cuda_runtime_map"] - self.kineto_gpu_ops = kineto_data["kineto_gpu_ops"] - self.kineto_id_arrow_op_map = kineto_data["kineto_id_arrow_op_map"] - self.kineto_id_cuda_launch_op_map = kineto_data["kineto_id_cuda_launch_op_map"] - self.kineto_process_start_time = kineto_data["kineto_process_start_time"] - self.kineto_process_end_time = kineto_data["kineto_process_end_time"] - self.kineto_thread_info = kineto_data["kineto_thread_info"] - self.kineto_rf_id_to_kineto_op_map = {op.rf_id: op for op in self.kineto_cpu_ops if op.rf_id is not None} - self.sorted_kineto_cpu_ops = kineto_data["sorted_kineto_cpu_ops"] - self.sorted_kineto_cpu_op_ts = kineto_data["sorted_kineto_cpu_op_ts"] - - def enforce_inter_thread_order(self, threshold: int = 1000) -> None: + kineto_cpu_ops = kineto_data["kineto_cpu_ops"] + kineto_tid_cpu_ops_map = kineto_data["kineto_tid_cpu_ops_map"] + kineto_correlation_cuda_runtime_map = kineto_data["kineto_correlation_cuda_runtime_map"] + kineto_gpu_ops = kineto_data["kineto_gpu_ops"] + kineto_id_arrow_op_map = kineto_data["kineto_id_arrow_op_map"] + kineto_id_cuda_launch_op_map = kineto_data["kineto_id_cuda_launch_op_map"] + kineto_process_start_time = kineto_data["kineto_process_start_time"] + kineto_process_end_time = kineto_data["kineto_process_end_time"] + kineto_thread_info = kineto_data["kineto_thread_info"] + kineto_rf_id_to_kineto_op_map = {op.rf_id: op for op in kineto_cpu_ops if op.rf_id is not None} + sorted_kineto_cpu_ops = kineto_data["sorted_kineto_cpu_ops"] + sorted_kineto_cpu_op_ts = kineto_data["sorted_kineto_cpu_op_ts"] + + return ( + kineto_cpu_ops, + kineto_tid_cpu_ops_map, + kineto_correlation_cuda_runtime_map, + kineto_gpu_ops, + kineto_id_arrow_op_map, + kineto_id_cuda_launch_op_map, + kineto_process_start_time, + kineto_process_end_time, + kineto_thread_info, + kineto_rf_id_to_kineto_op_map, + sorted_kineto_cpu_ops, + sorted_kineto_cpu_op_ts, + ) + + def enforce_inter_thread_order( + self, kineto_tid_cpu_ops_map: Dict[int, List[KinetoOperator]], threshold: int = 1000 + ) -> Dict[int, List[KinetoOperator]]: """ Enforce order between groups of operators in different threads. @@ -406,39 +441,18 @@ def enforce_inter_thread_order(self, threshold: int = 1000) -> None: on the last CPU operator from other threads, enforcing order and dependency across threads. Args: + kineto_tid_cpu_ops_map (Dict[int, List[KinetoOperator]]): Kineto CPU operators grouped by thread ID. threshold (int): Threshold for significant gap detection in microseconds, used to define group boundaries. + + Returns: + Dict[int, List[KinetoOperator]]: Updated map with enforced inter-thread order. """ self.logger.info("Enforcing inter-thread order in Kineto traces.") - def process_thread( - tid: int, - ops: List[KinetoOperator], - ops_by_tid: Dict[int, List[KinetoOperator]], - ) -> None: - self.logger.info(f"Thread {tid}: Identifying gaps for dependency linking with threshold {threshold}us.") - sorted_ops = sorted(ops, key=lambda op: op.timestamp) - last_cpu_node_rf_id = None - - for i, op in enumerate(sorted_ops): - if ( - i == 0 - or (sorted_ops[i].timestamp - sorted_ops[i - 1].timestamp - sorted_ops[i - 1].inclusive_dur) - > threshold - ): - last_cpu_node_rf_id = self.find_last_cpu_node_before_timestamp(ops_by_tid, tid, op.timestamp) - if last_cpu_node_rf_id: - self.logger.debug( - f"Thread {tid}: Linking op '{op.name}' to CPU node before gap with rf_id " - f"'{last_cpu_node_rf_id}'." - ) - - if last_cpu_node_rf_id: - op.inter_thread_dep = last_cpu_node_rf_id - with ThreadPoolExecutor() as executor: futures = { - executor.submit(process_thread, tid, ops, self.kineto_tid_cpu_ops_map): tid - for tid, ops in self.kineto_tid_cpu_ops_map.items() + executor.submit(self.process_thread, tid, ops, kineto_tid_cpu_ops_map, threshold): tid + for tid, ops in kineto_tid_cpu_ops_map.items() } for future in as_completed(futures): @@ -449,6 +463,39 @@ def process_thread( except Exception as e: self.logger.error(f"Error processing thread {tid}: {e}") + return kineto_tid_cpu_ops_map + + def process_thread( + self, tid: int, ops: List[KinetoOperator], ops_by_tid: Dict[int, List[KinetoOperator]], threshold: int + ) -> None: + """ + Process a single thread's operators to enforce inter-thread order. + + Args: + tid (int): Thread ID. + ops (List[KinetoOperator]): List of Kineto operators for the thread. + ops_by_tid (Dict[int, List[KinetoOperator]]): Kineto operators grouped by thread ID. + threshold (int): Threshold for significant gap detection in microseconds. + """ + self.logger.info(f"Thread {tid}: Identifying gaps for dependency linking with threshold {threshold}us.") + sorted_ops = sorted(ops, key=lambda op: op.timestamp) + last_cpu_node_rf_id = None + + for i, op in enumerate(sorted_ops): + if ( + i == 0 + or (sorted_ops[i].timestamp - sorted_ops[i - 1].timestamp - sorted_ops[i - 1].inclusive_dur) > threshold + ): + last_cpu_node_rf_id = self.find_last_cpu_node_before_timestamp(ops_by_tid, tid, op.timestamp) + if last_cpu_node_rf_id: + self.logger.debug( + f"Thread {tid}: Linking op '{op.name}' to CPU node before gap with rf_id " + f"'{last_cpu_node_rf_id}'." + ) + + if last_cpu_node_rf_id: + op.inter_thread_dep = last_cpu_node_rf_id + def find_last_cpu_node_before_timestamp( self, ops_by_tid: Dict[int, List[KinetoOperator]],