Skip to content

Commit

Permalink
enhance: upsert support autoid (#2173)
Browse files Browse the repository at this point in the history
milvus-io/milvus#29258

Signed-off-by: lixinguo <[email protected]>
Co-authored-by: lixinguo <[email protected]>
  • Loading branch information
smellthemoon and lixinguo authored Jul 12, 2024
1 parent 8e0a27b commit 6625af7
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 46 deletions.
3 changes: 1 addition & 2 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,6 @@ def _prepare_batch_upsert_request(
entities: List,
partition_name: Optional[str] = None,
timeout: Optional[float] = None,
is_insert: bool = True,
**kwargs,
):
param = kwargs.get("upsert_param")
Expand Down Expand Up @@ -661,7 +660,7 @@ def upsert(

try:
request = self._prepare_batch_upsert_request(
collection_name, entities, partition_name, timeout, False, **kwargs
collection_name, entities, partition_name, timeout, **kwargs
)
rf = self._stub.Upsert.future(request, timeout=timeout)
if kwargs.get("_async", False) is True:
Expand Down
36 changes: 32 additions & 4 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
ResourceGroupConfig,
get_consistency_level,
)
from .utils import traverse_info, traverse_rows_info
from .utils import traverse_info, traverse_rows_info, traverse_upsert_info


class Prepare:
Expand Down Expand Up @@ -462,7 +462,7 @@ def row_upsert_param(
return cls._parse_row_request(request, fields_info, enable_dynamic, entities)

@staticmethod
def _pre_batch_check(
def _pre_insert_batch_check(
entities: List,
fields_info: Any,
):
Expand Down Expand Up @@ -493,6 +493,34 @@ def _pre_batch_check(
raise ParamError(msg)
return location

@staticmethod
def _pre_upsert_batch_check(
entities: List,
fields_info: Any,
):
for entity in entities:
if (
entity.get("name") is None
or entity.get("values") is None
or entity.get("type") is None
):
raise ParamError(
message="Missing param in entities, a field must have type, name and values"
)
if not fields_info:
raise ParamError(message="Missing collection meta to validate entities")

location, primary_key_loc = traverse_upsert_info(fields_info)

# though impossible from sdk
if primary_key_loc is None:
raise ParamError(message="primary key not found")

if len(entities) != len(fields_info):
msg = f"number of fields: {len(fields_info)}, number of entities: {len(entities)}"
raise ParamError(msg)
return location

@staticmethod
def _parse_batch_request(
request: Union[milvus_types.InsertRequest, milvus_types.UpsertRequest],
Expand Down Expand Up @@ -533,7 +561,7 @@ def batch_insert_param(
partition_name: str,
fields_info: Any,
):
location = cls._pre_batch_check(entities, fields_info)
location = cls._pre_insert_batch_check(entities, fields_info)
tag = partition_name if isinstance(partition_name, str) else ""
request = milvus_types.InsertRequest(collection_name=collection_name, partition_name=tag)

Expand All @@ -547,7 +575,7 @@ def batch_upsert_param(
partition_name: str,
fields_info: Any,
):
location = cls._pre_batch_check(entities, fields_info)
location = cls._pre_upsert_batch_check(entities, fields_info)
tag = partition_name if isinstance(partition_name, str) else ""
request = milvus_types.UpsertRequest(collection_name=collection_name, partition_name=tag)

Expand Down
11 changes: 11 additions & 0 deletions pymilvus/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,17 @@ def traverse_info(fields_info: Any):
return location, primary_key_loc, auto_id_loc


def traverse_upsert_info(fields_info: Any):
location, primary_key_loc = {}, None
for i, field in enumerate(fields_info):
if field.get("is_primary", False):
primary_key_loc = i

location[field["name"]] = i

return location, primary_key_loc


def get_server_type(host: str):
return ZILLIZ if (isinstance(host, str) and "zilliz" in host.lower()) else MILVUS

Expand Down
5 changes: 1 addition & 4 deletions pymilvus/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,6 @@ class InvalidConsistencyLevel(MilvusException):
"""Raise when consistency level is invalid"""


class UpsertAutoIDTrueException(MilvusException):
"""Raise when upsert autoID is true"""


class ExceptionsMessage:
NoHostPort = "connection configuration must contain 'host' and 'port'."
HostType = "Type of 'host' must be str."
Expand Down Expand Up @@ -234,3 +230,4 @@ class ExceptionsMessage:
ClusteringKeyOnlyOne = "Expected only one clustering key field, got [%s, %s, ...]."
IsClusteringKeyType = "Param is_clustering_key must be bool type."
ClusteringKeyFieldType = "Param clustering_key_field must be str type."
UpsertPrimaryKeyEmpty = "Upsert need to assign pk"
8 changes: 2 additions & 6 deletions pymilvus/orm/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
IndexNotExistException,
PartitionAlreadyExistException,
SchemaNotReadyException,
UpsertAutoIDTrueException,
)
from pymilvus.grpc_gen import schema_pb2
from pymilvus.settings import Config
Expand Down Expand Up @@ -511,7 +510,7 @@ def insert(
)

check_insert_schema(self.schema, data)
entities = Prepare.prepare_insert_data(data, self.schema)
entities = Prepare.prepare_data(data, self.schema)
return conn.batch_insert(
self._name,
entities,
Expand Down Expand Up @@ -622,9 +621,6 @@ def upsert(
10
"""

if self.schema.auto_id:
raise UpsertAutoIDTrueException(message=ExceptionsMessage.UpsertAutoIDTrue)

if not is_valid_insert_data(data):
raise DataTypeNotSupportException(
message="The type of data should be List, pd.DataFrame or Dict"
Expand All @@ -643,7 +639,7 @@ def upsert(
return MutationResult(res)

check_upsert_schema(self.schema, data)
entities = Prepare.prepare_upsert_data(data, self.schema)
entities = Prepare.prepare_data(data, self.schema, False)
res = conn.upsert(
self._name,
entities,
Expand Down
21 changes: 5 additions & 16 deletions pymilvus/orm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,24 @@
import numpy as np
import pandas as pd

from pymilvus.client import utils
from pymilvus.client.types import DataType
from pymilvus.exceptions import (
DataNotMatchException,
DataTypeNotSupportException,
ExceptionsMessage,
ParamError,
UpsertAutoIDTrueException,
)

from .schema import CollectionSchema


class Prepare:
@classmethod
def prepare_insert_data(
def prepare_data(
cls,
data: Union[List, Tuple, pd.DataFrame],
schema: CollectionSchema,
is_insert: bool = True,
) -> List:
if not isinstance(data, (list, tuple, pd.DataFrame)):
raise DataTypeNotSupportException(message=ExceptionsMessage.DataTypeNotSupport)
Expand All @@ -46,12 +45,13 @@ def prepare_insert_data(
if (
schema.auto_id
and schema.primary_field.name in data
and is_insert
and not data[schema.primary_field.name].isnull().all()
):
raise DataNotMatchException(message=ExceptionsMessage.AutoIDWithData)
# TODO(SPARSE): support pd.SparseDtype for sparse float vector field
for field in fields:
if field.is_primary and field.auto_id:
if field.is_primary and field.auto_id and is_insert:
continue
values = []
if field.name in list(data.columns):
Expand All @@ -63,7 +63,7 @@ def prepare_insert_data(
for i, field in enumerate(tmp_fields):
# TODO Goose: Checking auto_id and is_primary only, maybe different than
# schema.is_primary, schema.auto_id, need to check why and how schema is built.
if field.is_primary and field.auto_id:
if field.is_primary and field.auto_id and is_insert:
tmp_fields.pop(i)

vec_dtype_checker = {
Expand Down Expand Up @@ -152,14 +152,3 @@ def prepare_insert_data(
entities.append({"name": field.name, "type": field.dtype, "values": d})

return entities

@classmethod
def prepare_upsert_data(
cls,
data: Union[List, Tuple, pd.DataFrame, utils.SparseMatrixInputType],
schema: CollectionSchema,
) -> List:
if schema.auto_id:
raise UpsertAutoIDTrueException(message=ExceptionsMessage.UpsertAutoIDTrue)

return cls.prepare_insert_data(data, schema)
30 changes: 16 additions & 14 deletions pymilvus/orm/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
PartitionKeyException,
PrimaryKeyException,
SchemaNotReadyException,
UpsertAutoIDTrueException,
)

from .constants import COMMON_TYPE_PARAMS
Expand Down Expand Up @@ -485,28 +484,23 @@ def _check_insert_data(data: Union[List[List], pd.DataFrame]):
raise DataTypeNotSupportException(message="data should be a list of list")


def _check_data_schema_cnt(schema: CollectionSchema, data: Union[List[List], pd.DataFrame]):
tmp_fields = copy.deepcopy(schema.fields)
for i, field in enumerate(tmp_fields):
if field.is_primary and field.auto_id:
tmp_fields.pop(i)

field_cnt = len(tmp_fields)
def _check_data_schema_cnt(fields: List, data: Union[List[List], pd.DataFrame]):
field_cnt = len(fields)
is_dataframe = isinstance(data, pd.DataFrame)
data_cnt = len(data.columns) if is_dataframe else len(data)
if field_cnt != data_cnt:
message = (
f"The data don't match with schema fields, expect {field_cnt} list, got {len(data)}"
)
if is_dataframe:
i_name = [f.name for f in tmp_fields]
i_name = [f.name for f in fields]
t_name = list(data.columns)
message = f"The fields don't match with schema fields, expected: {i_name}, got {t_name}"

raise DataNotMatchException(message=message)

if is_dataframe:
for x, y in zip(list(data.columns), tmp_fields):
for x, y in zip(list(data.columns), fields):
if x != y.name:
raise DataNotMatchException(
message=f"The name of field don't match, expected: {y.name}, got {x}"
Expand All @@ -524,17 +518,25 @@ def check_insert_schema(schema: CollectionSchema, data: Union[List[List], pd.Dat
columns.remove(schema.primary_field)
data = data[[columns]]

_check_data_schema_cnt(schema, data)
tmp_fields = copy.deepcopy(schema.fields)
for i, field in enumerate(tmp_fields):
if field.is_primary and field.auto_id:
tmp_fields.pop(i)

_check_data_schema_cnt(tmp_fields, data)
_check_insert_data(data)


def check_upsert_schema(schema: CollectionSchema, data: Union[List[List], pd.DataFrame]):
if schema is None:
raise SchemaNotReadyException(message="Schema shouldn't be None")
if schema.auto_id:
raise UpsertAutoIDTrueException(message=ExceptionsMessage.UpsertAutoIDTrue)
if isinstance(data, pd.DataFrame):
if schema.primary_field.name not in data or data[schema.primary_field.name].isnull().all():
raise DataNotMatchException(message=ExceptionsMessage.UpsertPrimaryKeyEmpty)
columns = list(data.columns)
data = data[[columns]]

_check_data_schema_cnt(schema, data)
_check_data_schema_cnt(copy.deepcopy(schema.fields), data)
_check_insert_data(data)


Expand Down

0 comments on commit 6625af7

Please sign in to comment.