Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RRF tests #992

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ jobs:
run: |
./gradlew check

- name: Upload Coverage Report
uses: codecov/codecov-action@v3
with:
token: ${{ secrets.CODECOV_TOKEN }}

Precommit-neural-search-linux:
needs: Get-CI-Image-Tag
strategy:
Expand Down Expand Up @@ -131,8 +136,7 @@ jobs:
su `id -un 1000` -c "./gradlew precommit --parallel"

- name: Upload Coverage Report
if: ${{ !cancelled() }}
uses: codecov/codecov-action@v4
uses: codecov/codecov-action@v3
with:
token: ${{ secrets.CODECOV_TOKEN }}

Expand Down Expand Up @@ -164,3 +168,7 @@ jobs:
run: |
./gradlew precommit --parallel

- name: Upload Coverage Report
uses: codecov/codecov-action@v3
with:
token: ${{ secrets.CODECOV_TOKEN }}
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@

Compatible with OpenSearch 2.18.0

### Features
- Introduces ByFieldRerankProcessor for second level reranking on documents ([#932](https://github.com/opensearch-project/neural-search/pull/932))
### Bug Fixes
- Fixed incorrect document order for nested aggregations in hybrid query ([#956](https://github.com/opensearch-project/neural-search/pull/956))
### Enhancements
- Implement `ignore_missing` field in text chunking processors ([#907](https://github.com/opensearch-project/neural-search/pull/907))
- Added rescorer in hybrid query ([#917](https://github.com/opensearch-project/neural-search/pull/917))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.Optional;

import lombok.Getter;
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.fetch.FetchSearchResult;
Expand Down Expand Up @@ -98,7 +99,8 @@ public boolean isIgnoreFailure() {
return false;
}

private <Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPhaseResults<Result> searchPhaseResult) {
@VisibleForTesting
<Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPhaseResults<Result> searchPhaseResult) {
if (Objects.isNull(searchPhaseResult) || !(searchPhaseResult instanceof QueryPhaseResultConsumer queryPhaseResultConsumer)) {
return true;
}
Expand All @@ -111,7 +113,8 @@ private <Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPha
* @param searchPhaseResult
* @return true if results are from hybrid query
*/
private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) {
@VisibleForTesting
boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) {
// check for delimiter at the end of the score docs.
return Objects.nonNull(searchPhaseResult.queryResult())
&& Objects.nonNull(searchPhaseResult.queryResult().topDocs())
Expand All @@ -120,17 +123,17 @@ private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) {
&& isHybridQueryStartStopElement(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs[0]);
}

private <Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhaseSearchResults(
final SearchPhaseResults<Result> results
) {
@VisibleForTesting
<Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhaseSearchResults(final SearchPhaseResults<Result> results) {
return results.getAtomicArray()
.asList()
.stream()
.map(result -> result == null ? null : result.queryResult())
.collect(Collectors.toList());
}

private <Result extends SearchPhaseResult> Optional<FetchSearchResult> getFetchSearchResults(
@VisibleForTesting
<Result extends SearchPhaseResult> Optional<FetchSearchResult> getFetchSearchResults(
final SearchPhaseResults<Result> searchPhaseResults
) {
Optional<Result> optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import lombok.SneakyThrows;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.junit.Before;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.OriginalIndices;
import org.opensearch.action.search.QueryPhaseResultConsumer;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.action.search.SearchPhaseName;
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.support.IndicesOptions;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.common.util.concurrent.AtomicArray;
import org.opensearch.core.common.Strings;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.internal.AliasFilter;
import org.opensearch.search.internal.ShardSearchContextId;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.test.OpenSearchTestCase;

import java.util.List;
import java.util.Optional;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class RRFProcessorTests extends OpenSearchTestCase {

@Mock
private ScoreNormalizationTechnique mockNormalizationTechnique;
@Mock
private ScoreCombinationTechnique mockCombinationTechnique;
@Mock
private NormalizationProcessorWorkflow mockNormalizationWorkflow;
@Mock
private SearchPhaseResults<SearchPhaseResult> mockSearchPhaseResults;
@Mock
private SearchPhaseContext mockSearchPhaseContext;
@Mock
private QueryPhaseResultConsumer mockQueryPhaseResultConsumer;

private RRFProcessor rrfProcessor;

@Before
@SneakyThrows
public void setUp() {
super.setUp();
MockitoAnnotations.openMocks(this);
rrfProcessor = new RRFProcessor(
"tag",
"description",
mockNormalizationTechnique,
mockCombinationTechnique,
mockNormalizationWorkflow
);
}

@SneakyThrows
public void testGetType() {
assertEquals("score-ranker-processor", rrfProcessor.getType());
}

@SneakyThrows
public void testGetBeforePhase() {
assertEquals(SearchPhaseName.QUERY, rrfProcessor.getBeforePhase());
}

@SneakyThrows
public void testGetAfterPhase() {
assertEquals(SearchPhaseName.FETCH, rrfProcessor.getAfterPhase());
}

@SneakyThrows
public void testIsIgnoreFailure() {
assertFalse(rrfProcessor.isIgnoreFailure());
}

@SneakyThrows
public void testProcessWithNullSearchPhaseResult() {
rrfProcessor.process(null, mockSearchPhaseContext);
verify(mockNormalizationWorkflow, never()).execute(any());
}

@SneakyThrows
public void testProcessWithNonQueryPhaseResultConsumer() {
rrfProcessor.process(mockSearchPhaseResults, mockSearchPhaseContext);
verify(mockNormalizationWorkflow, never()).execute(any());
}

@SneakyThrows
public void testProcessWithValidHybridInput() {
QuerySearchResult result = createQuerySearchResult(true);
AtomicArray<SearchPhaseResult> atomicArray = new AtomicArray<>(1);
atomicArray.set(0, result);

when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray);

rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext);

verify(mockNormalizationWorkflow).execute(any(NormalizationExecuteDTO.class));
}

@SneakyThrows
public void testProcessWithValidNonHybridInput() {
QuerySearchResult result = createQuerySearchResult(false);
AtomicArray<SearchPhaseResult> atomicArray = new AtomicArray<>(1);
atomicArray.set(0, result);

when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray);

rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext);

verify(mockNormalizationWorkflow, never()).execute(any(NormalizationExecuteDTO.class));
}

@SneakyThrows
public void testGetTag() {
assertEquals("tag", rrfProcessor.getTag());
}

@SneakyThrows
public void testGetDescription() {
assertEquals("description", rrfProcessor.getDescription());
}

@SneakyThrows
public void testShouldSkipProcessor() {
assertTrue(rrfProcessor.shouldSkipProcessor(null));
assertTrue(rrfProcessor.shouldSkipProcessor(mockSearchPhaseResults));

AtomicArray<SearchPhaseResult> atomicArray = new AtomicArray<>(1);
atomicArray.set(0, createQuerySearchResult(false));
when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray);

assertTrue(rrfProcessor.shouldSkipProcessor(mockQueryPhaseResultConsumer));

atomicArray.set(0, createQuerySearchResult(true));
assertFalse(rrfProcessor.shouldSkipProcessor(mockQueryPhaseResultConsumer));
}

@SneakyThrows
public void testGetQueryPhaseSearchResults() {
AtomicArray<SearchPhaseResult> atomicArray = new AtomicArray<>(2);
atomicArray.set(0, createQuerySearchResult(true));
atomicArray.set(1, createQuerySearchResult(false));
when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray);

List<QuerySearchResult> results = rrfProcessor.getQueryPhaseSearchResults(mockQueryPhaseResultConsumer);
assertEquals(2, results.size());
assertNotNull(results.get(0));
assertNotNull(results.get(1));
}

@SneakyThrows
public void testGetFetchSearchResults() {
AtomicArray<SearchPhaseResult> atomicArray = new AtomicArray<>(1);
atomicArray.set(0, createQuerySearchResult(true));
when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray);

Optional<FetchSearchResult> result = rrfProcessor.getFetchSearchResults(mockQueryPhaseResultConsumer);
assertFalse(result.isPresent());
}

private QuerySearchResult createQuerySearchResult(boolean isHybrid) {
ShardId shardId = new ShardId("index", "uuid", 0);
OriginalIndices originalIndices = new OriginalIndices(new String[] { "index" }, IndicesOptions.strictExpandOpenAndForbidClosed());
SearchRequest searchRequest = new SearchRequest("index");
searchRequest.source(new SearchSourceBuilder());
searchRequest.allowPartialSearchResults(true);

int numberOfShards = 1;
AliasFilter aliasFilter = new AliasFilter(null, Strings.EMPTY_ARRAY);
float indexBoost = 1.0f;
long nowInMillis = System.currentTimeMillis();
String clusterAlias = null;
String[] indexRoutings = Strings.EMPTY_ARRAY;

ShardSearchRequest shardSearchRequest = new ShardSearchRequest(
originalIndices,
searchRequest,
shardId,
numberOfShards,
aliasFilter,
indexBoost,
nowInMillis,
clusterAlias,
indexRoutings
);

QuerySearchResult result = new QuerySearchResult(
new ShardSearchContextId("test", 1),
new SearchShardTarget("node1", shardId, clusterAlias, originalIndices),
shardSearchRequest
);
result.from(0).size(10);

ScoreDoc[] scoreDocs;
if (isHybrid) {
scoreDocs = new ScoreDoc[] { HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults(0) };
} else {
scoreDocs = new ScoreDoc[] { new ScoreDoc(0, 1.0f) };
}

TopDocs topDocs = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), scoreDocs);
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(topDocs, 1.0f);
result.topDocs(topDocsAndMaxScore, new DocValueFormat[0]);

return result;
}
}
Loading
Loading