Skip to content

Commit

Permalink
Add stored fields for knn_vector type (#1640)
Browse files Browse the repository at this point in the history
Fixes bug where we were not creating stored field type for knn_vector
even when the mapping parameter is passed. Along with this, clean up
the field mapper implementations.

Add relevant uTs and iTs to ensure functionality is working as expected.

(cherry picked from commit 699510d)

Signed-off-by: John Mazanec <[email protected]>
(cherry picked from commit ec9ddb4)
Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Apr 24, 2024
1 parent 328f501 commit 17ced5e
Show file tree
Hide file tree
Showing 10 changed files with 309 additions and 51 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
* Add stored fields for knn_vector type [#1630](https://github.com/opensearch-project/k-NN/pull/1630)
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public float[] getValue() {
throw new IllegalStateException(errorMessage);
}
try {
return vectorDataType.getVectorFromDocValues(binaryDocValues.binaryValue());
return vectorDataType.getVectorFromBytesRef(binaryDocValues.binaryValue());
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
Expand Down
10 changes: 5 additions & 5 deletions src/main/java/org/opensearch/knn/index/VectorDataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunctio
}

@Override
public float[] getVectorFromDocValues(BytesRef binaryValue) {
public float[] getVectorFromBytesRef(BytesRef binaryValue) {
float[] vector = new float[binaryValue.length];
int i = 0;
int j = binaryValue.offset;
Expand All @@ -56,7 +56,7 @@ public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunctio
}

@Override
public float[] getVectorFromDocValues(BytesRef binaryValue) {
public float[] getVectorFromBytesRef(BytesRef binaryValue) {
ByteArrayInputStream byteStream = new ByteArrayInputStream(binaryValue.bytes, binaryValue.offset, binaryValue.length);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
return vectorSerializer.byteToFloatArray(byteStream);
Expand All @@ -81,12 +81,12 @@ public float[] getVectorFromDocValues(BytesRef binaryValue) {
public abstract FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction);

/**
* Deserializes float vector from doc values binary value.
* Deserializes float vector from BytesRef.
*
* @param binaryValue Binary Value of DocValues
* @param binaryValue Binary Value
* @return float vector deserialized from binary value
*/
public abstract float[] getVectorFromDocValues(BytesRef binaryValue);
public abstract float[] getVectorFromBytesRef(BytesRef binaryValue);

/**
* Validates if given VectorDataType is in the list of supported data types.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Supplier;
import lombok.Getter;
Expand All @@ -20,6 +21,7 @@
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.search.DocValuesFieldExistsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.opensearch.Version;
import org.opensearch.common.Explicit;
import org.opensearch.common.Nullable;
Expand Down Expand Up @@ -52,16 +54,6 @@
import org.opensearch.search.aggregations.support.CoreValuesSourceType;
import org.opensearch.search.lookup.SearchLookup;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Supplier;

import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP;
Expand All @@ -74,19 +66,17 @@
import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue;
import static org.opensearch.knn.common.KNNValidationUtil.validateVectorDimension;
import static org.opensearch.knn.index.KNNSettings.KNN_INDEX;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForByteVector;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForFloatVector;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.clipVectorValueToFP16Range;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.deserializeStoredVector;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFP16VectorValue;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting;

/**
* Field Mapper for KNN vector type.
*
* Extends ParametrizedFieldMapper in order to easily configure mapping parameters.
*
* Implementations of this class define what needs to be stored in Lucene's fieldType. This allows us to have
* alternative mappings for the same field type.
* Field Mapper for KNN vector type. Implementations of this class define what needs to be stored in Lucene's fieldType.
* This allows us to have alternative mappings for the same field type.
*/
@Log4j2
public abstract class KNNVectorFieldMapper extends ParametrizedFieldMapper {
Expand All @@ -109,8 +99,8 @@ private static KNNVectorFieldMapper toType(FieldMapper in) {
public static class Builder extends ParametrizedFieldMapper.Builder {
protected Boolean ignoreMalformed;

protected final Parameter<Boolean> stored = Parameter.boolParam("store", false, m -> toType(m).stored, false);
protected final Parameter<Boolean> hasDocValues = Parameter.boolParam("doc_values", false, m -> toType(m).hasDocValues, true);
protected final Parameter<Boolean> stored = Parameter.storeParam(m -> toType(m).stored, false);
protected final Parameter<Boolean> hasDocValues = Parameter.docValuesParam(m -> toType(m).hasDocValues, true);
protected final Parameter<Integer> dimension = new Parameter<>(KNNConstants.DIMENSION, false, () -> -1, (n, c, o) -> {
if (o == null) {
throw new IllegalArgumentException("Dimension cannot be null");
Expand Down Expand Up @@ -483,6 +473,11 @@ public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, S
failIfNoDocValues();
return new KNNVectorIndexFieldData.Builder(name(), CoreValuesSourceType.BYTES, this.vectorDataType);
}

@Override
public Object valueForDisplay(Object value) {
return deserializeStoredVector((BytesRef) value, vectorDataType);
}
}

protected Explicit<Boolean> ignoreMalformed;
Expand Down Expand Up @@ -561,7 +556,9 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
VectorField point = new VectorField(name(), array, fieldType);

context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
if (this.stored) {
context.doc().add(createStoredFieldForByteVector(name(), array));
}
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext);

Expand All @@ -572,7 +569,9 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
spaceType.validateVector(array);
VectorField point = new VectorField(name(), array, fieldType);
context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
if (this.stored) {
context.doc().add(createStoredFieldForFloatVector(name(), array));
}
} else {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD)
Expand Down Expand Up @@ -735,11 +734,6 @@ Optional<float[]> getFloatsFromContext(ParseContext context, int dimension, Meth
return Optional.of(array);
}

@Override
protected boolean docValuesByDefault() {
return true;
}

@Override
public ParametrizedFieldMapper.Builder getMergeBuilder() {
return new KNNVectorFieldMapper.Builder(simpleName(), modelDao, indexCreatedVersion).init(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.util.BytesRef;
import org.opensearch.index.mapper.ParametrizedFieldMapper;
import org.opensearch.index.mapper.ParseContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;
import org.opensearch.knn.index.util.KNNEngine;

import java.util.Arrays;
import java.util.Locale;

import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
Expand All @@ -43,7 +45,6 @@ public class KNNVectorFieldMapperUtil {
*/
public static void validateFP16VectorValue(float value) {
validateFloatVectorValue(value);

if (value < FP16_MIN_VALUE || value > FP16_MAX_VALUE) {
throw new IllegalArgumentException(
String.format(
Expand Down Expand Up @@ -135,14 +136,39 @@ public static FieldType buildDocValuesFieldType(KNNEngine knnEngine) {
return field;
}

public static void addStoredFieldForVectorField(
ParseContext context,
FieldType fieldType,
String mapperName,
String vectorFieldAsString
) {
if (fieldType.stored()) {
context.doc().add(new StoredField(mapperName, vectorFieldAsString));
/**
* Creates a stored field for a byte vector
*
* @param name field name
* @param vector vector to be added to stored field
*/
public static StoredField createStoredFieldForByteVector(String name, byte[] vector) {
return new StoredField(name, vector);
}

/**
* Creates a stored field for a float vector
*
* @param name field name
* @param vector vector to be added to stored field
*/
public static StoredField createStoredFieldForFloatVector(String name, float[] vector) {
return new StoredField(name, KNNVectorSerializerFactory.getDefaultSerializer().floatToByteArray(vector));
}

/**
* @param storedVector Vector representation in bytes
* @param vectorDataType type of vector
* @return either int[] or float[] of corresponding vector
*/
public static Object deserializeStoredVector(BytesRef storedVector, VectorDataType vectorDataType) {
if (VectorDataType.BYTE == vectorDataType) {
byte[] bytes = storedVector.bytes;
int[] byteAsIntArray = new int[bytes.length];
Arrays.setAll(byteAsIntArray, i -> bytes[i]);
return byteAsIntArray;
}

return vectorDataType.getVectorFromBytesRef(storedVector);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
import org.opensearch.knn.index.util.KNNEngine;

import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForByteVector;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForFloatVector;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.buildDocValuesFieldType;

/**
Expand Down Expand Up @@ -92,7 +93,9 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
KnnByteVectorField point = new KnnByteVectorField(name(), array, fieldType);

context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
if (this.stored) {
context.doc().add(createStoredFieldForByteVector(name(), array));
}

if (hasDocValues && vectorFieldType != null) {
context.doc().add(new VectorField(name(), array, vectorFieldType));
Expand All @@ -108,7 +111,9 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
KnnVectorField point = new KnnVectorField(name(), array, fieldType);

context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
if (this.stored) {
context.doc().add(createStoredFieldForFloatVector(name(), array));
}

if (hasDocValues && vectorFieldType != null) {
context.doc().add(new VectorField(name(), array, vectorFieldType));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ public class AdvancedFilteringUseCasesIT extends KNNRestTestCase {

private static final String FIELD_NAME_VECTOR = "test_vector";

private static final String PROPERTIES_FIELD = "properties";

private static final String FILTER_FIELD = "filter";

private static final String TERM_FIELD = "term";
Expand Down
Loading

0 comments on commit 17ced5e

Please sign in to comment.