From b418cb57498934bf9bcbbaecd47bd608616fc67a Mon Sep 17 00:00:00 2001 From: Arnaud Gelas Date: Thu, 16 Jan 2025 09:31:28 +0100 Subject: [PATCH] Refactor database storage path handling across modules - Introduced `DatabaseStorage` class in `utilities/paths.py` to encapsulate logic for managing database storage paths. - Supports app-specific storage directories with default or custom configurations. - Ensures consistent handling and directory creation. - Updated `knowledge_storage.py`: - Replaced direct calls to `db_storage_path()` with `DatabaseStorage` usage. - Adjusted initialization to accept `DatabaseStorage` as a parameter. - Updated `kickoff_task_outputs_storage.py`: - Migrated to `DatabaseStorage` for path management. - Simplified constructor by removing hardcoded paths. - Updated `ltm_sqlite_storage.py`: - Integrated `DatabaseStorage` for database path handling. - Enhanced consistency with other storage modules. - Updated `rag_storage.py`: - Refactored to use `DatabaseStorage` for managing storage paths. - Improved maintainability by consolidating path logic. - Removed outdated `db_storage_path()` function and related utilities in `utilities/paths.py`. - Adjusted import paths and parameter handling in all affected modules. - Reduced redundant code and improved modularity of storage path management. --- .../knowledge/storage/knowledge_storage.py | 8 ++-- .../storage/kickoff_task_outputs_storage.py | 8 ++-- .../memory/storage/ltm_sqlite_storage.py | 8 ++-- src/crewai/memory/storage/rag_storage.py | 15 +++++-- src/crewai/utilities/paths.py | 43 +++++++++++-------- 5 files changed, 46 insertions(+), 36 deletions(-) diff --git a/src/crewai/knowledge/storage/knowledge_storage.py b/src/crewai/knowledge/storage/knowledge_storage.py index 4a70c59971..28c0189aa5 100644 --- a/src/crewai/knowledge/storage/knowledge_storage.py +++ b/src/crewai/knowledge/storage/knowledge_storage.py @@ -16,7 +16,7 @@ from crewai.utilities import EmbeddingConfigurator from crewai.utilities.constants import KNOWLEDGE_DIRECTORY from crewai.utilities.logger import Logger -from crewai.utilities.paths import db_storage_path +from crewai.utilities.paths import DatabaseStorage @contextlib.contextmanager @@ -48,9 +48,11 @@ class KnowledgeStorage(BaseKnowledgeStorage): def __init__( self, + db_storage: DatabaseStorage = DatabaseStorage(), embedder_config: Optional[Dict[str, Any]] = None, collection_name: Optional[str] = None, ): + self.db_storage_path = db_storage.db_storage_path self.collection_name = collection_name self._set_embedder_config(embedder_config) @@ -83,7 +85,7 @@ def search( raise Exception("Collection not initialized") def initialize_knowledge_storage(self): - base_path = os.path.join(db_storage_path(), "knowledge") + base_path = os.path.join(self.db_storage_path, "knowledge") chroma_client = chromadb.PersistentClient( path=base_path, settings=Settings(allow_reset=True), @@ -107,7 +109,7 @@ def initialize_knowledge_storage(self): raise Exception("Failed to create or get collection") def reset(self): - base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY) + base_path = os.path.join(self.db_storage_path, KNOWLEDGE_DIRECTORY) if not self.app: self.app = chromadb.PersistentClient( path=base_path, diff --git a/src/crewai/memory/storage/kickoff_task_outputs_storage.py b/src/crewai/memory/storage/kickoff_task_outputs_storage.py index 26905191cb..869ef7e765 100644 --- a/src/crewai/memory/storage/kickoff_task_outputs_storage.py +++ b/src/crewai/memory/storage/kickoff_task_outputs_storage.py @@ -5,7 +5,7 @@ from crewai.task import Task from crewai.utilities import Printer from crewai.utilities.crew_json_encoder import CrewJSONEncoder -from crewai.utilities.paths import db_storage_path +from crewai.utilities.paths import DatabaseStorage class KickoffTaskOutputsSQLiteStorage: @@ -13,10 +13,8 @@ class KickoffTaskOutputsSQLiteStorage: An updated SQLite storage class for kickoff task outputs storage. """ - def __init__( - self, db_path: str = f"{db_storage_path()}/latest_kickoff_task_outputs.db" - ) -> None: - self.db_path = db_path + def __init__(self, db_storage: DatabaseStorage = DatabaseStorage()) -> None: + self.db_path = f"{db_storage.db_storage_path}/latest_kickoff_task_outputs.db" self._printer: Printer = Printer() self._initialize_db() diff --git a/src/crewai/memory/storage/ltm_sqlite_storage.py b/src/crewai/memory/storage/ltm_sqlite_storage.py index 93d993ee67..1373b1ce05 100644 --- a/src/crewai/memory/storage/ltm_sqlite_storage.py +++ b/src/crewai/memory/storage/ltm_sqlite_storage.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Union from crewai.utilities import Printer -from crewai.utilities.paths import db_storage_path +from crewai.utilities.paths import DatabaseStorage class LTMSQLiteStorage: @@ -11,10 +11,8 @@ class LTMSQLiteStorage: An updated SQLite storage class for LTM data storage. """ - def __init__( - self, db_path: str = f"{db_storage_path()}/long_term_memory_storage.db" - ) -> None: - self.db_path = db_path + def __init__(self, db_storage: DatabaseStorage = DatabaseStorage()) -> None: + self.db_path: str = f"{db_storage.db_storage_path}/long_term_memory_storage.db" self._printer: Printer = Printer() self._initialize_db() diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index fd4c77838c..022c536047 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -11,7 +11,7 @@ from crewai.memory.storage.base_rag_storage import BaseRAGStorage from crewai.utilities import EmbeddingConfigurator from crewai.utilities.constants import MAX_FILE_NAME_LENGTH -from crewai.utilities.paths import db_storage_path +from crewai.utilities.paths import DatabaseStorage @contextlib.contextmanager @@ -40,7 +40,13 @@ class RAGStorage(BaseRAGStorage): app: ClientAPI | None = None def __init__( - self, type, allow_reset=True, embedder_config=None, crew=None, path=None + self, + type, + allow_reset=True, + embedder_config=None, + crew=None, + path=None, + db_storage=DatabaseStorage(), ): super().__init__(type, allow_reset, embedder_config, crew) agents = crew.agents if crew else [] @@ -53,6 +59,7 @@ def __init__( self.allow_reset = allow_reset self.path = path + self.db_storage_path = db_storage.db_storage_path self._initialize_app() def _set_embedder_config(self): @@ -90,7 +97,7 @@ def _build_storage_file_name(self, type: str, file_name: str) -> str: """ Ensures file name does not exceed max allowed by OS """ - base_path = f"{db_storage_path()}/{type}" + base_path = f"{self.db_storage_path}/{type}" if len(file_name) > MAX_FILE_NAME_LENGTH: logging.warning( @@ -152,7 +159,7 @@ def reset(self) -> None: try: if self.app: self.app.reset() - shutil.rmtree(f"{db_storage_path()}/{self.type}") + shutil.rmtree(f"{self.db_storage_path}/{self.type}") self.app = None self.collection = None except Exception as e: diff --git a/src/crewai/utilities/paths.py b/src/crewai/utilities/paths.py index 9bf167ee6c..9153706fe3 100644 --- a/src/crewai/utilities/paths.py +++ b/src/crewai/utilities/paths.py @@ -3,25 +3,30 @@ import appdirs -"""Path management utilities for CrewAI storage and configuration.""" -def db_storage_path(): - """Returns the path for database storage.""" - app_name = get_project_directory_name() - app_author = "CrewAI" +class DatabaseStorage: + def __init__( + self, + app_author: str = "CrewAI", + app_name: str = "", + data_dir: Path | None = None, + ): + self.app_author = app_author + self.app_name = app_name if app_name else self._get_project_directoy_name() + self.db_storage_path = ( + data_dir + if data_dir + else Path(appdirs.user_data_dir(self.app_name, self.app_author)) + ) + self.db_storage_path.mkdir(parents=True, exist_ok=True) - data_dir = Path(appdirs.user_data_dir(app_name, app_author)) - data_dir.mkdir(parents=True, exist_ok=True) - return data_dir + def _get_project_directoy_name(self) -> str: + """Returns the current project directory name.""" + project_directory_name = os.environ.get("CREWAI_STORAGE_DIR") - -def get_project_directory_name(): - """Returns the current project directory name.""" - project_directory_name = os.environ.get("CREWAI_STORAGE_DIR") - - if project_directory_name: - return project_directory_name - else: - cwd = Path.cwd() - project_directory_name = cwd.name - return project_directory_name + if project_directory_name: + return project_directory_name + else: + cwd = Path.cwd() + project_directory_name = cwd.name + return project_directory_name