Skip to content

Commit

Permalink
[feature]: Improve pinecone db integration (#806)
Browse files Browse the repository at this point in the history
  • Loading branch information
deshraj authored Oct 15, 2023
1 parent a7a61fa commit 636bc0a
Show file tree
Hide file tree
Showing 14 changed files with 85 additions and 46 deletions.
6 changes: 6 additions & 0 deletions configs/pinecone.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
vectordb:
provider: pinecone
config:
metric: cosine
vector_dimension: 1536
collection_name: my-pinecone-index
18 changes: 17 additions & 1 deletion docs/components/vector-databases.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,23 @@ _Coming soon_
## Pinecone
_Coming soon_
In order to use Pinecone as vector database, set the environment variables `PINECONE_API_KEY` and `PINECONE_ENV` which you can find on [Pinecone dashboard](https://app.pinecone.io/).

```python main.py
from embedchain import App
# load pinecone configuration from yaml file
app = App.from_config(yaml_path="config.yaml")
```

```yaml config.yaml
vectordb:
provider: pinecone
config:
metric: cosine
vector_dimension: 1536
collection_name: my-pinecone-index
```

## Qdrant

Expand Down
5 changes: 3 additions & 2 deletions embedchain/apps/person_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from embedchain.apps.app import App
from embedchain.apps.open_source_app import OpenSourceApp
from embedchain.config import BaseLlmConfig, AppConfig
from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY
from embedchain.config import AppConfig, BaseLlmConfig
from embedchain.config.llm.base import (DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY)
from embedchain.helper.json_serializable import register_deserializable


Expand Down
10 changes: 6 additions & 4 deletions embedchain/config/vectordb/pinecone.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from typing import Optional
from typing import Dict, Optional

from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helper.json_serializable import register_deserializable


@register_deserializable
class PineconeDbConfig(BaseVectorDbConfig):
class PineconeDBConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
dimension: Optional[int] = 1536,
vector_dimension: int = 1536,
metric: Optional[str] = "cosine",
**extra_params: Dict[str, any],
):
self.dimension = dimension
self.metric = metric
self.vector_dimension = vector_dimension
self.extra_params = extra_params
super().__init__(collection_name=collection_name, dir=dir)
2 changes: 2 additions & 0 deletions embedchain/embedchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,8 @@ def load_and_embed(
skip_embedding=(chunker.data_type == DataType.IMAGES),
)
count_new_chunks = self.db.count() - chunks_before_addition

print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks

def _format_result(self, results):
Expand Down
2 changes: 2 additions & 0 deletions embedchain/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,13 @@ class VectorDBFactory:
"chroma": "embedchain.vectordb.chroma.ChromaDB",
"elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB",
"opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
"pinecone": "embedchain.vectordb.pinecone.PineconeDB",
}
provider_to_config_class = {
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
"elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
"pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
}

@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,43 +1,42 @@
import copy
import os
from typing import Dict, List, Optional

try:
import pinecone
except ImportError:
raise ImportError(
"Pinecone requires extra dependencies. Install with `pip install --upgrade embedchain[pinecone]`"
"Pinecone requires extra dependencies. Install with `pip install --upgrade 'embedchain[pinecone]'`"
) from None

from embedchain.config.vectordb.pinecone import PineconeDbConfig
from embedchain.config.vectordb.pinecone import PineconeDBConfig
from embedchain.helper.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB


@register_deserializable
class PineconeDb(BaseVectorDB):
BATCH_SIZE = 100

class PineconeDB(BaseVectorDB):
"""
Pinecone as vector database
"""

BATCH_SIZE = 100

def __init__(
self,
config: Optional[PineconeDbConfig] = None,
config: Optional[PineconeDBConfig] = None,
):
"""Pinecone as vector database.
:param config: Pinecone database config, defaults to None
:type config: PineconeDbConfig, optional
:type config: PineconeDBConfig, optional
:raises ValueError: No config provided
"""
if config is None:
self.config = PineconeDbConfig()
self.config = PineconeDBConfig()
else:
if not isinstance(config, PineconeDbConfig):
if not isinstance(config, PineconeDBConfig):
raise TypeError(
"config is not a `PineconeDbConfig` instance. "
"config is not a `PineconeDBConfig` instance. "
"Please make sure the type is right and that you are passing an instance."
)
self.config = config
Expand All @@ -57,11 +56,14 @@ def _setup_pinecone_index(self):
pinecone.init(
api_key=os.environ.get("PINECONE_API_KEY"),
environment=os.environ.get("PINECONE_ENV"),
**self.config.extra_params,
)
self.index_name = self._get_index_name()
indexes = pinecone.list_indexes()
if indexes is None or self.index_name not in indexes:
pinecone.create_index(name=self.index_name, metric=self.config.metric, dimension=self.config.dimension)
pinecone.create_index(
name=self.index_name, metric=self.config.metric, dimension=self.config.vector_dimension
)
return pinecone.Index(self.index_name)

def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
Expand All @@ -81,7 +83,6 @@ def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] =
result = self.client.fetch(ids=ids[i : i + 1000])
batch_existing_ids = list(result.get("vectors").keys())
existing_ids.extend(batch_existing_ids)

return {"ids": existing_ids}

def add(
Expand All @@ -102,15 +103,15 @@ def add(
:type ids: List[str]
"""
docs = []
if embeddings is None:
embeddings = self.embedder.embedding_fn(documents)
print("Adding documents to Pinecone...")

embeddings = self.embedder.embedding_fn(documents)
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
metadata["text"] = text
docs.append(
{
"id": id,
"values": embedding,
"metadata": copy.deepcopy(metadata),
"metadata": {**metadata, "text": text},
}
)

Expand All @@ -120,13 +121,14 @@ def add(
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
"""
query contents from vector database based on vector similarity
:param input_query: list of query string
:type input_query: List[str]
:param n_results: no of similar documents to fetch from database
:type n_results: int
:param where: Optional. to filter data
:type where: Dict[str, any]
:param skip_embedding: Optional. if True, input_query is already embedded
:type skip_embedding: bool
:return: Database contents that are the result of the query
:rtype: List[str]
"""
Expand Down Expand Up @@ -177,4 +179,4 @@ def _get_index_name(self) -> str:
:return: Pinecone index
:rtype: str
"""
return f"{self.config.collection_name}-{self.config.dimension}".lower().replace("_", "-")
return f"{self.config.collection_name}-{self.config.vector_dimension}".lower().replace("_", "-")
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "embedchain"
version = "0.0.70"
version = "0.0.71"
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
authors = ["Taranjeet Singh, Deshraj Yadav"]
license = "Apache License"
Expand Down Expand Up @@ -132,6 +132,7 @@ click = "^8.1.3"
isort = "^5.12.0"
pytest-cov = "^4.1.0"
responses = "^0.23.3"
mock = "^5.1.0"

[tool.poetry.extras]
streamlit = ["streamlit"]
Expand Down
4 changes: 3 additions & 1 deletion tests/apps/test_apps.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os

import pytest
import yaml

from embedchain import App, CustomApp, Llama2App, OpenSourceApp
from embedchain.config import AddConfig, AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChromaDbConfig
from embedchain.config import (AddConfig, AppConfig, BaseEmbedderConfig,
BaseLlmConfig, ChromaDbConfig)
from embedchain.embedder.base import BaseEmbedder
from embedchain.llm.base import BaseLlm
from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig
Expand Down
3 changes: 2 additions & 1 deletion tests/apps/test_person_app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest

from embedchain.apps.app import App
from embedchain.apps.person_app import PersonApp, PersonOpenSourceApp
from embedchain.config import BaseLlmConfig, AppConfig
from embedchain.config import AppConfig, BaseLlmConfig
from embedchain.config.llm.base import DEFAULT_PROMPT


Expand Down
3 changes: 2 additions & 1 deletion tests/bots/test_poe.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import argparse

import pytest
from fastapi_poe.types import ProtocolMessage, QueryRequest

from embedchain.bots.poe import PoeBot, start_command
from fastapi_poe.types import QueryRequest, ProtocolMessage


@pytest.fixture
Expand Down
2 changes: 2 additions & 0 deletions tests/embedchain/test_add.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

import pytest

from embedchain import App
from embedchain.config import AddConfig, AppConfig, ChunkerConfig
from embedchain.models.data_type import DataType
Expand Down
1 change: 1 addition & 0 deletions tests/llm/test_base_llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest

from embedchain.llm.base import BaseLlm, BaseLlmConfig


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,38 @@
from embedchain import App
from embedchain.config import AppConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.vectordb.pineconedb import PineconeDb
from embedchain.vectordb.pinecone import PineconeDB


class TestPineconeDb:
@patch("embedchain.vectordb.pineconedb.pinecone")
class TestPinecone:
@patch("embedchain.vectordb.pinecone.pinecone")
def test_init(self, pinecone_mock):
"""Test that the PineconeDb can be initialized."""
# Create a PineconeDb instance
PineconeDb()
"""Test that the PineconeDB can be initialized."""
# Create a PineconeDB instance
PineconeDB()

# Assert that the Pinecone client was initialized
pinecone_mock.init.assert_called_once()
pinecone_mock.list_indexes.assert_called_once()
pinecone_mock.Index.assert_called_once()

@patch("embedchain.vectordb.pineconedb.pinecone")
@patch("embedchain.vectordb.pinecone.pinecone")
def test_set_embedder(self, pinecone_mock):
"""Test that the embedder can be set."""

# Set the embedder
embedder = BaseEmbedder()

# Create a PineconeDb instance
db = PineconeDb()
# Create a PineconeDB instance
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)

# Assert that the embedder was set
assert db.embedder == embedder
pinecone_mock.init.assert_called_once()

@patch("embedchain.vectordb.pineconedb.pinecone")
@patch("embedchain.vectordb.pinecone.pinecone")
def test_add_documents(self, pinecone_mock):
"""Test that documents can be added to the database."""
pinecone_client_mock = pinecone_mock.Index.return_value
Expand All @@ -46,7 +46,7 @@ def test_add_documents(self, pinecone_mock):
vectors = [[0, 0, 0], [1, 1, 1]]
embedding_function.return_value = vectors
# Create a PineconeDb instance
db = PineconeDb()
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=base_embedder)

Expand All @@ -63,7 +63,7 @@ def test_add_documents(self, pinecone_mock):
# Assert that the Pinecone client was called to upsert the documents
pinecone_client_mock.upsert.assert_called_once_with(expected_pinecone_upsert_args)

@patch("embedchain.vectordb.pineconedb.pinecone")
@patch("embedchain.vectordb.pinecone.pinecone")
def test_query_documents(self, pinecone_mock):
"""Test that documents can be queried from the database."""
pinecone_client_mock = pinecone_mock.Index.return_value
Expand All @@ -73,8 +73,8 @@ def test_query_documents(self, pinecone_mock):
base_embedder.set_embedding_fn(embedding_function)
vectors = [[0, 0, 0]]
embedding_function.return_value = vectors
# Create a PineconeDb instance
db = PineconeDb()
# Create a PineconeDB instance
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=base_embedder)

Expand All @@ -88,11 +88,11 @@ def test_query_documents(self, pinecone_mock):
vector=db.embedder.embedding_fn(input_query)[0], top_k=n_results, filter={}, include_metadata=True
)

@patch("embedchain.vectordb.pineconedb.pinecone")
@patch("embedchain.vectordb.pinecone.pinecone")
def test_reset(self, pinecone_mock):
"""Test that the database can be reset."""
# Create a PineconeDb instance
db = PineconeDb()
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=BaseEmbedder())

Expand Down

0 comments on commit 636bc0a

Please sign in to comment.