Skip to content

Commit

Permalink
Multi vector support for Faiss HNSW
Browse files Browse the repository at this point in the history
Apply the parentId filter to the Faiss HNSW search method. This ensures that documents are deduplicated based on their parentId, and the method returns k results for documents with nested fields.

Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 committed Jan 4, 2024
1 parent df6d1fa commit 4d2e517
Show file tree
Hide file tree
Showing 16 changed files with 1,337 additions and 79 deletions.
15 changes: 9 additions & 6 deletions jni/include/knn_extension/faiss/utils/BitSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,21 @@ struct BitSet {
* bitmap: 10001000 00000100
*
* for next_set_bit call with 4
* 1. it looks for bitmap[0]
* 2. bitmap[0] >> 4
* 1. it looks for words[0]
* 2. words[0] >> 4
* 3. count trailing zero of the result from step 2 which is 3
* 4. return 4(current index) + 3(result from step 3)
*/
struct FixedBitSet : public BitSet {
// Length of bitmap
size_t numBits;
// The number of bits in use
idx_t num_bits;

// Pointer to an array of uint64_t
// The exact number of longs needed to hold num_bits
size_t num_words;

// Array of uint64_t holding the bits
// Using uint64_t to leverage function __builtin_ctzll which is defined in faiss/impl/platform_macros.h
uint64_t* bitmap;
uint64_t* words;

FixedBitSet(const int* int_array, const int length);
idx_t next_set_bit(idx_t index) const;
Expand Down
4 changes: 2 additions & 2 deletions jni/include/knn_extension/faiss/utils/Heap.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ inline void maxheap_push(
std::unordered_map<int64_t, size_t>* parent_id_to_index,
int64_t parent_id) {

assert(parent_id_to_index->find(parent_id) != parent_id_to_index->end() && "parent id should not exist in the binary heap");
assert(parent_id_to_index->find(parent_id) == parent_id_to_index->end() && "parent id should not exist in the binary heap");

up_heap<faiss::CMax<T, int64_t>>(
bh_val,
Expand Down Expand Up @@ -189,7 +189,7 @@ inline void maxheap_replace_top(
std::unordered_map<int64_t, size_t>* parent_id_to_index,
int64_t parent_id) {

assert(parent_id_to_index->find(parent_id) != parent_id_to_index->end() && "parent id should not exist in the binary heap");
assert(parent_id_to_index->find(parent_id) == parent_id_to_index->end() && "parent id should not exist in the binary heap");

parent_id_to_id->erase(bh_ids[0]);
parent_id_to_index->erase(bh_ids[0]);
Expand Down
21 changes: 12 additions & 9 deletions jni/src/knn_extension/faiss/utils/BitSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,28 @@
FixedBitSet::FixedBitSet(const int* int_array, const int length){
assert(int_array && "int_array should not be null");
const int* maxValue = std::max_element(int_array, int_array + length);
this->numBits = (*maxValue >> 6) + 1; // div by 64
this->bitmap = new uint64_t[this->numBits]();
this->num_bits = *maxValue + 1;
this->num_words = (num_bits >> 6) + 1; // div by 64
this->words = new uint64_t[this->num_words]();
for(int i = 0 ; i < length ; i ++) {
int value = int_array[i];
int bitsetArrayIndex = value >> 6;
this->bitmap[bitsetArrayIndex] |= 1ULL << (value & 63); // Equivalent of 1ULL << (value % 64)
int bitset_array_index = value >> 6;
this->words[bitset_array_index] |= 1ULL << (value & 63); // Equivalent of 1ULL << (value % 64)
}
}

idx_t FixedBitSet::next_set_bit(idx_t index) const {
assert(index >= 0 && index < this->num_bits && "index is out of the bound");
idx_t i = index >> 6; // div by 64
uint64_t word = this->bitmap[i] >> (index & 63); // Equivalent of bitmap[i] >> (index % 64)

uint64_t word = this->words[i] >> (index & 63); // Equivalent of words[i] >> (index % 64)
// word is non zero after right shift, it means, next set bit is in current word
// The index of set bit is "given index" + "trailing zero in the right shifted word"
if (word != 0) {
return index + __builtin_ctzll(word);
}

while (++i < this->numBits) {
word = this->bitmap[i];
while (++i < this->num_words) {
word = this->words[i];
if (word != 0) {
return (i << 6) + __builtin_ctzll(word);
}
Expand All @@ -38,5 +41,5 @@ idx_t FixedBitSet::next_set_bit(idx_t index) const {
}

FixedBitSet::~FixedBitSet() {
delete this->bitmap;
delete this->words;
}
9 changes: 7 additions & 2 deletions src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.join.BitSetProducer;
import org.opensearch.knn.index.KNNSettings;

import java.io.IOException;
Expand All @@ -33,20 +34,24 @@ public class KNNQuery extends Query {
@Getter
@Setter
private Query filterQuery;
@Getter
private BitSetProducer parentsFilter;

public KNNQuery(String field, float[] queryVector, int k, String indexName) {
public KNNQuery(String field, float[] queryVector, int k, String indexName, final BitSetProducer parentsFilter) {
this.field = field;
this.queryVector = queryVector;
this.k = k;
this.indexName = indexName;
this.parentsFilter = parentsFilter;
}

public KNNQuery(String field, float[] queryVector, int k, String indexName, Query filterQuery) {
public KNNQuery(String field, float[] queryVector, int k, String indexName, Query filterQuery, BitSetProducer parentsFilter) {
this.field = field;
this.queryVector = queryVector;
this.k = k;
this.indexName = indexName;
this.filterQuery = filterQuery;
this.parentsFilter = parentsFilter;
}

public String getField() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,17 @@ public static Query create(CreateQueryRequest createQueryRequest) {
final VectorDataType vectorDataType = createQueryRequest.getVectorDataType();
final Query filterQuery = getFilterQuery(createQueryRequest);

BitSetProducer parentFilter = createQueryRequest.context == null ? null : createQueryRequest.context.getParentFilter();
if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) {
if (filterQuery != null && KNNEngine.getEnginesThatSupportsFilters().contains(createQueryRequest.getKnnEngine())) {
log.debug("Creating custom k-NN query with filters for index: {}, field: {} , k: {}", indexName, fieldName, k);
return new KNNQuery(fieldName, vector, k, indexName, filterQuery);
return new KNNQuery(fieldName, vector, k, indexName, filterQuery, parentFilter);
}
log.debug(String.format("Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
return new KNNQuery(fieldName, vector, k, indexName);
return new KNNQuery(fieldName, vector, k, indexName, parentFilter);
}

log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
BitSetProducer parentFilter = createQueryRequest.context == null ? null : createQueryRequest.context.getParentFilter();
if (VectorDataType.BYTE == vectorDataType) {
return getKnnByteVectorQuery(fieldName, byteVector, k, filterQuery, parentFilter);
} else if (VectorDataType.FLOAT == vectorDataType) {
Expand Down Expand Up @@ -187,9 +187,6 @@ static class CreateQueryRequest {
private VectorDataType vectorDataType;
@Getter
private int k;
// can be null in cases filter not passed with the knn query
@Getter
private BitSetProducer parentFilter;
private QueryBuilder filter;
// can be null in cases filter not passed with the knn query
private QueryShardContext context;
Expand Down
54 changes: 34 additions & 20 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.lucene.search.FilteredDocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
Expand Down Expand Up @@ -117,9 +118,9 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
* This improves the recall.
*/
if (filterWeight != null && canDoExactSearch(filterIdsArray.length)) {
docIdsToScoreMap.putAll(doExactSearch(context, filterIdsArray));
docIdsToScoreMap.putAll(doExactSearch(context, filterIdsArray, knnQuery.getParentsFilter()));
} else {
Map<Integer, Float> annResults = doANNSearch(context, filterIdsArray);
Map<Integer, Float> annResults = doANNSearch(context, filterIdsArray, knnQuery.getParentsFilter());
if (annResults == null) {
return null;
}
Expand All @@ -131,7 +132,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
annResults.size(),
filterIdsArray.length
);
annResults = doExactSearch(context, filterIdsArray);
annResults = doExactSearch(context, filterIdsArray, knnQuery.getParentsFilter());

Check warning on line 135 in src/main/java/org/opensearch/knn/index/query/KNNWeight.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/query/KNNWeight.java#L135

Added line #L135 was not covered by tests
}
docIdsToScoreMap.putAll(annResults);
}
Expand Down Expand Up @@ -172,23 +173,31 @@ private int[] getFilterIdsArray(final LeafReaderContext context) throws IOExcept
if (filterWeight == null) {
return new int[0];
}
final BitSet filteredDocsBitSet = getFilteredDocsBitSet(context, this.filterWeight);
final int[] filteredIds = new int[filteredDocsBitSet.cardinality()];
int filteredIdsIndex = 0;
int docId = 0;
while (docId < filteredDocsBitSet.length()) {
docId = filteredDocsBitSet.nextSetBit(docId);
if (docId == DocIdSetIterator.NO_MORE_DOCS || docId + 1 == DocIdSetIterator.NO_MORE_DOCS) {
break;
}
filteredIds[filteredIdsIndex] = docId;
filteredIdsIndex++;
docId++;
return bitSetToIntArray(getFilteredDocsBitSet(context, this.filterWeight));
}

private int[] getParentIdsArray(final LeafReaderContext context, final BitSetProducer parentFilter) throws IOException {
if (parentFilter == null) {
return null;
}
return filteredIds;
return bitSetToIntArray(parentFilter.getBitSet(context));
}

private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final int[] filterIdsArray) throws IOException {
private int[] bitSetToIntArray(final BitSet bitSet) {
final int[] intArray = new int[bitSet.cardinality()];
final BitSetIterator bitSetIterator = new BitSetIterator(bitSet, bitSet.cardinality());
int index = 0;
int docId = bitSetIterator.nextDoc();
while (docId != DocIdSetIterator.NO_MORE_DOCS) {
assert index < intArray.length;
intArray[index++] = docId;
docId = bitSetIterator.nextDoc();
}
return intArray;
}

private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final int[] filterIdsArray, final BitSetProducer parentFilter)
throws IOException {
SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader());
String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString();

Expand Down Expand Up @@ -265,13 +274,14 @@ private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final i
if (indexAllocation.isClosed()) {
throw new RuntimeException("Index has already been closed");
}

int[] parentIds = getParentIdsArray(context, parentFilter);
results = JNIService.queryIndex(
indexAllocation.getMemoryAddress(),
knnQuery.getQueryVector(),
knnQuery.getK(),
knnEngine.getName(),
filterIdsArray
filterIdsArray,
parentIds
);

} catch (Exception e) {
Expand All @@ -296,7 +306,11 @@ private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final i
.collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType)));
}

private Map<Integer, Float> doExactSearch(final LeafReaderContext leafReaderContext, final int[] filterIdsArray) {
private Map<Integer, Float> doExactSearch(
final LeafReaderContext leafReaderContext,
final int[] filterIdsArray,
final BitSetProducer parentFilter
) throws IOException {
final SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader());
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
float[] queryVector = this.knnQuery.getQueryVector();
Expand Down
13 changes: 10 additions & 3 deletions src/main/java/org/opensearch/knn/jni/JNIService.java
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,14 @@ public static long loadIndex(String indexPath, Map<String, Object> parameters, S
* @param filteredIds array of ints on which should be used for search.
* @return KNNQueryResult array of k neighbors
*/
public static KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector, int k, String engineName, int[] filteredIds) {
public static KNNQueryResult[] queryIndex(
long indexPointer,
float[] queryVector,
int k,
String engineName,
int[] filteredIds,
int[] parentIds
) {
if (KNNEngine.NMSLIB.getName().equals(engineName)) {
return NmslibService.queryIndex(indexPointer, queryVector, k);
}
Expand All @@ -112,9 +119,9 @@ public static KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector
// filterIds. FilterIds is coming as empty then its the case where we need to do search with Faiss engine
// normally.
if (ArrayUtils.isNotEmpty(filteredIds)) {
return FaissService.queryIndexWithFilter(indexPointer, queryVector, k, filteredIds, null);
return FaissService.queryIndexWithFilter(indexPointer, queryVector, k, filteredIds, parentIds);
}
return FaissService.queryIndex(indexPointer, queryVector, k, null);
return FaissService.queryIndex(indexPointer, queryVector, k, parentIds);
}
throw new IllegalArgumentException("QueryIndex not supported for provided engine");
}
Expand Down
26 changes: 25 additions & 1 deletion src/test/java/org/opensearch/knn/index/NestedSearchIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public final void cleanUp() {
}

@SneakyThrows
public void testNestedSearch_whenKIsTwo_thenReturnTwoResults() {
public void testNestedSearchWithLucene_whenKIsTwo_thenReturnTwoResults() {
createKnnIndex(2, KNNEngine.LUCENE.getName());

String doc1 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
Expand All @@ -78,6 +78,30 @@ public void testNestedSearch_whenKIsTwo_thenReturnTwoResults() {
assertEquals(2, hits.size());
}

@SneakyThrows
public void testNestedSearchWithFaiss_whenKIsTwo_thenReturnTwoResults() {
createKnnIndex(2, KNNEngine.FAISS.getName());

String doc1 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
.add(FIELD_NAME_VECTOR, new Float[] { 1f, 1f }, new Float[] { 1f, 1f })
.build();
addNestedKnnDoc(INDEX_NAME, "1", doc1);

String doc2 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
.add(FIELD_NAME_VECTOR, new Float[] { 2f, 2f }, new Float[] { 2f, 2f })
.build();
addNestedKnnDoc(INDEX_NAME, "2", doc2);

Float[] queryVector = { 1f, 1f };
Response response = queryNestedField(INDEX_NAME, 2, queryVector);

List<Object> hits = (List<Object>) ((Map<String, Object>) createParser(
MediaTypeRegistry.getDefaultMediaType().xContent(),
EntityUtils.toString(response.getEntity())
).map().get("hits")).get("hits");
assertEquals(2, hits.size());
}

/**
* {
* "properties": {
Expand Down
17 changes: 12 additions & 5 deletions src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.join.BitSetProducer;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.index.mapper.MapperService;
Expand Down Expand Up @@ -162,14 +163,20 @@ public void testMultiFieldsKnnIndex(Codec codec) throws Exception {

// query to verify distance for each of the field
IndexSearcher searcher = new IndexSearcher(reader);
float score = searcher.search(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy"), 10).scoreDocs[0].score;
float score1 = searcher.search(new KNNQuery("my_vector", new float[] { 1.0f, 2.0f }, 1, "dummy"), 10).scoreDocs[0].score;
float score = searcher.search(
new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy", (BitSetProducer) null),
10
).scoreDocs[0].score;
float score1 = searcher.search(
new KNNQuery("my_vector", new float[] { 1.0f, 2.0f }, 1, "dummy", (BitSetProducer) null),
10
).scoreDocs[0].score;
assertEquals(1.0f / (1 + 25), score, 0.01f);
assertEquals(1.0f / (1 + 169), score1, 0.01f);

// query to determine the hits
assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy")));
assertEquals(1, searcher.count(new KNNQuery("my_vector", new float[] { 1.0f, 1.0f }, 1, "dummy")));
assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy", (BitSetProducer) null)));
assertEquals(1, searcher.count(new KNNQuery("my_vector", new float[] { 1.0f, 1.0f }, 1, "dummy", (BitSetProducer) null)));

reader.close();
dir.close();
Expand Down Expand Up @@ -254,7 +261,7 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio
NativeMemoryLoadStrategy.IndexLoadStrategy.initialize(resourceWatcherService);
float[] query = { 10.0f, 10.0f, 10.0f };
IndexSearcher searcher = new IndexSearcher(reader);
TopDocs topDocs = searcher.search(new KNNQuery(fieldName, query, 4, "dummy"), 10);
TopDocs topDocs = searcher.search(new KNNQuery(fieldName, query, 4, "dummy", (BitSetProducer) null), 10);

assertEquals(3, topDocs.scoreDocs[0].doc);
assertEquals(2, topDocs.scoreDocs[1].doc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ public static void assertLoadableByEngine(
);
int k = 2;
float[] queryVector = new float[dimension];
KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, knnEngine.getName(), null);
KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, knnEngine.getName(), null, null);
assertTrue(results.length > 0);
JNIService.free(indexPtr, knnEngine.getName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public void testIndexLoadStrategy_load() throws IOException {
// Confirm that the file was loaded by querying
float[] query = new float[dimension];
Arrays.fill(query, numVectors + 1);
KNNQueryResult[] results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), query, 2, knnEngine.getName(), null);
KNNQueryResult[] results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), query, 2, knnEngine.getName(), null, null);
assertTrue(results.length > 0);
}

Expand Down
Loading

0 comments on commit 4d2e517

Please sign in to comment.