Skip to content

Commit

Permalink
refactor: rasa/shared/core/training_data/visualization.py (#12544)
Browse files Browse the repository at this point in the history
* refactor: rasa/shared/core/training_data/visualization.py

In the refactored version, In function(_remove_auxiliary_nodes), I've replaced the conversion of graph.predecessors(i) into a list with the direct usage of the generator. Additionally, I've introduced the predecessors_seen set to efficiently keep track of seen predecessors. When a duplicated predecessor is found, we can remove the node and break out of the inner loop. This optimization reduces the time complexity of checking for duplicated nodes to approximately O((TMP_NODE_ID - special_node_idx) + out_degree(node)).

* add: unit test for _remove_auxiliary_nodes()

* add: test cases of  _remove_auxiliary_nodes()

In this commit, I have implemented example test cases to test method: _remove_auxiliary_nodes in @pytest.mark.parametrize. Also,fixed the import error of networkx.

* refactor: blank/long lines in function: test_remove_auxiliary_nodes

* remove: break statement after removing node

In this version, I have removed break statement from the loop so it can remove all the predecessors nodes which are in predecessors_seen without breaking the loop.

* fix: black formatting issue
  • Loading branch information
arjun-234 authored Aug 8, 2023
1 parent 16a8034 commit 491a392
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 6 deletions.
12 changes: 6 additions & 6 deletions rasa/shared/core/training_data/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,17 +533,17 @@ def _remove_auxiliary_nodes(

graph.remove_node(TMP_NODE_ID)

if not len(list(graph.predecessors(END_NODE_ID))):
if not graph.predecessors(END_NODE_ID):
graph.remove_node(END_NODE_ID)

# remove duplicated "..." nodes after merging
ps = set()
predecessors_seen = set()
for i in range(special_node_idx + 1, TMP_NODE_ID):
for pred in list(graph.predecessors(i)):
if pred in ps:
predecessors = graph.predecessors(i)
for pred in predecessors:
if pred in predecessors_seen:
graph.remove_node(i)
else:
ps.add(pred)
predecessors_seen.update(predecessors)


def visualize_stories(
Expand Down
41 changes: 41 additions & 0 deletions tests/shared/core/training_data/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from rasa.shared.nlu.training_data.message import Message
from rasa.shared.nlu.training_data.training_data import TrainingData

import pytest


def test_style_transfer():
r = visualization._transfer_style({"class": "dashed great"}, {"class": "myclass"})
Expand Down Expand Up @@ -188,3 +190,42 @@ def test_story_visualization_with_merging(domain: Domain):
assert 15 < len(generated_graph.nodes()) < 33

assert 20 < len(generated_graph.edges()) < 33


@pytest.mark.parametrize(
"input_nodes, input_edges, remove_count, expected_nodes, expected_edges",
[
(
[-2, -1, 0, 1, 2, 3, 4, 5],
[(-2, 0), (-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)],
3,
set([0, 1, 2, 3, 4, 5, -1]),
[(-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)],
),
(
[-3, -2, -1, 0, 1, 2, 3, 4, 5],
[(-3, -2), (-2, -1), (-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)],
4,
set([-3, -1, 0, 1, 2, 3, 4, 5]),
[(-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)],
),
],
)
def test_remove_auxiliary_nodes(
input_nodes, input_edges, remove_count, expected_nodes, expected_edges
):
import networkx as nx

# Create a sample graph
graph = nx.MultiDiGraph()
graph.add_nodes_from(input_nodes)
graph.add_edges_from(input_edges)

# Call the method to remove auxiliary nodes
visualization._remove_auxiliary_nodes(graph, remove_count)

# Check if the expected nodes are removed
assert set(graph.nodes()) == expected_nodes, "Nodes mismatch"

# Check if the edges are updated correctly
assert list(graph.edges()) == expected_edges, "Edges mismatch"

0 comments on commit 491a392

Please sign in to comment.