From 86c326389d8206b247249983c9a070fede1563b1 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 19 Dec 2024 09:38:16 -0800 Subject: [PATCH] Add case for null/NaN scores and minor refactoring Signed-off-by: Martin Gaievski --- .../ExplanationResponseProcessor.java | 3 +- .../RRFNormalizationTechnique.java | 2 +- ...=> ExplanationResponseProcessorTests.java} | 116 +++++++++++++++++- .../RRFNormalizationTechniqueTests.java | 7 +- .../query/HybridQueryExplainIT.java | 12 +- 5 files changed, 130 insertions(+), 10 deletions(-) rename src/test/java/org/opensearch/neuralsearch/processor/{ExplanationPayloadProcessorTests.java => ExplanationResponseProcessorTests.java} (76%) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java index 01cdfcb0d..7a61519f8 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java @@ -111,8 +111,9 @@ public SearchResponse processResponse( ); } // Create and set final explanation combining all components + Float finalScore = Float.isNaN(searchHit.getScore()) ? 0.0f : searchHit.getScore(); Explanation finalExplanation = Explanation.match( - searchHit.getScore(), + finalScore, // combination level explanation is always a single detail combinationExplanation.getScoreDetails().get(0).getValue(), normalizedExplanation diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java index 4cc773592..80fc65eb3 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java @@ -71,7 +71,7 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { @Override public String describe() { - return String.format(Locale.ROOT, "%s, rank_constant %s", TECHNIQUE_NAME, rankConstant); + return String.format(Locale.ROOT, "%s, rank_constant [%s]", TECHNIQUE_NAME, rankConstant); } @Override diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java similarity index 76% rename from src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java rename to src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java index e47ea43d2..530753a96 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java @@ -37,9 +37,10 @@ import java.util.TreeMap; import static org.mockito.Mockito.mock; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_FLOATS_ASSERTION; import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; -public class ExplanationPayloadProcessorTests extends OpenSearchTestCase { +public class ExplanationResponseProcessorTests extends OpenSearchTestCase { private static final String PROCESSOR_TAG = "mockTag"; private static final String DESCRIPTION = "mockDescription"; @@ -192,6 +193,119 @@ public void testParsingOfExplanations_whenScoreSortingAndExplanations_thenSucces assertOnExplanationResults(processedResponse, maxScore); } + @SneakyThrows + public void testProcessResponse_whenNullSearchHits_thenNoOp() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchResponse searchResponse = getSearchResponse(null); + PipelineProcessingContext context = new PipelineProcessingContext(); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertEquals(searchResponse, processedResponse); + } + + @SneakyThrows + public void testProcessResponse_whenEmptySearchHits_thenNoOp() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchHits emptyHits = new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0.0f); + SearchResponse searchResponse = getSearchResponse(emptyHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertEquals(searchResponse, processedResponse); + } + + @SneakyThrows + public void testProcessResponse_whenNullExplanation_thenSkipProcessing() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchHits searchHits = getSearchHits(1.0f); + for (SearchHit hit : searchHits.getHits()) { + hit.explanation(null); + } + SearchResponse searchResponse = getSearchResponse(searchHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertEquals(searchResponse, processedResponse); + } + + @SneakyThrows + public void testProcessResponse_whenInvalidExplanationPayload_thenHandleGracefully() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchHits searchHits = getSearchHits(1.0f); + SearchResponse searchResponse = getSearchResponse(searchHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + // Set invalid payload + Map invalidPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, + "invalid payload" + ); + ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(invalidPayload).build(); + context.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, explanationPayload); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertNotNull(processedResponse); + } + + @SneakyThrows + public void testProcessResponse_whenZeroScore_thenProcessCorrectly() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchHits searchHits = getSearchHits(0.0f); + SearchResponse searchResponse = getSearchResponse(searchHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + Map> combinedExplainDetails = getCombinedExplainDetails(searchHits); + Map explainPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, + combinedExplainDetails + ); + ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(explainPayload).build(); + context.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, explanationPayload); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertNotNull(processedResponse); + assertEquals(0.0f, processedResponse.getHits().getMaxScore(), DELTA_FOR_SCORE_ASSERTION); + } + + @SneakyThrows + public void testProcessResponse_whenScoreIsNaN_thenExplanationUsesZero() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + + // Create SearchHits with NaN score + SearchHits searchHits = getSearchHits(Float.NaN); + SearchResponse searchResponse = getSearchResponse(searchHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + // Setup explanation payload + Map> combinedExplainDetails = getCombinedExplainDetails(searchHits); + Map explainPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, + combinedExplainDetails + ); + ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(explainPayload).build(); + context.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, explanationPayload); + + // Process response + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + + // Verify results + assertNotNull(processedResponse); + SearchHit[] hits = processedResponse.getHits().getHits(); + assertNotNull(hits); + assertTrue(hits.length > 0); + + // Verify that the explanation uses 0.0f when input score was NaN + Explanation explanation = hits[0].getExplanation(); + assertNotNull(explanation); + assertEquals(0.0f, (float) explanation.getValue(), DELTA_FOR_FLOATS_ASSERTION); + } + private static SearchHits getSearchHits(float maxScore) { int numResponses = 1; int numIndices = 2; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java index 273d3d25f..da6d37bd7 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java @@ -30,8 +30,13 @@ public class RRFNormalizationTechniqueTests extends OpenSearchQueryTestCase { private static final SearchShard SEARCH_SHARD = new SearchShard("my_index", 0, "12345678"); public void testDescribe() { + // verify with default values for parameters RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); - assertEquals("rrf, rank_constant 60", normalizationTechnique.describe()); + assertEquals("rrf, rank_constant [60]", normalizationTechnique.describe()); + + // verify when parameter values are set + normalizationTechnique = new RRFNormalizationTechnique(Map.of("rank_constant", 25), scoreNormalizationUtil); + assertEquals("rrf, rank_constant [25]", normalizationTechnique.describe()); } public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() { diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java index 35ad2aac5..c6eaa21ff 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java @@ -633,7 +633,7 @@ public void testExplain_whenRRFProcessor_thenSuccessful() { // two sub-queries meaning we do have two detail objects with separate query level details Map hit1DetailsForHit1 = hit1Details.get(0); assertTrue((double) hit1DetailsForHit1.get("value") > DELTA_FOR_SCORE_ASSERTION); - assertEquals("rrf, rank_constant 60 normalization of:", hit1DetailsForHit1.get("description")); + assertEquals("rrf, rank_constant [60] normalization of:", hit1DetailsForHit1.get("description")); assertEquals(1, ((List) hit1DetailsForHit1.get("details")).size()); Map explanationsHit1 = getListOfValues(hit1DetailsForHit1, "details").get(0); @@ -643,7 +643,7 @@ public void testExplain_whenRRFProcessor_thenSuccessful() { Map hit1DetailsForHit2 = hit1Details.get(1); assertTrue((double) hit1DetailsForHit2.get("value") > 0.0f); - assertEquals("rrf, rank_constant 60 normalization of:", hit1DetailsForHit2.get("description")); + assertEquals("rrf, rank_constant [60] normalization of:", hit1DetailsForHit2.get("description")); assertEquals(1, ((List) hit1DetailsForHit2.get("details")).size()); Map explanationsHit2 = getListOfValues(hit1DetailsForHit2, "details").get(0); @@ -663,12 +663,12 @@ public void testExplain_whenRRFProcessor_thenSuccessful() { Map hit2DetailsForHit1 = hit2Details.get(0); assertTrue((double) hit2DetailsForHit1.get("value") > DELTA_FOR_SCORE_ASSERTION); - assertEquals("rrf, rank_constant 60 normalization of:", hit2DetailsForHit1.get("description")); + assertEquals("rrf, rank_constant [60] normalization of:", hit2DetailsForHit1.get("description")); assertEquals(1, ((List) hit2DetailsForHit1.get("details")).size()); Map hit2DetailsForHit2 = hit2Details.get(1); assertTrue((double) hit2DetailsForHit2.get("value") > DELTA_FOR_SCORE_ASSERTION); - assertEquals("rrf, rank_constant 60 normalization of:", hit2DetailsForHit2.get("description")); + assertEquals("rrf, rank_constant [60] normalization of:", hit2DetailsForHit2.get("description")); assertEquals(1, ((List) hit2DetailsForHit2.get("details")).size()); // hit 3 @@ -683,7 +683,7 @@ public void testExplain_whenRRFProcessor_thenSuccessful() { Map hit3DetailsForHit1 = hit3Details.get(0); assertTrue((double) hit3DetailsForHit1.get("value") > .0f); - assertEquals("rrf, rank_constant 60 normalization of:", hit3DetailsForHit1.get("description")); + assertEquals("rrf, rank_constant [60] normalization of:", hit3DetailsForHit1.get("description")); assertEquals(1, ((List) hit3DetailsForHit1.get("details")).size()); Map explanationsHit3 = getListOfValues(hit3DetailsForHit1, "details").get(0); @@ -703,7 +703,7 @@ public void testExplain_whenRRFProcessor_thenSuccessful() { Map hit4DetailsForHit1 = hit4Details.get(0); assertTrue((double) hit4DetailsForHit1.get("value") > DELTA_FOR_SCORE_ASSERTION); - assertEquals("rrf, rank_constant 60 normalization of:", hit4DetailsForHit1.get("description")); + assertEquals("rrf, rank_constant [60] normalization of:", hit4DetailsForHit1.get("description")); assertEquals(1, ((List) hit4DetailsForHit1.get("details")).size()); Map explanationsHit4 = getListOfValues(hit4DetailsForHit1, "details").get(0);