From 2afd64105575308754407588002bf801d331d01c Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Thu, 18 Apr 2024 15:33:49 -0700 Subject: [PATCH] Add stored fields for knn_vector type Fixes bug where we were not creating stored field type for knn_vector even when the mapping parameter is passed. Along with this, clean up the field mapper implementations. Add relevant uTs and iTs to ensure functionality is working as expected. Signed-off-by: John Mazanec --- .../index/mapper/KNNVectorFieldMapper.java | 45 ++--- .../mapper/KNNVectorFieldMapperUtil.java | 42 +++- .../knn/index/mapper/LuceneFieldMapper.java | 10 +- .../index/AdvancedFilteringUseCasesIT.java | 2 - .../knn/index/KNNMapperSearcherIT.java | 182 +++++++++++++++++- .../mapper/KNNVectorFieldMapperUtilTests.java | 54 ++++++ .../org/opensearch/knn/KNNRestTestCase.java | 4 + 7 files changed, 299 insertions(+), 40 deletions(-) create mode 100644 src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 0fa026f343..332de28529 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -11,6 +11,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.function.Supplier; import lombok.Getter; @@ -20,6 +21,7 @@ import org.apache.lucene.index.IndexOptions; import org.apache.lucene.search.DocValuesFieldExistsQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.util.BytesRef; import org.opensearch.Version; import org.opensearch.common.Explicit; import org.opensearch.common.Nullable; @@ -52,16 +54,6 @@ import org.opensearch.search.aggregations.support.CoreValuesSourceType; import org.opensearch.search.lookup.SearchLookup; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.function.Supplier; - import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP; @@ -74,19 +66,16 @@ import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue; import static org.opensearch.knn.common.KNNValidationUtil.validateVectorDimension; import static org.opensearch.knn.index.KNNSettings.KNN_INDEX; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForVectorField; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.clipVectorValueToFP16Range; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.deserializeStoredVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFP16VectorValue; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting; /** - * Field Mapper for KNN vector type. - * - * Extends ParametrizedFieldMapper in order to easily configure mapping parameters. - * - * Implementations of this class define what needs to be stored in Lucene's fieldType. This allows us to have - * alternative mappings for the same field type. + * Field Mapper for KNN vector type. Implementations of this class define what needs to be stored in Lucene's fieldType. + * This allows us to have alternative mappings for the same field type. */ @Log4j2 public abstract class KNNVectorFieldMapper extends ParametrizedFieldMapper { @@ -109,8 +98,8 @@ private static KNNVectorFieldMapper toType(FieldMapper in) { public static class Builder extends ParametrizedFieldMapper.Builder { protected Boolean ignoreMalformed; - protected final Parameter stored = Parameter.boolParam("store", false, m -> toType(m).stored, false); - protected final Parameter hasDocValues = Parameter.boolParam("doc_values", false, m -> toType(m).hasDocValues, true); + protected final Parameter stored = Parameter.storeParam(m -> toType(m).stored, false); + protected final Parameter hasDocValues = Parameter.docValuesParam(m -> toType(m).hasDocValues, true); protected final Parameter dimension = new Parameter<>(KNNConstants.DIMENSION, false, () -> -1, (n, c, o) -> { if (o == null) { throw new IllegalArgumentException("Dimension cannot be null"); @@ -483,6 +472,11 @@ public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, S failIfNoDocValues(); return new KNNVectorIndexFieldData.Builder(name(), CoreValuesSourceType.BYTES, this.vectorDataType); } + + @Override + public Object valueForDisplay(Object value) { + return deserializeStoredVector((BytesRef) value, vectorDataType); + } } protected Explicit ignoreMalformed; @@ -561,7 +555,9 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s VectorField point = new VectorField(name(), array, fieldType); context.doc().add(point); - addStoredFieldForVectorField(context, fieldType, name(), point); + if (this.stored) { + context.doc().add(createStoredFieldForVectorField(name(), array)); + } } else if (VectorDataType.FLOAT == vectorDataType) { Optional floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext); @@ -572,7 +568,9 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s spaceType.validateVector(array); VectorField point = new VectorField(name(), array, fieldType); context.doc().add(point); - addStoredFieldForVectorField(context, fieldType, name(), point); + if (this.stored) { + context.doc().add(createStoredFieldForVectorField(name(), array)); + } } else { throw new IllegalArgumentException( String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD) @@ -735,11 +733,6 @@ Optional getFloatsFromContext(ParseContext context, int dimension, Meth return Optional.of(array); } - @Override - protected boolean docValuesByDefault() { - return true; - } - @Override public ParametrizedFieldMapper.Builder getMergeBuilder() { return new KNNVectorFieldMapper.Builder(simpleName(), modelDao, indexCreatedVersion).init(this); diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index 074be03757..5673cf520b 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -13,15 +13,16 @@ import lombok.AccessLevel; import lombok.NoArgsConstructor; -import org.apache.lucene.document.Field; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.StoredField; import org.apache.lucene.index.DocValuesType; +import org.apache.lucene.util.BytesRef; import org.opensearch.index.mapper.ParametrizedFieldMapper; -import org.opensearch.index.mapper.ParseContext; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; import org.opensearch.knn.index.util.KNNEngine; +import java.util.Arrays; import java.util.Locale; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; @@ -44,7 +45,6 @@ public class KNNVectorFieldMapperUtil { */ public static void validateFP16VectorValue(float value) { validateFloatVectorValue(value); - if (value < FP16_MIN_VALUE || value > FP16_MAX_VALUE) { throw new IllegalArgumentException( String.format( @@ -136,9 +136,39 @@ public static FieldType buildDocValuesFieldType(KNNEngine knnEngine) { return field; } - public static void addStoredFieldForVectorField(ParseContext context, FieldType fieldType, String mapperName, Field vectorField) { - if (fieldType.stored()) { - context.doc().add(new StoredField(mapperName, vectorField.toString())); + /** + * Creates a stored field for a byte vector + * + * @param name field name + * @param vector vector to be added to stored field + */ + public static StoredField createStoredFieldForVectorField(String name, byte[] vector) { + return new StoredField(name, vector); + } + + /** + * Creates a stored field for a float vector + * + * @param name field name + * @param vector vector to be added to stored field + */ + public static StoredField createStoredFieldForVectorField(String name, float[] vector) { + return new StoredField(name, KNNVectorSerializerFactory.getDefaultSerializer().floatToByteArray(vector)); + } + + /** + * @param storedVector Vector representation in bytes + * @param vectorDataType type of vector + * @return either int[] or float[] of corresponding vector + */ + public static Object deserializeStoredVector(BytesRef storedVector, VectorDataType vectorDataType) { + if (VectorDataType.BYTE == vectorDataType) { + byte[] bytes = storedVector.bytes; + int[] byteAsIntArray = new int[bytes.length]; + Arrays.setAll(byteAsIntArray, i -> bytes[i]); + return byteAsIntArray; + } else { + return vectorDataType.getVectorFromDocValues(storedVector); } } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index d61fa11503..223c96ee7a 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -25,7 +25,7 @@ import org.opensearch.knn.index.util.KNNEngine; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForVectorField; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.buildDocValuesFieldType; /** @@ -92,7 +92,9 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s KnnByteVectorField point = new KnnByteVectorField(name(), array, fieldType); context.doc().add(point); - addStoredFieldForVectorField(context, fieldType, name(), point); + if (this.stored) { + context.doc().add(createStoredFieldForVectorField(name(), array)); + } if (hasDocValues && vectorFieldType != null) { context.doc().add(new VectorField(name(), array, vectorFieldType)); @@ -108,7 +110,9 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s KnnVectorField point = new KnnVectorField(name(), array, fieldType); context.doc().add(point); - addStoredFieldForVectorField(context, fieldType, name(), point); + if (this.stored) { + context.doc().add(createStoredFieldForVectorField(name(), array)); + } if (hasDocValues && vectorFieldType != null) { context.doc().add(new VectorField(name(), array, vectorFieldType)); diff --git a/src/test/java/org/opensearch/knn/index/AdvancedFilteringUseCasesIT.java b/src/test/java/org/opensearch/knn/index/AdvancedFilteringUseCasesIT.java index 46b5590fc4..b786ee8731 100644 --- a/src/test/java/org/opensearch/knn/index/AdvancedFilteringUseCasesIT.java +++ b/src/test/java/org/opensearch/knn/index/AdvancedFilteringUseCasesIT.java @@ -54,8 +54,6 @@ public class AdvancedFilteringUseCasesIT extends KNNRestTestCase { private static final String FIELD_NAME_VECTOR = "test_vector"; - private static final String PROPERTIES_FIELD = "properties"; - private static final String FILTER_FIELD = "filter"; private static final String TERM_FIELD = "term"; diff --git a/src/test/java/org/opensearch/knn/index/KNNMapperSearcherIT.java b/src/test/java/org/opensearch/knn/index/KNNMapperSearcherIT.java index e56560c0de..3e95c6463e 100644 --- a/src/test/java/org/opensearch/knn/index/KNNMapperSearcherIT.java +++ b/src/test/java/org/opensearch/knn/index/KNNMapperSearcherIT.java @@ -5,20 +5,34 @@ package org.opensearch.knn.index; +import lombok.SneakyThrows; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.KNNResult; import org.apache.hc.core5.http.io.entity.EntityUtils; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.client.Response; import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.knn.index.util.KNNEngine; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.QUERY; +import static org.opensearch.knn.common.KNNConstants.TYPE; +import static org.opensearch.knn.common.KNNConstants.TYPE_KNN_VECTOR; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; + public class KNNMapperSearcherIT extends KNNRestTestCase { - private static final Logger logger = LogManager.getLogger(KNNMapperSearcherIT.class); + + private static final String INDEX_NAME = "test_index"; + private static final String FIELD_NAME = "test_vector"; /** * Test Data set @@ -239,4 +253,166 @@ public void testLargeK() throws Exception { List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(results.size(), 4); } + + /** + * Request: + * { + * "stored_fields": ["test_vector"], + * "query": { + * "match_all": {} + * } + * } + * + * Example Response: + * { + * "took":248, + * "timed_out":false, + * "_shards":{ + * "total":1, + * "successful":1, + * "skipped":0, + * "failed":0 + * }, + * "hits":{ + * "total":{ + * "value":1, + * "relation":"eq" + * }, + * "max_score":1.0, + * "hits":[ + * { + * "_index":"test_index", + * "_id":"1", + * "_score":1.0, + * "fields":{"test_vector":[[-128,0,1,127]]} + * } + * ] + * } + * } + */ + @SneakyThrows + public void testStoredFields_whenByteDataType_thenSucceed() { + // Create index with stored field and confirm that we can properly retrieve it + byte[] testVector = new byte[] { -128, 0, 1, 127 }; + String expectedResponse = String.format("\"fields\":{\"%s\":[[-128,0,1,127]]}}", FIELD_NAME); + createKnnIndex( + INDEX_NAME, + createVectorMapping(testVector.length, KNNEngine.LUCENE.getName(), VectorDataType.BYTE.getValue(), true) + ); + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, testVector); + + final XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.field(STORED_QUERY_FIELD, List.of(FIELD_NAME)); + builder.startObject(QUERY); + builder.startObject(MATCH_ALL_QUERY_FIELD); + builder.endObject(); + builder.endObject(); + builder.endObject(); + + String response = EntityUtils.toString(performSearch(INDEX_NAME, builder.toString()).getEntity()); + assertTrue(response.contains(expectedResponse)); + + deleteKNNIndex(INDEX_NAME); + } + + /** + * Request: + * { + * "stored_fields": ["test_vector"], + * "query": { + * "match_all": {} + * } + * } + * + * Example Response: + * { + * "took":248, + * "timed_out":false, + * "_shards":{ + * "total":1, + * "successful":1, + * "skipped":0, + * "failed":0 + * }, + * "hits":{ + * "total":{ + * "value":1, + * "relation":"eq" + * }, + * "max_score":1.0, + * "hits":[ + * { + * "_index":"test_index", + * "_id":"1", + * "_score":1.0, + * "fields":{"test_vector":[[-100.0,100.0,0.0,1.0]]} + * } + * ] + * } + * } + */ + @SneakyThrows + public void testStoredFields_whenFloatDataType_thenSucceed() { + List enginesToTest = List.of(KNNEngine.NMSLIB, KNNEngine.FAISS, KNNEngine.LUCENE); + float[] testVector = new float[] { -100.0f, 100.0f, 0f, 1f }; + String expectedResponse = String.format("\"fields\":{\"%s\":[[-100.0,100.0,0.0,1.0]]}}", FIELD_NAME); + for (KNNEngine knnEngine : enginesToTest) { + createKnnIndex(INDEX_NAME, createVectorMapping(testVector.length, knnEngine.getName(), VectorDataType.FLOAT.getValue(), true)); + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, testVector); + + final XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.field(STORED_QUERY_FIELD, List.of(FIELD_NAME)); + builder.startObject(QUERY); + builder.startObject(MATCH_ALL_QUERY_FIELD); + builder.endObject(); + builder.endObject(); + builder.endObject(); + + String response = EntityUtils.toString(performSearch(INDEX_NAME, builder.toString()).getEntity()); + assertTrue(response.contains(expectedResponse)); + + deleteKNNIndex(INDEX_NAME); + } + } + + /** + * Mapping + * { + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": {dimension}, + * "data_type": "{type}", + * "stored": true + * "method": { + * "name": "hnsw", + * "engine": "{engine}" + * } + * } + * } + * } + */ + @SneakyThrows + private String createVectorMapping(final int dimension, final String engine, final String dataType, final boolean isStored) { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME) + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, dimension) + .field(VECTOR_DATA_TYPE_FIELD, dataType) + .field(STORE_FIELD, isStored) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, engine) + .endObject() + .endObject() + .endObject() + .endObject(); + + return builder.toString(); + } + } diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java new file mode 100644 index 0000000000..fe8ea8e368 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java @@ -0,0 +1,54 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.mapper; + +import org.apache.lucene.document.StoredField; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; + +import java.io.ByteArrayInputStream; +import java.util.Arrays; + +public class KNNVectorFieldMapperUtilTests extends KNNTestCase { + + private static final String TEST_FIELD_NAME = "test_field_name"; + private static final byte[] TEST_BYTE_VECTOR = new byte[] { -128, 0, 1, 127 }; + private static final float[] TEST_FLOAT_VECTOR = new float[] { -100.0f, 100.0f, 0f, 1f }; + + public void testStoredFields_whenVectorIsByteType_thenSucceed() { + StoredField storedField = KNNVectorFieldMapperUtil.createStoredFieldForVectorField(TEST_FIELD_NAME, TEST_BYTE_VECTOR); + assertEquals(TEST_FIELD_NAME, storedField.name()); + assertEquals(TEST_BYTE_VECTOR, storedField.binaryValue().bytes); + Object vector = KNNVectorFieldMapperUtil.deserializeStoredVector(storedField.binaryValue(), VectorDataType.BYTE); + assertTrue(vector instanceof int[]); + int[] byteAsIntArray = new int[TEST_BYTE_VECTOR.length]; + Arrays.setAll(byteAsIntArray, i -> TEST_BYTE_VECTOR[i]); + assertArrayEquals(byteAsIntArray, (int[]) vector); + } + + public void testStoredFields_whenVectorIsFloatType_thenSucceed() { + StoredField storedField = KNNVectorFieldMapperUtil.createStoredFieldForVectorField(TEST_FIELD_NAME, TEST_FLOAT_VECTOR); + assertEquals(TEST_FIELD_NAME, storedField.name()); + byte[] bytes = storedField.binaryValue().bytes; + ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bytes, 0, bytes.length); + assertArrayEquals( + TEST_FLOAT_VECTOR, + KNNVectorSerializerFactory.getDefaultSerializer().byteToFloatArray(byteArrayInputStream), + 0.001f + ); + + Object vector = KNNVectorFieldMapperUtil.deserializeStoredVector(storedField.binaryValue(), VectorDataType.FLOAT); + assertTrue(vector instanceof float[]); + assertArrayEquals(TEST_FLOAT_VECTOR, (float[]) vector, 0.001f); + } +} diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 396c8ea646..5010ff6ee3 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -107,6 +107,10 @@ public class KNNRestTestCase extends ODFERestTestCase { public static final String INDEX_NAME = "test_index"; public static final String FIELD_NAME = "test_field"; + public static final String PROPERTIES_FIELD = "properties"; + public static final String STORE_FIELD = "store"; + public static final String STORED_QUERY_FIELD = "stored_fields"; + public static final String MATCH_ALL_QUERY_FIELD = "match_all"; private static final String DOCUMENT_FIELD_SOURCE = "_source"; private static final String DOCUMENT_FIELD_FOUND = "found"; protected static final int DELAY_MILLI_SEC = 1000;