Skip to content

Commit

Permalink
Add stored fields for knn_vector type
Browse files Browse the repository at this point in the history
Fixes bug where we were not creating stored field type for knn_vector
even when the mapping parameter is passed. Along with this, clean up
the field mapper implementations.

Add relevant uTs and iTs to ensure functionality is working as expected.

Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Apr 18, 2024
1 parent dc0953a commit 2afd641
Show file tree
Hide file tree
Showing 7 changed files with 299 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Supplier;
import lombok.Getter;
Expand All @@ -20,6 +21,7 @@
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.search.DocValuesFieldExistsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.opensearch.Version;
import org.opensearch.common.Explicit;
import org.opensearch.common.Nullable;
Expand Down Expand Up @@ -52,16 +54,6 @@
import org.opensearch.search.aggregations.support.CoreValuesSourceType;
import org.opensearch.search.lookup.SearchLookup;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Supplier;

import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP;
Expand All @@ -74,19 +66,16 @@
import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue;
import static org.opensearch.knn.common.KNNValidationUtil.validateVectorDimension;
import static org.opensearch.knn.index.KNNSettings.KNN_INDEX;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForVectorField;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.clipVectorValueToFP16Range;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.deserializeStoredVector;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFP16VectorValue;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting;

/**
* Field Mapper for KNN vector type.
*
* Extends ParametrizedFieldMapper in order to easily configure mapping parameters.
*
* Implementations of this class define what needs to be stored in Lucene's fieldType. This allows us to have
* alternative mappings for the same field type.
* Field Mapper for KNN vector type. Implementations of this class define what needs to be stored in Lucene's fieldType.
* This allows us to have alternative mappings for the same field type.
*/
@Log4j2
public abstract class KNNVectorFieldMapper extends ParametrizedFieldMapper {
Expand All @@ -109,8 +98,8 @@ private static KNNVectorFieldMapper toType(FieldMapper in) {
public static class Builder extends ParametrizedFieldMapper.Builder {
protected Boolean ignoreMalformed;

protected final Parameter<Boolean> stored = Parameter.boolParam("store", false, m -> toType(m).stored, false);
protected final Parameter<Boolean> hasDocValues = Parameter.boolParam("doc_values", false, m -> toType(m).hasDocValues, true);
protected final Parameter<Boolean> stored = Parameter.storeParam(m -> toType(m).stored, false);
protected final Parameter<Boolean> hasDocValues = Parameter.docValuesParam(m -> toType(m).hasDocValues, true);
protected final Parameter<Integer> dimension = new Parameter<>(KNNConstants.DIMENSION, false, () -> -1, (n, c, o) -> {
if (o == null) {
throw new IllegalArgumentException("Dimension cannot be null");
Expand Down Expand Up @@ -483,6 +472,11 @@ public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, S
failIfNoDocValues();
return new KNNVectorIndexFieldData.Builder(name(), CoreValuesSourceType.BYTES, this.vectorDataType);
}

@Override
public Object valueForDisplay(Object value) {
return deserializeStoredVector((BytesRef) value, vectorDataType);
}
}

protected Explicit<Boolean> ignoreMalformed;
Expand Down Expand Up @@ -561,7 +555,9 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
VectorField point = new VectorField(name(), array, fieldType);

context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point);
if (this.stored) {
context.doc().add(createStoredFieldForVectorField(name(), array));
}
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext);

Expand All @@ -572,7 +568,9 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
spaceType.validateVector(array);
VectorField point = new VectorField(name(), array, fieldType);
context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point);
if (this.stored) {
context.doc().add(createStoredFieldForVectorField(name(), array));
}
} else {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD)
Expand Down Expand Up @@ -735,11 +733,6 @@ Optional<float[]> getFloatsFromContext(ParseContext context, int dimension, Meth
return Optional.of(array);
}

@Override
protected boolean docValuesByDefault() {
return true;
}

@Override
public ParametrizedFieldMapper.Builder getMergeBuilder() {
return new KNNVectorFieldMapper.Builder(simpleName(), modelDao, indexCreatedVersion).init(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.util.BytesRef;
import org.opensearch.index.mapper.ParametrizedFieldMapper;
import org.opensearch.index.mapper.ParseContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;
import org.opensearch.knn.index.util.KNNEngine;

import java.util.Arrays;
import java.util.Locale;

import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
Expand All @@ -44,7 +45,6 @@ public class KNNVectorFieldMapperUtil {
*/
public static void validateFP16VectorValue(float value) {
validateFloatVectorValue(value);

if (value < FP16_MIN_VALUE || value > FP16_MAX_VALUE) {
throw new IllegalArgumentException(
String.format(
Expand Down Expand Up @@ -136,9 +136,39 @@ public static FieldType buildDocValuesFieldType(KNNEngine knnEngine) {
return field;
}

public static void addStoredFieldForVectorField(ParseContext context, FieldType fieldType, String mapperName, Field vectorField) {
if (fieldType.stored()) {
context.doc().add(new StoredField(mapperName, vectorField.toString()));
/**
* Creates a stored field for a byte vector
*
* @param name field name
* @param vector vector to be added to stored field
*/
public static StoredField createStoredFieldForVectorField(String name, byte[] vector) {
return new StoredField(name, vector);
}

/**
* Creates a stored field for a float vector
*
* @param name field name
* @param vector vector to be added to stored field
*/
public static StoredField createStoredFieldForVectorField(String name, float[] vector) {
return new StoredField(name, KNNVectorSerializerFactory.getDefaultSerializer().floatToByteArray(vector));
}

/**
* @param storedVector Vector representation in bytes
* @param vectorDataType type of vector
* @return either int[] or float[] of corresponding vector
*/
public static Object deserializeStoredVector(BytesRef storedVector, VectorDataType vectorDataType) {
if (VectorDataType.BYTE == vectorDataType) {
byte[] bytes = storedVector.bytes;
int[] byteAsIntArray = new int[bytes.length];
Arrays.setAll(byteAsIntArray, i -> bytes[i]);
return byteAsIntArray;
} else {
return vectorDataType.getVectorFromDocValues(storedVector);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import org.opensearch.knn.index.util.KNNEngine;

import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForVectorField;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.buildDocValuesFieldType;

/**
Expand Down Expand Up @@ -92,7 +92,9 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
KnnByteVectorField point = new KnnByteVectorField(name(), array, fieldType);

context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point);
if (this.stored) {
context.doc().add(createStoredFieldForVectorField(name(), array));
}

if (hasDocValues && vectorFieldType != null) {
context.doc().add(new VectorField(name(), array, vectorFieldType));
Expand All @@ -108,7 +110,9 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
KnnVectorField point = new KnnVectorField(name(), array, fieldType);

context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point);
if (this.stored) {
context.doc().add(createStoredFieldForVectorField(name(), array));
}

if (hasDocValues && vectorFieldType != null) {
context.doc().add(new VectorField(name(), array, vectorFieldType));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ public class AdvancedFilteringUseCasesIT extends KNNRestTestCase {

private static final String FIELD_NAME_VECTOR = "test_vector";

private static final String PROPERTIES_FIELD = "properties";

private static final String FILTER_FIELD = "filter";

private static final String TERM_FIELD = "term";
Expand Down
Loading

0 comments on commit 2afd641

Please sign in to comment.