diff --git a/gptcache/manager/vector_data/lancedb.py b/gptcache/manager/vector_data/lancedb.py index b16e6f91..eb12e5d0 100644 --- a/gptcache/manager/vector_data/lancedb.py +++ b/gptcache/manager/vector_data/lancedb.py @@ -42,8 +42,7 @@ def __init__( def mul_add(self, datas: List[VectorData]): """Add multiple vectors to the LanceDB table""" - vectors, ids = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas))) - + vectors, vector_ids = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas))) # Infer the dimension of the vectors vector_dim = len(vectors[0]) if vectors else 0 @@ -56,7 +55,7 @@ def mul_add(self, datas: List[VectorData]): self._table = self._db.create_table(self._table_name, schema=schema) # Prepare data for insertion - data = [{"id": id, "vector": vector} for id, vector in zip(ids, vectors)] + data = [{"id": vector_id, "vector": vector} for vector_id, vector in zip(vector_ids, vectors)] self._table.add(data) def search(self, data: np.ndarray, top_k: int = -1): @@ -72,21 +71,13 @@ def search(self, data: np.ndarray, top_k: int = -1): def delete(self, ids: List[int]): """Delete vectors from the LanceDB table based on IDs""" - for id in ids: - self._table.delete(f"id = '{id}'") + for vector_id in ids: + self._table.delete(f"id = '{vector_id}'") def rebuild(self, ids: Optional[List[int]] = None): """Rebuild the index, if applicable""" return True - def flush(self): - """Flush changes to disk (if necessary)""" - pass - - def close(self): - """Close the connection to LanceDB""" - pass - def count(self): """Return the total number of vectors in the table""" - return len(self._table) \ No newline at end of file + return len(self._table) diff --git a/gptcache/manager/vector_data/manager.py b/gptcache/manager/vector_data/manager.py index ef943f81..2314654d 100644 --- a/gptcache/manager/vector_data/manager.py +++ b/gptcache/manager/vector_data/manager.py @@ -273,7 +273,7 @@ def get(name, **kwargs): from gptcache.manager.vector_data.weaviate import Weaviate url = kwargs.get("url", None) - auth_client_secret = kwargs.get("auth_client_secrets", None) + auth_client_secret = kwargs.get("auth_client_secret", None) timeout_config = kwargs.get("timeout_config", WEAVIATE_TIMEOUT_CONFIG) proxies = kwargs.get("proxies", None) trust_env = kwargs.get("trust_env", False)