Skip to content

Commit

Permalink
feat:add geospatial api support for py client
Browse files Browse the repository at this point in the history
fix: complete geospatial impl
  • Loading branch information
tasty-gumi committed Sep 5, 2024
1 parent 6cc2e55 commit 2376f6a
Show file tree
Hide file tree
Showing 17 changed files with 2,040 additions and 817 deletions.
42 changes: 42 additions & 0 deletions examples/genwkt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy as np
import random

def random_point()->str:
x = random.uniform(-90, 90)
y = random.uniform(-180, 180)
return f"POINT ({x:.3f} {y:.3f})"

def random_linestring(num_points)->str:
points = ", ".join(f"{random.uniform(-90, 90):.3f} {random.uniform(-180, 180):.3f}" for _ in range(num_points))
return f"LINESTRING ({points})"

def random_polygon(num_points: int) -> str:
points = [
f"{random.uniform(-90, 90):.3f} {random.uniform(-180, 180):.3f}"
for _ in range(num_points)
]
# 闭合多边形
points.append(points[0]) # 将第一个点再添加一次
return f"POLYGON(({', '.join(points)}))"


def generate_data(num):
data = list()
for i in range(num):
if i%3==0:
data.append(random_point())
elif i%3==1:
data.append(random_linestring(random.randint(2,9)))
else:
data.append(random_polygon(random.randint(3,9)))
return data

def main():
num_entities = 10
data = generate_data(num_entities)
for item in data:
print(item)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion examples/hello_milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,4 @@
# 7. drop collection
# Finally, drop the hello_milvus collection
print(fmt.format("Drop collection `hello_milvus`"))
utility.drop_collection("hello_milvus")
utility.drop_collection("hello_milvus")
226 changes: 226 additions & 0 deletions examples/hello_milvus_geospatial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# hello_milvus.py demonstrates the basic operations of PyMilvus, a Python SDK of Milvus.
# 1. connect to Milvus
# 2. create collection
# 3. insert data
# 4. create index
# 5. search, query, and hybrid search on entities
# 6. delete entities by PK
# 7. drop collection
import time

import numpy as np
from pymilvus import (
connections,
utility,
FieldSchema, CollectionSchema, DataType,
Collection,
)
from genwkt import generate_data

fmt = "\n=== {:30} ===\n"
search_latency_fmt = "search latency = {:.4f}s"
num_entities, dim = 3000, 8

#################################################################################
# 1. connect to Milvus
# Add a new connection alias `default` for Milvus server in `localhost:19530`
# Actually the "default" alias is a buildin in PyMilvus.
# If the address of Milvus is the same as `localhost:19530`, you can omit all
# parameters and call the method as: `connections.connect()`.
#
# Note: the `using` parameter of the following methods is default to "default".
print(fmt.format("start connecting to Milvus"))
connections.connect("default", host="localhost", port="19530")

has = utility.has_collection("hello_milvus")
print(f"Does collection hello_milvus exist in Milvus: {has}")

#################################################################################
# 2. create collection
# We're going to create a collection with 3 fields.
# +-+------------+------------+------------------+------------------------------+
# | | field name | field type | other attributes | field description |
# +-+------------+------------+------------------+------------------------------+
# |1| "pk" | VarChar | is_primary=True | "primary field" |
# | | | | auto_id=False | |
# +-+------------+------------+------------------+------------------------------+
# |2| "random" | Double | | "a double field" |
# +-+------------+------------+------------------+------------------------------+
# |3|"embeddings"| FloatVector| dim=8 | "float vector with dim 8" |
# +-+------------+------------+------------------+------------------------------+
fields = [
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100),
FieldSchema(name="random", dtype=DataType.DOUBLE),
FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=dim),
FieldSchema(name="geospatial",dtype=DataType.GEOSPATIAL)
]

schema = CollectionSchema(fields, "hello_milvus is the simplest demo to introduce the APIs")

print(fmt.format("Create collection `hello_milvus`"))
hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong")
print(f"The milvus describetion: {hello_milvus.describe()}")

################################################################################
# 3. insert data
# We are going to insert 3000 rows of data into `hello_milvus`
# Data to be inserted must be organized in fields.
#
# The insert() method returns:
# - either automatically generated primary keys by Milvus if auto_id=True in the schema;
# - or the existing primary key field from the entities if auto_id=False in the schema.

print(fmt.format("Start inserting entities"))
rng = np.random.default_rng(seed=19530)
entities = [
# provide the pk field because `auto_id` is set to False
[str(i) for i in range(num_entities)],
rng.random(num_entities).tolist(), # field random, only supports list
rng.random((num_entities, dim), np.float32), # field embeddings, supports numpy.ndarray and list
generate_data(num_entities) #field geospatial,wkt list
]

# print(entities)
insert_result = hello_milvus.insert(entities)

row = {
"pk": "19530",
"random": 0.5,
"embeddings": rng.random((1, dim), np.float32)[0],
"geospatial": "POINT (-84.036 39.997)"
}
hello_milvus.insert(row)
row = {
"pk": "19531",
"random": 0.5,
"embeddings": rng.random((1, dim), np.float32)[0],
"geospatial": "POLYGON ((0 0, 0 2, 2 2, 2 0, 0 0))"
}
hello_milvus.insert(row)
row = {
"pk": "19532",
"random": 0.5,
"embeddings": rng.random((1, dim), np.float32)[0],
"geospatial": "POLYGON ((1 1, 1 3, 3 3, 3 1, 1 1))"
}
hello_milvus.insert(row)

hello_milvus.flush()
print(f"Number of entities in Milvus: {hello_milvus.num_entities}") # check the num_entities

b=0
input(b)

################################################################################
# 4. create index
# We are going to create an IVF_FLAT index for hello_milvus collection.
# create_index() can only be applied to `FloatVector` and `BinaryVector` fields.
print(fmt.format("Start Creating index IVF_FLAT"))
index = {
"index_type": "IVF_FLAT",
"metric_type": "L2",
"params": {"nlist": 128},
}

hello_milvus.create_index("embeddings", index)

################################################################################
# 5. search, query, and hybrid search
# After data were inserted into Milvus and indexed, you can perform:
# - search based on vector similarity
# - query based on scalar filtering(boolean, int, etc.)
# - hybrid search based on vector similarity and scalar filtering.
#

# Before conducting a search or a query, you need to load the data in `hello_milvus` into memory.
print(fmt.format("Start loading"))
hello_milvus.load()

a=0
input(a)

# -----------------------------------------------------------------------------
# search based on vector similarity
print(fmt.format("Start searching based on vector similarity"))
vectors_to_search = entities[-2][-2:]
search_params = {
"metric_type": "L2",
"params": {"nprobe": 10},
}

start_time = time.time()
result = hello_milvus.search(vectors_to_search, "embeddings", search_params, limit=3, output_fields=["geospatial"])
end_time = time.time()

for hits in result:
for hit in hits:
print(f"hit: {hit}, random field: {hit.entity.get('random')}")
print(search_latency_fmt.format(end_time - start_time))

# -----------------------------------------------------------------------------
# query based on scalar filtering(boolean, int, etc.)
print(fmt.format("Start querying with GIS FUNC"))

start_time = time.time()
result1 = hello_milvus.query(expr="geospatial_equals(geospatial,'POINT (-84.036 39.997)')", output_fields=["random", "geospatial"])
result2 = hello_milvus.query(expr="geospatial_touches(geospatial,'POLYGON ((0 0, -1 0, -1 -1, 0 -1, 0 0))')", output_fields=["random", "geospatial"])
result3 = hello_milvus.query(expr="geospatial_overlaps(geospatial,'POLYGON ((6 0, 6 5, 8 5, 8 0, 6 0))')", output_fields=["random", "geospatial"])
result4 = hello_milvus.query(expr="geospatial_crosses(geospatial,'POLYGON ((6 0, 6 5, 8 5, 8 0, 6 0))')", output_fields=["random", "geospatial"])
result5 = hello_milvus.query(expr="geospatial_contains(geospatial,'POLYGON ((6 0, 6 5, 8 5, 8 0, 6 0))')", output_fields=["random", "geospatial"])
result6 = hello_milvus.query(expr="geospatial_intersects(geospatial,'POLYGON ((6 0, 6 5, 8 5, 8 0, 6 0))')", output_fields=["random", "geospatial"])
# the within realationship operator refers to which data in geo field within the wkt literal
result7 = hello_milvus.query(expr="geospatial_within(geospatial,'POLYGON ((0 0, 0 4, 4 4, 4 0, 0 0))')", output_fields=["random", "geospatial"])
end_time = time.time()

print(f"equals query result1:\n-{result1[0]}")
print(f"touches query result2:\n-{result2[0]}")
print(f"overlaps query result3:\n-{result3[0]}")
print(f"crosses query result4:\n-{result4[0]}")
print(f"contains query result5:\n-{result5[0]}")
print(f"intersects query result6:\n-{result6[0]}")
print(f"within query result7:\n-{result7[0]}")
print(search_latency_fmt.format(end_time - start_time))

# -----------------------------------------------------------------------------
# pagination
r1 = hello_milvus.query(expr="random > 0.5", limit=4, output_fields=["random"])
r2 = hello_milvus.query(expr="random > 0.5", offset=1, limit=3, output_fields=["random"])
print(f"query pagination(limit=4):\n\t{r1}")
print(f"query pagination(offset=1, limit=3):\n\t{r2}")


# -----------------------------------------------------------------------------
# hybrid search
print(fmt.format("Start hybrid searching with `random > 0.5`"))

start_time = time.time()
result = hello_milvus.search(vectors_to_search, "embeddings", search_params, limit=3, expr="random > 0.5", output_fields=["random"])
end_time = time.time()

for hits in result:
for hit in hits:
print(f"hit: {hit}, random field: {hit.entity.get('random')}")
print(search_latency_fmt.format(end_time - start_time))

###############################################################################
# 6. delete entities by PK
# You can delete entities by their PK values using boolean expressions.
ids = insert_result.primary_keys

expr = f'pk in ["{ids[0]}" , "{ids[1]}"]'
print(fmt.format(f"Start deleting with expr `{expr}`"))

result = hello_milvus.query(expr=expr, output_fields=["random", "geospatial"])
print(f"query before delete by expr=`{expr}` -> result: \n-{result[0]}\n-{result[1]}\n")

hello_milvus.delete(expr)

result = hello_milvus.query(expr=expr, output_fields=["random", "geospatial"])
print(f"query after delete by expr=`{expr}` -> result: {result}\n")


###############################################################################
# 7. drop collection
# Finally, drop the hello_milvus collection
print(fmt.format("Drop collection `hello_milvus`"))
utility.drop_collection("hello_milvus")
4 changes: 4 additions & 0 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,10 @@ def get_fields_by_range(
field2data[name] = json_dict_list, field_meta
continue

if dtype == DataType.GEOSPATIAL:
geospatial_data_list = [ data.decode(Config.EncodeProtocol) for data in scalars.geospatial_data.data[start:end] ]
field2data[name] = geospatial_data_list, field_meta

if dtype == DataType.ARRAY:
res = apply_valid_data(
scalars.array_data.data[start:end], field.valid_data, start, end
Expand Down
31 changes: 31 additions & 0 deletions pymilvus/client/entity_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,17 @@ def convert_to_json_arr(objs: List[object]):
def entity_to_json_arr(entity: Dict):
return convert_to_json_arr(entity.get("values", []))

def convert_to_wkt_bytes(wktstr:str):
return ujson.dumps(wktstr,ensure_ascii=False).encode(Config.EncodeProtocol)

def convert_to_wkt_bytes_arr(wktarray:list[str]):
arr = []
for wkt in wktarray:
arr.append(convert_to_wkt_bytes(wkt))
return arr

def entity_to_wktbyte_arr(entity: Dict):
return convert_to_wkt_bytes_arr(entity.get("values", []))

def convert_to_array_arr(objs: List[Any], field_info: Any):
return [convert_to_array(obj, field_info) for obj in objs]
Expand Down Expand Up @@ -385,6 +396,14 @@ def pack_field_value_to_field_data(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "json", type(field_value))
) from e
elif field_type == DataType.GEOSPATIAL:
try:
field_data.scalars.geospatial_data.data.append(convert_to_wkt_bytes(field_value))
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
%(field_name,"geospatial",type(field_value))
) from e
elif field_type == DataType.ARRAY:
try:
field_data.scalars.array_data.data.append(convert_to_array(field_value, field_info))
Expand Down Expand Up @@ -513,6 +532,14 @@ def entity_to_field_data(entity: Any, field_info: Any, num_rows: int):
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "json", type(entity.get("values")[0]))
) from e
elif entity_type == DataType.GEOSPATIAL:
try:
field_data.scalars.geospatial_data.data.extend(entity_to_wktbyte_arr(entity))
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
%(field_name,"geospatial",type(entity.get("values")[0]))
) from e
elif entity_type == DataType.ARRAY:
try:
field_data.scalars.array_data.data.extend(entity_to_array_arr(entity, field_info))
Expand Down Expand Up @@ -670,6 +697,10 @@ def check_append(field_data: Any):

entity_row_data.update({k: v for k, v in json_dict.items() if k in dynamic_fields})
return
if field_data.type == DataType.GEOSPATIAL and len(field_data.scalars.geospatial_data.data)>=index:
entity_row_data[field_data.field_name] = field_data.scalars.geospatial_data.data[index].decode(Config.EncodeProtocol)
return

if field_data.type == DataType.ARRAY and len(field_data.scalars.array_data.data) >= index:
if len(field_data.valid_data) > 0 and field_data.valid_data[index] is False:
entity_row_data[field_data.field_name] = None
Expand Down
1 change: 1 addition & 0 deletions pymilvus/client/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class DataType(IntEnum):
VARCHAR = 21
ARRAY = 22
JSON = 23
GEOSPATIAL = 24

BINARY_VECTOR = 100
FLOAT_VECTOR = 101
Expand Down
3 changes: 3 additions & 0 deletions pymilvus/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ def len_of(field_data: Any) -> int:
if field_data.scalars.HasField("json_data"):
return len(field_data.scalars.json_data.data)

if field_data.scalars.HasField("geospatial_data"):
return len(field_data.scalars.geospatial_data.data)

if field_data.scalars.HasField("array_data"):
return len(field_data.scalars.array_data.data)

Expand Down
Loading

0 comments on commit 2376f6a

Please sign in to comment.