Skip to content

Commit

Permalink
Support clustering key in create collection (#1969)
Browse files Browse the repository at this point in the history
Signed-off-by: wayblink <[email protected]>
  • Loading branch information
wayblink authored Mar 14, 2024
1 parent 87691ab commit bf21bf3
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 1 deletion.
4 changes: 4 additions & 0 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, raw: Any):
self.is_dynamic = False
# For array field
self.element_type = None
self.is_clustering_key = False
##
self.__pack(self._raw)

Expand All @@ -40,6 +41,7 @@ def __pack(self, raw: Any):
self.type = DataType(raw.data_type)
self.is_partition_key = raw.is_partition_key
self.element_type = DataType(raw.element_type)
self.is_clustering_key = raw.is_clustering_key
try:
self.is_dynamic = raw.is_dynamic
except Exception:
Expand Down Expand Up @@ -95,6 +97,8 @@ def dict(self):
_dict["auto_id"] = True
if self.is_primary:
_dict["is_primary"] = self.is_primary
if self.is_clustering_key:
_dict["is_clustering_key"] = True
return _dict


Expand Down
2 changes: 2 additions & 0 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def get_schema_from_collection_schema(
is_partition_key=f.is_partition_key,
is_dynamic=f.is_dynamic,
element_type=f.element_type,
is_clustering_key=f.is_clustering_key,
)
for k, v in f.params.items():
kv_pair = common_types.KeyValuePair(key=str(k), value=str(v))
Expand Down Expand Up @@ -175,6 +176,7 @@ def get_field_schema(
is_primary_key=is_primary,
autoID=auto_id,
is_partition_key=field.get("is_partition_key", False),
is_clustering_key=field.get("is_clustering_key", False),
)

type_params = field.get("params", {})
Expand Down
16 changes: 15 additions & 1 deletion pymilvus/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ class PartitionKeyException(MilvusException):
"""Raise when partitionkey are invalid"""


class ClusteringKeyException(MilvusException):
"""Raise when clusteringkey are invalid"""


class FieldsTypeException(MilvusException):
"""Raise when fields is invalid"""

Expand Down Expand Up @@ -185,7 +189,7 @@ class ExceptionsMessage:
PartitionKeyOnlyOne = "Expected only one partition key field, got [%s, %s, ...]."
PrimaryKeyType = "Primary key type must be DataType.INT64 or DataType.VARCHAR."
PartitionKeyType = "Partition key field type must be DataType.INT64 or DataType.VARCHAR."
PartitionKeyNotPrimary = "Primary key filed should not be primary field"
PartitionKeyNotPrimary = "Partition key field should not be primary field"
IsPrimaryType = "Param is_primary must be bool type."
PrimaryFieldType = "Param primary_field must be int or str type."
PartitionKeyFieldType = "Param partition_key_field must be str type."
Expand Down Expand Up @@ -219,3 +223,13 @@ class ExceptionsMessage:
"Ambiguous parameter, either ids or filter should be specified, cannot support both."
)
JSONKeyMustBeStr = "JSON key must be str."
ClusteringKeyNotPrimary = "Clustering key field should not be primary field"
ClusteringKeyType = (
"Clustering key field type must be DataType.INT8, DataType.INT16, "
"DataType.INT32, DataType.INT64, DataType.FLOAT, DataType.DOUBLE, "
"DataType.VARCHAR, DataType.FLOAT_VECTOR."
)
ClusteringKeyFieldNotExist = "the specified clustering key field {%s} not exist"
ClusteringKeyOnlyOne = "Expected only one clustering key field, got [%s, %s, ...]."
IsClusteringKeyType = "Param is_clustering_key must be bool type."
ClusteringKeyFieldType = "Param clustering_key_field must be str type."
54 changes: 54 additions & 0 deletions pymilvus/orm/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pymilvus.exceptions import (
AutoIDException,
CannotInferSchemaException,
ClusteringKeyException,
DataNotMatchException,
DataTypeNotSupportException,
ExceptionsMessage,
Expand Down Expand Up @@ -62,6 +63,29 @@ def validate_partition_key(
)


def validate_clustering_key(
clustering_key_field_name: Any, clustering_key_field: Any, primary_field_name: Any
):
if clustering_key_field is not None:
if clustering_key_field.name == primary_field_name:
raise ClusteringKeyException(message=ExceptionsMessage.ClusteringKeyNotPrimary)
if clustering_key_field.dtype not in [
DataType.INT8,
DataType.INT16,
DataType.INT32,
DataType.INT64,
DataType.FLOAT,
DataType.DOUBLE,
DataType.VARCHAR,
DataType.FLOAT_VECTOR,
]:
raise ClusteringKeyException(message=ExceptionsMessage.ClusteringKeyType)
elif clustering_key_field_name is not None:
raise ClusteringKeyException(
message=ExceptionsMessage.PartitionKeyFieldNotExist % clustering_key_field_name
)


class CollectionSchema:
def __init__(self, fields: List, description: str = "", **kwargs):
self._kwargs = copy.deepcopy(kwargs)
Expand All @@ -71,6 +95,7 @@ def __init__(self, fields: List, description: str = "", **kwargs):
self._enable_dynamic_field = self._kwargs.get("enable_dynamic_field", None)
self._primary_field = None
self._partition_key_field = None
self._clustering_key_field = None

if not isinstance(fields, list):
raise FieldsTypeException(message=ExceptionsMessage.FieldsType)
Expand All @@ -83,10 +108,13 @@ def __init__(self, fields: List, description: str = "", **kwargs):
def _check_kwargs(self):
primary_field_name = self._kwargs.get("primary_field", None)
partition_key_field_name = self._kwargs.get("partition_key_field", None)
clustering_key_field_name = self._kwargs.get("clustering_key_field_name", None)
if primary_field_name is not None and not isinstance(primary_field_name, str):
raise PrimaryKeyException(message=ExceptionsMessage.PrimaryFieldType)
if partition_key_field_name is not None and not isinstance(partition_key_field_name, str):
raise PartitionKeyException(message=ExceptionsMessage.PartitionKeyFieldType)
if clustering_key_field_name is not None and not isinstance(clustering_key_field_name, str):
raise ClusteringKeyException(message=ExceptionsMessage.ClusteringKeyFieldType)

for field in self._fields:
if not isinstance(field, FieldSchema):
Expand All @@ -98,6 +126,7 @@ def _check_kwargs(self):
def _check_fields(self):
primary_field_name = self._kwargs.get("primary_field", None)
partition_key_field_name = self._kwargs.get("partition_key_field", None)
clustering_key_field_name = self._kwargs.get("clustering_key_field", None)
for field in self._fields:
if primary_field_name and primary_field_name == field.name:
field.is_primary = True
Expand All @@ -122,10 +151,29 @@ def _check_fields(self):
self._partition_key_field = field
partition_key_field_name = field.name

if clustering_key_field_name and clustering_key_field_name == field.name:
field.is_clustering_key = True

if field.is_clustering_key:
if (
clustering_key_field_name is not None
and clustering_key_field_name != field.name
):
msg = ExceptionsMessage.ClusteringKeyOnlyOne % (
clustering_key_field_name,
field.name,
)
raise ClusteringKeyException(message=msg)
self._clustering_key_field = field
clustering_key_field_name = field.name

validate_primary_key(self._primary_field)
validate_partition_key(
partition_key_field_name, self._partition_key_field, self._primary_field.name
)
validate_clustering_key(
clustering_key_field_name, self._clustering_key_field, self._primary_field.name
)

auto_id = self._kwargs.get("auto_id", False)
if auto_id:
Expand Down Expand Up @@ -272,7 +320,10 @@ def __init__(self, name: str, dtype: DataType, description: str = "", **kwargs)

if not isinstance(kwargs.get("is_partition_key", False), bool):
raise PartitionKeyException(message=ExceptionsMessage.IsPartitionKeyType)
if not isinstance(kwargs.get("is_clustering_key", False), bool):
raise ClusteringKeyException(message=ExceptionsMessage.IsClusteringKeyType)
self.is_partition_key = kwargs.get("is_partition_key", False)
self.is_clustering_key = kwargs.get("is_clustering_key", False)
self.element_type = kwargs.get("element_type", None)
self._parse_type_params()

Expand Down Expand Up @@ -313,6 +364,7 @@ def construct_from_dict(cls, raw: Dict):
if raw.get("auto_id") is not None:
kwargs["auto_id"] = raw.get("auto_id")
kwargs["is_partition_key"] = raw.get("is_partition_key", False)
kwargs["is_clustering_key"] = raw.get("is_clustering_key", False)
kwargs["is_dynamic"] = raw.get("is_dynamic", False)
kwargs["element_type"] = raw.get("element_type")
return FieldSchema(raw["name"], raw["type"], raw.get("description", ""), **kwargs)
Expand All @@ -334,6 +386,8 @@ def to_dict(self):
_dict["is_dynamic"] = self.is_dynamic
if self.dtype == DataType.ARRAY and self.element_type:
_dict["element_type"] = self.element_type
if self.is_clustering_key:
_dict["is_clustering_key"] = True
return _dict

def __getattr__(self, item: str):
Expand Down

0 comments on commit bf21bf3

Please sign in to comment.