Skip to content

Commit

Permalink
refactor: update database connections to use storage_path
Browse files Browse the repository at this point in the history
Co-Authored-By: Joe Moura <[email protected]>
  • Loading branch information
devin-ai-integration[bot] and Joe Moura committed Dec 28, 2024
1 parent a32036f commit 6145b6e
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 34 deletions.
9 changes: 5 additions & 4 deletions src/crewai/memory/contextual/contextual_memory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, Optional
from crewai.task import Task

from crewai.memory import EntityMemory, LongTermMemory, ShortTermMemory, UserMemory

Expand All @@ -21,7 +22,7 @@ def __init__(
self.em = em
self.um = um

def build_context_for_task(self, task, context) -> str:
def build_context_for_task(self, task: Task, context: str) -> str:
"""
Automatically builds a minimal, highly relevant set of contextual information
for a given task.
Expand All @@ -39,7 +40,7 @@ def build_context_for_task(self, task, context) -> str:
context.append(self._fetch_user_context(query))
return "\n".join(filter(None, context))

def _fetch_stm_context(self, query) -> str:
def _fetch_stm_context(self, query: str) -> str:
"""
Fetches recent relevant insights from STM related to the task's description and expected_output,
formatted as bullet points.
Expand All @@ -53,7 +54,7 @@ def _fetch_stm_context(self, query) -> str:
)
return f"Recent Insights:\n{formatted_results}" if stm_results else ""

def _fetch_ltm_context(self, task) -> Optional[str]:
def _fetch_ltm_context(self, task: str) -> Optional[str]:
"""
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
formatted as bullet points.
Expand All @@ -72,7 +73,7 @@ def _fetch_ltm_context(self, task) -> Optional[str]:

return f"Historical Data:\n{formatted_results}" if ltm_results else ""

def _fetch_entity_context(self, query) -> str:
def _fetch_entity_context(self, query: str) -> str:
"""
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
formatted as bullet points.
Expand Down
97 changes: 86 additions & 11 deletions src/crewai/memory/storage/base_rag_storage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional
import os
from typing import Any, Dict, List, Optional, TypeVar
from abc import ABC, abstractmethod
from pathlib import Path

from crewai.utilities.paths import db_storage_path

Expand All @@ -19,15 +22,42 @@ def __init__(
allow_reset: bool = True,
embedder_config: Optional[Any] = None,
crew: Any = None,
):
) -> None:
"""Initialize the BaseRAGStorage.
Args:
type: Type of storage being used
storage_path: Optional custom path for storage location
allow_reset: Whether storage can be reset
embedder_config: Optional configuration for the embedder
crew: Optional crew instance this storage belongs to
Raises:
PermissionError: If storage path is not writable
OSError: If storage path cannot be created
"""
self.type = type
self.storage_path = storage_path if storage_path else db_storage_path()

# Validate storage path
try:
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
if not os.access(self.storage_path.parent, os.W_OK):
raise PermissionError(f"No write permission for storage path: {self.storage_path}")
except OSError as e:
raise OSError(f"Failed to initialize storage path: {str(e)}")

self.allow_reset = allow_reset
self.embedder_config = embedder_config
self.crew = crew
self.agents = self._initialize_agents()

def _initialize_agents(self) -> str:
"""Initialize agent identifiers for storage.
Returns:
str: Underscore-joined string of sanitized agent role names
"""
if self.crew:
return "_".join(
[self._sanitize_role(agent.role) for agent in self.crew.agents]
Expand All @@ -36,12 +66,27 @@ def _initialize_agents(self) -> str:

@abstractmethod
def _sanitize_role(self, role: str) -> str:
"""Sanitizes agent roles to ensure valid directory names."""
"""Sanitizes agent roles to ensure valid directory names.
Args:
role: The agent role name to sanitize
Returns:
str: Sanitized role name safe for use in paths
"""
pass

@abstractmethod
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
"""Save a value with metadata to the storage."""
"""Save a value with metadata to the storage.
Args:
value: The value to store
metadata: Additional metadata to store with the value
Raises:
OSError: If there is an error writing to storage
"""
pass

@abstractmethod
Expand All @@ -51,25 +96,55 @@ def search(
limit: int = 3,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Any]:
"""Search for entries in the storage."""
) -> List[Dict[str, Any]]:
"""Search for entries in the storage.
Args:
query: The search query string
limit: Maximum number of results to return
filter: Optional filter criteria
score_threshold: Minimum similarity score threshold
Returns:
List[Dict[str, Any]]: List of matching entries with their metadata
"""
pass

@abstractmethod
def reset(self) -> None:
"""Reset the storage."""
"""Reset the storage.
Raises:
OSError: If there is an error clearing storage
PermissionError: If reset is not allowed
"""
pass

@abstractmethod
def _generate_embedding(
self, text: str, metadata: Optional[Dict[str, Any]] = None
) -> Any:
"""Generate an embedding for the given text and metadata."""
) -> List[float]:
"""Generate an embedding for the given text and metadata.
Args:
text: Text to generate embedding for
metadata: Optional metadata to include in embedding
Returns:
List[float]: Vector embedding of the text
Raises:
ValueError: If text is empty or invalid
"""
pass

@abstractmethod
def _initialize_app(self):
"""Initialize the vector db."""
def _initialize_app(self) -> None:
"""Initialize the vector db.
Raises:
OSError: If vector db initialization fails
"""
pass

def setup_config(self, config: Dict[str, Any]):
Expand Down
54 changes: 43 additions & 11 deletions src/crewai/memory/storage/kickoff_task_outputs_storage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import os
import sqlite3
from pathlib import Path
from typing import Any, Dict, List, Optional

from crewai.task import Task
Expand All @@ -13,12 +15,30 @@ class KickoffTaskOutputsSQLiteStorage:
An updated SQLite storage class for kickoff task outputs storage.
"""

def __init__(self, db_path: Optional[str] = None) -> None:
self.db_path = (
db_path
if db_path
else f"{db_storage_path()}/latest_kickoff_task_outputs.db"
def __init__(self, storage_path: Optional[Path] = None) -> None:
"""Initialize kickoff task outputs storage.
Args:
storage_path: Optional custom path for storage location
Raises:
PermissionError: If storage path is not writable
OSError: If storage path cannot be created
"""
self.storage_path = (
storage_path
if storage_path
else Path(f"{db_storage_path()}/latest_kickoff_task_outputs.db")
)

# Validate storage path
try:
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
if not os.access(self.storage_path.parent, os.W_OK):
raise PermissionError(f"No write permission for storage path: {self.storage_path}")
except OSError as e:
raise OSError(f"Failed to initialize storage path: {str(e)}")

self._printer: Printer = Printer()
self._initialize_db()

Expand All @@ -27,7 +47,7 @@ def _initialize_db(self):
Initializes the SQLite database and creates LTM table
"""
try:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(str(self.storage_path)) as conn:
cursor = conn.cursor()
cursor.execute(
"""
Expand Down Expand Up @@ -57,9 +77,21 @@ def add(
task_index: int,
was_replayed: bool = False,
inputs: Dict[str, Any] = {},
):
) -> None:
"""Add a task output to storage.
Args:
task: The task whose output is being stored
output: The output data from the task
task_index: Index of this task in the sequence
was_replayed: Whether this was from a replay
inputs: Optional input data that led to this output
Raises:
sqlite3.Error: If there is an error saving to database
"""
try:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(str(self.storage_path)) as conn:
cursor = conn.cursor()
cursor.execute(
"""
Expand Down Expand Up @@ -92,7 +124,7 @@ def update(
Updates an existing row in the latest_kickoff_task_outputs table based on task_index.
"""
try:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(str(self.storage_path)) as conn:
cursor = conn.cursor()

fields = []
Expand Down Expand Up @@ -121,7 +153,7 @@ def update(

def load(self) -> Optional[List[Dict[str, Any]]]:
try:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(str(self.storage_path)) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT *
Expand Down Expand Up @@ -157,7 +189,7 @@ def delete_all(self):
Deletes all rows from the latest_kickoff_task_outputs table.
"""
try:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(str(self.storage_path)) as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM latest_kickoff_task_outputs")
conn.commit()
Expand Down
45 changes: 37 additions & 8 deletions src/crewai/memory/storage/ltm_sqlite_storage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import os
import sqlite3
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

from crewai.utilities import Printer
Expand All @@ -11,10 +13,26 @@ class LTMSQLiteStorage:
An updated SQLite storage class for LTM data storage.
"""

def __init__(self, db_path: Optional[str] = None) -> None:
self.db_path = (
db_path if db_path else f"{db_storage_path()}/latest_long_term_memories.db"
)
def __init__(self, storage_path: Optional[Path] = None) -> None:
"""Initialize LTM SQLite storage.
Args:
storage_path: Optional custom path for storage location
Raises:
PermissionError: If storage path is not writable
OSError: If storage path cannot be created
"""
self.storage_path = storage_path if storage_path else Path(f"{db_storage_path()}/latest_long_term_memories.db")

# Validate storage path
try:
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
if not os.access(self.storage_path.parent, os.W_OK):
raise PermissionError(f"No write permission for storage path: {self.storage_path}")
except OSError as e:
raise OSError(f"Failed to initialize storage path: {str(e)}")

self._printer: Printer = Printer()
self._initialize_db()

Expand All @@ -23,7 +41,7 @@ def _initialize_db(self):
Initializes the SQLite database and creates LTM table
"""
try:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(str(self.storage_path)) as conn:
cursor = conn.cursor()
cursor.execute(
"""
Expand Down Expand Up @@ -51,9 +69,20 @@ def save(
datetime: str,
score: Union[int, float],
) -> None:
"""Save a memory entry to long-term memory.
Args:
task_description: Description of the task this memory relates to
metadata: Additional data to store with the memory
datetime: Timestamp for when this memory was created
score: Relevance score for this memory (higher is more relevant)
Raises:
sqlite3.Error: If there is an error saving to the database
"""
"""Saves data to the LTM table with error handling."""
try:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(str(self.storage_path)) as conn:
cursor = conn.cursor()
cursor.execute(
"""
Expand All @@ -74,7 +103,7 @@ def load(
) -> Optional[List[Dict[str, Any]]]:
"""Queries the LTM table by task description with error handling."""
try:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(str(self.storage_path)) as conn:
cursor = conn.cursor()
cursor.execute(
f"""
Expand Down Expand Up @@ -109,7 +138,7 @@ def reset(
) -> None:
"""Resets the LTM table with error handling."""
try:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(str(self.storage_path)) as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM long_term_memories")
conn.commit()
Expand Down

0 comments on commit 6145b6e

Please sign in to comment.