diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 6dcc37e403..b744ba6add 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -447,14 +447,12 @@ class _FlowGeneric(cls): # type: ignore def __init__( self, persistence: Optional[FlowPersistence] = None, - restore_uuid: Optional[str] = None, **kwargs: Any, ) -> None: """Initialize a new Flow instance. Args: persistence: Optional persistence backend for storing flow states - restore_uuid: Optional UUID to restore state from persistence **kwargs: Additional state values to initialize or override """ # Initialize basic instance attributes @@ -464,64 +462,12 @@ def __init__( self._method_outputs: List[Any] = [] # List to store all method outputs self._persistence: Optional[FlowPersistence] = persistence - # Validate state model before initialization - if isinstance(self.initial_state, type): - if issubclass(self.initial_state, BaseModel) and not issubclass( - self.initial_state, FlowState - ): - # Check if model has id field - model_fields = getattr(self.initial_state, "model_fields", None) - if not model_fields or "id" not in model_fields: - raise ValueError("Flow state model must have an 'id' field") - - # Handle persistence and potential ID conflicts - stored_state = None - if self._persistence is not None: - if ( - restore_uuid - and kwargs - and "id" in kwargs - and restore_uuid != kwargs["id"] - ): - raise ValueError( - f"Conflicting IDs provided: restore_uuid='{restore_uuid}' " - f"vs kwargs['id']='{kwargs['id']}'. Use only one ID for restoration." - ) + # Initialize state with initial values + self._state = self._create_initial_state() - # Attempt to load state, prioritizing restore_uuid - if restore_uuid: - self._log_flow_event(f"Loading flow state from memory for UUID: {restore_uuid}", color="bold_yellow") - stored_state = self._persistence.load_state(restore_uuid) - if not stored_state: - raise ValueError( - f"No state found for restore_uuid='{restore_uuid}'" - ) - elif kwargs and "id" in kwargs: - self._log_flow_event(f"Loading flow state from memory for ID: {kwargs['id']}", color="bold_yellow") - stored_state = self._persistence.load_state(kwargs["id"]) - if not stored_state: - # For kwargs["id"], we allow creating new state if not found - self._state = self._create_initial_state() - if kwargs: - self._initialize_state(kwargs) - return - - # Initialize state based on persistence and kwargs - if stored_state: - # Create initial state and restore from persistence - self._state = self._create_initial_state() - self._restore_state(stored_state) - # Apply any additional kwargs to override specific fields - if kwargs: - filtered_kwargs = {k: v for k, v in kwargs.items() if k != "id"} - if filtered_kwargs: - self._initialize_state(filtered_kwargs) - else: - # No stored state, create new state with initial values - self._state = self._create_initial_state() - # Apply any additional kwargs - if kwargs: - self._initialize_state(kwargs) + # Apply any additional kwargs + if kwargs: + self._initialize_state(kwargs) self._telemetry.flow_creation_span(self.__class__.__name__) @@ -635,18 +581,18 @@ def method_outputs(self) -> List[Any]: @property def flow_id(self) -> str: """Returns the unique identifier of this flow instance. - + This property provides a consistent way to access the flow's unique identifier regardless of the underlying state implementation (dict or BaseModel). - + Returns: str: The flow's unique identifier, or an empty string if not found - + Note: This property safely handles both dictionary and BaseModel state types, returning an empty string if the ID cannot be retrieved rather than raising an exception. - + Example: ```python flow = MyFlow() @@ -656,7 +602,7 @@ def flow_id(self) -> str: try: if not hasattr(self, '_state'): return "" - + if isinstance(self._state, dict): return str(self._state.get("id", "")) elif isinstance(self._state, BaseModel): @@ -731,7 +677,6 @@ def _restore_state(self, stored_state: Dict[str, Any]) -> None: """ # When restoring from persistence, use the stored ID stored_id = stored_state.get("id") - self._log_flow_event(f"Restoring flow state from memory for ID: {stored_id}", color="bold_yellow") if not stored_id: raise ValueError("Stored state must have an 'id' field") @@ -755,6 +700,36 @@ def _restore_state(self, stored_state: Dict[str, Any]) -> None: raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}") def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any: + """Start the flow execution. + + Args: + inputs: Optional dictionary containing input values and potentially a state ID to restore + """ + # Handle state restoration if ID is provided in inputs + if inputs and 'id' in inputs and self._persistence is not None: + restore_uuid = inputs['id'] + stored_state = self._persistence.load_state(restore_uuid) + + # Override the id in the state if it exists in inputs + if 'id' in inputs: + if isinstance(self._state, dict): + self._state['id'] = inputs['id'] + elif isinstance(self._state, BaseModel): + setattr(self._state, 'id', inputs['id']) + + if stored_state: + self._log_flow_event(f"Loading flow state from memory for UUID: {restore_uuid}", color="yellow") + # Restore the state + self._restore_state(stored_state) + else: + self._log_flow_event(f"No flow state found for UUID: {restore_uuid}", color="red") + + # Apply any additional inputs after restoration + filtered_inputs = {k: v for k, v in inputs.items() if k != 'id'} + if filtered_inputs: + self._initialize_state(filtered_inputs) + + # Start flow execution self.event_emitter.send( self, event=FlowStartedEvent( @@ -762,10 +737,11 @@ def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any: flow_name=self.__class__.__name__, ), ) - self._log_flow_event(f"Flow started with ID: {self.flow_id}", color="yellow") + self._log_flow_event(f"Flow started with ID: {self.flow_id}", color="bold_magenta") - if inputs is not None: + if inputs is not None and 'id' not in inputs: self._initialize_state(inputs) + return asyncio.run(self.kickoff_async()) async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any: @@ -1010,18 +986,18 @@ async def _execute_single_listener(self, listener_name: str, result: Any) -> Non def _log_flow_event(self, message: str, color: str = "yellow", level: str = "info") -> None: """Centralized logging method for flow events. - + This method provides a consistent interface for logging flow-related events, combining both console output with colors and proper logging levels. - + Args: message: The message to log color: Color to use for console output (default: yellow) Available colors: purple, red, bold_green, bold_purple, - bold_blue, yellow, bold_yellow + bold_blue, yellow, yellow level: Log level to use (default: info) Supported levels: info, warning - + Note: This method uses the Printer utility for colored console output and the standard logging module for log level support. @@ -1031,7 +1007,7 @@ def _log_flow_event(self, message: str, color: str = "yellow", level: str = "inf logger.info(message) elif level == "warning": logger.warning(message) - + def plot(self, filename: str = "crewai_flow") -> None: self._telemetry.flow_plotting_span( self.__class__.__name__, list(self._methods.keys()) diff --git a/src/crewai/flow/persistence/decorators.py b/src/crewai/flow/persistence/decorators.py index 58cf1e1112..ebf3778b73 100644 --- a/src/crewai/flow/persistence/decorators.py +++ b/src/crewai/flow/persistence/decorators.py @@ -54,57 +54,44 @@ async def async_method(self): class PersistenceDecorator: """Class to handle flow state persistence with consistent logging.""" - + _printer = Printer() # Class-level printer instance - + @classmethod def persist_state(cls, flow_instance: Any, method_name: str, persistence_instance: FlowPersistence) -> None: """Persist flow state with proper error handling and logging. - + This method handles the persistence of flow state data, including proper error handling and colored console output for status updates. - + Args: flow_instance: The flow instance whose state to persist method_name: Name of the method that triggered persistence persistence_instance: The persistence backend to use - + Raises: ValueError: If flow has no state or state lacks an ID RuntimeError: If state persistence fails AttributeError: If flow instance lacks required state attributes - - Note: - Uses bold_yellow color for success messages and red for errors. - All operations are logged at appropriate levels (info/error). - - Example: - ```python - @persist - def my_flow_method(self): - # Method implementation - pass - # State will be automatically persisted after method execution - ``` """ try: state = getattr(flow_instance, 'state', None) if state is None: raise ValueError("Flow instance has no state") - + flow_uuid: Optional[str] = None if isinstance(state, dict): flow_uuid = state.get('id') elif isinstance(state, BaseModel): flow_uuid = getattr(state, 'id', None) - + if not flow_uuid: raise ValueError("Flow state must have an 'id' field for persistence") - + # Log state saving with consistent message - cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="bold_yellow") + cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan") logger.info(LOG_MESSAGES["save_state"].format(flow_uuid)) - + try: persistence_instance.save_state( flow_uuid=flow_uuid, @@ -154,44 +141,79 @@ class MyFlow(Flow[MyState]): def begin(self): pass """ + def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]: """Decorator that handles both class and method decoration.""" actual_persistence = persistence or SQLiteFlowPersistence() if isinstance(target, type): # Class decoration - class_methods = {} + original_init = getattr(target, "__init__") + + @functools.wraps(original_init) + def new_init(self: Any, *args: Any, **kwargs: Any) -> None: + if 'persistence' not in kwargs: + kwargs['persistence'] = actual_persistence + original_init(self, *args, **kwargs) + + setattr(target, "__init__", new_init) + + # Store original methods to preserve their decorators + original_methods = {} + for name, method in target.__dict__.items(): - if callable(method) and hasattr(method, "__is_flow_method__"): - # Wrap each flow method with persistence - if asyncio.iscoroutinefunction(method): - @functools.wraps(method) - async def class_async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - method_coro = method(self, *args, **kwargs) - if asyncio.iscoroutine(method_coro): - result = await method_coro - else: - result = method_coro - PersistenceDecorator.persist_state(self, method.__name__, actual_persistence) + if callable(method) and ( + hasattr(method, "__is_start_method__") or + hasattr(method, "__trigger_methods__") or + hasattr(method, "__condition_type__") or + hasattr(method, "__is_flow_method__") or + hasattr(method, "__is_router__") + ): + original_methods[name] = method + + # Create wrapped versions of the methods that include persistence + for name, method in original_methods.items(): + if asyncio.iscoroutinefunction(method): + # Create a closure to capture the current name and method + def create_async_wrapper(method_name: str, original_method: Callable): + @functools.wraps(original_method) + async def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + result = await original_method(self, *args, **kwargs) + PersistenceDecorator.persist_state(self, method_name, actual_persistence) return result - class_methods[name] = class_async_wrapper - else: - @functools.wraps(method) - def class_sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - result = method(self, *args, **kwargs) - PersistenceDecorator.persist_state(self, method.__name__, actual_persistence) + return method_wrapper + + wrapped = create_async_wrapper(name, method) + + # Preserve all original decorators and attributes + for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: + if hasattr(method, attr): + setattr(wrapped, attr, getattr(method, attr)) + setattr(wrapped, "__is_flow_method__", True) + + # Update the class with the wrapped method + setattr(target, name, wrapped) + else: + # Create a closure to capture the current name and method + def create_sync_wrapper(method_name: str, original_method: Callable): + @functools.wraps(original_method) + def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + result = original_method(self, *args, **kwargs) + PersistenceDecorator.persist_state(self, method_name, actual_persistence) return result - class_methods[name] = class_sync_wrapper + return method_wrapper + + wrapped = create_sync_wrapper(name, method) - # Preserve flow-specific attributes + # Preserve all original decorators and attributes for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: if hasattr(method, attr): - setattr(class_methods[name], attr, getattr(method, attr)) - setattr(class_methods[name], "__is_flow_method__", True) + setattr(wrapped, attr, getattr(method, attr)) + setattr(wrapped, "__is_flow_method__", True) + + # Update the class with the wrapped method + setattr(target, name, wrapped) - # Update class with wrapped methods - for name, method in class_methods.items(): - setattr(target, name, method) return target else: # Method decoration @@ -208,6 +230,7 @@ async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> result = method_coro PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence) return result + for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: if hasattr(method, attr): setattr(method_async_wrapper, attr, getattr(method, attr)) @@ -219,6 +242,7 @@ def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T: result = method(flow_instance, *args, **kwargs) PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence) return result + for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: if hasattr(method, attr): setattr(method_sync_wrapper, attr, getattr(method, attr)) diff --git a/src/crewai/flow/persistence/sqlite.py b/src/crewai/flow/persistence/sqlite.py index bdd091b2b3..7a6f134fa2 100644 --- a/src/crewai/flow/persistence/sqlite.py +++ b/src/crewai/flow/persistence/sqlite.py @@ -3,10 +3,9 @@ """ import json -import os import sqlite3 -import tempfile from datetime import datetime +from pathlib import Path from typing import Any, Dict, Optional, Union from pydantic import BaseModel @@ -16,34 +15,34 @@ class SQLiteFlowPersistence(FlowPersistence): """SQLite-based implementation of flow state persistence. - + This class provides a simple, file-based persistence implementation using SQLite. It's suitable for development and testing, or for production use cases with moderate performance requirements. """ - + db_path: str # Type annotation for instance variable - + def __init__(self, db_path: Optional[str] = None): """Initialize SQLite persistence. - + Args: db_path: Path to the SQLite database file. If not provided, uses db_storage_path() from utilities.paths. - + Raises: ValueError: If db_path is invalid """ from crewai.utilities.paths import db_storage_path # Get path from argument or default location - path = db_path or db_storage_path() - + path = db_path or str(Path(db_storage_path()) / "flow_states.db") + if not path: raise ValueError("Database path must be provided") - + self.db_path = path # Now mypy knows this is str self.init_db() - + def init_db(self) -> None: """Create the necessary tables if they don't exist.""" with sqlite3.connect(self.db_path) as conn: @@ -58,10 +57,10 @@ def init_db(self) -> None: """) # Add index for faster UUID lookups conn.execute(""" - CREATE INDEX IF NOT EXISTS idx_flow_states_uuid + CREATE INDEX IF NOT EXISTS idx_flow_states_uuid ON flow_states(flow_uuid) """) - + def save_state( self, flow_uuid: str, @@ -69,7 +68,7 @@ def save_state( state_data: Union[Dict[str, Any], BaseModel], ) -> None: """Save the current flow state to SQLite. - + Args: flow_uuid: Unique identifier for the flow instance method_name: Name of the method that just completed @@ -84,7 +83,7 @@ def save_state( raise ValueError( f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}" ) - + with sqlite3.connect(self.db_path) as conn: conn.execute(""" INSERT INTO flow_states ( @@ -99,13 +98,13 @@ def save_state( datetime.utcnow().isoformat(), json.dumps(state_dict), )) - + def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]: """Load the most recent state for a given flow UUID. - + Args: flow_uuid: Unique identifier for the flow instance - + Returns: The most recent state as a dictionary, or None if no state exists """ @@ -118,7 +117,7 @@ def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]: LIMIT 1 """, (flow_uuid,)) row = cursor.fetchone() - + if row: return json.loads(row[0]) return None diff --git a/src/crewai/memory/storage/kickoff_task_outputs_storage.py b/src/crewai/memory/storage/kickoff_task_outputs_storage.py index ef99e7b86e..2a035833d8 100644 --- a/src/crewai/memory/storage/kickoff_task_outputs_storage.py +++ b/src/crewai/memory/storage/kickoff_task_outputs_storage.py @@ -23,7 +23,7 @@ def __init__( ) -> None: if db_path is None: # Get the parent directory of the default db path and create our db file there - db_path = str(Path(db_storage_path()).parent / "latest_kickoff_task_outputs.db") + db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db") self.db_path = db_path self._printer: Printer = Printer() self._initialize_db() diff --git a/src/crewai/memory/storage/ltm_sqlite_storage.py b/src/crewai/memory/storage/ltm_sqlite_storage.py index 3d12087229..35f54e0e77 100644 --- a/src/crewai/memory/storage/ltm_sqlite_storage.py +++ b/src/crewai/memory/storage/ltm_sqlite_storage.py @@ -17,7 +17,7 @@ def __init__( ) -> None: if db_path is None: # Get the parent directory of the default db path and create our db file there - db_path = str(Path(db_storage_path()).parent / "long_term_memory_storage.db") + db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db") self.db_path = db_path self._printer: Printer = Printer() # Ensure parent directory exists diff --git a/src/crewai/utilities/paths.py b/src/crewai/utilities/paths.py index 5d91d1719e..853c612c34 100644 --- a/src/crewai/utilities/paths.py +++ b/src/crewai/utilities/paths.py @@ -7,7 +7,7 @@ def db_storage_path() -> str: """Returns the path for SQLite database storage. - + Returns: str: Full path to the SQLite database file """ @@ -16,7 +16,7 @@ def db_storage_path() -> str: data_dir = Path(appdirs.user_data_dir(app_name, app_author)) data_dir.mkdir(parents=True, exist_ok=True) - return str(data_dir / "crewai_flows.db") + return str(data_dir) def get_project_directory_name(): @@ -28,4 +28,4 @@ def get_project_directory_name(): else: cwd = Path.cwd() project_directory_name = cwd.name - return project_directory_name + return project_directory_name \ No newline at end of file diff --git a/src/crewai/utilities/printer.py b/src/crewai/utilities/printer.py index abebf6aae8..74ad9a30b4 100644 --- a/src/crewai/utilities/printer.py +++ b/src/crewai/utilities/printer.py @@ -21,6 +21,16 @@ def print(self, content: str, color: Optional[str] = None): self._print_yellow(content) elif color == "bold_yellow": self._print_bold_yellow(content) + elif color == "cyan": + self._print_cyan(content) + elif color == "bold_cyan": + self._print_bold_cyan(content) + elif color == "magenta": + self._print_magenta(content) + elif color == "bold_magenta": + self._print_bold_magenta(content) + elif color == "green": + self._print_green(content) else: print(content) @@ -44,3 +54,18 @@ def _print_yellow(self, content): def _print_bold_yellow(self, content): print("\033[1m\033[93m {}\033[00m".format(content)) + + def _print_cyan(self, content): + print("\033[96m {}\033[00m".format(content)) + + def _print_bold_cyan(self, content): + print("\033[1m\033[96m {}\033[00m".format(content)) + + def _print_magenta(self, content): + print("\033[35m {}\033[00m".format(content)) + + def _print_bold_magenta(self, content): + print("\033[1m\033[35m {}\033[00m".format(content)) + + def _print_green(self, content): + print("\033[32m {}\033[00m".format(content)) diff --git a/tests/test_flow_default_override.py b/tests/test_flow_default_override.py new file mode 100644 index 0000000000..f11b779821 --- /dev/null +++ b/tests/test_flow_default_override.py @@ -0,0 +1,112 @@ +"""Test that persisted state properly overrides default values.""" + +from crewai.flow.flow import Flow, FlowState, listen, start +from crewai.flow.persistence import persist + + +class PoemState(FlowState): + """Test state model with default values that should be overridden.""" + sentence_count: int = 1000 # Default that should be overridden + has_set_count: bool = False # Track whether we've set the count + poem_type: str = "" + + +def test_default_value_override(): + """Test that persisted state values override class defaults.""" + + @persist() + class PoemFlow(Flow[PoemState]): + initial_state = PoemState + + @start() + def set_sentence_count(self): + if self.state.has_set_count and self.state.sentence_count == 2: + self.state.sentence_count = 3 + + elif self.state.has_set_count and self.state.sentence_count == 1000: + self.state.sentence_count = 1000 + + elif self.state.has_set_count and self.state.sentence_count == 5: + self.state.sentence_count = 5 + + else: + self.state.sentence_count = 2 + self.state.has_set_count = True + + # First run - should set sentence_count to 2 + flow1 = PoemFlow() + flow1.kickoff() + original_uuid = flow1.state.id + assert flow1.state.sentence_count == 2 + + # Second run - should load sentence_count=2 instead of default 1000 + flow2 = PoemFlow() + flow2.kickoff(inputs={"id": original_uuid}) + assert flow2.state.sentence_count == 3 # Should load 2, not default 1000 + + # Fourth run - explicit override should work + flow3 = PoemFlow() + flow3.kickoff(inputs={ + "id": original_uuid, + "has_set_count": True, + "sentence_count": 5, # Override persisted value + }) + assert flow3.state.sentence_count == 5 # Should use override value + + # Third run - should not load sentence_count=2 instead of default 1000 + flow4 = PoemFlow() + flow4.kickoff(inputs={"has_set_count": True}) + assert flow4.state.sentence_count == 1000 # Should load 1000, not 2 + + +def test_multi_step_default_override(): + """Test default value override with multiple start methods.""" + + @persist() + class MultiStepPoemFlow(Flow[PoemState]): + initial_state = PoemState + + @start() + def set_sentence_count(self): + print("Setting sentence count") + if not self.state.has_set_count: + self.state.sentence_count = 3 + self.state.has_set_count = True + + @listen(set_sentence_count) + def set_poem_type(self): + print("Setting poem type") + if self.state.sentence_count == 3: + self.state.poem_type = "haiku" + elif self.state.sentence_count == 5: + self.state.poem_type = "limerick" + else: + self.state.poem_type = "free_verse" + + @listen(set_poem_type) + def finished(self): + print("finished") + + # First run - should set both sentence count and poem type + flow1 = MultiStepPoemFlow() + flow1.kickoff() + original_uuid = flow1.state.id + assert flow1.state.sentence_count == 3 + assert flow1.state.poem_type == "haiku" + + # Second run - should load persisted state and update poem type + flow2 = MultiStepPoemFlow() + flow2.kickoff(inputs={ + "id": original_uuid, + "sentence_count": 5 + }) + assert flow2.state.sentence_count == 5 + assert flow2.state.poem_type == "limerick" + + # Third run - new flow without persisted state should use defaults + flow3 = MultiStepPoemFlow() + flow3.kickoff(inputs={ + "id": original_uuid + }) + assert flow3.state.sentence_count == 5 + assert flow3.state.poem_type == "limerick" \ No newline at end of file diff --git a/tests/test_flow_persistence.py b/tests/test_flow_persistence.py index 0d1cfe3c38..e51806b058 100644 --- a/tests/test_flow_persistence.py +++ b/tests/test_flow_persistence.py @@ -1,12 +1,12 @@ """Test flow state persistence functionality.""" import os -from typing import Dict, Optional +from typing import Dict import pytest from pydantic import BaseModel -from crewai.flow.flow import Flow, FlowState, start +from crewai.flow.flow import Flow, FlowState, listen, start from crewai.flow.persistence import persist from crewai.flow.persistence.sqlite import SQLiteFlowPersistence @@ -73,13 +73,14 @@ def test_flow_state_restoration(tmp_path): # First flow execution to create initial state class RestorableFlow(Flow[TestState]): - initial_state = TestState @start() @persist(persistence) def set_message(self): - self.state.message = "Original message" - self.state.counter = 42 + if self.state.message == "": + self.state.message = "Original message" + if self.state.counter == 0: + self.state.counter = 42 # Create and persist initial state flow1 = RestorableFlow(persistence=persistence) @@ -87,11 +88,11 @@ def set_message(self): original_uuid = flow1.state.id # Test case 1: Restore using restore_uuid with field override - flow2 = RestorableFlow( - persistence=persistence, - restore_uuid=original_uuid, - counter=43, # Override counter - ) + flow2 = RestorableFlow(persistence=persistence) + flow2.kickoff(inputs={ + "id": original_uuid, + "counter": 43 + }) # Verify state restoration and selective field override assert flow2.state.id == original_uuid @@ -99,48 +100,17 @@ def set_message(self): assert flow2.state.counter == 43 # Overridden # Test case 2: Restore using kwargs['id'] - flow3 = RestorableFlow( - persistence=persistence, - id=original_uuid, - message="Updated message", # Override message - ) + flow3 = RestorableFlow(persistence=persistence) + flow3.kickoff(inputs={ + "id": original_uuid, + "message": "Updated message" + }) # Verify state restoration and selective field override assert flow3.state.id == original_uuid - assert flow3.state.counter == 42 # Preserved + assert flow3.state.counter == 43 # Preserved assert flow3.state.message == "Updated message" # Overridden - # Test case 3: Verify error on conflicting IDs - with pytest.raises(ValueError) as exc_info: - RestorableFlow( - persistence=persistence, - restore_uuid=original_uuid, - id="different-id", # Conflict with restore_uuid - ) - assert "Conflicting IDs provided" in str(exc_info.value) - - # Test case 4: Verify error on non-existent restore_uuid - with pytest.raises(ValueError) as exc_info: - RestorableFlow( - persistence=persistence, - restore_uuid="non-existent-uuid", - ) - assert "No state found" in str(exc_info.value) - - # Test case 5: Allow new state creation with kwargs['id'] - new_uuid = "new-flow-id" - flow4 = RestorableFlow( - persistence=persistence, - id=new_uuid, - message="New message", - counter=100, - ) - - # Verify new state creation with provided ID - assert flow4.state.id == new_uuid - assert flow4.state.message == "New message" - assert flow4.state.counter == 100 - def test_multiple_method_persistence(tmp_path): """Test state persistence across multiple method executions.""" @@ -148,48 +118,59 @@ def test_multiple_method_persistence(tmp_path): persistence = SQLiteFlowPersistence(db_path) class MultiStepFlow(Flow[TestState]): - initial_state = TestState - @start() @persist(persistence) def step_1(self): - self.state.counter = 1 - self.state.message = "Step 1" - - @start() + if self.state.counter == 1: + self.state.counter = 99999 + self.state.message = "Step 99999" + else: + self.state.counter = 1 + self.state.message = "Step 1" + + @listen(step_1) @persist(persistence) def step_2(self): - self.state.counter = 2 - self.state.message = "Step 2" + if self.state.counter == 1: + self.state.counter = 2 + self.state.message = "Step 2" flow = MultiStepFlow(persistence=persistence) flow.kickoff() + flow2 = MultiStepFlow(persistence=persistence) + flow2.kickoff(inputs={"id": flow.state.id}) + # Load final state - final_state = persistence.load_state(flow.state.id) + final_state = flow2.state assert final_state is not None - assert final_state["counter"] == 2 - assert final_state["message"] == "Step 2" - - -def test_persistence_error_handling(tmp_path): - """Test error handling in persistence operations.""" - db_path = os.path.join(tmp_path, "test_flows.db") - persistence = SQLiteFlowPersistence(db_path) - - class InvalidFlow(Flow[TestState]): - # Missing id field in initial state - class InvalidState(BaseModel): - value: str = "" - - initial_state = InvalidState + assert final_state.counter == 2 + assert final_state.message == "Step 2" + class NoPersistenceMultiStepFlow(Flow[TestState]): @start() @persist(persistence) - def will_fail(self): - self.state.value = "test" + def step_1(self): + if self.state.counter == 1: + self.state.counter = 99999 + self.state.message = "Step 99999" + else: + self.state.counter = 1 + self.state.message = "Step 1" + + @listen(step_1) + def step_2(self): + if self.state.counter == 1: + self.state.counter = 2 + self.state.message = "Step 2" + + flow = NoPersistenceMultiStepFlow(persistence=persistence) + flow.kickoff() - with pytest.raises(ValueError) as exc_info: - flow = InvalidFlow(persistence=persistence) + flow2 = NoPersistenceMultiStepFlow(persistence=persistence) + flow2.kickoff(inputs={"id": flow.state.id}) - assert "must have an 'id' field" in str(exc_info.value) + # Load final state + final_state = flow2.state + assert final_state.counter == 99999 + assert final_state.message == "Step 99999"