-
Notifications
You must be signed in to change notification settings - Fork 331
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add sparse float vector support: insert, search, query, delete; suppo…
…rts various sparse matrix representations (#1902) also supported row based insertion so sparse is also supported in milvus client Signed-off-by: Buqian Zheng <[email protected]>
- Loading branch information
1 parent
bf21bf3
commit 818f290
Showing
17 changed files
with
516 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# hello_sprase.py demonstrates the basic operations of PyMilvus, a Python SDK of Milvus, | ||
# while operating on sparse float vectors. | ||
# 1. connect to Milvus | ||
# 2. create collection | ||
# 3. insert data | ||
# 4. create index | ||
# 5. search, query, and hybrid search on entities | ||
# 6. delete entities by PK | ||
# 7. drop collection | ||
import time | ||
|
||
import numpy as np | ||
from scipy.sparse import rand | ||
from pymilvus import ( | ||
connections, | ||
utility, | ||
FieldSchema, CollectionSchema, DataType, | ||
Collection, | ||
) | ||
|
||
fmt = "=== {:30} ===" | ||
search_latency_fmt = "search latency = {:.4f}s" | ||
num_entities, dim, density = 1000, 3000, 0.005 | ||
|
||
def log(msg): | ||
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + " " + msg) | ||
|
||
# ----------------------------------------------------------------------------- | ||
# connect to Milvus | ||
log(fmt.format("start connecting to Milvus")) | ||
connections.connect("default", host="localhost", port="19530") | ||
|
||
has = utility.has_collection("hello_sparse") | ||
log(f"Does collection hello_sparse exist in Milvus: {has}") | ||
|
||
# ----------------------------------------------------------------------------- | ||
# create collection with a sparse float vector column | ||
hello_sparse = None | ||
if not has: | ||
fields = [ | ||
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=True, max_length=100), | ||
FieldSchema(name="random", dtype=DataType.DOUBLE), | ||
FieldSchema(name="embeddings", dtype=DataType.SPARSE_FLOAT_VECTOR), | ||
] | ||
schema = CollectionSchema(fields, "hello_sparse is the simplest demo to introduce sparse float vector usage") | ||
log(fmt.format("Create collection `hello_sparse`")) | ||
hello_sparse = Collection("hello_sparse", schema, consistency_level="Strong") | ||
else: | ||
hello_sparse = Collection("hello_sparse") | ||
|
||
log(f"hello_sparse has {hello_sparse.num_entities} entities({hello_sparse.num_entities/1000000}M), indexed {hello_sparse.has_index()}") | ||
|
||
# ----------------------------------------------------------------------------- | ||
# insert | ||
log(fmt.format("Start creating entities to insert")) | ||
rng = np.random.default_rng(seed=19530) | ||
# this step is so damn slow | ||
matrix_csr = rand(num_entities, dim, density=density, format='csr') | ||
entities = [ | ||
rng.random(num_entities).tolist(), | ||
matrix_csr, | ||
] | ||
|
||
log(fmt.format("Start inserting entities")) | ||
insert_result = hello_sparse.insert(entities) | ||
|
||
# ----------------------------------------------------------------------------- | ||
# create index | ||
if not hello_sparse.has_index(): | ||
log(fmt.format("Start Creating index SPARSE_INVERTED_INDEX")) | ||
index = { | ||
"index_type": "SPARSE_INVERTED_INDEX", | ||
"metric_type": "IP", | ||
"params":{ | ||
"drop_ratio_build": 0.2, | ||
} | ||
} | ||
hello_sparse.create_index("embeddings", index) | ||
|
||
log(fmt.format("Start loading")) | ||
hello_sparse.load() | ||
|
||
# ----------------------------------------------------------------------------- | ||
# search based on vector similarity | ||
log(fmt.format("Start searching based on vector similarity")) | ||
vectors_to_search = entities[-1][-1:] | ||
search_params = { | ||
"metric_type": "IP", | ||
"params": { | ||
"drop_ratio_search": "0.2", | ||
} | ||
} | ||
|
||
start_time = time.time() | ||
result = hello_sparse.search(vectors_to_search, "embeddings", search_params, limit=3, output_fields=["pk", "random", "embeddings"]) | ||
end_time = time.time() | ||
|
||
for hits in result: | ||
for hit in hits: | ||
print(f"hit: {hit}") | ||
log(search_latency_fmt.format(end_time - start_time)) | ||
# ----------------------------------------------------------------------------- | ||
# query based on scalar filtering(boolean, int, etc.) | ||
print(fmt.format("Start querying with `random > 0.5`")) | ||
|
||
start_time = time.time() | ||
result = hello_sparse.query(expr="random > 0.5", output_fields=["random", "embeddings"]) | ||
end_time = time.time() | ||
|
||
print(f"query result:\n-{result[0]}") | ||
print(search_latency_fmt.format(end_time - start_time)) | ||
|
||
# ----------------------------------------------------------------------------- | ||
# pagination | ||
r1 = hello_sparse.query(expr="random > 0.5", limit=4, output_fields=["random"]) | ||
r2 = hello_sparse.query(expr="random > 0.5", offset=1, limit=3, output_fields=["random"]) | ||
print(f"query pagination(limit=4):\n\t{r1}") | ||
print(f"query pagination(offset=1, limit=3):\n\t{r2}") | ||
|
||
|
||
# ----------------------------------------------------------------------------- | ||
# hybrid search | ||
print(fmt.format("Start hybrid searching with `random > 0.5`")) | ||
|
||
start_time = time.time() | ||
result = hello_sparse.search(vectors_to_search, "embeddings", search_params, limit=3, expr="random > 0.5", output_fields=["random"]) | ||
end_time = time.time() | ||
|
||
for hits in result: | ||
for hit in hits: | ||
print(f"hit: {hit}, random field: {hit.entity.get('random')}") | ||
print(search_latency_fmt.format(end_time - start_time)) | ||
|
||
# ----------------------------------------------------------------------------- | ||
# delete entities by PK | ||
# You can delete entities by their PK values using boolean expressions. | ||
ids = insert_result.primary_keys | ||
|
||
expr = f'pk in ["{ids[0]}" , "{ids[1]}"]' | ||
print(fmt.format(f"Start deleting with expr `{expr}`")) | ||
|
||
result = hello_sparse.query(expr=expr, output_fields=["random", "embeddings"]) | ||
print(f"query before delete by expr=`{expr}` -> result: \n-{result[0]}\n-{result[1]}\n") | ||
|
||
hello_sparse.delete(expr) | ||
|
||
result = hello_sparse.query(expr=expr, output_fields=["random", "embeddings"]) | ||
print(f"query after delete by expr=`{expr}` -> result: {result}\n") | ||
|
||
|
||
# ----------------------------------------------------------------------------- | ||
# drop collection | ||
print(fmt.format("Drop collection `hello_sparse`")) | ||
utility.drop_collection("hello_sparse") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from pymilvus import ( | ||
MilvusClient, | ||
FieldSchema, CollectionSchema, DataType, | ||
) | ||
|
||
import random | ||
|
||
def generate_sparse_vector(dimension: int, non_zero_count: int) -> dict: | ||
indices = random.sample(range(dimension), non_zero_count) | ||
values = [random.random() for _ in range(non_zero_count)] | ||
sparse_vector = {index: value for index, value in zip(indices, values)} | ||
return sparse_vector | ||
|
||
|
||
fmt = "\n=== {:30} ===\n" | ||
dim = 100 | ||
non_zero_count = 20 | ||
collection_name = "hello_sparse" | ||
milvus_client = MilvusClient("http://localhost:19530") | ||
|
||
has_collection = milvus_client.has_collection(collection_name, timeout=5) | ||
if has_collection: | ||
milvus_client.drop_collection(collection_name) | ||
fields = [ | ||
FieldSchema(name="pk", dtype=DataType.VARCHAR, | ||
is_primary=True, auto_id=True, max_length=100), | ||
FieldSchema(name="random", dtype=DataType.DOUBLE), | ||
FieldSchema(name="embeddings", dtype=DataType.SPARSE_FLOAT_VECTOR), | ||
] | ||
schema = CollectionSchema( | ||
fields, "demo for using sparse float vector with milvus client") | ||
index_params = milvus_client.prepare_index_params() | ||
index_params.add_index(field_name="embeddings", index_name="sparse_inverted_index", | ||
index_type="SPARSE_INVERTED_INDEX", metric_type="IP", params={"drop_ratio_build": 0.2}) | ||
milvus_client.create_collection(collection_name, schema=schema, | ||
index_params=index_params, timeout=5, consistency_level="Strong") | ||
|
||
print(fmt.format(" all collections ")) | ||
print(milvus_client.list_collections()) | ||
|
||
print(fmt.format(f"schema of collection {collection_name}")) | ||
print(milvus_client.describe_collection(collection_name)) | ||
|
||
N = 6 | ||
rows = [{"random": i, "embeddings": generate_sparse_vector( | ||
dim, non_zero_count)} for i in range(N)] | ||
|
||
print(fmt.format("Start inserting entities")) | ||
insert_result = milvus_client.insert(collection_name, rows, progress_bar=True) | ||
print(fmt.format("Inserting entities done")) | ||
print(insert_result) | ||
|
||
print(fmt.format(f"Start vector anns search.")) | ||
vectors_to_search = [generate_sparse_vector(dim, non_zero_count)] | ||
search_params = { | ||
"metric_type": "IP", | ||
"params": { | ||
"drop_ratio_search": 0.2, | ||
} | ||
} | ||
# no need to specify anns_field for collections with only 1 vector field | ||
result = milvus_client.search(collection_name, vectors_to_search, limit=3, output_fields=[ | ||
"pk", "random", "embeddings"], search_params=search_params) | ||
for hits in result: | ||
for hit in hits: | ||
print(f"hit: {hit}") | ||
|
||
print(fmt.format("Start query by specifying filtering expression")) | ||
query_results = milvus_client.query(collection_name, filter="random < 3") | ||
pks = [ret['pk'] for ret in query_results] | ||
for ret in query_results: | ||
print(ret) | ||
|
||
print(fmt.format("Start query by specifying primary keys")) | ||
query_results = milvus_client.query( | ||
collection_name, filter=f"pk == '{pks[0]}'") | ||
print(query_results[0]) | ||
|
||
print(f"start to delete by specifying filter in collection {collection_name}") | ||
delete_result = milvus_client.delete(collection_name, ids=pks[:1]) | ||
print(delete_result) | ||
|
||
print(fmt.format("Start query by specifying primary keys")) | ||
query_results = milvus_client.query( | ||
collection_name, filter=f"pk == '{pks[0]}'") | ||
print(f'query result should be empty: {query_results}') | ||
|
||
milvus_client.drop_collection(collection_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.