Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wrap up and tests for flow trigger action #12938

Merged
merged 2 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rasa/core/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,9 @@ def action_for_name_or_text(
return FormAction(action_name_or_text, action_endpoint)

if action_name_or_text.startswith(FLOW_PREFIX):
from rasa.core.actions.flows import FlowTriggerAction
from rasa.core.actions.action_trigger_flow import ActionTriggerFlow

return FlowTriggerAction(action_name_or_text)
return ActionTriggerFlow(action_name_or_text)
return RemoteAction(action_name_or_text, action_endpoint)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,54 +22,81 @@
structlogger = structlog.get_logger(__name__)


class FlowTriggerAction(action.Action):
"""Action which implements and executes the form logic."""
class ActionTriggerFlow(action.Action):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed to align with other trigger action

"""Action which triggers a flow by putting it on the dialogue stack."""

def __init__(self, flow_action_name: Text) -> None:
"""Creates a `FlowTriggerAction`.
"""Creates a `ActionTriggerFlow`.

Args:
flow_action_name: Name of the flow.
"""
super().__init__()

if not flow_action_name.startswith(FLOW_PREFIX):
raise ValueError(
f"Flow action name '{flow_action_name}' needs to start with "
f"'{FLOW_PREFIX}'."
)

self._flow_name = flow_action_name[len(FLOW_PREFIX) :]
self._flow_action_name = flow_action_name

def name(self) -> Text:
"""Return the flow name."""
return self._flow_action_name

async def run(
self,
output_channel: "OutputChannel",
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
"""Trigger the flow."""
def create_event_to_start_flow(self, tracker: DialogueStateTracker) -> Event:
"""Create an event to start the flow.

Args:
tracker: The tracker to start the flow on.

Returns:
The event to start the flow."""
stack = DialogueStack.from_tracker(tracker)
if not stack.is_empty():
frame_type = FlowStackFrameType.INTERRUPT
else:
frame_type = FlowStackFrameType.REGULAR
frame_type = (
FlowStackFrameType.REGULAR
if stack.is_empty()
else FlowStackFrameType.INTERRUPT
)

stack.push(
UserFlowStackFrame(
flow_id=self._flow_name,
frame_type=frame_type,
)
)
return stack.persist_as_event()

def create_events_to_set_flow_slots(self, metadata: Dict[str, Any]) -> List[Event]:
"""Create events to set the flow slots.

Set additional slots to prefill information for the flow.

Args:
metadata: The metadata to set the slots from.

Returns:
The events to set the flow slots.
"""
slots_to_be_set = metadata.get("slots", {}) if metadata else {}
slot_set_events: List[Event] = [
SlotSet(key, value) for key, value in slots_to_be_set.items()
]
return [SlotSet(key, value) for key, value in slots_to_be_set.items()]

async def run(
self,
output_channel: "OutputChannel",
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
"""Trigger the flow."""
events: List[Event] = [self.create_event_to_start_flow(tracker)]
events.extend(self.create_events_to_set_flow_slots(metadata))

events: List[Event] = [
stack.persist_as_event(),
] + slot_set_events
if tracker.active_loop_name:
# end any active loop to ensure we are progressing the started flow
events.append(ActiveLoop(None))

return events
92 changes: 92 additions & 0 deletions tests/core/actions/test_action_trigger_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import pytest
from rasa.core.actions.action_trigger_flow import ActionTriggerFlow
from rasa.core.channels import CollectingOutputChannel
from rasa.core.nlg import TemplatedNaturalLanguageGenerator
from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack
from rasa.dialogue_understanding.stack.frames.flow_stack_frame import (
FlowStackFrameType,
UserFlowStackFrame,
)
from rasa.shared.core.constants import DIALOGUE_STACK_SLOT
from rasa.shared.core.domain import Domain
from rasa.shared.core.events import ActiveLoop, SlotSet
from rasa.shared.core.trackers import DialogueStateTracker


async def test_action_trigger_flow():
tracker = DialogueStateTracker.from_events("test", [])
action = ActionTriggerFlow("flow_foo")
channel = CollectingOutputChannel()
nlg = TemplatedNaturalLanguageGenerator({})
events = await action.run(channel, nlg, tracker, Domain.empty())
assert len(events) == 1
event = events[0]
assert isinstance(event, SlotSet)
assert event.key == DIALOGUE_STACK_SLOT
assert len(event.value) == 1
assert event.value[0]["type"] == UserFlowStackFrame.type()
assert event.value[0]["flow_id"] == "foo"
assert event.value[0]["frame_type"] == FlowStackFrameType.REGULAR.value


async def test_action_trigger_flow_with_slots():
tracker = DialogueStateTracker.from_events("test", [])
action = ActionTriggerFlow("flow_foo")
channel = CollectingOutputChannel()
nlg = TemplatedNaturalLanguageGenerator({})
events = await action.run(
channel, nlg, tracker, Domain.empty(), metadata={"slots": {"foo": "bar"}}
)

event = events[0]
assert isinstance(event, SlotSet)
assert event.key == DIALOGUE_STACK_SLOT
assert len(event.value) == 1
assert event.value[0]["type"] == UserFlowStackFrame.type()
assert event.value[0]["flow_id"] == "foo"

assert len(events) == 2
event = events[1]
assert isinstance(event, SlotSet)
assert event.key == "foo"
assert event.value == "bar"


async def test_action_trigger_fails_if_name_is_invalid():
with pytest.raises(ValueError):
ActionTriggerFlow("foo")


async def test_action_trigger_ends_an_active_loop_on_the_tracker():
tracker = DialogueStateTracker.from_events("test", [ActiveLoop("loop_foo")])
action = ActionTriggerFlow("flow_foo")
channel = CollectingOutputChannel()
nlg = TemplatedNaturalLanguageGenerator({})
events = await action.run(channel, nlg, tracker, Domain.empty())

assert len(events) == 2
assert isinstance(events[1], ActiveLoop)
assert events[1].name is None


async def test_action_trigger_uses_interrupt_flow_type_if_stack_already_contains_flow():
user_frame = UserFlowStackFrame(
flow_id="my_flow", step_id="collect_bar", frame_id="some-frame-id"
)
stack = DialogueStack(frames=[user_frame])
tracker = DialogueStateTracker.from_events("test", [stack.persist_as_event()])

action = ActionTriggerFlow("flow_foo")
channel = CollectingOutputChannel()
nlg = TemplatedNaturalLanguageGenerator({})

events = await action.run(channel, nlg, tracker, Domain.empty())

assert len(events) == 1
event = events[0]
assert isinstance(event, SlotSet)
assert event.key == DIALOGUE_STACK_SLOT
assert len(event.value) == 2
assert event.value[1]["type"] == UserFlowStackFrame.type()
assert event.value[1]["flow_id"] == "foo"
assert event.value[1]["frame_type"] == FlowStackFrameType.INTERRUPT.value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please also test that it's a regular frame when there is nothing there yet :)

Loading