Skip to content

Commit

Permalink
Set the path where all memories are saved
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudgelas committed Apr 30, 2024
1 parent 3d52575 commit ef4d048
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 11 deletions.
12 changes: 9 additions & 3 deletions src/crewai/crew.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from crewai.agent import Agent
from crewai.agents.cache import CacheHandler
from crewai.memory import get_memory_paths
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory
Expand Down Expand Up @@ -72,6 +73,10 @@ class Crew(BaseModel):
default=False,
description="Whether the crew should use memory to store memories of it's execution",
)
memory_root_path: Optional[str] = Field(
default=None,
description="The root path for the memory storage.",
)
embedder: Optional[dict] = Field(
default={"provider": "openai"},
description="Configuration for the embedder to be used for the crew.",
Expand Down Expand Up @@ -166,9 +171,10 @@ def set_private_attrs(self) -> "Crew":
def create_crew_memory(self) -> "Crew":
"""Set private attributes."""
if self.memory:
self._long_term_memory = LongTermMemory()
self._short_term_memory = ShortTermMemory(embedder_config=self.embedder)
self._entity_memory = EntityMemory(embedder_config=self.embedder)
long_term_memory_path, short_term_memory_path, entity_memory_path = get_memory_paths(self.memory_root_path)
self._long_term_memory = LongTermMemory(db_path=long_term_memory_path)
self._short_term_memory = ShortTermMemory(db_path=short_term_memory_path, embedder_config=self.embedder)
self._entity_memory = EntityMemory(db_path=entity_memory_path, embedder_config=self.embedder)
return self

@model_validator(mode="after")
Expand Down
4 changes: 2 additions & 2 deletions src/crewai/memory/entity/entity_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ class EntityMemory(Memory):
Inherits from the Memory class.
"""

def __init__(self, embedder_config=None):
def __init__(self, db_storage_path=None, embedder_config=None):
storage = RAGStorage(
type="entities", allow_reset=False, embedder_config=embedder_config
type="entities", db_storage_path=db_storage_path, allow_reset=False, embedder_config=embedder_config
)
super().__init__(storage)

Expand Down
7 changes: 5 additions & 2 deletions src/crewai/memory/long_term/long_term_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
from crewai.utilities.paths import db_storage_path


class LongTermMemory(Memory):
Expand All @@ -14,8 +15,10 @@ class LongTermMemory(Memory):
LongTermMemoryItem instances.
"""

def __init__(self):
storage = LTMSQLiteStorage()
def __init__(self, db_path=None):
if db_path is None:
db_path = f"{db_storage_path()}/long_term_memory_storage.db"
storage = LTMSQLiteStorage(db_path)
super().__init__(storage)

def save(self, item: LongTermMemoryItem) -> None:
Expand Down
14 changes: 14 additions & 0 deletions src/crewai/memory/memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
from pathlib import Path
from typing import Any, Dict

from crewai.memory.storage.interface import Storage
from crewai.utilities.paths import db_storage_path


def get_memory_paths(memory_path: str = None):
if not memory_path:
long_term_memory_path = f"{db_storage_path()}/long_term_memory_storage.db"
short_term_memory_path = db_storage_path()
entity_memory_path = db_storage_path()
else:
long_term_memory_path = Path(memory_path) / "long_term_memory_storage.db"
short_term_memory_path = Path(memory_path)
entity_memory_path = Path(memory_path)
return long_term_memory_path, short_term_memory_path, entity_memory_path


class Memory:
Expand Down
4 changes: 2 additions & 2 deletions src/crewai/memory/short_term/short_term_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ class ShortTermMemory(Memory):
MemoryItem instances.
"""

def __init__(self, embedder_config=None):
storage = RAGStorage(type="short_term", embedder_config=embedder_config)
def __init__(self, db_storage_path=None, embedder_config=None):
storage = RAGStorage(type="short_term", db_storage_path=db_storage_path, embedder_config=embedder_config)
super().__init__(storage)

def save(self, item: ShortTermMemoryItem) -> None:
Expand Down
6 changes: 4 additions & 2 deletions src/crewai/memory/storage/rag_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@ class RAGStorage(Storage):
search efficiency.
"""

def __init__(self, type, allow_reset=True, embedder_config=None):
def __init__(self, type, db_storage_path=None, allow_reset=True, embedder_config=None):
super().__init__()
if (
not os.getenv("OPENAI_API_KEY")
and not os.getenv("OPENAI_BASE_URL") == "https://api.openai.com/v1"
):
os.environ["OPENAI_API_KEY"] = "fake"
if db_storage_path is None:
db_storage_path = db_storage_path()
config = {
"app": {
"config": {"name": type, "collect_metrics": False, "log_level": "ERROR"}
Expand All @@ -58,7 +60,7 @@ def __init__(self, type, allow_reset=True, embedder_config=None):
"provider": "chroma",
"config": {
"collection_name": type,
"dir": f"{db_storage_path()}/{type}",
"dir": f"{db_storage_path}/{type}",
"allow_reset": allow_reset,
},
},
Expand Down

0 comments on commit ef4d048

Please sign in to comment.