From b2d6c38ac9ab86aa21de913cd12dd7d07cc195f3 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 5 Oct 2023 22:41:52 -0700 Subject: [PATCH] Use list of original doc ids for fetch results Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + .../NormalizationProcessorWorkflow.java | 38 +++++- .../NormalizationProcessorWorkflowTests.java | 113 ++++++++++++++++++ 3 files changed, 146 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 61dedd0d5..6042ebfb5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), Support sparse semantic retrieval by introducing `sparse_encoding` ingest processor and query builder ([#333](https://github.com/opensearch-project/neural-search/pull/333)) ### Enhancements ### Bug Fixes +Fixed exception in Hybrid Query for one shard and multiple node ([#396](https://github.com/opensearch-project/neural-search/pull/396)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index 23fbac002..e504eb5ef 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -6,12 +6,12 @@ package org.opensearch.neuralsearch.processor; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Optional; -import java.util.function.Function; import java.util.stream.Collectors; import lombok.AllArgsConstructor; @@ -52,6 +52,9 @@ public void execute( final ScoreNormalizationTechnique normalizationTechnique, final ScoreCombinationTechnique combinationTechnique ) { + // save original state + List unprocessedDocIds = unprocessedDocIds(querySearchResults); + // pre-process data log.debug("Pre-process query results"); List queryTopDocs = getQueryTopDocs(querySearchResults); @@ -67,7 +70,7 @@ public void execute( // post-process data log.debug("Post-process query results after score normalization and combination"); updateOriginalQueryResults(querySearchResults, queryTopDocs); - updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional); + updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional, unprocessedDocIds); } /** @@ -123,7 +126,8 @@ private void updateOriginalQueryResults(final List querySearc */ private void updateOriginalFetchResults( final List querySearchResults, - final Optional fetchSearchResultOptional + final Optional fetchSearchResultOptional, + final List docIds ) { if (fetchSearchResultOptional.isEmpty()) { return; @@ -135,14 +139,17 @@ private void updateOriginalFetchResults( // 3. update original scores to normalized and combined values // 4. order scores based on normalized and combined values FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get(); - SearchHits searchHits = fetchSearchResult.hits(); + SearchHit[] searchHitArray = getSearchHits(docIds, fetchSearchResult); // create map of docId to index of search hits. This solves (2), duplicates are from // delimiter and start/stop elements, they all have same valid doc_id. For this map // we use doc_id as a key, and all those special elements are collapsed into a single // key-value pair. - Map docIdToSearchHit = Arrays.stream(searchHits.getHits()) - .collect(Collectors.toMap(SearchHit::docId, Function.identity(), (a1, a2) -> a1)); + Map docIdToSearchHit = new HashMap<>(); + for (int i = 0; i < searchHitArray.length; i++) { + int originalDocId = docIds.get(i); + docIdToSearchHit.put(originalDocId, searchHitArray[i]); + } QuerySearchResult querySearchResult = querySearchResults.get(0); TopDocs topDocs = querySearchResult.topDocs().topDocs; @@ -161,4 +168,23 @@ private void updateOriginalFetchResults( ); fetchSearchResult.hits(updatedSearchHits); } + + private SearchHit[] getSearchHits(List docIds, FetchSearchResult fetchSearchResult) { + SearchHits searchHits = fetchSearchResult.hits(); + SearchHit[] searchHitArray = searchHits.getHits(); + // validate the both collections are of the same size + if (Objects.isNull(searchHitArray) || searchHitArray.length != docIds.size()) { + throw new IllegalStateException("Score normalization processor cannot produce final query result"); + } + return searchHitArray; + } + + private List unprocessedDocIds(List querySearchResults) { + List docIds = querySearchResults.isEmpty() + ? List.of() + : Arrays.stream(querySearchResults.get(0).topDocs().topDocs.scoreDocs) + .map(scoreDoc -> scoreDoc.doc) + .collect(Collectors.toList()); + return docIds; + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java index a74fb53f2..95c2ba0c2 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -179,4 +179,117 @@ public void testFetchResults_whenOneShardAndQueryAndFetchResultsPresent_thenDoNo TestUtils.assertQueryResultScores(querySearchResults); TestUtils.assertFetchResultScores(fetchSearchResult, 4); } + + public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCombination() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + + List querySearchResults = new ArrayList<>(); + FetchSearchResult fetchSearchResult = new FetchSearchResult(); + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + querySearchResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + new ScoreDoc(4, 0.25f), + new ScoreDoc(10, 0.2f), + createStartStopElementForHybridSearchResults(0) } + ), + 0.5f + ), + new DocValueFormat[0] + ); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + querySearchResults.add(querySearchResult); + SearchHit[] searchHitArray = new SearchHit[] { + new SearchHit(-1, "10", Map.of(), Map.of()), + new SearchHit(-1, "10", Map.of(), Map.of()), + new SearchHit(-1, "10", Map.of(), Map.of()), + new SearchHit(-1, "1", Map.of(), Map.of()), + new SearchHit(-1, "2", Map.of(), Map.of()), + new SearchHit(-1, "3", Map.of(), Map.of()), + new SearchHit(-1, "10", Map.of(), Map.of()), }; + SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); + fetchSearchResult.hits(searchHits); + + normalizationProcessorWorkflow.execute( + querySearchResults, + Optional.of(fetchSearchResult), + ScoreNormalizationFactory.DEFAULT_METHOD, + ScoreCombinationFactory.DEFAULT_METHOD + ); + + TestUtils.assertQueryResultScores(querySearchResults); + TestUtils.assertFetchResultScores(fetchSearchResult, 4); + } + + public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_thenFail() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + + List querySearchResults = new ArrayList<>(); + FetchSearchResult fetchSearchResult = new FetchSearchResult(); + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + querySearchResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + new ScoreDoc(4, 0.25f), + new ScoreDoc(10, 0.2f), + createStartStopElementForHybridSearchResults(0) } + ), + 0.5f + ), + new DocValueFormat[0] + ); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + querySearchResults.add(querySearchResult); + SearchHit[] searchHitArray = new SearchHit[] { + new SearchHit(-1, "10", Map.of(), Map.of()), + new SearchHit(-1, "10", Map.of(), Map.of()), + new SearchHit(-1, "10", Map.of(), Map.of()), + new SearchHit(-1, "1", Map.of(), Map.of()), + new SearchHit(-1, "2", Map.of(), Map.of()), + new SearchHit(-1, "3", Map.of(), Map.of()) }; + SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); + fetchSearchResult.hits(searchHits); + + expectThrows( + IllegalStateException.class, + () -> normalizationProcessorWorkflow.execute( + querySearchResults, + Optional.of(fetchSearchResult), + ScoreNormalizationFactory.DEFAULT_METHOD, + ScoreCombinationFactory.DEFAULT_METHOD + ) + ); + } }