Skip to content

Commit

Permalink
fixing peristed stateful flows
Browse files Browse the repository at this point in the history
  • Loading branch information
joaomdmoura committed Jan 20, 2025
1 parent 5c71b09 commit 85bac81
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 76 deletions.
2 changes: 2 additions & 0 deletions src/crewai/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,8 @@ def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
self._log_flow_event(f"Loading flow state from memory for UUID: {restore_uuid}", color="yellow")
# Restore the state
self._restore_state(stored_state)
else:
self._log_flow_event(f"No flow state found for UUID: {restore_uuid}", color="red")

# Apply any additional inputs after restoration
filtered_inputs = {k: v for k, v in inputs.items() if k != 'id'}
Expand Down
133 changes: 57 additions & 76 deletions tests/test_flow_persistence.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Test flow state persistence functionality."""

import os
from typing import Dict, Optional
from typing import Dict

import pytest
from pydantic import BaseModel

from crewai.flow.flow import Flow, FlowState, start
from crewai.flow.flow import Flow, FlowState, start, listen
from crewai.flow.persistence import persist
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence

Expand Down Expand Up @@ -73,123 +73,104 @@ def test_flow_state_restoration(tmp_path):

# First flow execution to create initial state
class RestorableFlow(Flow[TestState]):
initial_state = TestState

@start()
@persist(persistence)
def set_message(self):
self.state.message = "Original message"
self.state.counter = 42
if self.state.message == "":
self.state.message = "Original message"
if self.state.counter == 0:
self.state.counter = 42

# Create and persist initial state
flow1 = RestorableFlow(persistence=persistence)
flow1.kickoff()
original_uuid = flow1.state.id

# Test case 1: Restore using restore_uuid with field override
flow2 = RestorableFlow(
persistence=persistence,
restore_uuid=original_uuid,
counter=43, # Override counter
)
flow2 = RestorableFlow(persistence=persistence)
flow2.kickoff(inputs={
"id": original_uuid,
"counter": 43
})

# Verify state restoration and selective field override
assert flow2.state.id == original_uuid
assert flow2.state.message == "Original message" # Preserved
assert flow2.state.counter == 43 # Overridden

# Test case 2: Restore using kwargs['id']
flow3 = RestorableFlow(
persistence=persistence,
id=original_uuid,
message="Updated message", # Override message
)
flow3 = RestorableFlow(persistence=persistence)
flow3.kickoff(inputs={
"id": original_uuid,
"message": "Updated message"
})

# Verify state restoration and selective field override
assert flow3.state.id == original_uuid
assert flow3.state.counter == 42 # Preserved
assert flow3.state.counter == 43 # Preserved
assert flow3.state.message == "Updated message" # Overridden

# Test case 3: Verify error on conflicting IDs
with pytest.raises(ValueError) as exc_info:
RestorableFlow(
persistence=persistence,
restore_uuid=original_uuid,
id="different-id", # Conflict with restore_uuid
)
assert "Conflicting IDs provided" in str(exc_info.value)

# Test case 4: Verify error on non-existent restore_uuid
with pytest.raises(ValueError) as exc_info:
RestorableFlow(
persistence=persistence,
restore_uuid="non-existent-uuid",
)
assert "No state found" in str(exc_info.value)

# Test case 5: Allow new state creation with kwargs['id']
new_uuid = "new-flow-id"
flow4 = RestorableFlow(
persistence=persistence,
id=new_uuid,
message="New message",
counter=100,
)

# Verify new state creation with provided ID
assert flow4.state.id == new_uuid
assert flow4.state.message == "New message"
assert flow4.state.counter == 100


def test_multiple_method_persistence(tmp_path):
"""Test state persistence across multiple method executions."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)

class MultiStepFlow(Flow[TestState]):
initial_state = TestState

@start()
@persist(persistence)
def step_1(self):
self.state.counter = 1
self.state.message = "Step 1"

@start()
if self.state.counter == 1:
self.state.counter = 99999
self.state.message = "Step 99999"
else:
self.state.counter = 1
self.state.message = "Step 1"

@listen(step_1)
@persist(persistence)
def step_2(self):
self.state.counter = 2
self.state.message = "Step 2"
if self.state.counter == 1:
self.state.counter = 2
self.state.message = "Step 2"

flow = MultiStepFlow(persistence=persistence)
flow.kickoff()

flow2 = MultiStepFlow(persistence=persistence)
flow2.kickoff(inputs={"id": flow.state.id})

# Load final state
final_state = persistence.load_state(flow.state.id)
final_state = flow2.state
assert final_state is not None
assert final_state["counter"] == 2
assert final_state["message"] == "Step 2"


def test_persistence_error_handling(tmp_path):
"""Test error handling in persistence operations."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)

class InvalidFlow(Flow[TestState]):
# Missing id field in initial state
class InvalidState(BaseModel):
value: str = ""

initial_state = InvalidState
assert final_state.counter == 2
assert final_state.message == "Step 2"

class NoPersistenceMultiStepFlow(Flow[TestState]):
@start()
@persist(persistence)
def will_fail(self):
self.state.value = "test"
def step_1(self):
if self.state.counter == 1:
self.state.counter = 99999
self.state.message = "Step 99999"
else:
self.state.counter = 1
self.state.message = "Step 1"

@listen(step_1)
def step_2(self):
if self.state.counter == 1:
self.state.counter = 2
self.state.message = "Step 2"

flow = NoPersistenceMultiStepFlow(persistence=persistence)
flow.kickoff()

with pytest.raises(ValueError) as exc_info:
flow = InvalidFlow(persistence=persistence)
flow2 = NoPersistenceMultiStepFlow(persistence=persistence)
flow2.kickoff(inputs={"id": flow.state.id})

assert "must have an 'id' field" in str(exc_info.value)
# Load final state
final_state = flow2.state
assert final_state.counter == 99999
assert final_state.message == "Step 99999"

0 comments on commit 85bac81

Please sign in to comment.