Skip to content

Commit

Permalink
Set the path where all memories are saved
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudgelas committed Jun 8, 2024
1 parent c908dff commit 39ecfbb
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 32 deletions.
10 changes: 2 additions & 8 deletions src/crewai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from pydantic_core import PydanticCustomError

from crewai.agents import CacheHandler, CrewAgentExecutor, CrewAgentParser, ToolsHandler
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.utilities import I18N, Logger, Prompts, RPMController
from crewai.utilities.token_counter_callback import TokenCalcHandler, TokenProcess

Expand Down Expand Up @@ -213,13 +212,8 @@ def execute_task(
task=task_prompt, context=context
)

if self.crew and self.crew.memory:
contextual_memory = ContextualMemory(
self.crew._short_term_memory,
self.crew._long_term_memory,
self.crew._entity_memory,
)
memory = contextual_memory.build_context_for_task(task, context)
if self.crew and self.crew.memory and self.crew.contextual_memory:
memory = self.crew.contextual_memory.build_context_for_task(task, context)
if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory)

Expand Down
30 changes: 21 additions & 9 deletions src/crewai/crew.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from crewai.agent import Agent
from crewai.agents.cache import CacheHandler
from crewai.memory.memory import get_memory_paths
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory
Expand Down Expand Up @@ -59,10 +61,6 @@ class Crew(BaseModel):
_logger: Logger = PrivateAttr()
_file_handler: FileHandler = PrivateAttr()
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler())
_short_term_memory: Optional[InstanceOf[ShortTermMemory]] = PrivateAttr()
_long_term_memory: Optional[InstanceOf[LongTermMemory]] = PrivateAttr()
_entity_memory: Optional[InstanceOf[EntityMemory]] = PrivateAttr()

cache: bool = Field(default=True)
model_config = ConfigDict(arbitrary_types_allowed=True)
tasks: List[Task] = Field(default_factory=list)
Expand All @@ -73,6 +71,10 @@ class Crew(BaseModel):
default=False,
description="Whether the crew should use memory to store memories of it's execution",
)
contextual_memory: Optional[InstanceOf[ContextualMemory]] = Field(
default=None,
description="The memory storage for the crew.",
)
embedder: Optional[dict] = Field(
default={"provider": "openai"},
description="Configuration for the embedder to be used for the crew.",
Expand Down Expand Up @@ -162,12 +164,22 @@ def set_private_attrs(self) -> "Crew":
@model_validator(mode="after")
def create_crew_memory(self) -> "Crew":
"""Set private attributes."""
if self.memory:
self._long_term_memory = LongTermMemory()
self._short_term_memory = ShortTermMemory(
crew=self, embedder_config=self.embedder
if self.memory and not self.contextual_memory:
long_term_memory_path, short_term_memory_path, entity_memory_path = (
get_memory_paths()
)
long_term_memory = LongTermMemory(db_path=long_term_memory_path)
short_term_memory = ShortTermMemory(
crew=self, db_path=short_term_memory_path, embedder_config=self.embedder
)
entity_memory = EntityMemory(
crew=self, db_path=entity_memory_path, embedder_config=self.embedder
)
self.contextual_memory = ContextualMemory(
short_term_memory=short_term_memory,
long_term_memory=long_term_memory,
entity_memory=entity_memory,
)
self._entity_memory = EntityMemory(crew=self, embedder_config=self.embedder)
return self

@model_validator(mode="after")
Expand Down
10 changes: 9 additions & 1 deletion src/crewai/memory/entity/entity_memory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.rag_storage import RAGStorage
Expand All @@ -10,9 +12,15 @@ class EntityMemory(Memory):
Inherits from the Memory class.
"""

def __init__(self, crew=None, embedder_config=None):
def __init__(
self,
crew=None,
db_storage_path: Path | None = None,
embedder_config: dict | None = None,
):
storage = RAGStorage(
type="entities",
db_storage_path=db_storage_path,
allow_reset=False,
embedder_config=embedder_config,
crew=crew,
Expand Down
5 changes: 3 additions & 2 deletions src/crewai/memory/long_term/long_term_memory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from typing import Any, Dict

from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
Expand All @@ -14,8 +15,8 @@ class LongTermMemory(Memory):
LongTermMemoryItem instances.
"""

def __init__(self):
storage = LTMSQLiteStorage()
def __init__(self, db_path: Path):
storage = LTMSQLiteStorage(db_path)
super().__init__(storage)

def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
Expand Down
12 changes: 12 additions & 0 deletions src/crewai/memory/memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
from pathlib import Path
from typing import Any, Dict, Optional

from crewai.memory.storage.interface import Storage
from crewai.utilities.paths import db_storage_path


def get_memory_paths(memory_path: Path | None = None) -> tuple[Path, Path, Path]:
if not memory_path:
memory_path = db_storage_path()

long_term_memory_path = memory_path / "long_term_memory_storage.db"
short_term_memory_path = memory_path
entity_memory_path = memory_path
return long_term_memory_path, short_term_memory_path, entity_memory_path


class Memory:
Expand Down
14 changes: 12 additions & 2 deletions src/crewai/memory/short_term/short_term_memory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

from crewai.memory.memory import Memory
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
from crewai.memory.storage.rag_storage import RAGStorage
Expand All @@ -12,9 +14,17 @@ class ShortTermMemory(Memory):
MemoryItem instances.
"""

def __init__(self, crew=None, embedder_config=None):
def __init__(
self,
db_storage_path: Path | None = None,
crew=None,
embedder_config=None,
):
storage = RAGStorage(
type="short_term", embedder_config=embedder_config, crew=crew
type="short_term",
db_storage_path=db_storage_path,
embedder_config=embedder_config,
crew=crew,
)
super().__init__(storage)

Expand Down
5 changes: 1 addition & 4 deletions src/crewai/memory/storage/ltm_sqlite_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,14 @@
from typing import Any, Dict, List, Optional, Union

from crewai.utilities import Printer
from crewai.utilities.paths import db_storage_path


class LTMSQLiteStorage:
"""
An updated SQLite storage class for LTM data storage.
"""

def __init__(
self, db_path: str = f"{db_storage_path()}/long_term_memory_storage.db"
) -> None:
def __init__(self, db_path: str) -> None:
self.db_path = db_path
self._printer: Printer = Printer()
self._initialize_db()
Expand Down
17 changes: 14 additions & 3 deletions src/crewai/memory/storage/rag_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import io
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Optional

from embedchain import App
from embedchain.llm.base import BaseLlm
from embedchain.vectordb.chroma import InvalidDimensionException

from crewai.memory.storage.interface import Storage
from crewai.utilities.paths import db_storage_path


@contextlib.contextmanager
Expand Down Expand Up @@ -37,18 +37,29 @@ class RAGStorage(Storage):
search efficiency.
"""

def __init__(self, type, allow_reset=True, embedder_config=None, crew=None):
def __init__(
self,
type: str,
db_storage_path: Path | None = None,
allow_reset=True,
embedder_config=None,
crew=None,
):
super().__init__()
if (
not os.getenv("OPENAI_API_KEY")
and not os.getenv("OPENAI_BASE_URL") == "https://api.openai.com/v1"
):
os.environ["OPENAI_API_KEY"] = "fake"
if db_storage_path is None:
db_storage_path = db_storage_path()

agents = crew.agents if crew else []
agents = [agent.role for agent in agents]
agents = "_".join(agents)

dir = db_storage_path / {type} / {agents}

config = {
"app": {
"config": {"name": type, "collect_metrics": False, "log_level": "ERROR"}
Expand All @@ -63,7 +74,7 @@ def __init__(self, type, allow_reset=True, embedder_config=None, crew=None):
"provider": "chroma",
"config": {
"collection_name": type,
"dir": f"{db_storage_path()}/{type}/{agents}",
"dir": str(dir.absolute()),
"allow_reset": allow_reset,
},
},
Expand Down
2 changes: 1 addition & 1 deletion src/crewai/utilities/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import appdirs


def db_storage_path():
def db_storage_path() -> Path:
app_name = get_project_directory_name()
app_author = "CrewAI"

Expand Down
6 changes: 4 additions & 2 deletions tests/memory/long_term_memory_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import pytest


from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem


@pytest.fixture
@pytest.fixture(scope="session")
def long_term_memory():
"""Fixture to create a LongTermMemory instance"""
return LongTermMemory()
data_dir = pytest.tmp_path_factory.mktemp("long_term_memory")
return LongTermMemory(data_dir)


def test_save_and_search(long_term_memory):
Expand Down

0 comments on commit 39ecfbb

Please sign in to comment.