Skip to content

Commit

Permalink
Refactor load_traces for improved testability
Browse files Browse the repository at this point in the history
  • Loading branch information
TaekyungHeo committed Jun 11, 2024
1 parent 663cac6 commit 979c3e0
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 18 deletions.
40 changes: 28 additions & 12 deletions src/trace_link/trace_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,29 +103,42 @@ def link(self, pytorch_et_file: str, kineto_file: str, output_file: str) -> None
"""
self.pytorch_et_file = pytorch_et_file
self.kineto_file = kineto_file
self.load_traces()
self.pytorch_ops, kineto_data = self.load_traces(pytorch_et_file, kineto_file)
self.update_kineto_data(kineto_data)
self.enforce_inter_thread_order()
self.link_traces()
self.dump_pytorch_execution_trace_plus(output_file)

def load_traces(self) -> None:
"""Load both PyTorch Execution Traces and Kineto Traces."""
self.pytorch_ops = self.load_pytorch_et()
kineto_data = self.load_kineto_trace()
self.update_kineto_data(kineto_data)
def load_traces(self, pytorch_et_file: str, kineto_file: str) -> Tuple[List[PyTorchOperator], Dict]:
"""
Load both PyTorch Execution Traces and Kineto Traces.
Args:
pytorch_et_file (str): Path to the PyTorch execution trace file.
kineto_file (str): Path to the Kineto trace file.
Returns:
Tuple: A tuple containing the list of PyTorch operators and the kineto data dictionary.
"""
pytorch_ops = self.load_pytorch_et(pytorch_et_file)
kineto_data = self.load_kineto_trace(kineto_file)
return pytorch_ops, kineto_data

def load_pytorch_et(self) -> List[PyTorchOperator]:
def load_pytorch_et(self, pytorch_et_file: str) -> List[PyTorchOperator]:
"""
Load and process the PyTorch Execution Trace.
This method handles multiple iterations in the trace and extracts the nodes, considering the specified
annotation for segmenting the iterations.
Returns
Args:
pytorch_et_file (str): Path to the PyTorch execution trace file.
Returns:
List[PyTorchOperator]: List of PyTorch operators.
"""
self.logger.info("Starting to load PyTorch Execution Trace.")
pytorch_et = load_execution_trace_file(self.pytorch_et_file)
pytorch_et = load_execution_trace_file(pytorch_et_file)

root_node = pytorch_et.get_nodes()[1] # Root node is usually 1-based
pytorch_ops = self.extract_pytorch_ops(root_node)
Expand Down Expand Up @@ -157,18 +170,21 @@ def traverse(node: PyTorchOperator):
traverse(node)
return sorted(nodes, key=lambda x: x.id)

def load_kineto_trace(self) -> Dict:
def load_kineto_trace(self, kineto_file: str) -> Dict:
"""
Load and process the Kineto Trace.
This method parses the Kineto trace file, creating KinetoOperator instances for each operator in the trace.
It then categorizes and segments these operators for further processing and linking with PyTorch operators.
Returns
Args:
kineto_file (str): Path to the Kineto trace file.
Returns:
Dict: Dictionary containing various data structures needed for linking traces.
"""
self.logger.info("Starting to load Kineto Trace.")
kineto_trace_data = read_dictionary_from_json_file(self.kineto_file)
kineto_trace_data = read_dictionary_from_json_file(kineto_file)
sorted_kineto_ops = sorted(
[KinetoOperator(op) for op in kineto_trace_data["traceEvents"]],
key=lambda op: op.timestamp,
Expand Down
8 changes: 2 additions & 6 deletions tests/trace_link/test_trace_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,11 @@ def test_initialization(trace_linker):

@patch("chakra.src.trace_link.trace_linker.TraceLinker.load_pytorch_et")
@patch("chakra.src.trace_link.trace_linker.TraceLinker.load_kineto_trace")
@patch("chakra.src.trace_link.trace_linker.TraceLinker.update_kineto_data")
def test_load_traces(mock_update_kineto_data, mock_load_kineto_trace, mock_load_pytorch_et, trace_linker):
def test_load_traces(mock_load_kineto_trace, mock_load_pytorch_et, trace_linker):
mock_load_kineto_trace.return_value = {"sample_data": "data"}
trace_linker.pytorch_et_file = "path/to/pytorch_et.json"
trace_linker.kineto_file = "path/to/kineto.json"
trace_linker.load_traces()
trace_linker.load_traces("path/to/pytorch_et.json", "path/to/kineto.json")
mock_load_pytorch_et.assert_called_once()
mock_load_kineto_trace.assert_called_once()
mock_update_kineto_data.assert_called_once_with({"sample_data": "data"})


def test_construct_kineto_data_structures(trace_linker):
Expand Down

0 comments on commit 979c3e0

Please sign in to comment.