Skip to content

Commit

Permalink
Support new DataType: Array (#1681)
Browse files Browse the repository at this point in the history
Signed-off-by: cai.zhang <[email protected]>
  • Loading branch information
xiaocai2333 authored Sep 22, 2023
1 parent 5a685eb commit 1f1f964
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 3 deletions.
12 changes: 11 additions & 1 deletion pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def __init__(self, raw: Any):
self.params = {}
self.is_partition_key = False
self.is_dynamic = False

# For array field
self.element_type = None
##
self.__pack(self._raw)

Expand All @@ -84,6 +85,7 @@ def __pack(self, raw: Any):
self.auto_id = raw.autoID
self.type = raw.data_type
self.is_partition_key = raw.is_partition_key
self.element_type = raw.element_type
try:
self.is_dynamic = raw.is_dynamic
except Exception:
Expand Down Expand Up @@ -122,6 +124,7 @@ def dict(self):
"description": self.description,
"type": self.type,
"params": self.params or {},
"element_type": self.element_type,
}

if self.is_partition_key:
Expand Down Expand Up @@ -561,6 +564,13 @@ def _pack(self, raw_list: List):
field.scalars.json_data.data.extend(
field_data.scalars.json_data.data[start_pos:end_pos]
)
elif field_data.type == DataType.ARRAY:
field.scalars.array_data.data.extend(
field_data.scalars.array_data.data[start_pos:end_pos]
)
field.scalars.array_data.element_type = (
field_data.scalars.array_data.element_type
)
elif field_data.type == DataType.FLOAT_VECTOR:
dim = field_data.vectors.dim
field.vectors.dim = dim
Expand Down
71 changes: 71 additions & 0 deletions pymilvus/client/entity_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,40 @@ def entity_to_json_arr(entity: Dict):
return convert_to_json_arr(entity.get("values", []))


def convert_to_array_arr(objs: List[Any]):
return [convert_to_array_arr(obj) for obj in objs]


def convert_to_array(obj: List[Any], field_info: Any):
field_data = schema_types.ScalarField()
element_type = field_info.get("element_type", None)
if element_type == DataType.BOOL:
field_data.bool_data.data.extend(obj)
return field_data
if element_type in (DataType.INT8, DataType.INT16, DataType.INT32):
field_data.int_data.data.extend(obj)
return field_data
if element_type == DataType.INT64:
field_data.long_data.data.extend(obj)
return field_data
if element_type == DataType.FLOAT:
field_data.float_data.data.extend(obj)
return field_data
if element_type == DataType.DOUBLE:
field_data.double_data.data.extend(obj)
return field_data
if element_type in (DataType.VARCHAR, DataType.STRING):
field_data.string_data.data.extend(obj)
return field_data
raise ParamError(
message=f"UnSupported element type: {element_type} for Array field: {field_info.get('name')}"
)


def entity_to_array_arr(entity: List[Any]):
return convert_to_array_arr(entity.get("values", []))


def pack_field_value_to_field_data(field_value: Any, field_data: Any, field_info: Any):
field_type = field_data.type
if field_type == DataType.BOOL:
Expand All @@ -94,6 +128,8 @@ def pack_field_value_to_field_data(field_value: Any, field_data: Any, field_info
)
elif field_type == DataType.JSON:
field_data.scalars.json_data.data.append(convert_to_json(field_value))
elif field_type == DataType.ARRAY:
field_data.scalars.array_data.data.append(convert_to_array(field_value, field_info))
else:
raise ParamError(message=f"UnSupported data type: {field_type}")

Expand Down Expand Up @@ -129,6 +165,8 @@ def entity_to_field_data(entity: Any, field_info: Any):
)
elif entity_type == DataType.JSON:
field_data.scalars.json_data.data.extend(entity_to_json_arr(entity))
elif entity_type == DataType.ARRAY:
field_data.scalars.array_data.data.extend(entity_to_array_arr(entity))
else:
raise ParamError(message=f"UnSupported data type: {entity_type}")

Expand All @@ -154,6 +192,37 @@ def extract_dynamic_field_from_result(raw: Any):
return dynamic_field_name, dynamic_fields


def extract_array_row_data(field_data: Any, index: int):
array = field_data.scalars.array_data.data[index]
row = []
if field_data.scalars.array_data.element_type == DataType.INT64:
row.extend(array.long_data.data)
return row
if field_data.scalars.array_data.element_type == DataType.BOOL:
row.extend(array.bool_data.data)
return row
if field_data.scalars.array_data.element_type in (
DataType.INT8,
DataType.INT16,
DataType.INT32,
):
row.extend(array.int_data.data)
return row
if field_data.scalars.array_data.element_type == DataType.FLOAT:
row.extend(array.float_data.data)
return row
if field_data.scalars.array_data.element_type == DataType.DOUBLE:
row.extend(array.double_data.data)
return row
if field_data.scalars.array_data.element_type in (
DataType.STRING,
DataType.VARCHAR,
):
row.extend(array.string_data.data)
return row
return row


# pylint: disable=R1702 (too-many-nested-blocks)
def extract_row_data_from_fields_data(
fields_data: Any,
Expand Down Expand Up @@ -216,6 +285,8 @@ def check_append(field_data: Any):
tmp_dict = {k: v for k, v in json_dict.items() if k in dynamic_fields}
entity_row_data.update(tmp_dict)
return
if field_data.type == DataType.ARRAY and len(field_data.scalars.array_data.data) >= index:
entity_row_data[field_data.field_name] = extract_array_row_data(field_data, index)

if field_data.type == DataType.FLOAT_VECTOR:
dim = field_data.vectors.dim
Expand Down
1 change: 1 addition & 0 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def get_schema_from_collection_schema(
autoID=f.auto_id,
is_partition_key=f.is_partition_key,
is_dynamic=f.is_dynamic,
element_type=f.element_type,
)
for k, v in f.params.items():
kv_pair = common_types.KeyValuePair(key=str(k), value=str(v))
Expand Down
1 change: 1 addition & 0 deletions pymilvus/client/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class DataType(IntEnum):

STRING = 20
VARCHAR = 21
ARRAY = 22
JSON = 23

BINARY_VECTOR = 100
Expand Down
3 changes: 3 additions & 0 deletions pymilvus/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ def len_of(field_data: Any) -> int:
if field_data.scalars.HasField("json_data"):
return len(field_data.scalars.json_data.data)

if field_data.scalars.HasField("array_data"):
return len(field_data.scalars.array_data.data)

raise MilvusException(message="Unsupported scalar type")

if field_data.HasField("vectors"):
Expand Down
2 changes: 1 addition & 1 deletion pymilvus/orm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# or implied. See the License for the specific language governing permissions and limitations under
# the License.

COMMON_TYPE_PARAMS = ("dim", "max_length")
COMMON_TYPE_PARAMS = ("dim", "max_length", "max_capacity")

CALC_DIST_IDS = "ids"
CALC_DIST_FLOAT_VEC = "float_vectors"
Expand Down
11 changes: 10 additions & 1 deletion pymilvus/orm/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def __init__(self, name: str, dtype: DataType, description: str = "", **kwargs)
if not isinstance(kwargs.get("is_partition_key", False), bool):
raise PartitionKeyException(message=ExceptionsMessage.IsPartitionKeyType)
self.is_partition_key = kwargs.get("is_partition_key", False)
self.element_type = kwargs.get("element_type", None)
self._parse_type_params()

def __repr__(self) -> str:
Expand All @@ -283,7 +284,12 @@ def __deepcopy__(self, memodict: Optional[Dict] = None):

def _parse_type_params(self):
# update self._type_params according to self._kwargs
if self._dtype not in (DataType.BINARY_VECTOR, DataType.FLOAT_VECTOR, DataType.VARCHAR):
if self._dtype not in (
DataType.BINARY_VECTOR,
DataType.FLOAT_VECTOR,
DataType.VARCHAR,
DataType.ARRAY,
):
return
if not self._kwargs:
return
Expand All @@ -304,6 +310,7 @@ def construct_from_dict(cls, raw: Dict):
kwargs["auto_id"] = raw.get("auto_id", None)
kwargs["is_partition_key"] = raw.get("is_partition_key", False)
kwargs["is_dynamic"] = raw.get("is_dynamic", False)
kwargs["element_type"] = raw.get("element_type", None)
return FieldSchema(raw["name"], raw["type"], raw.get("description", ""), **kwargs)

def to_dict(self):
Expand All @@ -321,6 +328,8 @@ def to_dict(self):
_dict["is_partition_key"] = True
if self.is_dynamic:
_dict["is_dynamic"] = self.is_dynamic
if self.dtype == DataType.ARRAY and self.element_type:
_dict["element_type"] = self.element_type
return _dict

def __getattr__(self, item: str):
Expand Down

0 comments on commit 1f1f964

Please sign in to comment.