diff --git a/gptcache/manager/vector_data/manager.py b/gptcache/manager/vector_data/manager.py index 815fb934..fe21b6f9 100644 --- a/gptcache/manager/vector_data/manager.py +++ b/gptcache/manager/vector_data/manager.py @@ -76,6 +76,8 @@ class VectorBase: :type local_mode: bool :param local_data: required when local_mode is True. :type local_data: str + :param use_partition_key: if true, use partition key feature in milvus. + :type use_partition_key: bool :param url: the connection url for PostgreSQL database, defaults to 'postgresql://postgres@localhost:5432/postgres' :type url: str @@ -125,6 +127,7 @@ def get(name, **kwargs): search_params = kwargs.get("search_params", None) local_mode = kwargs.get("local_mode", False) local_data = kwargs.get("local_data", "./milvus_data") + use_partition_key = kwargs.get("use_partition_key", False) vector_base = Milvus( host=host, port=port, @@ -138,6 +141,7 @@ def get(name, **kwargs): search_params=search_params, local_mode=local_mode, local_data=local_data, + use_partition_key=use_partition_key ) elif name == "faiss": from gptcache.manager.vector_data.faiss import Faiss diff --git a/gptcache/manager/vector_data/milvus.py b/gptcache/manager/vector_data/milvus.py index 9f286b3b..d69d974b 100644 --- a/gptcache/manager/vector_data/milvus.py +++ b/gptcache/manager/vector_data/milvus.py @@ -74,7 +74,8 @@ def __init__( index_params: dict = None, search_params: dict = None, local_mode: bool = False, - local_data: str = "./milvus_data" + local_data: str = "./milvus_data", + use_partition_key: bool = False ): if dimension <= 0: raise ValueError( @@ -85,6 +86,7 @@ def __init__( self.dimension = dimension self.top_k = top_k self.index_params = index_params + self.use_partition_key = use_partition_key if self._local_mode: self._create_local(port, local_data) self._connect(host, port, user, password, secure) @@ -131,16 +133,19 @@ def _create_collection(self, collection_name): is_primary=True, auto_id=False, ), - FieldSchema( - name="partition_key", - dtype=DataType.VARCHAR, - max_length=256, - is_partition_key=True, - ), FieldSchema( name="embedding", dtype=DataType.FLOAT_VECTOR, dim=self.dimension ), ] + if self.use_partition_key: + schema.append( + FieldSchema( + name="partition_key", + dtype=DataType.VARCHAR, + max_length=256, + is_partition_key=True, + ) + ) schema = CollectionSchema(schema) self.col = Collection( collection_name, @@ -170,8 +175,11 @@ def _create_collection(self, collection_name): self.col.load() def mul_add(self, datas: List[VectorData], **kwargs): - partition_key = kwargs.get("partition_key", "") - self.col.insert([{"id": data.id, "embedding": np.array(data.data).astype("float32"), "partition_key": partition_key} for data in datas]) + if self.use_partition_key: + partition_key = kwargs.get("partition_key") or "default" + self.col.insert([{"id": data.id, "embedding": np.array(data.data).astype("float32"), "partition_key": partition_key} for data in datas]) + else: + self.col.insert([{"id": data.id, "embedding": np.array(data.data).astype("float32")} for data in datas]) def search(self, data: np.ndarray, top_k: int = -1, **kwargs): if top_k == -1: @@ -182,7 +190,7 @@ def search(self, data: np.ndarray, top_k: int = -1, **kwargs): anns_field="embedding", param=self.search_params, limit=top_k, - expr=f'partition_key=="{partition_key}"' if partition_key else None, + expr=f'partition_key=="{partition_key}"' if (self.use_partition_key and partition_key) else None, ) return list(zip(search_result[0].distances, search_result[0].ids))