From 636bc0a99dda4a013880d1619dc5d849ea1257ee Mon Sep 17 00:00:00 2001 From: Deshraj Yadav Date: Sun, 15 Oct 2023 02:26:35 -0700 Subject: [PATCH] [feature]: Improve pinecone db integration (#806) --- configs/pinecone.yaml | 6 +++ docs/components/vector-databases.mdx | 18 ++++++++- embedchain/apps/person_app.py | 5 ++- embedchain/config/vectordb/pinecone.py | 10 +++-- embedchain/embedchain.py | 2 + embedchain/factory.py | 2 + .../vectordb/{pineconedb.py => pinecone.py} | 40 ++++++++++--------- pyproject.toml | 3 +- tests/apps/test_apps.py | 4 +- tests/apps/test_person_app.py | 3 +- tests/bots/test_poe.py | 3 +- tests/embedchain/test_add.py | 2 + tests/llm/test_base_llm.py | 1 + .../{test_pinecone_db.py => test_pinecone.py} | 32 +++++++-------- 14 files changed, 85 insertions(+), 46 deletions(-) create mode 100644 configs/pinecone.yaml rename embedchain/vectordb/{pineconedb.py => pinecone.py} (84%) rename tests/vectordb/{test_pinecone_db.py => test_pinecone.py} (84%) diff --git a/configs/pinecone.yaml b/configs/pinecone.yaml new file mode 100644 index 0000000000..24e33c11a8 --- /dev/null +++ b/configs/pinecone.yaml @@ -0,0 +1,6 @@ +vectordb: + provider: pinecone + config: + metric: cosine + vector_dimension: 1536 + collection_name: my-pinecone-index diff --git a/docs/components/vector-databases.mdx b/docs/components/vector-databases.mdx index e6caa6efab..09c53a53c5 100644 --- a/docs/components/vector-databases.mdx +++ b/docs/components/vector-databases.mdx @@ -147,7 +147,23 @@ _Coming soon_ ## Pinecone -_Coming soon_ +In order to use Pinecone as vector database, set the environment variables `PINECONE_API_KEY` and `PINECONE_ENV` which you can find on [Pinecone dashboard](https://app.pinecone.io/). + +```python main.py +from embedchain import App + +# load pinecone configuration from yaml file +app = App.from_config(yaml_path="config.yaml") +``` + +```yaml config.yaml +vectordb: + provider: pinecone + config: + metric: cosine + vector_dimension: 1536 + collection_name: my-pinecone-index +``` ## Qdrant diff --git a/embedchain/apps/person_app.py b/embedchain/apps/person_app.py index 4511e4a69e..9d9e07ba55 100644 --- a/embedchain/apps/person_app.py +++ b/embedchain/apps/person_app.py @@ -2,8 +2,9 @@ from embedchain.apps.app import App from embedchain.apps.open_source_app import OpenSourceApp -from embedchain.config import BaseLlmConfig, AppConfig -from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY +from embedchain.config import AppConfig, BaseLlmConfig +from embedchain.config.llm.base import (DEFAULT_PROMPT, + DEFAULT_PROMPT_WITH_HISTORY) from embedchain.helper.json_serializable import register_deserializable diff --git a/embedchain/config/vectordb/pinecone.py b/embedchain/config/vectordb/pinecone.py index 2e1334e346..7bd462ae46 100644 --- a/embedchain/config/vectordb/pinecone.py +++ b/embedchain/config/vectordb/pinecone.py @@ -1,18 +1,20 @@ -from typing import Optional +from typing import Dict, Optional from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.helper.json_serializable import register_deserializable @register_deserializable -class PineconeDbConfig(BaseVectorDbConfig): +class PineconeDBConfig(BaseVectorDbConfig): def __init__( self, collection_name: Optional[str] = None, dir: Optional[str] = None, - dimension: Optional[int] = 1536, + vector_dimension: int = 1536, metric: Optional[str] = "cosine", + **extra_params: Dict[str, any], ): - self.dimension = dimension self.metric = metric + self.vector_dimension = vector_dimension + self.extra_params = extra_params super().__init__(collection_name=collection_name, dir=dir) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 0a52c168f5..fa15358999 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -403,6 +403,8 @@ def load_and_embed( skip_embedding=(chunker.data_type == DataType.IMAGES), ) count_new_chunks = self.db.count() - chunks_before_addition + + print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}")) return list(documents), metadatas, ids, count_new_chunks def _format_result(self, results): diff --git a/embedchain/factory.py b/embedchain/factory.py index 79663962f3..b3bbf0cc72 100644 --- a/embedchain/factory.py +++ b/embedchain/factory.py @@ -69,11 +69,13 @@ class VectorDBFactory: "chroma": "embedchain.vectordb.chroma.ChromaDB", "elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB", "opensearch": "embedchain.vectordb.opensearch.OpenSearchDB", + "pinecone": "embedchain.vectordb.pinecone.PineconeDB", } provider_to_config_class = { "chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig", "elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig", "opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig", + "pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig", } @classmethod diff --git a/embedchain/vectordb/pineconedb.py b/embedchain/vectordb/pinecone.py similarity index 84% rename from embedchain/vectordb/pineconedb.py rename to embedchain/vectordb/pinecone.py index 2aa12f1b96..df4a82f822 100644 --- a/embedchain/vectordb/pineconedb.py +++ b/embedchain/vectordb/pinecone.py @@ -1,4 +1,3 @@ -import copy import os from typing import Dict, List, Optional @@ -6,38 +5,38 @@ import pinecone except ImportError: raise ImportError( - "Pinecone requires extra dependencies. Install with `pip install --upgrade embedchain[pinecone]`" + "Pinecone requires extra dependencies. Install with `pip install --upgrade 'embedchain[pinecone]'`" ) from None -from embedchain.config.vectordb.pinecone import PineconeDbConfig +from embedchain.config.vectordb.pinecone import PineconeDBConfig from embedchain.helper.json_serializable import register_deserializable from embedchain.vectordb.base import BaseVectorDB @register_deserializable -class PineconeDb(BaseVectorDB): - BATCH_SIZE = 100 - +class PineconeDB(BaseVectorDB): """ Pinecone as vector database """ + BATCH_SIZE = 100 + def __init__( self, - config: Optional[PineconeDbConfig] = None, + config: Optional[PineconeDBConfig] = None, ): """Pinecone as vector database. :param config: Pinecone database config, defaults to None - :type config: PineconeDbConfig, optional + :type config: PineconeDBConfig, optional :raises ValueError: No config provided """ if config is None: - self.config = PineconeDbConfig() + self.config = PineconeDBConfig() else: - if not isinstance(config, PineconeDbConfig): + if not isinstance(config, PineconeDBConfig): raise TypeError( - "config is not a `PineconeDbConfig` instance. " + "config is not a `PineconeDBConfig` instance. " "Please make sure the type is right and that you are passing an instance." ) self.config = config @@ -57,11 +56,14 @@ def _setup_pinecone_index(self): pinecone.init( api_key=os.environ.get("PINECONE_API_KEY"), environment=os.environ.get("PINECONE_ENV"), + **self.config.extra_params, ) self.index_name = self._get_index_name() indexes = pinecone.list_indexes() if indexes is None or self.index_name not in indexes: - pinecone.create_index(name=self.index_name, metric=self.config.metric, dimension=self.config.dimension) + pinecone.create_index( + name=self.index_name, metric=self.config.metric, dimension=self.config.vector_dimension + ) return pinecone.Index(self.index_name) def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None): @@ -81,7 +83,6 @@ def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = result = self.client.fetch(ids=ids[i : i + 1000]) batch_existing_ids = list(result.get("vectors").keys()) existing_ids.extend(batch_existing_ids) - return {"ids": existing_ids} def add( @@ -102,15 +103,15 @@ def add( :type ids: List[str] """ docs = [] - if embeddings is None: - embeddings = self.embedder.embedding_fn(documents) + print("Adding documents to Pinecone...") + + embeddings = self.embedder.embedding_fn(documents) for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings): - metadata["text"] = text docs.append( { "id": id, "values": embedding, - "metadata": copy.deepcopy(metadata), + "metadata": {**metadata, "text": text}, } ) @@ -120,13 +121,14 @@ def add( def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]: """ query contents from vector database based on vector similarity - :param input_query: list of query string :type input_query: List[str] :param n_results: no of similar documents to fetch from database :type n_results: int :param where: Optional. to filter data :type where: Dict[str, any] + :param skip_embedding: Optional. if True, input_query is already embedded + :type skip_embedding: bool :return: Database contents that are the result of the query :rtype: List[str] """ @@ -177,4 +179,4 @@ def _get_index_name(self) -> str: :return: Pinecone index :rtype: str """ - return f"{self.config.collection_name}-{self.config.dimension}".lower().replace("_", "-") + return f"{self.config.collection_name}-{self.config.vector_dimension}".lower().replace("_", "-") diff --git a/pyproject.toml b/pyproject.toml index 32b4a977e4..c33341ff42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "embedchain" -version = "0.0.70" +version = "0.0.71" description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data" authors = ["Taranjeet Singh, Deshraj Yadav"] license = "Apache License" @@ -132,6 +132,7 @@ click = "^8.1.3" isort = "^5.12.0" pytest-cov = "^4.1.0" responses = "^0.23.3" +mock = "^5.1.0" [tool.poetry.extras] streamlit = ["streamlit"] diff --git a/tests/apps/test_apps.py b/tests/apps/test_apps.py index 2e311242dc..0bafce8605 100644 --- a/tests/apps/test_apps.py +++ b/tests/apps/test_apps.py @@ -1,9 +1,11 @@ import os + import pytest import yaml from embedchain import App, CustomApp, Llama2App, OpenSourceApp -from embedchain.config import AddConfig, AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChromaDbConfig +from embedchain.config import (AddConfig, AppConfig, BaseEmbedderConfig, + BaseLlmConfig, ChromaDbConfig) from embedchain.embedder.base import BaseEmbedder from embedchain.llm.base import BaseLlm from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig diff --git a/tests/apps/test_person_app.py b/tests/apps/test_person_app.py index dc846508e6..f57899a212 100644 --- a/tests/apps/test_person_app.py +++ b/tests/apps/test_person_app.py @@ -1,7 +1,8 @@ import pytest + from embedchain.apps.app import App from embedchain.apps.person_app import PersonApp, PersonOpenSourceApp -from embedchain.config import BaseLlmConfig, AppConfig +from embedchain.config import AppConfig, BaseLlmConfig from embedchain.config.llm.base import DEFAULT_PROMPT diff --git a/tests/bots/test_poe.py b/tests/bots/test_poe.py index 09ae1d6d50..031eeac246 100644 --- a/tests/bots/test_poe.py +++ b/tests/bots/test_poe.py @@ -1,8 +1,9 @@ import argparse + import pytest +from fastapi_poe.types import ProtocolMessage, QueryRequest from embedchain.bots.poe import PoeBot, start_command -from fastapi_poe.types import QueryRequest, ProtocolMessage @pytest.fixture diff --git a/tests/embedchain/test_add.py b/tests/embedchain/test_add.py index 7974b9c472..b12509438b 100644 --- a/tests/embedchain/test_add.py +++ b/tests/embedchain/test_add.py @@ -1,5 +1,7 @@ import os + import pytest + from embedchain import App from embedchain.config import AddConfig, AppConfig, ChunkerConfig from embedchain.models.data_type import DataType diff --git a/tests/llm/test_base_llm.py b/tests/llm/test_base_llm.py index e74f8a0c75..9112053d51 100644 --- a/tests/llm/test_base_llm.py +++ b/tests/llm/test_base_llm.py @@ -1,4 +1,5 @@ import pytest + from embedchain.llm.base import BaseLlm, BaseLlmConfig diff --git a/tests/vectordb/test_pinecone_db.py b/tests/vectordb/test_pinecone.py similarity index 84% rename from tests/vectordb/test_pinecone_db.py rename to tests/vectordb/test_pinecone.py index f17252c0a0..bf0a485d78 100644 --- a/tests/vectordb/test_pinecone_db.py +++ b/tests/vectordb/test_pinecone.py @@ -4,30 +4,30 @@ from embedchain import App from embedchain.config import AppConfig from embedchain.embedder.base import BaseEmbedder -from embedchain.vectordb.pineconedb import PineconeDb +from embedchain.vectordb.pinecone import PineconeDB -class TestPineconeDb: - @patch("embedchain.vectordb.pineconedb.pinecone") +class TestPinecone: + @patch("embedchain.vectordb.pinecone.pinecone") def test_init(self, pinecone_mock): - """Test that the PineconeDb can be initialized.""" - # Create a PineconeDb instance - PineconeDb() + """Test that the PineconeDB can be initialized.""" + # Create a PineconeDB instance + PineconeDB() # Assert that the Pinecone client was initialized pinecone_mock.init.assert_called_once() pinecone_mock.list_indexes.assert_called_once() pinecone_mock.Index.assert_called_once() - @patch("embedchain.vectordb.pineconedb.pinecone") + @patch("embedchain.vectordb.pinecone.pinecone") def test_set_embedder(self, pinecone_mock): """Test that the embedder can be set.""" # Set the embedder embedder = BaseEmbedder() - # Create a PineconeDb instance - db = PineconeDb() + # Create a PineconeDB instance + db = PineconeDB() app_config = AppConfig(collect_metrics=False) App(config=app_config, db=db, embedder=embedder) @@ -35,7 +35,7 @@ def test_set_embedder(self, pinecone_mock): assert db.embedder == embedder pinecone_mock.init.assert_called_once() - @patch("embedchain.vectordb.pineconedb.pinecone") + @patch("embedchain.vectordb.pinecone.pinecone") def test_add_documents(self, pinecone_mock): """Test that documents can be added to the database.""" pinecone_client_mock = pinecone_mock.Index.return_value @@ -46,7 +46,7 @@ def test_add_documents(self, pinecone_mock): vectors = [[0, 0, 0], [1, 1, 1]] embedding_function.return_value = vectors # Create a PineconeDb instance - db = PineconeDb() + db = PineconeDB() app_config = AppConfig(collect_metrics=False) App(config=app_config, db=db, embedder=base_embedder) @@ -63,7 +63,7 @@ def test_add_documents(self, pinecone_mock): # Assert that the Pinecone client was called to upsert the documents pinecone_client_mock.upsert.assert_called_once_with(expected_pinecone_upsert_args) - @patch("embedchain.vectordb.pineconedb.pinecone") + @patch("embedchain.vectordb.pinecone.pinecone") def test_query_documents(self, pinecone_mock): """Test that documents can be queried from the database.""" pinecone_client_mock = pinecone_mock.Index.return_value @@ -73,8 +73,8 @@ def test_query_documents(self, pinecone_mock): base_embedder.set_embedding_fn(embedding_function) vectors = [[0, 0, 0]] embedding_function.return_value = vectors - # Create a PineconeDb instance - db = PineconeDb() + # Create a PineconeDB instance + db = PineconeDB() app_config = AppConfig(collect_metrics=False) App(config=app_config, db=db, embedder=base_embedder) @@ -88,11 +88,11 @@ def test_query_documents(self, pinecone_mock): vector=db.embedder.embedding_fn(input_query)[0], top_k=n_results, filter={}, include_metadata=True ) - @patch("embedchain.vectordb.pineconedb.pinecone") + @patch("embedchain.vectordb.pinecone.pinecone") def test_reset(self, pinecone_mock): """Test that the database can be reset.""" # Create a PineconeDb instance - db = PineconeDb() + db = PineconeDB() app_config = AppConfig(collect_metrics=False) App(config=app_config, db=db, embedder=BaseEmbedder())