Skip to content

Commit

Permalink
Update KNN80BinaryDocValues reader count live docs and use live docs …
Browse files Browse the repository at this point in the history
…as initial capacity to initialize vector address

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v committed Apr 8, 2024
1 parent badbb1d commit dae2025
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 14 deletions.
22 changes: 21 additions & 1 deletion src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ public class KNNSettings {
public static final String KNN_ALGO_PARAM_INDEX_THREAD_QTY = "knn.algo_param.index_thread_qty";
public static final String KNN_MEMORY_CIRCUIT_BREAKER_ENABLED = "knn.memory.circuit_breaker.enabled";
public static final String KNN_MEMORY_CIRCUIT_BREAKER_LIMIT = "knn.memory.circuit_breaker.limit";
public static final String KNN_VECTOR_STREAMING_MEMORY_LIMIT_IN_MB = "knn.vector_streaming_memory.limit";
public static final String KNN_CIRCUIT_BREAKER_TRIGGERED = "knn.circuit_breaker.triggered";
public static final String KNN_CACHE_ITEM_EXPIRY_ENABLED = "knn.cache.item.expiry.enabled";
public static final String KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES = "knn.cache.item.expiry.minutes";
Expand All @@ -93,13 +94,23 @@ public class KNNSettings {
public static final Integer KNN_DEFAULT_MODEL_CACHE_SIZE_LIMIT_PERCENTAGE = 10; // By default, set aside 10% of the JVM for the limit
public static final Integer KNN_MAX_MODEL_CACHE_SIZE_LIMIT_PERCENTAGE = 25; // Model cache limit cannot exceed 25% of the JVM heap
public static final String KNN_DEFAULT_MEMORY_CIRCUIT_BREAKER_LIMIT = "50%";
public static final String KNN_DEFAULT_VECTOR_STREAMING_MEMORY_LIMIT_PCT = "1%";

public static final Integer ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE = -1;

/**
* Settings Definition
*/

// This setting controls how much memory should be used to transfer vectors from Java to JNI Layer. The default
// 1% of the JVM heap
public static final Setting<ByteSizeValue> KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING = Setting.memorySizeSetting(
KNN_VECTOR_STREAMING_MEMORY_LIMIT_IN_MB,
KNN_DEFAULT_VECTOR_STREAMING_MEMORY_LIMIT_PCT,
Setting.Property.Dynamic,
Setting.Property.NodeScope
);

public static final Setting<String> INDEX_KNN_SPACE_TYPE = Setting.simpleString(
KNN_SPACE_TYPE,
INDEX_KNN_DEFAULT_SPACE_TYPE,
Expand Down Expand Up @@ -354,6 +365,10 @@ private Setting<?> getSetting(String key) {
return KNN_FAISS_AVX2_DISABLED_SETTING;
}

if (KNN_VECTOR_STREAMING_MEMORY_LIMIT_IN_MB.equals(key)) {
return KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING;
}

throw new IllegalArgumentException("Cannot find setting by key [" + key + "]");
}

Expand All @@ -371,7 +386,8 @@ public List<Setting<?>> getSettings() {
MODEL_INDEX_NUMBER_OF_REPLICAS_SETTING,
MODEL_CACHE_SIZE_LIMIT_SETTING,
ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_SETTING,
KNN_FAISS_AVX2_DISABLED_SETTING
KNN_FAISS_AVX2_DISABLED_SETTING,
KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING
);
return Stream.concat(settings.stream(), dynamicCacheSettings.values().stream()).collect(Collectors.toList());
}
Expand Down Expand Up @@ -475,6 +491,10 @@ public void onFailure(Exception e) {
});
}

public static ByteSizeValue getVectorStreamingMemoryLimit() {
return KNNSettings.state().getSettingValue(KNN_VECTOR_STREAMING_MEMORY_LIMIT_IN_MB);
}

/**
*
* @param index Name of the index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.index.codec.KNN80Codec;

import lombok.Getter;
import org.opensearch.knn.index.codec.util.BinaryDocValuesSub;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocIDMerger;
Expand All @@ -15,10 +16,13 @@
/**
* A per-document kNN numeric value.
*/
class KNN80BinaryDocValues extends BinaryDocValues {
public class KNN80BinaryDocValues extends BinaryDocValues {

private DocIDMerger<BinaryDocValuesSub> docIDMerger;

@Getter
private long totalLiveDocs;

KNN80BinaryDocValues(DocIDMerger<BinaryDocValuesSub> docIdMerger) {
this.docIDMerger = docIdMerger;
}
Expand Down Expand Up @@ -61,4 +65,14 @@ public long cost() {
public BytesRef binaryValue() throws IOException {
return current.getValues().binaryValue();
}
};

/**
* Builder pattern like setter for setting totalLiveDocs. We can use setter also. But this way the code is clean.
* @param totalLiveDocs int
* @return {@link KNN80BinaryDocValues}
*/
public KNN80BinaryDocValues setTotalLiveDocs(long totalLiveDocs) {
this.totalLiveDocs = totalLiveDocs;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

package org.opensearch.knn.index.codec.KNN80Codec;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.Bits;
import org.opensearch.common.StopWatch;
import org.opensearch.knn.index.codec.util.BinaryDocValuesSub;
import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.index.BinaryDocValues;
Expand All @@ -14,12 +18,14 @@
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/**
* Reader for KNNDocValues from the segments
*/
@Log4j2
class KNN80DocValuesReader extends EmptyDocValuesProducer {

private final MergeState mergeState;
Expand All @@ -30,6 +36,7 @@ class KNN80DocValuesReader extends EmptyDocValuesProducer {

@Override
public BinaryDocValues getBinary(FieldInfo field) {
long totalLiveDocs = 0;
try {
List<BinaryDocValuesSub> subs = new ArrayList<>(this.mergeState.docValuesProducers.length);
for (int i = 0; i < this.mergeState.docValuesProducers.length; i++) {
Expand All @@ -41,13 +48,49 @@ public BinaryDocValues getBinary(FieldInfo field) {
values = docValuesProducer.getBinary(readerFieldInfo);
}
if (values != null) {
totalLiveDocs = totalLiveDocs + getLiveDocsCount(values, this.mergeState.liveDocs[i]);
// docValues will be consumed when liveDocs are not null, hence resetting the docsValues
// pointer.
values = this.mergeState.liveDocs[i] != null ? docValuesProducer.getBinary(readerFieldInfo) : values;

subs.add(new BinaryDocValuesSub(mergeState.docMaps[i], values));
}
}
}
return new KNN80BinaryDocValues(DocIDMerger.of(subs, mergeState.needsIndexSort));
return new KNN80BinaryDocValues(DocIDMerger.of(subs, mergeState.needsIndexSort)).setTotalLiveDocs(totalLiveDocs);
} catch (Exception e) {
throw new RuntimeException(e);
}
}

/**
* This function return the liveDocs count present in the BinaryDocValues. If the liveDocsBits is null, then we
* can use {@link BinaryDocValues#cost()} function to get max docIds. But if LiveDocsBits is not null, then we
* iterate over the BinaryDocValues and validate if the docId is present in the live docs bits or not.
*
* @param binaryDocValues {@link BinaryDocValues}
* @param liveDocsBits {@link Bits}
* @return total number of liveDocs.
* @throws IOException
*/
private long getLiveDocsCount(final BinaryDocValues binaryDocValues, final Bits liveDocsBits) throws IOException {
long liveDocs = 0;
if (liveDocsBits != null) {
int docId;
// This is not the right way to log the time. I create a github issue for adding an annotation to track
// the time. https://github.com/opensearch-project/k-NN/issues/1594
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (docId = binaryDocValues.nextDoc(); docId != DocIdSetIterator.NO_MORE_DOCS; docId = binaryDocValues.nextDoc()) {
if (liveDocsBits.get(docId)) {
liveDocs++;
}
}
stopWatch.stop();
log.debug("Time taken to iterate over binary doc values: {} ms", stopWatch.totalTime().millis());
} else {
liveDocs = binaryDocValues.cost();
}
return liveDocs;
}
}
46 changes: 36 additions & 10 deletions src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.codec.KNN80Codec.KNN80BinaryDocValues;
import org.opensearch.knn.jni.JNICommons;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class KNNCodecUtil {

public static final String HNSW_EXTENSION = ".hnsw";
public static final String HNSW_COMPOUND_EXTENSION = ".hnswc";
// Floats are 4 bytes in size
public static final int FLOAT_BYTE_SIZE = 4;
// References to objects are 4 bytes in size
Expand All @@ -44,28 +44,44 @@ public static final class Pair {
}

public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOException {
ArrayList<float[]> vectorList = new ArrayList<>();
ArrayList<Integer> docIdList = new ArrayList<>();
List<float[]> vectorList = new ArrayList<>();
List<Integer> docIdList = new ArrayList<>();
long vectorAddress = 0;
int dimension = 0;
SerializationMode serializationMode = SerializationMode.COLLECTION_OF_FLOATS;

long totalLiveDocs = getTotalLiveDocsCount(values);
long vectorsStreamingMemoryLimit = KNNSettings.getVectorStreamingMemoryLimit().getBytes();
long vectorsPerTransfer = Integer.MIN_VALUE;

for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) {
BytesRef bytesref = values.binaryValue();
try (ByteArrayInputStream byteStream = new ByteArrayInputStream(bytesref.bytes, bytesref.offset, bytesref.length)) {
serializationMode = KNNVectorSerializerFactory.serializerModeFromStream(byteStream);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
dimension = vector.length;

if (vectorsPerTransfer == Integer.MIN_VALUE) {
vectorsPerTransfer = (dimension * Float.BYTES * totalLiveDocs) / vectorsStreamingMemoryLimit;
}
if (vectorList.size() == vectorsPerTransfer) {
vectorAddress = JNICommons.storeVectorData(
vectorAddress,
vectorList.toArray(new float[][] {}),
totalLiveDocs * dimension
);
// We should probably come up with a better way to reuse the vectorList memory which we have
// created. Problem here is doing like this can lead to a lot of list memory which is of no use and
// will be garbage collected later on, but it creates pressure on JVM. We should revisit this.
vectorList = new ArrayList<>();
}
vectorList.add(vector);
}
docIdList.add(doc);
}
if (vectorList.isEmpty() == false) {
vectorAddress = JNICommons.storeVectorData(
vectorAddress,
vectorList.toArray(new float[][] {}),
(long) vectorList.size() * dimension
);
vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension);
}
return new KNNCodecUtil.Pair(docIdList.stream().mapToInt(Integer::intValue).toArray(), vectorAddress, dimension, serializationMode);
}
Expand Down Expand Up @@ -105,4 +121,14 @@ public static String buildEngineFilePrefix(String segmentName) {
public static String buildEngineFileSuffix(String fieldName, String extension) {
return String.format("_%s%s", fieldName, extension);
}

private static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues) {
long totalLiveDocs;
if (binaryDocValues instanceof KNN80BinaryDocValues) {
totalLiveDocs = ((KNN80BinaryDocValues) binaryDocValues).getTotalLiveDocs();
} else {
totalLiveDocs = binaryDocValues.cost();
}
return totalLiveDocs;
}
}
104 changes: 104 additions & 0 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
import java.io.IOException;
import java.net.URL;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -173,6 +175,108 @@ public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() {
fail("Graphs are not getting evicted");
}

@SneakyThrows
public void testEndToEnd_whenMethodIsHNSWFlatAndHasDeletedDocs_thenSucceed() {
String indexName = "test-index-1";
String fieldName = "test-field-1";

KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW);
SpaceType spaceType = SpaceType.L2;

List<Integer> mValues = ImmutableList.of(16, 32, 64, 128);
List<Integer> efConstructionValues = ImmutableList.of(16, 32, 64, 128);
List<Integer> efSearchValues = ImmutableList.of(16, 32, 64, 128);

Integer dimension = testData.indexData.vectors[0].length;

// Create an index
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(fieldName)
.field("type", "knn_vector")
.field("dimension", dimension)
.startObject(KNNConstants.KNN_METHOD)
.field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName())
.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue())
.field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName())
.startObject(KNNConstants.PARAMETERS)
.field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size())))
.field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size())))
.field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size())))
.endObject()
.endObject()
.endObject()
.endObject()
.endObject();

Map<String, Object> mappingMap = xContentBuilderToMap(builder);
String mapping = builder.toString();

createKnnIndex(indexName, mapping);
assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName)));

// Index the test data
for (int i = 0; i < testData.indexData.docs.length; i++) {
addKnnDoc(
indexName,
Integer.toString(testData.indexData.docs[i]),
fieldName,
Floats.asList(testData.indexData.vectors[i]).toArray()
);
}

// Assert we have the right number of documents in the index
refreshAllNonSystemIndices();
assertEquals(testData.indexData.docs.length, getDocCount(indexName));

final Set<Integer> docIdsToBeDeleted = new HashSet<>();
while (docIdsToBeDeleted.size() < 10) {
docIdsToBeDeleted.add(randomInt(testData.indexData.docs.length));
}

for (Integer id : docIdsToBeDeleted) {
deleteKnnDoc(indexName, Integer.toString(testData.indexData.docs[id]));
}
refreshAllNonSystemIndices();
forceMergeKnnIndex(indexName, 3);

assertEquals(testData.indexData.docs.length - 10, getDocCount(indexName));

int k = 10;
for (int i = 0; i < testData.queries.length; i++) {
Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, testData.queries[i], k), k);
String responseBody = EntityUtils.toString(response.getEntity());
List<KNNResult> knnResults = parseSearchResponse(responseBody, fieldName);
assertEquals(k, knnResults.size());

List<Float> actualScores = parseSearchResponseScore(responseBody, fieldName);
for (int j = 0; j < k; j++) {
float[] primitiveArray = knnResults.get(j).getVector();
assertEquals(
KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType),
actualScores.get(j),
0.0001
);
}
}

// Delete index
deleteKNNIndex(indexName);

// Search every 5 seconds 14 times to confirm graph gets evicted
int intervals = 14;
for (int i = 0; i < intervals; i++) {
if (getTotalGraphsInCache() == 0) {
return;
}

Thread.sleep(5 * 1000);
}

fail("Graphs are not getting evicted");
}

@SneakyThrows
public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() {
String indexName = "test-index";
Expand Down
Loading

0 comments on commit dae2025

Please sign in to comment.