Skip to content

Commit

Permalink
With only GlobalAggregation in request causes unnecessary wrapping wi…
Browse files Browse the repository at this point in the history
…th MultiCollector (#8125)

Signed-off-by: Sorabh Hamirwasia <[email protected]>
  • Loading branch information
sohami authored Jun 17, 2023
1 parent 4228075 commit 90678c2
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Replaces ZipInputStream with ZipFile to fix Zip Slip vulnerability ([#7230](https://github.com/opensearch-project/OpenSearch/pull/7230))
- Add missing validation/parsing of SearchBackpressureMode of SearchBackpressureSettings ([#7541](https://github.com/opensearch-project/OpenSearch/pull/7541))
- Fix mapping char_filter when mapping a hashtag ([#7591](https://github.com/opensearch-project/OpenSearch/pull/7591))
- With only GlobalAggregation in request causes unnecessary wrapping with MultiCollector ([#8125](https://github.com/opensearch-project/OpenSearch/pull/8125))

### Security

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,19 @@

package org.opensearch.search.profile.aggregation;

import org.hamcrest.core.IsNull;
import org.opensearch.action.index.IndexRequestBuilder;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.search.aggregations.Aggregator.SubAggCollectionMode;
import org.opensearch.search.aggregations.BucketOrder;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.bucket.global.Global;
import org.opensearch.search.aggregations.bucket.sampler.DiversifiedOrdinalsSamplerAggregator;
import org.opensearch.search.aggregations.bucket.terms.GlobalOrdinalsStringTermsAggregator;
import org.opensearch.search.aggregations.metrics.Stats;
import org.opensearch.search.profile.ProfileResult;
import org.opensearch.search.profile.ProfileShardResult;
import org.opensearch.search.profile.query.QueryProfileShardResult;
import org.opensearch.test.OpenSearchIntegTestCase;

import java.util.ArrayList;
Expand All @@ -48,11 +53,15 @@
import java.util.Set;
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.sameInstance;
import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder;
import static org.opensearch.search.aggregations.AggregationBuilders.avg;
import static org.opensearch.search.aggregations.AggregationBuilders.diversifiedSampler;
import static org.opensearch.search.aggregations.AggregationBuilders.global;
import static org.opensearch.search.aggregations.AggregationBuilders.histogram;
import static org.opensearch.search.aggregations.AggregationBuilders.max;
import static org.opensearch.search.aggregations.AggregationBuilders.stats;
import static org.opensearch.search.aggregations.AggregationBuilders.terms;
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked;
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertSearchResponse;
Expand Down Expand Up @@ -95,6 +104,7 @@ public class AggregationProfilerIT extends OpenSearchIntegTestCase {
private static final String NUMBER_FIELD = "number";
private static final String TAG_FIELD = "tag";
private static final String STRING_FIELD = "string_field";
private final int numDocs = 5;

@Override
protected int numberOfShards() {
Expand All @@ -118,7 +128,7 @@ protected void setupSuiteScopeCluster() throws Exception {
randomStrings[i] = randomAlphaOfLength(10);
}

for (int i = 0; i < 5; i++) {
for (int i = 0; i < numDocs; i++) {
builders.add(
client().prepareIndex("idx")
.setSource(
Expand Down Expand Up @@ -633,4 +643,68 @@ public void testNoProfile() {
assertThat(profileResults, notNullValue());
assertThat(profileResults.size(), equalTo(0));
}

public void testGlobalAggWithStatsSubAggregatorProfile() {
boolean profileEnabled = true;
SearchResponse response = client().prepareSearch("idx")
.addAggregation(global("global").subAggregation(stats("value_stats").field(NUMBER_FIELD)))
.setProfile(profileEnabled)
.get();

assertSearchResponse(response);

Global global = response.getAggregations().get("global");
assertThat(global, IsNull.notNullValue());
assertThat(global.getName(), equalTo("global"));
assertThat(global.getDocCount(), equalTo((long) numDocs));
assertThat((long) ((InternalAggregation) global).getProperty("_count"), equalTo((long) numDocs));
assertThat(global.getAggregations().asList().isEmpty(), is(false));

Stats stats = global.getAggregations().get("value_stats");
assertThat((Stats) ((InternalAggregation) global).getProperty("value_stats"), sameInstance(stats));
assertThat(stats, IsNull.notNullValue());
assertThat(stats.getName(), equalTo("value_stats"));

Map<String, ProfileShardResult> profileResults = response.getProfileResults();
assertThat(profileResults, notNullValue());
assertThat(profileResults.size(), equalTo(getNumShards("idx").numPrimaries));
for (ProfileShardResult profileShardResult : profileResults.values()) {
assertThat(profileShardResult, notNullValue());
List<QueryProfileShardResult> queryProfileShardResults = profileShardResult.getQueryProfileResults();
assertEquals(queryProfileShardResults.size(), 2);
// ensure there is no multi collector getting added with only global agg
for (QueryProfileShardResult queryProfileShardResult : queryProfileShardResults) {
assertEquals(queryProfileShardResult.getQueryResults().size(), 1);
if (queryProfileShardResult.getQueryResults().get(0).getQueryName().equals("MatchAllDocsQuery")) {
assertEquals(0, queryProfileShardResult.getQueryResults().get(0).getProfiledChildren().size());
assertEquals("search_top_hits", queryProfileShardResult.getCollectorResult().getReason());
assertEquals(0, queryProfileShardResult.getCollectorResult().getProfiledChildren().size());
} else if (queryProfileShardResult.getQueryResults().get(0).getQueryName().equals("ConstantScoreQuery")) {
assertEquals(1, queryProfileShardResult.getQueryResults().get(0).getProfiledChildren().size());
assertEquals("aggregation_global", queryProfileShardResult.getCollectorResult().getReason());
assertEquals(0, queryProfileShardResult.getCollectorResult().getProfiledChildren().size());
} else {
fail("unexpected profile shard result in the response");
}
}
AggregationProfileShardResult aggProfileResults = profileShardResult.getAggregationProfileResults();
assertThat(aggProfileResults, notNullValue());
List<ProfileResult> aggProfileResultsList = aggProfileResults.getProfileResults();
assertThat(aggProfileResultsList, notNullValue());
assertEquals(1, aggProfileResultsList.size());
ProfileResult globalAggResult = aggProfileResultsList.get(0);
assertThat(globalAggResult, notNullValue());
assertEquals("GlobalAggregator", globalAggResult.getQueryName());
assertEquals("global", globalAggResult.getLuceneDescription());
assertEquals(1, globalAggResult.getProfiledChildren().size());
assertThat(globalAggResult.getTime(), greaterThan(0L));
Map<String, Long> breakdown = globalAggResult.getTimeBreakdown();
assertThat(breakdown, notNullValue());
assertEquals(BREAKDOWN_KEYS, breakdown.keySet());
assertThat(breakdown.get(INITIALIZE), greaterThan(0L));
assertThat(breakdown.get(COLLECT), greaterThan(0L));
assertThat(breakdown.get(BUILD_AGGREGATION).longValue(), greaterThan(0L));
assertEquals(0, breakdown.get(REDUCE).intValue());
}
}
}
26 changes: 14 additions & 12 deletions server/src/main/java/org/opensearch/search/query/QueryPhase.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
Expand Down Expand Up @@ -71,6 +72,7 @@

import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutorService;
Expand Down Expand Up @@ -234,19 +236,19 @@ static boolean executeInternal(SearchContext searchContext, QueryPhaseSearcher q
// this collector can filter documents during the collection
hasFilterCollector = true;
}
if (searchContext.queryCollectorManagers().isEmpty() == false) {
// plug in additional collectors, like aggregations except global aggregations
collectors.add(
createMultiCollectorContext(
searchContext.queryCollectorManagers()
.entrySet()
.stream()
.filter(entry -> !(entry.getKey().equals(GlobalAggCollectorManager.class)))
.map(Map.Entry::getValue)
.collect(Collectors.toList())
)
);

// plug in additional collectors, like aggregations except global aggregations
final List<CollectorManager<? extends Collector, ReduceableSearchResult>> managersExceptGlobalAgg = searchContext
.queryCollectorManagers()
.entrySet()
.stream()
.filter(entry -> !(entry.getKey().equals(GlobalAggCollectorManager.class)))
.map(Map.Entry::getValue)
.collect(Collectors.toList());
if (managersExceptGlobalAgg.isEmpty() == false) {
collectors.add(createMultiCollectorContext(managersExceptGlobalAgg));
}

if (searchContext.minimumScore() != null) {
// apply the minimum score after multi collector so we filter aggs as well
collectors.add(createMinScoreCollectorContext(searchContext.minimumScore()));
Expand Down

0 comments on commit 90678c2

Please sign in to comment.