Skip to content

Commit

Permalink
add: test cases of _remove_auxiliary_nodes()
Browse files Browse the repository at this point in the history
In this commit, I have implemented example test cases to test method: _remove_auxiliary_nodes in @pytest.mark.parametrize. Also,fixed the import error of networkx.
  • Loading branch information
arjun-234 committed Jul 8, 2023
1 parent b213186 commit f1b1925
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions tests/shared/core/training_data/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from rasa.shared.nlu.training_data.message import Message
from rasa.shared.nlu.training_data.training_data import TrainingData

import pytest


def test_style_transfer():
r = visualization._transfer_style({"class": "dashed great"}, {"class": "myclass"})
Expand Down Expand Up @@ -190,19 +192,35 @@ def test_story_visualization_with_merging(domain: Domain):
assert 20 < len(generated_graph.edges()) < 33


def test_remove_auxiliary_nodes():
@pytest.mark.parametrize("input_nodes, input_edges, remove_count, expected_nodes, expected_edges", [
(
[-2, -1, 0, 1, 2, 3, 4, 5],
[(-2, 0), (-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)],
3,
set([0, 1, 2, 3, 4, 5, -1]),
[(-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)]
),
(
[-3, -2, -1, 0, 1, 2, 3, 4, 5],
[(-3, -2), (-2, -1), (-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)],
4,
set([-3, -1, 0, 1, 2, 3, 4, 5]),
[(-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)]
)
])
def test_remove_auxiliary_nodes(input_nodes, input_edges, remove_count, expected_nodes, expected_edges):
import networkx as nx

# Create a sample graph
graph = nx.MultiDiGraph()
graph.add_nodes_from([-2, -1, 0, 1, 2, 3, 4, 5])
graph.add_edges_from([(-2, 0), (-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)])
graph.add_nodes_from(input_nodes)
graph.add_edges_from(input_edges)

# Call the method to remove auxiliary nodes
visualization._remove_auxiliary_nodes(graph, 3)
visualization._remove_auxiliary_nodes(graph, remove_count)

# Check if the expected nodes are removed
expected_nodes = set([0, 1, 2, 3, 4, 5, -1])
assert set(graph.nodes()) == expected_nodes, "Nodes mismatch"

# Check if the edges are updated correctly
expected_edges = [(-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)]
assert list(graph.edges()) == expected_edges, "Edges mismatch"

0 comments on commit f1b1925

Please sign in to comment.