Skip to content

Commit

Permalink
Merge branch 'main' into devin/1737272162-colored-logging
Browse files Browse the repository at this point in the history
  • Loading branch information
joaomdmoura authored Jan 19, 2025
2 parents ed2172a + cc018bf commit b5019a8
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 66 deletions.
9 changes: 2 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,20 @@ dependencies = [
"openai>=1.13.3",
"litellm==1.57.4",
"instructor>=1.3.3",

# Text Processing
"pdfplumber>=0.11.4",
"regex>=2024.9.11",

# Telemetry and Monitoring
"opentelemetry-api>=1.22.0",
"opentelemetry-sdk>=1.22.0",
"opentelemetry-exporter-otlp-proto-http>=1.22.0",

# Data Handling
"chromadb>=0.5.23",
"openpyxl>=3.1.5",
"pyvis>=0.3.2",

# Authentication and Security
"auth0-python>=4.7.1",
"python-dotenv>=1.0.0",

# Configuration and Utils
"click>=8.1.7",
"appdirs>=1.4.4",
Expand All @@ -40,7 +35,7 @@ dependencies = [
"uv>=0.4.25",
"tomli-w>=1.1.0",
"tomli>=2.0.2",
"blinker>=1.9.0"
"blinker>=1.9.0",
]

[project.urls]
Expand All @@ -49,7 +44,7 @@ Documentation = "https://docs.crewai.com"
Repository = "https://github.com/crewAIInc/crewAI"

[project.optional-dependencies]
tools = ["crewai-tools>=0.25.5"]
tools = ["crewai-tools>=0.32.1"]
embeddings = [
"tiktoken~=0.7.0"
]
Expand Down
2 changes: 1 addition & 1 deletion src/crewai/cli/templates/flow/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pydantic import BaseModel

from crewai.flow.flow import Flow, listen, start
from crewai.flow import Flow, listen, start

from {{folder_name}}.crews.poem_crew.poem_crew import PoemCrew

Expand Down
6 changes: 4 additions & 2 deletions src/crewai/flow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from crewai.flow.flow import Flow
from crewai.flow.flow import Flow, start, listen, or_, and_, router
from crewai.flow.persistence import persist

__all__ = ["Flow", "start", "listen", "or_", "and_", "router", "persist"]

__all__ = ["Flow"]
7 changes: 3 additions & 4 deletions src/crewai/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
ValueError: If state validation fails
"""
"""Ensure state matches expected type with proper validation.
Args:
state: State instance to validate
expected_type: Expected type for the state
Returns:
Validated state instance
Raises:
TypeError: If state doesn't match expected type
ValueError: If state validation fails
Expand Down Expand Up @@ -619,7 +619,6 @@ class StateWithId(state_type, FlowState): # type: ignore
# Create new instance of the same class
model_class = type(model)
return cast(T, model_class(**state_dict))

raise TypeError(
f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
)
Expand Down
29 changes: 12 additions & 17 deletions src/crewai/flow/persistence/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
```python
from crewai.flow.flow import Flow, start
from crewai.flow.persistence import persist, SQLiteFlowPersistence
class MyFlow(Flow):
@start()
@persist(SQLiteFlowPersistence())
def sync_method(self):
# Synchronous method implementation
pass
@start()
@persist(SQLiteFlowPersistence())
async def async_method(self):
Expand All @@ -23,18 +23,15 @@ async def async_method(self):

import asyncio
import functools
import inspect
import logging
from typing import (
Any,
Callable,
Dict,
Optional,
Type,
TypeVar,
Union,
cast,
get_type_hints,
)

from pydantic import BaseModel
Expand Down Expand Up @@ -133,36 +130,34 @@ def my_flow_method(self):

def persist(persistence: Optional[FlowPersistence] = None):
"""Decorator to persist flow state.
This decorator can be applied at either the class level or method level.
When applied at the class level, it automatically persists all flow method
states. When applied at the method level, it persists only that method's
state.
Args:
persistence: Optional FlowPersistence implementation to use.
If not provided, uses SQLiteFlowPersistence.
Returns:
A decorator that can be applied to either a class or method
Raises:
ValueError: If the flow state doesn't have an 'id' field
RuntimeError: If state persistence fails
Example:
@persist # Class-level persistence with default SQLite
class MyFlow(Flow[MyState]):
@start()
def begin(self):
pass
"""
# Helper function moved to PersistenceDecorator class

def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]:
"""Decorator that handles both class and method decoration."""
actual_persistence = persistence or SQLiteFlowPersistence()

if isinstance(target, type):
# Class decoration
class_methods = {}
Expand All @@ -187,13 +182,13 @@ def class_sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
PersistenceDecorator.persist_state(self, method.__name__, actual_persistence)
return result
class_methods[name] = class_sync_wrapper

# Preserve flow-specific attributes
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(class_methods[name], attr, getattr(method, attr))
setattr(class_methods[name], "__is_flow_method__", True)

# Update class with wrapped methods
for name, method in class_methods.items():
setattr(target, name, method)
Expand All @@ -202,7 +197,7 @@ def class_sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
# Method decoration
method = target
setattr(method, "__is_flow_method__", True)

if asyncio.iscoroutinefunction(method):
@functools.wraps(method)
async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
Expand All @@ -229,5 +224,5 @@ def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
setattr(method_sync_wrapper, attr, getattr(method, attr))
setattr(method_sync_wrapper, "__is_flow_method__", True)
return cast(Callable[..., T], method_sync_wrapper)

return decorator
60 changes: 30 additions & 30 deletions tests/test_flow_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydantic import BaseModel

from crewai.flow.flow import Flow, FlowState, start
from crewai.flow.persistence import FlowPersistence, persist
from crewai.flow.persistence import persist
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence


Expand All @@ -21,20 +21,20 @@ def test_persist_decorator_saves_state(tmp_path):
"""Test that @persist decorator saves state in SQLite."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)

class TestFlow(Flow[Dict[str, str]]):
initial_state = dict() # Use dict instance as initial state

@start()
@persist(persistence)
def init_step(self):
self.state["message"] = "Hello, World!"
self.state["id"] = "test-uuid" # Ensure we have an ID for persistence

# Run flow and verify state is saved
flow = TestFlow(persistence=persistence)
flow.kickoff()

# Load state from DB and verify
saved_state = persistence.load_state(flow.state["id"])
assert saved_state is not None
Expand All @@ -45,20 +45,20 @@ def test_structured_state_persistence(tmp_path):
"""Test persistence with Pydantic model state."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)

class StructuredFlow(Flow[TestState]):
initial_state = TestState

@start()
@persist(persistence)
def count_up(self):
self.state.counter += 1
self.state.message = f"Count is {self.state.counter}"

# Run flow and verify state changes are saved
flow = StructuredFlow(persistence=persistence)
flow.kickoff()

# Load and verify state
saved_state = persistence.load_state(flow.state.id)
assert saved_state is not None
Expand All @@ -70,46 +70,46 @@ def test_flow_state_restoration(tmp_path):
"""Test restoring flow state from persistence with various restoration methods."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)

# First flow execution to create initial state
class RestorableFlow(Flow[TestState]):
initial_state = TestState

@start()
@persist(persistence)
def set_message(self):
self.state.message = "Original message"
self.state.counter = 42

# Create and persist initial state
flow1 = RestorableFlow(persistence=persistence)
flow1.kickoff()
original_uuid = flow1.state.id

# Test case 1: Restore using restore_uuid with field override
flow2 = RestorableFlow(
persistence=persistence,
restore_uuid=original_uuid,
counter=43, # Override counter
)

# Verify state restoration and selective field override
assert flow2.state.id == original_uuid
assert flow2.state.message == "Original message" # Preserved
assert flow2.state.counter == 43 # Overridden

# Test case 2: Restore using kwargs['id']
flow3 = RestorableFlow(
persistence=persistence,
id=original_uuid,
message="Updated message", # Override message
)

# Verify state restoration and selective field override
assert flow3.state.id == original_uuid
assert flow3.state.counter == 42 # Preserved
assert flow3.state.message == "Updated message" # Overridden

# Test case 3: Verify error on conflicting IDs
with pytest.raises(ValueError) as exc_info:
RestorableFlow(
Expand All @@ -118,15 +118,15 @@ def set_message(self):
id="different-id", # Conflict with restore_uuid
)
assert "Conflicting IDs provided" in str(exc_info.value)

# Test case 4: Verify error on non-existent restore_uuid
with pytest.raises(ValueError) as exc_info:
RestorableFlow(
persistence=persistence,
restore_uuid="non-existent-uuid",
)
assert "No state found" in str(exc_info.value)

# Test case 5: Allow new state creation with kwargs['id']
new_uuid = "new-flow-id"
flow4 = RestorableFlow(
Expand All @@ -135,7 +135,7 @@ def set_message(self):
message="New message",
counter=100,
)

# Verify new state creation with provided ID
assert flow4.state.id == new_uuid
assert flow4.state.message == "New message"
Expand All @@ -146,25 +146,25 @@ def test_multiple_method_persistence(tmp_path):
"""Test state persistence across multiple method executions."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)

class MultiStepFlow(Flow[TestState]):
initial_state = TestState

@start()
@persist(persistence)
def step_1(self):
self.state.counter = 1
self.state.message = "Step 1"

@start()
@persist(persistence)
def step_2(self):
self.state.counter = 2
self.state.message = "Step 2"

flow = MultiStepFlow(persistence=persistence)
flow.kickoff()

# Load final state
final_state = persistence.load_state(flow.state.id)
assert final_state is not None
Expand All @@ -176,20 +176,20 @@ def test_persistence_error_handling(tmp_path):
"""Test error handling in persistence operations."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)

class InvalidFlow(Flow[TestState]):
# Missing id field in initial state
class InvalidState(BaseModel):
value: str = ""

initial_state = InvalidState

@start()
@persist(persistence)
def will_fail(self):
self.state.value = "test"

with pytest.raises(ValueError) as exc_info:
flow = InvalidFlow(persistence=persistence)

assert "must have an 'id' field" in str(exc_info.value)
Loading

0 comments on commit b5019a8

Please sign in to comment.