Skip to content

Commit

Permalink
Add Range Validation for SQFP16 (#1493)
Browse files Browse the repository at this point in the history
* Add Range Validation for SQFP16 Vector Data

Signed-off-by: Naveen Tatikonda <[email protected]>

* Add index setting to clip vector data to FP16 range

Signed-off-by: Naveen Tatikonda <[email protected]>

* Add CHANGELOG

Signed-off-by: Naveen Tatikonda <[email protected]>

* Add an encoder parameter to clip fp16 range

Signed-off-by: Naveen Tatikonda <[email protected]>

* Address Review Comments

Signed-off-by: Naveen Tatikonda <[email protected]>

* Add BWC Tests

Signed-off-by: Naveen Tatikonda <[email protected]>

---------

Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda authored Mar 14, 2024
1 parent 7a144b8 commit 8cad13b
Show file tree
Hide file tree
Showing 9 changed files with 697 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
* Optize Faiss Query With Filters: Reduce iteration and memory for id filter [#1402](https://github.com/opensearch-project/k-NN/pull/1402)
* Add Range Validation for Faiss SQFP16 [#1493](https://github.com/opensearch-project/k-NN/pull/1493)
### Bug Fixes
* Disable sdc table for HNSWPQ read-only indices [#1518](https://github.com/opensearch-project/k-NN/pull/1518)
### Infrastructure
Expand Down
319 changes: 319 additions & 0 deletions qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/FaissSQIT.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.bwc;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.hc.core5.http.ParseException;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.opensearch.client.Response;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.KNNResult;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.KNNMethod;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.util.KNNEngine;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.TreeMap;

import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.FAISS_NAME;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_IVF;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.NAME;

public class FaissSQIT extends AbstractRestartUpgradeTestCase {
private static final String TEST_FIELD = "test-field";
private static final String TRAIN_TEST_FIELD = "train-test-field";
private static final String TRAIN_INDEX = "train-index";
private static final String TEST_MODEL = "test-model";
private static final int DIMENSION = 128;
private static final int NUM_DOCS = 100;

public void testHNSWSQFP16_onUpgradeWhenIndexedAndQueried_thenSucceed() throws Exception {
if (!isRunningAgainstOldCluster()) {
KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW);
SpaceType[] spaceTypes = { SpaceType.L2, SpaceType.INNER_PRODUCT };
Random random = new Random();
SpaceType spaceType = spaceTypes[random.nextInt(spaceTypes.length)];

List<Integer> mValues = ImmutableList.of(16, 32, 64, 128);
List<Integer> efConstructionValues = ImmutableList.of(16, 32, 64, 128);
List<Integer> efSearchValues = ImmutableList.of(16, 32, 64, 128);

// Create an index
/**
* "properties": {
* "test-field": {
* "type": "knn_vector",
* "dimension": 128,
* "method": {
* "name": "hnsw",
* "space_type": "l2",
* "engine": "faiss",
* "parameters": {
* "m": 16,
* "ef_construction": 128,
* "ef_search": 128,
* "encoder": {
* "name": "sq",
* "parameters": {
* "type": "fp16"
* }
* }
* }
* }
* }
* }
*/
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(TEST_FIELD)
.field("type", "knn_vector")
.field("dimension", DIMENSION)
.startObject(KNNConstants.KNN_METHOD)
.field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName())
.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue())
.field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName())
.startObject(KNNConstants.PARAMETERS)
.field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size())))
.field(
KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION,
efConstructionValues.get(random().nextInt(efConstructionValues.size()))
)
.field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size())))
.startObject(METHOD_ENCODER_PARAMETER)
.field(NAME, ENCODER_SQ)
.startObject(PARAMETERS)
.field(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16)
.endObject()
.endObject()
.endObject()
.endObject()
.endObject()
.endObject()
.endObject();

Map<String, Object> mappingMap = xContentBuilderToMap(builder);
String mapping = builder.toString();

createKnnIndex(testIndex, mapping);
assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(testIndex)));
indexTestData(testIndex, TEST_FIELD, DIMENSION, NUM_DOCS);
queryTestData(testIndex, TEST_FIELD, DIMENSION, NUM_DOCS);
deleteKNNIndex(testIndex);
validateGraphEviction();
}
}

public void testHNSWSQFP16_onUpgradeWhenClipToFp16isTrueAndIndexedWithOutOfFP16Range_thenSucceed() throws Exception {
if (!isRunningAgainstOldCluster()) {
KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW);
new Random();

List<Integer> mValues = ImmutableList.of(16, 32, 64, 128);
List<Integer> efConstructionValues = ImmutableList.of(16, 32, 64, 128);
List<Integer> efSearchValues = ImmutableList.of(16, 32, 64, 128);

int dimension = 2;

// Create an index
/**
* "properties": {
* "test-field": {
* "type": "knn_vector",
* "dimension": 128,
* "method": {
* "name": "hnsw",
* "space_type": "l2",
* "engine": "faiss",
* "parameters": {
* "m": 16,
* "ef_construction": 128,
* "ef_search": 128,
* "encoder": {
* "name": "sq",
* "parameters": {
* "type": "fp16",
* "clip": true
* }
* }
* }
* }
* }
* }
*/
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(TEST_FIELD)
.field("type", "knn_vector")
.field("dimension", dimension)
.startObject(KNNConstants.KNN_METHOD)
.field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName())
.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue())
.field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName())
.startObject(PARAMETERS)
.field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size())))
.field(
KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION,
efConstructionValues.get(random().nextInt(efConstructionValues.size()))
)
.field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size())))
.startObject(METHOD_ENCODER_PARAMETER)
.field(NAME, ENCODER_SQ)
.startObject(PARAMETERS)
.field(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16)
.field(FAISS_SQ_CLIP, true)
.endObject()
.endObject()
.endObject()
.endObject()
.endObject()
.endObject()
.endObject();

Map<String, Object> mappingMap = xContentBuilderToMap(builder);
String mapping = builder.toString();

createKnnIndex(testIndex, mapping);
assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(testIndex)));
Float[] vector1 = { -65523.76f, 65504.2f };
Float[] vector2 = { -270.85f, 65514.2f };
Float[] vector3 = { -150.9f, 65504.0f };
Float[] vector4 = { -20.89f, 100000000.0f };
addKnnDoc(testIndex, "1", TEST_FIELD, vector1);
addKnnDoc(testIndex, "2", TEST_FIELD, vector2);
addKnnDoc(testIndex, "3", TEST_FIELD, vector3);
addKnnDoc(testIndex, "4", TEST_FIELD, vector4);

float[] queryVector = { -10.5f, 25.48f };
int k = 4;
Response searchResponse = searchKNNIndex(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, k), k);
List<KNNResult> results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), TEST_FIELD);
assertEquals(k, results.size());
for (int i = 0; i < k; i++) {
assertEquals(k - i, Integer.parseInt(results.get(i).getDocId()));
}
deleteKNNIndex(testIndex);
validateGraphEviction();
}
}

public void testIVFSQFP16_onUpgradeWhenIndexedAndQueried_thenSucceed() throws Exception {
if (!isRunningAgainstOldCluster()) {

// Add training data
createBasicKnnIndex(TRAIN_INDEX, TRAIN_TEST_FIELD, DIMENSION);
int trainingDataCount = 200;
bulkIngestRandomVectors(TRAIN_INDEX, TRAIN_TEST_FIELD, trainingDataCount, DIMENSION);

XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.field(NAME, METHOD_IVF)
.field(KNN_ENGINE, FAISS_NAME)
.field(METHOD_PARAMETER_SPACE_TYPE, "l2")
.startObject(PARAMETERS)
.startObject(METHOD_ENCODER_PARAMETER)
.field(NAME, ENCODER_SQ)
.startObject(PARAMETERS)
.field(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16)
.endObject()
.endObject()
.endObject()
.endObject();
Map<String, Object> method = xContentBuilderToMap(builder);

trainModel(TEST_MODEL, TRAIN_INDEX, TRAIN_TEST_FIELD, DIMENSION, method, "faiss ivf sqfp16 test description");

// Make sure training succeeds after 30 seconds
assertTrainingSucceeds(TEST_MODEL, 30, 1000);

// Create knn index from model
String indexMapping = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(TEST_FIELD)
.field("type", "knn_vector")
.field(MODEL_ID, TEST_MODEL)
.endObject()
.endObject()
.endObject()
.toString();

createKnnIndex(testIndex, getKNNDefaultIndexSettings(), indexMapping);

indexTestData(testIndex, TEST_FIELD, DIMENSION, NUM_DOCS);
queryTestData(testIndex, TEST_FIELD, DIMENSION, NUM_DOCS);
deleteKNNIndex(TRAIN_INDEX);
deleteKNNIndex(testIndex);
deleteModel(TEST_MODEL);
validateGraphEviction();
}
}

private void validateGraphEviction() throws Exception {
// Search every 5 seconds 14 times to confirm graph gets evicted
int intervals = 14;
for (int i = 0; i < intervals; i++) {
if (getTotalGraphsInCache() == 0) {
return;
}

Thread.sleep(5 * 1000);
}

fail("Graphs are not getting evicted");
}

private void queryTestData(final String indexName, final String fieldName, final int dimension, final int numDocs) throws IOException,
ParseException {
float[] queryVector = new float[dimension];
Arrays.fill(queryVector, (float) numDocs);
int k = 10;

Response searchResponse = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, queryVector, k), k);
List<KNNResult> results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName);
assertEquals(k, results.size());
for (int i = 0; i < k; i++) {
assertEquals(numDocs - i - 1, Integer.parseInt(results.get(i).getDocId()));
}
}

private void indexTestData(final String indexName, final String fieldName, final int dimension, final int numDocs) throws Exception {
for (int i = 0; i < numDocs; i++) {
float[] indexVector = new float[dimension];
Arrays.fill(indexVector, (float) i);
addKnnDocWithAttributes(indexName, Integer.toString(i), fieldName, indexVector, ImmutableMap.of("rating", String.valueOf(i)));
}

// Assert that all docs are ingested
refreshAllNonSystemIndices();
assertEquals(numDocs, getDocCount(indexName));
}

}
4 changes: 4 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ public class KNNConstants {
public static final String FAISS_SQ_TYPE = "type";
public static final String FAISS_SQ_ENCODER_FP16 = "fp16";
public static final List<String> FAISS_SQ_ENCODER_TYPES = List.of(FAISS_SQ_ENCODER_FP16);
public static final String FAISS_SQ_CLIP = "clip";

// Parameter defaults/limits
public static final Integer ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT = 1;
Expand All @@ -110,6 +111,9 @@ public class KNNConstants {
public static final Integer MODEL_CACHE_CAPACITY_ATROPHY_THRESHOLD_IN_MINUTES = 30;
public static final Integer MODEL_CACHE_EXPIRE_AFTER_ACCESS_TIME_MINUTES = 30;

public static final Float FP16_MAX_VALUE = 65504.0f;
public static final Float FP16_MIN_VALUE = -65504.0f;

// Lib names
private static final String JNI_LIBRARY_PREFIX = "opensearchknn_";
public static final String FAISS_JNI_LIBRARY_NAME = JNI_LIBRARY_PREFIX + FAISS_NAME;
Expand Down
25 changes: 25 additions & 0 deletions src/main/java/org/opensearch/knn/index/Parameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,31 @@ public T getDefaultValue() {
*/
public abstract ValidationException validate(Object value);

/**
* Boolean method parameter
*/
public static class BooleanParameter extends Parameter<Boolean> {
public BooleanParameter(String name, Boolean defaultValue, Predicate<Boolean> validator) {
super(name, defaultValue, validator);
}

@Override
public ValidationException validate(Object value) {
ValidationException validationException = null;
if (!(value instanceof Boolean)) {
validationException = new ValidationException();
validationException.addValidationError(String.format("value not of type Boolean for Boolean parameter [%s].", getName()));
return validationException;
}

if (!validator.test((Boolean) value)) {
validationException = new ValidationException();
validationException.addValidationError(String.format("parameter validation failed for Boolean parameter [%s].", getName()));
}
return validationException;
}
}

/**
* Integer method parameter
*/
Expand Down
Loading

0 comments on commit 8cad13b

Please sign in to comment.