From dae2025ac97a5ea7d255cf985f8a4e1545aff855 Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Fri, 5 Apr 2024 11:59:12 -0700 Subject: [PATCH] Update KNN80BinaryDocValues reader count live docs and use live docs as initial capacity to initialize vector address Signed-off-by: Navneet Verma --- .../org/opensearch/knn/index/KNNSettings.java | 22 +++- .../KNN80Codec/KNN80BinaryDocValues.java | 18 ++- .../KNN80Codec/KNN80DocValuesReader.java | 45 +++++++- .../knn/index/codec/util/KNNCodecUtil.java | 46 ++++++-- .../org/opensearch/knn/index/FaissIT.java | 104 ++++++++++++++++++ .../knn/index/codec/KNNCodecServiceTests.java | 1 + 6 files changed, 222 insertions(+), 14 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index 04e50ed9b..88a396f44 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -68,6 +68,7 @@ public class KNNSettings { public static final String KNN_ALGO_PARAM_INDEX_THREAD_QTY = "knn.algo_param.index_thread_qty"; public static final String KNN_MEMORY_CIRCUIT_BREAKER_ENABLED = "knn.memory.circuit_breaker.enabled"; public static final String KNN_MEMORY_CIRCUIT_BREAKER_LIMIT = "knn.memory.circuit_breaker.limit"; + public static final String KNN_VECTOR_STREAMING_MEMORY_LIMIT_IN_MB = "knn.vector_streaming_memory.limit"; public static final String KNN_CIRCUIT_BREAKER_TRIGGERED = "knn.circuit_breaker.triggered"; public static final String KNN_CACHE_ITEM_EXPIRY_ENABLED = "knn.cache.item.expiry.enabled"; public static final String KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES = "knn.cache.item.expiry.minutes"; @@ -93,6 +94,7 @@ public class KNNSettings { public static final Integer KNN_DEFAULT_MODEL_CACHE_SIZE_LIMIT_PERCENTAGE = 10; // By default, set aside 10% of the JVM for the limit public static final Integer KNN_MAX_MODEL_CACHE_SIZE_LIMIT_PERCENTAGE = 25; // Model cache limit cannot exceed 25% of the JVM heap public static final String KNN_DEFAULT_MEMORY_CIRCUIT_BREAKER_LIMIT = "50%"; + public static final String KNN_DEFAULT_VECTOR_STREAMING_MEMORY_LIMIT_PCT = "1%"; public static final Integer ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE = -1; @@ -100,6 +102,15 @@ public class KNNSettings { * Settings Definition */ + // This setting controls how much memory should be used to transfer vectors from Java to JNI Layer. The default + // 1% of the JVM heap + public static final Setting KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING = Setting.memorySizeSetting( + KNN_VECTOR_STREAMING_MEMORY_LIMIT_IN_MB, + KNN_DEFAULT_VECTOR_STREAMING_MEMORY_LIMIT_PCT, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + public static final Setting INDEX_KNN_SPACE_TYPE = Setting.simpleString( KNN_SPACE_TYPE, INDEX_KNN_DEFAULT_SPACE_TYPE, @@ -354,6 +365,10 @@ private Setting getSetting(String key) { return KNN_FAISS_AVX2_DISABLED_SETTING; } + if (KNN_VECTOR_STREAMING_MEMORY_LIMIT_IN_MB.equals(key)) { + return KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING; + } + throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); } @@ -371,7 +386,8 @@ public List> getSettings() { MODEL_INDEX_NUMBER_OF_REPLICAS_SETTING, MODEL_CACHE_SIZE_LIMIT_SETTING, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_SETTING, - KNN_FAISS_AVX2_DISABLED_SETTING + KNN_FAISS_AVX2_DISABLED_SETTING, + KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING ); return Stream.concat(settings.stream(), dynamicCacheSettings.values().stream()).collect(Collectors.toList()); } @@ -475,6 +491,10 @@ public void onFailure(Exception e) { }); } + public static ByteSizeValue getVectorStreamingMemoryLimit() { + return KNNSettings.state().getSettingValue(KNN_VECTOR_STREAMING_MEMORY_LIMIT_IN_MB); + } + /** * * @param index Name of the index diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValues.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValues.java index 832737a6d..df26766b3 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValues.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValues.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.codec.KNN80Codec; +import lombok.Getter; import org.opensearch.knn.index.codec.util.BinaryDocValuesSub; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocIDMerger; @@ -15,10 +16,13 @@ /** * A per-document kNN numeric value. */ -class KNN80BinaryDocValues extends BinaryDocValues { +public class KNN80BinaryDocValues extends BinaryDocValues { private DocIDMerger docIDMerger; + @Getter + private long totalLiveDocs; + KNN80BinaryDocValues(DocIDMerger docIdMerger) { this.docIDMerger = docIdMerger; } @@ -61,4 +65,14 @@ public long cost() { public BytesRef binaryValue() throws IOException { return current.getValues().binaryValue(); } -}; + + /** + * Builder pattern like setter for setting totalLiveDocs. We can use setter also. But this way the code is clean. + * @param totalLiveDocs int + * @return {@link KNN80BinaryDocValues} + */ + public KNN80BinaryDocValues setTotalLiveDocs(long totalLiveDocs) { + this.totalLiveDocs = totalLiveDocs; + return this; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesReader.java index ccfaa68fc..16380c5d9 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesReader.java @@ -5,6 +5,10 @@ package org.opensearch.knn.index.codec.KNN80Codec; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.Bits; +import org.opensearch.common.StopWatch; import org.opensearch.knn.index.codec.util.BinaryDocValuesSub; import org.apache.lucene.codecs.DocValuesProducer; import org.apache.lucene.index.BinaryDocValues; @@ -14,12 +18,14 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.MergeState; +import java.io.IOException; import java.util.ArrayList; import java.util.List; /** * Reader for KNNDocValues from the segments */ +@Log4j2 class KNN80DocValuesReader extends EmptyDocValuesProducer { private final MergeState mergeState; @@ -30,6 +36,7 @@ class KNN80DocValuesReader extends EmptyDocValuesProducer { @Override public BinaryDocValues getBinary(FieldInfo field) { + long totalLiveDocs = 0; try { List subs = new ArrayList<>(this.mergeState.docValuesProducers.length); for (int i = 0; i < this.mergeState.docValuesProducers.length; i++) { @@ -41,13 +48,49 @@ public BinaryDocValues getBinary(FieldInfo field) { values = docValuesProducer.getBinary(readerFieldInfo); } if (values != null) { + totalLiveDocs = totalLiveDocs + getLiveDocsCount(values, this.mergeState.liveDocs[i]); + // docValues will be consumed when liveDocs are not null, hence resetting the docsValues + // pointer. + values = this.mergeState.liveDocs[i] != null ? docValuesProducer.getBinary(readerFieldInfo) : values; + subs.add(new BinaryDocValuesSub(mergeState.docMaps[i], values)); } } } - return new KNN80BinaryDocValues(DocIDMerger.of(subs, mergeState.needsIndexSort)); + return new KNN80BinaryDocValues(DocIDMerger.of(subs, mergeState.needsIndexSort)).setTotalLiveDocs(totalLiveDocs); } catch (Exception e) { throw new RuntimeException(e); } } + + /** + * This function return the liveDocs count present in the BinaryDocValues. If the liveDocsBits is null, then we + * can use {@link BinaryDocValues#cost()} function to get max docIds. But if LiveDocsBits is not null, then we + * iterate over the BinaryDocValues and validate if the docId is present in the live docs bits or not. + * + * @param binaryDocValues {@link BinaryDocValues} + * @param liveDocsBits {@link Bits} + * @return total number of liveDocs. + * @throws IOException + */ + private long getLiveDocsCount(final BinaryDocValues binaryDocValues, final Bits liveDocsBits) throws IOException { + long liveDocs = 0; + if (liveDocsBits != null) { + int docId; + // This is not the right way to log the time. I create a github issue for adding an annotation to track + // the time. https://github.com/opensearch-project/k-NN/issues/1594 + StopWatch stopWatch = new StopWatch(); + stopWatch.start(); + for (docId = binaryDocValues.nextDoc(); docId != DocIdSetIterator.NO_MORE_DOCS; docId = binaryDocValues.nextDoc()) { + if (liveDocsBits.get(docId)) { + liveDocs++; + } + } + stopWatch.stop(); + log.debug("Time taken to iterate over binary doc values: {} ms", stopWatch.totalTime().millis()); + } else { + liveDocs = binaryDocValues.cost(); + } + return liveDocs; + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index eecec6b50..c5ae469e0 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -11,16 +11,16 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.codec.KNN80Codec.KNN80BinaryDocValues; import org.opensearch.knn.jni.JNICommons; import java.io.ByteArrayInputStream; import java.io.IOException; import java.util.ArrayList; +import java.util.List; public class KNNCodecUtil { - - public static final String HNSW_EXTENSION = ".hnsw"; - public static final String HNSW_COMPOUND_EXTENSION = ".hnswc"; // Floats are 4 bytes in size public static final int FLOAT_BYTE_SIZE = 4; // References to objects are 4 bytes in size @@ -44,11 +44,16 @@ public static final class Pair { } public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOException { - ArrayList vectorList = new ArrayList<>(); - ArrayList docIdList = new ArrayList<>(); + List vectorList = new ArrayList<>(); + List docIdList = new ArrayList<>(); long vectorAddress = 0; int dimension = 0; SerializationMode serializationMode = SerializationMode.COLLECTION_OF_FLOATS; + + long totalLiveDocs = getTotalLiveDocsCount(values); + long vectorsStreamingMemoryLimit = KNNSettings.getVectorStreamingMemoryLimit().getBytes(); + long vectorsPerTransfer = Integer.MIN_VALUE; + for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) { BytesRef bytesref = values.binaryValue(); try (ByteArrayInputStream byteStream = new ByteArrayInputStream(bytesref.bytes, bytesref.offset, bytesref.length)) { @@ -56,16 +61,27 @@ public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOExcep final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); final float[] vector = vectorSerializer.byteToFloatArray(byteStream); dimension = vector.length; + + if (vectorsPerTransfer == Integer.MIN_VALUE) { + vectorsPerTransfer = (dimension * Float.BYTES * totalLiveDocs) / vectorsStreamingMemoryLimit; + } + if (vectorList.size() == vectorsPerTransfer) { + vectorAddress = JNICommons.storeVectorData( + vectorAddress, + vectorList.toArray(new float[][] {}), + totalLiveDocs * dimension + ); + // We should probably come up with a better way to reuse the vectorList memory which we have + // created. Problem here is doing like this can lead to a lot of list memory which is of no use and + // will be garbage collected later on, but it creates pressure on JVM. We should revisit this. + vectorList = new ArrayList<>(); + } vectorList.add(vector); } docIdList.add(doc); } if (vectorList.isEmpty() == false) { - vectorAddress = JNICommons.storeVectorData( - vectorAddress, - vectorList.toArray(new float[][] {}), - (long) vectorList.size() * dimension - ); + vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension); } return new KNNCodecUtil.Pair(docIdList.stream().mapToInt(Integer::intValue).toArray(), vectorAddress, dimension, serializationMode); } @@ -105,4 +121,14 @@ public static String buildEngineFilePrefix(String segmentName) { public static String buildEngineFileSuffix(String fieldName, String extension) { return String.format("_%s%s", fieldName, extension); } + + private static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues) { + long totalLiveDocs; + if (binaryDocValues instanceof KNN80BinaryDocValues) { + totalLiveDocs = ((KNN80BinaryDocValues) binaryDocValues).getTotalLiveDocs(); + } else { + totalLiveDocs = binaryDocValues.cost(); + } + return totalLiveDocs; + } } diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 0cec3810e..85dd3f169 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -35,10 +35,12 @@ import java.io.IOException; import java.net.URL; import java.util.Arrays; +import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Random; +import java.util.Set; import java.util.TreeMap; import java.util.stream.Collectors; @@ -173,6 +175,108 @@ public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() { fail("Graphs are not getting evicted"); } + @SneakyThrows + public void testEndToEnd_whenMethodIsHNSWFlatAndHasDeletedDocs_thenSucceed() { + String indexName = "test-index-1"; + String fieldName = "test-field-1"; + + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + SpaceType spaceType = SpaceType.L2; + + List mValues = ImmutableList.of(16, 32, 64, 128); + List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + List efSearchValues = ImmutableList.of(16, 32, 64, 128); + + Integer dimension = testData.indexData.vectors[0].length; + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(KNNConstants.PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(indexName, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + fieldName, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + // Assert we have the right number of documents in the index + refreshAllNonSystemIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + + final Set docIdsToBeDeleted = new HashSet<>(); + while (docIdsToBeDeleted.size() < 10) { + docIdsToBeDeleted.add(randomInt(testData.indexData.docs.length)); + } + + for (Integer id : docIdsToBeDeleted) { + deleteKnnDoc(indexName, Integer.toString(testData.indexData.docs[id])); + } + refreshAllNonSystemIndices(); + forceMergeKnnIndex(indexName, 3); + + assertEquals(testData.indexData.docs.length - 10, getDocCount(indexName)); + + int k = 10; + for (int i = 0; i < testData.queries.length; i++) { + Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, testData.queries[i], k), k); + String responseBody = EntityUtils.toString(response.getEntity()); + List knnResults = parseSearchResponse(responseBody, fieldName); + assertEquals(k, knnResults.size()); + + List actualScores = parseSearchResponseScore(responseBody, fieldName); + for (int j = 0; j < k; j++) { + float[] primitiveArray = knnResults.get(j).getVector(); + assertEquals( + KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType), + actualScores.get(j), + 0.0001 + ); + } + } + + // Delete index + deleteKNNIndex(indexName); + + // Search every 5 seconds 14 times to confirm graph gets evicted + int intervals = 14; + for (int i = 0; i < intervals; i++) { + if (getTotalGraphsInCache() == 0) { + return; + } + + Thread.sleep(5 * 1000); + } + + fail("Graphs are not getting evicted"); + } + @SneakyThrows public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() { String indexName = "test-index"; diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecServiceTests.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecServiceTests.java index 233b9adf7..dfe4e7f22 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecServiceTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecServiceTests.java @@ -36,6 +36,7 @@ public void setUp() throws Exception { super.setUp(); IndexMetadata indexMetadata = mock(IndexMetadata.class); when(indexMetadata.getIndex()).thenReturn(new Index(TEST_INDEX, INDEX_UUID.toString())); + when(indexMetadata.getCustomData(IndexMetadata.REMOTE_STORE_CUSTOM_KEY)).thenReturn(null); when(indexMetadata.getSettings()).thenReturn(Settings.EMPTY); Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(NUM_OF_SHARDS)).build(); indexSettings = new IndexSettings(indexMetadata, settings);