From 1f1f964082fea477e325360ad69db090bab19697 Mon Sep 17 00:00:00 2001 From: "cai.zhang" Date: Fri, 22 Sep 2023 19:43:18 +0800 Subject: [PATCH] Support new DataType: Array (#1681) Signed-off-by: cai.zhang --- pymilvus/client/abstract.py | 12 +++++- pymilvus/client/entity_helper.py | 71 ++++++++++++++++++++++++++++++++ pymilvus/client/prepare.py | 1 + pymilvus/client/types.py | 1 + pymilvus/client/utils.py | 3 ++ pymilvus/orm/constants.py | 2 +- pymilvus/orm/schema.py | 11 ++++- 7 files changed, 98 insertions(+), 3 deletions(-) diff --git a/pymilvus/client/abstract.py b/pymilvus/client/abstract.py index c54f04135..3dfb3e751 100644 --- a/pymilvus/client/abstract.py +++ b/pymilvus/client/abstract.py @@ -72,7 +72,8 @@ def __init__(self, raw: Any): self.params = {} self.is_partition_key = False self.is_dynamic = False - + # For array field + self.element_type = None ## self.__pack(self._raw) @@ -84,6 +85,7 @@ def __pack(self, raw: Any): self.auto_id = raw.autoID self.type = raw.data_type self.is_partition_key = raw.is_partition_key + self.element_type = raw.element_type try: self.is_dynamic = raw.is_dynamic except Exception: @@ -122,6 +124,7 @@ def dict(self): "description": self.description, "type": self.type, "params": self.params or {}, + "element_type": self.element_type, } if self.is_partition_key: @@ -561,6 +564,13 @@ def _pack(self, raw_list: List): field.scalars.json_data.data.extend( field_data.scalars.json_data.data[start_pos:end_pos] ) + elif field_data.type == DataType.ARRAY: + field.scalars.array_data.data.extend( + field_data.scalars.array_data.data[start_pos:end_pos] + ) + field.scalars.array_data.element_type = ( + field_data.scalars.array_data.element_type + ) elif field_data.type == DataType.FLOAT_VECTOR: dim = field_data.vectors.dim field.vectors.dim = dim diff --git a/pymilvus/client/entity_helper.py b/pymilvus/client/entity_helper.py index 4ca41ab1c..fb8debba7 100644 --- a/pymilvus/client/entity_helper.py +++ b/pymilvus/client/entity_helper.py @@ -70,6 +70,40 @@ def entity_to_json_arr(entity: Dict): return convert_to_json_arr(entity.get("values", [])) +def convert_to_array_arr(objs: List[Any]): + return [convert_to_array_arr(obj) for obj in objs] + + +def convert_to_array(obj: List[Any], field_info: Any): + field_data = schema_types.ScalarField() + element_type = field_info.get("element_type", None) + if element_type == DataType.BOOL: + field_data.bool_data.data.extend(obj) + return field_data + if element_type in (DataType.INT8, DataType.INT16, DataType.INT32): + field_data.int_data.data.extend(obj) + return field_data + if element_type == DataType.INT64: + field_data.long_data.data.extend(obj) + return field_data + if element_type == DataType.FLOAT: + field_data.float_data.data.extend(obj) + return field_data + if element_type == DataType.DOUBLE: + field_data.double_data.data.extend(obj) + return field_data + if element_type in (DataType.VARCHAR, DataType.STRING): + field_data.string_data.data.extend(obj) + return field_data + raise ParamError( + message=f"UnSupported element type: {element_type} for Array field: {field_info.get('name')}" + ) + + +def entity_to_array_arr(entity: List[Any]): + return convert_to_array_arr(entity.get("values", [])) + + def pack_field_value_to_field_data(field_value: Any, field_data: Any, field_info: Any): field_type = field_data.type if field_type == DataType.BOOL: @@ -94,6 +128,8 @@ def pack_field_value_to_field_data(field_value: Any, field_data: Any, field_info ) elif field_type == DataType.JSON: field_data.scalars.json_data.data.append(convert_to_json(field_value)) + elif field_type == DataType.ARRAY: + field_data.scalars.array_data.data.append(convert_to_array(field_value, field_info)) else: raise ParamError(message=f"UnSupported data type: {field_type}") @@ -129,6 +165,8 @@ def entity_to_field_data(entity: Any, field_info: Any): ) elif entity_type == DataType.JSON: field_data.scalars.json_data.data.extend(entity_to_json_arr(entity)) + elif entity_type == DataType.ARRAY: + field_data.scalars.array_data.data.extend(entity_to_array_arr(entity)) else: raise ParamError(message=f"UnSupported data type: {entity_type}") @@ -154,6 +192,37 @@ def extract_dynamic_field_from_result(raw: Any): return dynamic_field_name, dynamic_fields +def extract_array_row_data(field_data: Any, index: int): + array = field_data.scalars.array_data.data[index] + row = [] + if field_data.scalars.array_data.element_type == DataType.INT64: + row.extend(array.long_data.data) + return row + if field_data.scalars.array_data.element_type == DataType.BOOL: + row.extend(array.bool_data.data) + return row + if field_data.scalars.array_data.element_type in ( + DataType.INT8, + DataType.INT16, + DataType.INT32, + ): + row.extend(array.int_data.data) + return row + if field_data.scalars.array_data.element_type == DataType.FLOAT: + row.extend(array.float_data.data) + return row + if field_data.scalars.array_data.element_type == DataType.DOUBLE: + row.extend(array.double_data.data) + return row + if field_data.scalars.array_data.element_type in ( + DataType.STRING, + DataType.VARCHAR, + ): + row.extend(array.string_data.data) + return row + return row + + # pylint: disable=R1702 (too-many-nested-blocks) def extract_row_data_from_fields_data( fields_data: Any, @@ -216,6 +285,8 @@ def check_append(field_data: Any): tmp_dict = {k: v for k, v in json_dict.items() if k in dynamic_fields} entity_row_data.update(tmp_dict) return + if field_data.type == DataType.ARRAY and len(field_data.scalars.array_data.data) >= index: + entity_row_data[field_data.field_name] = extract_array_row_data(field_data, index) if field_data.type == DataType.FLOAT_VECTOR: dim = field_data.vectors.dim diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index 5a05c2bc8..3a946f9ea 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -116,6 +116,7 @@ def get_schema_from_collection_schema( autoID=f.auto_id, is_partition_key=f.is_partition_key, is_dynamic=f.is_dynamic, + element_type=f.element_type, ) for k, v in f.params.items(): kv_pair = common_types.KeyValuePair(key=str(k), value=str(v)) diff --git a/pymilvus/client/types.py b/pymilvus/client/types.py index 40736277e..3e540ed54 100644 --- a/pymilvus/client/types.py +++ b/pymilvus/client/types.py @@ -81,6 +81,7 @@ class DataType(IntEnum): STRING = 20 VARCHAR = 21 + ARRAY = 22 JSON = 23 BINARY_VECTOR = 100 diff --git a/pymilvus/client/utils.py b/pymilvus/client/utils.py index aa054aa03..f1c928cf5 100644 --- a/pymilvus/client/utils.py +++ b/pymilvus/client/utils.py @@ -151,6 +151,9 @@ def len_of(field_data: Any) -> int: if field_data.scalars.HasField("json_data"): return len(field_data.scalars.json_data.data) + if field_data.scalars.HasField("array_data"): + return len(field_data.scalars.array_data.data) + raise MilvusException(message="Unsupported scalar type") if field_data.HasField("vectors"): diff --git a/pymilvus/orm/constants.py b/pymilvus/orm/constants.py index e85dcfc3a..57a5d1648 100644 --- a/pymilvus/orm/constants.py +++ b/pymilvus/orm/constants.py @@ -10,7 +10,7 @@ # or implied. See the License for the specific language governing permissions and limitations under # the License. -COMMON_TYPE_PARAMS = ("dim", "max_length") +COMMON_TYPE_PARAMS = ("dim", "max_length", "max_capacity") CALC_DIST_IDS = "ids" CALC_DIST_FLOAT_VEC = "float_vectors" diff --git a/pymilvus/orm/schema.py b/pymilvus/orm/schema.py index 802f3314d..bdab6b6f9 100644 --- a/pymilvus/orm/schema.py +++ b/pymilvus/orm/schema.py @@ -271,6 +271,7 @@ def __init__(self, name: str, dtype: DataType, description: str = "", **kwargs) if not isinstance(kwargs.get("is_partition_key", False), bool): raise PartitionKeyException(message=ExceptionsMessage.IsPartitionKeyType) self.is_partition_key = kwargs.get("is_partition_key", False) + self.element_type = kwargs.get("element_type", None) self._parse_type_params() def __repr__(self) -> str: @@ -283,7 +284,12 @@ def __deepcopy__(self, memodict: Optional[Dict] = None): def _parse_type_params(self): # update self._type_params according to self._kwargs - if self._dtype not in (DataType.BINARY_VECTOR, DataType.FLOAT_VECTOR, DataType.VARCHAR): + if self._dtype not in ( + DataType.BINARY_VECTOR, + DataType.FLOAT_VECTOR, + DataType.VARCHAR, + DataType.ARRAY, + ): return if not self._kwargs: return @@ -304,6 +310,7 @@ def construct_from_dict(cls, raw: Dict): kwargs["auto_id"] = raw.get("auto_id", None) kwargs["is_partition_key"] = raw.get("is_partition_key", False) kwargs["is_dynamic"] = raw.get("is_dynamic", False) + kwargs["element_type"] = raw.get("element_type", None) return FieldSchema(raw["name"], raw["type"], raw.get("description", ""), **kwargs) def to_dict(self): @@ -321,6 +328,8 @@ def to_dict(self): _dict["is_partition_key"] = True if self.is_dynamic: _dict["is_dynamic"] = self.is_dynamic + if self.dtype == DataType.ARRAY and self.element_type: + _dict["element_type"] = self.element_type return _dict def __getattr__(self, item: str):