Skip to content

Commit

Permalink
Support Milvus partition key
Browse files Browse the repository at this point in the history
Support Milvus partition key in saving and searching collection.
  • Loading branch information
ziyi-curiousthing authored and SimFG committed Jan 18, 2024
1 parent b732ea2 commit 9adf6ab
Show file tree
Hide file tree
Showing 14 changed files with 48 additions and 31 deletions.
6 changes: 6 additions & 0 deletions gptcache/adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
raise NotInitError()
cache_enable = chat_cache.cache_enable_func(*args, **kwargs)
context = kwargs.pop("cache_context", {})
partition_key = kwargs.pop("partition_key", None)
embedding_data = None
# you want to retry to send the request to chatgpt when the cache is negative

Expand Down Expand Up @@ -91,6 +92,7 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
top_k=kwargs.pop("top_k", 5)
if (user_temperature and not user_top_k)
else kwargs.pop("top_k", -1),
partition_key=partition_key,
)
if search_data_list is None:
search_data_list = []
Expand Down Expand Up @@ -263,6 +265,7 @@ def update_cache_func(handled_llm_data, question=None):
embedding_data,
extra_param=context.get("save_func", None),
session=session,
partition_key=partition_key,
)
if (
chat_cache.report.op_save.count > 0
Expand Down Expand Up @@ -304,6 +307,7 @@ async def aadapt(
raise NotInitError()
cache_enable = chat_cache.cache_enable_func(*args, **kwargs)
context = kwargs.pop("cache_context", {})
partition_key = kwargs.pop("partition_key", None)
embedding_data = None
# you want to retry to send the request to chatgpt when the cache is negative

Expand Down Expand Up @@ -362,6 +366,7 @@ async def aadapt(
top_k=kwargs.pop("top_k", 5)
if (user_temperature and not user_top_k)
else kwargs.pop("top_k", -1),
partition_key=partition_key,
)
if search_data_list is None:
search_data_list = []
Expand Down Expand Up @@ -509,6 +514,7 @@ def update_cache_func(handled_llm_data, question=None):
embedding_data,
extra_param=context.get("save_func", None),
session=session,
partition_key=partition_key,
)
if (
chat_cache.report.op_save.count > 0
Expand Down
3 changes: 2 additions & 1 deletion gptcache/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def close():
if not os.getenv("IS_CI"):
gptcache_log.error(e)

def import_data(self, questions: List[Any], answers: List[Any], session_ids: Optional[List[Optional[str]]] = None) -> None:
def import_data(self, questions: List[Any], answers: List[Any], session_ids: Optional[List[Optional[str]]] = None, **kwargs) -> None:
"""Import data to GPTCache
:param questions: preprocessed question Data
Expand All @@ -101,6 +101,7 @@ def import_data(self, questions: List[Any], answers: List[Any], session_ids: Opt
answers=answers,
embedding_datas=[self.embedding_func(question) for question in questions],
session_ids=session_ids if session_ids else [None for _ in range(len(questions))],
**kwargs,
)

def flush(self):
Expand Down
12 changes: 8 additions & 4 deletions gptcache/manager/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def import_data(
answers: List[Any],
embedding_datas: List[Any],
session_ids: List[Optional[str]],
**kwargs,
):
pass

Expand Down Expand Up @@ -136,6 +137,7 @@ def import_data(
answers: List[Any],
embedding_datas: List[Any],
session_ids: List[Optional[str]],
**_,
):
if (
len(questions) != len(answers)
Expand Down Expand Up @@ -272,7 +274,7 @@ def save(self, question, answer, embedding_data, **kwargs):
"""
session = kwargs.get("session", None)
session_id = session.name if session else None
self.import_data([question], [answer], [embedding_data], [session_id])
self.import_data([question], [answer], [embedding_data], [session_id], **kwargs)

def _process_answer_data(self, answers: Union[Answer, List[Answer]]):
if isinstance(answers, Answer):
Expand Down Expand Up @@ -303,6 +305,7 @@ def import_data(
answers: List[Answer],
embedding_datas: List[Any],
session_ids: List[Optional[str]],
**kwargs,
):
if (
len(questions) != len(answers)
Expand Down Expand Up @@ -333,7 +336,8 @@ def import_data(
[
VectorData(id=ids[i], data=embedding_data)
for i, embedding_data in enumerate(embedding_datas)
]
],
**kwargs,
)
self.eviction_base.put(ids)

Expand Down Expand Up @@ -368,8 +372,8 @@ def hit_cache_callback(self, res_data, **kwargs):

def search(self, embedding_data, **kwargs):
embedding_data = normalize(embedding_data)
top_k = kwargs.get("top_k", -1)
return self.v.search(data=embedding_data, top_k=top_k)
top_k = kwargs.pop("top_k", -1)
return self.v.search(data=embedding_data, top_k=top_k, **kwargs)

def flush(self):
self.s.flush()
Expand Down
4 changes: 2 additions & 2 deletions gptcache/manager/vector_data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ class VectorBase(ABC):
"""VectorBase: base vector store interface"""

@abstractmethod
def mul_add(self, datas: List[VectorData]):
def mul_add(self, datas: List[VectorData], **kwargs):
pass

@abstractmethod
def search(self, data: np.ndarray, top_k: int):
def search(self, data: np.ndarray, top_k: int, **kwargs):
pass

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions gptcache/manager/vector_data/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def __init__(
self._persist_directory = persist_directory
self._collection = self._client.get_or_create_collection(name=collection_name)

def mul_add(self, datas: List[VectorData]):
def mul_add(self, datas: List[VectorData], **_):
data_array, id_array = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas)))
self._collection.add(embeddings=data_array, ids=id_array)

def search(self, data, top_k: int = -1):
def search(self, data, top_k: int = -1, **_):
if self._collection.count() == 0:
return []
if top_k == -1:
Expand Down
4 changes: 2 additions & 2 deletions gptcache/manager/vector_data/docarray_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, index_file_path: str, top_k: int):
self._index_file_path = index_file_path
self._top_k = top_k

def mul_add(self, datas: List[VectorData]) -> None:
def mul_add(self, datas: List[VectorData], **_) -> None:
"""
Add multiple vector data elements to the index.
Expand All @@ -48,7 +48,7 @@ def mul_add(self, datas: List[VectorData]) -> None:
self._index.index(docs)

def search(
self, data: np.ndarray, top_k: int = -1
self, data: np.ndarray, top_k: int = -1, **_
) -> Optional[List[Tuple[float, int]]]:
"""
Search for the nearest vector data elements in the index.
Expand Down
4 changes: 2 additions & 2 deletions gptcache/manager/vector_data/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ def __init__(self, index_file_path, dimension, top_k):
if os.path.isfile(index_file_path):
self._index = faiss.read_index(index_file_path)

def mul_add(self, datas: List[VectorData]):
def mul_add(self, datas: List[VectorData], **_):
data_array, id_array = map(list, zip(*((data.data, data.id) for data in datas)))
np_data = np.array(data_array).astype("float32")
ids = np.array(id_array)
self._index.add_with_ids(np_data, ids)

def search(self, data: np.ndarray, top_k: int = -1):
def search(self, data: np.ndarray, top_k: int = -1, **_):
if self._index.ntotal == 0:
return None
if top_k == -1:
Expand Down
4 changes: 2 additions & 2 deletions gptcache/manager/vector_data/hnswlib_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ def add(self, key: int, data: np.ndarray):
np_data = np.array(data).astype("float32").reshape(1, -1)
self._index.add_items(np_data, np.array([key]))

def mul_add(self, datas: List[VectorData]):
def mul_add(self, datas: List[VectorData], **_):
data_array, id_array = map(list, zip(*((data.data, data.id) for data in datas)))
np_data = np.array(data_array).astype("float32")
ids = np.array(id_array)
self._index.add_items(np_data, ids)

def search(self, data: np.ndarray, top_k: int = -1):
def search(self, data: np.ndarray, top_k: int = -1, **_):
np_data = np.array(data).astype("float32").reshape(1, -1)
if top_k == -1:
top_k = self._top_k
Expand Down
18 changes: 12 additions & 6 deletions gptcache/manager/vector_data/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ def _create_collection(self, collection_name):
is_primary=True,
auto_id=False,
),
FieldSchema(
name="partition_key",
dtype=DataType.VARCHAR,
max_length=256,
is_partition_key=True,
),
FieldSchema(
name="embedding", dtype=DataType.FLOAT_VECTOR, dim=self.dimension
),
Expand Down Expand Up @@ -163,20 +169,20 @@ def _create_collection(self, collection_name):

self.col.load()

def mul_add(self, datas: List[VectorData]):
data_array, id_array = map(list, zip(*((data.data, data.id) for data in datas)))
np_data = np.array(data_array).astype("float32")
entities = [id_array, np_data]
self.col.insert(entities)
def mul_add(self, datas: List[VectorData], **kwargs):
partition_key = kwargs.get("partition_key", "")
self.col.insert([{"id": data.id, "embedding": np.array(data.data).astype("float32"), "partition_key": partition_key} for data in datas])

def search(self, data: np.ndarray, top_k: int = -1):
def search(self, data: np.ndarray, top_k: int = -1, **kwargs):
if top_k == -1:
top_k = self.top_k
partition_key = kwargs.get("partition_key")
search_result = self.col.search(
data=data.reshape(1, -1).tolist(),
anns_field="embedding",
param=self.search_params,
limit=top_k,
expr=f'partition_key=="{partition_key}"' if partition_key else None,
)
return list(zip(search_result[0].distances, search_result[0].ids))

Expand Down
4 changes: 2 additions & 2 deletions gptcache/manager/vector_data/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _query(self, session):
def _format_data_for_search(self, data):
return f"[{','.join(map(str, data))}]"

def mul_add(self, datas: List[VectorData]):
def mul_add(self, datas: List[VectorData], **_):
data_array, id_array = map(list, zip(*((data.data, data.id) for data in datas)))
np_data = np.array(data_array).astype("float32")
entities = [{"id": id, "embedding": embedding.tolist()} for id, embedding in zip(id_array, np_data)]
Expand All @@ -135,7 +135,7 @@ def mul_add(self, datas: List[VectorData]):
session.bulk_insert_mappings(self._store, entities)
session.commit()

def search(self, data: np.ndarray, top_k: int = -1):
def search(self, data: np.ndarray, top_k: int = -1, **_):
if top_k == -1:
top_k = self.top_k

Expand Down
4 changes: 2 additions & 2 deletions gptcache/manager/vector_data/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,15 @@ def _create_collection(
optimizers_config=optimizers_config,
)

def mul_add(self, datas: List[VectorData]):
def mul_add(self, datas: List[VectorData], **_):
points = [
PointStruct(id=d.id, vector=d.data.reshape(-1).tolist()) for d in datas
]
self._client.upsert(
collection_name=self._collection_name, points=points, wait=False
)

def search(self, data: np.ndarray, top_k: int = -1):
def search(self, data: np.ndarray, top_k: int = -1, **_):
if top_k == -1:
top_k = self.top_k
reshaped_data = data.reshape(-1).tolist()
Expand Down
4 changes: 2 additions & 2 deletions gptcache/manager/vector_data/redis_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _create_collection(self, collection_name):
fields=schema, definition=definition
)

def mul_add(self, datas: List[VectorData]):
def mul_add(self, datas: List[VectorData], **_):
pipe = self._client.pipeline()

for data in datas:
Expand All @@ -110,7 +110,7 @@ def mul_add(self, datas: List[VectorData]):

pipe.execute()

def search(self, data: np.ndarray, top_k: int = -1):
def search(self, data: np.ndarray, top_k: int = -1, **_):
query = (
Query(
f"*=>[KNN {top_k if top_k > 0 else self.top_k} @vector $vec as score]"
Expand Down
4 changes: 2 additions & 2 deletions gptcache/manager/vector_data/usearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ def __init__(
if os.path.isfile(self._index_file_path):
self._index.load(self._index_file_path)

def mul_add(self, datas: List[VectorData]):
def mul_add(self, datas: List[VectorData], **_):
data_array, id_array = map(list, zip(*((data.data, data.id) for data in datas)))
np_data = np.array(data_array).astype("float32")
ids = np.array(id_array, dtype=np.longlong)
self._index.add(ids, np_data)

def search(self, data: np.ndarray, top_k: int = -1):
def search(self, data: np.ndarray, top_k: int = -1, **_):
if top_k == -1:
top_k = self._top_k
np_data = np.array(data).astype("float32").reshape(1, -1)
Expand Down
4 changes: 2 additions & 2 deletions gptcache/manager/vector_data/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _get_default_class_schema(self) -> dict:
"vectorIndexConfig": {"distance": "cosine"},
}

def mul_add(self, datas: List[VectorData]):
def mul_add(self, datas: List[VectorData], **_):
with self.client.batch(batch_size=100, dynamic=True) as batch:
for data in datas:
properties = {
Expand All @@ -97,7 +97,7 @@ def mul_add(self, datas: List[VectorData]):
data_object=properties, class_name=self.class_name, vector=data.data
)

def search(self, data: np.ndarray, top_k: int = -1):
def search(self, data: np.ndarray, top_k: int = -1, **_):
if top_k == -1:
top_k = self.top_k

Expand Down

0 comments on commit 9adf6ab

Please sign in to comment.