Skip to content

Commit

Permalink
Fix error encoding METADATA node
Browse files Browse the repository at this point in the history
  • Loading branch information
JoongunPark committed Jan 5, 2025
1 parent 31b25d4 commit a01ce81
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 4 deletions.
3 changes: 3 additions & 0 deletions src/converter/pytorch_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
COMM_RECV_NODE,
COMM_SEND_NODE,
COMP_NODE,
METADATA_NODE,
REDUCE_SCATTER,
GlobalMetadata,
)
Expand Down Expand Up @@ -338,6 +339,8 @@ def get_protobuf_node_type_from_json_node(
Returns:
int: The corresponding Chakra node type.
"""
if json_node.is_metadata_op():
return METADATA_NODE
if json_node.is_gpu_op():
if "ncclDevKernel_SendRecv" in json_node.name:
parent_node = json_node_map[json_node.parent]
Expand Down
9 changes: 9 additions & 0 deletions src/converter/pytorch_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,15 @@ def get_op_type(self) -> PyTorchNodeType:
else:
return PyTorchNodeType.LABEL

def is_metadata_op(self) -> bool:
"""
Check if the node is a METADATA operator.
Returns
bool: True if the node is a METADATA operator, False otherwise.
"""
return self.get_op_type() == PyTorchNodeType.METADATA

def is_cpu_op(self) -> bool:
"""
Check if the node is a CPU operator.
Expand Down
11 changes: 7 additions & 4 deletions tests/converter/test_pytorch_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
BROADCAST,
COMM_COLL_NODE,
COMP_NODE,
METADATA_NODE,
REDUCE_SCATTER,
)
from chakra.schema.protobuf.et_def_pb2 import Node as ChakraNode
Expand Down Expand Up @@ -167,17 +168,19 @@ def test_write_chakra_et(mock_file: MagicMock, sample_pytorch_data: Dict) -> Non
@pytest.mark.parametrize(
"pytorch_node_data, expected_type",
[
({"name": "ncclKernel", "is_gpu_op": True}, COMM_COLL_NODE),
({"name": "ncclDevKernel", "is_gpu_op": True}, COMM_COLL_NODE),
({"name": "c10d::all_reduce", "is_gpu_op": True}, COMP_NODE),
({"name": "other_op", "is_gpu_op": False}, COMP_NODE),
({"name": "process_group:init", "is_gpu_op": False, "is_metadata_op": True}, METADATA_NODE),
({"name": "ncclKernel", "is_gpu_op": True, "is_metadata_op": False}, COMM_COLL_NODE),
({"name": "ncclDevKernel", "is_gpu_op": True, "is_metadata_op": False}, COMM_COLL_NODE),
({"name": "c10d::all_reduce", "is_gpu_op": True, "is_metadata_op": False}, COMP_NODE),
({"name": "other_op", "is_gpu_op": False, "is_metadata_op": False}, COMP_NODE),
],
)
def test_get_protobuf_node_type_from_json_node(pytorch_node_data: Dict, expected_type: int) -> None:
# Create a mock PyTorchNode with the required attributes
pytorch_node = MagicMock(spec=PyTorchNode)
pytorch_node.name = pytorch_node_data["name"]
pytorch_node.is_gpu_op = MagicMock(return_value=pytorch_node_data["is_gpu_op"])
pytorch_node.is_metadata_op = MagicMock(return_value=pytorch_node_data["is_metadata_op"])

# Create a mock json_node_map dictionary with actual PyTorchNode instances
mock_pytorch_node_data = {
Expand Down

0 comments on commit a01ce81

Please sign in to comment.