Skip to content

Commit

Permalink
Initial unit test implementation
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Bogan <[email protected]>
  • Loading branch information
ryanbogan committed Nov 20, 2024
1 parent 2cebb0e commit 5690405
Show file tree
Hide file tree
Showing 2 changed files with 239 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.Optional;

import lombok.Getter;
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.fetch.FetchSearchResult;
Expand Down Expand Up @@ -98,7 +99,8 @@ public boolean isIgnoreFailure() {
return false;
}

private <Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPhaseResults<Result> searchPhaseResult) {
@VisibleForTesting
<Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPhaseResults<Result> searchPhaseResult) {
if (Objects.isNull(searchPhaseResult) || !(searchPhaseResult instanceof QueryPhaseResultConsumer queryPhaseResultConsumer)) {
return true;
}
Expand All @@ -111,7 +113,8 @@ private <Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPha
* @param searchPhaseResult
* @return true if results are from hybrid query
*/
private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) {
@VisibleForTesting
boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) {
// check for delimiter at the end of the score docs.
return Objects.nonNull(searchPhaseResult.queryResult())
&& Objects.nonNull(searchPhaseResult.queryResult().topDocs())
Expand All @@ -120,17 +123,17 @@ private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) {
&& isHybridQueryStartStopElement(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs[0]);
}

private <Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhaseSearchResults(
final SearchPhaseResults<Result> results
) {
@VisibleForTesting
<Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhaseSearchResults(final SearchPhaseResults<Result> results) {
return results.getAtomicArray()
.asList()
.stream()
.map(result -> result == null ? null : result.queryResult())
.collect(Collectors.toList());
}

private <Result extends SearchPhaseResult> Optional<FetchSearchResult> getFetchSearchResults(
@VisibleForTesting
<Result extends SearchPhaseResult> Optional<FetchSearchResult> getFetchSearchResults(
final SearchPhaseResults<Result> searchPhaseResults
) {
Optional<Result> optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import lombok.SneakyThrows;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.junit.Before;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.OriginalIndices;
import org.opensearch.action.search.QueryPhaseResultConsumer;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.action.search.SearchPhaseName;
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.support.IndicesOptions;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.common.util.concurrent.AtomicArray;
import org.opensearch.core.common.Strings;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.internal.AliasFilter;
import org.opensearch.search.internal.ShardSearchContextId;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.test.OpenSearchTestCase;

import java.util.List;
import java.util.Optional;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class RRFProcessorTests extends OpenSearchTestCase {

@Mock
private ScoreNormalizationTechnique mockNormalizationTechnique;
@Mock
private ScoreCombinationTechnique mockCombinationTechnique;
@Mock
private NormalizationProcessorWorkflow mockNormalizationWorkflow;
@Mock
private SearchPhaseResults<SearchPhaseResult> mockSearchPhaseResults;
@Mock
private SearchPhaseContext mockSearchPhaseContext;
@Mock
private QueryPhaseResultConsumer mockQueryPhaseResultConsumer;

private RRFProcessor rrfProcessor;

@Before
@SneakyThrows
public void setUp() {
super.setUp();
MockitoAnnotations.openMocks(this);
rrfProcessor = new RRFProcessor(
"tag",
"description",
mockNormalizationTechnique,
mockCombinationTechnique,
mockNormalizationWorkflow
);
}

@SneakyThrows
public void testGetType() {
assertEquals("score-ranker-processor", rrfProcessor.getType());
}

@SneakyThrows
public void testGetBeforePhase() {
assertEquals(SearchPhaseName.QUERY, rrfProcessor.getBeforePhase());
}

@SneakyThrows
public void testGetAfterPhase() {
assertEquals(SearchPhaseName.FETCH, rrfProcessor.getAfterPhase());
}

@SneakyThrows
public void testIsIgnoreFailure() {
assertFalse(rrfProcessor.isIgnoreFailure());
}

@SneakyThrows
public void testProcessWithNullSearchPhaseResult() {
rrfProcessor.process(null, mockSearchPhaseContext);
verify(mockNormalizationWorkflow, never()).execute(any());
}

@SneakyThrows
public void testProcessWithNonQueryPhaseResultConsumer() {
rrfProcessor.process(mockSearchPhaseResults, mockSearchPhaseContext);
verify(mockNormalizationWorkflow, never()).execute(any());
}

@SneakyThrows
public void testProcessWithValidHybridInput() {
QuerySearchResult result = createQuerySearchResult(true);
AtomicArray<SearchPhaseResult> atomicArray = new AtomicArray<>(1);
atomicArray.set(0, result);

when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray);

rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext);

verify(mockNormalizationWorkflow).execute(any(NormalizationExecuteDTO.class));
}

@SneakyThrows
public void testProcessWithValidNonHybridInput() {
QuerySearchResult result = createQuerySearchResult(false);
AtomicArray<SearchPhaseResult> atomicArray = new AtomicArray<>(1);
atomicArray.set(0, result);

when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray);

rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext);

verify(mockNormalizationWorkflow, never()).execute(any(NormalizationExecuteDTO.class));
}

@SneakyThrows
public void testGetTag() {
assertEquals("tag", rrfProcessor.getTag());
}

@SneakyThrows
public void testGetDescription() {
assertEquals("description", rrfProcessor.getDescription());
}

@SneakyThrows
public void testShouldSkipProcessor() {
assertTrue(rrfProcessor.shouldSkipProcessor(null));
assertTrue(rrfProcessor.shouldSkipProcessor(mockSearchPhaseResults));

AtomicArray<SearchPhaseResult> atomicArray = new AtomicArray<>(1);
atomicArray.set(0, createQuerySearchResult(false));
when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray);

assertTrue(rrfProcessor.shouldSkipProcessor(mockQueryPhaseResultConsumer));

atomicArray.set(0, createQuerySearchResult(true));
assertFalse(rrfProcessor.shouldSkipProcessor(mockQueryPhaseResultConsumer));
}

@SneakyThrows
public void testGetQueryPhaseSearchResults() {
AtomicArray<SearchPhaseResult> atomicArray = new AtomicArray<>(2);
atomicArray.set(0, createQuerySearchResult(true));
atomicArray.set(1, createQuerySearchResult(false));
when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray);

List<QuerySearchResult> results = rrfProcessor.getQueryPhaseSearchResults(mockQueryPhaseResultConsumer);
assertEquals(2, results.size());
assertNotNull(results.get(0));
assertNotNull(results.get(1));
}

@SneakyThrows
public void testGetFetchSearchResults() {
AtomicArray<SearchPhaseResult> atomicArray = new AtomicArray<>(1);
atomicArray.set(0, createQuerySearchResult(true));
when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray);

Optional<FetchSearchResult> result = rrfProcessor.getFetchSearchResults(mockQueryPhaseResultConsumer);
assertFalse(result.isPresent());
}

private QuerySearchResult createQuerySearchResult(boolean isHybrid) {
ShardId shardId = new ShardId("index", "uuid", 0);
OriginalIndices originalIndices = new OriginalIndices(new String[] { "index" }, IndicesOptions.strictExpandOpenAndForbidClosed());
SearchRequest searchRequest = new SearchRequest("index");
searchRequest.source(new SearchSourceBuilder());
searchRequest.allowPartialSearchResults(true);

int numberOfShards = 1;
AliasFilter aliasFilter = new AliasFilter(null, Strings.EMPTY_ARRAY);
float indexBoost = 1.0f;
long nowInMillis = System.currentTimeMillis();
String clusterAlias = null;
String[] indexRoutings = Strings.EMPTY_ARRAY;

ShardSearchRequest shardSearchRequest = new ShardSearchRequest(
originalIndices,
searchRequest,
shardId,
numberOfShards,
aliasFilter,
indexBoost,
nowInMillis,
clusterAlias,
indexRoutings
);

QuerySearchResult result = new QuerySearchResult(
new ShardSearchContextId("test", 1),
new SearchShardTarget("node1", shardId, clusterAlias, originalIndices),
shardSearchRequest
);
result.from(0).size(10);

ScoreDoc[] scoreDocs;
if (isHybrid) {
scoreDocs = new ScoreDoc[] { HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults(0) };
} else {
scoreDocs = new ScoreDoc[] { new ScoreDoc(0, 1.0f) };
}

TopDocs topDocs = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), scoreDocs);
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(topDocs, 1.0f);
result.topDocs(topDocsAndMaxScore, new DocValueFormat[0]);

return result;
}
}

0 comments on commit 5690405

Please sign in to comment.