Skip to content

Commit

Permalink
fix: Improve type safety in flow state handling with proper validation
Browse files Browse the repository at this point in the history
Co-Authored-By: Joe Moura <[email protected]>
  • Loading branch information
devin-ai-integration[bot] and Joe Moura committed Jan 14, 2025
1 parent 4e0a7ba commit 212e60f
Showing 1 changed file with 68 additions and 9 deletions.
77 changes: 68 additions & 9 deletions src/crewai/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,48 @@ class FlowState(BaseModel):
"""Base model for all flow states, ensuring each state has a unique ID."""
id: str = Field(default_factory=lambda: str(uuid4()), description="Unique identifier for the flow state")

# Type variables with explicit bounds
T = TypeVar("T", bound=Union[FlowState, Dict[str, Any]])
DictStateType = Dict[str, Any]
ModelStateType = TypeVar('ModelStateType', bound=BaseModel)

def validate_state_type(state: Any, expected_type: Type[Union[dict, BaseModel]]) -> bool:
"""Validate that state matches expected type.
Args:
state: State instance to validate
expected_type: Expected type for the state
Returns:
True if state matches expected type, False otherwise
"""
if expected_type == dict:
return isinstance(state, dict)
return isinstance(state, expected_type)

def ensure_state_type(state: Any, expected_type: Union[Type[dict], Type[BaseModel]]) -> T:
"""Ensure state matches expected type with proper validation.
Args:
state: State instance to validate
expected_type: Expected type for the state (dict or BaseModel)
Returns:
Validated state instance
Raises:
TypeError: If state doesn't match expected type
ValueError: If state validation fails
"""
if expected_type == dict:
if not isinstance(state, dict):
raise TypeError("State must be a dictionary")
return cast(T, state)
elif issubclass(expected_type, BaseModel):
if not isinstance(state, expected_type):
raise TypeError(f"State must be instance of {expected_type.__name__}")
return cast(T, state)
raise TypeError("Expected type must be dict or BaseModel subclass")


def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
Expand Down Expand Up @@ -430,6 +471,14 @@ def __init__(
method = method.__get__(self, self.__class__)
self._methods[method_name] = method

@overload
def _create_initial_state(self: "Flow[DictStateType]") -> DictStateType:
...

@overload
def _create_initial_state(self: "Flow[ModelStateType]") -> ModelStateType:
...

def _create_initial_state(self) -> T:
"""Create and initialize flow state with UUID.
Expand All @@ -445,39 +494,49 @@ def _create_initial_state(self) -> T:
state_type = getattr(self, "_initial_state_T")
if isinstance(state_type, type):
if issubclass(state_type, FlowState):
return state_type() # type: ignore
return ensure_state_type(state_type(), state_type)
elif issubclass(state_type, BaseModel):
# Create a new type that includes the ID field
class StateWithId(state_type, FlowState): # type: ignore
pass
return StateWithId() # type: ignore
return ensure_state_type(StateWithId(), BaseModel)
elif state_type == dict:
return cast(T, {"id": str(uuid4())})
return ensure_state_type({"id": str(uuid4())}, dict)

# Handle case where no initial state is provided
if self.initial_state is None:
return cast(T, {"id": str(uuid4())})
return ensure_state_type({"id": str(uuid4())}, dict)

# Handle case where initial_state is a type (class)
if isinstance(self.initial_state, type):
if issubclass(self.initial_state, FlowState):
return cast(T, self.initial_state())
state = self.initial_state()
return ensure_state_type(state, self.initial_state)
elif issubclass(self.initial_state, BaseModel):
# Validate that the model has an id field
model_fields = getattr(self.initial_state, "model_fields", None)
if not model_fields or "id" not in model_fields:
raise ValueError("Flow state model must have an 'id' field")
return cast(T, self.initial_state())
state = self.initial_state()
return ensure_state_type(state, self.initial_state)
elif self.initial_state == dict:
return cast(T, {"id": str(uuid4())})
return ensure_state_type({"id": str(uuid4())}, dict)

# Handle dictionary instance case
if isinstance(self.initial_state, dict):
if "id" not in self.initial_state:
self.initial_state["id"] = str(uuid4())
return cast(T, self.initial_state)
return ensure_state_type(dict(self.initial_state), dict) # Create new dict to avoid mutations

return cast(T, self.initial_state)
# Handle BaseModel instance case
if isinstance(self.initial_state, BaseModel):
if not hasattr(self.initial_state, "id"):
raise ValueError("Flow state model must have an 'id' field")
return ensure_state_type(self.initial_state, type(self.initial_state))

raise TypeError(
f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
)

@property
def state(self) -> T:
Expand Down

0 comments on commit 212e60f

Please sign in to comment.