diff --git a/pymilvus/client/abstract.py b/pymilvus/client/abstract.py index 78235b8bb..6acbb6934 100644 --- a/pymilvus/client/abstract.py +++ b/pymilvus/client/abstract.py @@ -468,6 +468,19 @@ def get_fields_by_range( ) continue + if dtype == DataType.BFLOAT16_VECTOR: + field2data[name] = ( + vectors.bfloat16_vector[start * (dim * 2) : end * (dim * 2)], + field_meta, + ) + continue + + if dtype == DataType.FLOAT16_VECTOR: + field2data[name] = ( + vectors.float16_vector[start * (dim * 2) : end * (dim * 2)], + field_meta, + ) + continue return field2data def __iter__(self) -> SequenceIterator: @@ -510,10 +523,17 @@ def __init__( if len(data) <= i: curr_field[fname] = None # Get vectors - if field_meta.type in (DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR): + if field_meta.type in ( + DataType.FLOAT_VECTOR, + DataType.BINARY_VECTOR, + DataType.BFLOAT16_VECTOR, + DataType.FLOAT16_VECTOR, + ): dim = field_meta.vectors.dim - dim = dim // 8 if field_meta.type == DataType.BINARY_VECTOR else dim - + if field_meta.type in [DataType.BINARY_VECTOR]: + dim = dim // 8 + elif field_meta.type in [DataType.BFLOAT16_VECTOR, DataType.FLOAT16_VECTOR]: + dim = dim * 2 curr_field[fname] = data[i * dim : (i + 1) * dim] continue diff --git a/pymilvus/client/entity_helper.py b/pymilvus/client/entity_helper.py index 72eca8c42..b0da5f496 100644 --- a/pymilvus/client/entity_helper.py +++ b/pymilvus/client/entity_helper.py @@ -325,6 +325,20 @@ def check_append(field_data: Any): entity_row_data[field_data.field_name] = [ field_data.vectors.binary_vector[start_pos:end_pos] ] + elif field_data.type == DataType.BFLOAT16_VECTOR: + dim = field_data.vectors.dim + if len(field_data.vectors.bfloat16_vector) >= index * (dim * 2): + start_pos, end_pos = index * (dim * 2), (index + 1) * (dim * 2) + entity_row_data[field_data.field_name] = [ + field_data.vectors.bfloat16_vector[start_pos:end_pos] + ] + elif field_data.type == DataType.FLOAT16_VECTOR: + dim = field_data.vectors.dim + if len(field_data.vectors.float16_vector) >= index * (dim * 2): + start_pos, end_pos = index * (dim * 2), (index + 1) * (dim * 2) + entity_row_data[field_data.field_name] = [ + field_data.vectors.float16_vector[start_pos:end_pos] + ] for field_data in fields_data: check_append(field_data) diff --git a/pymilvus/client/utils.py b/pymilvus/client/utils.py index 889f5b884..a114cfebe 100644 --- a/pymilvus/client/utils.py +++ b/pymilvus/client/utils.py @@ -240,10 +240,18 @@ def traverse_rows_info(fields_info: Any, entities: List): if value is None: raise ParamError(message=f"Field {field_name} don't match in entities[{j}]") - if field_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: + if field_type in [ + DataType.FLOAT_VECTOR, + DataType.BINARY_VECTOR, + DataType.BFLOAT16_VECTOR, + DataType.FLOAT16_VECTOR, + ]: field_dim = field["params"]["dim"] - entity_dim = len(value) if field_type == DataType.FLOAT_VECTOR else len(value) * 8 - + entity_dim = len(value) + if field_type in [DataType.BINARY_VECTOR]: + entity_dim = entity_dim * 8 + elif field_type in [DataType.BFLOAT16_VECTOR, DataType.FLOAT16_VECTOR]: + entity_dim = int(entity_dim // 2) if entity_dim != field_dim: raise ParamError( message=f"Collection field dim is {field_dim}" @@ -282,7 +290,12 @@ def traverse_info(fields_info: Any, entities: List): ) entity_dim, field_dim = 0, 0 - if entity_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: + if entity_type in [ + DataType.FLOAT_VECTOR, + DataType.BINARY_VECTOR, + DataType.BFLOAT16_VECTOR, + DataType.FLOAT16_VECTOR, + ]: field_dim = field["params"]["dim"] entity_dim = len(entity["values"][0]) @@ -298,6 +311,15 @@ def traverse_info(fields_info: Any, entities: List): f", but entities field dim is {entity_dim * 8}" ) + if ( + entity_type in [DataType.BFLOAT16_VECTOR, DataType.FLOAT16_VECTOR] + and int(entity_dim // 2) != field_dim + ): + raise ParamError( + message=f"Collection field dim is {field_dim}" + f", but entities field dim is {int(entity_dim // 2)}" + ) + location[field["name"]] = i match_flag = True break diff --git a/pymilvus/orm/schema.py b/pymilvus/orm/schema.py index c5674a31d..5b0902047 100644 --- a/pymilvus/orm/schema.py +++ b/pymilvus/orm/schema.py @@ -507,7 +507,7 @@ def prepare_fields_from_dataframe(df: pd.DataFrame): if new_dtype == DataType.BINARY_VECTOR: vector_type_params["dim"] = len(values[i]) * 8 elif new_dtype in (DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR): - vector_type_params["dim"] = len(values[i]) / 2 + vector_type_params["dim"] = int(len(values[i]) // 2) else: vector_type_params["dim"] = len(values[i]) column_params_map[col_names[i]] = vector_type_params diff --git a/tests/test_abstract.py b/tests/test_abstract.py index 29be073bf..0edb9099b 100644 --- a/tests/test_abstract.py +++ b/tests/test_abstract.py @@ -202,3 +202,5 @@ def test_search_result_with_fields_data(self, pk): assert 1 == r[0][1].int8_field assert 2 == r[0][2].int16_field assert [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] == r[0][1].int64_array_field + assert 32 == len(r[0][0].entity.bfloat16_vector_field) + assert 32 == len(r[0][0].entity.float16_vector_field) diff --git a/tests/test_collection.py b/tests/test_collection.py index 53a37a168..8800680fa 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -146,6 +146,19 @@ def test_query(self, collection): expr = "id in [ " + ids_expr + " ]" res = collection.query(expr) + @pytest.mark.xfail + def test_query_with_output_fields(self, collection): + data = gen_list_data(default_nb) + ids = collection.insert(data) + assert len(ids) == default_nb + ids_expr = ",".join(str(x) for x in ids) + expr = "id in [ " + ids_expr + " ]" + output_fields = ["float_vector", "float16_vector", "bfloat16_vector"] + res = collection.query(expr, output_fields=output_fields) + assert len(res) == default_nb + for key in output_fields: + assert key in res + @pytest.mark.xfail def test_partitions(self, collection): assert len(collection.partitions) == 1