From 0a70b58935a9214d958294d9b4d65f6c3abcaf97 Mon Sep 17 00:00:00 2001 From: MrPresent-Han <116052805+MrPresent-Han@users.noreply.github.com> Date: Tue, 22 Aug 2023 16:02:16 +0800 Subject: [PATCH] refine iterator params to differentiate limit for clients(#26358 #26397) (#1651) Signed-off-by: MrPresent-Han --- examples/iterator.py | 55 +++++++++++++++---- pymilvus/client/prepare.py | 9 +--- pymilvus/orm/collection.py | 11 +++- pymilvus/orm/constants.py | 5 +- pymilvus/orm/iterator.py | 108 +++++++++++++++++++++++++++++++------ 5 files changed, 152 insertions(+), 36 deletions(-) diff --git a/examples/iterator.py b/examples/iterator.py index f763faf33..9edd82a97 100644 --- a/examples/iterator.py +++ b/examples/iterator.py @@ -21,7 +21,6 @@ DIM = 8 CLEAR_EXIST = False - def re_create_collection(): if utility.has_collection(COLLECTION_NAME) and CLEAR_EXIST: utility.drop_collection(COLLECTION_NAME) @@ -77,8 +76,7 @@ def prepare_data(collection): def query_iterate_collection_no_offset(collection): expr = f"10 <= {AGE} <= 14" query_iterator = collection.query_iterator(expr=expr, output_fields=[USER_ID, AGE], - offset=0, limit=5, consistency_level=CONSISTENCY_LEVEL, - iteration_extension_reduce_rate=10) + offset=0, batch_size=5, consistency_level=CONSISTENCY_LEVEL) page_idx = 0 while True: res = query_iterator.next() @@ -94,8 +92,23 @@ def query_iterate_collection_no_offset(collection): def query_iterate_collection_with_offset(collection): expr = f"10 <= {AGE} <= 14" query_iterator = collection.query_iterator(expr=expr, output_fields=[USER_ID, AGE], - offset=10, limit=5, consistency_level=CONSISTENCY_LEVEL, - iteration_extension_reduce_rate=10) + offset=10, batch_size=50, consistency_level=CONSISTENCY_LEVEL) + page_idx = 0 + while True: + res = query_iterator.next() + if len(res) == 0: + print("query iteration finished, close") + query_iterator.close() + break + for i in range(len(res)): + print(res[i]) + page_idx += 1 + print(f"page{page_idx}-------------------------") + +def query_iterate_collection_with_limit(collection): + expr = f"10 <= {AGE} <= 44" + query_iterator = collection.query_iterator(expr=expr, output_fields=[USER_ID, AGE], + batch_size=80, limit=530, consistency_level=CONSISTENCY_LEVEL) page_idx = 0 while True: res = query_iterator.next() @@ -117,20 +130,42 @@ def search_iterator_collection(collection): "metric_type": "L2", "params": {"nprobe": 10, "radius": 1.0}, } - search_iterator = collection.search_iterator(vectors_to_search, PICTURE, search_params, limit=5, + search_iterator = collection.search_iterator(vectors_to_search, PICTURE, search_params, batch_size=500, output_fields=[USER_ID]) page_idx = 0 while True: res = search_iterator.next() - if len(res[0]) == 0: + if len(res) == 0: print("query iteration finished, close") search_iterator.close() break - for i in range(len(res[0])): - print(res[0][i]) + for i in range(len(res)): + print(res[i]) page_idx += 1 print(f"page{page_idx}-------------------------") +def search_iterator_collection_with_limit(collection): + SEARCH_NQ = 1 + DIM = 8 + rng = np.random.default_rng(seed=19530) + vectors_to_search = rng.random((SEARCH_NQ, DIM)) + search_params = { + "metric_type": "L2", + "params": {"nprobe": 10, "radius": 1.0}, + } + search_iterator = collection.search_iterator(vectors_to_search, PICTURE, search_params, batch_size=200, limit=755, + output_fields=[USER_ID]) + page_idx = 0 + while True: + res = search_iterator.next() + if len(res) == 0: + print("query iteration finished, close") + search_iterator.close() + break + for i in range(len(res)): + print(res[i]) + page_idx += 1 + print(f"page{page_idx}-------------------------") def main(): connections.connect("default", host=HOST, port=PORT) @@ -138,7 +173,9 @@ def main(): collection = prepare_data(collection) query_iterate_collection_no_offset(collection) query_iterate_collection_with_offset(collection) + query_iterate_collection_with_limit(collection) search_iterator_collection(collection) + search_iterator_collection_with_limit(collection) if __name__ == '__main__': diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index da7596518..f6257cfe5 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -13,7 +13,7 @@ from . import blob, entity_helper, ts_utils from .check import check_pass_param, is_legal_collection_properties -from .constants import DEFAULT_CONSISTENCY_LEVEL, ITERATION_EXTENSION_REDUCE_RATE +from .constants import DEFAULT_CONSISTENCY_LEVEL from .types import DataType, PlaceholderType, get_consistency_level from .utils import traverse_info, traverse_rows_info @@ -835,13 +835,6 @@ def query_request( req.query_params.append( common_types.KeyValuePair(key="ignore_growing", value=str(ignore_growing)) ) - - use_iteration_extension_reduce_rate = kwargs.get(ITERATION_EXTENSION_REDUCE_RATE, 0) - req.query_params.append( - common_types.KeyValuePair( - key=ITERATION_EXTENSION_REDUCE_RATE, value=str(use_iteration_extension_reduce_rate) - ) - ) return req @classmethod diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index 975c66213..6d1d357f4 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -36,6 +36,7 @@ from pymilvus.settings import Config from .connections import connections +from .constants import UNLIMITED from .future import MutationFuture, SearchFuture from .index import Index from .iterator import QueryIterator, SearchIterator @@ -798,7 +799,8 @@ def search_iterator( data: List, anns_field: str, param: Dict, - limit: int, + batch_size: Optional[int] = 1000, + limit: Optional[int] = UNLIMITED, expr: Optional[str] = None, partition_names: Optional[List[str]] = None, output_fields: Optional[List[str]] = None, @@ -814,6 +816,7 @@ def search_iterator( data=data, ann_field=anns_field, param=param, + batch_size=batch_size, limit=limit, expr=expr, partition_names=partition_names, @@ -919,15 +922,21 @@ def query( def query_iterator( self, + batch_size: Optional[int] = 1000, + limit: Optional[int] = UNLIMITED, expr: Optional[str] = None, output_fields: Optional[List[str]] = None, partition_names: Optional[List[str]] = None, timeout: Optional[float] = None, **kwargs, ): + if expr is not None and not isinstance(expr, str): + raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr)) return QueryIterator( connection=self._get_connection(), collection_name=self._name, + batch_size=batch_size, + limit=limit, expr=expr, output_fields=output_fields, partition_names=partition_names, diff --git a/pymilvus/orm/constants.py b/pymilvus/orm/constants.py index c59505e92..8d87f39d0 100644 --- a/pymilvus/orm/constants.py +++ b/pymilvus/orm/constants.py @@ -26,7 +26,8 @@ CALC_DIST_DIM = "dim" OFFSET = "offset" -LIMIT = "limit" +MILVUS_LIMIT = "limit" +BATCH_SIZE = "batch_size" ID = "id" METRIC_TYPE = "metric_type" PARAMS = "params" @@ -43,3 +44,5 @@ DEFAULT_MIN_COSINE_DISTANCE = -2.0 MAX_FILTERED_IDS_COUNT_ITERATION = 100000 INT64_MAX = 9223372036854775807 +MAX_BATCH_SIZE: int = 16384 +UNLIMITED: int = -1 diff --git a/pymilvus/orm/iterator.py b/pymilvus/orm/iterator.py index 20a930c1f..3d91c9b42 100644 --- a/pymilvus/orm/iterator.py +++ b/pymilvus/orm/iterator.py @@ -1,10 +1,15 @@ from copy import deepcopy from typing import Any, Dict, List, Optional, TypeVar -from pymilvus.exceptions import MilvusException +from pymilvus.client.abstract import ChunkedQueryResult, LoopBase +from pymilvus.exceptions import ( + MilvusException, + ParamError, +) from .connections import Connections from .constants import ( + BATCH_SIZE, CALC_DIST_COSINE, CALC_DIST_HAMMING, CALC_DIST_IP, @@ -20,13 +25,15 @@ FIELDS, INT64_MAX, ITERATION_EXTENSION_REDUCE_RATE, - LIMIT, + MAX_BATCH_SIZE, MAX_FILTERED_IDS_COUNT_ITERATION, METRIC_TYPE, + MILVUS_LIMIT, OFFSET, PARAMS, RADIUS, RANGE_FILTER, + UNLIMITED, ) from .schema import CollectionSchema from .types import DataType @@ -40,6 +47,8 @@ def __init__( self, connection: Connections, collection_name: str, + batch_size: Optional[int] = 1000, + limit: Optional[int] = -1, expr: Optional[str] = None, output_fields: Optional[List[str]] = None, partition_names: Optional[List[str]] = None, @@ -54,11 +63,22 @@ def __init__( self._schema = schema self._timeout = timeout self._kwargs = kwargs + self.__check_set_batch_size(batch_size) + self._limit = limit + self._returned_count = 0 self.__setup__pk_prop() self.__set_up_expr(expr) self.__seek() self._cache_id_in_use = NO_CACHE_ID + def __check_set_batch_size(self, batch_size: int): + if batch_size < 0: + raise ParamError(message="batch size cannot be less than zero") + if batch_size > MAX_BATCH_SIZE: + raise ParamError(message=f"batch size cannot be larger than {MAX_BATCH_SIZE}") + self._kwargs[BATCH_SIZE] = batch_size + self._kwargs[MILVUS_LIMIT] = batch_size + # rely on pk prop, so this method should be called after __set_up_expr def __set_up_expr(self, expr: str): if expr is not None: @@ -77,7 +97,7 @@ def __seek(self): first_cursor_kwargs = self._kwargs.copy() first_cursor_kwargs[OFFSET] = 0 # offset may be too large, needed to seek in multiple times - first_cursor_kwargs[LIMIT] = self._kwargs[OFFSET] + first_cursor_kwargs[MILVUS_LIMIT] = self._kwargs[OFFSET] first_cursor_kwargs[ITERATION_EXTENSION_REDUCE_RATE] = 0 res = self._conn.query( @@ -92,22 +112,22 @@ def __seek(self): self._kwargs[OFFSET] = 0 def __maybe_cache(self, result: List): - if len(result) < 2 * self._kwargs[LIMIT]: + if len(result) < 2 * self._kwargs[BATCH_SIZE]: return - start = self._kwargs[LIMIT] + start = self._kwargs[BATCH_SIZE] cache_result = result[start:] cache_id = iterator_cache.cache(cache_result, NO_CACHE_ID) self._cache_id_in_use = cache_id def __is_res_sufficient(self, res: List): - return res is not None and len(res) >= self._kwargs[LIMIT] + return res is not None and len(res) >= self._kwargs[BATCH_SIZE] def next(self): cached_res = iterator_cache.fetch_cache(self._cache_id_in_use) ret = None if self.__is_res_sufficient(cached_res): - ret = cached_res[0 : self._kwargs[LIMIT]] - res_to_cache = cached_res[self._kwargs[LIMIT] :] + ret = cached_res[0 : self._kwargs[BATCH_SIZE]] + res_to_cache = cached_res[self._kwargs[BATCH_SIZE] :] iterator_cache.cache(res_to_cache, self._cache_id_in_use) else: iterator_cache.release_cache(self._cache_id_in_use) @@ -121,10 +141,22 @@ def next(self): **self._kwargs, ) self.__maybe_cache(res) - ret = res[0 : min(self._kwargs[LIMIT], len(res))] + ret = res[0 : min(self._kwargs[BATCH_SIZE], len(res))] + + ret = self.check_reached_limit(ret) self.__update_cursor(ret) + self._returned_count += len(ret) return ret + def check_reached_limit(self, ret: List): + if self._limit == UNLIMITED: + return ret + left_count = self._limit - self._returned_count + if left_count >= len(ret): + return ret + # has exceeded the limit, cut off the result and return + return ret[0:left_count] + def __setup__pk_prop(self): fields = self._schema[FIELDS] for field in fields: @@ -177,6 +209,28 @@ def default_radius(metrics: str): raise MilvusException(message="unknown metrics type for search iteration") +class SearchPage(LoopBase): + """Since we only support nq=1 in search iteration, so search iteration response + should be different from raw response of search operation""" + + def __init__(self, res: List): + super().__init__() + self._res = res + + def get_res(self): + return self._res + + def __len__(self): + if self._res is not None: + return len(self._res) + return 0 + + def get__item(self, idx: Any): + if self._res is None: + return None + return self._res[idx] + + class SearchIterator: def __init__( self, @@ -185,7 +239,8 @@ def __init__( data: List, ann_field: str, param: Dict, - limit: int, + batch_size: Optional[int] = 1000, + limit: Optional[int] = UNLIMITED, expr: Optional[str] = None, partition_names: Optional[List[str]] = None, output_fields: Optional[List[str]] = None, @@ -195,13 +250,17 @@ def __init__( **kwargs, ) -> SearchIterator: if len(data) > 1: - raise MilvusException(message="Not support multiple vector iterator at present") + raise ParamError( + message="Not support search iteration over multiple vectors at present" + ) + if len(data) == 0: + raise ParamError(message="vector_data for search cannot be empty") self._conn = connection self._iterator_params = { "collection_name": collection_name, "data": data, "ann_field": ann_field, - "limit": limit, + BATCH_SIZE: batch_size, "output_fields": output_fields, "partition_names": partition_names, "timeout": timeout, @@ -214,11 +273,24 @@ def __init__( self._filtered_ids = [] self._filtered_distance = None self._schema = schema + self._limit = limit + self._returned_count = 0 self.__check_metrics() self.__check_radius() self.__seek() self.__setup__pk_prop() + def check_reached_limit(self, ret: ChunkedQueryResult): + if self._limit == UNLIMITED: + return SearchPage(ret[0]) + left_count = self._limit - self._returned_count + if left_count >= len(ret[0]): + return SearchPage(ret[0]) + # has exceeded the limit, cut off the result and return + left_ret_arr = None + left_ret_arr = [] if left_count == 0 else ret[0][0:left_count] + return SearchPage(left_ret_arr) + def __check_set_params(self, param: Dict): if param is None: self._param = {} @@ -252,17 +324,17 @@ def __seek(self): if self._kwargs.get(OFFSET, 0) != 0: raise MilvusException(message="Not support offset when searching iteration") - def __update_cursor(self, res: Any): - if len(res[0]) == 0: + def __update_cursor(self, res: SearchPage): + if len(res) == 0: return - last_hit = res[0][-1] + last_hit = res[-1] if last_hit is None: return self._distance_cursor[0] = last_hit.distance if self._distance_cursor[0] != self._filtered_distance: self._filtered_ids = [] # distance has changed, clear filter_ids array self._filtered_distance = self._distance_cursor[0] # renew the distance for filtering - for hit in res[0]: + for hit in res: if hit.distance == last_hit.distance: self._filtered_ids.append(hit.id) if len(self._filtered_ids) > MAX_FILTERED_IDS_COUNT_ITERATION: @@ -280,7 +352,7 @@ def next(self): self._iterator_params["data"], self._iterator_params["ann_field"], next_params, - self._iterator_params["limit"], + self._iterator_params[BATCH_SIZE], next_expr, self._iterator_params["partition_names"], self._iterator_params["output_fields"], @@ -289,7 +361,9 @@ def next(self): schema=self._schema, **self._kwargs, ) + res = self.check_reached_limit(res) self.__update_cursor(res) + self._returned_count += len(res) return res # at present, the range_filter parameter means 'larger/less and equal',