Skip to content

Commit

Permalink
Refactor database storage path handling across modules
Browse files Browse the repository at this point in the history
- Introduced `DatabaseStorage` class in `utilities/paths.py` to encapsulate
  logic for managing database storage paths.
  - Supports app-specific storage directories with default or custom
    configurations.
  - Ensures consistent handling and directory creation.

- Updated `knowledge_storage.py`:
  - Replaced direct calls to `db_storage_path()` with `DatabaseStorage` usage.
  - Adjusted initialization to accept `DatabaseStorage` as a parameter.

- Updated `kickoff_task_outputs_storage.py`:
  - Migrated to `DatabaseStorage` for path management.
  - Simplified constructor by removing hardcoded paths.

- Updated `ltm_sqlite_storage.py`:
  - Integrated `DatabaseStorage` for database path handling.
  - Enhanced consistency with other storage modules.

- Updated `rag_storage.py`:
  - Refactored to use `DatabaseStorage` for managing storage paths.
  - Improved maintainability by consolidating path logic.

- Removed outdated `db_storage_path()` function and related utilities in
  `utilities/paths.py`.

- Adjusted import paths and parameter handling in all affected modules.

- Reduced redundant code and improved modularity of storage path management.
  • Loading branch information
arnaudgelas committed Jan 16, 2025
1 parent 3dc4428 commit b418cb5
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 36 deletions.
8 changes: 5 additions & 3 deletions src/crewai/knowledge/storage/knowledge_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from crewai.utilities import EmbeddingConfigurator
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
from crewai.utilities.logger import Logger
from crewai.utilities.paths import db_storage_path
from crewai.utilities.paths import DatabaseStorage


@contextlib.contextmanager
Expand Down Expand Up @@ -48,9 +48,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):

def __init__(
self,
db_storage: DatabaseStorage = DatabaseStorage(),
embedder_config: Optional[Dict[str, Any]] = None,
collection_name: Optional[str] = None,
):
self.db_storage_path = db_storage.db_storage_path
self.collection_name = collection_name
self._set_embedder_config(embedder_config)

Expand Down Expand Up @@ -83,7 +85,7 @@ def search(
raise Exception("Collection not initialized")

def initialize_knowledge_storage(self):
base_path = os.path.join(db_storage_path(), "knowledge")
base_path = os.path.join(self.db_storage_path, "knowledge")
chroma_client = chromadb.PersistentClient(
path=base_path,
settings=Settings(allow_reset=True),
Expand All @@ -107,7 +109,7 @@ def initialize_knowledge_storage(self):
raise Exception("Failed to create or get collection")

def reset(self):
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY)
base_path = os.path.join(self.db_storage_path, KNOWLEDGE_DIRECTORY)
if not self.app:
self.app = chromadb.PersistentClient(
path=base_path,
Expand Down
8 changes: 3 additions & 5 deletions src/crewai/memory/storage/kickoff_task_outputs_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,16 @@
from crewai.task import Task
from crewai.utilities import Printer
from crewai.utilities.crew_json_encoder import CrewJSONEncoder
from crewai.utilities.paths import db_storage_path
from crewai.utilities.paths import DatabaseStorage


class KickoffTaskOutputsSQLiteStorage:
"""
An updated SQLite storage class for kickoff task outputs storage.
"""

def __init__(
self, db_path: str = f"{db_storage_path()}/latest_kickoff_task_outputs.db"
) -> None:
self.db_path = db_path
def __init__(self, db_storage: DatabaseStorage = DatabaseStorage()) -> None:
self.db_path = f"{db_storage.db_storage_path}/latest_kickoff_task_outputs.db"
self._printer: Printer = Printer()
self._initialize_db()

Expand Down
8 changes: 3 additions & 5 deletions src/crewai/memory/storage/ltm_sqlite_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,16 @@
from typing import Any, Dict, List, Optional, Union

from crewai.utilities import Printer
from crewai.utilities.paths import db_storage_path
from crewai.utilities.paths import DatabaseStorage


class LTMSQLiteStorage:
"""
An updated SQLite storage class for LTM data storage.
"""

def __init__(
self, db_path: str = f"{db_storage_path()}/long_term_memory_storage.db"
) -> None:
self.db_path = db_path
def __init__(self, db_storage: DatabaseStorage = DatabaseStorage()) -> None:
self.db_path: str = f"{db_storage.db_storage_path}/long_term_memory_storage.db"
self._printer: Printer = Printer()
self._initialize_db()

Expand Down
15 changes: 11 additions & 4 deletions src/crewai/memory/storage/rag_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities import EmbeddingConfigurator
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
from crewai.utilities.paths import db_storage_path
from crewai.utilities.paths import DatabaseStorage


@contextlib.contextmanager
Expand Down Expand Up @@ -40,7 +40,13 @@ class RAGStorage(BaseRAGStorage):
app: ClientAPI | None = None

def __init__(
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
self,
type,
allow_reset=True,
embedder_config=None,
crew=None,
path=None,
db_storage=DatabaseStorage(),
):
super().__init__(type, allow_reset, embedder_config, crew)
agents = crew.agents if crew else []
Expand All @@ -53,6 +59,7 @@ def __init__(

self.allow_reset = allow_reset
self.path = path
self.db_storage_path = db_storage.db_storage_path
self._initialize_app()

def _set_embedder_config(self):
Expand Down Expand Up @@ -90,7 +97,7 @@ def _build_storage_file_name(self, type: str, file_name: str) -> str:
"""
Ensures file name does not exceed max allowed by OS
"""
base_path = f"{db_storage_path()}/{type}"
base_path = f"{self.db_storage_path}/{type}"

if len(file_name) > MAX_FILE_NAME_LENGTH:
logging.warning(
Expand Down Expand Up @@ -152,7 +159,7 @@ def reset(self) -> None:
try:
if self.app:
self.app.reset()
shutil.rmtree(f"{db_storage_path()}/{self.type}")
shutil.rmtree(f"{self.db_storage_path}/{self.type}")
self.app = None
self.collection = None
except Exception as e:
Expand Down
43 changes: 24 additions & 19 deletions src/crewai/utilities/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,30 @@

import appdirs

"""Path management utilities for CrewAI storage and configuration."""

def db_storage_path():
"""Returns the path for database storage."""
app_name = get_project_directory_name()
app_author = "CrewAI"
class DatabaseStorage:
def __init__(
self,
app_author: str = "CrewAI",
app_name: str = "",
data_dir: Path | None = None,
):
self.app_author = app_author
self.app_name = app_name if app_name else self._get_project_directoy_name()
self.db_storage_path = (
data_dir
if data_dir
else Path(appdirs.user_data_dir(self.app_name, self.app_author))
)
self.db_storage_path.mkdir(parents=True, exist_ok=True)

data_dir = Path(appdirs.user_data_dir(app_name, app_author))
data_dir.mkdir(parents=True, exist_ok=True)
return data_dir
def _get_project_directoy_name(self) -> str:
"""Returns the current project directory name."""
project_directory_name = os.environ.get("CREWAI_STORAGE_DIR")


def get_project_directory_name():
"""Returns the current project directory name."""
project_directory_name = os.environ.get("CREWAI_STORAGE_DIR")

if project_directory_name:
return project_directory_name
else:
cwd = Path.cwd()
project_directory_name = cwd.name
return project_directory_name
if project_directory_name:
return project_directory_name
else:
cwd = Path.cwd()
project_directory_name = cwd.name
return project_directory_name

0 comments on commit b418cb5

Please sign in to comment.