Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stateful flows #1931

Merged
merged 15 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 48 additions & 72 deletions src/crewai/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,14 +447,12 @@ class _FlowGeneric(cls): # type: ignore
def __init__(
self,
persistence: Optional[FlowPersistence] = None,
restore_uuid: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Initialize a new Flow instance.

Args:
persistence: Optional persistence backend for storing flow states
restore_uuid: Optional UUID to restore state from persistence
**kwargs: Additional state values to initialize or override
"""
# Initialize basic instance attributes
Expand All @@ -464,64 +462,12 @@ def __init__(
self._method_outputs: List[Any] = [] # List to store all method outputs
self._persistence: Optional[FlowPersistence] = persistence

# Validate state model before initialization
if isinstance(self.initial_state, type):
if issubclass(self.initial_state, BaseModel) and not issubclass(
self.initial_state, FlowState
):
# Check if model has id field
model_fields = getattr(self.initial_state, "model_fields", None)
if not model_fields or "id" not in model_fields:
raise ValueError("Flow state model must have an 'id' field")

# Handle persistence and potential ID conflicts
stored_state = None
if self._persistence is not None:
if (
restore_uuid
and kwargs
and "id" in kwargs
and restore_uuid != kwargs["id"]
):
raise ValueError(
f"Conflicting IDs provided: restore_uuid='{restore_uuid}' "
f"vs kwargs['id']='{kwargs['id']}'. Use only one ID for restoration."
)
# Initialize state with initial values
self._state = self._create_initial_state()

# Attempt to load state, prioritizing restore_uuid
if restore_uuid:
self._log_flow_event(f"Loading flow state from memory for UUID: {restore_uuid}", color="bold_yellow")
stored_state = self._persistence.load_state(restore_uuid)
if not stored_state:
raise ValueError(
f"No state found for restore_uuid='{restore_uuid}'"
)
elif kwargs and "id" in kwargs:
self._log_flow_event(f"Loading flow state from memory for ID: {kwargs['id']}", color="bold_yellow")
stored_state = self._persistence.load_state(kwargs["id"])
if not stored_state:
# For kwargs["id"], we allow creating new state if not found
self._state = self._create_initial_state()
if kwargs:
self._initialize_state(kwargs)
return

# Initialize state based on persistence and kwargs
if stored_state:
# Create initial state and restore from persistence
self._state = self._create_initial_state()
self._restore_state(stored_state)
# Apply any additional kwargs to override specific fields
if kwargs:
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "id"}
if filtered_kwargs:
self._initialize_state(filtered_kwargs)
else:
# No stored state, create new state with initial values
self._state = self._create_initial_state()
# Apply any additional kwargs
if kwargs:
self._initialize_state(kwargs)
# Apply any additional kwargs
if kwargs:
self._initialize_state(kwargs)

self._telemetry.flow_creation_span(self.__class__.__name__)

Expand Down Expand Up @@ -635,18 +581,18 @@ def method_outputs(self) -> List[Any]:
@property
def flow_id(self) -> str:
"""Returns the unique identifier of this flow instance.

This property provides a consistent way to access the flow's unique identifier
regardless of the underlying state implementation (dict or BaseModel).

Returns:
str: The flow's unique identifier, or an empty string if not found

Note:
This property safely handles both dictionary and BaseModel state types,
returning an empty string if the ID cannot be retrieved rather than raising
an exception.

Example:
```python
flow = MyFlow()
Expand All @@ -656,7 +602,7 @@ def flow_id(self) -> str:
try:
if not hasattr(self, '_state'):
return ""

if isinstance(self._state, dict):
return str(self._state.get("id", ""))
elif isinstance(self._state, BaseModel):
Expand Down Expand Up @@ -731,7 +677,6 @@ def _restore_state(self, stored_state: Dict[str, Any]) -> None:
"""
# When restoring from persistence, use the stored ID
stored_id = stored_state.get("id")
self._log_flow_event(f"Restoring flow state from memory for ID: {stored_id}", color="bold_yellow")
if not stored_id:
raise ValueError("Stored state must have an 'id' field")

Expand All @@ -755,17 +700,48 @@ def _restore_state(self, stored_state: Dict[str, Any]) -> None:
raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}")

def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
"""Start the flow execution.

Args:
inputs: Optional dictionary containing input values and potentially a state ID to restore
"""
# Handle state restoration if ID is provided in inputs
if inputs and 'id' in inputs and self._persistence is not None:
restore_uuid = inputs['id']
stored_state = self._persistence.load_state(restore_uuid)

# Override the id in the state if it exists in inputs
if 'id' in inputs:
if isinstance(self._state, dict):
self._state['id'] = inputs['id']
elif isinstance(self._state, BaseModel):
setattr(self._state, 'id', inputs['id'])

if stored_state:
self._log_flow_event(f"Loading flow state from memory for UUID: {restore_uuid}", color="yellow")
# Restore the state
self._restore_state(stored_state)
else:
self._log_flow_event(f"No flow state found for UUID: {restore_uuid}", color="red")

# Apply any additional inputs after restoration
filtered_inputs = {k: v for k, v in inputs.items() if k != 'id'}
if filtered_inputs:
self._initialize_state(filtered_inputs)

# Start flow execution
self.event_emitter.send(
self,
event=FlowStartedEvent(
type="flow_started",
flow_name=self.__class__.__name__,
),
)
self._log_flow_event(f"Flow started with ID: {self.flow_id}", color="yellow")
self._log_flow_event(f"Flow started with ID: {self.flow_id}", color="bold_magenta")

if inputs is not None:
if inputs is not None and 'id' not in inputs:
self._initialize_state(inputs)

return asyncio.run(self.kickoff_async())

async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
Expand Down Expand Up @@ -1010,18 +986,18 @@ async def _execute_single_listener(self, listener_name: str, result: Any) -> Non

def _log_flow_event(self, message: str, color: str = "yellow", level: str = "info") -> None:
"""Centralized logging method for flow events.

This method provides a consistent interface for logging flow-related events,
combining both console output with colors and proper logging levels.

Args:
message: The message to log
color: Color to use for console output (default: yellow)
Available colors: purple, red, bold_green, bold_purple,
bold_blue, yellow, bold_yellow
bold_blue, yellow, yellow
level: Log level to use (default: info)
Supported levels: info, warning

Note:
This method uses the Printer utility for colored console output
and the standard logging module for log level support.
Expand All @@ -1031,7 +1007,7 @@ def _log_flow_event(self, message: str, color: str = "yellow", level: str = "inf
logger.info(message)
elif level == "warning":
logger.warning(message)

def plot(self, filename: str = "crewai_flow") -> None:
self._telemetry.flow_plotting_span(
self.__class__.__name__, list(self._methods.keys())
Expand Down
120 changes: 72 additions & 48 deletions src/crewai/flow/persistence/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,57 +54,44 @@ async def async_method(self):

class PersistenceDecorator:
"""Class to handle flow state persistence with consistent logging."""

_printer = Printer() # Class-level printer instance

@classmethod
def persist_state(cls, flow_instance: Any, method_name: str, persistence_instance: FlowPersistence) -> None:
"""Persist flow state with proper error handling and logging.

This method handles the persistence of flow state data, including proper
error handling and colored console output for status updates.

Args:
flow_instance: The flow instance whose state to persist
method_name: Name of the method that triggered persistence
persistence_instance: The persistence backend to use

Raises:
ValueError: If flow has no state or state lacks an ID
RuntimeError: If state persistence fails
AttributeError: If flow instance lacks required state attributes

Note:
Uses bold_yellow color for success messages and red for errors.
All operations are logged at appropriate levels (info/error).

Example:
```python
@persist
def my_flow_method(self):
# Method implementation
pass
# State will be automatically persisted after method execution
```
"""
try:
state = getattr(flow_instance, 'state', None)
if state is None:
raise ValueError("Flow instance has no state")

flow_uuid: Optional[str] = None
if isinstance(state, dict):
flow_uuid = state.get('id')
elif isinstance(state, BaseModel):
flow_uuid = getattr(state, 'id', None)

if not flow_uuid:
raise ValueError("Flow state must have an 'id' field for persistence")

# Log state saving with consistent message
cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="bold_yellow")
cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan")
logger.info(LOG_MESSAGES["save_state"].format(flow_uuid))

try:
persistence_instance.save_state(
flow_uuid=flow_uuid,
Expand Down Expand Up @@ -154,44 +141,79 @@ class MyFlow(Flow[MyState]):
def begin(self):
pass
"""

def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]:
"""Decorator that handles both class and method decoration."""
actual_persistence = persistence or SQLiteFlowPersistence()

if isinstance(target, type):
# Class decoration
class_methods = {}
original_init = getattr(target, "__init__")

@functools.wraps(original_init)
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
if 'persistence' not in kwargs:
kwargs['persistence'] = actual_persistence
original_init(self, *args, **kwargs)

setattr(target, "__init__", new_init)

# Store original methods to preserve their decorators
original_methods = {}

for name, method in target.__dict__.items():
if callable(method) and hasattr(method, "__is_flow_method__"):
# Wrap each flow method with persistence
if asyncio.iscoroutinefunction(method):
@functools.wraps(method)
async def class_async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
method_coro = method(self, *args, **kwargs)
if asyncio.iscoroutine(method_coro):
result = await method_coro
else:
result = method_coro
PersistenceDecorator.persist_state(self, method.__name__, actual_persistence)
if callable(method) and (
hasattr(method, "__is_start_method__") or
hasattr(method, "__trigger_methods__") or
hasattr(method, "__condition_type__") or
hasattr(method, "__is_flow_method__") or
hasattr(method, "__is_router__")
):
original_methods[name] = method

# Create wrapped versions of the methods that include persistence
for name, method in original_methods.items():
if asyncio.iscoroutinefunction(method):
# Create a closure to capture the current name and method
def create_async_wrapper(method_name: str, original_method: Callable):
@functools.wraps(original_method)
async def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
result = await original_method(self, *args, **kwargs)
PersistenceDecorator.persist_state(self, method_name, actual_persistence)
return result
class_methods[name] = class_async_wrapper
else:
@functools.wraps(method)
def class_sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
result = method(self, *args, **kwargs)
PersistenceDecorator.persist_state(self, method.__name__, actual_persistence)
return method_wrapper

wrapped = create_async_wrapper(name, method)

# Preserve all original decorators and attributes
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(wrapped, attr, getattr(method, attr))
setattr(wrapped, "__is_flow_method__", True)

# Update the class with the wrapped method
setattr(target, name, wrapped)
else:
# Create a closure to capture the current name and method
def create_sync_wrapper(method_name: str, original_method: Callable):
@functools.wraps(original_method)
def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
result = original_method(self, *args, **kwargs)
PersistenceDecorator.persist_state(self, method_name, actual_persistence)
return result
class_methods[name] = class_sync_wrapper
return method_wrapper

wrapped = create_sync_wrapper(name, method)

# Preserve flow-specific attributes
# Preserve all original decorators and attributes
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(class_methods[name], attr, getattr(method, attr))
setattr(class_methods[name], "__is_flow_method__", True)
setattr(wrapped, attr, getattr(method, attr))
setattr(wrapped, "__is_flow_method__", True)

# Update the class with the wrapped method
setattr(target, name, wrapped)

# Update class with wrapped methods
for name, method in class_methods.items():
setattr(target, name, method)
return target
else:
# Method decoration
Expand All @@ -208,6 +230,7 @@ async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) ->
result = method_coro
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence)
return result

for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(method_async_wrapper, attr, getattr(method, attr))
Expand All @@ -219,6 +242,7 @@ def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
result = method(flow_instance, *args, **kwargs)
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence)
return result

for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(method_sync_wrapper, attr, getattr(method, attr))
Expand Down
Loading
Loading