diff --git a/src/trace_link/trace_linker.py b/src/trace_link/trace_linker.py index e3ea96c9..fc043783 100644 --- a/src/trace_link/trace_linker.py +++ b/src/trace_link/trace_linker.py @@ -103,29 +103,42 @@ def link(self, pytorch_et_file: str, kineto_file: str, output_file: str) -> None """ self.pytorch_et_file = pytorch_et_file self.kineto_file = kineto_file - self.load_traces() + self.pytorch_ops, kineto_data = self.load_traces(pytorch_et_file, kineto_file) + self.update_kineto_data(kineto_data) self.enforce_inter_thread_order() self.link_traces() self.dump_pytorch_execution_trace_plus(output_file) - def load_traces(self) -> None: - """Load both PyTorch Execution Traces and Kineto Traces.""" - self.pytorch_ops = self.load_pytorch_et() - kineto_data = self.load_kineto_trace() - self.update_kineto_data(kineto_data) + def load_traces(self, pytorch_et_file: str, kineto_file: str) -> Tuple[List[PyTorchOperator], Dict]: + """ + Load both PyTorch Execution Traces and Kineto Traces. + + Args: + pytorch_et_file (str): Path to the PyTorch execution trace file. + kineto_file (str): Path to the Kineto trace file. + + Returns: + Tuple: A tuple containing the list of PyTorch operators and the kineto data dictionary. + """ + pytorch_ops = self.load_pytorch_et(pytorch_et_file) + kineto_data = self.load_kineto_trace(kineto_file) + return pytorch_ops, kineto_data - def load_pytorch_et(self) -> List[PyTorchOperator]: + def load_pytorch_et(self, pytorch_et_file: str) -> List[PyTorchOperator]: """ Load and process the PyTorch Execution Trace. This method handles multiple iterations in the trace and extracts the nodes, considering the specified annotation for segmenting the iterations. - Returns + Args: + pytorch_et_file (str): Path to the PyTorch execution trace file. + + Returns: List[PyTorchOperator]: List of PyTorch operators. """ self.logger.info("Starting to load PyTorch Execution Trace.") - pytorch_et = load_execution_trace_file(self.pytorch_et_file) + pytorch_et = load_execution_trace_file(pytorch_et_file) root_node = pytorch_et.get_nodes()[1] # Root node is usually 1-based pytorch_ops = self.extract_pytorch_ops(root_node) @@ -157,18 +170,21 @@ def traverse(node: PyTorchOperator): traverse(node) return sorted(nodes, key=lambda x: x.id) - def load_kineto_trace(self) -> Dict: + def load_kineto_trace(self, kineto_file: str) -> Dict: """ Load and process the Kineto Trace. This method parses the Kineto trace file, creating KinetoOperator instances for each operator in the trace. It then categorizes and segments these operators for further processing and linking with PyTorch operators. - Returns + Args: + kineto_file (str): Path to the Kineto trace file. + + Returns: Dict: Dictionary containing various data structures needed for linking traces. """ self.logger.info("Starting to load Kineto Trace.") - kineto_trace_data = read_dictionary_from_json_file(self.kineto_file) + kineto_trace_data = read_dictionary_from_json_file(kineto_file) sorted_kineto_ops = sorted( [KinetoOperator(op) for op in kineto_trace_data["traceEvents"]], key=lambda op: op.timestamp, diff --git a/tests/trace_link/test_trace_linker.py b/tests/trace_link/test_trace_linker.py index 7bf35a56..746753c8 100644 --- a/tests/trace_link/test_trace_linker.py +++ b/tests/trace_link/test_trace_linker.py @@ -25,15 +25,11 @@ def test_initialization(trace_linker): @patch("chakra.src.trace_link.trace_linker.TraceLinker.load_pytorch_et") @patch("chakra.src.trace_link.trace_linker.TraceLinker.load_kineto_trace") -@patch("chakra.src.trace_link.trace_linker.TraceLinker.update_kineto_data") -def test_load_traces(mock_update_kineto_data, mock_load_kineto_trace, mock_load_pytorch_et, trace_linker): +def test_load_traces(mock_load_kineto_trace, mock_load_pytorch_et, trace_linker): mock_load_kineto_trace.return_value = {"sample_data": "data"} - trace_linker.pytorch_et_file = "path/to/pytorch_et.json" - trace_linker.kineto_file = "path/to/kineto.json" - trace_linker.load_traces() + trace_linker.load_traces("path/to/pytorch_et.json", "path/to/kineto.json") mock_load_pytorch_et.assert_called_once() mock_load_kineto_trace.assert_called_once() - mock_update_kineto_data.assert_called_once_with({"sample_data": "data"}) def test_construct_kineto_data_structures(trace_linker):