diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index c31c9ed822..b744ba6add 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -721,6 +721,8 @@ def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any: self._log_flow_event(f"Loading flow state from memory for UUID: {restore_uuid}", color="yellow") # Restore the state self._restore_state(stored_state) + else: + self._log_flow_event(f"No flow state found for UUID: {restore_uuid}", color="red") # Apply any additional inputs after restoration filtered_inputs = {k: v for k, v in inputs.items() if k != 'id'} diff --git a/tests/test_flow_persistence.py b/tests/test_flow_persistence.py index 0d1cfe3c38..b944050a29 100644 --- a/tests/test_flow_persistence.py +++ b/tests/test_flow_persistence.py @@ -1,12 +1,12 @@ """Test flow state persistence functionality.""" import os -from typing import Dict, Optional +from typing import Dict import pytest from pydantic import BaseModel -from crewai.flow.flow import Flow, FlowState, start +from crewai.flow.flow import Flow, FlowState, start, listen from crewai.flow.persistence import persist from crewai.flow.persistence.sqlite import SQLiteFlowPersistence @@ -73,13 +73,14 @@ def test_flow_state_restoration(tmp_path): # First flow execution to create initial state class RestorableFlow(Flow[TestState]): - initial_state = TestState @start() @persist(persistence) def set_message(self): - self.state.message = "Original message" - self.state.counter = 42 + if self.state.message == "": + self.state.message = "Original message" + if self.state.counter == 0: + self.state.counter = 42 # Create and persist initial state flow1 = RestorableFlow(persistence=persistence) @@ -87,11 +88,11 @@ def set_message(self): original_uuid = flow1.state.id # Test case 1: Restore using restore_uuid with field override - flow2 = RestorableFlow( - persistence=persistence, - restore_uuid=original_uuid, - counter=43, # Override counter - ) + flow2 = RestorableFlow(persistence=persistence) + flow2.kickoff(inputs={ + "id": original_uuid, + "counter": 43 + }) # Verify state restoration and selective field override assert flow2.state.id == original_uuid @@ -99,48 +100,17 @@ def set_message(self): assert flow2.state.counter == 43 # Overridden # Test case 2: Restore using kwargs['id'] - flow3 = RestorableFlow( - persistence=persistence, - id=original_uuid, - message="Updated message", # Override message - ) + flow3 = RestorableFlow(persistence=persistence) + flow3.kickoff(inputs={ + "id": original_uuid, + "message": "Updated message" + }) # Verify state restoration and selective field override assert flow3.state.id == original_uuid - assert flow3.state.counter == 42 # Preserved + assert flow3.state.counter == 43 # Preserved assert flow3.state.message == "Updated message" # Overridden - # Test case 3: Verify error on conflicting IDs - with pytest.raises(ValueError) as exc_info: - RestorableFlow( - persistence=persistence, - restore_uuid=original_uuid, - id="different-id", # Conflict with restore_uuid - ) - assert "Conflicting IDs provided" in str(exc_info.value) - - # Test case 4: Verify error on non-existent restore_uuid - with pytest.raises(ValueError) as exc_info: - RestorableFlow( - persistence=persistence, - restore_uuid="non-existent-uuid", - ) - assert "No state found" in str(exc_info.value) - - # Test case 5: Allow new state creation with kwargs['id'] - new_uuid = "new-flow-id" - flow4 = RestorableFlow( - persistence=persistence, - id=new_uuid, - message="New message", - counter=100, - ) - - # Verify new state creation with provided ID - assert flow4.state.id == new_uuid - assert flow4.state.message == "New message" - assert flow4.state.counter == 100 - def test_multiple_method_persistence(tmp_path): """Test state persistence across multiple method executions.""" @@ -148,48 +118,59 @@ def test_multiple_method_persistence(tmp_path): persistence = SQLiteFlowPersistence(db_path) class MultiStepFlow(Flow[TestState]): - initial_state = TestState - @start() @persist(persistence) def step_1(self): - self.state.counter = 1 - self.state.message = "Step 1" - - @start() + if self.state.counter == 1: + self.state.counter = 99999 + self.state.message = "Step 99999" + else: + self.state.counter = 1 + self.state.message = "Step 1" + + @listen(step_1) @persist(persistence) def step_2(self): - self.state.counter = 2 - self.state.message = "Step 2" + if self.state.counter == 1: + self.state.counter = 2 + self.state.message = "Step 2" flow = MultiStepFlow(persistence=persistence) flow.kickoff() + flow2 = MultiStepFlow(persistence=persistence) + flow2.kickoff(inputs={"id": flow.state.id}) + # Load final state - final_state = persistence.load_state(flow.state.id) + final_state = flow2.state assert final_state is not None - assert final_state["counter"] == 2 - assert final_state["message"] == "Step 2" - - -def test_persistence_error_handling(tmp_path): - """Test error handling in persistence operations.""" - db_path = os.path.join(tmp_path, "test_flows.db") - persistence = SQLiteFlowPersistence(db_path) - - class InvalidFlow(Flow[TestState]): - # Missing id field in initial state - class InvalidState(BaseModel): - value: str = "" - - initial_state = InvalidState + assert final_state.counter == 2 + assert final_state.message == "Step 2" + class NoPersistenceMultiStepFlow(Flow[TestState]): @start() @persist(persistence) - def will_fail(self): - self.state.value = "test" + def step_1(self): + if self.state.counter == 1: + self.state.counter = 99999 + self.state.message = "Step 99999" + else: + self.state.counter = 1 + self.state.message = "Step 1" + + @listen(step_1) + def step_2(self): + if self.state.counter == 1: + self.state.counter = 2 + self.state.message = "Step 2" + + flow = NoPersistenceMultiStepFlow(persistence=persistence) + flow.kickoff() - with pytest.raises(ValueError) as exc_info: - flow = InvalidFlow(persistence=persistence) + flow2 = NoPersistenceMultiStepFlow(persistence=persistence) + flow2.kickoff(inputs={"id": flow.state.id}) - assert "must have an 'id' field" in str(exc_info.value) + # Load final state + final_state = flow2.state + assert final_state.counter == 99999 + assert final_state.message == "Step 99999"