Skip to content

Commit

Permalink
add sparse float vector support: insert, search, query, delete; suppo…
Browse files Browse the repository at this point in the history
…rts various sparse matrix representations (#1902)

also supported row based insertion so sparse is also supported in milvus
client

Signed-off-by: Buqian Zheng <[email protected]>
  • Loading branch information
zhengbuqian authored Mar 14, 2024
1 parent bf21bf3 commit 818f290
Show file tree
Hide file tree
Showing 17 changed files with 516 additions and 60 deletions.
154 changes: 154 additions & 0 deletions examples/hello_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# hello_sprase.py demonstrates the basic operations of PyMilvus, a Python SDK of Milvus,
# while operating on sparse float vectors.
# 1. connect to Milvus
# 2. create collection
# 3. insert data
# 4. create index
# 5. search, query, and hybrid search on entities
# 6. delete entities by PK
# 7. drop collection
import time

import numpy as np
from scipy.sparse import rand
from pymilvus import (
connections,
utility,
FieldSchema, CollectionSchema, DataType,
Collection,
)

fmt = "=== {:30} ==="
search_latency_fmt = "search latency = {:.4f}s"
num_entities, dim, density = 1000, 3000, 0.005

def log(msg):
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + " " + msg)

# -----------------------------------------------------------------------------
# connect to Milvus
log(fmt.format("start connecting to Milvus"))
connections.connect("default", host="localhost", port="19530")

has = utility.has_collection("hello_sparse")
log(f"Does collection hello_sparse exist in Milvus: {has}")

# -----------------------------------------------------------------------------
# create collection with a sparse float vector column
hello_sparse = None
if not has:
fields = [
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=True, max_length=100),
FieldSchema(name="random", dtype=DataType.DOUBLE),
FieldSchema(name="embeddings", dtype=DataType.SPARSE_FLOAT_VECTOR),
]
schema = CollectionSchema(fields, "hello_sparse is the simplest demo to introduce sparse float vector usage")
log(fmt.format("Create collection `hello_sparse`"))
hello_sparse = Collection("hello_sparse", schema, consistency_level="Strong")
else:
hello_sparse = Collection("hello_sparse")

log(f"hello_sparse has {hello_sparse.num_entities} entities({hello_sparse.num_entities/1000000}M), indexed {hello_sparse.has_index()}")

# -----------------------------------------------------------------------------
# insert
log(fmt.format("Start creating entities to insert"))
rng = np.random.default_rng(seed=19530)
# this step is so damn slow
matrix_csr = rand(num_entities, dim, density=density, format='csr')
entities = [
rng.random(num_entities).tolist(),
matrix_csr,
]

log(fmt.format("Start inserting entities"))
insert_result = hello_sparse.insert(entities)

# -----------------------------------------------------------------------------
# create index
if not hello_sparse.has_index():
log(fmt.format("Start Creating index SPARSE_INVERTED_INDEX"))
index = {
"index_type": "SPARSE_INVERTED_INDEX",
"metric_type": "IP",
"params":{
"drop_ratio_build": 0.2,
}
}
hello_sparse.create_index("embeddings", index)

log(fmt.format("Start loading"))
hello_sparse.load()

# -----------------------------------------------------------------------------
# search based on vector similarity
log(fmt.format("Start searching based on vector similarity"))
vectors_to_search = entities[-1][-1:]
search_params = {
"metric_type": "IP",
"params": {
"drop_ratio_search": "0.2",
}
}

start_time = time.time()
result = hello_sparse.search(vectors_to_search, "embeddings", search_params, limit=3, output_fields=["pk", "random", "embeddings"])
end_time = time.time()

for hits in result:
for hit in hits:
print(f"hit: {hit}")
log(search_latency_fmt.format(end_time - start_time))
# -----------------------------------------------------------------------------
# query based on scalar filtering(boolean, int, etc.)
print(fmt.format("Start querying with `random > 0.5`"))

start_time = time.time()
result = hello_sparse.query(expr="random > 0.5", output_fields=["random", "embeddings"])
end_time = time.time()

print(f"query result:\n-{result[0]}")
print(search_latency_fmt.format(end_time - start_time))

# -----------------------------------------------------------------------------
# pagination
r1 = hello_sparse.query(expr="random > 0.5", limit=4, output_fields=["random"])
r2 = hello_sparse.query(expr="random > 0.5", offset=1, limit=3, output_fields=["random"])
print(f"query pagination(limit=4):\n\t{r1}")
print(f"query pagination(offset=1, limit=3):\n\t{r2}")


# -----------------------------------------------------------------------------
# hybrid search
print(fmt.format("Start hybrid searching with `random > 0.5`"))

start_time = time.time()
result = hello_sparse.search(vectors_to_search, "embeddings", search_params, limit=3, expr="random > 0.5", output_fields=["random"])
end_time = time.time()

for hits in result:
for hit in hits:
print(f"hit: {hit}, random field: {hit.entity.get('random')}")
print(search_latency_fmt.format(end_time - start_time))

# -----------------------------------------------------------------------------
# delete entities by PK
# You can delete entities by their PK values using boolean expressions.
ids = insert_result.primary_keys

expr = f'pk in ["{ids[0]}" , "{ids[1]}"]'
print(fmt.format(f"Start deleting with expr `{expr}`"))

result = hello_sparse.query(expr=expr, output_fields=["random", "embeddings"])
print(f"query before delete by expr=`{expr}` -> result: \n-{result[0]}\n-{result[1]}\n")

hello_sparse.delete(expr)

result = hello_sparse.query(expr=expr, output_fields=["random", "embeddings"])
print(f"query after delete by expr=`{expr}` -> result: {result}\n")


# -----------------------------------------------------------------------------
# drop collection
print(fmt.format("Drop collection `hello_sparse`"))
utility.drop_collection("hello_sparse")
88 changes: 88 additions & 0 deletions examples/milvus_client/sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from pymilvus import (
MilvusClient,
FieldSchema, CollectionSchema, DataType,
)

import random

def generate_sparse_vector(dimension: int, non_zero_count: int) -> dict:
indices = random.sample(range(dimension), non_zero_count)
values = [random.random() for _ in range(non_zero_count)]
sparse_vector = {index: value for index, value in zip(indices, values)}
return sparse_vector


fmt = "\n=== {:30} ===\n"
dim = 100
non_zero_count = 20
collection_name = "hello_sparse"
milvus_client = MilvusClient("http://localhost:19530")

has_collection = milvus_client.has_collection(collection_name, timeout=5)
if has_collection:
milvus_client.drop_collection(collection_name)
fields = [
FieldSchema(name="pk", dtype=DataType.VARCHAR,
is_primary=True, auto_id=True, max_length=100),
FieldSchema(name="random", dtype=DataType.DOUBLE),
FieldSchema(name="embeddings", dtype=DataType.SPARSE_FLOAT_VECTOR),
]
schema = CollectionSchema(
fields, "demo for using sparse float vector with milvus client")
index_params = milvus_client.prepare_index_params()
index_params.add_index(field_name="embeddings", index_name="sparse_inverted_index",
index_type="SPARSE_INVERTED_INDEX", metric_type="IP", params={"drop_ratio_build": 0.2})
milvus_client.create_collection(collection_name, schema=schema,
index_params=index_params, timeout=5, consistency_level="Strong")

print(fmt.format(" all collections "))
print(milvus_client.list_collections())

print(fmt.format(f"schema of collection {collection_name}"))
print(milvus_client.describe_collection(collection_name))

N = 6
rows = [{"random": i, "embeddings": generate_sparse_vector(
dim, non_zero_count)} for i in range(N)]

print(fmt.format("Start inserting entities"))
insert_result = milvus_client.insert(collection_name, rows, progress_bar=True)
print(fmt.format("Inserting entities done"))
print(insert_result)

print(fmt.format(f"Start vector anns search."))
vectors_to_search = [generate_sparse_vector(dim, non_zero_count)]
search_params = {
"metric_type": "IP",
"params": {
"drop_ratio_search": 0.2,
}
}
# no need to specify anns_field for collections with only 1 vector field
result = milvus_client.search(collection_name, vectors_to_search, limit=3, output_fields=[
"pk", "random", "embeddings"], search_params=search_params)
for hits in result:
for hit in hits:
print(f"hit: {hit}")

print(fmt.format("Start query by specifying filtering expression"))
query_results = milvus_client.query(collection_name, filter="random < 3")
pks = [ret['pk'] for ret in query_results]
for ret in query_results:
print(ret)

print(fmt.format("Start query by specifying primary keys"))
query_results = milvus_client.query(
collection_name, filter=f"pk == '{pks[0]}'")
print(query_results[0])

print(f"start to delete by specifying filter in collection {collection_name}")
delete_result = milvus_client.delete(collection_name, ids=pks[:1])
print(delete_result)

print(fmt.format("Start query by specifying primary keys"))
query_results = milvus_client.query(
collection_name, filter=f"pk == '{pks[0]}'")
print(f'query result should be empty: {query_results}')

milvus_client.drop_collection(collection_name)
18 changes: 11 additions & 7 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pymilvus.grpc_gen import schema_pb2
from pymilvus.settings import Config

from . import entity_helper
from .constants import DEFAULT_CONSISTENCY_LEVEL, RANKER_TYPE_RRF, RANKER_TYPE_WEIGHTED
from .types import DataType

Expand All @@ -15,7 +16,6 @@ class FieldSchema:
def __init__(self, raw: Any):
self._raw = raw

#
self.field_id = 0
self.name = None
self.is_primary = False
Expand All @@ -29,7 +29,6 @@ def __init__(self, raw: Any):
# For array field
self.element_type = None
self.is_clustering_key = False
##
self.__pack(self._raw)

def __pack(self, raw: Any):
Expand Down Expand Up @@ -106,7 +105,6 @@ class CollectionSchema:
def __init__(self, raw: Any):
self._raw = raw

#
self.collection_name = None
self.description = None
self.params = {}
Expand All @@ -121,7 +119,6 @@ def __init__(self, raw: Any):
self.num_partitions = 0
self.enable_dynamic_field = False

#
if self._raw:
self.__pack(self._raw)

Expand Down Expand Up @@ -330,7 +327,7 @@ def dict(self):
class AnnSearchRequest:
def __init__(
self,
data: List,
data: Union[List, entity_helper.SparseMatrixInputType],
anns_field: str,
param: Dict,
limit: int,
Expand Down Expand Up @@ -472,6 +469,13 @@ def get_fields_by_range(
field_meta,
)
continue
# TODO(SPARSE): do we want to allow the user to specify the return format?
if dtype == DataType.SPARSE_FLOAT_VECTOR:
field2data[name] = (
entity_helper.sparse_proto_to_rows(vectors.sparse_float_vector, start, end),
field_meta,
)
continue

if dtype == DataType.BFLOAT16_VECTOR:
field2data[name] = (
Expand Down Expand Up @@ -527,7 +531,7 @@ def __init__(
for fname, (data, field_meta) in fields.items():
if len(data) <= i:
curr_field[fname] = None
# Get vectors
# Get dense vectors
if field_meta.type in (
DataType.FLOAT_VECTOR,
DataType.BINARY_VECTOR,
Expand All @@ -552,7 +556,7 @@ def __init__(
curr_field.update(data[i])
continue

# normal fields
# sparse float vector and other fields
curr_field[fname] = data[i]

hits.append(Hit(pks[i], distances[i], curr_field))
Expand Down
28 changes: 10 additions & 18 deletions pymilvus/client/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pymilvus.exceptions import ParamError
from pymilvus.grpc_gen import milvus_pb2 as milvus_types

from . import entity_helper
from .singleton_utils import Singleton


Expand Down Expand Up @@ -40,24 +41,6 @@ def is_legal_port(port: Any) -> bool:
return False


def is_legal_vector(array: Any) -> bool:
if not array or not isinstance(array, list) or len(array) == 0:
return False

return True


def is_legal_bin_vector(array: Any) -> bool:
if not array or not isinstance(array, bytes) or len(array) == 0:
return False

return True


def is_legal_numpy_array(array: Any) -> bool:
return not (array is None or array.size == 0)


def int_or_str(item: Union[int, str]) -> str:
if isinstance(item, int):
return str(item)
Expand Down Expand Up @@ -149,6 +132,10 @@ def is_legal_max_iterations(max_iterations: Any) -> bool:
return isinstance(max_iterations, int)


def is_legal_drop_ratio(drop_ratio: Any) -> bool:
return isinstance(drop_ratio, float) and 0 <= drop_ratio < 1


def is_legal_team_size(team_size: Any) -> bool:
return isinstance(team_size, int)

Expand Down Expand Up @@ -197,6 +184,9 @@ def is_legal_anns_field(field: Any) -> bool:
def is_legal_search_data(data: Any) -> bool:
import numpy as np

if entity_helper.entity_is_sparse_matrix(data):
return True

if not isinstance(data, (list, np.ndarray)):
return False

Expand Down Expand Up @@ -331,6 +321,8 @@ def __init__(self) -> None:
"team_size": is_legal_team_size,
"index_name": is_legal_index_name,
"timeout": is_legal_timeout,
"drop_ratio_build": is_legal_drop_ratio,
"drop_ratio_search": is_legal_drop_ratio,
}

def check(self, key: str, value: Callable):
Expand Down
Loading

0 comments on commit 818f290

Please sign in to comment.