Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: rasa/shared/core/training_data/visualization.py #12544

Merged
merged 24 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
1d97d81
refactor: rasa/shared/core/training_data/visualization.py
arjun-234 Jun 23, 2023
80e2b8b
Merge branch 'main' into main
arjun-234 Jun 23, 2023
652692a
Merge branch 'main' into main
arjun-234 Jun 29, 2023
6c506a0
Merge branch 'main' into main
arjun-234 Jun 30, 2023
7470405
Merge branch 'main' into main
arjun-234 Jul 4, 2023
5350bfc
add: unit test for _remove_auxiliary_nodes()
arjun-234 Jul 5, 2023
ce414e5
Merge branch 'main' into main
arjun-234 Jul 5, 2023
4f375b9
Merge branch 'main' into main
arjun-234 Jul 5, 2023
0b321df
Merge branch 'main' into main
arjun-234 Jul 6, 2023
b213186
Merge branch 'main' into main
arjun-234 Jul 7, 2023
f1b1925
add: test cases of _remove_auxiliary_nodes()
arjun-234 Jul 8, 2023
5841b06
Merge branch 'main' into main
arjun-234 Jul 10, 2023
8ce124a
refactor: blank/long lines in function: test_remove_auxiliary_nodes
arjun-234 Jul 10, 2023
d4b956d
remove: break statement after removing node
arjun-234 Jul 10, 2023
ded049b
fix: black formatting issue
arjun-234 Jul 11, 2023
aabe510
Merge branch 'main' into main
arjun-234 Jul 11, 2023
8a73bc0
Merge branch 'main' into main
arjun-234 Jul 20, 2023
c863861
Merge branch 'main' into main
arjun-234 Jul 31, 2023
0f95d46
Merge branch 'main' into main
arjun-234 Aug 6, 2023
9cf9e6c
Merge branch 'RasaHQ:main' into main
arjun-234 Aug 7, 2023
2b7164f
Merge branch 'main' into main
arjun-234 Aug 8, 2023
cde48b7
Merge branch 'main' into main
vcidst Aug 8, 2023
6cd06c7
Merge branch 'main' into main
vcidst Aug 8, 2023
5553387
Merge branch 'main' into main
arjun-234 Aug 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions rasa/shared/core/training_data/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,17 +533,18 @@ def _remove_auxiliary_nodes(

graph.remove_node(TMP_NODE_ID)

if not len(list(graph.predecessors(END_NODE_ID))):
if not graph.predecessors(END_NODE_ID):
graph.remove_node(END_NODE_ID)

# remove duplicated "..." nodes after merging
ps = set()
predecessors_seen = set()
for i in range(special_node_idx + 1, TMP_NODE_ID):
for pred in list(graph.predecessors(i)):
if pred in ps:
predecessors = graph.predecessors(i)
for pred in predecessors:
if pred in predecessors_seen:
graph.remove_node(i)
else:
ps.add(pred)
break
predecessors_seen.update(predecessors)


def visualize_stories(
Expand Down
36 changes: 36 additions & 0 deletions tests/shared/core/training_data/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from rasa.shared.nlu.training_data.message import Message
from rasa.shared.nlu.training_data.training_data import TrainingData

import pytest


def test_style_transfer():
r = visualization._transfer_style({"class": "dashed great"}, {"class": "myclass"})
Expand Down Expand Up @@ -188,3 +190,37 @@ def test_story_visualization_with_merging(domain: Domain):
assert 15 < len(generated_graph.nodes()) < 33

assert 20 < len(generated_graph.edges()) < 33


@pytest.mark.parametrize("input_nodes, input_edges, remove_count, expected_nodes, expected_edges", [
(
[-2, -1, 0, 1, 2, 3, 4, 5],
[(-2, 0), (-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)],
3,
set([0, 1, 2, 3, 4, 5, -1]),
[(-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)]
),
(
[-3, -2, -1, 0, 1, 2, 3, 4, 5],
[(-3, -2), (-2, -1), (-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)],
4,
set([-3, -1, 0, 1, 2, 3, 4, 5]),
[(-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)]
)
])
def test_remove_auxiliary_nodes(input_nodes, input_edges, remove_count, expected_nodes, expected_edges):
import networkx as nx

# Create a sample graph
graph = nx.MultiDiGraph()
arjun-234 marked this conversation as resolved.
Show resolved Hide resolved
graph.add_nodes_from(input_nodes)
graph.add_edges_from(input_edges)

# Call the method to remove auxiliary nodes
visualization._remove_auxiliary_nodes(graph, remove_count)

# Check if the expected nodes are removed
assert set(graph.nodes()) == expected_nodes, "Nodes mismatch"

# Check if the edges are updated correctly
assert list(graph.edges()) == expected_edges, "Edges mismatch"