diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 6ab6e61c92..7a81adf9b9 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -18,6 +18,7 @@ from crewai.agent import Agent from crewai.agents.cache import CacheHandler +from crewai.memory import get_memory_paths from crewai.memory.entity.entity_memory import EntityMemory from crewai.memory.long_term.long_term_memory import LongTermMemory from crewai.memory.short_term.short_term_memory import ShortTermMemory @@ -72,6 +73,10 @@ class Crew(BaseModel): default=False, description="Whether the crew should use memory to store memories of it's execution", ) + memory_root_path: Optional[str] = Field( + default=None, + description="The root path for the memory storage.", + ) embedder: Optional[dict] = Field( default={"provider": "openai"}, description="Configuration for the embedder to be used for the crew.", @@ -166,9 +171,10 @@ def set_private_attrs(self) -> "Crew": def create_crew_memory(self) -> "Crew": """Set private attributes.""" if self.memory: - self._long_term_memory = LongTermMemory() - self._short_term_memory = ShortTermMemory(embedder_config=self.embedder) - self._entity_memory = EntityMemory(embedder_config=self.embedder) + long_term_memory_path, short_term_memory_path, entity_memory_path = get_memory_paths(self.memory_root_path) + self._long_term_memory = LongTermMemory(db_path=long_term_memory_path) + self._short_term_memory = ShortTermMemory(db_path=short_term_memory_path, embedder_config=self.embedder) + self._entity_memory = EntityMemory(db_path=entity_memory_path, embedder_config=self.embedder) return self @model_validator(mode="after") diff --git a/src/crewai/memory/entity/entity_memory.py b/src/crewai/memory/entity/entity_memory.py index 43986cc43e..81c0b48461 100644 --- a/src/crewai/memory/entity/entity_memory.py +++ b/src/crewai/memory/entity/entity_memory.py @@ -10,9 +10,9 @@ class EntityMemory(Memory): Inherits from the Memory class. """ - def __init__(self, embedder_config=None): + def __init__(self, db_storage_path=None, embedder_config=None): storage = RAGStorage( - type="entities", allow_reset=False, embedder_config=embedder_config + type="entities", db_storage_path=db_storage_path, allow_reset=False, embedder_config=embedder_config ) super().__init__(storage) diff --git a/src/crewai/memory/long_term/long_term_memory.py b/src/crewai/memory/long_term/long_term_memory.py index 046e5aaf7f..30c3a3f0f1 100644 --- a/src/crewai/memory/long_term/long_term_memory.py +++ b/src/crewai/memory/long_term/long_term_memory.py @@ -3,6 +3,7 @@ from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem from crewai.memory.memory import Memory from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage +from crewai.utilities.paths import db_storage_path class LongTermMemory(Memory): @@ -14,8 +15,10 @@ class LongTermMemory(Memory): LongTermMemoryItem instances. """ - def __init__(self): - storage = LTMSQLiteStorage() + def __init__(self, db_path=None): + if db_path is None: + db_path = f"{db_storage_path()}/long_term_memory_storage.db" + storage = LTMSQLiteStorage(db_path) super().__init__(storage) def save(self, item: LongTermMemoryItem) -> None: diff --git a/src/crewai/memory/memory.py b/src/crewai/memory/memory.py index 790b5cd498..71cf870447 100644 --- a/src/crewai/memory/memory.py +++ b/src/crewai/memory/memory.py @@ -1,6 +1,20 @@ +from pathlib import Path from typing import Any, Dict from crewai.memory.storage.interface import Storage +from crewai.utilities.paths import db_storage_path + + +def get_memory_paths(memory_path: str = None): + if not memory_path: + long_term_memory_path = f"{db_storage_path()}/long_term_memory_storage.db" + short_term_memory_path = db_storage_path() + entity_memory_path = db_storage_path() + else: + long_term_memory_path = Path(memory_path) / "long_term_memory_storage.db" + short_term_memory_path = Path(memory_path) + entity_memory_path = Path(memory_path) + return long_term_memory_path, short_term_memory_path, entity_memory_path class Memory: diff --git a/src/crewai/memory/short_term/short_term_memory.py b/src/crewai/memory/short_term/short_term_memory.py index 9d0b4b25d7..f1070a8621 100644 --- a/src/crewai/memory/short_term/short_term_memory.py +++ b/src/crewai/memory/short_term/short_term_memory.py @@ -12,8 +12,8 @@ class ShortTermMemory(Memory): MemoryItem instances. """ - def __init__(self, embedder_config=None): - storage = RAGStorage(type="short_term", embedder_config=embedder_config) + def __init__(self, db_storage_path=None, embedder_config=None): + storage = RAGStorage(type="short_term", db_storage_path=db_storage_path, embedder_config=embedder_config) super().__init__(storage) def save(self, item: ShortTermMemoryItem) -> None: diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index fb936829e8..568ffdda1b 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -37,13 +37,15 @@ class RAGStorage(Storage): search efficiency. """ - def __init__(self, type, allow_reset=True, embedder_config=None): + def __init__(self, type, db_storage_path=None, allow_reset=True, embedder_config=None): super().__init__() if ( not os.getenv("OPENAI_API_KEY") and not os.getenv("OPENAI_BASE_URL") == "https://api.openai.com/v1" ): os.environ["OPENAI_API_KEY"] = "fake" + if db_storage_path is None: + db_storage_path = db_storage_path() config = { "app": { "config": {"name": type, "collect_metrics": False, "log_level": "ERROR"} @@ -58,7 +60,7 @@ def __init__(self, type, allow_reset=True, embedder_config=None): "provider": "chroma", "config": { "collection_name": type, - "dir": f"{db_storage_path()}/{type}", + "dir": f"{db_storage_path}/{type}", "allow_reset": allow_reset, }, },