diff --git a/tests/shared/core/training_data/test_visualization.py b/tests/shared/core/training_data/test_visualization.py index 50aac37e0114..e99d4ada8ee3 100644 --- a/tests/shared/core/training_data/test_visualization.py +++ b/tests/shared/core/training_data/test_visualization.py @@ -10,6 +10,8 @@ from rasa.shared.nlu.training_data.message import Message from rasa.shared.nlu.training_data.training_data import TrainingData +import pytest + def test_style_transfer(): r = visualization._transfer_style({"class": "dashed great"}, {"class": "myclass"}) @@ -190,19 +192,35 @@ def test_story_visualization_with_merging(domain: Domain): assert 20 < len(generated_graph.edges()) < 33 -def test_remove_auxiliary_nodes(): +@pytest.mark.parametrize("input_nodes, input_edges, remove_count, expected_nodes, expected_edges", [ + ( + [-2, -1, 0, 1, 2, 3, 4, 5], + [(-2, 0), (-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)], + 3, + set([0, 1, 2, 3, 4, 5, -1]), + [(-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)] + ), + ( + [-3, -2, -1, 0, 1, 2, 3, 4, 5], + [(-3, -2), (-2, -1), (-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)], + 4, + set([-3, -1, 0, 1, 2, 3, 4, 5]), + [(-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)] + ) +]) +def test_remove_auxiliary_nodes(input_nodes, input_edges, remove_count, expected_nodes, expected_edges): + import networkx as nx + # Create a sample graph graph = nx.MultiDiGraph() - graph.add_nodes_from([-2, -1, 0, 1, 2, 3, 4, 5]) - graph.add_edges_from([(-2, 0), (-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)]) + graph.add_nodes_from(input_nodes) + graph.add_edges_from(input_edges) # Call the method to remove auxiliary nodes - visualization._remove_auxiliary_nodes(graph, 3) + visualization._remove_auxiliary_nodes(graph, remove_count) # Check if the expected nodes are removed - expected_nodes = set([0, 1, 2, 3, 4, 5, -1]) assert set(graph.nodes()) == expected_nodes, "Nodes mismatch" # Check if the edges are updated correctly - expected_edges = [(-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)] assert list(graph.edges()) == expected_edges, "Edges mismatch"