From a01ce81b73f0e276eb67ae719227c891f73c5df3 Mon Sep 17 00:00:00 2001 From: JoongunPark <8554137+JoongunPark@users.noreply.github.com> Date: Sun, 1 Dec 2024 16:16:12 -0500 Subject: [PATCH] Fix error encoding METADATA node --- src/converter/pytorch_converter.py | 3 +++ src/converter/pytorch_node.py | 9 +++++++++ tests/converter/test_pytorch_converter.py | 11 +++++++---- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/converter/pytorch_converter.py b/src/converter/pytorch_converter.py index 48f307c4..b5ace295 100644 --- a/src/converter/pytorch_converter.py +++ b/src/converter/pytorch_converter.py @@ -11,6 +11,7 @@ COMM_RECV_NODE, COMM_SEND_NODE, COMP_NODE, + METADATA_NODE, REDUCE_SCATTER, GlobalMetadata, ) @@ -338,6 +339,8 @@ def get_protobuf_node_type_from_json_node( Returns: int: The corresponding Chakra node type. """ + if json_node.is_metadata_op(): + return METADATA_NODE if json_node.is_gpu_op(): if "ncclDevKernel_SendRecv" in json_node.name: parent_node = json_node_map[json_node.parent] diff --git a/src/converter/pytorch_node.py b/src/converter/pytorch_node.py index 86b59acc..0729b6b7 100644 --- a/src/converter/pytorch_node.py +++ b/src/converter/pytorch_node.py @@ -137,6 +137,15 @@ def get_op_type(self) -> PyTorchNodeType: else: return PyTorchNodeType.LABEL + def is_metadata_op(self) -> bool: + """ + Check if the node is a METADATA operator. + + Returns + bool: True if the node is a METADATA operator, False otherwise. + """ + return self.get_op_type() == PyTorchNodeType.METADATA + def is_cpu_op(self) -> bool: """ Check if the node is a CPU operator. diff --git a/tests/converter/test_pytorch_converter.py b/tests/converter/test_pytorch_converter.py index 88f2abf9..c01c7569 100644 --- a/tests/converter/test_pytorch_converter.py +++ b/tests/converter/test_pytorch_converter.py @@ -10,6 +10,7 @@ BROADCAST, COMM_COLL_NODE, COMP_NODE, + METADATA_NODE, REDUCE_SCATTER, ) from chakra.schema.protobuf.et_def_pb2 import Node as ChakraNode @@ -167,10 +168,11 @@ def test_write_chakra_et(mock_file: MagicMock, sample_pytorch_data: Dict) -> Non @pytest.mark.parametrize( "pytorch_node_data, expected_type", [ - ({"name": "ncclKernel", "is_gpu_op": True}, COMM_COLL_NODE), - ({"name": "ncclDevKernel", "is_gpu_op": True}, COMM_COLL_NODE), - ({"name": "c10d::all_reduce", "is_gpu_op": True}, COMP_NODE), - ({"name": "other_op", "is_gpu_op": False}, COMP_NODE), + ({"name": "process_group:init", "is_gpu_op": False, "is_metadata_op": True}, METADATA_NODE), + ({"name": "ncclKernel", "is_gpu_op": True, "is_metadata_op": False}, COMM_COLL_NODE), + ({"name": "ncclDevKernel", "is_gpu_op": True, "is_metadata_op": False}, COMM_COLL_NODE), + ({"name": "c10d::all_reduce", "is_gpu_op": True, "is_metadata_op": False}, COMP_NODE), + ({"name": "other_op", "is_gpu_op": False, "is_metadata_op": False}, COMP_NODE), ], ) def test_get_protobuf_node_type_from_json_node(pytorch_node_data: Dict, expected_type: int) -> None: @@ -178,6 +180,7 @@ def test_get_protobuf_node_type_from_json_node(pytorch_node_data: Dict, expected pytorch_node = MagicMock(spec=PyTorchNode) pytorch_node.name = pytorch_node_data["name"] pytorch_node.is_gpu_op = MagicMock(return_value=pytorch_node_data["is_gpu_op"]) + pytorch_node.is_metadata_op = MagicMock(return_value=pytorch_node_data["is_metadata_op"]) # Create a mock json_node_map dictionary with actual PyTorchNode instances mock_pytorch_node_data = {