diff --git a/.github/workflows/code_checker.yml b/.github/workflows/code_checker.yml index 962648e6c..9a681c0e3 100644 --- a/.github/workflows/code_checker.yml +++ b/.github/workflows/code_checker.yml @@ -19,6 +19,7 @@ jobs: uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} + cache: pip - name: Install requirements run: | pip install -r requirements.txt @@ -26,3 +27,5 @@ jobs: shell: bash run: | make lint + - name: Run pre-commit + uses: pre-commit/action@v3.0.0 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..2e4bdcd9b --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,43 @@ +exclude: "^docs/|/grpc_gen/|/.git/|/.tox/|/media/|/.pytest_cache/|^.idea/" +default_stages: [commit] + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.3.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-json + - id: check-xml + - id: check-executables-have-shebangs + - id: check-toml + - id: check-xml + - id: check-yaml + - id: debug-statements + - id: check-symlinks + - id: detect-aws-credentials + args: [ '--allow-missing-credentials' ] + - id: detect-private-key + exclude: ^examples|(?:tests/ssl)/ + + - repo: https://github.com/asottile/pyupgrade + rev: v2.34.0 + hooks: + - id: pyupgrade + args: ['--py37-plus', '--keep-mock'] + + - repo: https://github.com/psf/black + rev: 22.10.0 + hooks: + - id: black + + - repo: https://github.com/PyCQA/isort + rev: 5.10.1 + hooks: + - id: isort + +# sets up .pre-commit-ci.yaml to ensure pre-commit dependencies stay up to date +ci: + autoupdate_schedule: weekly + skip: [] + submodules: false diff --git a/pymilvus/__init__.py b/pymilvus/__init__.py index 99ee25810..c6135a0b3 100644 --- a/pymilvus/__init__.py +++ b/pymilvus/__init__.py @@ -10,64 +10,107 @@ # or implied. See the License for the specific language governing permissions and limitations under # the License. -from .client.stub import Milvus +from .client import __version__ from .client.prepare import Prepare -from .client.types import Status, DataType, RangeType, IndexType, Replica, Group, Shard, BulkInsertState +from .client.stub import Milvus +from .client.types import ( + BulkInsertState, + DataType, + Group, + IndexType, + RangeType, + Replica, + Shard, + Status, +) from .exceptions import ( - ParamError, + ExceptionsMessage, MilvusException, MilvusUnavailableException, - ExceptionsMessage + ParamError, ) -from .client import __version__ - -from .settings import DEBUG_LOG_LEVEL, INFO_LOG_LEVEL, WARN_LOG_LEVEL, ERROR_LOG_LEVEL - +from .orm import utility from .orm.collection import Collection -from .orm.connections import connections, Connections - +from .orm.connections import Connections, connections +from .orm.default_config import ENV_CONNECTION_CONF, DefaultConfig +from .orm.future import MutationFuture, SearchFuture from .orm.index import Index from .orm.partition import Partition +from .orm.role import Role +from .orm.schema import CollectionSchema, FieldSchema +from .orm.search import Hit, Hits, SearchResult from .orm.utility import ( - loading_progress, - index_building_progress, - wait_for_loading_complete, - wait_for_index_building_complete, + create_user, + delete_user, + do_bulk_insert, + drop_collection, + get_bulk_insert_state, + get_query_segment_info, has_collection, has_partition, + hybridts_to_datetime, + hybridts_to_unixtime, + index_building_progress, + list_bulk_insert_tasks, list_collections, - drop_collection, - get_query_segment_info, + list_usernames, load_balance, - mkts_from_hybridts, mkts_from_unixtime, mkts_from_datetime, - hybridts_to_unixtime, hybridts_to_datetime, - do_bulk_insert, get_bulk_insert_state, list_bulk_insert_tasks, - reset_password, create_user, update_password, delete_user, list_usernames, + loading_progress, + mkts_from_datetime, + mkts_from_hybridts, + mkts_from_unixtime, + reset_password, + update_password, + wait_for_index_building_complete, + wait_for_loading_complete, ) +from .settings import DEBUG_LOG_LEVEL, ERROR_LOG_LEVEL, INFO_LOG_LEVEL, WARN_LOG_LEVEL -from .orm import utility -from .orm.default_config import DefaultConfig, ENV_CONNECTION_CONF - -from .orm.search import SearchResult, Hits, Hit -from .orm.schema import FieldSchema, CollectionSchema -from .orm.future import SearchFuture, MutationFuture -from .orm.role import Role __all__ = [ - 'Collection', 'Index', 'Partition', - 'connections', - 'loading_progress', 'index_building_progress', 'wait_for_loading_complete', 'has_collection', 'has_partition', - 'list_collections', 'wait_for_loading_complete', 'wait_for_index_building_complete', 'drop_collection', - 'mkts_from_hybridts', 'mkts_from_unixtime', 'mkts_from_datetime', - 'hybridts_to_unixtime', 'hybridts_to_datetime', - 'reset_password', 'create_user', 'update_password', 'delete_user', 'list_usernames', - 'SearchResult', 'Hits', 'Hit', 'Replica', 'Group', 'Shard', - 'FieldSchema', 'CollectionSchema', - 'SearchFuture', 'MutationFuture', - 'utility', 'DefaultConfig', 'ExceptionsMessage', 'MilvusUnavailableException', 'BulkInsertState', - 'Role', - - 'Milvus', 'Prepare', 'Status', 'DataType', - 'MilvusException', - '__version__' + "Collection", + "Index", + "Partition", + "connections", + "loading_progress", + "index_building_progress", + "wait_for_loading_complete", + "has_collection", + "has_partition", + "list_collections", + "wait_for_loading_complete", + "wait_for_index_building_complete", + "drop_collection", + "mkts_from_hybridts", + "mkts_from_unixtime", + "mkts_from_datetime", + "hybridts_to_unixtime", + "hybridts_to_datetime", + "reset_password", + "create_user", + "update_password", + "delete_user", + "list_usernames", + "SearchResult", + "Hits", + "Hit", + "Replica", + "Group", + "Shard", + "FieldSchema", + "CollectionSchema", + "SearchFuture", + "MutationFuture", + "utility", + "DefaultConfig", + "ExceptionsMessage", + "MilvusUnavailableException", + "BulkInsertState", + "Role", + "Milvus", + "Prepare", + "Status", + "DataType", + "MilvusException", + "__version__", ] diff --git a/pymilvus/client/__init__.py b/pymilvus/client/__init__.py index c7405010d..cf907dd71 100644 --- a/pymilvus/client/__init__.py +++ b/pymilvus/client/__init__.py @@ -1,20 +1,22 @@ -import subprocess import re -from pkg_resources import get_distribution, DistributionNotFound +import subprocess + +from pkg_resources import DistributionNotFound, get_distribution + -__version__ = '0.0.0.dev' +__version__ = "0.0.0.dev" try: - __version__ = get_distribution('pymilvus').version + __version__ = get_distribution("pymilvus").version except DistributionNotFound: # package is not installed pass def get_commit(version="", short=True) -> str: - """get commit return the commit for a specific version like `xxxxxx.dev12` """ + """get commit return the commit for a specific version like `xxxxxx.dev12`""" - version_info = r'((\d+)\.(\d+)\.(\d+))((rc)(\d+))?(\.dev(\d+))?' + version_info = r"((\d+)\.(\d+)\.(\d+))((rc)(\d+))?(\.dev(\d+))?" # 2.0.0rc9.dev12 # ('2.0.0', '2', '0', '0', 'rc9', 'rc', '9', '.dev12', '12') p = re.compile(version_info) @@ -27,23 +29,31 @@ def get_commit(version="", short=True) -> str: if match_version[7] is not None: if match_version[4] is not None: v = str(int(match_version[6]) - 1) - target_tag = 'v' + match_version[0] + match_version[5] + v + target_tag = "v" + match_version[0] + match_version[5] + v else: - target_tag = 'v' + ".".join(str(int("".join(match_version[1:4])) - 1).split("")) + target_tag = "v" + ".".join( + str(int("".join(match_version[1:4])) - 1).split("") + ) target_num = int(match_version[-1]) elif match_version[4] is not None: - target_tag = 'v' + match_version[0] + match_version[4] + target_tag = "v" + match_version[0] + match_version[4] target_num = 0 else: - target_tag = 'v' + match_version[0] + target_tag = "v" + match_version[0] target_num = 0 else: return f"Version: {target_v} isn't the right form" try: - cmd = ['git', 'rev-list', '--reverse', '--ancestry-path', f'{target_tag}^..HEAD'] + cmd = [ + "git", + "rev-list", + "--reverse", + "--ancestry-path", + f"{target_tag}^..HEAD", + ] print(f"git cmd: {' '.join(cmd)}") - result = subprocess.check_output(cmd).decode('ascii').strip().split('\n') + result = subprocess.check_output(cmd).decode("ascii").strip().split("\n") length = 7 if short else 40 return result[target_num][:length] diff --git a/pymilvus/client/abstract.py b/pymilvus/client/abstract.py index b83f9cdfb..e0a63edcc 100644 --- a/pymilvus/client/abstract.py +++ b/pymilvus/client/abstract.py @@ -1,11 +1,12 @@ import abc import numpy as np + +from ..exceptions import MilvusException +from ..grpc_gen import schema_pb2 from .configs import DefaultConfigs -from .types import DataType from .constants import DEFAULT_CONSISTENCY_LEVEL -from ..grpc_gen import schema_pb2 -from ..exceptions import MilvusException +from .types import DataType class LoopBase: @@ -84,6 +85,7 @@ def __pack(self, raw): for type_param in raw.type_params: if type_param.key == "params": import json + self.params[type_param.key] = json.loads(type_param.value) else: self.params[type_param.key] = type_param.value @@ -95,6 +97,7 @@ def __pack(self, raw): for index_param in raw.index_params: if index_param.key == "params": import json + index_dict[index_param.key] = json.loads(index_param.value) else: index_dict[index_param.key] = index_param.value @@ -187,7 +190,7 @@ def __init__(self, entity_id, entity_row_data, entity_score): self._distance = entity_score def __str__(self): - return f'id: {self._id}, distance: {self._distance}, entity: {self._row_data}' + return f"id: {self._id}, distance: {self._distance}, entity: {self._row_data}" def __getattr__(self, item): return self.value_of_field(item) @@ -210,7 +213,7 @@ def value_of_field(self, field): return self._row_data[field] def type_of_field(self, field): - raise NotImplementedError('TODO: support field in Hits') + raise NotImplementedError("TODO: support field in Hits") class Hit: @@ -269,22 +272,34 @@ def get__item(self, item): for field_data in self._raw.fields_data: if field_data.type == DataType.BOOL: if len(field_data.scalars.bool_data.data) >= item: - entity_row_data[field_data.field_name] = field_data.scalars.bool_data.data[item] + entity_row_data[ + field_data.field_name + ] = field_data.scalars.bool_data.data[item] elif field_data.type in (DataType.INT8, DataType.INT16, DataType.INT32): if len(field_data.scalars.int_data.data) >= item: - entity_row_data[field_data.field_name] = field_data.scalars.int_data.data[item] + entity_row_data[ + field_data.field_name + ] = field_data.scalars.int_data.data[item] elif field_data.type == DataType.INT64: if len(field_data.scalars.long_data.data) >= item: - entity_row_data[field_data.field_name] = field_data.scalars.long_data.data[item] + entity_row_data[ + field_data.field_name + ] = field_data.scalars.long_data.data[item] elif field_data.type == DataType.FLOAT: if len(field_data.scalars.float_data.data) >= item: - entity_row_data[field_data.field_name] = np.single(field_data.scalars.float_data.data[item]) + entity_row_data[field_data.field_name] = np.single( + field_data.scalars.float_data.data[item] + ) elif field_data.type == DataType.DOUBLE: if len(field_data.scalars.double_data.data) >= item: - entity_row_data[field_data.field_name] = field_data.scalars.double_data.data[item] + entity_row_data[ + field_data.field_name + ] = field_data.scalars.double_data.data[item] elif field_data.type == DataType.VARCHAR: if len(field_data.scalars.string_data.data) >= item: - entity_row_data[field_data.field_name] = field_data.scalars.string_data.data[item] + entity_row_data[ + field_data.field_name + ] = field_data.scalars.string_data.data[item] elif field_data.type == DataType.STRING: raise MilvusException(message="Not support string yet") # result[field_data.field_name] = field_data.scalars.string_data.data[index] @@ -293,16 +308,20 @@ def get__item(self, item): if len(field_data.vectors.float_vector.data) >= item * dim: start_pos = item * dim end_pos = item * dim + dim - entity_row_data[field_data.field_name] = [np.single(x) for x in - field_data.vectors.float_vector.data[ - start_pos:end_pos]] + entity_row_data[field_data.field_name] = [ + np.single(x) + for x in field_data.vectors.float_vector.data[ + start_pos:end_pos + ] + ] elif field_data.type == DataType.BINARY_VECTOR: dim = field_data.vectors.dim if len(field_data.vectors.binary_vector.data) >= item * (dim / 8): start_pos = item * (dim / 8) end_pos = (item + 1) * (dim / 8) entity_row_data[field_data.field_name] = [ - field_data.vectors.binary_vector.data[start_pos:end_pos]] + field_data.vectors.binary_vector.data[start_pos:end_pos] + ] entity_score = self._distances[item] return Hit(entity_id, entity_row_data, entity_score) @@ -369,8 +388,10 @@ def err_index(self): return self._err_index def __str__(self): - return f"(insert count: {self._insert_cnt}, delete count: {self._delete_cnt}, upsert count: {self._upsert_cnt}, " \ - f"timestamp: {self._timestamp}, success count: {self.succ_count}, err count: {self.err_count})" + return ( + f"(insert count: {self._insert_cnt}, delete count: {self._delete_cnt}, upsert count: {self._upsert_cnt}, " + f"timestamp: {self._timestamp}, success count: {self.succ_count}, err count: {self.err_count})" + ) __repr__ = __str__ @@ -416,27 +437,39 @@ def _pack(self, raw): hit = schema_pb2.SearchResultData() start_pos = offset end_pos = offset + raw.results.topks[i] - hit.scores.append(raw.results.scores[start_pos: end_pos]) + hit.scores.append(raw.results.scores[start_pos:end_pos]) if raw.results.ids.HasField("int_id"): - hit.ids.append(raw.results.ids.int_id.data[start_pos: end_pos]) + hit.ids.append(raw.results.ids.int_id.data[start_pos:end_pos]) elif raw.results.ids.HasField("str_id"): - hit.ids.append(raw.results.ids.str_id.data[start_pos: end_pos]) + hit.ids.append(raw.results.ids.str_id.data[start_pos:end_pos]) for field_data in raw.result.fields_data: field = schema_pb2.FieldData() field.type = field_data.type field.field_name = field_data.field_name if field_data.type == DataType.BOOL: - field.scalars.bool_data.data.extend(field_data.scalars.bool_data.data[start_pos: end_pos]) + field.scalars.bool_data.data.extend( + field_data.scalars.bool_data.data[start_pos:end_pos] + ) elif field_data.type in (DataType.INT8, DataType.INT16, DataType.INT32): - field.scalars.int_data.data.extend(field_data.scalars.int_data.data[start_pos: end_pos]) + field.scalars.int_data.data.extend( + field_data.scalars.int_data.data[start_pos:end_pos] + ) elif field_data.type == DataType.INT64: - field.scalars.long_data.data.extend(field_data.scalars.long_data.data[start_pos: end_pos]) + field.scalars.long_data.data.extend( + field_data.scalars.long_data.data[start_pos:end_pos] + ) elif field_data.type == DataType.FLOAT: - field.scalars.float_data.data.extend(field_data.scalars.float_data.data[start_pos: end_pos]) + field.scalars.float_data.data.extend( + field_data.scalars.float_data.data[start_pos:end_pos] + ) elif field_data.type == DataType.DOUBLE: - field.scalars.double_data.data.extend(field_data.scalars.double_data.data[start_pos: end_pos]) + field.scalars.double_data.data.extend( + field_data.scalars.double_data.data[start_pos:end_pos] + ) elif field_data.type == DataType.VARCHAR: - field.scalars.string_data.data.extend(field_data.scalars.string_data.data[start_pos: end_pos]) + field.scalars.string_data.data.extend( + field_data.scalars.string_data.data[start_pos:end_pos] + ) elif field_data.type == DataType.STRING: raise MilvusException(message="Not support string yet") # result[field_data.field_name] = field_data.scalars.string_data.data[index] @@ -444,12 +477,18 @@ def _pack(self, raw): dim = field.vectors.dim field.vectors.dim = dim field.vectors.float_vector.data.extend( - field_data.vectors.float_data.data[start_pos * dim: end_pos * dim]) + field_data.vectors.float_data.data[ + start_pos * dim : end_pos * dim + ] + ) elif field_data.type == DataType.BINARY_VECTOR: dim = field_data.vectors.dim field.vectors.dim = dim - field.vectors.binary_vector.data.extend(field_data.vectors.binary_vector.data[ - start_pos * (dim / 8): end_pos * (dim / 8)]) + field.vectors.binary_vector.data.extend( + field_data.vectors.binary_vector.data[ + start_pos * (dim / 8) : end_pos * (dim / 8) + ] + ) hit.fields_data.append(field) self._hits.append(hit) offset += raw.results.topks[i] @@ -482,40 +521,66 @@ def _pack(self, raw_list): hit = schema_pb2.SearchResultData() start_pos = offset end_pos = offset + raw.results.topks[i] - hit.scores.extend(raw.results.scores[start_pos: end_pos]) + hit.scores.extend(raw.results.scores[start_pos:end_pos]) if raw.results.ids.HasField("int_id"): - hit.ids.int_id.data.extend(raw.results.ids.int_id.data[start_pos: end_pos]) + hit.ids.int_id.data.extend( + raw.results.ids.int_id.data[start_pos:end_pos] + ) elif raw.results.ids.HasField("str_id"): - hit.ids.str_id.data.extend(raw.results.ids.str_id.data[start_pos: end_pos]) + hit.ids.str_id.data.extend( + raw.results.ids.str_id.data[start_pos:end_pos] + ) for field_data in raw.results.fields_data: field = schema_pb2.FieldData() field.type = field_data.type field.field_name = field_data.field_name if field_data.type == DataType.BOOL: - field.scalars.bool_data.data.extend(field_data.scalars.bool_data.data[start_pos: end_pos]) - elif field_data.type in (DataType.INT8, DataType.INT16, DataType.INT32): - field.scalars.int_data.data.extend(field_data.scalars.int_data.data[start_pos: end_pos]) + field.scalars.bool_data.data.extend( + field_data.scalars.bool_data.data[start_pos:end_pos] + ) + elif field_data.type in ( + DataType.INT8, + DataType.INT16, + DataType.INT32, + ): + field.scalars.int_data.data.extend( + field_data.scalars.int_data.data[start_pos:end_pos] + ) elif field_data.type == DataType.INT64: - field.scalars.long_data.data.extend(field_data.scalars.long_data.data[start_pos: end_pos]) + field.scalars.long_data.data.extend( + field_data.scalars.long_data.data[start_pos:end_pos] + ) elif field_data.type == DataType.FLOAT: - field.scalars.float_data.data.extend(field_data.scalars.float_data.data[start_pos: end_pos]) + field.scalars.float_data.data.extend( + field_data.scalars.float_data.data[start_pos:end_pos] + ) elif field_data.type == DataType.DOUBLE: - field.scalars.double_data.data.extend(field_data.scalars.double_data.data[start_pos: end_pos]) + field.scalars.double_data.data.extend( + field_data.scalars.double_data.data[start_pos:end_pos] + ) elif field_data.type == DataType.VARCHAR: - field.scalars.string_data.data.extend(field_data.scalars.string_data.data[start_pos: end_pos]) + field.scalars.string_data.data.extend( + field_data.scalars.string_data.data[start_pos:end_pos] + ) elif field_data.type == DataType.STRING: raise MilvusException(message="Not support string yet") # result[field_data.field_name] = field_data.scalars.string_data.data[index] elif field_data.type == DataType.FLOAT_VECTOR: dim = field_data.vectors.dim field.vectors.dim = dim - field.vectors.float_vector.data.extend(field_data.vectors.float_vector.data[ - start_pos * dim: end_pos * dim]) + field.vectors.float_vector.data.extend( + field_data.vectors.float_vector.data[ + start_pos * dim : end_pos * dim + ] + ) elif field_data.type == DataType.BINARY_VECTOR: dim = field_data.vectors.dim field.vectors.dim = dim - field.vectors.binary_vector.data.extend(field_data.vectors.binary_vector.data[ - start_pos * (dim / 8): end_pos * (dim / 8)]) + field.vectors.binary_vector.data.extend( + field_data.vectors.binary_vector.data[ + start_pos * (dim / 8) : end_pos * (dim / 8) + ] + ) hit.fields_data.append(field) self._hits.append(hit) offset += raw.results.topks[i] @@ -525,7 +590,7 @@ def get__item(self, item): def _abstract(): - raise NotImplementedError('You need to override this function') + raise NotImplementedError("You need to override this function") class ConnectIntf: @@ -646,7 +711,9 @@ def add_vectors(self, table_name, records, ids, timeout, **kwargs): """ _abstract() - def search_vectors(self, table_name, top_k, nprobe, query_records, query_ranges, **kwargs): + def search_vectors( + self, table_name, top_k, nprobe, query_records, query_ranges, **kwargs + ): """ Query vectors in a table Should be implemented @@ -670,8 +737,9 @@ def search_vectors(self, table_name, top_k, nprobe, query_records, query_ranges, """ _abstract() - def search_vectors_in_files(self, table_name, file_ids, query_records, - top_k, nprobe, query_ranges, **kwargs): + def search_vectors_in_files( + self, table_name, file_ids, query_records, top_k, nprobe, query_ranges, **kwargs + ): """ Query vectors in a table, query vector in specified files Should be implemented diff --git a/pymilvus/client/asynch.py b/pymilvus/client/asynch.py index 5d74fc100..a16816bae 100644 --- a/pymilvus/client/asynch.py +++ b/pymilvus/client/asynch.py @@ -1,14 +1,15 @@ import abc import threading -from .abstract import QueryResult, ChunkedQueryResult, MutationResult from ..exceptions import MilvusException +from .abstract import ChunkedQueryResult, MutationResult, QueryResult from .types import Status # TODO: remove this to a common util def _parameter_is_empty(func): import inspect + sig = inspect.signature(func) # params = sig.parameters # todo: add more check to parameter, such as `default parameter`, @@ -26,29 +27,29 @@ def _parameter_is_empty(func): class AbstractFuture: @abc.abstractmethod def result(self, **kwargs): - '''Return deserialized result. + """Return deserialized result. It's a synchronous interface. It will wait executing until server respond or timeout occur(if specified). This API is thread-safe. - ''' + """ raise NotImplementedError() @abc.abstractmethod def cancel(self): - '''Cancle gRPC future. + """Cancle gRPC future. This API is thread-safe. - ''' + """ raise NotImplementedError() @abc.abstractmethod def done(self): - '''Wait for request done. + """Wait for request done. This API is thread-safe. - ''' + """ raise NotImplementedError() @@ -75,8 +76,7 @@ def __del__(self): @abc.abstractmethod def on_response(self, response): - ''' Parse response from gRPC server and return results. - ''' + """Parse response from gRPC server and return results.""" raise NotImplementedError() def _callback(self, **kwargs): @@ -171,7 +171,9 @@ def on_response(self, response): # TODO: if ChunkedFuture is more common later, consider using ChunkedFuture as Base Class, # then Future(future, done_cb, pre_exception) equal to ChunkedFuture([future], done_cb, pre_exception) class ChunkedSearchFuture(Future): - def __init__(self, future_list, done_callback=None, auto_id=True, pre_exception=None): + def __init__( + self, future_list, done_callback=None, auto_id=True, pre_exception=None + ): super().__init__(None, done_callback, pre_exception) self._auto_id = auto_id self._future_list = future_list @@ -198,7 +200,9 @@ def result(self, **kwargs): self.exception() if kwargs.get("raw", False) is True: # just return response object received from gRPC - raise AttributeError("Not supported to return response object received from gRPC") + raise AttributeError( + "Not supported to return response object received from gRPC" + ) if self._results: return self._results diff --git a/pymilvus/client/blob.py b/pymilvus/client/blob.py index f82001e83..742f69030 100644 --- a/pymilvus/client/blob.py +++ b/pymilvus/client/blob.py @@ -1,38 +1,50 @@ import struct + # reference: https://docs.python.org/3/library/struct.html#struct.pack + def boolToBytes(b): return struct.pack("?", b) + def int8ToBytes(i): return struct.pack("b", i) + def int16ToBytes(i): return struct.pack("h", i) + def int32ToBytes(i): return struct.pack("i", i) + def int64ToBytes(i): return struct.pack("q", i) + def floatToBytes(f): return struct.pack("f", f) + def doubleToBytes(d): return struct.pack("d", d) + def stringToBytes(s): - return bytes(s, encoding='utf8') + return bytes(s, encoding="utf8") + def vectorBinaryToBytes(v): return bytes(v) + def vectorFloatToBytes(v): # pack len(v) number of float - bs = struct.pack(f'{len(v)}f', *v) + bs = struct.pack(f"{len(v)}f", *v) return bs + def bytesToInt64(v): return struct.unpack("q", v)[0] diff --git a/pymilvus/client/check.py b/pymilvus/client/check.py index 959318c49..0ca366b2d 100644 --- a/pymilvus/client/check.py +++ b/pymilvus/client/check.py @@ -1,14 +1,11 @@ -import sys import datetime +import sys from typing import Any, Union + from ..exceptions import ParamError from ..grpc_gen import milvus_pb2 as milvus_types from .singleton_utils import Singleton -from .utils import ( - valid_index_types, - valid_binary_index_types, - valid_index_params_keys, -) +from .utils import valid_binary_index_types, valid_index_params_keys, valid_index_types def is_legal_address(addr: Any) -> bool: @@ -47,9 +44,7 @@ def is_legal_port(port: Any) -> bool: def is_legal_vector(array: Any) -> bool: - if not array or \ - not isinstance(array, list) or \ - len(array) == 0: + if not array or not isinstance(array, list) or len(array) == 0: return False # for v in array: @@ -60,9 +55,7 @@ def is_legal_vector(array: Any) -> bool: def is_legal_bin_vector(array: Any) -> bool: - if not array or \ - not isinstance(array, bytes) or \ - len(array) == 0: + if not array or not isinstance(array, bytes) or len(array) == 0: return False return True @@ -81,7 +74,7 @@ def int_or_str(item: Union[int, str]) -> str: def is_correct_date_str(param: str) -> bool: try: - datetime.datetime.strptime(param, '%Y-%m-%d') + datetime.datetime.strptime(param, "%Y-%m-%d") except ValueError: return False @@ -140,16 +133,18 @@ def is_legal_cmd(cmd: Any) -> bool: def parser_range_date(date: Union[str, datetime.date]) -> str: if isinstance(date, datetime.date): - return date.strftime('%Y-%m-%d') + return date.strftime("%Y-%m-%d") if isinstance(date, str): if not is_correct_date_str(date): - raise ParamError(message='Date string should be YY-MM-DD format!') + raise ParamError(message="Date string should be YY-MM-DD format!") return date - raise ParamError(message='Date should be YY-MM-DD format string or datetime.date, ' - 'or datetime.datetime object') + raise ParamError( + message="Date should be YY-MM-DD format string or datetime.date, " + "or datetime.datetime object" + ) def is_legal_date_range(start: str, end: str) -> bool: @@ -175,6 +170,7 @@ def is_legal_anns_field(field: Any) -> bool: def is_legal_search_data(data: Any) -> bool: import numpy as np + if not isinstance(data, (list, np.ndarray)): return False @@ -217,19 +213,21 @@ def is_legal_partition_name_array(tag_array: Any) -> bool: # https://milvus.io/cn/docs/v1.0.0/metric.md#floating def is_legal_index_metric_type(index_type: str, metric_type: str) -> bool: - if index_type not in ("FLAT", - "IVF_FLAT", - "IVF_SQ8", - # "IVF_SQ8_HYBRID", - "IVF_PQ", - "HNSW", - # "NSG", - "ANNOY", - "RHNSW_FLAT", - "RHNSW_PQ", - "RHNSW_SQ", - "AUTOINDEX", - "DISKANN"): + if index_type not in ( + "FLAT", + "IVF_FLAT", + "IVF_SQ8", + # "IVF_SQ8_HYBRID", + "IVF_PQ", + "HNSW", + # "NSG", + "ANNOY", + "RHNSW_FLAT", + "RHNSW_PQ", + "RHNSW_SQ", + "AUTOINDEX", + "DISKANN", + ): return False if metric_type not in ("L2", "IP"): return False @@ -239,7 +237,13 @@ def is_legal_index_metric_type(index_type: str, metric_type: str) -> bool: # https://milvus.io/cn/docs/v1.0.0/metric.md#binary def is_legal_binary_index_metric_type(index_type: str, metric_type: str) -> bool: if index_type == "BIN_FLAT": - if metric_type in ("JACCARD", "TANIMOTO", "HAMMING", "SUBSTRUCTURE", "SUPERSTRUCTURE"): + if metric_type in ( + "JACCARD", + "TANIMOTO", + "HAMMING", + "SUBSTRUCTURE", + "SUPERSTRUCTURE", + ): return True elif index_type == "BIN_IVF_FLAT": if metric_type in ("JACCARD", "TANIMOTO", "HAMMING"): @@ -276,8 +280,10 @@ def is_legal_role_name(role_name: Any) -> bool: def is_legal_operate_user_role_type(operate_user_role_type: Any) -> bool: - return operate_user_role_type in \ - (milvus_types.OperateUserRoleType.AddUserToRole, milvus_types.OperateUserRoleType.RemoveUserFromRole) + return operate_user_role_type in ( + milvus_types.OperateUserRoleType.AddUserToRole, + milvus_types.OperateUserRoleType.RemoveUserFromRole, + ) def is_legal_include_user_info(include_user_info: Any) -> bool: @@ -305,8 +311,10 @@ def is_legal_collection_properties(properties: Any) -> bool: def is_legal_operate_privilege_type(operate_privilege_type: Any) -> bool: - return operate_privilege_type in \ - (milvus_types.OperatePrivilegeType.Grant, milvus_types.OperatePrivilegeType.Revoke) + return operate_privilege_type in ( + milvus_types.OperatePrivilegeType.Grant, + milvus_types.OperatePrivilegeType.Revoke, + ) class ParamChecker(metaclass=Singleton): @@ -350,10 +358,14 @@ def check(self, key, value): else: raise ParamError(message=f"unknown param `{key}`") + def _get_param_checker(): return ParamChecker() -def check_pass_param(*_args: Any, **kwargs: Any) -> None: # pylint: disable=too-many-statements + +def check_pass_param( + *_args: Any, **kwargs: Any +) -> None: # pylint: disable=too-many-statements if kwargs is None: raise ParamError(message="Param should not be None") checker = _get_param_checker() @@ -366,26 +378,36 @@ def check_index_params(params): if not isinstance(params, dict): raise ParamError(message="Params must be a dictionary type") # params preliminary validate - if 'index_type' not in params: + if "index_type" not in params: raise ParamError(message="Params must contains key: 'index_type'") - if 'params' not in params: + if "params" not in params: raise ParamError(message="Params must contains key: 'params'") - if 'metric_type' not in params: + if "metric_type" not in params: raise ParamError(message="Params must contains key: 'metric_type'") - if not isinstance(params['params'], dict): + if not isinstance(params["params"], dict): raise ParamError(message="Params['params'] must be a dictionary type") - if params['index_type'] not in valid_index_types: - raise ParamError(message=f"Invalid index_type: {params['index_type']}, which must be one of: {str(valid_index_types)}") - for k in params['params'].keys(): + if params["index_type"] not in valid_index_types: + raise ParamError( + message=f"Invalid index_type: {params['index_type']}, which must be one of: {str(valid_index_types)}" + ) + for k in params["params"].keys(): if k not in valid_index_params_keys: raise ParamError(message=f"Invalid params['params'].key: {k}") - for v in params['params'].values(): + for v in params["params"].values(): if not isinstance(v, int): - raise ParamError(message=f"Invalid params['params'].value: {v}, which must be an integer") + raise ParamError( + message=f"Invalid params['params'].value: {v}, which must be an integer" + ) # filter invalid metric type - if params['index_type'] in valid_binary_index_types: - if not is_legal_binary_index_metric_type(params['index_type'], params['metric_type']): - raise ParamError(message=f"Invalid metric_type: {params['metric_type']}, which does not match the index type: {params['index_type']}") + if params["index_type"] in valid_binary_index_types: + if not is_legal_binary_index_metric_type( + params["index_type"], params["metric_type"] + ): + raise ParamError( + message=f"Invalid metric_type: {params['metric_type']}, which does not match the index type: {params['index_type']}" + ) else: - if not is_legal_index_metric_type(params['index_type'], params['metric_type']): - raise ParamError(message=f"Invalid metric_type: {params['metric_type']}, which does not match the index type: {params['index_type']}") + if not is_legal_index_metric_type(params["index_type"], params["metric_type"]): + raise ParamError( + message=f"Invalid metric_type: {params['metric_type']}, which does not match the index type: {params['index_type']}" + ) diff --git a/pymilvus/client/configs.py b/pymilvus/client/configs.py index 1c97fcfb3..24952e213 100644 --- a/pymilvus/client/configs.py +++ b/pymilvus/client/configs.py @@ -3,5 +3,5 @@ class DefaultConfigs: WaitTimeDurationWhenLoad = 0.5 # in seconds MaxVarCharLengthKey = "max_length" MaxVarCharLength = 65535 - EncodeProtocol = 'utf-8' + EncodeProtocol = "utf-8" IndexName = "" diff --git a/pymilvus/client/constants.py b/pymilvus/client/constants.py index f7a4f0c52..6a4a8c450 100644 --- a/pymilvus/client/constants.py +++ b/pymilvus/client/constants.py @@ -1,5 +1,6 @@ from ..grpc_gen.common_pb2 import ConsistencyLevel + LOGICAL_BITS = 18 LOGICAL_BITS_MASK = (1 << LOGICAL_BITS) - 1 EVENTUALLY_TS = 1 diff --git a/pymilvus/client/entity_helper.py b/pymilvus/client/entity_helper.py index bc713b2b6..b801675af 100644 --- a/pymilvus/client/entity_helper.py +++ b/pymilvus/client/entity_helper.py @@ -1,7 +1,7 @@ -from ..grpc_gen import schema_pb2 as schema_types -from .types import DataType from ..exceptions import ParamError +from ..grpc_gen import schema_pb2 as schema_types from .configs import DefaultConfigs +from .types import DataType def entity_type_to_dtype(entity_type): @@ -24,12 +24,14 @@ def check_str_arr(str_arr, max_len): if not isinstance(s, str): raise ParamError(message=f"expect string input, got: {type(s)}") if len(s) >= max_len: - raise ParamError(message=f"invalid input, length of string exceeds max length. length: {len(s)}, max length: {max_len}") + raise ParamError( + message=f"invalid input, length of string exceeds max length. length: {len(s)}, max length: {max_len}" + ) def entity_to_str_arr(entity, field_info, check=True): arr = [] - if DefaultConfigs.EncodeProtocol.lower() != 'utf-8'.lower(): + if DefaultConfigs.EncodeProtocol.lower() != "utf-8".lower(): for s in entity.get("values"): arr.append(s.encode(DefaultConfigs.EncodeProtocol)) else: @@ -68,9 +70,11 @@ def entity_to_field_data(entity, field_info): field_data.vectors.float_vector.data.extend(all_floats) elif entity_type in (DataType.BINARY_VECTOR,): field_data.vectors.dim = len(entity.get("values")[0]) * 8 - field_data.vectors.binary_vector = b''.join(entity.get("values")) + field_data.vectors.binary_vector = b"".join(entity.get("values")) elif entity_type in (DataType.VARCHAR,): - field_data.scalars.string_data.data.extend(entity_to_str_arr(entity, field_info, True)) + field_data.scalars.string_data.data.extend( + entity_to_str_arr(entity, field_info, True) + ) else: raise ParamError(message=f"UnSupported data type: {entity_type}") diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index 0671c2653..641de7ae6 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -1,67 +1,56 @@ -import time -import json -import copy import base64 +import copy +import json +import time from urllib import parse import grpc import numpy as np from grpc._cython import cygrpc -from ..grpc_gen import milvus_pb2_grpc -from ..grpc_gen import milvus_pb2 as milvus_types -from ..grpc_gen import common_pb2 - -from .abstract import CollectionSchema, ChunkedQueryResult, MutationResult -from .check import ( - is_legal_host, - is_legal_port, - check_pass_param, - check_index_params, -) -from .prepare import Prepare -from .types import ( - Status, - IndexState, - DataType, - CompactionState, - State, - CompactionPlans, - Plan, - get_consistency_level, - Replica, Shard, Group, - GrantInfo, UserInfo, RoleInfo, - BulkInsertState, -) - -from .utils import ( - check_invalid_binary_vector, - len_of +from ..decorators import retry_on_rpc_failure +from ..exceptions import ( + AmbiguousIndexName, + DescribeCollectionException, + ExceptionsMessage, + MilvusException, + ParamError, ) - +from ..grpc_gen import common_pb2 +from ..grpc_gen import milvus_pb2 as milvus_types +from ..grpc_gen import milvus_pb2_grpc from ..settings import DefaultConfig as config -from .configs import DefaultConfigs -from . import ts_utils -from . import interceptor - +from . import interceptor, ts_utils +from .abstract import ChunkedQueryResult, CollectionSchema, MutationResult from .asynch import ( - SearchFuture, - MutationFuture, + ChunkedSearchFuture, CreateIndexFuture, FlushFuture, LoadPartitionsFuture, - ChunkedSearchFuture + MutationFuture, + SearchFuture, ) - -from ..exceptions import ( - ExceptionsMessage, - ParamError, - DescribeCollectionException, - MilvusException, - AmbiguousIndexName, +from .check import check_index_params, check_pass_param, is_legal_host, is_legal_port +from .configs import DefaultConfigs +from .prepare import Prepare +from .types import ( + BulkInsertState, + CompactionPlans, + CompactionState, + DataType, + GrantInfo, + Group, + IndexState, + Plan, + Replica, + RoleInfo, + Shard, + State, + Status, + UserInfo, + get_consistency_level, ) - -from ..decorators import retry_on_rpc_failure +from .utils import check_invalid_binary_vector, len_of class GrpcHandler: @@ -70,7 +59,9 @@ def __init__(self, uri=config.GRPC_URI, host="", port="", channel=None, **kwargs self._channel = channel addr = kwargs.get("address") - self._address = addr if addr is not None else self.__get_address(uri, host, port) + self._address = ( + addr if addr is not None else self.__get_address(uri, host, port) + ) self._log_level = None self._request_id = None self._set_authorization(**kwargs) @@ -98,7 +89,9 @@ def _set_authorization(self, **kwargs): self._server_name = kwargs.get("server_name", "") self._authorization_interceptor = None - self._setup_authorization_interceptor(kwargs.get("user", None), kwargs.get("password", None)) + self._setup_authorization_interceptor( + kwargs.get("user", None), kwargs.get("password", None) + ) def __enter__(self): return self @@ -112,68 +105,99 @@ def _wait_for_channel_ready(self, timeout=3): grpc.channel_ready_future(self._channel).result(timeout=timeout) return except grpc.FutureTimeoutError as e: - raise MilvusException(Status.CONNECT_FAILED, - f'Fail connecting to server on {self._address}. Timeout') from e - - raise MilvusException(Status.CONNECT_FAILED, 'No channel in handler, please setup grpc channel first') + raise MilvusException( + Status.CONNECT_FAILED, + f"Fail connecting to server on {self._address}. Timeout", + ) from e + + raise MilvusException( + Status.CONNECT_FAILED, + "No channel in handler, please setup grpc channel first", + ) def close(self): self._channel.close() def _setup_authorization_interceptor(self, user, password): if user and password: - authorization = base64.b64encode(f"{user}:{password}".encode('utf-8')) + authorization = base64.b64encode(f"{user}:{password}".encode()) key = "authorization" - self._authorization_interceptor = interceptor.header_adder_interceptor(key, authorization) + self._authorization_interceptor = interceptor.header_adder_interceptor( + key, authorization + ) def _setup_grpc_channel(self): - """ Create a ddl grpc channel """ + """Create a ddl grpc channel""" if self._channel is None: - opts = [(cygrpc.ChannelArgKey.max_send_message_length, -1), - (cygrpc.ChannelArgKey.max_receive_message_length, -1), - ('grpc.enable_retries', 1), - ('grpc.keepalive_time_ms', 55000), - ] + opts = [ + (cygrpc.ChannelArgKey.max_send_message_length, -1), + (cygrpc.ChannelArgKey.max_receive_message_length, -1), + ("grpc.enable_retries", 1), + ("grpc.keepalive_time_ms", 55000), + ] if not self._secure: self._channel = grpc.insecure_channel( self._address, options=opts, ) else: - if self._client_pem_path != "" and self._client_key_path != "" and self._ca_pem_path != "" \ - and self._server_name != "": - opts.append(('grpc.ssl_target_name_override', self._server_name, ),) - with open(self._client_pem_path, 'rb') as f: + if ( + self._client_pem_path != "" + and self._client_key_path != "" + and self._ca_pem_path != "" + and self._server_name != "" + ): + opts.append( + ( + "grpc.ssl_target_name_override", + self._server_name, + ), + ) + with open(self._client_pem_path, "rb") as f: certificate_chain = f.read() - with open(self._client_key_path, 'rb') as f: + with open(self._client_key_path, "rb") as f: private_key = f.read() - with open(self._ca_pem_path, 'rb') as f: + with open(self._ca_pem_path, "rb") as f: root_certificates = f.read() - creds = grpc.ssl_channel_credentials(root_certificates, private_key, certificate_chain) + creds = grpc.ssl_channel_credentials( + root_certificates, private_key, certificate_chain + ) elif self._server_pem_path != "" and self._server_name != "": - opts.append(('grpc.ssl_target_name_override', self._server_name,), ) - with open(self._server_pem_path, 'rb') as f: + opts.append( + ( + "grpc.ssl_target_name_override", + self._server_name, + ), + ) + with open(self._server_pem_path, "rb") as f: server_pem = f.read() creds = grpc.ssl_channel_credentials(root_certificates=server_pem) else: - creds = grpc.ssl_channel_credentials(root_certificates=None, private_key=None, - certificate_chain=None) - self._channel = grpc.secure_channel( - self._address, - creds, - options=opts - ) + creds = grpc.ssl_channel_credentials( + root_certificates=None, private_key=None, certificate_chain=None + ) + self._channel = grpc.secure_channel(self._address, creds, options=opts) # avoid to add duplicate headers. self._final_channel = self._channel if self._authorization_interceptor: - self._final_channel = grpc.intercept_channel(self._final_channel, self._authorization_interceptor) + self._final_channel = grpc.intercept_channel( + self._final_channel, self._authorization_interceptor + ) if self._log_level: - log_level_interceptor = interceptor.header_adder_interceptor("log_level", self._log_level) - self._final_channel = grpc.intercept_channel(self._final_channel, log_level_interceptor) + log_level_interceptor = interceptor.header_adder_interceptor( + "log_level", self._log_level + ) + self._final_channel = grpc.intercept_channel( + self._final_channel, log_level_interceptor + ) self._log_level = None if self._request_id: - request_id_interceptor = interceptor.header_adder_interceptor("client_request_id", self._request_id) - self._final_channel = grpc.intercept_channel(self._final_channel, request_id_interceptor) + request_id_interceptor = interceptor.header_adder_interceptor( + "client_request_id", self._request_id + ) + self._final_channel = grpc.intercept_channel( + self._final_channel, request_id_interceptor + ) self._request_id = None self._stub = milvus_pb2_grpc.MilvusServiceStub(self._final_channel) @@ -187,7 +211,7 @@ def set_onetime_request_id(self, req_id): @property def server_address(self): - """ Server network address """ + """Server network address""" return self._address def reset_password(self, user, old_password, new_password, timeout=None): @@ -199,8 +223,12 @@ def reset_password(self, user, old_password, new_password, timeout=None): self._setup_grpc_channel() @retry_on_rpc_failure() - def create_collection(self, collection_name, fields, shards_num=2, timeout=None, **kwargs): - request = Prepare.create_collection_request(collection_name, fields, shards_num=shards_num, **kwargs) + def create_collection( + self, collection_name, fields, shards_num=2, timeout=None, **kwargs + ): + request = Prepare.create_collection_request( + collection_name, fields, shards_num=shards_num, **kwargs + ) rf = self._stub.CreateCollection.future(request, timeout=timeout) if kwargs.get("_async", False): @@ -239,7 +267,10 @@ def has_collection(self, collection_name, timeout=None, **kwargs): return True # TODO: Workaround for unreasonable describe collection results and error_code - if reply.status.error_code == common_pb2.UnexpectedError and "can\'t find collection" in reply.status.reason: + if ( + reply.status.error_code == common_pb2.UnexpectedError + and "can't find collection" in reply.status.reason + ): return False raise MilvusException(reply.status.error_code, reply.status.reason) @@ -329,7 +360,9 @@ def list_partitions(self, collection_name, timeout=None): raise MilvusException(status.error_code, status.reason) @retry_on_rpc_failure() - def get_partition_stats(self, collection_name, partition_name, timeout=None, **kwargs): + def get_partition_stats( + self, collection_name, partition_name, timeout=None, **kwargs + ): check_pass_param(collection_name=collection_name) req = Prepare.get_partition_stats_request(collection_name, partition_name) future = self._stub.GetPartitionStatistics.future(req, timeout=timeout) @@ -340,8 +373,10 @@ def get_partition_stats(self, collection_name, partition_name, timeout=None, **k raise MilvusException(status.error_code, status.reason) - def _prepare_batch_insert_request(self, collection_name, entities, partition_name=None, timeout=None, **kwargs): - insert_param = kwargs.get('insert_param', None) + def _prepare_batch_insert_request( + self, collection_name, entities, partition_name=None, timeout=None, **kwargs + ): + insert_param = kwargs.get("insert_param", None) if insert_param and not isinstance(insert_param, milvus_types.RowBatch): raise ParamError(message="The value of key 'insert_param' is invalid") @@ -350,22 +385,33 @@ def _prepare_batch_insert_request(self, collection_name, entities, partition_nam collection_schema = kwargs.get("schema", None) if not collection_schema: - collection_schema = self.describe_collection(collection_name, timeout=timeout, **kwargs) + collection_schema = self.describe_collection( + collection_name, timeout=timeout, **kwargs + ) fields_info = collection_schema["fields"] - request = insert_param if insert_param \ - else Prepare.batch_insert_param(collection_name, entities, partition_name, fields_info) + request = ( + insert_param + if insert_param + else Prepare.batch_insert_param( + collection_name, entities, partition_name, fields_info + ) + ) return request @retry_on_rpc_failure() - def batch_insert(self, collection_name, entities, partition_name=None, timeout=None, **kwargs): + def batch_insert( + self, collection_name, entities, partition_name=None, timeout=None, **kwargs + ): if not check_invalid_binary_vector(entities): raise ParamError(message="Invalid binary vector data exists") try: - request = self._prepare_batch_insert_request(collection_name, entities, partition_name, timeout, **kwargs) + request = self._prepare_batch_insert_request( + collection_name, entities, partition_name, timeout, **kwargs + ) rf = self._stub.Insert.future(request, timeout=timeout) if kwargs.get("_async", False) is True: cb = kwargs.get("_callback", None) @@ -386,7 +432,9 @@ def batch_insert(self, collection_name, entities, partition_name=None, timeout=N raise err @retry_on_rpc_failure() - def delete(self, collection_name, expression, partition_name=None, timeout=None, **kwargs): + def delete( + self, collection_name, expression, partition_name=None, timeout=None, **kwargs + ): check_pass_param(collection_name=collection_name) try: req = Prepare.delete_request(collection_name, partition_name, expression) @@ -427,7 +475,9 @@ def _execute_search_requests(self, requests, timeout=None, **kwargs): response = self._stub.Search(request, timeout=timeout) if response.status.error_code != 0: - raise MilvusException(response.status.error_code, response.status.reason) + raise MilvusException( + response.status.error_code, response.status.reason + ) raws.append(response) round_decimal = kwargs.get("round_decimal", -1) @@ -439,9 +489,21 @@ def _execute_search_requests(self, requests, timeout=None, **kwargs): raise pre_err @retry_on_rpc_failure(retry_on_deadline=False) - def search(self, collection_name, data, anns_field, param, limit, - expression=None, partition_names=None, output_fields=None, - round_decimal=-1, timeout=None, schema=None, **kwargs): + def search( + self, + collection_name, + data, + anns_field, + param, + limit, + expression=None, + partition_names=None, + output_fields=None, + round_decimal=-1, + timeout=None, + schema=None, + **kwargs, + ): check_pass_param( limit=limit, round_decimal=round_decimal, @@ -450,24 +512,40 @@ def search(self, collection_name, data, anns_field, param, limit, partition_name_array=partition_names, output_fields=output_fields, travel_timestamp=kwargs.get("travel_timestamp", 0), - guarantee_timestamp=kwargs.get("guarantee_timestamp", 0) + guarantee_timestamp=kwargs.get("guarantee_timestamp", 0), ) if schema is None: - schema = self.describe_collection(collection_name, timeout=timeout, **kwargs) + schema = self.describe_collection( + collection_name, timeout=timeout, **kwargs + ) consistency_level = schema["consistency_level"] # overwrite the consistency level defined when user created the collection - consistency_level = get_consistency_level(kwargs.get("consistency_level", consistency_level)) + consistency_level = get_consistency_level( + kwargs.get("consistency_level", consistency_level) + ) ts_utils.construct_guarantee_ts(consistency_level, collection_name, kwargs) - requests = Prepare.search_requests_with_expr(collection_name, data, anns_field, param, limit, schema, - expression, partition_names, output_fields, round_decimal, - **kwargs) + requests = Prepare.search_requests_with_expr( + collection_name, + data, + anns_field, + param, + limit, + schema, + expression, + partition_names, + output_fields, + round_decimal, + **kwargs, + ) auto_id = schema["auto_id"] - return self._execute_search_requests(requests, timeout, round_decimal=round_decimal, auto_id=auto_id, **kwargs) + return self._execute_search_requests( + requests, timeout, round_decimal=round_decimal, auto_id=auto_id, **kwargs + ) @retry_on_rpc_failure() def get_query_segment_info(self, collection_name, timeout=30, **kwargs): @@ -511,34 +589,47 @@ def create_index(self, collection_name, field_name, params, timeout=None, **kwar index_name = kwargs.pop("index_name", DefaultConfigs.IndexName) copy_kwargs = copy.deepcopy(kwargs) - collection_desc = self.describe_collection(collection_name, timeout=timeout, **copy_kwargs) + collection_desc = self.describe_collection( + collection_name, timeout=timeout, **copy_kwargs + ) valid_field = False for fields in collection_desc["fields"]: if field_name != fields["name"]: continue valid_field = True - if fields["type"] != DataType.FLOAT_VECTOR and fields["type"] != DataType.BINARY_VECTOR: + if ( + fields["type"] != DataType.FLOAT_VECTOR + and fields["type"] != DataType.BINARY_VECTOR + ): break # check index params on vector field. check_index_params(params) if not valid_field: - raise MilvusException(message=f"cannot create index on non-existed field: {field_name}") + raise MilvusException( + message=f"cannot create index on non-existed field: {field_name}" + ) # sync flush _async = kwargs.get("_async", False) kwargs["_async"] = False - index_param = Prepare.create_index_request(collection_name, field_name, params, index_name=index_name) + index_param = Prepare.create_index_request( + collection_name, field_name, params, index_name=index_name + ) future = self._stub.CreateIndex.future(index_param, timeout=timeout) if _async: + def _check(): if kwargs.get("sync", True): - index_success, fail_reason = self.wait_for_creating_index(collection_name=collection_name, - index_name=index_name, - timeout=timeout, field_name=field_name) + index_success, fail_reason = self.wait_for_creating_index( + collection_name=collection_name, + index_name=index_name, + timeout=timeout, + field_name=field_name, + ) if not index_success: raise MilvusException(message=fail_reason) @@ -555,9 +646,12 @@ def _check(): raise MilvusException(status.error_code, status.reason) if kwargs.get("sync", True): - index_success, fail_reason = self.wait_for_creating_index(collection_name=collection_name, - index_name=index_name, - timeout=timeout, field_name=field_name) + index_success, fail_reason = self.wait_for_creating_index( + collection_name=collection_name, + index_name=index_name, + timeout=timeout, + field_name=field_name, + ) if not index_success: raise MilvusException(message=fail_reason) @@ -590,9 +684,11 @@ def describe_index(self, collection_name, index_name, timeout=None, **kwargs): if status.error_code != 0: raise MilvusException(status.error_code, status.reason) if len(response.index_descriptions) == 1: - info_dict = {kv.key: kv.value for kv in response.index_descriptions[0].params} - info_dict['field_name'] = response.index_descriptions[0].field_name - info_dict['index_name'] = response.index_descriptions[0].index_name + info_dict = { + kv.key: kv.value for kv in response.index_descriptions[0].params + } + info_dict["field_name"] = response.index_descriptions[0].field_name + info_dict["index_name"] = response.index_descriptions[0].index_name if info_dict.get("params", None): info_dict["params"] = json.loads(info_dict["params"]) return info_dict @@ -608,12 +704,17 @@ def get_index_build_progress(self, collection_name, index_name, timeout=None): if status.error_code == 0: if len(response.index_descriptions) == 1: index_desc = response.index_descriptions[0] - return {'total_rows': index_desc.total_rows, 'indexed_rows': index_desc.indexed_rows} + return { + "total_rows": index_desc.total_rows, + "indexed_rows": index_desc.indexed_rows, + } raise AmbiguousIndexName(message=ExceptionsMessage.AmbiguousIndexName) raise MilvusException(status.error_code, status.reason) @retry_on_rpc_failure() - def get_index_state(self, collection_name: str, index_name: str, timeout=None, **kwargs): + def get_index_state( + self, collection_name: str, index_name: str, timeout=None, **kwargs + ): request = Prepare.describe_index_request(collection_name, index_name) rf = self._stub.DescribeIndex.future(request, timeout=timeout) response = rf.result() @@ -634,21 +735,29 @@ def get_index_state(self, collection_name: str, index_name: str, timeout=None, * raise AmbiguousIndexName(message=ExceptionsMessage.AmbiguousIndexName) @retry_on_rpc_failure() - def wait_for_creating_index(self, collection_name, index_name, timeout=None, **kwargs): + def wait_for_creating_index( + self, collection_name, index_name, timeout=None, **kwargs + ): start = time.time() while True: time.sleep(0.5) - state, fail_reason = self.get_index_state(collection_name, index_name, timeout=timeout, **kwargs) + state, fail_reason = self.get_index_state( + collection_name, index_name, timeout=timeout, **kwargs + ) if state == IndexState.Finished: return True, fail_reason if state == IndexState.Failed: return False, fail_reason end = time.time() if isinstance(timeout, int) and end - start > timeout: - raise MilvusException(message=f"collection {collection_name} create index {index_name} timeout in {timeout}s") + raise MilvusException( + message=f"collection {collection_name} create index {index_name} timeout in {timeout}s" + ) @retry_on_rpc_failure() - def load_collection(self, collection_name, replica_number=1, timeout=None, **kwargs): + def load_collection( + self, collection_name, replica_number=1, timeout=None, **kwargs + ): check_pass_param(collection_name=collection_name) request = Prepare.load_collection("", collection_name, replica_number) rf = self._stub.LoadCollection.future(request, timeout=timeout) @@ -661,7 +770,7 @@ def load_collection(self, collection_name, replica_number=1, timeout=None, **kwa @retry_on_rpc_failure() def load_collection_progress(self, collection_name, timeout=None): - """ Return loading progress of collection """ + """Return loading progress of collection""" progress = self.get_loading_progress(collection_name, timeout=timeout) return { "loading_progress": f"{progress:.0f}%", @@ -679,7 +788,9 @@ def can_loop(t) -> bool: if progress >= 100: return time.sleep(DefaultConfigs.WaitTimeDurationWhenLoad) - raise MilvusException(message=f"wait for loading collection timeout, collection: {collection_name}") + raise MilvusException( + message=f"wait for loading collection timeout, collection: {collection_name}" + ) @retry_on_rpc_failure() def release_collection(self, collection_name, timeout=None): @@ -691,12 +802,19 @@ def release_collection(self, collection_name, timeout=None): raise MilvusException(response.error_code, response.reason) @retry_on_rpc_failure() - def load_partitions(self, collection_name, partition_names, replica_number=1, timeout=None, **kwargs): - check_pass_param(collection_name=collection_name, partition_name_array=partition_names) - request = Prepare.load_partitions("", collection_name, partition_names, replica_number) + def load_partitions( + self, collection_name, partition_names, replica_number=1, timeout=None, **kwargs + ): + check_pass_param( + collection_name=collection_name, partition_name_array=partition_names + ) + request = Prepare.load_partitions( + "", collection_name, partition_names, replica_number + ) future = self._stub.LoadPartitions.future(request, timeout=timeout) if kwargs.get("_async", False): + def _check(): if kwargs.get("sync", True): self.wait_for_loading_partitions(collection_name, partition_names) @@ -718,30 +836,38 @@ def _check(): self.wait_for_loading_partitions(collection_name, partition_names) @retry_on_rpc_failure() - def wait_for_loading_partitions(self, collection_name, partition_names, timeout=None): + def wait_for_loading_partitions( + self, collection_name, partition_names, timeout=None + ): start = time.time() def can_loop(t) -> bool: return True if timeout is None else t <= (start + timeout) while can_loop(time.time()): - progress = self.get_loading_progress(collection_name, partition_names, timeout=timeout) + progress = self.get_loading_progress( + collection_name, partition_names, timeout=timeout + ) if progress >= 100: return time.sleep(DefaultConfigs.WaitTimeDurationWhenLoad) - raise MilvusException(message=f"wait for loading partition timeout, collection: {collection_name}, partitions: {partition_names}") + raise MilvusException( + message=f"wait for loading partition timeout, collection: {collection_name}, partitions: {partition_names}" + ) @retry_on_rpc_failure() def get_loading_progress(self, collection_name, partition_names=None, timeout=None): request = Prepare.get_loading_progress(collection_name, partition_names) - response = self._stub.GetLoadingProgress.future(request, timeout=timeout).result() + response = self._stub.GetLoadingProgress.future( + request, timeout=timeout + ).result() if response.status.error_code != 0: raise MilvusException(response.status.error_code, response.status.reason) return response.progress @retry_on_rpc_failure() def load_partitions_progress(self, collection_name, partition_names, timeout=None): - """ Return loading progress of partitions """ + """Return loading progress of partitions""" progress = self.get_loading_progress(collection_name, partition_names, timeout) return { "loading_progress": f"{progress:.0f}%", @@ -749,7 +875,9 @@ def load_partitions_progress(self, collection_name, partition_names, timeout=Non @retry_on_rpc_failure() def release_partitions(self, collection_name, partition_names, timeout=None): - check_pass_param(collection_name=collection_name, partition_name_array=partition_names) + check_pass_param( + collection_name=collection_name, partition_name_array=partition_names + ) request = Prepare.release_partitions("", collection_name, partition_names) rf = self._stub.ReleasePartitions.future(request, timeout=timeout) response = rf.result() @@ -797,7 +925,9 @@ def _wait_for_flushed(self, segment_ids, timeout=None, **kwargs): end = time.time() if timeout is not None: if end - start > timeout: - raise MilvusException(message=f"wait for flush timeout, segment ids: {segment_ids}") + raise MilvusException( + message=f"wait for flush timeout, segment ids: {segment_ids}" + ) if not flush_ret: time.sleep(0.5) @@ -834,7 +964,9 @@ def _check(): _check() @retry_on_rpc_failure() - def drop_index(self, collection_name, field_name, index_name, timeout=None, **kwargs): + def drop_index( + self, collection_name, field_name, index_name, timeout=None, **kwargs + ): check_pass_param(collection_name=collection_name, field_name=field_name) request = Prepare.drop_index_request(collection_name, field_name, index_name) future = self._stub.DropIndex.future(request, timeout=timeout) @@ -857,25 +989,48 @@ def fake_register_link(self, timeout=None): # TODO seems not in use @retry_on_rpc_failure() - def get(self, collection_name, ids, output_fields=None, partition_names=None, timeout=None): + def get( + self, + collection_name, + ids, + output_fields=None, + partition_names=None, + timeout=None, + ): # TODO: some check - request = Prepare.retrieve_request(collection_name, ids, output_fields, partition_names) + request = Prepare.retrieve_request( + collection_name, ids, output_fields, partition_names + ) future = self._stub.Retrieve.future(request, timeout=timeout) return future.result() @retry_on_rpc_failure() - def query(self, collection_name, expr, output_fields=None, partition_names=None, timeout=None, **kwargs): + def query( + self, + collection_name, + expr, + output_fields=None, + partition_names=None, + timeout=None, + **kwargs, + ): if output_fields is not None and not isinstance(output_fields, (list,)): - raise ParamError(message="Invalid query format. 'output_fields' must be a list") + raise ParamError( + message="Invalid query format. 'output_fields' must be a list" + ) collection_schema = kwargs.get("schema", None) if not collection_schema: collection_schema = self.describe_collection(collection_name, timeout) consistency_level = collection_schema["consistency_level"] # overwrite the consistency level defined when user created the collection - consistency_level = get_consistency_level(kwargs.get("consistency_level", consistency_level)) + consistency_level = get_consistency_level( + kwargs.get("consistency_level", consistency_level) + ) ts_utils.construct_guarantee_ts(consistency_level, collection_name, kwargs) - request = Prepare.query_request(collection_name, expr, output_fields, partition_names, **kwargs) + request = Prepare.query_request( + collection_name, expr, output_fields, partition_names, **kwargs + ) future = self._stub.Query.future(request, timeout=timeout) response = future.result() @@ -901,17 +1056,29 @@ def query(self, collection_name, expr, output_fields=None, partition_names=None, result = {} for field_data in response.fields_data: if field_data.type == DataType.BOOL: - result[field_data.field_name] = field_data.scalars.bool_data.data[index] + result[field_data.field_name] = field_data.scalars.bool_data.data[ + index + ] elif field_data.type in (DataType.INT8, DataType.INT16, DataType.INT32): - result[field_data.field_name] = field_data.scalars.int_data.data[index] + result[field_data.field_name] = field_data.scalars.int_data.data[ + index + ] elif field_data.type == DataType.INT64: - result[field_data.field_name] = field_data.scalars.long_data.data[index] + result[field_data.field_name] = field_data.scalars.long_data.data[ + index + ] elif field_data.type == DataType.FLOAT: - result[field_data.field_name] = np.single(field_data.scalars.float_data.data[index]) + result[field_data.field_name] = np.single( + field_data.scalars.float_data.data[index] + ) elif field_data.type == DataType.DOUBLE: - result[field_data.field_name] = field_data.scalars.double_data.data[index] + result[field_data.field_name] = field_data.scalars.double_data.data[ + index + ] elif field_data.type == DataType.VARCHAR: - result[field_data.field_name] = field_data.scalars.string_data.data[index] + result[field_data.field_name] = field_data.scalars.string_data.data[ + index + ] elif field_data.type == DataType.STRING: raise MilvusException(message="Not support string yet") # result[field_data.field_name] = field_data.scalars.string_data.data[index] @@ -919,20 +1086,34 @@ def query(self, collection_name, expr, output_fields=None, partition_names=None, dim = field_data.vectors.dim start_pos = index * dim end_pos = index * dim + dim - result[field_data.field_name] = [np.single(x) for x in - field_data.vectors.float_vector.data[start_pos:end_pos]] + result[field_data.field_name] = [ + np.single(x) + for x in field_data.vectors.float_vector.data[start_pos:end_pos] + ] elif field_data.type == DataType.BINARY_VECTOR: dim = field_data.vectors.dim start_pos = index * (int(dim / 8)) end_pos = (index + 1) * (int(dim / 8)) - result[field_data.field_name] = field_data.vectors.binary_vector[start_pos:end_pos] + result[field_data.field_name] = field_data.vectors.binary_vector[ + start_pos:end_pos + ] results.append(result) return results @retry_on_rpc_failure() - def load_balance(self, collection_name: str, src_node_id, dst_node_ids, sealed_segment_ids, timeout=None, **kwargs): - req = Prepare.load_balance_request(collection_name, src_node_id, dst_node_ids, sealed_segment_ids) + def load_balance( + self, + collection_name: str, + src_node_id, + dst_node_ids, + sealed_segment_ids, + timeout=None, + **kwargs, + ): + req = Prepare.load_balance_request( + collection_name, src_node_id, dst_node_ids, sealed_segment_ids + ) future = self._stub.LoadBalance.future(req, timeout=timeout) status = future.result() if status.error_code != 0: @@ -955,7 +1136,9 @@ def compact(self, collection_name, timeout=None, **kwargs) -> int: return response.compactionID @retry_on_rpc_failure() - def get_compaction_state(self, compaction_id, timeout=None, **kwargs) -> CompactionState: + def get_compaction_state( + self, compaction_id, timeout=None, **kwargs + ) -> CompactionState: req = Prepare.get_compaction_state(compaction_id) future = self._stub.GetCompactionState.future(req, timeout=timeout) @@ -968,7 +1151,7 @@ def get_compaction_state(self, compaction_id, timeout=None, **kwargs) -> Compact State.new(response.state), response.executingPlanNo, response.timeoutPlanNo, - response.completedPlanNo + response.completedPlanNo, ) @retry_on_rpc_failure() @@ -976,7 +1159,9 @@ def wait_for_compaction_completed(self, compaction_id, timeout=None, **kwargs): start = time.time() while True: time.sleep(0.5) - compaction_state = self.get_compaction_state(compaction_id, timeout, **kwargs) + compaction_state = self.get_compaction_state( + compaction_id, timeout, **kwargs + ) if compaction_state.state == State.Completed: return True if compaction_state == State.UndefiedState: @@ -984,10 +1169,14 @@ def wait_for_compaction_completed(self, compaction_id, timeout=None, **kwargs): end = time.time() if timeout is not None: if end - start > timeout: - raise MilvusException(message=f"get compaction state timeout, compaction id: {compaction_id}") + raise MilvusException( + message=f"get compaction state timeout, compaction id: {compaction_id}" + ) @retry_on_rpc_failure() - def get_compaction_plans(self, compaction_id, timeout=None, **kwargs) -> CompactionPlans: + def get_compaction_plans( + self, compaction_id, timeout=None, **kwargs + ) -> CompactionPlans: req = Prepare.get_compaction_state_with_plans(compaction_id) future = self._stub.GetCompactionStateWithPlans.future(req, timeout=timeout) @@ -1003,7 +1192,9 @@ def get_compaction_plans(self, compaction_id, timeout=None, **kwargs) -> Compact @retry_on_rpc_failure() def get_replicas(self, collection_name, timeout=None, **kwargs) -> Replica: - collection_id = self.describe_collection(collection_name, timeout, **kwargs)["collection_id"] + collection_id = self.describe_collection(collection_name, timeout, **kwargs)[ + "collection_id" + ] req = Prepare.get_replicas(collection_id) future = self._stub.GetReplicas.future(req, timeout=timeout) @@ -1013,20 +1204,27 @@ def get_replicas(self, collection_name, timeout=None, **kwargs) -> Replica: groups = [] for replica in response.replicas: - shards = [Shard(s.dm_channel_name, s.node_ids, s.leaderID) for s in replica.shard_replicas] + shards = [ + Shard(s.dm_channel_name, s.node_ids, s.leaderID) + for s in replica.shard_replicas + ] groups.append(Group(replica.replicaID, shards, replica.node_ids)) return Replica(groups) @retry_on_rpc_failure() - def do_bulk_insert(self, collection_name, partition_name, files: list, timeout=None, **kwargs) -> int: + def do_bulk_insert( + self, collection_name, partition_name, files: list, timeout=None, **kwargs + ) -> int: req = Prepare.do_bulk_insert(collection_name, partition_name, files, **kwargs) future = self._stub.Import.future(req, timeout=timeout) response = future.result() if response.status.error_code != 0: raise MilvusException(response.status.error_code, response.status.reason) if len(response.tasks) == 0: - raise MilvusException(common_pb2.UNEXPECTED_ERROR, "no task id returned from server") + raise MilvusException( + common_pb2.UNEXPECTED_ERROR, "no task id returned from server" + ) return response.tasks[0] @retry_on_rpc_failure() @@ -1036,19 +1234,30 @@ def get_bulk_insert_state(self, task_id, timeout=None, **kwargs) -> BulkInsertSt resp = future.result() if resp.status.error_code != 0: raise MilvusException(resp.status.error_code, resp.status.reason) - state = BulkInsertState(task_id, resp.state, resp.row_count, resp.id_list, resp.infos, resp.create_ts) + state = BulkInsertState( + task_id, + resp.state, + resp.row_count, + resp.id_list, + resp.infos, + resp.create_ts, + ) return state @retry_on_rpc_failure() - def list_bulk_insert_tasks(self, limit, collection_name, timeout=None, **kwargs) -> list: + def list_bulk_insert_tasks( + self, limit, collection_name, timeout=None, **kwargs + ) -> list: req = Prepare.list_bulk_insert_tasks(limit, collection_name) future = self._stub.ListImportTasks.future(req, timeout=timeout) resp = future.result() if resp.status.error_code != 0: raise MilvusException(resp.status.error_code, resp.status.reason) - tasks = [BulkInsertState(t.id, t.state, t.row_count, t.id_list, t.infos, t.create_ts) - for t in resp.tasks] + tasks = [ + BulkInsertState(t.id, t.state, t.row_count, t.id_list, t.infos, t.create_ts) + for t in resp.tasks + ] return tasks @retry_on_rpc_failure() @@ -1097,15 +1306,18 @@ def drop_role(self, role_name, timeout=None, **kwargs): @retry_on_rpc_failure() def add_user_to_role(self, username, role_name, timeout=None, **kwargs): - req = Prepare.operate_user_role_request(username, role_name, milvus_types.OperateUserRoleType.AddUserToRole) + req = Prepare.operate_user_role_request( + username, role_name, milvus_types.OperateUserRoleType.AddUserToRole + ) resp = self._stub.OperateUserRole(req, wait_for_ready=True, timeout=timeout) if resp.error_code != 0: raise MilvusException(resp.error_code, resp.reason) @retry_on_rpc_failure() def remove_user_from_role(self, username, role_name, timeout=None, **kwargs): - req = Prepare.operate_user_role_request(username, role_name, - milvus_types.OperateUserRoleType.RemoveUserFromRole) + req = Prepare.operate_user_role_request( + username, role_name, milvus_types.OperateUserRoleType.RemoveUserFromRole + ) resp = self._stub.OperateUserRole(req, wait_for_ready=True, timeout=timeout) if resp.error_code != 0: raise MilvusException(resp.error_code, resp.reason) @@ -1143,17 +1355,31 @@ def select_all_user(self, include_role_info, timeout=None, **kwargs): return UserInfo(resp.results) @retry_on_rpc_failure() - def grant_privilege(self, role_name, object, object_name, privilege, timeout=None, **kwargs): - req = Prepare.operate_privilege_request(role_name, object, object_name, privilege, - milvus_types.OperatePrivilegeType.Grant) + def grant_privilege( + self, role_name, object, object_name, privilege, timeout=None, **kwargs + ): + req = Prepare.operate_privilege_request( + role_name, + object, + object_name, + privilege, + milvus_types.OperatePrivilegeType.Grant, + ) resp = self._stub.OperatePrivilege(req, wait_for_ready=True, timeout=timeout) if resp.error_code != 0: raise MilvusException(resp.error_code, resp.reason) @retry_on_rpc_failure() - def revoke_privilege(self, role_name, object, object_name, privilege, timeout=None, **kwargs): - req = Prepare.operate_privilege_request(role_name, object, object_name, privilege, - milvus_types.OperatePrivilegeType.Revoke) + def revoke_privilege( + self, role_name, object, object_name, privilege, timeout=None, **kwargs + ): + req = Prepare.operate_privilege_request( + role_name, + object, + object_name, + privilege, + milvus_types.OperatePrivilegeType.Revoke, + ) resp = self._stub.OperatePrivilege(req, wait_for_ready=True, timeout=timeout) if resp.error_code != 0: raise MilvusException(resp.error_code, resp.reason) @@ -1168,7 +1394,9 @@ def select_grant_for_one_role(self, role_name, timeout=None, **kwargs): return GrantInfo(resp.entities) @retry_on_rpc_failure() - def select_grant_for_role_and_object(self, role_name, object, object_name, timeout=None, **kwargs): + def select_grant_for_role_and_object( + self, role_name, object, object_name, timeout=None, **kwargs + ): req = Prepare.select_grant_request(role_name, object, object_name) resp = self._stub.SelectGrant(req, wait_for_ready=True, timeout=timeout) if resp.status.error_code != 0: diff --git a/pymilvus/client/interceptor.py b/pymilvus/client/interceptor.py index 028a0f577..6db1b61b8 100644 --- a/pymilvus/client/interceptor.py +++ b/pymilvus/client/interceptor.py @@ -18,39 +18,45 @@ import grpc -class _GenericClientInterceptor(grpc.UnaryUnaryClientInterceptor, - grpc.UnaryStreamClientInterceptor, - grpc.StreamUnaryClientInterceptor, - grpc.StreamStreamClientInterceptor): - +class _GenericClientInterceptor( + grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + grpc.StreamStreamClientInterceptor, +): def __init__(self, interceptor_function): super().__init__() self._fn = interceptor_function def intercept_unary_unary(self, continuation, client_call_details, request): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, iter((request,)), False, False) + client_call_details, iter((request,)), False, False + ) response = continuation(new_details, next(new_request_iterator)) return postprocess(response) if postprocess else response - def intercept_unary_stream(self, continuation, client_call_details, - request): + def intercept_unary_stream(self, continuation, client_call_details, request): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, iter((request,)), False, True) + client_call_details, iter((request,)), False, True + ) response_it = continuation(new_details, next(new_request_iterator)) return postprocess(response_it) if postprocess else response_it - def intercept_stream_unary(self, continuation, client_call_details, - request_iterator): + def intercept_stream_unary( + self, continuation, client_call_details, request_iterator + ): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, request_iterator, True, False) + client_call_details, request_iterator, True, False + ) response = continuation(new_details, new_request_iterator) return postprocess(response) if postprocess else response - def intercept_stream_stream(self, continuation, client_call_details, - request_iterator): + def intercept_stream_stream( + self, continuation, client_call_details, request_iterator + ): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, request_iterator, True, True) + client_call_details, request_iterator, True, True + ) response_it = continuation(new_details, new_request_iterator) return postprocess(response_it) if postprocess else response_it @@ -60,27 +66,33 @@ def create(intercept_call): class _ClientCallDetails( - collections.namedtuple( - '_ClientCallDetails', - ('method', 'timeout', 'metadata', 'credentials')), - grpc.ClientCallDetails): + collections.namedtuple( + "_ClientCallDetails", ("method", "timeout", "metadata", "credentials") + ), + grpc.ClientCallDetails, +): pass def header_adder_interceptor(header, value): - - def intercept_call(client_call_details, request_iterator, request_streaming, - response_streaming): + def intercept_call( + client_call_details, request_iterator, request_streaming, response_streaming + ): metadata = [] if client_call_details.metadata is not None: metadata = list(client_call_details.metadata) - metadata.append(( - header, - value, - )) + metadata.append( + ( + header, + value, + ) + ) client_call_details = _ClientCallDetails( - client_call_details.method, client_call_details.timeout, metadata, - client_call_details.credentials) + client_call_details.method, + client_call_details.timeout, + metadata, + client_call_details.credentials, + ) return client_call_details, request_iterator, None return create(intercept_call) diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index 20b5d292d..34f89d609 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -1,26 +1,29 @@ -import copy import base64 +import copy from typing import Dict, Iterable, Union import ujson -from . import blob -from . import entity_helper -from .check import check_pass_param, is_legal_collection_properties -from .types import DataType, PlaceholderType, get_consistency_level -from .constants import DEFAULT_CONSISTENCY_LEVEL -from ..exceptions import ParamError, DataNotMatchException, ExceptionsMessage -from ..orm.schema import CollectionSchema - +from ..exceptions import DataNotMatchException, ExceptionsMessage, ParamError from ..grpc_gen import common_pb2 as common_types -from ..grpc_gen import schema_pb2 as schema_types from ..grpc_gen import milvus_pb2 as milvus_types +from ..grpc_gen import schema_pb2 as schema_types +from ..orm.schema import CollectionSchema +from . import blob, entity_helper +from .check import check_pass_param, is_legal_collection_properties +from .constants import DEFAULT_CONSISTENCY_LEVEL +from .types import DataType, PlaceholderType, get_consistency_level class Prepare: @classmethod - def create_collection_request(cls, collection_name: str, fields: Union[Dict[str, Iterable], CollectionSchema], - shards_num=2, **kwargs) -> milvus_types.CreateCollectionRequest: + def create_collection_request( + cls, + collection_name: str, + fields: Union[Dict[str, Iterable], CollectionSchema], + shards_num=2, + **kwargs, + ) -> milvus_types.CreateCollectionRequest: """ :type fields: Union(Dict[str, Iterable], CollectionSchema) :param fields: (Required) @@ -35,39 +38,54 @@ def create_collection_request(cls, collection_name: str, fields: Union[Dict[str, :return: milvus_types.CreateCollectionRequest """ if isinstance(fields, CollectionSchema): - schema = cls.get_schema_from_collection_schema(collection_name, fields, shards_num, **kwargs) + schema = cls.get_schema_from_collection_schema( + collection_name, fields, shards_num, **kwargs + ) else: schema = cls.get_schema(collection_name, fields, shards_num, **kwargs) - consistency_level = get_consistency_level(kwargs.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL)) + consistency_level = get_consistency_level( + kwargs.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL) + ) - req = milvus_types.CreateCollectionRequest(collection_name=collection_name, - schema=bytes(schema.SerializeToString()), - shards_num=shards_num, - consistency_level=consistency_level) + req = milvus_types.CreateCollectionRequest( + collection_name=collection_name, + schema=bytes(schema.SerializeToString()), + shards_num=shards_num, + consistency_level=consistency_level, + ) properties = kwargs.get("properties") if is_legal_collection_properties(properties): - properties = [common_types.KeyValuePair(key=str(k), value=str(v)) for k, v in properties.items()] + properties = [ + common_types.KeyValuePair(key=str(k), value=str(v)) + for k, v in properties.items() + ] req.properties.extend(properties) return req @classmethod - def get_schema_from_collection_schema(cls, collection_name: str, fields: CollectionSchema, shards_num=2, **kwargs) -> milvus_types.CreateCollectionRequest: + def get_schema_from_collection_schema( + cls, collection_name: str, fields: CollectionSchema, shards_num=2, **kwargs + ) -> milvus_types.CreateCollectionRequest: coll_description = fields.description if not isinstance(coll_description, (str, bytes)): - raise ParamError(message=f"description [{coll_description}] has type {type(coll_description).__name__}, but expected one of: bytes, str") + raise ParamError( + message=f"description [{coll_description}] has type {type(coll_description).__name__}, but expected one of: bytes, str" + ) - schema = schema_types.CollectionSchema(name=collection_name, - autoID=fields.auto_id, - description=coll_description) + schema = schema_types.CollectionSchema( + name=collection_name, autoID=fields.auto_id, description=coll_description + ) for f in fields.fields: - field_schema = schema_types.FieldSchema(name=f.name, - data_type=f.dtype, - description=f.description, - is_primary_key=f.is_primary, - autoID=f.auto_id) + field_schema = schema_types.FieldSchema( + name=f.name, + data_type=f.dtype, + description=f.description, + is_primary_key=f.is_primary, + autoID=f.auto_id, + ) for k, v in f.params.items(): kv_pair = common_types.KeyValuePair(key=str(k), value=str(v)) field_schema.type_params.append(kv_pair) @@ -76,7 +94,9 @@ def get_schema_from_collection_schema(cls, collection_name: str, fields: Collect return schema @classmethod - def get_schema(cls, collection_name: str, fields: Dict[str, Iterable], shards_num=2, **kwargs) -> schema_types.CollectionSchema: + def get_schema( + cls, collection_name: str, fields: Dict[str, Iterable], shards_num=2, **kwargs + ) -> schema_types.CollectionSchema: if not isinstance(fields, dict): raise ParamError(message="Param fields must be a dict") @@ -86,18 +106,20 @@ def get_schema(cls, collection_name: str, fields: Dict[str, Iterable], shards_nu if len(all_fields) == 0: raise ParamError(message="Param fields value cannot be empty") - schema = schema_types.CollectionSchema(name=collection_name, - autoID=False, - description=fields.get('description', '')) + schema = schema_types.CollectionSchema( + name=collection_name, + autoID=False, + description=fields.get("description", ""), + ) primary_field = None auto_id_field = None for field in all_fields: - field_name = field.get('name') + field_name = field.get("name") if field_name is None: raise ParamError(message="You should specify the name of field!") - data_type = field.get('type') + data_type = field.get("type") if data_type is None: raise ParamError(message="You should specify the data type of field!") if not isinstance(data_type, (int, DataType)): @@ -108,31 +130,44 @@ def get_schema(cls, collection_name: str, fields: Dict[str, Iterable], shards_nu raise ParamError(message="is_primary must be boolean") if is_primary: if primary_field is not None: - raise ParamError(message="A collection should only have one primary field") + raise ParamError( + message="A collection should only have one primary field" + ) if DataType(data_type) not in [DataType.INT64, DataType.VARCHAR]: - raise ParamError(message="int64 and varChar are the only supported types of primary key") + raise ParamError( + message="int64 and varChar are the only supported types of primary key" + ) primary_field = field_name - auto_id = field.get('auto_id', False) + auto_id = field.get("auto_id", False) if not isinstance(auto_id, bool): raise ParamError(message="auto_id must be boolean") if auto_id: if auto_id_field is not None: - raise ParamError(message="A collection should only have one autoID field") + raise ParamError( + message="A collection should only have one autoID field" + ) if DataType(data_type) != DataType.INT64: - raise ParamError(message="int64 is the only supported type of automatic generated id") + raise ParamError( + message="int64 is the only supported type of automatic generated id" + ) auto_id_field = field_name - field_schema = schema_types.FieldSchema(name=field_name, - data_type=data_type, - description=field.get('description', ''), - is_primary_key=is_primary, - autoID=auto_id) + field_schema = schema_types.FieldSchema( + name=field_name, + data_type=data_type, + description=field.get("description", ""), + is_primary_key=is_primary, + autoID=auto_id, + ) - type_params = field.get('params', {}) + type_params = field.get("params", {}) if not isinstance(type_params, dict): raise ParamError(message="params should be dictionary type") - kvs = [common_types.KeyValuePair(key=str(k), value=str(v)) for k, v in type_params.items()] + kvs = [ + common_types.KeyValuePair(key=str(k), value=str(v)) + for k, v in type_params.items() + ] field_schema.type_params.extend(kvs) schema.fields.append(field_schema) @@ -158,7 +193,9 @@ def alter_collection_request(cls, collection_name, properties): kv = common_types.KeyValuePair(key=k, value=str(properties[k])) kvs.append(kv) - return milvus_types.AlterCollectionRequest(collection_name=collection_name, properties=kvs) + return milvus_types.AlterCollectionRequest( + collection_name=collection_name, properties=kvs + ) @classmethod def collection_stats_request(cls, collection_name): @@ -169,7 +206,9 @@ def show_collections_request(cls, collection_names=None): req = milvus_types.ShowCollectionsRequest() if collection_names: if not isinstance(collection_names, (list,)): - raise ParamError(message=f"collection_names must be a list of strings, but got: {collection_names}") + raise ParamError( + message=f"collection_names must be a list of strings, but got: {collection_names}" + ) for collection_name in collection_names: check_pass_param(collection_name=collection_name) req.collection_names.extend(collection_names) @@ -178,27 +217,41 @@ def show_collections_request(cls, collection_names=None): @classmethod def create_partition_request(cls, collection_name, partition_name): - return milvus_types.CreatePartitionRequest(collection_name=collection_name, partition_name=partition_name) + return milvus_types.CreatePartitionRequest( + collection_name=collection_name, partition_name=partition_name + ) @classmethod def drop_partition_request(cls, collection_name, partition_name): - return milvus_types.DropPartitionRequest(collection_name=collection_name, partition_name=partition_name) + return milvus_types.DropPartitionRequest( + collection_name=collection_name, partition_name=partition_name + ) @classmethod def has_partition_request(cls, collection_name, partition_name): - return milvus_types.HasPartitionRequest(collection_name=collection_name, partition_name=partition_name) + return milvus_types.HasPartitionRequest( + collection_name=collection_name, partition_name=partition_name + ) @classmethod def partition_stats_request(cls, collection_name, partition_name): - return milvus_types.PartitionStatsRequest(collection_name=collection_name, partition_name=partition_name) + return milvus_types.PartitionStatsRequest( + collection_name=collection_name, partition_name=partition_name + ) @classmethod - def show_partitions_request(cls, collection_name, partition_names=None, type_in_memory=False): - check_pass_param(collection_name=collection_name, partition_name_array=partition_names) + def show_partitions_request( + cls, collection_name, partition_names=None, type_in_memory=False + ): + check_pass_param( + collection_name=collection_name, partition_name_array=partition_names + ) req = milvus_types.ShowPartitionsRequest(collection_name=collection_name) if partition_names: if not isinstance(partition_names, (list,)): - raise ParamError(message=f"partition_names must be a list of strings, but got: {partition_names}") + raise ParamError( + message=f"partition_names must be a list of strings, but got: {partition_names}" + ) for partition_name in partition_names: check_pass_param(partition_name=partition_name) req.partition_names.extend(partition_names) @@ -210,7 +263,9 @@ def show_partitions_request(cls, collection_name, partition_names=None, type_in_ @classmethod def get_loading_progress(cls, collection_name, partition_names=None): - check_pass_param(collection_name=collection_name, partition_name_array=partition_names) + check_pass_param( + collection_name=collection_name, partition_name_array=partition_names + ) req = milvus_types.GetLoadingProgressRequest(collection_name=collection_name) if partition_names: req.partition_names.extend(partition_names) @@ -231,19 +286,30 @@ def partition_name(cls, collection_name, partition_name): raise ParamError(message="collection_name must be of str type") if not isinstance(partition_name, str): raise ParamError(message="partition_name must be of str type") - return milvus_types.PartitionName(collection_name=collection_name, - tag=partition_name) + return milvus_types.PartitionName( + collection_name=collection_name, tag=partition_name + ) @classmethod - def batch_insert_param(cls, collection_name, entities, partition_name, fields_info=None, **kwargs): + def batch_insert_param( + cls, collection_name, entities, partition_name, fields_info=None, **kwargs + ): # insert_request.hash_keys won't be filled in client. It will be filled in proxy. - tag = partition_name or "_default" # should here? - insert_request = milvus_types.InsertRequest(collection_name=collection_name, partition_name=tag) + tag = partition_name or "_default" # should here? + insert_request = milvus_types.InsertRequest( + collection_name=collection_name, partition_name=tag + ) for entity in entities: - if not entity.get("name", None) or not entity.get("values", None) or not entity.get("type", None): - raise ParamError(message="Missing param in entities, a field must have type, name and values") + if ( + not entity.get("name", None) + or not entity.get("values", None) + or not entity.get("type", None) + ): + raise ParamError( + message="Missing param in entities, a field must have type, name and values" + ) if not fields_info: raise ParamError(message="Missing collection meta to validate entities") @@ -265,50 +331,80 @@ def batch_insert_param(cls, collection_name, entities, partition_name, fields_in if field_name == entity_name: if field_type != entity_type: - raise ParamError(message=f"Collection field type is {field_type}" - f", but entities field type is {entity_type}") + raise ParamError( + message=f"Collection field type is {field_type}" + f", but entities field type is {entity_type}" + ) entity_dim, field_dim = 0, 0 if entity_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: field_dim = field["params"]["dim"] entity_dim = len(entity["values"][0]) - if entity_type in [DataType.FLOAT_VECTOR, ] and entity_dim != field_dim: - raise ParamError(message=f"Collection field dim is {field_dim}" - f", but entities field dim is {entity_dim}") - - if entity_type in [DataType.BINARY_VECTOR, ] and entity_dim * 8 != field_dim: - raise ParamError(message=f"Collection field dim is {field_dim}" - f", but entities field dim is {entity_dim * 8}") + if ( + entity_type + in [ + DataType.FLOAT_VECTOR, + ] + and entity_dim != field_dim + ): + raise ParamError( + message=f"Collection field dim is {field_dim}" + f", but entities field dim is {entity_dim}" + ) + + if ( + entity_type + in [ + DataType.BINARY_VECTOR, + ] + and entity_dim * 8 != field_dim + ): + raise ParamError( + message=f"Collection field dim is {field_dim}" + f", but entities field dim is {entity_dim * 8}" + ) location[field["name"]] = j match_flag = True break if not match_flag: - raise ParamError(message=f"Field {field['name']} don't match in entities") + raise ParamError( + message=f"Field {field['name']} don't match in entities" + ) # though impossible from sdk if primary_key_loc is None: raise ParamError(message="primary key not found") if auto_id_loc is None and len(entities) != len(fields_info): - raise ParamError(message=f"number of fields: {len(fields_info)}, number of entities: {len(entities)}") + raise ParamError( + message=f"number of fields: {len(fields_info)}, number of entities: {len(entities)}" + ) if auto_id_loc is not None and len(entities) + 1 != len(fields_info): - raise ParamError(message=f"number of fields: {len(fields_info)}, number of entities: {len(entities)}") + raise ParamError( + message=f"number of fields: {len(fields_info)}, number of entities: {len(entities)}" + ) row_num = 0 try: for entity in entities: current = len(entity.get("values")) if row_num not in (0, current): - raise ParamError(message="row num misaligned current[{current}]!= previous[{row_num}]") + raise ParamError( + message="row num misaligned current[{current}]!= previous[{row_num}]" + ) row_num = current - field_data = entity_helper.entity_to_field_data(entity, fields_info[location[entity.get("name")]]) + field_data = entity_helper.entity_to_field_data( + entity, fields_info[location[entity.get("name")]] + ) insert_request.fields_data.append(field_data) except (TypeError, ValueError) as e: - raise DataNotMatchException(message=ExceptionsMessage.DataTypeInconsistent) from e + raise DataNotMatchException( + message=ExceptionsMessage.DataTypeInconsistent + ) from e insert_request.num_rows = row_num @@ -329,7 +425,9 @@ def check_str(instr, prefix): check_str(partition_name, "partition_name") check_str(expr, "expr") - request = milvus_types.DeleteRequest(collection_name=collection_name, expr=expr, partition_name=partition_name) + request = milvus_types.DeleteRequest( + collection_name=collection_name, expr=expr, partition_name=partition_name + ) return request @classmethod @@ -339,23 +437,38 @@ def _prepare_placeholders(cls, vectors, nq, tag, pl_type, is_binary, dimension=0 for i in range(0, nq): if is_binary: if len(vectors[i]) * 8 != dimension: - raise ParamError(message=f"The dimension of query entities[{vectors[i]*8}] is different from schema [{dimension}]") + raise ParamError( + message=f"The dimension of query entities[{vectors[i]*8}] is different from schema [{dimension}]" + ) pl.values.append(blob.vectorBinaryToBytes(vectors[i])) else: if len(vectors[i]) != dimension: - raise ParamError(message=f"The dimension of query entities[{vectors[i]*8}] is different from schema [{dimension}]") + raise ParamError( + message=f"The dimension of query entities[{vectors[i]*8}] is different from schema [{dimension}]" + ) pl.values.append(blob.vectorFloatToBytes(vectors[i])) return pl @classmethod - def search_request(cls, collection_name, query_entities, partition_names=None, fields=None, round_decimal=-1, **kwargs): + def search_request( + cls, + collection_name, + query_entities, + partition_names=None, + fields=None, + round_decimal=-1, + **kwargs, + ): schema = kwargs.get("schema", None) fields_schema = schema.get("fields", None) # list - fields_name_locs = {fields_schema[loc]["name"]: loc - for loc in range(len(fields_schema))} + fields_name_locs = { + fields_schema[loc]["name"]: loc for loc in range(len(fields_schema)) + } if not isinstance(query_entities, (dict,)): - raise ParamError(message="Invalid query format. 'query_entities' must be a dict") + raise ParamError( + message="Invalid query format. 'query_entities' must be a dict" + ) if fields is not None and not isinstance(fields, (list,)): raise ParamError(message="Invalid query format. 'fields' must be a list") @@ -382,7 +495,9 @@ def extract_vectors_param(param, placeholders, names, round_decimal): for pk, pv in param["vector"].items(): if "query" not in pv: - raise ParamError(message="param vector must contain 'query'") + raise ParamError( + message="param vector must contain 'query'" + ) placeholders[ph] = pv["query"] names[ph] = pk param["vector"][pk]["query"] = ph @@ -396,7 +511,9 @@ def extract_vectors_param(param, placeholders, names, round_decimal): for item in param: extract_vectors_param(item, placeholders, names, round_decimal) - extract_vectors_param(duplicated_entities, vector_placeholders, vector_names, round_decimal) + extract_vectors_param( + duplicated_entities, vector_placeholders, vector_names, round_decimal + ) request.dsl = ujson.dumps(duplicated_entities) plg = common_types.PlaceholderGroup() @@ -408,19 +525,25 @@ def extract_vectors_param(param, placeholders, names, round_decimal): fname = vector_names[tag] if fname not in fields_name_locs: raise ParamError(message=f"Field {fname} doesn't exist in schema") - dimension = int(fields_schema[fields_name_locs[fname]]["params"].get("dim", 0)) + dimension = int( + fields_schema[fields_name_locs[fname]]["params"].get("dim", 0) + ) if isinstance(vectors[0], bytes): pl.type = PlaceholderType.BinaryVector for vector in vectors: if dimension != len(vector) * 8: - raise ParamError(message="The dimension of query vector is different from schema") + raise ParamError( + message="The dimension of query vector is different from schema" + ) pl.values.append(blob.vectorBinaryToBytes(vector)) else: pl.type = PlaceholderType.FloatVector for vector in vectors: if dimension != len(vector): - raise ParamError(message="The dimension of query vector is different from schema") + raise ParamError( + message="The dimension of query vector is different from schema" + ) pl.values.append(blob.vectorFloatToBytes(vector)) # vector_values_bytes = service_msg_types.VectorValues.SerializeToString(vector_values) @@ -431,11 +554,25 @@ def extract_vectors_param(param, placeholders, names, round_decimal): return request @classmethod - def search_requests_with_expr(cls, collection_name, data, anns_field, param, limit, schema, expr=None, - partition_names=None, output_fields=None, round_decimal=-1, **kwargs): + def search_requests_with_expr( + cls, + collection_name, + data, + anns_field, + param, + limit, + schema, + expr=None, + partition_names=None, + output_fields=None, + round_decimal=-1, + **kwargs, + ): # TODO Move this impl into server side fields_schema = schema.get("fields", None) # list - fields_name_locs = {fields_schema[loc]["name"]: loc for loc in range(len(fields_schema))} + fields_name_locs = { + fields_schema[loc]["name"]: loc for loc in range(len(fields_schema)) + } requests = [] if len(data) <= 0: @@ -450,11 +587,15 @@ def search_requests_with_expr(cls, collection_name, data, anns_field, param, lim if anns_field not in fields_name_locs: raise ParamError(message=f"Field {anns_field} doesn't exist in schema") - dimension = int(fields_schema[fields_name_locs[anns_field]]["params"].get("dim", 0)) + dimension = int( + fields_schema[fields_name_locs[anns_field]]["params"].get("dim", 0) + ) params = param.get("params", {}) if not isinstance(params, dict): - raise ParamError(message=f"Search params must be a dict, got {type(params)}") + raise ParamError( + message=f"Search params must be a dict, got {type(params)}" + ) search_params = { "anns_field": anns_field, "topk": limit, @@ -488,15 +629,21 @@ def dump(v): request.dsl_type = common_types.DslType.BoolExprV1 if expr is not None: request.dsl = expr - request.search_params.extend([common_types.KeyValuePair(key=str(key), value=dump(value)) - for key, value in search_params.items()]) + request.search_params.extend( + [ + common_types.KeyValuePair(key=str(key), value=dump(value)) + for key, value in search_params.items() + ] + ) requests.append(request) return requests @classmethod def create_alias_request(cls, collection_name, alias): - return milvus_types.CreateAliasRequest(collection_name=collection_name, alias=alias) + return milvus_types.CreateAliasRequest( + collection_name=collection_name, alias=alias + ) @classmethod def drop_alias_request(cls, alias): @@ -504,12 +651,17 @@ def drop_alias_request(cls, alias): @classmethod def alter_alias_request(cls, collection_name, alias): - return milvus_types.AlterAliasRequest(collection_name=collection_name, alias=alias) + return milvus_types.AlterAliasRequest( + collection_name=collection_name, alias=alias + ) @classmethod def create_index_request(cls, collection_name, field_name, params, **kwargs): - index_params = milvus_types.CreateIndexRequest(collection_name=collection_name, field_name=field_name, - index_name=kwargs.get("index_name", "")) + index_params = milvus_types.CreateIndexRequest( + collection_name=collection_name, + field_name=field_name, + index_name=kwargs.get("index_name", ""), + ) # index_params.collection_name = collection_name # index_params.field_name = field_name @@ -517,6 +669,7 @@ def create_index_request(cls, collection_name, field_name, params, **kwargs): def dump(tv): if isinstance(tv, dict): import json + return json.dumps(tv) return str(tv) @@ -532,43 +685,64 @@ def dump(tv): @classmethod def describe_index_request(cls, collection_name, index_name): - return milvus_types.DescribeIndexRequest(collection_name=collection_name, index_name=index_name) + return milvus_types.DescribeIndexRequest( + collection_name=collection_name, index_name=index_name + ) @classmethod def get_index_build_progress(cls, collection_name: str, index_name: str): - return milvus_types.GetIndexBuildProgressRequest(collection_name=collection_name, index_name=index_name) + return milvus_types.GetIndexBuildProgressRequest( + collection_name=collection_name, index_name=index_name + ) @classmethod def get_index_state_request(cls, collection_name: str, index_name: str): - return milvus_types.GetIndexStateRequest(collection_name=collection_name, index_name=index_name) + return milvus_types.GetIndexStateRequest( + collection_name=collection_name, index_name=index_name + ) @classmethod def load_collection(cls, db_name, collection_name, replica_number): - return milvus_types.LoadCollectionRequest(db_name=db_name, collection_name=collection_name, - replica_number=replica_number) + return milvus_types.LoadCollectionRequest( + db_name=db_name, + collection_name=collection_name, + replica_number=replica_number, + ) @classmethod def release_collection(cls, db_name, collection_name): - return milvus_types.ReleaseCollectionRequest(db_name=db_name, collection_name=collection_name) + return milvus_types.ReleaseCollectionRequest( + db_name=db_name, collection_name=collection_name + ) @classmethod def load_partitions(cls, db_name, collection_name, partition_names, replica_number): - return milvus_types.LoadPartitionsRequest(db_name=db_name, collection_name=collection_name, - partition_names=partition_names, - replica_number=replica_number) + return milvus_types.LoadPartitionsRequest( + db_name=db_name, + collection_name=collection_name, + partition_names=partition_names, + replica_number=replica_number, + ) @classmethod def release_partitions(cls, db_name, collection_name, partition_names): - return milvus_types.ReleasePartitionsRequest(db_name=db_name, collection_name=collection_name, - partition_names=partition_names) + return milvus_types.ReleasePartitionsRequest( + db_name=db_name, + collection_name=collection_name, + partition_names=partition_names, + ) @classmethod def get_collection_stats_request(cls, collection_name): - return milvus_types.GetCollectionStatisticsRequest(collection_name=collection_name) + return milvus_types.GetCollectionStatisticsRequest( + collection_name=collection_name + ) @classmethod def get_persistent_segment_info_request(cls, collection_name): - return milvus_types.GetPersistentSegmentInfoRequest(collectionName=collection_name) + return milvus_types.GetPersistentSegmentInfoRequest( + collectionName=collection_name + ) @classmethod def get_flush_state_request(cls, segment_ids): @@ -584,13 +758,18 @@ def flush_param(cls, collection_names): @classmethod def drop_index_request(cls, collection_name, field_name, index_name): - return milvus_types.DropIndexRequest(db_name="", collection_name=collection_name, field_name=field_name, - index_name=index_name) + return milvus_types.DropIndexRequest( + db_name="", + collection_name=collection_name, + field_name=field_name, + index_name=index_name, + ) @classmethod def get_partition_stats_request(cls, collection_name, partition_name): - return milvus_types.GetPartitionStatisticsRequest(db_name="", collection_name=collection_name, - partition_name=partition_name) + return milvus_types.GetPartitionStatisticsRequest( + db_name="", collection_name=collection_name, partition_name=partition_name + ) @classmethod def dummy_request(cls, request_type): @@ -599,35 +778,46 @@ def dummy_request(cls, request_type): @classmethod def retrieve_request(cls, collection_name, ids, output_fields, partition_names): ids = schema_types.IDs(int_id=schema_types.LongArray(data=ids)) - return milvus_types.RetrieveRequest(db_name="", - collection_name=collection_name, - ids=ids, - output_fields=output_fields, - partition_names=partition_names) - - @classmethod - def query_request(cls, collection_name, expr, output_fields, partition_names, **kwargs): - req = milvus_types.QueryRequest(db_name="", - collection_name=collection_name, - expr=expr, - output_fields=output_fields, - partition_names=partition_names, - guarantee_timestamp=kwargs.get("guarantee_timestamp", 0), - travel_timestamp=kwargs.get("travel_timestamp", 0), - ) + return milvus_types.RetrieveRequest( + db_name="", + collection_name=collection_name, + ids=ids, + output_fields=output_fields, + partition_names=partition_names, + ) + + @classmethod + def query_request( + cls, collection_name, expr, output_fields, partition_names, **kwargs + ): + req = milvus_types.QueryRequest( + db_name="", + collection_name=collection_name, + expr=expr, + output_fields=output_fields, + partition_names=partition_names, + guarantee_timestamp=kwargs.get("guarantee_timestamp", 0), + travel_timestamp=kwargs.get("travel_timestamp", 0), + ) limit = kwargs.get("limit", None) if limit is not None: - req.query_params.append(common_types.KeyValuePair(key="limit", value=str(limit))) + req.query_params.append( + common_types.KeyValuePair(key="limit", value=str(limit)) + ) offset = kwargs.get("offset", None) if offset is not None: - req.query_params.append(common_types.KeyValuePair(key="offset", value=str(offset))) + req.query_params.append( + common_types.KeyValuePair(key="offset", value=str(offset)) + ) return req @classmethod - def load_balance_request(cls, collection_name, src_node_id, dst_node_ids, sealed_segment_ids): + def load_balance_request( + cls, collection_name, src_node_id, dst_node_ids, sealed_segment_ids + ): request = milvus_types.LoadBalanceRequest( collectionName=collection_name, src_nodeID=src_node_id, @@ -680,7 +870,9 @@ def get_replicas(cls, collection_id: int): return request @classmethod - def do_bulk_insert(cls, collection_name: str, partition_name: str, files: list, **kwargs): + def do_bulk_insert( + cls, collection_name: str, partition_name: str, files: list, **kwargs + ): channel_names = kwargs.get("channel_names", None) req = milvus_types.ImportRequest( collection_name=collection_name, @@ -719,17 +911,20 @@ def list_bulk_insert_tasks(cls, limit, collection_name): @classmethod def create_user_request(cls, user, password): check_pass_param(user=user, password=password) - return milvus_types.CreateCredentialRequest(username=user, password=base64.b64encode(password.encode('utf-8'))) + return milvus_types.CreateCredentialRequest( + username=user, password=base64.b64encode(password.encode("utf-8")) + ) @classmethod def update_password_request(cls, user, old_password, new_password): check_pass_param(user=user) check_pass_param(password=old_password) check_pass_param(password=new_password) - return milvus_types.UpdateCredentialRequest(username=user, - oldPassword=base64.b64encode(old_password.encode('utf-8')), - newPassword=base64.b64encode(new_password.encode('utf-8')), - ) + return milvus_types.UpdateCredentialRequest( + username=user, + oldPassword=base64.b64encode(old_password.encode("utf-8")), + newPassword=base64.b64encode(new_password.encode("utf-8")), + ) @classmethod def delete_user_request(cls, user): @@ -744,7 +939,9 @@ def list_usernames_request(cls): @classmethod def create_role_request(cls, role_name): check_pass_param(role_name=role_name) - return milvus_types.CreateRoleRequest(entity=milvus_types.RoleEntity(name=role_name)) + return milvus_types.CreateRoleRequest( + entity=milvus_types.RoleEntity(name=role_name) + ) @classmethod def drop_role_request(cls, role_name): @@ -756,38 +953,50 @@ def operate_user_role_request(cls, username, role_name, operate_user_role_type): check_pass_param(user=username) check_pass_param(role_name=role_name) check_pass_param(operate_user_role_type=operate_user_role_type) - return milvus_types.OperateUserRoleRequest(username=username, role_name=role_name, type=operate_user_role_type) + return milvus_types.OperateUserRoleRequest( + username=username, role_name=role_name, type=operate_user_role_type + ) @classmethod def select_role_request(cls, role_name, include_user_info): if role_name: check_pass_param(role_name=role_name) check_pass_param(include_user_info=include_user_info) - return milvus_types.SelectRoleRequest(role=milvus_types.RoleEntity(name=role_name) if role_name else None, - include_user_info=include_user_info) + return milvus_types.SelectRoleRequest( + role=milvus_types.RoleEntity(name=role_name) if role_name else None, + include_user_info=include_user_info, + ) @classmethod def select_user_request(cls, username, include_role_info): if username: check_pass_param(user=username) check_pass_param(include_role_info=include_role_info) - return milvus_types.SelectUserRequest(user=milvus_types.UserEntity(name=username) if username else None, - include_role_info=include_role_info) + return milvus_types.SelectUserRequest( + user=milvus_types.UserEntity(name=username) if username else None, + include_role_info=include_role_info, + ) @classmethod - def operate_privilege_request(cls, role_name, object, object_name, privilege, operate_privilege_type): + def operate_privilege_request( + cls, role_name, object, object_name, privilege, operate_privilege_type + ): check_pass_param(role_name=role_name) check_pass_param(object=object) check_pass_param(object_name=object_name) check_pass_param(privilege=privilege) check_pass_param(operate_privilege_type=operate_privilege_type) return milvus_types.OperatePrivilegeRequest( - entity=milvus_types.GrantEntity(role=milvus_types.RoleEntity(name=role_name), - object=milvus_types.ObjectEntity(name=object), - object_name=object_name, - grantor=milvus_types.GrantorEntity( - privilege=milvus_types.PrivilegeEntity(name=privilege))), - type=operate_privilege_type) + entity=milvus_types.GrantEntity( + role=milvus_types.RoleEntity(name=role_name), + object=milvus_types.ObjectEntity(name=object), + object_name=object_name, + grantor=milvus_types.GrantorEntity( + privilege=milvus_types.PrivilegeEntity(name=privilege) + ), + ), + type=operate_privilege_type, + ) @classmethod def select_grant_request(cls, role_name, object, object_name): @@ -797,9 +1006,12 @@ def select_grant_request(cls, role_name, object, object_name): if object_name: check_pass_param(object_name=object_name) return milvus_types.SelectGrantRequest( - entity=milvus_types.GrantEntity(role=milvus_types.RoleEntity(name=role_name), - object=milvus_types.ObjectEntity(name=object) if object else None, - object_name=object_name if object_name else None)) + entity=milvus_types.GrantEntity( + role=milvus_types.RoleEntity(name=role_name), + object=milvus_types.ObjectEntity(name=object) if object else None, + object_name=object_name if object_name else None, + ) + ) @classmethod def get_server_version(cls): diff --git a/pymilvus/client/stub.py b/pymilvus/client/stub.py index bffebca54..8ebb4ff0d 100644 --- a/pymilvus/client/stub.py +++ b/pymilvus/client/stub.py @@ -1,17 +1,23 @@ from urllib import parse -from .grpc_handler import GrpcHandler +from ..decorators import deprecated from ..exceptions import MilvusException, ParamError -from .types import CompactionState, CompactionPlans, Replica, BulkInsertState from ..settings import DefaultConfig as config -from ..decorators import deprecated - from .check import is_legal_host, is_legal_port +from .grpc_handler import GrpcHandler +from .types import BulkInsertState, CompactionPlans, CompactionState, Replica class Milvus: @deprecated - def __init__(self, host=None, port=config.GRPC_PORT, uri=config.GRPC_URI, channel=None, **kwargs): + def __init__( + self, + host=None, + port=config.GRPC_PORT, + uri=config.GRPC_URI, + channel=None, + **kwargs, + ): self.address = self.__get_address(host, port, uri) self._handler = GrpcHandler(address=self.address, channel=channel, **kwargs) @@ -20,7 +26,7 @@ def __init__(self, host=None, port=config.GRPC_PORT, uri=config.GRPC_URI, channe def __get_address(self, host=None, port=config.GRPC_PORT, uri=config.GRPC_URI): if host is None and uri is None: - raise ParamError(message='Host and uri cannot both be None') + raise ParamError(message="Host and uri cannot both be None") if host is None: try: @@ -56,8 +62,10 @@ def close(self): self.handler.close() self._handler = None - def create_collection(self, collection_name, fields, shards_num=2, timeout=None, **kwargs): - """ Creates a collection. + def create_collection( + self, collection_name, fields, shards_num=2, timeout=None, **kwargs + ): + """Creates a collection. :param collection_name: The name of the collection. A collection name can only include numbers, letters, and underscores, and must not begin with a number. @@ -99,7 +107,13 @@ def create_collection(self, collection_name, fields, shards_num=2, timeout=None, :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.create_collection(collection_name, fields, shards_num=shards_num, timeout=timeout, **kwargs) + return handler.create_collection( + collection_name, + fields, + shards_num=shards_num, + timeout=timeout, + **kwargs, + ) def drop_collection(self, collection_name, timeout=None): """ @@ -168,7 +182,9 @@ def describe_collection(self, collection_name, timeout=None): with self._connection() as handler: return handler.describe_collection(collection_name, timeout=timeout) - def load_collection(self, collection_name, replica_number=1, timeout=None, **kwargs): + def load_collection( + self, collection_name, replica_number=1, timeout=None, **kwargs + ): """ Loads a specified collection from disk to memory. @@ -190,7 +206,9 @@ def load_collection(self, collection_name, replica_number=1, timeout=None, **kwa :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.load_collection(collection_name, replica_number, timeout=timeout, **kwargs) + return handler.load_collection( + collection_name, replica_number, timeout=timeout, **kwargs + ) def release_collection(self, collection_name, timeout=None): """ @@ -211,7 +229,9 @@ def release_collection(self, collection_name, timeout=None): :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.release_collection(collection_name=collection_name, timeout=timeout) + return handler.release_collection( + collection_name=collection_name, timeout=timeout + ) def get_collection_stats(self, collection_name, timeout=None, **kwargs): """ @@ -233,7 +253,9 @@ def get_collection_stats(self, collection_name, timeout=None, **kwargs): :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - stats = handler.get_collection_stats(collection_name, timeout=timeout, **kwargs) + stats = handler.get_collection_stats( + collection_name, timeout=timeout, **kwargs + ) result = {stat.key: stat.value for stat in stats} result["row_count"] = int(result["row_count"]) return result @@ -280,7 +302,9 @@ def create_partition(self, collection_name, partition_name, timeout=None): :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.create_partition(collection_name, partition_name, timeout=timeout) + return handler.create_partition( + collection_name, partition_name, timeout=timeout + ) def drop_partition(self, collection_name, partition_name, timeout=None): """ @@ -306,7 +330,9 @@ def drop_partition(self, collection_name, partition_name, timeout=None): :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.drop_partition(collection_name, partition_name, timeout=timeout) + return handler.drop_partition( + collection_name, partition_name, timeout=timeout + ) def has_partition(self, collection_name, partition_name, timeout=None): """ @@ -330,9 +356,13 @@ def has_partition(self, collection_name, partition_name, timeout=None): :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.has_partition(collection_name, partition_name, timeout=timeout) + return handler.has_partition( + collection_name, partition_name, timeout=timeout + ) - def load_partitions(self, collection_name, partition_names, replica_number=1, timeout=None): + def load_partitions( + self, collection_name, partition_names, replica_number=1, timeout=None + ): """ Load specified partitions from disk to memory. @@ -357,9 +387,12 @@ def load_partitions(self, collection_name, partition_names, replica_number=1, ti :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.load_partitions(collection_name=collection_name, - partition_names=partition_names, - replica_number=replica_number, timeout=timeout) + return handler.load_partitions( + collection_name=collection_name, + partition_names=partition_names, + replica_number=replica_number, + timeout=timeout, + ) def release_partitions(self, collection_name, partition_names, timeout=None): """ @@ -383,8 +416,11 @@ def release_partitions(self, collection_name, partition_names, timeout=None): :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.release_partitions(collection_name=collection_name, - partition_names=partition_names, timeout=timeout) + return handler.release_partitions( + collection_name=collection_name, + partition_names=partition_names, + timeout=timeout, + ) def list_partitions(self, collection_name, timeout=None): """ @@ -407,7 +443,9 @@ def list_partitions(self, collection_name, timeout=None): with self._connection() as handler: return handler.list_partitions(collection_name, timeout=timeout) - def get_partition_stats(self, collection_name, partition_name, timeout=None, **kwargs): + def get_partition_stats( + self, collection_name, partition_name, timeout=None, **kwargs + ): """ Returns partition statistics information. Example: {"row_count": 10} @@ -430,7 +468,9 @@ def get_partition_stats(self, collection_name, partition_name, timeout=None, **k :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - stats = handler.get_partition_stats(collection_name, partition_name, timeout=timeout, **kwargs) + stats = handler.get_partition_stats( + collection_name, partition_name, timeout=timeout, **kwargs + ) result = {stat.key: stat.value for stat in stats} result["row_count"] = int(result["row_count"]) return result @@ -463,7 +503,9 @@ def create_alias(self, collection_name, alias, timeout=None, **kwargs): :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.create_alias(collection_name, alias, timeout=timeout, **kwargs) + return handler.create_alias( + collection_name, alias, timeout=timeout, **kwargs + ) def drop_alias(self, alias, timeout=None, **kwargs): """ @@ -522,7 +564,9 @@ def alter_alias(self, collection_name, alias, timeout=None, **kwargs): :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.alter_alias(collection_name, alias, timeout=timeout, **kwargs) + return handler.alter_alias( + collection_name, alias, timeout=timeout, **kwargs + ) def create_index(self, collection_name, field_name, params, timeout=None, **kwargs): """ @@ -629,7 +673,9 @@ def create_index(self, collection_name, field_name, params, timeout=None, **kwar :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.create_index(collection_name, field_name, params, timeout=timeout, **kwargs) + return handler.create_index( + collection_name, field_name, params, timeout=timeout, **kwargs + ) def drop_index(self, collection_name, field_name, timeout=None): """ @@ -653,8 +699,12 @@ def drop_index(self, collection_name, field_name, timeout=None): :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.drop_index(collection_name=collection_name, - field_name=field_name, index_name="", timeout=timeout) + return handler.drop_index( + collection_name=collection_name, + field_name=field_name, + index_name="", + timeout=timeout, + ) def describe_index(self, collection_name, index_name="", timeout=None): """ @@ -681,7 +731,9 @@ def describe_index(self, collection_name, index_name="", timeout=None): with self._connection() as handler: return handler.describe_index(collection_name, index_name, timeout=timeout) - def insert(self, collection_name, entities, partition_name=None, timeout=None, **kwargs): + def insert( + self, collection_name, entities, partition_name=None, timeout=None, **kwargs + ): """ Inserts entities in a specified collection. @@ -715,9 +767,13 @@ def insert(self, collection_name, entities, partition_name=None, timeout=None, * :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.batch_insert(collection_name, entities, partition_name, timeout=timeout, **kwargs) + return handler.batch_insert( + collection_name, entities, partition_name, timeout=timeout, **kwargs + ) - def delete(self, collection_name, expr, partition_name=None, timeout=None, **kwargs): + def delete( + self, collection_name, expr, partition_name=None, timeout=None, **kwargs + ): """ Delete entities with an expression condition. And return results to show which primary key is deleted successfully @@ -743,7 +799,9 @@ def delete(self, collection_name, expr, partition_name=None, timeout=None, **kwa :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.delete(collection_name, expr, partition_name, timeout=timeout, **kwargs) + return handler.delete( + collection_name, expr, partition_name, timeout=timeout, **kwargs + ) def flush(self, collection_names=None, timeout=None, **kwargs): """ @@ -780,8 +838,20 @@ def flush(self, collection_names=None, timeout=None, **kwargs): with self._connection() as handler: return handler.flush(collection_names, timeout=timeout, **kwargs) - def search(self, collection_name, data, anns_field, param, limit, expression=None, partition_names=None, - output_fields=None, timeout=None, round_decimal=-1, **kwargs): + def search( + self, + collection_name, + data, + anns_field, + param, + limit, + expression=None, + partition_names=None, + output_fields=None, + timeout=None, + round_decimal=-1, + **kwargs, + ): """ Searches a collection based on the given expression and returns query results. @@ -839,8 +909,19 @@ def search(self, collection_name, data, anns_field, param, limit, expression=Non :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.search(collection_name, data, anns_field, param, limit, expression, partition_names, - output_fields, round_decimal=round_decimal, timeout=timeout, **kwargs) + return handler.search( + collection_name, + data, + anns_field, + param, + limit, + expression, + partition_names, + output_fields, + round_decimal=round_decimal, + timeout=timeout, + **kwargs, + ) def get_query_segment_info(self, collection_name, timeout=None, **kwargs): """ @@ -856,10 +937,12 @@ def get_query_segment_info(self, collection_name, timeout=None, **kwargs): :rtype: QuerySegmentInfo """ with self._connection() as handler: - return handler.get_query_segment_info(collection_name, timeout=timeout, **kwargs) + return handler.get_query_segment_info( + collection_name, timeout=timeout, **kwargs + ) def load_collection_progress(self, collection_name, timeout=None): - """ { + """{ 'loading_progress': '100%', 'num_loaded_partitions': 3, 'not_loaded_partitions': [], @@ -869,36 +952,54 @@ def load_collection_progress(self, collection_name, timeout=None): return handler.load_collection_progress(collection_name, timeout=timeout) def load_partitions_progress(self, collection_name, partition_names, timeout=None): - """ { + """{ 'loading_progress': '100%', 'num_loaded_partitions': 3, 'not_loaded_partitions': [], } """ with self._connection() as handler: - return handler.load_partitions_progress(collection_name, partition_names, timeout=timeout) + return handler.load_partitions_progress( + collection_name, partition_names, timeout=timeout + ) def wait_for_loading_collection_complete(self, collection_name, timeout=None): with self._connection() as handler: return handler.wait_for_loading_collection(collection_name, timeout=timeout) - def wait_for_loading_partitions_complete(self, collection_name, partition_names, timeout=None): + def wait_for_loading_partitions_complete( + self, collection_name, partition_names, timeout=None + ): with self._connection() as handler: - return handler.wait_for_loading_partitions(collection_name, partition_names, timeout=timeout) + return handler.wait_for_loading_partitions( + collection_name, partition_names, timeout=timeout + ) def get_index_build_progress(self, collection_name, index_name, timeout=None): with self._connection() as handler: - return handler.get_index_build_progress(collection_name, index_name, timeout=timeout) + return handler.get_index_build_progress( + collection_name, index_name, timeout=timeout + ) def wait_for_creating_index(self, collection_name, index_name, timeout=None): with self._connection() as handler: - return handler.wait_for_creating_index(collection_name, index_name, timeout=timeout) + return handler.wait_for_creating_index( + collection_name, index_name, timeout=timeout + ) def dummy(self, request_type, timeout=None): with self._connection() as handler: return handler.dummy(request_type, timeout=timeout) - def query(self, collection_name, expr, output_fields=None, partition_names=None, timeout=None, **kwargs): + def query( + self, + collection_name, + expr, + output_fields=None, + partition_names=None, + timeout=None, + **kwargs, + ): """ Query with a set of criteria, and results in a list of records that match the query exactly. @@ -943,9 +1044,24 @@ def query(self, collection_name, expr, output_fields=None, partition_names=None, :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.query(collection_name, expr, output_fields, partition_names, timeout=timeout, **kwargs) - - def load_balance(self, collection_name: str, src_node_id, dst_node_ids, sealed_segment_ids, timeout=None, **kwargs): + return handler.query( + collection_name, + expr, + output_fields, + partition_names, + timeout=timeout, + **kwargs, + ) + + def load_balance( + self, + collection_name: str, + src_node_id, + dst_node_ids, + sealed_segment_ids, + timeout=None, + **kwargs, + ): """ Do load balancing operation from source query node to destination query node. :param collection_name: The collection to balance. @@ -967,8 +1083,14 @@ def load_balance(self, collection_name: str, src_node_id, dst_node_ids, sealed_s :raises MilvusException: If sealed segments not exist. """ with self._connection() as handler: - return handler.load_balance(collection_name, src_node_id, dst_node_ids, sealed_segment_ids, - timeout=timeout, **kwargs) + return handler.load_balance( + collection_name, + src_node_id, + dst_node_ids, + sealed_segment_ids, + timeout=timeout, + **kwargs, + ) def compact(self, collection_name, timeout=None, **kwargs) -> int: """ @@ -988,7 +1110,9 @@ def compact(self, collection_name, timeout=None, **kwargs) -> int: with self._connection() as handler: return handler.compact(collection_name, timeout=timeout, **kwargs) - def get_compaction_state(self, compaction_id: int, timeout=None, **kwargs) -> CompactionState: + def get_compaction_state( + self, compaction_id: int, timeout=None, **kwargs + ) -> CompactionState: """ Get compaction states of a targeted compaction id @@ -1005,13 +1129,21 @@ def get_compaction_state(self, compaction_id: int, timeout=None, **kwargs) -> Co """ with self._connection() as handler: - return handler.get_compaction_state(compaction_id, timeout=timeout, **kwargs) + return handler.get_compaction_state( + compaction_id, timeout=timeout, **kwargs + ) - def wait_for_compaction_completed(self, compaction_id: int, timeout=None, **kwargs) -> CompactionState: + def wait_for_compaction_completed( + self, compaction_id: int, timeout=None, **kwargs + ) -> CompactionState: with self._connection() as handler: - return handler.wait_for_compaction_completed(compaction_id, timeout=timeout, **kwargs) + return handler.wait_for_compaction_completed( + compaction_id, timeout=timeout, **kwargs + ) - def get_compaction_plans(self, compaction_id: int, timeout=None, **kwargs) -> CompactionPlans: + def get_compaction_plans( + self, compaction_id: int, timeout=None, **kwargs + ) -> CompactionPlans: """ Get compaction states of a targeted compaction id @@ -1027,10 +1159,12 @@ def get_compaction_plans(self, compaction_id: int, timeout=None, **kwargs) -> Co :raises MilvusException: If compaction_id doesn't exist. """ with self._connection() as handler: - return handler.get_compaction_plans(compaction_id, timeout=timeout, **kwargs) + return handler.get_compaction_plans( + compaction_id, timeout=timeout, **kwargs + ) def get_replicas(self, collection_name: str, timeout=None, **kwargs) -> Replica: - """ Get replica infos of a collection + """Get replica infos of a collection :param collection_name: the name of the collection :type collection_name: str @@ -1046,8 +1180,15 @@ def get_replicas(self, collection_name: str, timeout=None, **kwargs) -> Replica: with self._connection() as handler: return handler.get_replicas(collection_name, timeout=timeout, **kwargs) - def do_bulk_insert(self, collection_name: str, partition_name: str, files: list, timeout=None, **kwargs) -> int: - """ do_bulk_insert inserts entities through files, currently supports row-based json file. + def do_bulk_insert( + self, + collection_name: str, + partition_name: str, + files: list, + timeout=None, + **kwargs, + ) -> int: + """do_bulk_insert inserts entities through files, currently supports row-based json file. User need to create the json file with a specified json format which is described in the official user guide. Let's say a collection has two fields: "id" and "vec"(dimension=8), the row-based json format is: {"rows": [ @@ -1083,7 +1224,9 @@ def do_bulk_insert(self, collection_name: str, partition_name: str, files: list, :raises BaseException: If the files input is illegal. """ with self._connection() as handler: - return handler.do_bulk_insert(collection_name, partition_name, files, timeout=timeout, **kwargs) + return handler.do_bulk_insert( + collection_name, partition_name, files, timeout=timeout, **kwargs + ) def get_bulk_insert_state(self, task_id, timeout=None, **kwargs) -> BulkInsertState: """get_bulk_insert_state returns state of a certain task_id @@ -1114,7 +1257,7 @@ def list_bulk_insert_tasks(self, timeout=None, **kwargs) -> list: return handler.list_bulk_insert_tasks(timeout=timeout, **kwargs) def create_user(self, user, password, timeout=None, **kwargs): - """ Create a user using the given user and password. + """Create a user using the given user and password. :param user: the user name. :type user: str :param password: the password. @@ -1140,10 +1283,12 @@ def update_password(self, user, old_password, new_password, timeout=None, **kwar :type new_password: str """ with self._connection() as handler: - handler.update_password(user, old_password, new_password, timeout=timeout, **kwargs) + handler.update_password( + user, old_password, new_password, timeout=timeout, **kwargs + ) def delete_user(self, user, timeout=None, **kwargs): - """ Delete user corresponding to the username. + """Delete user corresponding to the username. :param user: the user name. :type user: str :param timeout: The timeout for this method, unit: second @@ -1153,7 +1298,7 @@ def delete_user(self, user, timeout=None, **kwargs): handler.delete_user(user, timeout=timeout, **kwargs) def list_usernames(self, timeout=None, **kwargs): - """ List all usernames. + """List all usernames. :param timeout: The timeout for this method, unit: second :type timeout: int :return list of str: @@ -1163,7 +1308,7 @@ def list_usernames(self, timeout=None, **kwargs): return handler.list_usernames(timeout=timeout, **kwargs) def create_role(self, role_name, timeout=None, **kwargs): - """ Create Role + """Create Role :param role_name: the role name. :type role_name: str """ @@ -1171,7 +1316,7 @@ def create_role(self, role_name, timeout=None, **kwargs): handler.create_role(role_name, timeout=timeout, **kwargs) def drop_role(self, role_name, timeout=None, **kwargs): - """ Drop Role + """Drop Role :param role_name: role name. :type role_name: str """ @@ -1179,7 +1324,7 @@ def drop_role(self, role_name, timeout=None, **kwargs): handler.drop_role(role_name, timeout=timeout, **kwargs) def add_user_to_role(self, username, role_name, timeout=None, **kwargs): - """ Add User To Role + """Add User To Role :param username: user name. :type username: str :param role_name: role name. @@ -1189,27 +1334,31 @@ def add_user_to_role(self, username, role_name, timeout=None, **kwargs): handler.add_user_to_role(username, role_name, timeout=timeout, **kwargs) def remove_user_from_role(self, username, role_name, timeout=None, **kwargs): - """ Remove User From Role + """Remove User From Role :param username: user name. :type username: str :param role_name: role name. :type role_name: str """ with self._connection() as handler: - handler.remove_user_from_role(username, role_name, timeout=timeout, **kwargs) + handler.remove_user_from_role( + username, role_name, timeout=timeout, **kwargs + ) def select_one_role(self, role_name, include_user_info, timeout=None, **kwargs): - """ Select One Role Info + """Select One Role Info :param role_name: role name. :type role_name: str :param include_user_info: whether to obtain the user information associated with the role :type include_user_info: bool """ with self._connection() as handler: - handler.select_one_role(role_name, include_user_info, timeout=timeout, **kwargs) + handler.select_one_role( + role_name, include_user_info, timeout=timeout, **kwargs + ) def select_all_role(self, include_user_info, timeout=None, **kwargs): - """ Select All Role Info + """Select All Role Info :param include_user_info: whether to obtain the user information associated with roles :type include_user_info: bool """ @@ -1217,26 +1366,29 @@ def select_all_role(self, include_user_info, timeout=None, **kwargs): handler.select_all_role(include_user_info, timeout=timeout, **kwargs) def select_one_user(self, username, include_role_info, timeout=None, **kwargs): - """ Select One User Info + """Select One User Info :param username: user name. :type username: str :param include_role_info: whether to obtain the role information associated with the user :type include_role_info: bool """ with self._connection() as handler: - handler.select_one_user(username, include_role_info, timeout=timeout, **kwargs) + handler.select_one_user( + username, include_role_info, timeout=timeout, **kwargs + ) def select_all_user(self, include_role_info, timeout=None, **kwargs): - """ Select All User Info + """Select All User Info :param include_role_info: whether to obtain the role information associated with users :type include_role_info: bool """ with self._connection() as handler: handler.select_all_role(include_role_info, timeout=timeout, **kwargs) - def grant_privilege(self, role_name, object, object_name, privilege, - timeout=None, **kwargs): - """ Grant Privilege + def grant_privilege( + self, role_name, object, object_name, privilege, timeout=None, **kwargs + ): + """Grant Privilege :param role_name: role name. :type role_name: str :param object: object that will be granted the privilege. @@ -1247,12 +1399,14 @@ def grant_privilege(self, role_name, object, object_name, privilege, :type privilege: str """ with self._connection() as handler: - handler.grant_privilege(role_name, object, object_name, privilege, - timeout=timeout, **kwargs) - - def revoke_privilege(self, role_name, object, object_name, privilege, - timeout=None, **kwargs): - """ Revoke Privilege + handler.grant_privilege( + role_name, object, object_name, privilege, timeout=timeout, **kwargs + ) + + def revoke_privilege( + self, role_name, object, object_name, privilege, timeout=None, **kwargs + ): + """Revoke Privilege :param role_name: role name. :type role_name: str :param object: object that will be granted the privilege. @@ -1263,20 +1417,22 @@ def revoke_privilege(self, role_name, object, object_name, privilege, :type privilege: str """ with self._connection() as handler: - handler.revoke_privilege(role_name, object, object_name, privilege, - timeout=timeout, **kwargs) + handler.revoke_privilege( + role_name, object, object_name, privilege, timeout=timeout, **kwargs + ) def select_grant_for_one_role(self, role_name, timeout=None, **kwargs): - """ Select the grant info about the role + """Select the grant info about the role :param role_name: role name. :type role_name: str """ with self._connection() as handler: handler.select_grant_for_one_role(role_name, timeout=timeout, **kwargs) - def select_grant_for_role_and_object(self, role_name, object, object_name, - timeout=None, **kwargs): - """ Select the grant info about the role and specific object + def select_grant_for_role_and_object( + self, role_name, object, object_name, timeout=None, **kwargs + ): + """Select the grant info about the role and specific object :param role_name: role name. :type role_name: str :param object: object that will be selected the privilege info. @@ -1285,7 +1441,9 @@ def select_grant_for_role_and_object(self, role_name, object, object_name, :type object_name: str """ with self._connection() as handler: - handler.select_grant_for_role_and_object(role_name, object, object_name, timeout=timeout, **kwargs) + handler.select_grant_for_role_and_object( + role_name, object, object_name, timeout=timeout, **kwargs + ) def get_version(self, timeout=None, **kwargs): with self._connection() as handler: diff --git a/pymilvus/client/ts_utils.py b/pymilvus/client/ts_utils.py index 87a3eaa80..7f8cfd8c2 100644 --- a/pymilvus/client/ts_utils.py +++ b/pymilvus/client/ts_utils.py @@ -1,11 +1,10 @@ -import threading import datetime +import threading +from ..grpc_gen.common_pb2 import ConsistencyLevel +from .constants import BOUNDED_TS, EVENTUALLY_TS from .singleton_utils import Singleton from .utils import hybridts_to_unixtime -from .constants import EVENTUALLY_TS, BOUNDED_TS - -from ..grpc_gen.common_pb2 import ConsistencyLevel class GTsDict(metaclass=Singleton): @@ -77,7 +76,9 @@ def construct_guarantee_ts(consistency_level, collection_name, kwargs): elif consistency_level == ConsistencyLevel.Session: # Using the last write ts of the collection. # TODO: get a timestamp from server? - kwargs["guarantee_timestamp"] = get_collection_ts(collection_name) or get_eventually_ts() + kwargs["guarantee_timestamp"] = ( + get_collection_ts(collection_name) or get_eventually_ts() + ) elif consistency_level == ConsistencyLevel.Bounded: # Milvus will assign ts according to the server timestamp and a configured time interval kwargs["guarantee_timestamp"] = get_bounded_ts() diff --git a/pymilvus/client/types.py b/pymilvus/client/types.py index 70bd51d29..696bc2a27 100644 --- a/pymilvus/client/types.py +++ b/pymilvus/client/types.py @@ -1,13 +1,10 @@ import time from enum import IntEnum -from ..grpc_gen.common_pb2 import ConsistencyLevel + +from ..exceptions import AutoIDException, ExceptionsMessage, InvalidConsistencyLevel from ..grpc_gen import common_pb2 -from ..exceptions import ( - AutoIDException, - ExceptionsMessage, - InvalidConsistencyLevel, -) from ..grpc_gen import milvus_pb2 as milvus_types +from ..grpc_gen.common_pb2 import ConsistencyLevel class Status: @@ -50,11 +47,11 @@ def __init__(self, code=SUCCESS, message="Success"): self.message = message def __repr__(self): - attr_list = [f'{key}={value}' for key, value in self.__dict__.items()] + attr_list = [f"{key}={value}" for key, value in self.__dict__.items()] return f"{self.__class__.__name__}({', '.join(attr_list)})" def __eq__(self, other): - """ Make Status comparable with self by code """ + """Make Status comparable with self by code""" if isinstance(other, int): return self.code == other @@ -186,7 +183,14 @@ class CompactionState: completed: number of plans successfully completed """ - def __init__(self, compaction_id: int, state: State, in_executing: int, in_timeout: int, completed: int): + def __init__( + self, + compaction_id: int, + state: State, + in_executing: int, + in_timeout: int, + completed: int, + ): self.compaction_id = compaction_id self.state = state self.in_executing = in_executing @@ -260,12 +264,16 @@ def get_consistency_level(consistency_level): if isinstance(consistency_level, int): if consistency_level in ConsistencyLevel.values(): return consistency_level - raise InvalidConsistencyLevel(message=f"invalid consistency level: {consistency_level}") + raise InvalidConsistencyLevel( + message=f"invalid consistency level: {consistency_level}" + ) if isinstance(consistency_level, str): try: return ConsistencyLevel.Value(consistency_level) except ValueError as e: - raise InvalidConsistencyLevel(message=f"invalid consistency level: {consistency_level}") from e + raise InvalidConsistencyLevel( + message=f"invalid consistency level: {consistency_level}" + ) from e raise InvalidConsistencyLevel(message="invalid consistency level") @@ -339,6 +347,7 @@ def groups(self): class BulkInsertState: """enum states of bulk insert task""" + ImportPending = 0 ImportFailed = 1 ImportStarted = 2 @@ -382,7 +391,9 @@ class BulkInsertState: ImportUnknownState: "Unknown", } - def __init__(self, task_id, state, row_count: int, id_ranges: list, infos, create_ts: int): + def __init__( + self, task_id, state, row_count: int, id_ranges: list, infos, create_ts: int + ): self._task_id = task_id self._state = state self._row_count = row_count @@ -400,8 +411,14 @@ def __repr__(self) -> str: - id_ranges : {}, - create_ts : {} >""" - return fmt.format(self._task_id, self.state_name, self.row_count, self.infos, - self.id_ranges, self.create_time_str) + return fmt.format( + self._task_id, + self.state_name, + self.row_count, + self.infos, + self.id_ranges, + self.create_time_str, + ) @property def task_id(self): @@ -505,9 +522,11 @@ def __init__(self, entity): self._privilege = entity.grantor.privilege.name def __repr__(self) -> str: - s = f"GrantItem: , , " \ - f", , " \ + s = ( + f"GrantItem: , , " + f", , " f"" + ) return s @property diff --git a/pymilvus/client/utils.py b/pymilvus/client/utils.py index 477702be9..d3ec0f1f0 100644 --- a/pymilvus/client/utils.py +++ b/pymilvus/client/utils.py @@ -1,8 +1,9 @@ import datetime -from .types import DataType -from .constants import LOGICAL_BITS, LOGICAL_BITS_MASK from ..exceptions import MilvusException +from .constants import LOGICAL_BITS, LOGICAL_BITS_MASK +from .types import DataType + valid_index_types = [ "FLAT", @@ -19,7 +20,7 @@ "BIN_FLAT", "BIN_IVF_FLAT", "DISKANN", - "AUTOINDEX" + "AUTOINDEX", ] valid_index_params_keys = [ @@ -29,20 +30,17 @@ "M", "efConstruction", "PQM", - "n_trees" + "n_trees", ] -valid_binary_index_types = [ - "BIN_FLAT", - "BIN_IVF_FLAT" -] +valid_binary_index_types = ["BIN_FLAT", "BIN_IVF_FLAT"] valid_binary_metric_types = [ "JACCARD", "HAMMING", "TANIMOTO", "SUBSTRUCTURE", - "SUPERSTRUCTURE" + "SUPERSTRUCTURE", ] @@ -51,14 +49,18 @@ def hybridts_to_unixtime(ts): return physical / 1000.0 -def mkts_from_hybridts(hybridts, milliseconds=0., delta=None): +def mkts_from_hybridts(hybridts, milliseconds=0.0, delta=None): if not isinstance(milliseconds, (int, float)): - raise MilvusException(message="parameter milliseconds should be type of int or float") + raise MilvusException( + message="parameter milliseconds should be type of int or float" + ) if isinstance(delta, datetime.timedelta): - milliseconds += (delta.microseconds / 1000.0) + milliseconds += delta.microseconds / 1000.0 elif delta is not None: - raise MilvusException(message="parameter delta should be type of datetime.timedelta") + raise MilvusException( + message="parameter delta should be type of datetime.timedelta" + ) if not isinstance(hybridts, int): raise MilvusException(message="parameter hybridts should be type of int") @@ -66,45 +68,53 @@ def mkts_from_hybridts(hybridts, milliseconds=0., delta=None): logical = hybridts & LOGICAL_BITS_MASK physical = hybridts >> LOGICAL_BITS - new_ts = int((int((physical + milliseconds)) << LOGICAL_BITS) + logical) + new_ts = int((int(physical + milliseconds) << LOGICAL_BITS) + logical) return new_ts -def mkts_from_unixtime(epoch, milliseconds=0., delta=None): +def mkts_from_unixtime(epoch, milliseconds=0.0, delta=None): if not isinstance(epoch, (int, float)): raise MilvusException(message="parameter epoch should be type of int or float") if not isinstance(milliseconds, (int, float)): - raise MilvusException(message="parameter milliseconds should be type of int or float") + raise MilvusException( + message="parameter milliseconds should be type of int or float" + ) if isinstance(delta, datetime.timedelta): - milliseconds += (delta.microseconds / 1000.0) + milliseconds += delta.microseconds / 1000.0 elif delta is not None: - raise MilvusException(message="parameter delta should be type of datetime.timedelta") + raise MilvusException( + message="parameter delta should be type of datetime.timedelta" + ) - epoch += (milliseconds / 1000.0) + epoch += milliseconds / 1000.0 int_msecs = int(epoch * 1000 // 1) return int(int_msecs << LOGICAL_BITS) -def mkts_from_datetime(d_time, milliseconds=0., delta=None): +def mkts_from_datetime(d_time, milliseconds=0.0, delta=None): if not isinstance(d_time, datetime.datetime): - raise MilvusException(message="parameter d_time should be type of datetime.datetime") + raise MilvusException( + message="parameter d_time should be type of datetime.datetime" + ) - return mkts_from_unixtime(d_time.timestamp(), milliseconds=milliseconds, delta=delta) + return mkts_from_unixtime( + d_time.timestamp(), milliseconds=milliseconds, delta=delta + ) def check_invalid_binary_vector(entities) -> bool: for entity in entities: - if entity['type'] == DataType.BINARY_VECTOR: - if not isinstance(entity['values'], list) and len(entity['values']) == 0: + if entity["type"] == DataType.BINARY_VECTOR: + if not isinstance(entity["values"], list) and len(entity["values"]) == 0: return False - dim = len(entity['values'][0]) * 8 + dim = len(entity["values"][0]) * 8 if dim == 0: return False - for values in entity['values']: + for values in entity["values"]: if len(values) * 8 != dim: return False if not isinstance(values, bytes): @@ -142,7 +152,9 @@ def len_of(field_data) -> int: if field_data.vectors.HasField("float_vector"): total_len = len(field_data.vectors.float_vector.data) if total_len % dim != 0: - raise MilvusException(message=f"Invalid vector length: total_len={total_len}, dim={dim}") + raise MilvusException( + message=f"Invalid vector length: total_len={total_len}, dim={dim}" + ) return int(total_len / dim) total_len = len(field_data.vectors.binary_vector) diff --git a/pymilvus/decorators.py b/pymilvus/decorators.py index aee37de21..e44efc3f2 100644 --- a/pymilvus/decorators.py +++ b/pymilvus/decorators.py @@ -1,13 +1,14 @@ -import time import datetime -import logging import functools +import logging +import time import grpc from .exceptions import MilvusException, MilvusUnavailableException from .grpc_gen import common_pb2 + LOGGER = logging.getLogger(__name__) WARNING_COLOR = "\033[93m{}\033[0m" @@ -18,10 +19,17 @@ def inner(*args, **kwargs): dup_msg = "[WARNING] PyMilvus: class Milvus will be deprecated soon, please use Collection/utility instead" LOGGER.warning(WARNING_COLOR.format(dup_msg)) return func(*args, **kwargs) + return inner -def retry_on_rpc_failure(retry_times=10, initial_back_off=0.01, max_back_off=60, back_off_multiplier=3, retry_on_deadline=True): +def retry_on_rpc_failure( + retry_times=10, + initial_back_off=0.01, + max_back_off=60, + back_off_multiplier=3, + retry_on_deadline=True, +): # the default 7 retry_times will cost about 26s def wrapper(func): @functools.wraps(func) @@ -32,14 +40,16 @@ def handler(self, *args, **kwargs): _timeout = kwargs.get("timeout", None) _retry_on_rate_limit = kwargs.get("retry_on_rate_limit", True) - retry_timeout = _timeout if _timeout is not None and isinstance(_timeout, int) else None + retry_timeout = ( + _timeout if _timeout is not None and isinstance(_timeout, int) else None + ) counter = 1 back_off = initial_back_off start_time = time.time() def timeout(start_time) -> bool: - """ If timeout is valid, use timeout as the retry limits, - If timeout is None, use retry_times as the retry limits. + """If timeout is valid, use timeout as the retry limits, + If timeout is None, use retry_times as the retry limits. """ if retry_timeout is not None: return time.time() - start_time >= retry_timeout @@ -52,18 +62,31 @@ def timeout(start_time) -> bool: # DEADLINE_EXCEEDED means that the task wat not completed # UNAVAILABLE means that the service is not reachable currently # Reference: https://grpc.github.io/grpc/python/grpc.html#grpc-status-code - if e.code() != grpc.StatusCode.DEADLINE_EXCEEDED and e.code() != grpc.StatusCode.UNAVAILABLE: + if ( + e.code() != grpc.StatusCode.DEADLINE_EXCEEDED + and e.code() != grpc.StatusCode.UNAVAILABLE + ): raise MilvusException(message=str(e)) from e - if not retry_on_deadline and e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: + if ( + not retry_on_deadline + and e.code() == grpc.StatusCode.DEADLINE_EXCEEDED + ): raise MilvusException(message=str(e)) from e if timeout(start_time): - timeout_msg = f"Retry timeout: {retry_timeout}s" if retry_timeout is not None \ + timeout_msg = ( + f"Retry timeout: {retry_timeout}s" + if retry_timeout is not None else f"Retry run out of {retry_times} retry times" + ) if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: - raise MilvusException(message=f"rpc deadline exceeded: {timeout_msg}") from e + raise MilvusException( + message=f"rpc deadline exceeded: {timeout_msg}" + ) from e if e.code() == grpc.StatusCode.UNAVAILABLE: - raise MilvusUnavailableException(message=f"server Unavailable: {timeout_msg}") from e + raise MilvusUnavailableException( + message=f"server Unavailable: {timeout_msg}" + ) from e raise MilvusException(message=str(e)) from e if counter > 3: @@ -74,10 +97,15 @@ def timeout(start_time) -> bool: back_off = min(back_off * back_off_multiplier, max_back_off) except MilvusException as e: if timeout(start_time): - timeout_msg = f"Retry timeout: {retry_timeout}s" if retry_timeout is not None \ + timeout_msg = ( + f"Retry timeout: {retry_timeout}s" + if retry_timeout is not None else f"Retry run out of {retry_times} retry times" + ) LOGGER.warning(WARNING_COLOR.format(timeout_msg)) - raise MilvusException(e.code, f"{timeout_msg}, message={e.message}") from e + raise MilvusException( + e.code, f"{timeout_msg}, message={e.message}" + ) from e if _retry_on_rate_limit and e.code == common_pb2.RateLimit: time.sleep(back_off) back_off = min(back_off * back_off_multiplier, max_back_off) @@ -89,6 +117,7 @@ def timeout(start_time) -> bool: counter += 1 return handler + return wrapper @@ -109,17 +138,25 @@ def handler(*args, **kwargs): raise e except grpc.FutureTimeoutError as e: record_dict["gRPC timeout"] = str(datetime.datetime.now()) - LOGGER.error(f"grpc Timeout: [{inner_name}], <{e.__class__.__name__}: {e.code()}, {e.details()}>, ") + LOGGER.error( + f"grpc Timeout: [{inner_name}], <{e.__class__.__name__}: {e.code()}, {e.details()}>, " + ) raise e except grpc.RpcError as e: record_dict["gRPC error"] = str(datetime.datetime.now()) - LOGGER.error(f"grpc RpcError: [{inner_name}], <{e.__class__.__name__}: {e.code()}, {e.details()}>, ") + LOGGER.error( + f"grpc RpcError: [{inner_name}], <{e.__class__.__name__}: {e.code()}, {e.details()}>, " + ) raise e except Exception as e: record_dict["Exception"] = str(datetime.datetime.now()) - LOGGER.error(f"Unexcepted error: [{inner_name}], {e}, ") + LOGGER.error( + f"Unexcepted error: [{inner_name}], {e}, " + ) raise e + return handler + return wrapper @@ -135,5 +172,7 @@ def handler(self, *args, **kwargs): self.set_onetime_request_id(req_id) ret = func(self, *args, **kwargs) return ret + return handler + return wrapper diff --git a/pymilvus/exceptions.py b/pymilvus/exceptions.py index 0f1cfc1e0..03e00916d 100644 --- a/pymilvus/exceptions.py +++ b/pymilvus/exceptions.py @@ -37,35 +37,35 @@ def __str__(self): class ParamError(MilvusException): - """ Raise when params are incorrect """ + """Raise when params are incorrect""" class ConnectError(MilvusException): - """ Connect server fail """ + """Connect server fail""" class MilvusUnavailableException(MilvusException): - """ Raise when server's Unavaliable""" + """Raise when server's Unavaliable""" class CollectionNotExistException(MilvusException): - """ Raise when collections doesn't exist """ + """Raise when collections doesn't exist""" class DescribeCollectionException(MilvusException): - """ Raise when fail to describe collection """ + """Raise when fail to describe collection""" class PartitionNotExistException(MilvusException): - """ Raise when partition doesn't exist """ + """Raise when partition doesn't exist""" class PartitionAlreadyExistException(MilvusException): - """ Raise when create an exsiting partition """ + """Raise when create an exsiting partition""" class IndexNotExistException(MilvusException): - """ Raise when index doesn't exist """ + """Raise when index doesn't exist""" class AmbiguousIndexName(MilvusException): @@ -73,51 +73,51 @@ class AmbiguousIndexName(MilvusException): class CannotInferSchemaException(MilvusException): - """ Raise when cannot trasfer dataframe to schema """ + """Raise when cannot trasfer dataframe to schema""" class SchemaNotReadyException(MilvusException): - """ Raise when schema is wrong """ + """Raise when schema is wrong""" class DataTypeNotMatchException(MilvusException): - """ Raise when datatype dosen't match """ + """Raise when datatype dosen't match""" class DataTypeNotSupportException(MilvusException): - """ Raise when datatype isn't supported """ + """Raise when datatype isn't supported""" class DataNotMatchException(MilvusException): - """ Raise when insert data isn't match with schema """ + """Raise when insert data isn't match with schema""" class ConnectionNotExistException(MilvusException): - """ Raise when connections doesn't exist """ + """Raise when connections doesn't exist""" class ConnectionConfigException(MilvusException): - """ Raise when configs of connection are invalid """ + """Raise when configs of connection are invalid""" class PrimaryKeyException(MilvusException): - """ Raise when primarykey are invalid """ + """Raise when primarykey are invalid""" class FieldsTypeException(MilvusException): - """ Raise when fields is invalid """ + """Raise when fields is invalid""" class FieldTypeException(MilvusException): - """ Raise when one field is invalid """ + """Raise when one field is invalid""" class AutoIDException(MilvusException): - """ Raise when autoID is invalid """ + """Raise when autoID is invalid""" class InvalidConsistencyLevel(MilvusException): - """ Raise when consistency level is invalid """ + """Raise when consistency level is invalid""" class ExceptionsMessage: @@ -128,7 +128,9 @@ class ExceptionsMessage: AliasType = "Alias should be string, but %r is given." ConnLackConf = "You need to pass in the configuration of the connection named %r ." ConnectFirst = "should create connect first." - CollectionNotExistNoSchema = "Collection %r not exist, or you can pass in schema to create one." + CollectionNotExistNoSchema = ( + "Collection %r not exist, or you can pass in schema to create one." + ) NoSchema = "Should be passed into the schema." EmptySchema = "The field of the schema cannot be empty." SchemaType = "Schema type must be schema.CollectionSchema." diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index fdf3b18a8..3f6fdc165 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -13,43 +13,56 @@ import copy import json from typing import List + import pandas +from ..client.configs import DefaultConfigs +from ..client.constants import DEFAULT_CONSISTENCY_LEVEL +from ..client.types import ( + CompactionPlans, + CompactionState, + Replica, + cmp_consistency_level, + get_consistency_level, +) +from ..exceptions import ( + AutoIDException, + DataTypeNotMatchException, + ExceptionsMessage, + IndexNotExistException, + PartitionAlreadyExistException, + PartitionNotExistException, + SchemaNotReadyException, +) from .connections import connections +from .default_config import DefaultConfig +from .future import MutationFuture, SearchFuture +from .index import Index +from .mutation import MutationResult +from .partition import Partition +from .prepare import Prepare from .schema import ( CollectionSchema, FieldSchema, - parse_fields_from_data, check_insert_data_schema, check_schema, + parse_fields_from_data, ) -from .prepare import Prepare -from .partition import Partition -from .index import Index from .search import SearchResult -from .mutation import MutationResult from .types import DataType -from ..exceptions import ( - SchemaNotReadyException, - DataTypeNotMatchException, - PartitionAlreadyExistException, - PartitionNotExistException, - IndexNotExistException, - AutoIDException, - ExceptionsMessage, -) -from .future import SearchFuture, MutationFuture from .utility import _get_connection -from .default_config import DefaultConfig -from ..client.types import CompactionState, CompactionPlans, Replica, get_consistency_level, cmp_consistency_level -from ..client.constants import DEFAULT_CONSISTENCY_LEVEL -from ..client.configs import DefaultConfigs - class Collection: - def __init__(self, name: str, schema: CollectionSchema=None, using: str="default", shards_num: int=2, **kwargs): - """ Constructs a collection by name, schema and other parameters. + def __init__( + self, + name: str, + schema: CollectionSchema = None, + using: str = "default", + shards_num: int = 2, + **kwargs + ): + """Constructs a collection by name, schema and other parameters. Args: name (``str``): the name of collection @@ -96,10 +109,14 @@ def __init__(self, name: str, schema: CollectionSchema=None, using: str="default has = conn.has_collection(self._name, **kwargs) if has: resp = conn.describe_collection(self._name, **kwargs) - s_consistency_level = resp.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL) + s_consistency_level = resp.get( + "consistency_level", DEFAULT_CONSISTENCY_LEVEL + ) arg_consistency_level = kwargs.get("consistency_level", s_consistency_level) if not cmp_consistency_level(s_consistency_level, arg_consistency_level): - raise SchemaNotReadyException(message=ExceptionsMessage.ConsistencyLevelInconsistent) + raise SchemaNotReadyException( + message=ExceptionsMessage.ConsistencyLevelInconsistent + ) server_schema = CollectionSchema.construct_from_dict(resp) self._consistency_level = s_consistency_level if schema is None: @@ -108,16 +125,24 @@ def __init__(self, name: str, schema: CollectionSchema=None, using: str="default if not isinstance(schema, CollectionSchema): raise SchemaNotReadyException(message=ExceptionsMessage.SchemaType) if server_schema != schema: - raise SchemaNotReadyException(message=ExceptionsMessage.SchemaInconsistent) + raise SchemaNotReadyException( + message=ExceptionsMessage.SchemaInconsistent + ) self._schema = schema else: if schema is None: - raise SchemaNotReadyException(message=ExceptionsMessage.CollectionNotExistNoSchema % name) + raise SchemaNotReadyException( + message=ExceptionsMessage.CollectionNotExistNoSchema % name + ) if isinstance(schema, CollectionSchema): check_schema(schema) - consistency_level = get_consistency_level(kwargs.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL)) - conn.create_collection(self._name, schema, shards_num=self._shards_num, **kwargs) + consistency_level = get_consistency_level( + kwargs.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL) + ) + conn.create_collection( + self._name, schema, shards_num=self._shards_num, **kwargs + ) self._schema = schema self._consistency_level = consistency_level else: @@ -128,10 +153,10 @@ def __init__(self, name: str, schema: CollectionSchema=None, using: str="default def __repr__(self): _dict = { - 'name': self.name, - 'partitions': self.partitions, - 'description': self.description, - 'schema': self._schema, + "name": self.name, + "partitions": self.partitions, + "description": self.description, + "schema": self._schema, } r = [":\n-------------\n"] s = "<{}>: {}\n" @@ -176,17 +201,25 @@ def construct_from_dataframe(cls, name, dataframe, **kwargs): else: fields_schema = parse_fields_from_data(dataframe) if auto_id: - fields_schema.insert(pk_index, - FieldSchema(name=primary_field, dtype=DataType.INT64, is_primary=True, - auto_id=True, - **kwargs)) + fields_schema.insert( + pk_index, + FieldSchema( + name=primary_field, + dtype=DataType.INT64, + is_primary=True, + auto_id=True, + **kwargs + ), + ) for field in fields_schema: if auto_id is False and field.name == primary_field: field.is_primary = True field.auto_id = False if field.dtype == DataType.VARCHAR: - field.params[DefaultConfigs.MaxVarCharLengthKey] = int(DefaultConfigs.MaxVarCharLength) + field.params[DefaultConfigs.MaxVarCharLengthKey] = int( + DefaultConfigs.MaxVarCharLength + ) schema = CollectionSchema(fields=fields_schema) check_schema(schema) @@ -196,12 +229,12 @@ def construct_from_dataframe(cls, name, dataframe, **kwargs): @property def schema(self) -> CollectionSchema: - """CollectionSchema: schema of the collection. """ + """CollectionSchema: schema of the collection.""" return self._schema @property def aliases(self, **kwargs) -> list: - """List[str]: all the aliases of the collection. """ + """List[str]: all the aliases of the collection.""" conn = self._get_connection() resp = conn.describe_collection(self._name, **kwargs) aliases = resp["aliases"] @@ -209,12 +242,12 @@ def aliases(self, **kwargs) -> list: @property def description(self) -> str: - """str: a text description of the collection. """ + """str: a text description of the collection.""" return self._schema.description @property def name(self) -> str: - """str: the name of the collection. """ + """str: the name of the collection.""" return self._name @property @@ -255,7 +288,7 @@ def primary_field(self) -> FieldSchema: return self._schema.primary_field def flush(self, timeout=None, **kwargs): - """ Seal all segments in the collection. Inserts after flushing will be written into + """Seal all segments in the collection. Inserts after flushing will be written into new segments. Only sealed segments can be indexed. Args: @@ -280,7 +313,7 @@ def flush(self, timeout=None, **kwargs): conn.flush([self.name], timeout=timeout, **kwargs) def drop(self, timeout=None, **kwargs): - """ Drops the collection. The same as `utility.drop_collection()` + """Drops the collection. The same as `utility.drop_collection()` Args: timeout (float, optional): an optional duration of time in seconds to allow for the RPCs. @@ -304,7 +337,7 @@ def drop(self, timeout=None, **kwargs): conn.drop_collection(self._name, timeout=timeout, **kwargs) def set_properties(self, properties, timeout=None, **kwargs): - """ Set properties for the collection + """Set properties for the collection Args: properties (``dict``): collection properties. @@ -327,7 +360,7 @@ def set_properties(self, properties, timeout=None, **kwargs): conn.alter_collection(self.name, properties=properties, timeout=timeout) def load(self, partition_names=None, replica_number=1, timeout=None, **kwargs): - """ Load the data into memory. + """Load the data into memory. Args: partition_names (``List[str]``): The specified partitions to load. @@ -356,12 +389,20 @@ def load(self, partition_names=None, replica_number=1, timeout=None, **kwargs): """ conn = self._get_connection() if partition_names is not None: - conn.load_partitions(self._name, partition_names, replica_number=replica_number, timeout=timeout, **kwargs) + conn.load_partitions( + self._name, + partition_names, + replica_number=replica_number, + timeout=timeout, + **kwargs + ) else: - conn.load_collection(self._name, replica_number=replica_number, timeout=timeout, **kwargs) + conn.load_collection( + self._name, replica_number=replica_number, timeout=timeout, **kwargs + ) def release(self, timeout=None, **kwargs): - """ Releases the collection data from memory. + """Releases the collection data from memory. Args: timeout (``float``, optional): an optional duration of time in seconds to allow for the RPCs. @@ -383,8 +424,14 @@ def release(self, timeout=None, **kwargs): conn = self._get_connection() conn.release_collection(self._name, timeout=timeout, **kwargs) - def insert(self, data: [List, pandas.DataFrame], partition_name: str=None, timeout=None, **kwargs) -> MutationResult: - """ Insert data into the collection. + def insert( + self, + data: [List, pandas.DataFrame], + partition_name: str = None, + timeout=None, + **kwargs + ) -> MutationResult: + """Insert data into the collection. Args: data (``list/tuple/pandas.DataFrame``): The specified data to insert @@ -422,15 +469,21 @@ def insert(self, data: [List, pandas.DataFrame], partition_name: str=None, timeo entities = Prepare.prepare_insert_data(data, self._schema) conn = self._get_connection() - res = conn.batch_insert(self._name, entities, partition_name, - timeout=timeout, schema=self._schema_dict, **kwargs) + res = conn.batch_insert( + self._name, + entities, + partition_name, + timeout=timeout, + schema=self._schema_dict, + **kwargs + ) if kwargs.get("_async", False): return MutationFuture(res) return MutationResult(res) def delete(self, expr, partition_name=None, timeout=None, **kwargs): - """ Delete entities with an expression condition. + """Delete entities with an expression condition. Args: expr (``str``): The specified data to insert. @@ -472,9 +525,20 @@ def delete(self, expr, partition_name=None, timeout=None, **kwargs): return MutationFuture(res) return MutationResult(res) - def search(self, data, anns_field, param, limit, expr=None, partition_names=None, - output_fields=None, timeout=None, round_decimal=-1, **kwargs): - """ Conducts a vector similarity search with an optional boolean expression as filter. + def search( + self, + data, + anns_field, + param, + limit, + expr=None, + partition_names=None, + output_fields=None, + timeout=None, + round_decimal=-1, + **kwargs + ): + """Conducts a vector similarity search with an optional boolean expression as filter. Args: data (``List[List[float]]``): The vectors of search data. @@ -603,18 +667,33 @@ def search(self, data, anns_field, param, limit, expr=None, partition_names=None - Top1 hit id: 8, distance: 0.10143111646175385, score: 0.10143111646175385 """ if expr is not None and not isinstance(expr, str): - raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr)) + raise DataTypeNotMatchException( + message=ExceptionsMessage.ExprType % type(expr) + ) conn = self._get_connection() - res = conn.search(self._name, data, anns_field, param, limit, expr, - partition_names, output_fields, round_decimal, timeout=timeout, - schema=self._schema_dict, **kwargs) + res = conn.search( + self._name, + data, + anns_field, + param, + limit, + expr, + partition_names, + output_fields, + round_decimal, + timeout=timeout, + schema=self._schema_dict, + **kwargs + ) if kwargs.get("_async", False): return SearchFuture(res) return SearchResult(res) - def query(self, expr, output_fields=None, partition_names=None, timeout=None, **kwargs): - """ Query with expressions + def query( + self, expr, output_fields=None, partition_names=None, timeout=None, **kwargs + ): + """Query with expressions Args: expr (``str``): The query expression. @@ -687,16 +766,25 @@ def query(self, expr, output_fields=None, partition_names=None, timeout=None, ** - Query results: [{'film_id': 1, 'film_date': 2001}] """ if not isinstance(expr, str): - raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr)) + raise DataTypeNotMatchException( + message=ExceptionsMessage.ExprType % type(expr) + ) conn = self._get_connection() - res = conn.query(self._name, expr, output_fields, partition_names, - timeout=timeout, schema=self._schema_dict, **kwargs) + res = conn.query( + self._name, + expr, + output_fields, + partition_names, + timeout=timeout, + schema=self._schema_dict, + **kwargs + ) return res @property def partitions(self, **kwargs) -> List[Partition]: - """ List[Partition]: List of Partition object. + """List[Partition]: List of Partition object. Raises: MilvusException: If anything goes wrong. @@ -720,7 +808,7 @@ def partitions(self, **kwargs) -> List[Partition]: return partitions def partition(self, partition_name, **kwargs) -> Partition: - """ Get the existing partition object according to name. Return None if not existed. + """Get the existing partition object according to name. Return None if not existed. Args: partition_name (``str``): The name of the partition to get. @@ -747,7 +835,7 @@ def partition(self, partition_name, **kwargs) -> Partition: return Partition(self, partition_name, construct_only=True, **kwargs) def create_partition(self, partition_name, description="", **kwargs) -> Partition: - """ Create a new partition corresponding to name if not existed. + """Create a new partition corresponding to name if not existed. Args: partition_name (``str``): The name of the partition to create. @@ -773,11 +861,13 @@ def create_partition(self, partition_name, description="", **kwargs) -> Partitio {"name": "comedy", "collection_name": "test_collection_create_partition", "description": ""} """ if self.has_partition(partition_name, **kwargs) is True: - raise PartitionAlreadyExistException(message=ExceptionsMessage.PartitionAlreadyExist) + raise PartitionAlreadyExistException( + message=ExceptionsMessage.PartitionAlreadyExist + ) return Partition(self, partition_name, description=description, **kwargs) def has_partition(self, partition_name, timeout=None, **kwargs) -> bool: - """ Checks if a specified partition exists. + """Checks if a specified partition exists. Args: partition_name (``str``): The name of the partition to check. @@ -809,7 +899,7 @@ def has_partition(self, partition_name, timeout=None, **kwargs) -> bool: return conn.has_partition(self._name, partition_name, timeout=timeout, **kwargs) def drop_partition(self, partition_name, timeout=None, **kwargs): - """ Drop the partition in this collection. + """Drop the partition in this collection. Args: partition_name (``str``): The name of the partition to drop. @@ -837,9 +927,13 @@ def drop_partition(self, partition_name, timeout=None, **kwargs): False """ if self.has_partition(partition_name, **kwargs) is False: - raise PartitionNotExistException(message=ExceptionsMessage.PartitionNotExist) + raise PartitionNotExistException( + message=ExceptionsMessage.PartitionNotExist + ) conn = self._get_connection() - return conn.drop_partition(self._name, partition_name, timeout=timeout, **kwargs) + return conn.drop_partition( + self._name, partition_name, timeout=timeout, **kwargs + ) @property def indexes(self, **kwargs) -> List[Index]: @@ -865,7 +959,13 @@ def indexes(self, **kwargs) -> List[Index]: if info_dict.get("params", None): info_dict["params"] = json.loads(info_dict["params"]) - index_info = Index(self, index.field_name, info_dict, index_name=index.index_name, construct_only=True) + index_info = Index( + self, + index.field_name, + info_dict, + index_name=index.index_name, + construct_only=True, + ) indexes.append(index_info) return indexes @@ -906,7 +1006,9 @@ def index(self, **kwargs) -> Index: if tmp_index is not None: field_name = tmp_index.pop("field_name", None) index_name = tmp_index.pop("index_name", index_name) - return Index(self, field_name, tmp_index, construct_only=True, index_name=index_name) + return Index( + self, field_name, tmp_index, construct_only=True, index_name=index_name + ) raise IndexNotExistException(message=ExceptionsMessage.IndexNotExist) def create_index(self, field_name, index_params={}, timeout=None, **kwargs): @@ -945,10 +1047,12 @@ def create_index(self, field_name, index_params={}, timeout=None, **kwargs): Status(code=0, message='') """ conn = self._get_connection() - return conn.create_index(self._name, field_name, index_params, timeout=timeout, **kwargs) + return conn.create_index( + self._name, field_name, index_params, timeout=timeout, **kwargs + ) def has_index(self, timeout=None, **kwargs) -> bool: - """ Check whether a specified index exists. + """Check whether a specified index exists. Args: timeout (``float``, optional): An optional duration of time in seconds to allow for the RPC. When timeout @@ -977,12 +1081,15 @@ def has_index(self, timeout=None, **kwargs) -> bool: conn = self._get_connection() copy_kwargs = copy.deepcopy(kwargs) index_name = copy_kwargs.pop("index_name", DefaultConfigs.IndexName) - if conn.describe_index(self._name, index_name, timeout=timeout, **copy_kwargs) is None: + if ( + conn.describe_index(self._name, index_name, timeout=timeout, **copy_kwargs) + is None + ): return False return True def drop_index(self, timeout=None, **kwargs): - """ Drop index and its corresponding index files. + """Drop index and its corresponding index files. Args: timeout (``float``, optional): An optional duration of time in seconds to allow for the RPC. When timeout is set to None, client waits until server response or error occur. @@ -1013,13 +1120,21 @@ def drop_index(self, timeout=None, **kwargs): copy_kwargs = copy.deepcopy(kwargs) index_name = copy_kwargs.pop("index_name", DefaultConfigs.IndexName) conn = self._get_connection() - tmp_index = conn.describe_index(self._name, index_name, timeout=timeout, **copy_kwargs) + tmp_index = conn.describe_index( + self._name, index_name, timeout=timeout, **copy_kwargs + ) if tmp_index is not None: - index = Index(self, tmp_index['field_name'], tmp_index, construct_only=True, index_name=index_name) + index = Index( + self, + tmp_index["field_name"], + tmp_index, + construct_only=True, + index_name=index_name, + ) index.drop(timeout=timeout, **kwargs) def compact(self, timeout=None, **kwargs): - """ Compact merge the small segments in a collection + """Compact merge the small segments in a collection Args: timeout (``float``, optional): An optional duration of time in seconds to allow for the RPC. When timeout @@ -1032,7 +1147,7 @@ def compact(self, timeout=None, **kwargs): self.compaction_id = conn.compact(self._name, timeout=timeout, **kwargs) def get_compaction_state(self, timeout=None, **kwargs) -> CompactionState: - """ Get the current compaction state + """Get the current compaction state Args: timeout (``float``, optional): An optional duration of time in seconds to allow for the RPC. When timeout @@ -1045,7 +1160,7 @@ def get_compaction_state(self, timeout=None, **kwargs) -> CompactionState: return conn.get_compaction_state(self.compaction_id, timeout=timeout, **kwargs) def wait_for_compaction_completed(self, timeout=None, **kwargs) -> CompactionState: - """ Block until the current collection's compaction completed + """Block until the current collection's compaction completed Args: timeout (``float``, optional): An optional duration of time in seconds to allow for the RPC. When timeout @@ -1055,7 +1170,9 @@ def wait_for_compaction_completed(self, timeout=None, **kwargs) -> CompactionSta MilvusException: If anything goes wrong. """ conn = self._get_connection() - return conn.wait_for_compaction_completed(self.compaction_id, timeout=timeout, **kwargs) + return conn.wait_for_compaction_completed( + self.compaction_id, timeout=timeout, **kwargs + ) def get_compaction_plans(self, timeout=None, **kwargs) -> CompactionPlans: """Get the current compaction plans diff --git a/pymilvus/orm/connections.py b/pymilvus/orm/connections.py index 7c178ff65..f2e7dda34 100644 --- a/pymilvus/orm/connections.py +++ b/pymilvus/orm/connections.py @@ -10,18 +10,21 @@ # or implied. See the License for the specific language governing permissions and limitations under # the License. -import os import copy +import os import re import threading -from urllib import parse from typing import Tuple +from urllib import parse -from ..client.check import is_legal_host, is_legal_port, is_legal_address +from ..client.check import is_legal_address, is_legal_host, is_legal_port from ..client.grpc_handler import GrpcHandler - -from .default_config import DefaultConfig, ENV_CONNECTION_CONF -from ..exceptions import ExceptionsMessage, ConnectionConfigException, ConnectionNotExistException +from ..exceptions import ( + ConnectionConfigException, + ConnectionNotExistException, + ExceptionsMessage, +) +from .default_config import ENV_CONNECTION_CONF, DefaultConfig def synchronized(func): @@ -56,13 +59,13 @@ def __new__(cls, *args, **kwargs): class Connections(metaclass=SingleInstanceMetaClass): - """ Class for managing all connections of milvus. Used as a singleton in this module. """ + """Class for managing all connections of milvus. Used as a singleton in this module.""" def __init__(self): - """ Constructs a default milvus alias config + """Constructs a default milvus alias config - default config will be read from env: MILVUS_DEFAULT_CONNECTION, - or "localhost:19530" + default config will be read from env: MILVUS_DEFAULT_CONNECTION, + or "localhost:19530" """ self._alias = {} @@ -71,7 +74,7 @@ def __init__(self): self.add_connection(default=self._read_default_config_from_os_env()) def _read_default_config_from_os_env(self): - """ Read default connection config from environment variable: MILVUS_DEFAULT_CONNECTION. + """Read default connection config from environment variable: MILVUS_DEFAULT_CONNECTION. Format is: [@]host[:] @@ -94,19 +97,18 @@ def _read_default_config_from_os_env(self): matched = rex.search(conf) if not matched: - raise ConnectionConfigException(message=ExceptionsMessage.EnvConfigErr % (ENV_CONNECTION_CONF, conf)) + raise ConnectionConfigException( + message=ExceptionsMessage.EnvConfigErr % (ENV_CONNECTION_CONF, conf) + ) user, host, port = matched.groups() user = user or "" port = port or DefaultConfig.DEFAULT_PORT - return { - "user": user, - "address": f"{host}:{port}" - } + return {"user": user, "address": f"{host}:{port}"} def add_connection(self, **kwargs): - """ Configures a milvus connection. + """Configures a milvus connection. Addresses priority in kwargs: address, uri, host and port @@ -140,11 +142,14 @@ def add_connection(self, **kwargs): config.get("address", ""), config.get("uri", ""), config.get("host", ""), - config.get("port", "")) + config.get("port", ""), + ) if alias in self._connected_alias: if self._alias[alias].get("address") != addr: - raise ConnectionConfigException(message=ExceptionsMessage.ConnDiffConf % alias) + raise ConnectionConfigException( + message=ExceptionsMessage.ConnDiffConf % alias + ) alias_config = { "address": addr, @@ -153,10 +158,14 @@ def add_connection(self, **kwargs): self._alias[alias] = alias_config - def __get_full_address(self, address: str = "", uri: str = "", host: str = "", port: str = "") -> str: + def __get_full_address( + self, address: str = "", uri: str = "", host: str = "", port: str = "" + ) -> str: if address != "": if not is_legal_address(address): - raise ConnectionConfigException(message=f"Illegal address: {address}, should be in form 'localhost:19530'") + raise ConnectionConfigException( + message=f"Illegal address: {address}, should be in form 'localhost:19530'" + ) else: address = self.__generate_address(uri, host, port) @@ -168,12 +177,18 @@ def __generate_address(self, uri: str, host: str, port: str) -> str: try: parsed_uri = parse.urlparse(uri) except (Exception) as e: - raise ConnectionConfigException(message=f"{illegal_uri_msg.format(uri)}: <{type(e).__name__}, {e}>") from None + raise ConnectionConfigException( + message=f"{illegal_uri_msg.format(uri)}: <{type(e).__name__}, {e}>" + ) from None if len(parsed_uri.netloc) == 0: raise ConnectionConfigException(message=illegal_uri_msg.format(uri)) - addr = parsed_uri.netloc if ":" in parsed_uri.netloc else f"{parsed_uri.netloc}:{DefaultConfig.DEFAULT_PORT}" + addr = ( + parsed_uri.netloc + if ":" in parsed_uri.netloc + else f"{parsed_uri.netloc}:{DefaultConfig.DEFAULT_PORT}" + ) if not is_legal_address(addr): raise ConnectionConfigException(message=illegal_uri_msg.format(uri)) return addr @@ -186,35 +201,43 @@ def __generate_address(self, uri: str, host: str, port: str) -> str: if not is_legal_port(port): raise ConnectionConfigException(message=ExceptionsMessage.PortType) if not 0 <= int(port) < 65535: - raise ConnectionConfigException(message=f"port number {port} out of range, valid range [0, 65535)") + raise ConnectionConfigException( + message=f"port number {port} out of range, valid range [0, 65535)" + ) return f"{host}:{port}" def disconnect(self, alias: str): - """ Disconnects connection from the registry. + """Disconnects connection from the registry. :param alias: The name of milvus connection :type alias: str """ if not isinstance(alias, str): - raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) + raise ConnectionConfigException( + message=ExceptionsMessage.AliasType % type(alias) + ) if alias in self._connected_alias: self._connected_alias.pop(alias).close() def remove_connection(self, alias: str): - """ Removes connection from the registry. + """Removes connection from the registry. :param alias: The name of milvus connection :type alias: str """ if not isinstance(alias, str): - raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) + raise ConnectionConfigException( + message=ExceptionsMessage.AliasType % type(alias) + ) self.disconnect(alias) self._alias.pop(alias, None) - def connect(self, alias=DefaultConfig.DEFAULT_USING, user="", password="", **kwargs): + def connect( + self, alias=DefaultConfig.DEFAULT_USING, user="", password="", **kwargs + ): """ Constructs a milvus connection and register it under given alias. @@ -263,7 +286,9 @@ def connect(self, alias=DefaultConfig.DEFAULT_USING, user="", password="", **kwa >>> connections.connect("test", host="localhost", port="19530") """ if not isinstance(alias, str): - raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) + raise ConnectionConfigException( + message=ExceptionsMessage.AliasType % type(alias) + ) def connect_milvus(**kwargs): gh = GrpcHandler(**kwargs) @@ -272,8 +297,8 @@ def connect_milvus(**kwargs): timeout = t if isinstance(t, int) else DefaultConfig.DEFAULT_CONNECT_TIMEOUT gh._wait_for_channel_ready(timeout=timeout) - kwargs.pop('password') - kwargs.pop('secure', None) + kwargs.pop("password") + kwargs.pop("secure", None) self._connected_alias[alias] = gh self._alias[alias] = copy.deepcopy(kwargs) @@ -289,7 +314,7 @@ def with_config(config: Tuple) -> bool: kwargs.pop("address", ""), kwargs.pop("uri", ""), kwargs.pop("host", ""), - kwargs.pop("port", "") + kwargs.pop("port", ""), ) if with_config(config): @@ -298,20 +323,24 @@ def with_config(config: Tuple) -> bool: if self.has_connection(alias): if self._alias[alias].get("address") != in_addr: - raise ConnectionConfigException(message=ExceptionsMessage.ConnDiffConf % alias) + raise ConnectionConfigException( + message=ExceptionsMessage.ConnDiffConf % alias + ) connect_milvus(**kwargs, user=user, password=password) else: if alias not in self._alias: - raise ConnectionConfigException(message=ExceptionsMessage.ConnLackConf % alias) + raise ConnectionConfigException( + message=ExceptionsMessage.ConnLackConf % alias + ) connect_alias = dict(self._alias[alias].items()) connect_alias["user"] = user connect_milvus(**connect_alias, password=password, **kwargs) def list_connections(self) -> list: - """ List names of all connections. + """List names of all connections. :return list: Names of all connections. @@ -344,12 +373,14 @@ def get_connection_addr(self, alias: str): {'host': 'localhost', 'port': '19530'} """ if not isinstance(alias, str): - raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) + raise ConnectionConfigException( + message=ExceptionsMessage.AliasType % type(alias) + ) return self._alias.get(alias, {}) def has_connection(self, alias: str) -> bool: - """ Check if connection named alias exists. + """Check if connection named alias exists. :param alias: The name of milvus connection :type alias: str @@ -366,13 +397,17 @@ def has_connection(self, alias: str) -> bool: {'host': 'localhost', 'port': '19530'} """ if not isinstance(alias, str): - raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) + raise ConnectionConfigException( + message=ExceptionsMessage.AliasType % type(alias) + ) return alias in self._connected_alias def _fetch_handler(self, alias=DefaultConfig.DEFAULT_USING) -> GrpcHandler: - """ Retrieves a GrpcHandler by alias. """ + """Retrieves a GrpcHandler by alias.""" if not isinstance(alias, str): - raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) + raise ConnectionConfigException( + message=ExceptionsMessage.AliasType % type(alias) + ) conn = self._connected_alias.get(alias, None) if conn is None: diff --git a/pymilvus/orm/future.py b/pymilvus/orm/future.py index 8fef4df6b..72bc38eb4 100644 --- a/pymilvus/orm/future.py +++ b/pymilvus/orm/future.py @@ -11,8 +11,8 @@ # the License. -from .search import SearchResult from .mutation import MutationResult +from .search import SearchResult # TODO(dragondriver): how could we inherit the docstring elegantly? diff --git a/pymilvus/orm/index.py b/pymilvus/orm/index.py index 2329242ca..d8fc47aae 100644 --- a/pymilvus/orm/index.py +++ b/pymilvus/orm/index.py @@ -12,8 +12,8 @@ import copy -from ..exceptions import CollectionNotExistException, ExceptionsMessage from ..client.configs import DefaultConfigs +from ..exceptions import CollectionNotExistException, ExceptionsMessage class Index: @@ -59,6 +59,7 @@ def __init__(self, collection, field_name, index_params, **kwargs): >>> index.drop() """ from .collection import Collection + if not isinstance(collection, Collection): raise CollectionNotExistException(message=ExceptionsMessage.CollectionType) self._collection = collection @@ -71,7 +72,9 @@ def __init__(self, collection, field_name, index_params, **kwargs): return conn = self._get_connection() - conn.create_index(self._collection.name, self._field_name, self._index_params, **kwargs) + conn.create_index( + self._collection.name, self._field_name, self._index_params, **kwargs + ) indexes = conn.list_indexes(self._collection.name) for index in indexes: if index.field_name == self._field_name: @@ -137,7 +140,7 @@ def to_dict(self): "collection": self._collection._name, "field": self._field_name, "index_name": self._index_name, - "index_param": self.params + "index_param": self.params, } return _dict @@ -157,4 +160,10 @@ def drop(self, timeout=None, **kwargs): copy_kwargs = copy.deepcopy(kwargs) index_name = copy_kwargs.pop("index_name", DefaultConfigs.IndexName) conn = self._get_connection() - conn.drop_index(self._collection.name, self.field_name, index_name, timeout=timeout, **copy_kwargs) + conn.drop_index( + self._collection.name, + self.field_name, + index_name, + timeout=timeout, + **copy_kwargs + ) diff --git a/pymilvus/orm/partition.py b/pymilvus/orm/partition.py index 73a3cff3e..3b1664f0f 100644 --- a/pymilvus/orm/partition.py +++ b/pymilvus/orm/partition.py @@ -13,23 +13,23 @@ import copy import json +from ..client.types import Replica from ..exceptions import ( CollectionNotExistException, - PartitionNotExistException, ExceptionsMessage, + PartitionNotExistException, ) - +from .future import MutationFuture, SearchFuture +from .mutation import MutationResult from .prepare import Prepare from .search import SearchResult -from .mutation import MutationResult -from .future import SearchFuture, MutationFuture -from ..client.types import Replica class Partition: def __init__(self, collection, name, description="", **kwargs): # TODO: Need a place to store the description from .collection import Collection + if not isinstance(collection, Collection): raise CollectionNotExistException(message=ExceptionsMessage.CollectionType) self._collection = collection @@ -53,18 +53,20 @@ def __init__(self, collection, name, description="", **kwargs): self._schema_dict["consistency_level"] = self._consistency_level def __repr__(self): - return json.dumps({ - 'name': self.name, - 'collection_name': self._collection.name, - 'description': self.description, - }) + return json.dumps( + { + "name": self.name, + "collection_name": self._collection.name, + "description": self.description, + } + ) def _get_connection(self): return self._collection._get_connection() @property def description(self) -> str: - """ Return the description text. + """Return the description text. :return: Partition description :rtype: str @@ -152,13 +154,15 @@ def num_entities(self, **kwargs) -> int: 10 """ conn = self._get_connection() - stats = conn.get_partition_stats(collection_name=self._collection.name, partition_name=self._name, **kwargs) + stats = conn.get_partition_stats( + collection_name=self._collection.name, partition_name=self._name, **kwargs + ) result = {stat.key: stat.value for stat in stats} result["row_count"] = int(result["row_count"]) return result["row_count"] def flush(self, timeout=None, **kwargs): - """ Flush """ + """Flush""" conn = self._get_connection() conn.flush([self._collection.name], timeout=timeout, **kwargs) @@ -185,9 +189,18 @@ def drop(self, timeout=None, **kwargs): >>> partition.drop() """ conn = self._get_connection() - if conn.has_partition(self._collection.name, self._name, timeout=timeout, **kwargs) is False: - raise PartitionNotExistException(message=ExceptionsMessage.PartitionNotExist) - return conn.drop_partition(self._collection.name, self._name, timeout=timeout, **kwargs) + if ( + conn.has_partition( + self._collection.name, self._name, timeout=timeout, **kwargs + ) + is False + ): + raise PartitionNotExistException( + message=ExceptionsMessage.PartitionNotExist + ) + return conn.drop_partition( + self._collection.name, self._name, timeout=timeout, **kwargs + ) def load(self, replica_number=1, timeout=None, **kwargs): """ @@ -219,7 +232,13 @@ def load(self, replica_number=1, timeout=None, **kwargs): # if index_names is not None, raise Exception Not Supported conn = self._get_connection() if conn.has_partition(self._collection.name, self._name, **kwargs): - return conn.load_partitions(self._collection.name, [self._name], replica_number, timeout=timeout, **kwargs) + return conn.load_partitions( + self._collection.name, + [self._name], + replica_number, + timeout=timeout, + **kwargs + ) raise PartitionNotExistException(message=ExceptionsMessage.PartitionNotExist) def release(self, timeout=None, **kwargs): @@ -247,7 +266,9 @@ def release(self, timeout=None, **kwargs): """ conn = self._get_connection() if conn.has_partition(self._collection.name, self._name, **kwargs): - return conn.release_partitions(self._collection.name, [self._name], timeout=timeout, **kwargs) + return conn.release_partitions( + self._collection.name, [self._name], timeout=timeout, **kwargs + ) raise PartitionNotExistException(message=ExceptionsMessage.PartitionNotExist) def insert(self, data, timeout=None, **kwargs): @@ -294,17 +315,26 @@ def insert(self, data, timeout=None, **kwargs): """ conn = self._get_connection() if conn.has_partition(self._collection.name, self._name, **kwargs) is False: - raise PartitionNotExistException(message=ExceptionsMessage.PartitionNotExist) + raise PartitionNotExistException( + message=ExceptionsMessage.PartitionNotExist + ) # TODO: check insert data schema here? entities = Prepare.prepare_insert_data(data, self._collection.schema) - res = conn.batch_insert(self._collection.name, entities=entities, partition_name=self._name, - timeout=timeout, orm=True, schema=self._schema_dict, **kwargs) + res = conn.batch_insert( + self._collection.name, + entities=entities, + partition_name=self._name, + timeout=timeout, + orm=True, + schema=self._schema_dict, + **kwargs + ) if kwargs.get("_async", False): return MutationFuture(res) return MutationResult(res) def delete(self, expr, timeout=None, **kwargs): - """ Delete entities with an expression condition. + """Delete entities with an expression condition. :param expr: The expression to specify entities to be deleted :type expr: str @@ -343,14 +373,26 @@ def delete(self, expr, timeout=None, **kwargs): """ conn = self._get_connection() - res = conn.delete(self._collection.name, expr, self.name, timeout=timeout, **kwargs) + res = conn.delete( + self._collection.name, expr, self.name, timeout=timeout, **kwargs + ) if kwargs.get("_async", False): return MutationFuture(res) return MutationResult(res) - def search(self, data, anns_field, param, limit, - expr=None, output_fields=None, timeout=None, round_decimal=-1, **kwargs): - """ Conducts a vector similarity search with an optional boolean expression as filter. + def search( + self, + data, + anns_field, + param, + limit, + expr=None, + output_fields=None, + timeout=None, + round_decimal=-1, + **kwargs + ): + """Conducts a vector similarity search with an optional boolean expression as filter. Args: data (``List[List[float]]``): The vectors of search data. @@ -477,8 +519,20 @@ def search(self, data, anns_field, param, limit, - Top1 hit id: 8, distance: 0.10143111646175385, score: 0.10143111646175385 """ conn = self._get_connection() - res = conn.search(self._collection.name, data, anns_field, param, limit, expr, [self._name], output_fields, - round_decimal=round_decimal, timeout=timeout, schema=self._schema_dict, **kwargs) + res = conn.search( + self._collection.name, + data, + anns_field, + param, + limit, + expr, + [self._name], + output_fields, + round_decimal=round_decimal, + timeout=timeout, + schema=self._schema_dict, + **kwargs + ) if kwargs.get("_async", False): return SearchFuture(res) return SearchResult(res) @@ -551,8 +605,15 @@ def query(self, expr, output_fields=None, timeout=None, **kwargs): - Query results: [{'film_id': 0, 'film_date': 2000}, {'film_id': 1, 'film_date': 2001}] """ conn = self._get_connection() - res = conn.query(self._collection.name, expr, output_fields, [self._name], - timeout=timeout, schema=self._schema_dict, **kwargs) + res = conn.query( + self._collection.name, + expr, + output_fields, + [self._name], + timeout=timeout, + schema=self._schema_dict, + **kwargs + ) return res def get_replicas(self, timeout=None, **kwargs) -> Replica: diff --git a/pymilvus/orm/prepare.py b/pymilvus/orm/prepare.py index 3fc843555..c23763494 100644 --- a/pymilvus/orm/prepare.py +++ b/pymilvus/orm/prepare.py @@ -14,14 +14,20 @@ import numpy import pandas -from ..exceptions import DataNotMatchException, DataTypeNotSupportException, ExceptionsMessage +from ..exceptions import ( + DataNotMatchException, + DataTypeNotSupportException, + ExceptionsMessage, +) class Prepare: @classmethod def prepare_insert_data(cls, data, schema): if not isinstance(data, (list, tuple, pandas.DataFrame)): - raise DataTypeNotSupportException(message=ExceptionsMessage.DataTypeNotSupport) + raise DataTypeNotSupportException( + message=ExceptionsMessage.DataTypeNotSupport + ) fields = schema.fields entities = [] # Entities @@ -31,25 +37,39 @@ def prepare_insert_data(cls, data, schema): if schema.auto_id: if schema.primary_field.name in data: if len(fields) != len(data.columns): - raise DataNotMatchException(message=ExceptionsMessage.FieldsNumInconsistent) + raise DataNotMatchException( + message=ExceptionsMessage.FieldsNumInconsistent + ) if not data[schema.primary_field.name].isnull().all(): - raise DataNotMatchException(message=ExceptionsMessage.AutoIDWithData) + raise DataNotMatchException( + message=ExceptionsMessage.AutoIDWithData + ) else: if len(fields) != len(data.columns) + 1: - raise DataNotMatchException(message=ExceptionsMessage.FieldsNumInconsistent) + raise DataNotMatchException( + message=ExceptionsMessage.FieldsNumInconsistent + ) else: if len(fields) != len(data.columns): - raise DataNotMatchException(message=ExceptionsMessage.FieldsNumInconsistent) + raise DataNotMatchException( + message=ExceptionsMessage.FieldsNumInconsistent + ) for i, field in enumerate(fields): if field.is_primary and field.auto_id: continue - entities.append({"name": field.name, - "type": field.dtype, - "values": list(data[field.name])}) + entities.append( + { + "name": field.name, + "type": field.dtype, + "values": list(data[field.name]), + } + ) raw_lengths.append(len(data[field.name])) else: if schema.auto_id and len(data) != len(fields) - 1: - raise DataNotMatchException(message=ExceptionsMessage.FieldsNumInconsistent) + raise DataNotMatchException( + message=ExceptionsMessage.FieldsNumInconsistent + ) tmp_fields = copy.deepcopy(fields) for i, field in enumerate(tmp_fields): @@ -63,15 +83,16 @@ def prepare_insert_data(cls, data, schema): if isinstance(data[i], numpy.ndarray): data[i] = data[i].tolist() - entities.append({ - "name": field.name, - "type": field.dtype, - "values": data[i]}) + entities.append( + {"name": field.name, "type": field.dtype, "values": data[i]} + ) raw_lengths.append(len(data[i])) # TODO Goose: check correctness AFTER copy is too expensive. lengths = list(set(raw_lengths)) if len(lengths) > 1: - raise DataNotMatchException(message=ExceptionsMessage.DataLengthsInconsistent) + raise DataNotMatchException( + message=ExceptionsMessage.DataLengthsInconsistent + ) return entities diff --git a/pymilvus/orm/role.py b/pymilvus/orm/role.py index 61202fdee..7e41cb4de 100644 --- a/pymilvus/orm/role.py +++ b/pymilvus/orm/role.py @@ -1,12 +1,13 @@ from .connections import connections + class Role: - """ Role, can be granted privileges which are allowed to execute some objects' apis. """ + """Role, can be granted privileges which are allowed to execute some objects' apis.""" def __init__(self, name: str, using="default", **kwargs): - """ Constructs a role by name - :param name: role name. - :type name: str + """Constructs a role by name + :param name: role name. + :type name: str """ self._name = name self._using = using @@ -20,7 +21,7 @@ def name(self): return self._name def create(self): - """ Create a role + """Create a role It will success if the role isn't existed, otherwise fail. :example: @@ -35,7 +36,7 @@ def create(self): return self._get_connection().create_role(self._name) def drop(self): - """ Drop a role + """Drop a role It will success if the role is existed, otherwise fail. :example: @@ -50,7 +51,7 @@ def drop(self): return self._get_connection().drop_role(self._name) def add_user(self, username: str): - """ Add user to role + """Add user to role The user will get permissions that the role are allowed to perform operations. :param username: user name. :type username: str @@ -67,7 +68,7 @@ def add_user(self, username: str): return self._get_connection().add_user_to_role(username, self._name) def remove_user(self, username: str): - """ Remove user from role + """Remove user from role The user will remove permissions that the role are allowed to perform operations. :param username: user name. :type username: str @@ -84,7 +85,7 @@ def remove_user(self, username: str): return self._get_connection().remove_user_from_role(username, self._name) def get_users(self): - """ Get all users who are added to the role. + """Get all users who are added to the role. :return a RoleInfo object which contains a RoleItem group According to the RoleItem, you can get a list of usernames. @@ -105,7 +106,7 @@ def get_users(self): return roles.groups[0].users def is_exist(self): - """ Check whether the role is existed. + """Check whether the role is existed. :return a bool value It will be True if the role is existed, otherwise False. @@ -121,7 +122,7 @@ def is_exist(self): return len(roles.groups) != 0 def grant(self, object: str, object_name: str, privilege: str): - """ Grant a privilege for the role + """Grant a privilege for the role :param object: object type. :type object: str :param object_name: identifies a specific object name. @@ -136,10 +137,12 @@ def grant(self, object: str, object_name: str, privilege: str): >>> role = Role(role_name) >>> role.grant("Collection", collection_name, "Insert") """ - return self._get_connection().grant_privilege(self._name, object, object_name, privilege) + return self._get_connection().grant_privilege( + self._name, object, object_name, privilege + ) def revoke(self, object: str, object_name: str, privilege: str): - """ Revoke a privilege for the role + """Revoke a privilege for the role :param object: object type. :type object: str :param object_name: identifies a specific object name. @@ -154,10 +157,12 @@ def revoke(self, object: str, object_name: str, privilege: str): >>> role = Role(role_name) >>> role.revoke("Collection", collection_name, "Insert") """ - return self._get_connection().revoke_privilege(self._name, object, object_name, privilege) + return self._get_connection().revoke_privilege( + self._name, object, object_name, privilege + ) def list_grant(self, object: str, object_name: str): - """ List a grant info for the role and the specific object + """List a grant info for the role and the specific object :param object: object type. :type object: str :param object_name: identifies a specific object name. @@ -175,10 +180,12 @@ def list_grant(self, object: str, object_name: str): >>> role = Role(role_name) >>> role.list_grant("Collection", collection_name) """ - return self._get_connection().select_grant_for_role_and_object(self._name, object, object_name) + return self._get_connection().select_grant_for_role_and_object( + self._name, object, object_name + ) def list_grants(self): - """ List a grant info for the role + """List a grant info for the role :return a GrantInfo object :rtype GrantInfo diff --git a/pymilvus/orm/schema.py b/pymilvus/orm/schema.py index 5d8b56d55..16cb71131 100644 --- a/pymilvus/orm/schema.py +++ b/pymilvus/orm/schema.py @@ -12,22 +12,23 @@ import copy from typing import List + import pandas from pandas.api.types import is_list_like -from .constants import COMMON_TYPE_PARAMS -from .types import DataType, map_numpy_dtype_to_datatype, infer_dtype_bydata from ..exceptions import ( + AutoIDException, CannotInferSchemaException, + DataNotMatchException, DataTypeNotSupportException, - PrimaryKeyException, + ExceptionsMessage, FieldsTypeException, FieldTypeException, - AutoIDException, - ExceptionsMessage, - DataNotMatchException, + PrimaryKeyException, SchemaNotReadyException, ) +from .constants import COMMON_TYPE_PARAMS +from .types import DataType, infer_dtype_bydata, map_numpy_dtype_to_datatype class CollectionSchema: @@ -45,7 +46,9 @@ def __init__(self, fields, description="", **kwargs): for field in self._fields: if field.is_primary: if primary_field is not None and primary_field != field.name: - raise PrimaryKeyException(message=ExceptionsMessage.PrimaryKeyOnlyOne) + raise PrimaryKeyException( + message=ExceptionsMessage.PrimaryKeyOnlyOne + ) self._primary_field = field primary_field = field.name @@ -59,10 +62,16 @@ def __init__(self, fields, description="", **kwargs): if "auto_id" in kwargs: if not isinstance(self._auto_id, bool): raise AutoIDException(message=ExceptionsMessage.AutoIDType) - if self._primary_field.auto_id is not None and self._primary_field.auto_id != self._auto_id: + if ( + self._primary_field.auto_id is not None + and self._primary_field.auto_id != self._auto_id + ): raise AutoIDException(message=ExceptionsMessage.AutoIDInconsistent) self._primary_field.auto_id = self._auto_id - if self._primary_field.auto_id and self._primary_field.dtype == DataType.VARCHAR: + if ( + self._primary_field.auto_id + and self._primary_field.dtype == DataType.VARCHAR + ): raise AutoIDException(message=ExceptionsMessage.AutoIDFieldType) else: if self._primary_field.auto_id is None: @@ -96,8 +105,10 @@ def __eq__(self, other): @classmethod def construct_from_dict(cls, raw): - fields = [FieldSchema.construct_from_dict(field_raw) for field_raw in raw['fields']] - return CollectionSchema(fields, raw.get('description', "")) + fields = [ + FieldSchema.construct_from_dict(field_raw) for field_raw in raw["fields"] + ] + return CollectionSchema(fields, raw.get("description", "")) @property # TODO: @@ -171,7 +182,9 @@ def __init__(self, name, dtype, description="", **kwargs): try: DataType(dtype) except ValueError: - raise DataTypeNotSupportException(message=ExceptionsMessage.FieldDtype) from None + raise DataTypeNotSupportException( + message=ExceptionsMessage.FieldDtype + ) from None if dtype == DataType.UNKNOWN: raise DataTypeNotSupportException(message=ExceptionsMessage.FieldDtype) self._dtype = dtype @@ -205,7 +218,11 @@ def __deepcopy__(self, memodict=None): def _parse_type_params(self): # update self._type_params according to self._kwargs - if self._dtype not in (DataType.BINARY_VECTOR, DataType.FLOAT_VECTOR, DataType.VARCHAR,): + if self._dtype not in ( + DataType.BINARY_VECTOR, + DataType.FLOAT_VECTOR, + DataType.VARCHAR, + ): return if not self._kwargs: return @@ -221,10 +238,10 @@ def _parse_type_params(self): def construct_from_dict(cls, raw): kwargs = {} kwargs.update(raw.get("params", {})) - kwargs['is_primary'] = raw.get("is_primary", False) + kwargs["is_primary"] = raw.get("is_primary", False) if raw.get("auto_id", None) is not None: - kwargs['auto_id'] = raw.get("auto_id", None) - return FieldSchema(raw['name'], raw['type'], raw['description'], **kwargs) + kwargs["auto_id"] = raw.get("auto_id", None) + return FieldSchema(raw["name"], raw["type"], raw["description"], **kwargs) def to_dict(self): _dict = { @@ -289,8 +306,10 @@ def dtype(self): return self._dtype -def check_insert_data_schema(schema: CollectionSchema, data: [List[List], pandas.DataFrame]) -> None: - """ check if the insert data is consist with the collection schema +def check_insert_data_schema( + schema: CollectionSchema, data: [List[List], pandas.DataFrame] +) -> None: + """check if the insert data is consist with the collection schema Args: schema (CollectionSchema): the schema of the collection @@ -306,7 +325,9 @@ def check_insert_data_schema(schema: CollectionSchema, data: [List[List], pandas if isinstance(data, pandas.DataFrame): if schema.primary_field.name in data: if not data[schema.primary_field.name].isnull().all(): - raise DataNotMatchException(message=f"Please don't provide data for auto_id primary field: {schema.primary_field.name}") + raise DataNotMatchException( + message=f"Please don't provide data for auto_id primary field: {schema.primary_field.name}" + ) data = data.drop(schema.primary_field.name, axis=1) infer_fields = parse_fields_from_data(data) @@ -320,18 +341,26 @@ def check_insert_data_schema(schema: CollectionSchema, data: [List[List], pandas i_name = [f.name for f in infer_fields] t_name = [f.name for f in tmp_fields] - raise DataNotMatchException(message=f"The fields don't match with schema fields, expected: {t_name}, got {i_name}") + raise DataNotMatchException( + message=f"The fields don't match with schema fields, expected: {t_name}, got {i_name}" + ) for x, y in zip(infer_fields, tmp_fields): if x.dtype != y.dtype: - raise DataNotMatchException(message=f"The data type of field {y.name} doesn't match, expected: {y.dtype.name}, got {x.dtype.name}") + raise DataNotMatchException( + message=f"The data type of field {y.name} doesn't match, expected: {y.dtype.name}, got {x.dtype.name}" + ) if isinstance(data, pandas.DataFrame) and x.name != y.name: - raise DataNotMatchException(message=f"The name of field don't match, expected: {y.name}, got {x.name}") + raise DataNotMatchException( + message=f"The name of field don't match, expected: {y.name}, got {x.name}" + ) def parse_fields_from_data(data: [List[List], pandas.DataFrame]) -> List[FieldSchema]: if not isinstance(data, (pandas.DataFrame, list)): - raise DataTypeNotSupportException(message="The type of data should be list or pandas.DataFrame") + raise DataTypeNotSupportException( + message="The type of data should be list or pandas.DataFrame" + ) if isinstance(data, pandas.DataFrame): return parse_fields_from_dataframe(data) @@ -361,9 +390,9 @@ def parse_fields_from_dataframe(df: pandas.DataFrame) -> List[FieldSchema]: if new_dtype in (DataType.BINARY_VECTOR, DataType.FLOAT_VECTOR): vector_type_params = {} if new_dtype == DataType.BINARY_VECTOR: - vector_type_params['dim'] = len(values[i]) * 8 + vector_type_params["dim"] = len(values[i]) * 8 else: - vector_type_params['dim'] = len(values[i]) + vector_type_params["dim"] = len(values[i]) column_params_map[col_names[i]] = vector_type_params data_types[i] = new_dtype diff --git a/pymilvus/orm/search.py b/pymilvus/orm/search.py index e2ee4c2cc..cadb87359 100644 --- a/pymilvus/orm/search.py +++ b/pymilvus/orm/search.py @@ -11,6 +11,7 @@ # the License. import abc + from ..client.abstract import Entity @@ -61,7 +62,7 @@ def on_result(self, res): class DocstringMeta(type): def __new__(cls, name, bases, attrs): doc_meta = attrs.pop("docstring", None) - new_cls = super(DocstringMeta, cls).__new__(cls, name, bases, attrs) + new_cls = super().__new__(cls, name, bases, attrs) if doc_meta: for member_name, member in attrs.items(): if member_name in doc_meta: diff --git a/pymilvus/orm/types.py b/pymilvus/orm/types.py index 0ae3e6cef..b2615e9f2 100644 --- a/pymilvus/orm/types.py +++ b/pymilvus/orm/types.py @@ -10,11 +10,18 @@ # or implied. See the License for the specific language governing permissions and limitations under # the License. -from pandas.api.types import infer_dtype, is_list_like, is_scalar, is_float, is_array_like import numpy as np +from pandas.api.types import ( + infer_dtype, + is_array_like, + is_float, + is_list_like, + is_scalar, +) from ..client.types import DataType + dtype_str_map = { "string": DataType.VARCHAR, "floating": DataType.FLOAT, diff --git a/pymilvus/orm/utility.py b/pymilvus/orm/utility.py index 8f7b10631..691f140e3 100644 --- a/pymilvus/orm/utility.py +++ b/pymilvus/orm/utility.py @@ -10,16 +10,15 @@ # or implied. See the License for the specific language governing permissions and limitations under # the License. -from .connections import connections - +from ..client.types import BulkInsertState +from ..client.utils import hybridts_to_unixtime as _hybridts_to_unixtime +from ..client.utils import mkts_from_datetime as _mkts_from_datetime from ..client.utils import mkts_from_hybridts as _mkts_from_hybridts from ..client.utils import mkts_from_unixtime as _mkts_from_unixtime -from ..client.utils import mkts_from_datetime as _mkts_from_datetime -from ..client.utils import hybridts_to_unixtime as _hybridts_to_unixtime -from ..client.types import BulkInsertState +from .connections import connections -def mkts_from_hybridts(hybridts, milliseconds=0., delta=None): +def mkts_from_hybridts(hybridts, milliseconds=0.0, delta=None): """ Generate a hybrid timestamp based on an existing hybrid timestamp, timedelta and incremental time internval. @@ -55,7 +54,7 @@ def mkts_from_hybridts(hybridts, milliseconds=0., delta=None): return _mkts_from_hybridts(hybridts, milliseconds=milliseconds, delta=delta) -def mkts_from_unixtime(epoch, milliseconds=0., delta=None): +def mkts_from_unixtime(epoch, milliseconds=0.0, delta=None): """ Generate a hybrid timestamp based on Unix Epoch time, timedelta and incremental time internval. @@ -83,7 +82,7 @@ def mkts_from_unixtime(epoch, milliseconds=0., delta=None): return _mkts_from_unixtime(epoch, milliseconds=milliseconds, delta=delta) -def mkts_from_datetime(d_time, milliseconds=0., delta=None): +def mkts_from_datetime(d_time, milliseconds=0.0, delta=None): """ Generate a hybrid timestamp based on datetime, timedelta and incremental time internval. @@ -133,6 +132,7 @@ def hybridts_to_datetime(hybridts, tz=None): >>> d = utility.hybridts_to_datetime(ts) """ import datetime + if tz is not None and not isinstance(tz, datetime.timezone): raise Exception("parameter tz should be type of datetime.timezone") epoch = _hybridts_to_unixtime(hybridts) @@ -165,8 +165,10 @@ def _get_connection(alias): return connections._fetch_handler(alias) -def loading_progress(collection_name, partition_names=None, using="default", timeout=None): - """ Show loading progress of sealed segments in percentage. +def loading_progress( + collection_name, partition_names=None, using="default", timeout=None +): + """Show loading progress of sealed segments in percentage. :param collection_name: The name of collection is loading :type collection_name: str @@ -198,13 +200,17 @@ def loading_progress(collection_name, partition_names=None, using="default", tim >>> utility.loading_progress("test_loading_progress") {'loading_progress': '100%'} """ - progress = _get_connection(using).get_loading_progress(collection_name, partition_names, timeout=timeout) + progress = _get_connection(using).get_loading_progress( + collection_name, partition_names, timeout=timeout + ) return { "loading_progress": f"{progress:.0f}%", } -def wait_for_loading_complete(collection_name, partition_names=None, timeout=None, using="default"): +def wait_for_loading_complete( + collection_name, partition_names=None, timeout=None, using="default" +): """ Block until loading is done or Raise Exception after timeout. @@ -237,11 +243,17 @@ def wait_for_loading_complete(collection_name, partition_names=None, timeout=Non >>> utility.wait_for_loading_complete("test_collection") """ if not partition_names or len(partition_names) == 0: - return _get_connection(using).wait_for_loading_collection(collection_name, timeout=timeout) - return _get_connection(using).wait_for_loading_partitions(collection_name, partition_names, timeout=timeout) + return _get_connection(using).wait_for_loading_collection( + collection_name, timeout=timeout + ) + return _get_connection(using).wait_for_loading_partitions( + collection_name, partition_names, timeout=timeout + ) -def index_building_progress(collection_name, index_name="", using="default", timeout=None): +def index_building_progress( + collection_name, index_name="", using="default", timeout=None +): """ Show # indexed entities vs. # total entities. @@ -284,10 +296,13 @@ def index_building_progress(collection_name, index_name="", using="default", tim >>> utility.index_building_progress("test_collection", c.name) """ return _get_connection(using).get_index_build_progress( - collection_name=collection_name, index_name=index_name, timeout=timeout) + collection_name=collection_name, index_name=index_name, timeout=timeout + ) -def wait_for_index_building_complete(collection_name, index_name="", timeout=None, using="default"): +def wait_for_index_building_complete( + collection_name, index_name="", timeout=None, using="default" +): """ Block until building is done or Raise Exception after timeout. @@ -330,7 +345,9 @@ def wait_for_index_building_complete(collection_name, index_name="", timeout=Non >>> utility.loading_progress("test_collection") """ - return _get_connection(using).wait_for_creating_index(collection_name, index_name, timeout=timeout)[0] + return _get_connection(using).wait_for_creating_index( + collection_name, index_name, timeout=timeout + )[0] def has_collection(collection_name, using="default", timeout=None): @@ -379,7 +396,9 @@ def has_partition(collection_name, partition_name, using="default", timeout=None >>> collection = Collection(name="test_collection", schema=schema) >>> utility.has_partition("_default") """ - return _get_connection(using).has_partition(collection_name, partition_name, timeout=timeout) + return _get_connection(using).has_partition( + collection_name, partition_name, timeout=timeout + ) def drop_collection(collection_name, timeout=None, using="default"): @@ -433,8 +452,15 @@ def list_collections(timeout=None, using="default") -> list: return _get_connection(using).list_collections(timeout=timeout) -def load_balance(collection_name: str, src_node_id, dst_node_ids=None, sealed_segment_ids=None, timeout=None, using="default"): - """ Do load balancing operation from source query node to destination query node. +def load_balance( + collection_name: str, + src_node_id, + dst_node_ids=None, + sealed_segment_ids=None, + timeout=None, + using="default", +): + """Do load balancing operation from source query node to destination query node. :param collection_name: The collection to balance. :type collection_name: str @@ -468,8 +494,9 @@ def load_balance(collection_name: str, src_node_id, dst_node_ids=None, sealed_se dst_node_ids = [] if sealed_segment_ids is None: sealed_segment_ids = [] - return _get_connection(using).\ - load_balance(collection_name, src_node_id, dst_node_ids, sealed_segment_ids, timeout=timeout) + return _get_connection(using).load_balance( + collection_name, src_node_id, dst_node_ids, sealed_segment_ids, timeout=timeout + ) def get_query_segment_info(collection_name, timeout=None, using="default"): @@ -501,11 +528,13 @@ def get_query_segment_info(collection_name, timeout=None, using="default"): >>> collection.load() # load collection to memory >>> res = utility.get_query_segment_info("test_get_segment_info") """ - return _get_connection(using).get_query_segment_info(collection_name, timeout=timeout) + return _get_connection(using).get_query_segment_info( + collection_name, timeout=timeout + ) def create_alias(collection_name: str, alias: str, timeout=None, using="default"): - """ Specify alias for a collection. + """Specify alias for a collection. Alias cannot be duplicated, you can't assign the same alias to different collections. But you can specify multiple aliases for a collection, for example: before create_alias("collection_1", "bob"): @@ -538,7 +567,7 @@ def create_alias(collection_name: str, alias: str, timeout=None, using="default" def drop_alias(alias: str, timeout=None, using="default"): - """ Delete the alias. + """Delete the alias. No need to provide collection name because an alias can only be assigned to one collection and the server knows which collection it belongs. For example: @@ -573,7 +602,7 @@ def drop_alias(alias: str, timeout=None, using="default"): def alter_alias(collection_name: str, alias: str, timeout=None, using="default"): - """ Change the alias of a collection to another collection. + """Change the alias of a collection to another collection. Raise error if the alias doesn't exist. Alias cannot be duplicated, you can't assign same alias to different collections. This api can change alias owner collection, for example: @@ -613,7 +642,7 @@ def alter_alias(collection_name: str, alias: str, timeout=None, using="default") def list_aliases(collection_name: str, timeout=None, using="default"): - """ Returns alias list of the collection. + """Returns alias list of the collection. :return list of str: The collection aliases, returned when the operation succeeds. @@ -637,8 +666,15 @@ def list_aliases(collection_name: str, timeout=None, using="default"): return aliases -def do_bulk_insert(collection_name: str, files: list, partition_name=None, timeout=None, using="default", **kwargs) -> int: - """ do_bulk_insert inserts entities through files, currently supports row-based json file. +def do_bulk_insert( + collection_name: str, + files: list, + partition_name=None, + timeout=None, + using="default", + **kwargs, +) -> int: + """do_bulk_insert inserts entities through files, currently supports row-based json file. User need to create the json file with a specified json format which is described in the official user guide. Let's say a collection has two fields: "id" and "vec"(dimension=8), the row-based json format is: {"rows": [ @@ -683,10 +719,14 @@ def do_bulk_insert(collection_name: str, files: list, partition_name=None, timeo >>> task_id = utility.do_bulk_insert(collection_name=collection.name, files=['data.json']) >>> print(task_id) """ - return _get_connection(using).do_bulk_insert(collection_name, partition_name, files, timeout=timeout, **kwargs) + return _get_connection(using).do_bulk_insert( + collection_name, partition_name, files, timeout=timeout, **kwargs + ) -def get_bulk_insert_state(task_id, timeout=None, using="default", **kwargs) -> BulkInsertState: +def get_bulk_insert_state( + task_id, timeout=None, using="default", **kwargs +) -> BulkInsertState: """get_bulk_insert_state returns state of a certain task_id :param task_id: the task id returned by bulk_insert @@ -702,10 +742,14 @@ def get_bulk_insert_state(task_id, timeout=None, using="default", **kwargs) -> B >>> if state.state == BulkInsertState.ImportFailed or state.state == BulkInsertState.ImportFailedAndCleaned: >>> print("task id:", state.task_id, "failed, reason:", state.failed_reason) """ - return _get_connection(using).get_bulk_insert_state(task_id, timeout=timeout, **kwargs) + return _get_connection(using).get_bulk_insert_state( + task_id, timeout=timeout, **kwargs + ) -def list_bulk_insert_tasks(limit=0, collection_name=None, timeout=None, using="default", **kwargs) -> list: +def list_bulk_insert_tasks( + limit=0, collection_name=None, timeout=None, using="default", **kwargs +) -> list: """list_bulk_insert_tasks lists all bulk load tasks :param limit: maximum number of tasks returned, list all tasks if the value is 0, else return the latest tasks @@ -723,10 +767,14 @@ def list_bulk_insert_tasks(limit=0, collection_name=None, timeout=None, using="d >>> tasks = utility.list_bulk_insert_tasks(collection_name=collection_name) >>> print(tasks) """ - return _get_connection(using).list_bulk_insert_tasks(limit, collection_name, timeout=timeout, **kwargs) + return _get_connection(using).list_bulk_insert_tasks( + limit, collection_name, timeout=timeout, **kwargs + ) -def reset_password(user: str, old_password: str, new_password: str, using="default", timeout=None): +def reset_password( + user: str, old_password: str, new_password: str, using="default", timeout=None +): """ Reset the user & password of the connection. You must provide the original password to check if the operation is valid. @@ -746,11 +794,13 @@ def reset_password(user: str, old_password: str, new_password: str, using="defau >>> users = utility.list_usernames() >>> print(f"users in Milvus: {users}") """ - return _get_connection(using).reset_password(user, old_password, new_password, timeout=timeout) + return _get_connection(using).reset_password( + user, old_password, new_password, timeout=timeout + ) def create_user(user: str, password: str, using="default", timeout=None): - """ Create User using the given user and password. + """Create User using the given user and password. :param user: the user name. :type user: str :param password: the password. @@ -767,7 +817,9 @@ def create_user(user: str, password: str, using="default", timeout=None): return _get_connection(using).create_user(user, password, timeout=timeout) -def update_password(user: str, old_password, new_password: str, using="default", timeout=None): +def update_password( + user: str, old_password, new_password: str, using="default", timeout=None +): """ Update user password using the given user and password. You must provide the original password to check if the operation is valid. @@ -789,11 +841,13 @@ def update_password(user: str, old_password, new_password: str, using="default", >>> users = utility.list_usernames() >>> print(f"users in Milvus: {users}") """ - return _get_connection(using).update_password(user, old_password, new_password, timeout=timeout) + return _get_connection(using).update_password( + user, old_password, new_password, timeout=timeout + ) def delete_user(user: str, using="default", timeout=None): - """ Delete User corresponding to the username. + """Delete User corresponding to the username. :param user: the user name. :type user: str @@ -808,7 +862,7 @@ def delete_user(user: str, using="default", timeout=None): def list_usernames(using="default", timeout=None): - """ List all usernames. + """List all usernames. :return list of str: The usernames in Milvus instances. @@ -822,7 +876,7 @@ def list_usernames(using="default", timeout=None): def list_roles(include_user_info: bool, using="default", timeout=None): - """ List All Role Info + """List All Role Info :param include_user_info: whether to obtain the user information associated with roles :type include_user_info: bool :return RoleInfo @@ -837,7 +891,7 @@ def list_roles(include_user_info: bool, using="default", timeout=None): def list_user(username: str, include_role_info: bool, using="default", timeout=None): - """ List One User Info + """List One User Info :param username: user name. :type username: str :param include_role_info: whether to obtain the role information associated with the user @@ -850,11 +904,13 @@ def list_user(username: str, include_role_info: bool, using="default", timeout=N >>> user = utility.list_user(username, include_role_info) >>> print(f"user info: {user}") """ - return _get_connection(using).select_one_user(username, include_role_info, timeout=timeout) + return _get_connection(using).select_one_user( + username, include_role_info, timeout=timeout + ) def list_users(include_role_info: bool, using="default", timeout=None): - """ List All User Info + """List All User Info :param include_role_info: whether to obtain the role information associated with users :type include_role_info: bool :return UserInfo @@ -867,8 +923,9 @@ def list_users(include_role_info: bool, using="default", timeout=None): """ return _get_connection(using).select_all_user(include_role_info, timeout=timeout) + def get_server_version(using="default", timeout=None) -> str: - """ get the running server's version + """get the running server's version :returns: server's version :rtype: str diff --git a/pymilvus/settings.py b/pymilvus/settings.py index b1a3b46cd..292cf7dda 100644 --- a/pymilvus/settings.py +++ b/pymilvus/settings.py @@ -15,20 +15,20 @@ class DefaultConfig: # logging COLORS = { - 'HEADER': '\033[95m', - 'INFO': '\033[92m', - 'DEBUG': '\033[94m', - 'WARNING': '\033[93m', - 'ERROR': '\033[95m', - 'CRITICAL': '\033[91m', - 'ENDC': '\033[0m', + "HEADER": "\033[95m", + "INFO": "\033[92m", + "DEBUG": "\033[94m", + "WARNING": "\033[93m", + "ERROR": "\033[95m", + "CRITICAL": "\033[91m", + "ENDC": "\033[0m", } class ColorFulFormatColMixin: def format_col(self, message_str, level_name): if level_name in COLORS: - message_str = COLORS.get(level_name) + message_str + COLORS.get('ENDC') + message_str = COLORS.get(level_name) + message_str + COLORS.get("ENDC") return message_str @@ -39,40 +39,40 @@ def format(self, record): return self.format_col(message_str, level_name=record.levelname) -LOG_LEVEL = 'WARNING' +LOG_LEVEL = "WARNING" # LOG_LEVEL = 'DEBUG' LOGGING = { - 'version': 1, - 'disable_existing_loggers': False, - 'handlers': { - 'console': { - 'class': 'logging.StreamHandler', - 'level': LOG_LEVEL, + "version": 1, + "disable_existing_loggers": False, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "level": LOG_LEVEL, }, }, - 'loggers': { - 'milvus': { - 'handlers': ['console'], - 'level': LOG_LEVEL, + "loggers": { + "milvus": { + "handlers": ["console"], + "level": LOG_LEVEL, }, }, } -if LOG_LEVEL == 'DEBUG': - LOGGING['formatters'] = { - 'colorful_console': { - 'format': '[%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)', - '()': ColorfulFormatter, +if LOG_LEVEL == "DEBUG": + LOGGING["formatters"] = { + "colorful_console": { + "format": "[%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)", + "()": ColorfulFormatter, }, } - LOGGING['handlers']['milvus_console'] = { - 'class': 'logging.StreamHandler', - 'formatter': 'colorful_console', + LOGGING["handlers"]["milvus_console"] = { + "class": "logging.StreamHandler", + "formatter": "colorful_console", } - LOGGING['loggers']['milvus'] = { - 'handlers': ['milvus_console'], - 'level': LOG_LEVEL, + LOGGING["loggers"]["milvus"] = { + "handlers": ["milvus_console"], + "level": LOG_LEVEL, } logging.config.dictConfig(LOGGING) diff --git a/requirements.txt b/requirements.txt index a11778635..0d5574f83 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,7 @@ mmh3>=2.0,<=3.0.0 packaging==20.9 pep517==0.10.0 pyparsing==2.4.7 +pre-commit==2.21.0 six==1.16.0 toml==0.10.2 ujson>=2.0.0,<=5.4.0 diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 000000000..648e7c523 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,12 @@ +[isort] +line_length = 88 +known_first_party = pymilvus +multi_line_output = 3 +default_section = THIRDPARTY +skip = venv/ +skip_glob = **/grpc_gen/** +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true +profile = black +lines_after_imports = 2