Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

normalize vector data to a standard form during insertion (#27469) #2322

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions examples/example_normalization_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import time

import numpy as np
from pymilvus import (
connections,
utility,
FieldSchema, CollectionSchema, DataType,
Collection,
MilvusClient
)

fmt = "\n=== {:30} ===\n"
search_latency_fmt = "search latency = {:.4f}s"
num_entities, dim = 3000, 8


print(fmt.format("start connecting to Milvus"))
# this is milvus standalone
connection = connections.connect(
alias="default",
host='localhost', # or '0.0.0.0' or 'localhost'
port='19530'
)

client = MilvusClient(connections=connection)

has = utility.has_collection("hello_milvus")
print(f"Does collection hello_milvus exist in Milvus: {has}")
if has:
utility.drop_collection("hello_milvus")

fields = [
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100),
FieldSchema(name="random", dtype=DataType.DOUBLE),
FieldSchema(name="embeddings1", dtype=DataType.FLOAT_VECTOR, dim=dim),
FieldSchema(name="embeddings2", dtype=DataType.FLOAT_VECTOR, dim=dim)
]

schema = CollectionSchema(fields, "hello_milvus is the simplest demo to introduce the APIs")

print(fmt.format("Create collection `hello_milvus`"))

print(fmt.format("Message for handling an invalid format in the normalization_fields value")) # you can try with other value like: dict,...
try:
hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong", normalization_fields='embeddings1')
except BaseException as e:
print(e)


print(fmt.format("Message for handling the invalid vector fields"))
try:
hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong", normalization_fields=['embddings'])
except BaseException as e:
print(e)

print(fmt.format("Insert data, without conversion to standard form"))

hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong")

print(fmt.format("Start inserting a row"))
rng = np.random.default_rng(seed=19530)

row = {
"pk": "19530",
"random": 0.5,
"embeddings1": rng.random((1, dim), np.float32)[0],
"embeddings2": rng.random((1, dim), np.float32)[0]
}
hello_milvus.insert(row)
utility.drop_collection("hello_milvus")

print(fmt.format("Insert data, with conversion to standard form"))

hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong", normalization_fields=['embeddings1'])

print(fmt.format("Start inserting a row"))
rng = np.random.default_rng(seed=19530)

row = {
"pk": "19530",
"random": 0.5,
"embeddings1": rng.random((1, dim), np.float32)[0],
"embeddings2": rng.random((1, dim), np.float32)[0]
}
_row = row.copy()
hello_milvus.insert(row)

index_param = {"index_type": "FLAT", "metric_type": "L2", "params": {}}
hello_milvus.create_index("embeddings1", index_param)
hello_milvus.create_index("embeddings2", index_param)
hello_milvus.load()

original_vector = _row['embeddings1']
insert_vector = hello_milvus.query(
expr="pk == '19530'",
output_fields=["embeddings1"],
)[0]['embeddings1']

print(fmt.format("Mean and standard deviation before normalization."))
print("Mean: ", np.mean(original_vector))
print("Std: ", np.std(original_vector))

print(fmt.format("Mean and standard deviation after normalization."))
print("Mean: ", np.mean(insert_vector))
print("Std: ", np.std(insert_vector))


print(fmt.format("Start inserting entities"))

entities = [
[str(i) for i in range(num_entities)],
rng.random(num_entities).tolist(),
rng.random((num_entities, dim), np.float32),
rng.random((num_entities, dim), np.float32),
]

insert_result = hello_milvus.insert(entities)

insert_vector = hello_milvus.query(
expr="pk == '1'",
output_fields=["embeddings1"],
)[0]['embeddings1']

print(fmt.format("Mean and standard deviation after normalization."))
print("Mean: ", np.mean(insert_vector))
print("Std: ", np.std(insert_vector))

utility.drop_collection("hello_milvus")
18 changes: 18 additions & 0 deletions pymilvus/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
import ujson

from pymilvus.exceptions import MilvusException, ParamError
Expand Down Expand Up @@ -375,3 +376,20 @@ def is_scipy_sparse(cls, data: Any):
"csr_array",
"spmatrix",
]


def convert_to_standard_form(vector_data: Any) -> Any:
if len(vector_data.shape) == 1:
# Calculate the mean and standard deviation of the vector
mean = np.mean(vector_data)
std_dev = np.std(vector_data)

# Standardize the vector
return (vector_data - mean) / std_dev if std_dev != 0 else vector_data

# Calculate mean and standard deviation for each row
row_means = np.mean(vector_data, axis=1, keepdims=True)
row_stds = np.std(vector_data, axis=1, keepdims=True)

# Standardize each row independently
return np.where(row_stds != 0, (vector_data - row_means) / row_stds, vector_data)
2 changes: 2 additions & 0 deletions pymilvus/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,5 @@ class ExceptionsMessage:
DefaultValueInvalid = (
"Default value cannot be None for a field that is defined as nullable == false."
)
InvalidVectorFields = "%s is not a valid vector field; expected %s"
InvalidNormalizationParam = "Unexpected normalization_fields parameters. Expected 'all' or a list of fields (e.g., [field1, field2, ...]), but got %s."
36 changes: 35 additions & 1 deletion pymilvus/orm/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
DataTypeNotSupportException,
ExceptionsMessage,
IndexNotExistException,
MilvusException,
PartitionAlreadyExistException,
SchemaNotReadyException,
)
Expand Down Expand Up @@ -93,6 +94,9 @@ def __init__(
If timeout is not set, the client keeps waiting until the server
responds or an error occurs.

* *normalization_fields* (``str/list``, optional)
Fields are selected to apply standard normalization.


Raises:
SchemaNotReadyException: if the schema is wrong.
Expand Down Expand Up @@ -156,6 +160,30 @@ def __init__(
self._schema_dict = self._schema.to_dict()
self._schema_dict["consistency_level"] = self._consistency_level

self._normalization_fields = self._kwargs.get("normalization_fields", None)
if self._normalization_fields:
self._vector_fields = self._get_vector_fields()
if self._normalization_fields == "all":
self._normalization_fields = self._vector_fields
elif isinstance(self._normalization_fields, list):
for field in self._normalization_fields:
if field not in self._vector_fields:
raise MilvusException(
ExceptionsMessage.InvalidVectorFields
% (field, ", ".join(self._vector_fields))
)
else:
raise MilvusException(
ExceptionsMessage.InvalidNormalizationParam % (self._normalization_fields)
)

def _get_vector_fields(self):
vector_fields = []
for field in self._schema_dict.get("fields", []):
if field.get("params", {}).get("dim", None):
vector_fields.append(field.get("name"))
return vector_fields

def __repr__(self) -> str:
_dict = {
"name": self.name,
Expand Down Expand Up @@ -504,6 +532,9 @@ def insert(

conn = self._get_connection()
if is_row_based(data):
if self._normalization_fields:
for norm_fld in self._normalization_fields:
data[norm_fld] = utils.convert_to_standard_form(data[norm_fld])
return conn.insert_rows(
collection_name=self._name,
entities=data,
Expand All @@ -512,7 +543,10 @@ def insert(
schema=self._schema_dict,
**kwargs,
)

if self._normalization_fields:
for idx, fld in enumerate(self._schema_dict["fields"]):
if fld["name"] in self._normalization_fields:
data[idx] = utils.convert_to_standard_form(data[idx])
check_insert_schema(self.schema, data)
entities = Prepare.prepare_data(data, self.schema)
return conn.batch_insert(
Expand Down