Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Kking112 authored Jan 20, 2025
2 parents ab8c269 + ab2274c commit d7c9c40
Show file tree
Hide file tree
Showing 9 changed files with 337 additions and 220 deletions.
120 changes: 48 additions & 72 deletions src/crewai/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,14 +447,12 @@ class _FlowGeneric(cls): # type: ignore
def __init__(
self,
persistence: Optional[FlowPersistence] = None,
restore_uuid: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Initialize a new Flow instance.
Args:
persistence: Optional persistence backend for storing flow states
restore_uuid: Optional UUID to restore state from persistence
**kwargs: Additional state values to initialize or override
"""
# Initialize basic instance attributes
Expand All @@ -464,64 +462,12 @@ def __init__(
self._method_outputs: List[Any] = [] # List to store all method outputs
self._persistence: Optional[FlowPersistence] = persistence

# Validate state model before initialization
if isinstance(self.initial_state, type):
if issubclass(self.initial_state, BaseModel) and not issubclass(
self.initial_state, FlowState
):
# Check if model has id field
model_fields = getattr(self.initial_state, "model_fields", None)
if not model_fields or "id" not in model_fields:
raise ValueError("Flow state model must have an 'id' field")

# Handle persistence and potential ID conflicts
stored_state = None
if self._persistence is not None:
if (
restore_uuid
and kwargs
and "id" in kwargs
and restore_uuid != kwargs["id"]
):
raise ValueError(
f"Conflicting IDs provided: restore_uuid='{restore_uuid}' "
f"vs kwargs['id']='{kwargs['id']}'. Use only one ID for restoration."
)
# Initialize state with initial values
self._state = self._create_initial_state()

# Attempt to load state, prioritizing restore_uuid
if restore_uuid:
self._log_flow_event(f"Loading flow state from memory for UUID: {restore_uuid}", color="bold_yellow")
stored_state = self._persistence.load_state(restore_uuid)
if not stored_state:
raise ValueError(
f"No state found for restore_uuid='{restore_uuid}'"
)
elif kwargs and "id" in kwargs:
self._log_flow_event(f"Loading flow state from memory for ID: {kwargs['id']}", color="bold_yellow")
stored_state = self._persistence.load_state(kwargs["id"])
if not stored_state:
# For kwargs["id"], we allow creating new state if not found
self._state = self._create_initial_state()
if kwargs:
self._initialize_state(kwargs)
return

# Initialize state based on persistence and kwargs
if stored_state:
# Create initial state and restore from persistence
self._state = self._create_initial_state()
self._restore_state(stored_state)
# Apply any additional kwargs to override specific fields
if kwargs:
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "id"}
if filtered_kwargs:
self._initialize_state(filtered_kwargs)
else:
# No stored state, create new state with initial values
self._state = self._create_initial_state()
# Apply any additional kwargs
if kwargs:
self._initialize_state(kwargs)
# Apply any additional kwargs
if kwargs:
self._initialize_state(kwargs)

self._telemetry.flow_creation_span(self.__class__.__name__)

Expand Down Expand Up @@ -635,18 +581,18 @@ def method_outputs(self) -> List[Any]:
@property
def flow_id(self) -> str:
"""Returns the unique identifier of this flow instance.
This property provides a consistent way to access the flow's unique identifier
regardless of the underlying state implementation (dict or BaseModel).
Returns:
str: The flow's unique identifier, or an empty string if not found
Note:
This property safely handles both dictionary and BaseModel state types,
returning an empty string if the ID cannot be retrieved rather than raising
an exception.
Example:
```python
flow = MyFlow()
Expand All @@ -656,7 +602,7 @@ def flow_id(self) -> str:
try:
if not hasattr(self, '_state'):
return ""

if isinstance(self._state, dict):
return str(self._state.get("id", ""))
elif isinstance(self._state, BaseModel):
Expand Down Expand Up @@ -731,7 +677,6 @@ def _restore_state(self, stored_state: Dict[str, Any]) -> None:
"""
# When restoring from persistence, use the stored ID
stored_id = stored_state.get("id")
self._log_flow_event(f"Restoring flow state from memory for ID: {stored_id}", color="bold_yellow")
if not stored_id:
raise ValueError("Stored state must have an 'id' field")

Expand All @@ -755,17 +700,48 @@ def _restore_state(self, stored_state: Dict[str, Any]) -> None:
raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}")

def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
"""Start the flow execution.
Args:
inputs: Optional dictionary containing input values and potentially a state ID to restore
"""
# Handle state restoration if ID is provided in inputs
if inputs and 'id' in inputs and self._persistence is not None:
restore_uuid = inputs['id']
stored_state = self._persistence.load_state(restore_uuid)

# Override the id in the state if it exists in inputs
if 'id' in inputs:
if isinstance(self._state, dict):
self._state['id'] = inputs['id']
elif isinstance(self._state, BaseModel):
setattr(self._state, 'id', inputs['id'])

if stored_state:
self._log_flow_event(f"Loading flow state from memory for UUID: {restore_uuid}", color="yellow")
# Restore the state
self._restore_state(stored_state)
else:
self._log_flow_event(f"No flow state found for UUID: {restore_uuid}", color="red")

# Apply any additional inputs after restoration
filtered_inputs = {k: v for k, v in inputs.items() if k != 'id'}
if filtered_inputs:
self._initialize_state(filtered_inputs)

# Start flow execution
self.event_emitter.send(
self,
event=FlowStartedEvent(
type="flow_started",
flow_name=self.__class__.__name__,
),
)
self._log_flow_event(f"Flow started with ID: {self.flow_id}", color="yellow")
self._log_flow_event(f"Flow started with ID: {self.flow_id}", color="bold_magenta")

if inputs is not None:
if inputs is not None and 'id' not in inputs:
self._initialize_state(inputs)

return asyncio.run(self.kickoff_async())

async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
Expand Down Expand Up @@ -1010,18 +986,18 @@ async def _execute_single_listener(self, listener_name: str, result: Any) -> Non

def _log_flow_event(self, message: str, color: str = "yellow", level: str = "info") -> None:
"""Centralized logging method for flow events.
This method provides a consistent interface for logging flow-related events,
combining both console output with colors and proper logging levels.
Args:
message: The message to log
color: Color to use for console output (default: yellow)
Available colors: purple, red, bold_green, bold_purple,
bold_blue, yellow, bold_yellow
bold_blue, yellow, yellow
level: Log level to use (default: info)
Supported levels: info, warning
Note:
This method uses the Printer utility for colored console output
and the standard logging module for log level support.
Expand All @@ -1031,7 +1007,7 @@ def _log_flow_event(self, message: str, color: str = "yellow", level: str = "inf
logger.info(message)
elif level == "warning":
logger.warning(message)

def plot(self, filename: str = "crewai_flow") -> None:
self._telemetry.flow_plotting_span(
self.__class__.__name__, list(self._methods.keys())
Expand Down
120 changes: 72 additions & 48 deletions src/crewai/flow/persistence/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,57 +54,44 @@ async def async_method(self):

class PersistenceDecorator:
"""Class to handle flow state persistence with consistent logging."""

_printer = Printer() # Class-level printer instance

@classmethod
def persist_state(cls, flow_instance: Any, method_name: str, persistence_instance: FlowPersistence) -> None:
"""Persist flow state with proper error handling and logging.
This method handles the persistence of flow state data, including proper
error handling and colored console output for status updates.
Args:
flow_instance: The flow instance whose state to persist
method_name: Name of the method that triggered persistence
persistence_instance: The persistence backend to use
Raises:
ValueError: If flow has no state or state lacks an ID
RuntimeError: If state persistence fails
AttributeError: If flow instance lacks required state attributes
Note:
Uses bold_yellow color for success messages and red for errors.
All operations are logged at appropriate levels (info/error).
Example:
```python
@persist
def my_flow_method(self):
# Method implementation
pass
# State will be automatically persisted after method execution
```
"""
try:
state = getattr(flow_instance, 'state', None)
if state is None:
raise ValueError("Flow instance has no state")

flow_uuid: Optional[str] = None
if isinstance(state, dict):
flow_uuid = state.get('id')
elif isinstance(state, BaseModel):
flow_uuid = getattr(state, 'id', None)

if not flow_uuid:
raise ValueError("Flow state must have an 'id' field for persistence")

# Log state saving with consistent message
cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="bold_yellow")
cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan")
logger.info(LOG_MESSAGES["save_state"].format(flow_uuid))

try:
persistence_instance.save_state(
flow_uuid=flow_uuid,
Expand Down Expand Up @@ -154,44 +141,79 @@ class MyFlow(Flow[MyState]):
def begin(self):
pass
"""

def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]:
"""Decorator that handles both class and method decoration."""
actual_persistence = persistence or SQLiteFlowPersistence()

if isinstance(target, type):
# Class decoration
class_methods = {}
original_init = getattr(target, "__init__")

@functools.wraps(original_init)
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
if 'persistence' not in kwargs:
kwargs['persistence'] = actual_persistence
original_init(self, *args, **kwargs)

setattr(target, "__init__", new_init)

# Store original methods to preserve their decorators
original_methods = {}

for name, method in target.__dict__.items():
if callable(method) and hasattr(method, "__is_flow_method__"):
# Wrap each flow method with persistence
if asyncio.iscoroutinefunction(method):
@functools.wraps(method)
async def class_async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
method_coro = method(self, *args, **kwargs)
if asyncio.iscoroutine(method_coro):
result = await method_coro
else:
result = method_coro
PersistenceDecorator.persist_state(self, method.__name__, actual_persistence)
if callable(method) and (
hasattr(method, "__is_start_method__") or
hasattr(method, "__trigger_methods__") or
hasattr(method, "__condition_type__") or
hasattr(method, "__is_flow_method__") or
hasattr(method, "__is_router__")
):
original_methods[name] = method

# Create wrapped versions of the methods that include persistence
for name, method in original_methods.items():
if asyncio.iscoroutinefunction(method):
# Create a closure to capture the current name and method
def create_async_wrapper(method_name: str, original_method: Callable):
@functools.wraps(original_method)
async def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
result = await original_method(self, *args, **kwargs)
PersistenceDecorator.persist_state(self, method_name, actual_persistence)
return result
class_methods[name] = class_async_wrapper
else:
@functools.wraps(method)
def class_sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
result = method(self, *args, **kwargs)
PersistenceDecorator.persist_state(self, method.__name__, actual_persistence)
return method_wrapper

wrapped = create_async_wrapper(name, method)

# Preserve all original decorators and attributes
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(wrapped, attr, getattr(method, attr))
setattr(wrapped, "__is_flow_method__", True)

# Update the class with the wrapped method
setattr(target, name, wrapped)
else:
# Create a closure to capture the current name and method
def create_sync_wrapper(method_name: str, original_method: Callable):
@functools.wraps(original_method)
def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
result = original_method(self, *args, **kwargs)
PersistenceDecorator.persist_state(self, method_name, actual_persistence)
return result
class_methods[name] = class_sync_wrapper
return method_wrapper

wrapped = create_sync_wrapper(name, method)

# Preserve flow-specific attributes
# Preserve all original decorators and attributes
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(class_methods[name], attr, getattr(method, attr))
setattr(class_methods[name], "__is_flow_method__", True)
setattr(wrapped, attr, getattr(method, attr))
setattr(wrapped, "__is_flow_method__", True)

# Update the class with the wrapped method
setattr(target, name, wrapped)

# Update class with wrapped methods
for name, method in class_methods.items():
setattr(target, name, method)
return target
else:
# Method decoration
Expand All @@ -208,6 +230,7 @@ async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) ->
result = method_coro
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence)
return result

for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(method_async_wrapper, attr, getattr(method, attr))
Expand All @@ -219,6 +242,7 @@ def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
result = method(flow_instance, *args, **kwargs)
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence)
return result

for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(method_sync_wrapper, attr, getattr(method, attr))
Expand Down
Loading

0 comments on commit d7c9c40

Please sign in to comment.