Skip to content

Commit

Permalink
Improve embedding retrieval performance (#2300)
Browse files Browse the repository at this point in the history
Signed-off-by: yhmo <[email protected]>
  • Loading branch information
yhmo authored Oct 16, 2024
1 parent d5a3e59 commit 2040ac6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
11 changes: 10 additions & 1 deletion pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,16 @@ def get_fields_by_range(
dim, vectors = field.vectors.dim, field.vectors
field_meta.vectors.dim = dim
if dtype == DataType.FLOAT_VECTOR:
field2data[name] = vectors.float_vector.data[start * dim : end * dim], field_meta
if start == 0 and (end - start) * dim >= len(vectors.float_vector.data):
# If the range equals to the lenth of ectors.float_vector.data, direct return
# it to avoid a copy. This logic improves performance by 25% for the case
# retrival 1536 dim embeddings with topk=16384.
field2data[name] = vectors.float_vector.data, field_meta
else:
field2data[name] = (
vectors.float_vector.data[start * dim : end * dim],
field_meta,
)
continue

if dtype == DataType.BINARY_VECTOR:
Expand Down
10 changes: 7 additions & 3 deletions pymilvus/client/entity_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,9 +708,13 @@ def check_append(field_data: Any):
dim = field_data.vectors.dim
if len(field_data.vectors.float_vector.data) >= index * dim:
start_pos, end_pos = index * dim, (index + 1) * dim
entity_row_data[field_data.field_name] = [
np.single(x) for x in field_data.vectors.float_vector.data[start_pos:end_pos]
]
# Here we use numpy.array to convert the float64 values to numpy.float32 values,
# and return a list of numpy.float32 to users
# By using numpy.array, performance improved by 60% for topk=16384 dim=1536 case.
arr = np.array(
field_data.vectors.float_vector.data[start_pos:end_pos], dtype=np.float32
)
entity_row_data[field_data.field_name] = list(arr)
elif field_data.type == DataType.BINARY_VECTOR:
dim = field_data.vectors.dim
if len(field_data.vectors.binary_vector) >= index * (dim // 8):
Expand Down

0 comments on commit 2040ac6

Please sign in to comment.