Skip to content

Commit

Permalink
fix: can't get vector when fp16/bf16 vectors as output fields (#1924)
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 authored Feb 6, 2024
1 parent 72f44be commit 24558a3
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 8 deletions.
26 changes: 23 additions & 3 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,19 @@ def get_fields_by_range(
)
continue

if dtype == DataType.BFLOAT16_VECTOR:
field2data[name] = (
vectors.bfloat16_vector[start * (dim * 2) : end * (dim * 2)],
field_meta,
)
continue

if dtype == DataType.FLOAT16_VECTOR:
field2data[name] = (
vectors.float16_vector[start * (dim * 2) : end * (dim * 2)],
field_meta,
)
continue
return field2data

def __iter__(self) -> SequenceIterator:
Expand Down Expand Up @@ -510,10 +523,17 @@ def __init__(
if len(data) <= i:
curr_field[fname] = None
# Get vectors
if field_meta.type in (DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR):
if field_meta.type in (
DataType.FLOAT_VECTOR,
DataType.BINARY_VECTOR,
DataType.BFLOAT16_VECTOR,
DataType.FLOAT16_VECTOR,
):
dim = field_meta.vectors.dim
dim = dim // 8 if field_meta.type == DataType.BINARY_VECTOR else dim

if field_meta.type in [DataType.BINARY_VECTOR]:
dim = dim // 8
elif field_meta.type in [DataType.BFLOAT16_VECTOR, DataType.FLOAT16_VECTOR]:
dim = dim * 2
curr_field[fname] = data[i * dim : (i + 1) * dim]
continue

Expand Down
14 changes: 14 additions & 0 deletions pymilvus/client/entity_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,20 @@ def check_append(field_data: Any):
entity_row_data[field_data.field_name] = [
field_data.vectors.binary_vector[start_pos:end_pos]
]
elif field_data.type == DataType.BFLOAT16_VECTOR:
dim = field_data.vectors.dim
if len(field_data.vectors.bfloat16_vector) >= index * (dim * 2):
start_pos, end_pos = index * (dim * 2), (index + 1) * (dim * 2)
entity_row_data[field_data.field_name] = [
field_data.vectors.bfloat16_vector[start_pos:end_pos]
]
elif field_data.type == DataType.FLOAT16_VECTOR:
dim = field_data.vectors.dim
if len(field_data.vectors.float16_vector) >= index * (dim * 2):
start_pos, end_pos = index * (dim * 2), (index + 1) * (dim * 2)
entity_row_data[field_data.field_name] = [
field_data.vectors.float16_vector[start_pos:end_pos]
]

for field_data in fields_data:
check_append(field_data)
Expand Down
30 changes: 26 additions & 4 deletions pymilvus/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,18 @@ def traverse_rows_info(fields_info: Any, entities: List):
if value is None:
raise ParamError(message=f"Field {field_name} don't match in entities[{j}]")

if field_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
if field_type in [
DataType.FLOAT_VECTOR,
DataType.BINARY_VECTOR,
DataType.BFLOAT16_VECTOR,
DataType.FLOAT16_VECTOR,
]:
field_dim = field["params"]["dim"]
entity_dim = len(value) if field_type == DataType.FLOAT_VECTOR else len(value) * 8

entity_dim = len(value)
if field_type in [DataType.BINARY_VECTOR]:
entity_dim = entity_dim * 8
elif field_type in [DataType.BFLOAT16_VECTOR, DataType.FLOAT16_VECTOR]:
entity_dim = int(entity_dim // 2)
if entity_dim != field_dim:
raise ParamError(
message=f"Collection field dim is {field_dim}"
Expand Down Expand Up @@ -282,7 +290,12 @@ def traverse_info(fields_info: Any, entities: List):
)

entity_dim, field_dim = 0, 0
if entity_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
if entity_type in [
DataType.FLOAT_VECTOR,
DataType.BINARY_VECTOR,
DataType.BFLOAT16_VECTOR,
DataType.FLOAT16_VECTOR,
]:
field_dim = field["params"]["dim"]
entity_dim = len(entity["values"][0])

Expand All @@ -298,6 +311,15 @@ def traverse_info(fields_info: Any, entities: List):
f", but entities field dim is {entity_dim * 8}"
)

if (
entity_type in [DataType.BFLOAT16_VECTOR, DataType.FLOAT16_VECTOR]
and int(entity_dim // 2) != field_dim
):
raise ParamError(
message=f"Collection field dim is {field_dim}"
f", but entities field dim is {int(entity_dim // 2)}"
)

location[field["name"]] = i
match_flag = True
break
Expand Down
2 changes: 1 addition & 1 deletion pymilvus/orm/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def prepare_fields_from_dataframe(df: pd.DataFrame):
if new_dtype == DataType.BINARY_VECTOR:
vector_type_params["dim"] = len(values[i]) * 8
elif new_dtype in (DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR):
vector_type_params["dim"] = len(values[i]) / 2
vector_type_params["dim"] = int(len(values[i]) // 2)
else:
vector_type_params["dim"] = len(values[i])
column_params_map[col_names[i]] = vector_type_params
Expand Down
2 changes: 2 additions & 0 deletions tests/test_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,5 @@ def test_search_result_with_fields_data(self, pk):
assert 1 == r[0][1].int8_field
assert 2 == r[0][2].int16_field
assert [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] == r[0][1].int64_array_field
assert 32 == len(r[0][0].entity.bfloat16_vector_field)
assert 32 == len(r[0][0].entity.float16_vector_field)
13 changes: 13 additions & 0 deletions tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,19 @@ def test_query(self, collection):
expr = "id in [ " + ids_expr + " ]"
res = collection.query(expr)

@pytest.mark.xfail
def test_query_with_output_fields(self, collection):
data = gen_list_data(default_nb)
ids = collection.insert(data)
assert len(ids) == default_nb
ids_expr = ",".join(str(x) for x in ids)
expr = "id in [ " + ids_expr + " ]"
output_fields = ["float_vector", "float16_vector", "bfloat16_vector"]
res = collection.query(expr, output_fields=output_fields)
assert len(res) == default_nb
for key in output_fields:
assert key in res

@pytest.mark.xfail
def test_partitions(self, collection):
assert len(collection.partitions) == 1
Expand Down

0 comments on commit 24558a3

Please sign in to comment.