Skip to content

Commit

Permalink
Merge branch 'main' into devin/1737272386-flow-override-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
joaomdmoura authored Jan 19, 2025
2 parents d3b6734 + 3e4f112 commit 48b0f3b
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 35 deletions.
66 changes: 66 additions & 0 deletions src/crewai/flow/flow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import inspect
import logging
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -28,6 +29,9 @@
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.utils import get_possible_return_constants
from crewai.telemetry import Telemetry
from crewai.utilities.printer import Printer

logger = logging.getLogger(__name__)


class FlowState(BaseModel):
Expand Down Expand Up @@ -424,6 +428,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
Type parameter T must be either Dict[str, Any] or a subclass of BaseModel."""

_telemetry = Telemetry()
_printer = Printer()

_start_methods: List[str] = []
_listeners: Dict[str, tuple[str, List[str]]] = {}
Expand Down Expand Up @@ -485,12 +490,14 @@ def __init__(

# Attempt to load state, prioritizing restore_uuid
if restore_uuid:
self._log_flow_event(f"Loading flow state from memory for UUID: {restore_uuid}", color="bold_yellow")
stored_state = self._persistence.load_state(restore_uuid)
if not stored_state:
raise ValueError(
f"No state found for restore_uuid='{restore_uuid}'"
)
elif kwargs and "id" in kwargs:
self._log_flow_event(f"Loading flow state from memory for ID: {kwargs['id']}", color="bold_yellow")
stored_state = self._persistence.load_state(kwargs["id"])
# Don't return early if state not found - let the normal initialization flow handle it
# This ensures proper state initialization and override behavior
Expand Down Expand Up @@ -621,6 +628,39 @@ def method_outputs(self) -> List[Any]:
"""Returns the list of all outputs from executed methods."""
return self._method_outputs

@property
def flow_id(self) -> str:
"""Returns the unique identifier of this flow instance.
This property provides a consistent way to access the flow's unique identifier
regardless of the underlying state implementation (dict or BaseModel).
Returns:
str: The flow's unique identifier, or an empty string if not found
Note:
This property safely handles both dictionary and BaseModel state types,
returning an empty string if the ID cannot be retrieved rather than raising
an exception.
Example:
```python
flow = MyFlow()
print(f"Current flow ID: {flow.flow_id}") # Safely get flow ID
```
"""
try:
if not hasattr(self, '_state'):
return ""

if isinstance(self._state, dict):
return str(self._state.get("id", ""))
elif isinstance(self._state, BaseModel):
return str(getattr(self._state, "id", ""))
return ""
except (AttributeError, TypeError):
return "" # Safely handle any unexpected attribute access issues

def _initialize_state(self, inputs: Dict[str, Any]) -> None:
"""Initialize or update flow state with new inputs.
Expand Down Expand Up @@ -687,6 +727,7 @@ def _restore_state(self, stored_state: Dict[str, Any]) -> None:
"""
# When restoring from persistence, use the stored ID
stored_id = stored_state.get("id")
self._log_flow_event(f"Restoring flow state from memory for ID: {stored_id}", color="bold_yellow")
if not stored_id:
raise ValueError("Stored state must have an 'id' field")

Expand Down Expand Up @@ -717,6 +758,7 @@ def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
flow_name=self.__class__.__name__,
),
)
self._log_flow_event(f"Flow started with ID: {self.flow_id}", color="yellow")

if inputs is not None:
self._initialize_state(inputs)
Expand Down Expand Up @@ -962,6 +1004,30 @@ async def _execute_single_listener(self, listener_name: str, result: Any) -> Non

traceback.print_exc()

def _log_flow_event(self, message: str, color: str = "yellow", level: str = "info") -> None:
"""Centralized logging method for flow events.
This method provides a consistent interface for logging flow-related events,
combining both console output with colors and proper logging levels.
Args:
message: The message to log
color: Color to use for console output (default: yellow)
Available colors: purple, red, bold_green, bold_purple,
bold_blue, yellow, bold_yellow
level: Log level to use (default: info)
Supported levels: info, warning
Note:
This method uses the Printer utility for colored console output
and the standard logging module for log level support.
"""
self._printer.print(message, color=color)
if level == "info":
logger.info(message)
elif level == "warning":
logger.warning(message)

def plot(self, filename: str = "crewai_flow") -> None:
self._telemetry.flow_plotting_span(
self.__class__.__name__, list(self._methods.keys())
Expand Down
124 changes: 89 additions & 35 deletions src/crewai/flow/persistence/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,95 @@ async def async_method(self):

from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
from crewai.utilities.printer import Printer

logger = logging.getLogger(__name__)
T = TypeVar("T")

# Constants for log messages
LOG_MESSAGES = {
"save_state": "Saving flow state to memory for ID: {}",
"save_error": "Failed to persist state for method {}: {}",
"state_missing": "Flow instance has no state",
"id_missing": "Flow state must have an 'id' field for persistence"
}


class PersistenceDecorator:
"""Class to handle flow state persistence with consistent logging."""

_printer = Printer() # Class-level printer instance

@classmethod
def persist_state(cls, flow_instance: Any, method_name: str, persistence_instance: FlowPersistence) -> None:
"""Persist flow state with proper error handling and logging.
This method handles the persistence of flow state data, including proper
error handling and colored console output for status updates.
Args:
flow_instance: The flow instance whose state to persist
method_name: Name of the method that triggered persistence
persistence_instance: The persistence backend to use
Raises:
ValueError: If flow has no state or state lacks an ID
RuntimeError: If state persistence fails
AttributeError: If flow instance lacks required state attributes
Note:
Uses bold_yellow color for success messages and red for errors.
All operations are logged at appropriate levels (info/error).
Example:
```python
@persist
def my_flow_method(self):
# Method implementation
pass
# State will be automatically persisted after method execution
```
"""
try:
state = getattr(flow_instance, 'state', None)
if state is None:
raise ValueError("Flow instance has no state")

flow_uuid: Optional[str] = None
if isinstance(state, dict):
flow_uuid = state.get('id')
elif isinstance(state, BaseModel):
flow_uuid = getattr(state, 'id', None)

if not flow_uuid:
raise ValueError("Flow state must have an 'id' field for persistence")

# Log state saving with consistent message
cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="bold_yellow")
logger.info(LOG_MESSAGES["save_state"].format(flow_uuid))

try:
persistence_instance.save_state(
flow_uuid=flow_uuid,
method_name=method_name,
state_data=state,
)
except Exception as e:
error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e))
cls._printer.print(error_msg, color="red")
logger.error(error_msg)
raise RuntimeError(f"State persistence failed: {str(e)}") from e
except AttributeError:
error_msg = LOG_MESSAGES["state_missing"]
cls._printer.print(error_msg, color="red")
logger.error(error_msg)
raise ValueError(error_msg)
except (TypeError, ValueError) as e:
error_msg = LOG_MESSAGES["id_missing"]
cls._printer.print(error_msg, color="red")
logger.error(error_msg)
raise ValueError(error_msg) from e


def persist(persistence: Optional[FlowPersistence] = None):
"""Decorator to persist flow state.
Expand Down Expand Up @@ -69,37 +154,6 @@ class MyFlow(Flow[MyState]):
def begin(self):
pass
"""
def _persist_state(flow_instance: Any, method_name: str, persistence_instance: FlowPersistence) -> None:
"""Helper to persist state with error handling."""
try:
# Get flow UUID from state
state = getattr(flow_instance, 'state', None)
if state is None:
raise ValueError("Flow instance has no state")

flow_uuid: Optional[str] = None
if isinstance(state, dict):
flow_uuid = state.get('id')
elif isinstance(state, BaseModel):
flow_uuid = getattr(state, 'id', None)

if not flow_uuid:
raise ValueError(
"Flow state must have an 'id' field for persistence"
)

# Persist the state
persistence_instance.save_state(
flow_uuid=flow_uuid,
method_name=method_name,
state_data=state,
)
except Exception as e:
logger.error(
f"Failed to persist state for method {method_name}: {str(e)}"
)
raise RuntimeError(f"State persistence failed: {str(e)}") from e

def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]:
"""Decorator that handles both class and method decoration."""
actual_persistence = persistence or SQLiteFlowPersistence()
Expand All @@ -118,14 +172,14 @@ async def class_async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
result = await method_coro
else:
result = method_coro
_persist_state(self, method.__name__, actual_persistence)
PersistenceDecorator.persist_state(self, method.__name__, actual_persistence)
return result
class_methods[name] = class_async_wrapper
else:
@functools.wraps(method)
def class_sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
result = method(self, *args, **kwargs)
_persist_state(self, method.__name__, actual_persistence)
PersistenceDecorator.persist_state(self, method.__name__, actual_persistence)
return result
class_methods[name] = class_sync_wrapper

Expand All @@ -152,7 +206,7 @@ async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) ->
result = await method_coro
else:
result = method_coro
_persist_state(flow_instance, method.__name__, actual_persistence)
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence)
return result
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
Expand All @@ -163,7 +217,7 @@ async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) ->
@functools.wraps(method)
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
result = method(flow_instance, *args, **kwargs)
_persist_state(flow_instance, method.__name__, actual_persistence)
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence)
return result
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
Expand Down

0 comments on commit 48b0f3b

Please sign in to comment.