From f4e81739492a21d373215c99dc2b758b344256c5 Mon Sep 17 00:00:00 2001 From: Chenyang Ji Date: Wed, 24 Apr 2024 18:37:51 -0700 Subject: [PATCH 1/5] Support customized and rule based labeling for search queries Signed-off-by: Chenyang Ji --- CHANGELOG.md | 1 + .../core/listener/QueryInsightsListener.java | 2 + .../insights/rules/model/Attribute.java | 10 ++- .../action/search/SearchRequestContext.java | 4 + .../main/java/org/opensearch/node/Node.java | 8 +- .../search/builder/SearchSourceBuilder.java | 43 ++++++++++ .../labels/RuleBasedLabelingService.java | 52 ++++++++++++ .../labels/SearchRequestLabelingListener.java | 47 ++++++++++ .../search/labels/package-info.java | 10 +++ .../rules/DefaultUserInfoLabelingRule.java | 85 +++++++++++++++++++ .../opensearch/search/labels/rules/Rule.java | 27 ++++++ .../search/labels/rules/package-info.java | 10 +++ .../DefaultUserInfoLabelingRuleTests.java | 56 ++++++++++++ .../labels/RuleBasedLabelingServiceTests.java | 75 ++++++++++++++++ 14 files changed, 428 insertions(+), 2 deletions(-) create mode 100644 server/src/main/java/org/opensearch/search/labels/RuleBasedLabelingService.java create mode 100644 server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java create mode 100644 server/src/main/java/org/opensearch/search/labels/package-info.java create mode 100644 server/src/main/java/org/opensearch/search/labels/rules/DefaultUserInfoLabelingRule.java create mode 100644 server/src/main/java/org/opensearch/search/labels/rules/Rule.java create mode 100644 server/src/main/java/org/opensearch/search/labels/rules/package-info.java create mode 100644 server/src/test/java/org/opensearch/search/labels/DefaultUserInfoLabelingRuleTests.java create mode 100644 server/src/test/java/org/opensearch/search/labels/RuleBasedLabelingServiceTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index fb465153512bf..8f215e01660c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add remote routing table for remote state publication with experimental feature flag ([#13304](https://github.com/opensearch-project/OpenSearch/pull/13304)) - [Remote Store] Add support to disable flush based on translog reader count ([#14027](https://github.com/opensearch-project/OpenSearch/pull/14027)) - [Query Insights] Add exporter support for top n queries ([#12982](https://github.com/opensearch-project/OpenSearch/pull/12982)) +- Support customized and rule-based labeling for search queries ([#13374](https://github.com/opensearch-project/OpenSearch/pull/13374)) ### Dependencies - Bump `com.github.spullara.mustache.java:compiler` from 0.9.10 to 0.9.13 ([#13329](https://github.com/opensearch-project/OpenSearch/pull/13329), [#13559](https://github.com/opensearch-project/OpenSearch/pull/13559)) diff --git a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java index 9ec8673147c38..2f328c534a8f4 100644 --- a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java +++ b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java @@ -21,6 +21,7 @@ import org.opensearch.plugin.insights.rules.model.Attribute; import org.opensearch.plugin.insights.rules.model.MetricType; import org.opensearch.plugin.insights.rules.model.SearchQueryRecord; +import org.opensearch.search.labels.rules.DefaultUserInfoLabelingRule; import java.util.Collections; import java.util.HashMap; @@ -138,6 +139,7 @@ public void onRequestEnd(final SearchPhaseContext context, final SearchRequestCo attributes.put(Attribute.TOTAL_SHARDS, context.getNumShards()); attributes.put(Attribute.INDICES, request.indices()); attributes.put(Attribute.PHASE_LATENCY_MAP, searchRequestContext.phaseTookMap()); + attributes.put(Attribute.USER_NAME, request.source().labels().get(DefaultUserInfoLabelingRule.USER_NAME)); SearchQueryRecord record = new SearchQueryRecord(request.getOrCreateAbsoluteStartMillis(), measurements, attributes); queryInsightsService.addRecord(record); } catch (Exception e) { diff --git a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/rules/model/Attribute.java b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/rules/model/Attribute.java index c1d17edf9ff14..dc000a49e5d36 100644 --- a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/rules/model/Attribute.java +++ b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/rules/model/Attribute.java @@ -43,7 +43,15 @@ public enum Attribute { /** * The node id for this request */ - NODE_ID; + NODE_ID, + /** + * User associated with this request + */ + USER_NAME, + /** + * Custom tenant tags + */ + CUSTOMIZED_TAG; /** * Read an Attribute from a StreamInput diff --git a/server/src/main/java/org/opensearch/action/search/SearchRequestContext.java b/server/src/main/java/org/opensearch/action/search/SearchRequestContext.java index b8bbde65ca6bc..5b133ba0554f4 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchRequestContext.java +++ b/server/src/main/java/org/opensearch/action/search/SearchRequestContext.java @@ -107,6 +107,10 @@ String formattedShardStats() { ); } } + + public SearchRequest getRequest() { + return searchRequest; + } } enum ShardStatsFieldNames { diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index cb1f2caa082fc..1ceb3adb5314a 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -227,6 +227,8 @@ import org.opensearch.search.backpressure.SearchBackpressureService; import org.opensearch.search.backpressure.settings.SearchBackpressureSettings; import org.opensearch.search.fetch.FetchPhase; +import org.opensearch.search.labels.RuleBasedLabelingService; +import org.opensearch.search.labels.SearchRequestLabelingListener; import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.search.query.QueryPhase; import org.opensearch.snapshots.InternalSnapshotsInfoService; @@ -962,11 +964,15 @@ protected Node( // Add the telemetryAwarePlugin components to the existing pluginComponents collection. pluginComponents.addAll(telemetryAwarePluginComponents); + final SearchRequestLabelingListener searchRequestLabelingListener = new SearchRequestLabelingListener( + threadPool, + new RuleBasedLabelingService(new ArrayList<>()) + ); // register all standard SearchRequestOperationsCompositeListenerFactory to the SearchRequestOperationsCompositeListenerFactory final SearchRequestOperationsCompositeListenerFactory searchRequestOperationsCompositeListenerFactory = new SearchRequestOperationsCompositeListenerFactory( Stream.concat( - Stream.of(searchRequestStats, searchRequestSlowLog), + Stream.of(searchRequestStats, searchRequestSlowLog, searchRequestLabelingListener), pluginComponents.stream() .filter(p -> p instanceof SearchRequestOperationsListener) .map(p -> (SearchRequestOperationsListener) p) diff --git a/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java index 8a9704b04566f..1fd08089d60f0 100644 --- a/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java @@ -79,6 +79,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -136,6 +137,7 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R public static final ParseField SLICE = new ParseField("slice"); public static final ParseField POINT_IN_TIME = new ParseField("pit"); public static final ParseField SEARCH_PIPELINE = new ParseField("search_pipeline"); + public static final ParseField LABELS = new ParseField("labels"); public static SearchSourceBuilder fromXContent(XContentParser parser) throws IOException { return fromXContent(parser, true); @@ -224,6 +226,8 @@ public static HighlightBuilder highlight() { private Map searchPipelineSource = null; + private Map labels = new HashMap<>(); + /** * Constructs a new search source builder. */ @@ -286,6 +290,11 @@ public SearchSourceBuilder(StreamInput in) throws IOException { searchPipelineSource = in.readMap(); } } + if (in.getVersion().onOrAfter(Version.V_2_15_0)) { + if (in.readBoolean()) { + labels = in.readMap(); + } + } if (in.getVersion().onOrAfter(Version.V_2_13_0)) { includeNamedQueriesScore = in.readOptionalBoolean(); } @@ -362,6 +371,12 @@ public void writeTo(StreamOutput out) throws IOException { out.writeMap(searchPipelineSource); } } + if (out.getVersion().onOrAfter(Version.V_2_15_0)) { + out.writeBoolean(labels != null); + if (labels != null) { + out.writeMap(labels); + } + } if (out.getVersion().onOrAfter(Version.V_2_13_0)) { out.writeOptionalBoolean(includeNamedQueriesScore); } @@ -1119,6 +1134,29 @@ public SearchSourceBuilder searchPipelineSource(Map searchPipeli return this; } + /** + * @return labels defined within the search request. + */ + public Map labels() { + return labels; + } + + /** + * Define labels within this search request. + */ + public SearchSourceBuilder labels(Map labels) { + this.labels = labels; + return this; + } + + /** + * Add labels within this search request. + */ + public SearchSourceBuilder addLabels(Map labels) { + this.labels.putAll(labels); + return this; + } + /** * Rewrites this search source builder into its primitive form. e.g. by * rewriting the QueryBuilder. If the builder did not change the identity @@ -1365,6 +1403,8 @@ public void parseXContent(XContentParser parser, boolean checkTrailingTokens) th searchPipelineSource = parser.mapOrdered(); } else if (DERIVED_FIELDS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { derivedFieldsObject = parser.map(); + } else if (LABELS.match(currentFieldName, parser.getDeprecationHandler())) { + labels = parser.mapOrdered(); } else { throw new ParsingException( parser.getTokenLocation(), @@ -1597,6 +1637,9 @@ public XContentBuilder innerToXContent(XContentBuilder builder, Params params) t if (searchPipelineSource != null) { builder.field(SEARCH_PIPELINE.getPreferredName(), searchPipelineSource); } + if (labels != null) { + builder.field(LABELS.getPreferredName(), labels); + } if (derivedFieldsObject != null || derivedFields != null) { builder.startObject(DERIVED_FIELDS_FIELD.getPreferredName()); diff --git a/server/src/main/java/org/opensearch/search/labels/RuleBasedLabelingService.java b/server/src/main/java/org/opensearch/search/labels/RuleBasedLabelingService.java new file mode 100644 index 0000000000000..942faebea2b12 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/labels/RuleBasedLabelingService.java @@ -0,0 +1,52 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.labels; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.search.labels.rules.DefaultUserInfoLabelingRule; +import org.opensearch.search.labels.rules.Rule; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Service to attach labels to a search request based on pre-defined rules + * + * In this POC, this service only handles search requests, but in theory it should be able to handle index as well. + */ +public class RuleBasedLabelingService { + private final List rules; + + public RuleBasedLabelingService(List rules) { + this.rules = rules; + // default rules + rules.add(new DefaultUserInfoLabelingRule()); + } + + public List getRules() { + return rules; + } + + public void addRule(Rule rule) { + this.rules.add(rule); + } + + /** + * Evaluate all rules and return labels + */ + public void applyAllRules(final ThreadContext threadContext, final SearchRequest searchRequest) { + Map labels = rules.stream() + .map(rule -> rule.evaluate(threadContext, searchRequest)) + .flatMap(m -> m.entrySet().stream()) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + searchRequest.source().addLabels(labels); + } +} diff --git a/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java b/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java new file mode 100644 index 0000000000000..cb00fcba6accd --- /dev/null +++ b/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java @@ -0,0 +1,47 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.labels; + +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchRequestContext; +import org.opensearch.action.search.SearchRequestOperationsListener; +import org.opensearch.threadpool.ThreadPool; + +/** + * SearchRequestOperationsListener subscriber for labeling search requests + * + * @opensearch.internal + */ +public final class SearchRequestLabelingListener extends SearchRequestOperationsListener { + final private ThreadPool threadPool; + final private RuleBasedLabelingService ruleBasedLabelingService; + + public SearchRequestLabelingListener(final ThreadPool threadPool, final RuleBasedLabelingService ruleBasedLabelingService) { + this.threadPool = threadPool; + this.ruleBasedLabelingService = ruleBasedLabelingService; + } + + @Override + protected void onPhaseStart(SearchPhaseContext context) {} + + @Override + protected void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) {} + + @Override + protected void onPhaseFailure(SearchPhaseContext context, Throwable cause) {} + + @Override + public void onRequestStart(SearchRequestContext searchRequestContext) { + // add tags to search request + ruleBasedLabelingService.applyAllRules(threadPool.getThreadContext(), searchRequestContext.getRequest()); + } + + @Override + public void onRequestEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) {} +} diff --git a/server/src/main/java/org/opensearch/search/labels/package-info.java b/server/src/main/java/org/opensearch/search/labels/package-info.java new file mode 100644 index 0000000000000..acb7b154cb3f2 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/labels/package-info.java @@ -0,0 +1,10 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** Search labeling service. */ +package org.opensearch.search.labels; diff --git a/server/src/main/java/org/opensearch/search/labels/rules/DefaultUserInfoLabelingRule.java b/server/src/main/java/org/opensearch/search/labels/rules/DefaultUserInfoLabelingRule.java new file mode 100644 index 0000000000000..079f377439292 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/labels/rules/DefaultUserInfoLabelingRule.java @@ -0,0 +1,85 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.labels.rules; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.common.Strings; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +/** + * Rules to get user info labels, specifically, the info is injected by the security plugin. + */ +public class DefaultUserInfoLabelingRule implements Rule { + /** + * Constant setting for user info header key that are injected during authentication + */ + private final String REQUEST_HEADER_USER_INFO = "_opendistro_security_user_info"; + /** + * Constant setting for remote address info header key that are injected during authentication + */ + private final String REQUEST_HEADER_REMOTE_ADDRESS = "_opendistro_security_remote_address"; + + public static final String REMOTE_ADDRESS = "remote_address"; + public static final String USER_NAME = "user_name"; + public static final String USER_BACKEND_ROLES = "user_backend_roles"; + public static final String USER_ROLES = "user_roles"; + public static final String USER_TENANT = "user_tenant"; + + /** + * @param threadContext + * @param searchRequest + * @return Map of User related info and the corresponding values + */ + @Override + public Map evaluate(ThreadContext threadContext, SearchRequest searchRequest) { + return getUserInfoFromThreadContext(threadContext); + } + + /** + * Get User info, specifically injected by the security plugin, from the thread context + * + * @param threadContext context of the thread + * @return Map of User related info and the corresponding values + */ + private Map getUserInfoFromThreadContext(ThreadContext threadContext) { + Map userInfoMap = new HashMap<>(); + if (threadContext == null) { + return userInfoMap; + } + Object userInfoObj = threadContext.getTransient(REQUEST_HEADER_USER_INFO); + if (userInfoObj == null) { + return userInfoMap; + } + String userInfoStr = userInfoObj.toString(); + Object remoteAddressObj = threadContext.getTransient(REQUEST_HEADER_REMOTE_ADDRESS); + if (remoteAddressObj != null) { + userInfoMap.put(REMOTE_ADDRESS, remoteAddressObj.toString()); + } + + String[] userInfo = userInfoStr.split("\\|"); + if ((userInfo.length == 0) || (Strings.isNullOrEmpty(userInfo[0]))) { + return userInfoMap; + } + userInfoMap.put(USER_NAME, userInfo[0].trim()); + if ((userInfo.length > 1) && !Strings.isNullOrEmpty(userInfo[1])) { + userInfoMap.put(USER_BACKEND_ROLES, Arrays.asList(userInfo[1].split(","))); + } + if ((userInfo.length > 2) && !Strings.isNullOrEmpty(userInfo[2])) { + userInfoMap.put(USER_ROLES, Arrays.asList(userInfo[2].split(","))); + } + if ((userInfo.length > 3) && !Strings.isNullOrEmpty(userInfo[3])) { + userInfoMap.put(USER_TENANT, userInfo[3].trim()); + } + return userInfoMap; + } +} diff --git a/server/src/main/java/org/opensearch/search/labels/rules/Rule.java b/server/src/main/java/org/opensearch/search/labels/rules/Rule.java new file mode 100644 index 0000000000000..331ba92b1e70f --- /dev/null +++ b/server/src/main/java/org/opensearch/search/labels/rules/Rule.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.labels.rules; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.util.concurrent.ThreadContext; + +import java.util.Map; + +/** + * An interface to define a labeling rule + */ +public interface Rule { + /** + * Defines the rule to calculate labels from the context and request + * + * @return a Map of labels for POC + */ + public Map evaluate(final ThreadContext threadContext, final SearchRequest searchRequest); + +} diff --git a/server/src/main/java/org/opensearch/search/labels/rules/package-info.java b/server/src/main/java/org/opensearch/search/labels/rules/package-info.java new file mode 100644 index 0000000000000..8d16a48e3a57b --- /dev/null +++ b/server/src/main/java/org/opensearch/search/labels/rules/package-info.java @@ -0,0 +1,10 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** Search labeling rules. */ +package org.opensearch.search.labels.rules; diff --git a/server/src/test/java/org/opensearch/search/labels/DefaultUserInfoLabelingRuleTests.java b/server/src/test/java/org/opensearch/search/labels/DefaultUserInfoLabelingRuleTests.java new file mode 100644 index 0000000000000..cbb91332760e4 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/labels/DefaultUserInfoLabelingRuleTests.java @@ -0,0 +1,56 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.labels; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.search.labels.rules.DefaultUserInfoLabelingRule; +import org.opensearch.test.OpenSearchTestCase; +import org.junit.Before; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +public class DefaultUserInfoLabelingRuleTests extends OpenSearchTestCase { + private DefaultUserInfoLabelingRule defaultUserInfoLabelingRule; + private ThreadContext threadContext; + private SearchRequest searchRequest; + + @Before + public void setUpVariables() { + defaultUserInfoLabelingRule = new DefaultUserInfoLabelingRule(); + threadContext = new ThreadContext(Settings.EMPTY); + searchRequest = new SearchRequest(); + } + + public void testGetUserInfoFromThreadContext() { + threadContext.putTransient("_opendistro_security_user_info", "user1|role1,role2|group1,group2|tenant1"); + threadContext.putTransient("_opendistro_security_remote_address", "127.0.0.1"); + Map expectedUserInfoMap = new HashMap<>(); + expectedUserInfoMap.put(DefaultUserInfoLabelingRule.REMOTE_ADDRESS, "127.0.0.1"); + expectedUserInfoMap.put(DefaultUserInfoLabelingRule.USER_NAME, "user1"); + expectedUserInfoMap.put(DefaultUserInfoLabelingRule.USER_BACKEND_ROLES, Arrays.asList("role1", "role2")); + expectedUserInfoMap.put(DefaultUserInfoLabelingRule.USER_ROLES, Arrays.asList("group1", "group2")); + expectedUserInfoMap.put(DefaultUserInfoLabelingRule.USER_TENANT, "tenant1"); + Map actualUserInfoMap = defaultUserInfoLabelingRule.evaluate(threadContext, searchRequest); + assertEquals(expectedUserInfoMap, actualUserInfoMap); + } + + public void testGetUserInfoFromThreadContext_EmptyUserInfo() { + Map actualUserInfoMap = defaultUserInfoLabelingRule.evaluate(threadContext, searchRequest); + assertTrue(actualUserInfoMap.isEmpty()); + } + + public void testGetUserInfoFromThreadContext_NullThreadContext() { + Map userInfoMap = defaultUserInfoLabelingRule.evaluate(null, searchRequest); + assertTrue(userInfoMap.isEmpty()); + } +} diff --git a/server/src/test/java/org/opensearch/search/labels/RuleBasedLabelingServiceTests.java b/server/src/test/java/org/opensearch/search/labels/RuleBasedLabelingServiceTests.java new file mode 100644 index 0000000000000..25f540e57b675 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/labels/RuleBasedLabelingServiceTests.java @@ -0,0 +1,75 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.labels; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.labels.rules.DefaultUserInfoLabelingRule; +import org.opensearch.search.labels.rules.Rule; +import org.opensearch.test.OpenSearchTestCase; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class RuleBasedLabelingServiceTests extends OpenSearchTestCase { + private RuleBasedLabelingService ruleBasedLabelingService; + private ThreadContext threadContext; + private SearchRequest searchRequest; + private List rules; + + @Before + public void setUpVariables() { + rules = new ArrayList<>(); + ruleBasedLabelingService = new RuleBasedLabelingService(rules); + threadContext = new ThreadContext(Settings.EMPTY); + searchRequest = new SearchRequest(); + searchRequest.source(new SearchSourceBuilder().addLabels(new HashMap<>())); + } + + public void testConstructorAddsDefaultRule() { + List rules = ruleBasedLabelingService.getRules(); + assertEquals(1, rules.size()); + assertEquals(DefaultUserInfoLabelingRule.class, rules.get(0).getClass()); + } + + public void testAddRule() { + Rule mockRule = mock(Rule.class); + ruleBasedLabelingService.addRule(mockRule); + List rules = ruleBasedLabelingService.getRules(); + assertEquals(2, rules.size()); + assertEquals(DefaultUserInfoLabelingRule.class, rules.get(0).getClass()); + assertEquals(mockRule, rules.get(1)); + } + + public void testApplyAllRules() { + Rule mockRule1 = mock(Rule.class); + Rule mockRule2 = mock(Rule.class); + Map labels1 = new HashMap<>(); + labels1.put("label1", "value1"); + Map labels2 = new HashMap<>(); + labels2.put("label2", "value2"); + when(mockRule1.evaluate(threadContext, searchRequest)).thenReturn(labels1); + when(mockRule2.evaluate(threadContext, searchRequest)).thenReturn(labels2); + ruleBasedLabelingService.addRule(mockRule1); + ruleBasedLabelingService.addRule(mockRule2); + ruleBasedLabelingService.applyAllRules(threadContext, searchRequest); + Map expectedLabels = new HashMap<>(); + expectedLabels.putAll(labels1); + expectedLabels.putAll(labels2); + assertEquals(expectedLabels, searchRequest.source().labels()); + } +} From 56246bf69472ae62fe3d0730a714fd97d90e01ab Mon Sep 17 00:00:00 2001 From: Chenyang Ji Date: Tue, 4 Jun 2024 14:03:28 -0700 Subject: [PATCH 2/5] improve the overall logic and fix several bugs based on comments Signed-off-by: Chenyang Ji --- .../core/listener/QueryInsightsListener.java | 3 +-- .../insights/rules/model/Attribute.java | 8 ++------ .../SearchRequestOperationsListener.java | 6 +++--- .../search/builder/SearchSourceBuilder.java | 20 +++++++++---------- .../labels/RuleBasedLabelingService.java | 6 ++++++ .../labels/SearchRequestLabelingListener.java | 9 --------- .../rules/DefaultUserInfoLabelingRule.java | 12 +++++------ .../DefaultUserInfoLabelingRuleTests.java | 8 ++++++++ 8 files changed, 35 insertions(+), 37 deletions(-) diff --git a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java index 2f328c534a8f4..67a8ed275e0ba 100644 --- a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java +++ b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java @@ -21,7 +21,6 @@ import org.opensearch.plugin.insights.rules.model.Attribute; import org.opensearch.plugin.insights.rules.model.MetricType; import org.opensearch.plugin.insights.rules.model.SearchQueryRecord; -import org.opensearch.search.labels.rules.DefaultUserInfoLabelingRule; import java.util.Collections; import java.util.HashMap; @@ -139,7 +138,7 @@ public void onRequestEnd(final SearchPhaseContext context, final SearchRequestCo attributes.put(Attribute.TOTAL_SHARDS, context.getNumShards()); attributes.put(Attribute.INDICES, request.indices()); attributes.put(Attribute.PHASE_LATENCY_MAP, searchRequestContext.phaseTookMap()); - attributes.put(Attribute.USER_NAME, request.source().labels().get(DefaultUserInfoLabelingRule.USER_NAME)); + attributes.put(Attribute.LABELS, request.source().labels()); SearchQueryRecord record = new SearchQueryRecord(request.getOrCreateAbsoluteStartMillis(), measurements, attributes); queryInsightsService.addRecord(record); } catch (Exception e) { diff --git a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/rules/model/Attribute.java b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/rules/model/Attribute.java index dc000a49e5d36..7ee4883c54023 100644 --- a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/rules/model/Attribute.java +++ b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/rules/model/Attribute.java @@ -45,13 +45,9 @@ public enum Attribute { */ NODE_ID, /** - * User associated with this request + * Custom search request labels */ - USER_NAME, - /** - * Custom tenant tags - */ - CUSTOMIZED_TAG; + LABELS; /** * Read an Attribute from a StreamInput diff --git a/server/src/main/java/org/opensearch/action/search/SearchRequestOperationsListener.java b/server/src/main/java/org/opensearch/action/search/SearchRequestOperationsListener.java index 53efade174502..b944572cef122 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchRequestOperationsListener.java +++ b/server/src/main/java/org/opensearch/action/search/SearchRequestOperationsListener.java @@ -41,11 +41,11 @@ protected SearchRequestOperationsListener(final boolean enabled) { this.enabled = enabled; } - protected abstract void onPhaseStart(SearchPhaseContext context); + protected void onPhaseStart(SearchPhaseContext context) {}; - protected abstract void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext); + protected void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) {}; - protected abstract void onPhaseFailure(SearchPhaseContext context, Throwable cause); + protected void onPhaseFailure(SearchPhaseContext context, Throwable cause) {}; protected void onRequestStart(SearchRequestContext searchRequestContext) {} diff --git a/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java index 1fd08089d60f0..ca6241787f057 100644 --- a/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java @@ -226,7 +226,7 @@ public static HighlightBuilder highlight() { private Map searchPipelineSource = null; - private Map labels = new HashMap<>(); + private Map labels; /** * Constructs a new search source builder. @@ -1138,21 +1138,19 @@ public SearchSourceBuilder searchPipelineSource(Map searchPipeli * @return labels defined within the search request. */ public Map labels() { + if (this.labels == null) { + this.labels = new HashMap<>(); + } return labels; } - /** - * Define labels within this search request. - */ - public SearchSourceBuilder labels(Map labels) { - this.labels = labels; - return this; - } - /** * Add labels within this search request. */ public SearchSourceBuilder addLabels(Map labels) { + if (this.labels == null) { + this.labels = new HashMap<>(); + } this.labels.putAll(labels); return this; } @@ -1401,10 +1399,10 @@ public void parseXContent(XContentParser parser, boolean checkTrailingTokens) th pointInTimeBuilder = PointInTimeBuilder.fromXContent(parser); } else if (SEARCH_PIPELINE.match(currentFieldName, parser.getDeprecationHandler())) { searchPipelineSource = parser.mapOrdered(); - } else if (DERIVED_FIELDS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - derivedFieldsObject = parser.map(); } else if (LABELS.match(currentFieldName, parser.getDeprecationHandler())) { labels = parser.mapOrdered(); + } else if (DERIVED_FIELDS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + derivedFieldsObject = parser.map(); } else { throw new ParsingException( parser.getTokenLocation(), diff --git a/server/src/main/java/org/opensearch/search/labels/RuleBasedLabelingService.java b/server/src/main/java/org/opensearch/search/labels/RuleBasedLabelingService.java index 942faebea2b12..cee801560a72c 100644 --- a/server/src/main/java/org/opensearch/search/labels/RuleBasedLabelingService.java +++ b/server/src/main/java/org/opensearch/search/labels/RuleBasedLabelingService.java @@ -47,6 +47,12 @@ public void applyAllRules(final ThreadContext threadContext, final SearchRequest .map(rule -> rule.evaluate(threadContext, searchRequest)) .flatMap(m -> m.entrySet().stream()) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + // Handling potential spoofing by checking if any conflicts exist between user-supplied labels and the computed labels + for (String key : searchRequest.source().labels().keySet()) { + if (labels.containsKey(key)) { + throw new IllegalArgumentException("Unexpected label found: " + key); + } + } searchRequest.source().addLabels(labels); } } diff --git a/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java b/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java index cb00fcba6accd..fa6ed0f04880c 100644 --- a/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java +++ b/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java @@ -27,15 +27,6 @@ public SearchRequestLabelingListener(final ThreadPool threadPool, final RuleBase this.ruleBasedLabelingService = ruleBasedLabelingService; } - @Override - protected void onPhaseStart(SearchPhaseContext context) {} - - @Override - protected void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) {} - - @Override - protected void onPhaseFailure(SearchPhaseContext context, Throwable cause) {} - @Override public void onRequestStart(SearchRequestContext searchRequestContext) { // add tags to search request diff --git a/server/src/main/java/org/opensearch/search/labels/rules/DefaultUserInfoLabelingRule.java b/server/src/main/java/org/opensearch/search/labels/rules/DefaultUserInfoLabelingRule.java index 079f377439292..63fd95ea1d855 100644 --- a/server/src/main/java/org/opensearch/search/labels/rules/DefaultUserInfoLabelingRule.java +++ b/server/src/main/java/org/opensearch/search/labels/rules/DefaultUserInfoLabelingRule.java @@ -17,7 +17,7 @@ import java.util.Map; /** - * Rules to get user info labels, specifically, the info is injected by the security plugin. + * Rules to get user info labels, specifically, the info that is injected by the security plugin. */ public class DefaultUserInfoLabelingRule implements Rule { /** @@ -56,16 +56,16 @@ private Map getUserInfoFromThreadContext(ThreadContext threadCon if (threadContext == null) { return userInfoMap; } - Object userInfoObj = threadContext.getTransient(REQUEST_HEADER_USER_INFO); - if (userInfoObj == null) { - return userInfoMap; - } - String userInfoStr = userInfoObj.toString(); Object remoteAddressObj = threadContext.getTransient(REQUEST_HEADER_REMOTE_ADDRESS); if (remoteAddressObj != null) { userInfoMap.put(REMOTE_ADDRESS, remoteAddressObj.toString()); } + Object userInfoObj = threadContext.getTransient(REQUEST_HEADER_USER_INFO); + if (userInfoObj == null) { + return userInfoMap; + } + String userInfoStr = userInfoObj.toString(); String[] userInfo = userInfoStr.split("\\|"); if ((userInfo.length == 0) || (Strings.isNullOrEmpty(userInfo[0]))) { return userInfoMap; diff --git a/server/src/test/java/org/opensearch/search/labels/DefaultUserInfoLabelingRuleTests.java b/server/src/test/java/org/opensearch/search/labels/DefaultUserInfoLabelingRuleTests.java index cbb91332760e4..dd220eae4f5a7 100644 --- a/server/src/test/java/org/opensearch/search/labels/DefaultUserInfoLabelingRuleTests.java +++ b/server/src/test/java/org/opensearch/search/labels/DefaultUserInfoLabelingRuleTests.java @@ -44,6 +44,14 @@ public void testGetUserInfoFromThreadContext() { assertEquals(expectedUserInfoMap, actualUserInfoMap); } + public void testGetPartialInfoFromThreadContext() { + threadContext.putTransient("_opendistro_security_remote_address", "127.0.0.1"); + Map expectedUserInfoMap = new HashMap<>(); + expectedUserInfoMap.put(DefaultUserInfoLabelingRule.REMOTE_ADDRESS, "127.0.0.1"); + Map actualUserInfoMap = defaultUserInfoLabelingRule.evaluate(threadContext, searchRequest); + assertEquals(expectedUserInfoMap, actualUserInfoMap); + } + public void testGetUserInfoFromThreadContext_EmptyUserInfo() { Map actualUserInfoMap = defaultUserInfoLabelingRule.evaluate(threadContext, searchRequest); assertTrue(actualUserInfoMap.isEmpty()); From fb6df675c75f454182d58765f2079fbf0e8402e9 Mon Sep 17 00:00:00 2001 From: Chenyang Ji Date: Wed, 5 Jun 2024 14:09:57 -0700 Subject: [PATCH 3/5] remove label field from search source and use header for labeling Signed-off-by: Chenyang Ji --- CHANGELOG.md | 2 +- .../plugin/insights/QueryInsightsPlugin.java | 2 +- .../core/listener/QueryInsightsListener.java | 31 +++++- .../listener/QueryInsightsListenerTests.java | 18 +++- .../main/java/org/opensearch/node/Node.java | 9 +- .../search/builder/SearchSourceBuilder.java | 41 -------- .../search/labels/RequestLabelingService.java | 76 ++++++++++++++ .../labels/RuleBasedLabelingService.java | 58 ----------- .../labels/SearchRequestLabelingListener.java | 11 +-- .../rules/DefaultUserInfoLabelingRule.java | 85 ---------------- .../DefaultUserInfoLabelingRuleTests.java | 64 ------------ .../labels/RequestLabelingServiceTests.java | 99 +++++++++++++++++++ .../labels/RuleBasedLabelingServiceTests.java | 75 -------------- 13 files changed, 231 insertions(+), 340 deletions(-) create mode 100644 server/src/main/java/org/opensearch/search/labels/RequestLabelingService.java delete mode 100644 server/src/main/java/org/opensearch/search/labels/RuleBasedLabelingService.java delete mode 100644 server/src/main/java/org/opensearch/search/labels/rules/DefaultUserInfoLabelingRule.java delete mode 100644 server/src/test/java/org/opensearch/search/labels/DefaultUserInfoLabelingRuleTests.java create mode 100644 server/src/test/java/org/opensearch/search/labels/RequestLabelingServiceTests.java delete mode 100644 server/src/test/java/org/opensearch/search/labels/RuleBasedLabelingServiceTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f215e01660c9..2dca88f7c2430 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add remote routing table for remote state publication with experimental feature flag ([#13304](https://github.com/opensearch-project/OpenSearch/pull/13304)) - [Remote Store] Add support to disable flush based on translog reader count ([#14027](https://github.com/opensearch-project/OpenSearch/pull/14027)) - [Query Insights] Add exporter support for top n queries ([#12982](https://github.com/opensearch-project/OpenSearch/pull/12982)) -- Support customized and rule-based labeling for search queries ([#13374](https://github.com/opensearch-project/OpenSearch/pull/13374)) +- Support rule-based labeling for search queries ([#13374](https://github.com/opensearch-project/OpenSearch/pull/13374)) ### Dependencies - Bump `com.github.spullara.mustache.java:compiler` from 0.9.10 to 0.9.13 ([#13329](https://github.com/opensearch-project/OpenSearch/pull/13329), [#13559](https://github.com/opensearch-project/OpenSearch/pull/13559)) diff --git a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/QueryInsightsPlugin.java b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/QueryInsightsPlugin.java index 22831c3e0f8ba..ee25da5d8b217 100644 --- a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/QueryInsightsPlugin.java +++ b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/QueryInsightsPlugin.java @@ -71,7 +71,7 @@ public Collection createComponents( ) { // create top n queries service final QueryInsightsService queryInsightsService = new QueryInsightsService(clusterService.getClusterSettings(), threadPool, client); - return List.of(queryInsightsService, new QueryInsightsListener(clusterService, queryInsightsService)); + return List.of(queryInsightsService, new QueryInsightsListener(threadPool, clusterService, queryInsightsService)); } @Override diff --git a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java index 67a8ed275e0ba..0e6e768781970 100644 --- a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java +++ b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java @@ -16,11 +16,15 @@ import org.opensearch.action.search.SearchRequestOperationsListener; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.plugin.insights.core.service.QueryInsightsService; import org.opensearch.plugin.insights.rules.model.Attribute; import org.opensearch.plugin.insights.rules.model.MetricType; import org.opensearch.plugin.insights.rules.model.SearchQueryRecord; +import org.opensearch.search.labels.RequestLabelingService; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; import java.util.Collections; import java.util.HashMap; @@ -45,15 +49,21 @@ public final class QueryInsightsListener extends SearchRequestOperationsListener private static final Logger log = LogManager.getLogger(QueryInsightsListener.class); private final QueryInsightsService queryInsightsService; + private final ThreadPool threadPool; /** * Constructor for QueryInsightsListener * + * @param threadPool the OpenSearch internal threadPool * @param clusterService The Node's cluster service. * @param queryInsightsService The topQueriesByLatencyService associated with this listener */ @Inject - public QueryInsightsListener(final ClusterService clusterService, final QueryInsightsService queryInsightsService) { + public QueryInsightsListener( + final ThreadPool threadPool, + final ClusterService clusterService, + final QueryInsightsService queryInsightsService + ) { this.queryInsightsService = queryInsightsService; clusterService.getClusterSettings() .addSettingsUpdateConsumer(TOP_N_LATENCY_QUERIES_ENABLED, v -> this.setEnableTopQueries(MetricType.LATENCY, v)); @@ -74,6 +84,7 @@ public QueryInsightsListener(final ClusterService clusterService, final QueryIns .setTopNSize(clusterService.getClusterSettings().get(TOP_N_LATENCY_QUERIES_SIZE)); this.queryInsightsService.getTopQueriesService(MetricType.LATENCY) .setWindowSize(clusterService.getClusterSettings().get(TOP_N_LATENCY_QUERIES_WINDOW_SIZE)); + this.threadPool = threadPool; } /** @@ -138,7 +149,23 @@ public void onRequestEnd(final SearchPhaseContext context, final SearchRequestCo attributes.put(Attribute.TOTAL_SHARDS, context.getNumShards()); attributes.put(Attribute.INDICES, request.indices()); attributes.put(Attribute.PHASE_LATENCY_MAP, searchRequestContext.phaseTookMap()); - attributes.put(Attribute.LABELS, request.source().labels()); + + // Get internal computed and user provided labels + Map labels = new HashMap<>(); + // Retrieve user provided label if exists + ThreadContext threadContext = threadPool.getThreadContext(); + String userProvidedLabel = threadContext.getRequestHeadersOnly().get(Task.X_OPAQUE_ID); + if (userProvidedLabel != null) { + labels.put(Task.X_OPAQUE_ID, userProvidedLabel); + } + // Retrieve computed labels if exists + Map computedLabels = threadContext.getTransient(RequestLabelingService.COMPUTED_LABELS); + if (computedLabels != null) { + labels.putAll(computedLabels); + } + attributes.put(Attribute.LABELS, labels); + + // construct SearchQueryRecord from attributes and measurements SearchQueryRecord record = new SearchQueryRecord(request.getOrCreateAbsoluteStartMillis(), measurements, attributes); queryInsightsService.addRecord(record); } catch (Exception e) { diff --git a/plugins/query-insights/src/test/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListenerTests.java b/plugins/query-insights/src/test/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListenerTests.java index 328ed0cd2ed15..a4d4ca5736af0 100644 --- a/plugins/query-insights/src/test/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListenerTests.java +++ b/plugins/query-insights/src/test/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListenerTests.java @@ -13,8 +13,10 @@ import org.opensearch.action.search.SearchRequestContext; import org.opensearch.action.search.SearchType; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.collect.Tuple; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.plugin.insights.core.service.QueryInsightsService; import org.opensearch.plugin.insights.core.service.TopQueriesService; import org.opensearch.plugin.insights.rules.model.MetricType; @@ -22,11 +24,15 @@ import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.opensearch.search.aggregations.support.ValueType; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.labels.RequestLabelingService; +import org.opensearch.tasks.Task; import org.opensearch.test.ClusterServiceUtils; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; import org.junit.Before; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -48,6 +54,7 @@ public class QueryInsightsListenerTests extends OpenSearchTestCase { private final SearchRequest searchRequest = mock(SearchRequest.class); private final QueryInsightsService queryInsightsService = mock(QueryInsightsService.class); private final TopQueriesService topQueriesService = mock(TopQueriesService.class); + private final ThreadPool threadPool = mock(ThreadPool.class); private ClusterService clusterService; @Before @@ -61,6 +68,11 @@ public void setup() { clusterService = ClusterServiceUtils.createClusterService(settings, clusterSettings, null); when(queryInsightsService.isCollectionEnabled(MetricType.LATENCY)).thenReturn(true); when(queryInsightsService.getTopQueriesService(MetricType.LATENCY)).thenReturn(topQueriesService); + + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + threadContext.setHeaders(new Tuple<>(Collections.singletonMap(Task.X_OPAQUE_ID, "test"), new HashMap<>())); + threadContext.putTransient(RequestLabelingService.COMPUTED_LABELS, Map.of("a", "b")); + when(threadPool.getThreadContext()).thenReturn(threadContext); } public void testOnRequestEnd() throws InterruptedException { @@ -80,7 +92,7 @@ public void testOnRequestEnd() throws InterruptedException { int numberOfShards = 10; - QueryInsightsListener queryInsightsListener = new QueryInsightsListener(clusterService, queryInsightsService); + QueryInsightsListener queryInsightsListener = new QueryInsightsListener(threadPool, clusterService, queryInsightsService); when(searchRequest.getOrCreateAbsoluteStartMillis()).thenReturn(timestamp); when(searchRequest.searchType()).thenReturn(searchType); @@ -128,7 +140,7 @@ public void testConcurrentOnRequestEnd() throws InterruptedException { CountDownLatch countDownLatch = new CountDownLatch(numRequests); for (int i = 0; i < numRequests; i++) { - searchListenersList.add(new QueryInsightsListener(clusterService, queryInsightsService)); + searchListenersList.add(new QueryInsightsListener(threadPool, clusterService, queryInsightsService)); } for (int i = 0; i < numRequests; i++) { @@ -149,7 +161,7 @@ public void testConcurrentOnRequestEnd() throws InterruptedException { public void testSetEnabled() { when(queryInsightsService.isCollectionEnabled(MetricType.LATENCY)).thenReturn(true); - QueryInsightsListener queryInsightsListener = new QueryInsightsListener(clusterService, queryInsightsService); + QueryInsightsListener queryInsightsListener = new QueryInsightsListener(threadPool, clusterService, queryInsightsService); queryInsightsListener.setEnableTopQueries(MetricType.LATENCY, true); assertTrue(queryInsightsListener.isEnabled()); diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index 1ceb3adb5314a..4d3f526e9a448 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -227,8 +227,9 @@ import org.opensearch.search.backpressure.SearchBackpressureService; import org.opensearch.search.backpressure.settings.SearchBackpressureSettings; import org.opensearch.search.fetch.FetchPhase; -import org.opensearch.search.labels.RuleBasedLabelingService; +import org.opensearch.search.labels.RequestLabelingService; import org.opensearch.search.labels.SearchRequestLabelingListener; +import org.opensearch.search.labels.rules.Rule; import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.search.query.QueryPhase; import org.opensearch.snapshots.InternalSnapshotsInfoService; @@ -965,8 +966,10 @@ protected Node( pluginComponents.addAll(telemetryAwarePluginComponents); final SearchRequestLabelingListener searchRequestLabelingListener = new SearchRequestLabelingListener( - threadPool, - new RuleBasedLabelingService(new ArrayList<>()) + new RequestLabelingService( + threadPool, + pluginComponents.stream().filter(p -> p instanceof Rule).map(p -> (Rule) p).collect(toList()) + ) ); // register all standard SearchRequestOperationsCompositeListenerFactory to the SearchRequestOperationsCompositeListenerFactory final SearchRequestOperationsCompositeListenerFactory searchRequestOperationsCompositeListenerFactory = diff --git a/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java index ca6241787f057..8a9704b04566f 100644 --- a/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java @@ -79,7 +79,6 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -137,7 +136,6 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R public static final ParseField SLICE = new ParseField("slice"); public static final ParseField POINT_IN_TIME = new ParseField("pit"); public static final ParseField SEARCH_PIPELINE = new ParseField("search_pipeline"); - public static final ParseField LABELS = new ParseField("labels"); public static SearchSourceBuilder fromXContent(XContentParser parser) throws IOException { return fromXContent(parser, true); @@ -226,8 +224,6 @@ public static HighlightBuilder highlight() { private Map searchPipelineSource = null; - private Map labels; - /** * Constructs a new search source builder. */ @@ -290,11 +286,6 @@ public SearchSourceBuilder(StreamInput in) throws IOException { searchPipelineSource = in.readMap(); } } - if (in.getVersion().onOrAfter(Version.V_2_15_0)) { - if (in.readBoolean()) { - labels = in.readMap(); - } - } if (in.getVersion().onOrAfter(Version.V_2_13_0)) { includeNamedQueriesScore = in.readOptionalBoolean(); } @@ -371,12 +362,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeMap(searchPipelineSource); } } - if (out.getVersion().onOrAfter(Version.V_2_15_0)) { - out.writeBoolean(labels != null); - if (labels != null) { - out.writeMap(labels); - } - } if (out.getVersion().onOrAfter(Version.V_2_13_0)) { out.writeOptionalBoolean(includeNamedQueriesScore); } @@ -1134,27 +1119,6 @@ public SearchSourceBuilder searchPipelineSource(Map searchPipeli return this; } - /** - * @return labels defined within the search request. - */ - public Map labels() { - if (this.labels == null) { - this.labels = new HashMap<>(); - } - return labels; - } - - /** - * Add labels within this search request. - */ - public SearchSourceBuilder addLabels(Map labels) { - if (this.labels == null) { - this.labels = new HashMap<>(); - } - this.labels.putAll(labels); - return this; - } - /** * Rewrites this search source builder into its primitive form. e.g. by * rewriting the QueryBuilder. If the builder did not change the identity @@ -1399,8 +1363,6 @@ public void parseXContent(XContentParser parser, boolean checkTrailingTokens) th pointInTimeBuilder = PointInTimeBuilder.fromXContent(parser); } else if (SEARCH_PIPELINE.match(currentFieldName, parser.getDeprecationHandler())) { searchPipelineSource = parser.mapOrdered(); - } else if (LABELS.match(currentFieldName, parser.getDeprecationHandler())) { - labels = parser.mapOrdered(); } else if (DERIVED_FIELDS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { derivedFieldsObject = parser.map(); } else { @@ -1635,9 +1597,6 @@ public XContentBuilder innerToXContent(XContentBuilder builder, Params params) t if (searchPipelineSource != null) { builder.field(SEARCH_PIPELINE.getPreferredName(), searchPipelineSource); } - if (labels != null) { - builder.field(LABELS.getPreferredName(), labels); - } if (derivedFieldsObject != null || derivedFields != null) { builder.startObject(DERIVED_FIELDS_FIELD.getPreferredName()); diff --git a/server/src/main/java/org/opensearch/search/labels/RequestLabelingService.java b/server/src/main/java/org/opensearch/search/labels/RequestLabelingService.java new file mode 100644 index 0000000000000..6e0f9dfc14355 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/labels/RequestLabelingService.java @@ -0,0 +1,76 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.labels; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.search.labels.rules.Rule; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Service to attach labels to a search request based on pre-defined rules + * It evaluate all available rules and generate labels into the thread context. + */ +public class RequestLabelingService { + /** + * Field name for computed labels + */ + public static final String COMPUTED_LABELS = "computed_labels"; + private final ThreadPool threadPool; + private final List rules; + + public RequestLabelingService(final ThreadPool threadPool, final List rules) { + this.threadPool = threadPool; + this.rules = rules; + } + + /** + * Get all the existing rules + * + * @return list of existing rules + */ + public List getRules() { + return rules; + } + + /** + * Add a labeling rule to the service + * + * @param rule {@link Rule} + */ + public void addRule(final Rule rule) { + this.rules.add(rule); + } + + /** + * Get the user provided tag from the X-Opaque-Id header + * + * @return user provided tag + */ + public String getUserProvidedTag() { + return threadPool.getThreadContext().getRequestHeadersOnly().getOrDefault(Task.X_OPAQUE_ID, null); + } + + /** + * Evaluate all labeling rules and store the computed rules into thread context + * + * @param searchRequest {@link SearchRequest} + */ + public void applyAllRules(final SearchRequest searchRequest) { + Map labels = rules.stream() + .map(rule -> rule.evaluate(threadPool.getThreadContext(), searchRequest)) + .flatMap(m -> m.entrySet().stream()) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (existing, replacement) -> replacement)); + threadPool.getThreadContext().putTransient(COMPUTED_LABELS, labels); + } +} diff --git a/server/src/main/java/org/opensearch/search/labels/RuleBasedLabelingService.java b/server/src/main/java/org/opensearch/search/labels/RuleBasedLabelingService.java deleted file mode 100644 index cee801560a72c..0000000000000 --- a/server/src/main/java/org/opensearch/search/labels/RuleBasedLabelingService.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.search.labels; - -import org.opensearch.action.search.SearchRequest; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.search.labels.rules.DefaultUserInfoLabelingRule; -import org.opensearch.search.labels.rules.Rule; - -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -/** - * Service to attach labels to a search request based on pre-defined rules - * - * In this POC, this service only handles search requests, but in theory it should be able to handle index as well. - */ -public class RuleBasedLabelingService { - private final List rules; - - public RuleBasedLabelingService(List rules) { - this.rules = rules; - // default rules - rules.add(new DefaultUserInfoLabelingRule()); - } - - public List getRules() { - return rules; - } - - public void addRule(Rule rule) { - this.rules.add(rule); - } - - /** - * Evaluate all rules and return labels - */ - public void applyAllRules(final ThreadContext threadContext, final SearchRequest searchRequest) { - Map labels = rules.stream() - .map(rule -> rule.evaluate(threadContext, searchRequest)) - .flatMap(m -> m.entrySet().stream()) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - // Handling potential spoofing by checking if any conflicts exist between user-supplied labels and the computed labels - for (String key : searchRequest.source().labels().keySet()) { - if (labels.containsKey(key)) { - throw new IllegalArgumentException("Unexpected label found: " + key); - } - } - searchRequest.source().addLabels(labels); - } -} diff --git a/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java b/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java index fa6ed0f04880c..d672bb199404f 100644 --- a/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java +++ b/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java @@ -11,7 +11,6 @@ import org.opensearch.action.search.SearchPhaseContext; import org.opensearch.action.search.SearchRequestContext; import org.opensearch.action.search.SearchRequestOperationsListener; -import org.opensearch.threadpool.ThreadPool; /** * SearchRequestOperationsListener subscriber for labeling search requests @@ -19,18 +18,16 @@ * @opensearch.internal */ public final class SearchRequestLabelingListener extends SearchRequestOperationsListener { - final private ThreadPool threadPool; - final private RuleBasedLabelingService ruleBasedLabelingService; + final private RequestLabelingService requestLabelingService; - public SearchRequestLabelingListener(final ThreadPool threadPool, final RuleBasedLabelingService ruleBasedLabelingService) { - this.threadPool = threadPool; - this.ruleBasedLabelingService = ruleBasedLabelingService; + public SearchRequestLabelingListener(final RequestLabelingService requestLabelingService) { + this.requestLabelingService = requestLabelingService; } @Override public void onRequestStart(SearchRequestContext searchRequestContext) { // add tags to search request - ruleBasedLabelingService.applyAllRules(threadPool.getThreadContext(), searchRequestContext.getRequest()); + requestLabelingService.applyAllRules(searchRequestContext.getRequest()); } @Override diff --git a/server/src/main/java/org/opensearch/search/labels/rules/DefaultUserInfoLabelingRule.java b/server/src/main/java/org/opensearch/search/labels/rules/DefaultUserInfoLabelingRule.java deleted file mode 100644 index 63fd95ea1d855..0000000000000 --- a/server/src/main/java/org/opensearch/search/labels/rules/DefaultUserInfoLabelingRule.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.search.labels.rules; - -import org.opensearch.action.search.SearchRequest; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.core.common.Strings; - -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; - -/** - * Rules to get user info labels, specifically, the info that is injected by the security plugin. - */ -public class DefaultUserInfoLabelingRule implements Rule { - /** - * Constant setting for user info header key that are injected during authentication - */ - private final String REQUEST_HEADER_USER_INFO = "_opendistro_security_user_info"; - /** - * Constant setting for remote address info header key that are injected during authentication - */ - private final String REQUEST_HEADER_REMOTE_ADDRESS = "_opendistro_security_remote_address"; - - public static final String REMOTE_ADDRESS = "remote_address"; - public static final String USER_NAME = "user_name"; - public static final String USER_BACKEND_ROLES = "user_backend_roles"; - public static final String USER_ROLES = "user_roles"; - public static final String USER_TENANT = "user_tenant"; - - /** - * @param threadContext - * @param searchRequest - * @return Map of User related info and the corresponding values - */ - @Override - public Map evaluate(ThreadContext threadContext, SearchRequest searchRequest) { - return getUserInfoFromThreadContext(threadContext); - } - - /** - * Get User info, specifically injected by the security plugin, from the thread context - * - * @param threadContext context of the thread - * @return Map of User related info and the corresponding values - */ - private Map getUserInfoFromThreadContext(ThreadContext threadContext) { - Map userInfoMap = new HashMap<>(); - if (threadContext == null) { - return userInfoMap; - } - Object remoteAddressObj = threadContext.getTransient(REQUEST_HEADER_REMOTE_ADDRESS); - if (remoteAddressObj != null) { - userInfoMap.put(REMOTE_ADDRESS, remoteAddressObj.toString()); - } - - Object userInfoObj = threadContext.getTransient(REQUEST_HEADER_USER_INFO); - if (userInfoObj == null) { - return userInfoMap; - } - String userInfoStr = userInfoObj.toString(); - String[] userInfo = userInfoStr.split("\\|"); - if ((userInfo.length == 0) || (Strings.isNullOrEmpty(userInfo[0]))) { - return userInfoMap; - } - userInfoMap.put(USER_NAME, userInfo[0].trim()); - if ((userInfo.length > 1) && !Strings.isNullOrEmpty(userInfo[1])) { - userInfoMap.put(USER_BACKEND_ROLES, Arrays.asList(userInfo[1].split(","))); - } - if ((userInfo.length > 2) && !Strings.isNullOrEmpty(userInfo[2])) { - userInfoMap.put(USER_ROLES, Arrays.asList(userInfo[2].split(","))); - } - if ((userInfo.length > 3) && !Strings.isNullOrEmpty(userInfo[3])) { - userInfoMap.put(USER_TENANT, userInfo[3].trim()); - } - return userInfoMap; - } -} diff --git a/server/src/test/java/org/opensearch/search/labels/DefaultUserInfoLabelingRuleTests.java b/server/src/test/java/org/opensearch/search/labels/DefaultUserInfoLabelingRuleTests.java deleted file mode 100644 index dd220eae4f5a7..0000000000000 --- a/server/src/test/java/org/opensearch/search/labels/DefaultUserInfoLabelingRuleTests.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.search.labels; - -import org.opensearch.action.search.SearchRequest; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.search.labels.rules.DefaultUserInfoLabelingRule; -import org.opensearch.test.OpenSearchTestCase; -import org.junit.Before; - -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; - -public class DefaultUserInfoLabelingRuleTests extends OpenSearchTestCase { - private DefaultUserInfoLabelingRule defaultUserInfoLabelingRule; - private ThreadContext threadContext; - private SearchRequest searchRequest; - - @Before - public void setUpVariables() { - defaultUserInfoLabelingRule = new DefaultUserInfoLabelingRule(); - threadContext = new ThreadContext(Settings.EMPTY); - searchRequest = new SearchRequest(); - } - - public void testGetUserInfoFromThreadContext() { - threadContext.putTransient("_opendistro_security_user_info", "user1|role1,role2|group1,group2|tenant1"); - threadContext.putTransient("_opendistro_security_remote_address", "127.0.0.1"); - Map expectedUserInfoMap = new HashMap<>(); - expectedUserInfoMap.put(DefaultUserInfoLabelingRule.REMOTE_ADDRESS, "127.0.0.1"); - expectedUserInfoMap.put(DefaultUserInfoLabelingRule.USER_NAME, "user1"); - expectedUserInfoMap.put(DefaultUserInfoLabelingRule.USER_BACKEND_ROLES, Arrays.asList("role1", "role2")); - expectedUserInfoMap.put(DefaultUserInfoLabelingRule.USER_ROLES, Arrays.asList("group1", "group2")); - expectedUserInfoMap.put(DefaultUserInfoLabelingRule.USER_TENANT, "tenant1"); - Map actualUserInfoMap = defaultUserInfoLabelingRule.evaluate(threadContext, searchRequest); - assertEquals(expectedUserInfoMap, actualUserInfoMap); - } - - public void testGetPartialInfoFromThreadContext() { - threadContext.putTransient("_opendistro_security_remote_address", "127.0.0.1"); - Map expectedUserInfoMap = new HashMap<>(); - expectedUserInfoMap.put(DefaultUserInfoLabelingRule.REMOTE_ADDRESS, "127.0.0.1"); - Map actualUserInfoMap = defaultUserInfoLabelingRule.evaluate(threadContext, searchRequest); - assertEquals(expectedUserInfoMap, actualUserInfoMap); - } - - public void testGetUserInfoFromThreadContext_EmptyUserInfo() { - Map actualUserInfoMap = defaultUserInfoLabelingRule.evaluate(threadContext, searchRequest); - assertTrue(actualUserInfoMap.isEmpty()); - } - - public void testGetUserInfoFromThreadContext_NullThreadContext() { - Map userInfoMap = defaultUserInfoLabelingRule.evaluate(null, searchRequest); - assertTrue(userInfoMap.isEmpty()); - } -} diff --git a/server/src/test/java/org/opensearch/search/labels/RequestLabelingServiceTests.java b/server/src/test/java/org/opensearch/search/labels/RequestLabelingServiceTests.java new file mode 100644 index 0000000000000..fe7f899d9c45e --- /dev/null +++ b/server/src/test/java/org/opensearch/search/labels/RequestLabelingServiceTests.java @@ -0,0 +1,99 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.labels; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.collect.Tuple; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.search.labels.rules.Rule; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class RequestLabelingServiceTests extends OpenSearchTestCase { + private RequestLabelingService requestLabelingService; + private ThreadContext threadContext; + private final ThreadPool threadPool = mock(ThreadPool.class); + private final Rule mockRule1 = mock(Rule.class); + private final Rule mockRule2 = mock(Rule.class); + private final List rules = new ArrayList<>(); + + @Before + public void setUpVariables() { + requestLabelingService = new RequestLabelingService(threadPool, rules); + threadContext = new ThreadContext(Settings.EMPTY); + when(threadPool.getThreadContext()).thenReturn(threadContext); + } + + public void testAddRule() { + Rule mockRule = mock(Rule.class); + requestLabelingService.addRule(mockRule); + List rules = requestLabelingService.getRules(); + assertEquals(1, rules.size()); + assertEquals(mockRule, rules.get(0)); + } + + public void testGetUserProvidedTag() { + String expectedTag = "test-tag"; + threadContext.setHeaders(new Tuple<>(Collections.singletonMap(Task.X_OPAQUE_ID, expectedTag), new HashMap<>())); + String actualTag = requestLabelingService.getUserProvidedTag(); + assertEquals(expectedTag, actualTag); + } + + public void testBasicApplyAllRules() { + SearchRequest mockSearchRequest = mock(SearchRequest.class); + Map mockLabelMap = Collections.singletonMap("label1", "value1"); + when(mockRule1.evaluate(threadContext, mockSearchRequest)).thenReturn(mockLabelMap); + rules.add(mockRule1); + requestLabelingService.applyAllRules(mockSearchRequest); + Map computedLabels = threadContext.getTransient(RequestLabelingService.COMPUTED_LABELS); + assertEquals(1, computedLabels.size()); + assertEquals("value1", computedLabels.get("label1")); + } + + public void testApplyAllRulesWithConflict() { + SearchRequest mockSearchRequest = mock(SearchRequest.class); + Map mockLabelMap1 = Collections.singletonMap("conflictingLabel", "value1"); + Map mockLabelMap2 = Collections.singletonMap("conflictingLabel", "value2"); + when(mockRule1.evaluate(threadContext, mockSearchRequest)).thenReturn(mockLabelMap1); + when(mockRule2.evaluate(threadContext, mockSearchRequest)).thenReturn(mockLabelMap2); + rules.add(mockRule1); + rules.add(mockRule2); + requestLabelingService.applyAllRules(mockSearchRequest); + Map computedLabels = threadContext.getTransient(RequestLabelingService.COMPUTED_LABELS); + assertEquals(1, computedLabels.size()); + assertEquals("value2", computedLabels.get("conflictingLabel")); + } + + public void testApplyAllRulesWithoutConflict() { + SearchRequest mockSearchRequest = mock(SearchRequest.class); + Map mockLabelMap1 = Collections.singletonMap("label1", "value1"); + Map mockLabelMap2 = Collections.singletonMap("label2", "value2"); + when(mockRule1.evaluate(threadContext, mockSearchRequest)).thenReturn(mockLabelMap1); + when(mockRule2.evaluate(threadContext, mockSearchRequest)).thenReturn(mockLabelMap2); + rules.add(mockRule1); + rules.add(mockRule2); + requestLabelingService.applyAllRules(mockSearchRequest); + Map computedLabels = threadContext.getTransient(RequestLabelingService.COMPUTED_LABELS); + assertEquals(2, computedLabels.size()); + assertEquals("value1", computedLabels.get("label1")); + assertEquals("value2", computedLabels.get("label2")); + } +} diff --git a/server/src/test/java/org/opensearch/search/labels/RuleBasedLabelingServiceTests.java b/server/src/test/java/org/opensearch/search/labels/RuleBasedLabelingServiceTests.java deleted file mode 100644 index 25f540e57b675..0000000000000 --- a/server/src/test/java/org/opensearch/search/labels/RuleBasedLabelingServiceTests.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.search.labels; - -import org.opensearch.action.search.SearchRequest; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.search.labels.rules.DefaultUserInfoLabelingRule; -import org.opensearch.search.labels.rules.Rule; -import org.opensearch.test.OpenSearchTestCase; -import org.junit.Before; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -public class RuleBasedLabelingServiceTests extends OpenSearchTestCase { - private RuleBasedLabelingService ruleBasedLabelingService; - private ThreadContext threadContext; - private SearchRequest searchRequest; - private List rules; - - @Before - public void setUpVariables() { - rules = new ArrayList<>(); - ruleBasedLabelingService = new RuleBasedLabelingService(rules); - threadContext = new ThreadContext(Settings.EMPTY); - searchRequest = new SearchRequest(); - searchRequest.source(new SearchSourceBuilder().addLabels(new HashMap<>())); - } - - public void testConstructorAddsDefaultRule() { - List rules = ruleBasedLabelingService.getRules(); - assertEquals(1, rules.size()); - assertEquals(DefaultUserInfoLabelingRule.class, rules.get(0).getClass()); - } - - public void testAddRule() { - Rule mockRule = mock(Rule.class); - ruleBasedLabelingService.addRule(mockRule); - List rules = ruleBasedLabelingService.getRules(); - assertEquals(2, rules.size()); - assertEquals(DefaultUserInfoLabelingRule.class, rules.get(0).getClass()); - assertEquals(mockRule, rules.get(1)); - } - - public void testApplyAllRules() { - Rule mockRule1 = mock(Rule.class); - Rule mockRule2 = mock(Rule.class); - Map labels1 = new HashMap<>(); - labels1.put("label1", "value1"); - Map labels2 = new HashMap<>(); - labels2.put("label2", "value2"); - when(mockRule1.evaluate(threadContext, searchRequest)).thenReturn(labels1); - when(mockRule2.evaluate(threadContext, searchRequest)).thenReturn(labels2); - ruleBasedLabelingService.addRule(mockRule1); - ruleBasedLabelingService.addRule(mockRule2); - ruleBasedLabelingService.applyAllRules(threadContext, searchRequest); - Map expectedLabels = new HashMap<>(); - expectedLabels.putAll(labels1); - expectedLabels.putAll(labels2); - assertEquals(expectedLabels, searchRequest.source().labels()); - } -} From 3b6fd30dbb70452368cee3501e6889b4d9cc347a Mon Sep 17 00:00:00 2001 From: Chenyang Ji Date: Wed, 5 Jun 2024 19:01:23 -0700 Subject: [PATCH 4/5] refactor code based on comments Signed-off-by: Chenyang Ji --- .../core/listener/QueryInsightsListener.java | 7 +-- .../listener/QueryInsightsListenerTests.java | 21 +++++++-- .../search/labels/RequestLabelingService.java | 45 ++++++++----------- .../labels/SearchRequestLabelingListener.java | 4 -- .../labels/RequestLabelingServiceTests.java | 16 ++----- 5 files changed, 43 insertions(+), 50 deletions(-) diff --git a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java index 0e6e768781970..263c4d3a6f78d 100644 --- a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java +++ b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java @@ -16,7 +16,6 @@ import org.opensearch.action.search.SearchRequestOperationsListener; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.plugin.insights.core.service.QueryInsightsService; import org.opensearch.plugin.insights.rules.model.Attribute; @@ -153,18 +152,16 @@ public void onRequestEnd(final SearchPhaseContext context, final SearchRequestCo // Get internal computed and user provided labels Map labels = new HashMap<>(); // Retrieve user provided label if exists - ThreadContext threadContext = threadPool.getThreadContext(); - String userProvidedLabel = threadContext.getRequestHeadersOnly().get(Task.X_OPAQUE_ID); + String userProvidedLabel = RequestLabelingService.getUserProvidedTag(threadPool); if (userProvidedLabel != null) { labels.put(Task.X_OPAQUE_ID, userProvidedLabel); } // Retrieve computed labels if exists - Map computedLabels = threadContext.getTransient(RequestLabelingService.COMPUTED_LABELS); + Map computedLabels = RequestLabelingService.getRuleBasedLabels(threadPool); if (computedLabels != null) { labels.putAll(computedLabels); } attributes.put(Attribute.LABELS, labels); - // construct SearchQueryRecord from attributes and measurements SearchQueryRecord record = new SearchQueryRecord(request.getOrCreateAbsoluteStartMillis(), measurements, attributes); queryInsightsService.addRecord(record); diff --git a/plugins/query-insights/src/test/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListenerTests.java b/plugins/query-insights/src/test/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListenerTests.java index a4d4ca5736af0..d944ed46778f6 100644 --- a/plugins/query-insights/src/test/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListenerTests.java +++ b/plugins/query-insights/src/test/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListenerTests.java @@ -19,7 +19,9 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.plugin.insights.core.service.QueryInsightsService; import org.opensearch.plugin.insights.core.service.TopQueriesService; +import org.opensearch.plugin.insights.rules.model.Attribute; import org.opensearch.plugin.insights.rules.model.MetricType; +import org.opensearch.plugin.insights.rules.model.SearchQueryRecord; import org.opensearch.plugin.insights.settings.QueryInsightsSettings; import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.opensearch.search.aggregations.support.ValueType; @@ -35,10 +37,13 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Phaser; +import org.mockito.ArgumentCaptor; + import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -70,11 +75,12 @@ public void setup() { when(queryInsightsService.getTopQueriesService(MetricType.LATENCY)).thenReturn(topQueriesService); ThreadContext threadContext = new ThreadContext(Settings.EMPTY); - threadContext.setHeaders(new Tuple<>(Collections.singletonMap(Task.X_OPAQUE_ID, "test"), new HashMap<>())); - threadContext.putTransient(RequestLabelingService.COMPUTED_LABELS, Map.of("a", "b")); + threadContext.setHeaders(new Tuple<>(Collections.singletonMap(Task.X_OPAQUE_ID, "userLabel"), new HashMap<>())); + threadContext.putTransient(RequestLabelingService.RULE_BASED_LABELS, Map.of("labelKey", "labelValue")); when(threadPool.getThreadContext()).thenReturn(threadContext); } + @SuppressWarnings("unchecked") public void testOnRequestEnd() throws InterruptedException { Long timestamp = System.currentTimeMillis() - 100L; SearchType searchType = SearchType.QUERY_THEN_FETCH; @@ -101,10 +107,19 @@ public void testOnRequestEnd() throws InterruptedException { when(searchRequestContext.phaseTookMap()).thenReturn(phaseLatencyMap); when(searchPhaseContext.getRequest()).thenReturn(searchRequest); when(searchPhaseContext.getNumShards()).thenReturn(numberOfShards); + ArgumentCaptor captor = ArgumentCaptor.forClass(SearchQueryRecord.class); queryInsightsListener.onRequestEnd(searchPhaseContext, searchRequestContext); - verify(queryInsightsService, times(1)).addRecord(any()); + verify(queryInsightsService, times(1)).addRecord(captor.capture()); + SearchQueryRecord generatedRecord = captor.getValue(); + assertEquals(timestamp.longValue(), generatedRecord.getTimestamp()); + assertEquals(numberOfShards, generatedRecord.getAttributes().get(Attribute.TOTAL_SHARDS)); + assertEquals(searchType.toString().toLowerCase(Locale.ROOT), generatedRecord.getAttributes().get(Attribute.SEARCH_TYPE)); + assertEquals(searchSourceBuilder.toString(), generatedRecord.getAttributes().get(Attribute.SOURCE)); + Map labels = (Map) generatedRecord.getAttributes().get(Attribute.LABELS); + assertEquals("labelValue", labels.get("labelKey")); + assertEquals("userLabel", labels.get(Task.X_OPAQUE_ID)); } public void testConcurrentOnRequestEnd() throws InterruptedException { diff --git a/server/src/main/java/org/opensearch/search/labels/RequestLabelingService.java b/server/src/main/java/org/opensearch/search/labels/RequestLabelingService.java index 6e0f9dfc14355..8a37a322428d5 100644 --- a/server/src/main/java/org/opensearch/search/labels/RequestLabelingService.java +++ b/server/src/main/java/org/opensearch/search/labels/RequestLabelingService.java @@ -14,6 +14,7 @@ import org.opensearch.threadpool.ThreadPool; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.stream.Collectors; @@ -25,7 +26,7 @@ public class RequestLabelingService { /** * Field name for computed labels */ - public static final String COMPUTED_LABELS = "computed_labels"; + public static final String RULE_BASED_LABELS = "rule_based_labels"; private final ThreadPool threadPool; private final List rules; @@ -35,21 +36,22 @@ public RequestLabelingService(final ThreadPool threadPool, final List rule } /** - * Get all the existing rules - * - * @return list of existing rules - */ - public List getRules() { - return rules; - } - - /** - * Add a labeling rule to the service + * Evaluate all labeling rules and store the computed rules into thread context * - * @param rule {@link Rule} + * @param searchRequest {@link SearchRequest} */ - public void addRule(final Rule rule) { - this.rules.add(rule); + public void applyAllRules(final SearchRequest searchRequest) { + Map labels = rules.stream() + .map(rule -> rule.evaluate(threadPool.getThreadContext(), searchRequest)) + .flatMap(m -> m.entrySet().stream()) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (existing, replacement) -> replacement)); + String userProvidedTag = getUserProvidedTag(threadPool); + if (labels.containsKey(Task.X_OPAQUE_ID) && userProvidedTag.equals(labels.get(Task.X_OPAQUE_ID))) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Unexpected label %s found: %s", Task.X_OPAQUE_ID, userProvidedTag) + ); + } + threadPool.getThreadContext().putTransient(RULE_BASED_LABELS, labels); } /** @@ -57,20 +59,11 @@ public void addRule(final Rule rule) { * * @return user provided tag */ - public String getUserProvidedTag() { + public static String getUserProvidedTag(ThreadPool threadPool) { return threadPool.getThreadContext().getRequestHeadersOnly().getOrDefault(Task.X_OPAQUE_ID, null); } - /** - * Evaluate all labeling rules and store the computed rules into thread context - * - * @param searchRequest {@link SearchRequest} - */ - public void applyAllRules(final SearchRequest searchRequest) { - Map labels = rules.stream() - .map(rule -> rule.evaluate(threadPool.getThreadContext(), searchRequest)) - .flatMap(m -> m.entrySet().stream()) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (existing, replacement) -> replacement)); - threadPool.getThreadContext().putTransient(COMPUTED_LABELS, labels); + public static Map getRuleBasedLabels(ThreadPool threadPool) { + return threadPool.getThreadContext().getTransient(RequestLabelingService.RULE_BASED_LABELS); } } diff --git a/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java b/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java index d672bb199404f..2c191aa491b32 100644 --- a/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java +++ b/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java @@ -8,7 +8,6 @@ package org.opensearch.search.labels; -import org.opensearch.action.search.SearchPhaseContext; import org.opensearch.action.search.SearchRequestContext; import org.opensearch.action.search.SearchRequestOperationsListener; @@ -29,7 +28,4 @@ public void onRequestStart(SearchRequestContext searchRequestContext) { // add tags to search request requestLabelingService.applyAllRules(searchRequestContext.getRequest()); } - - @Override - public void onRequestEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) {} } diff --git a/server/src/test/java/org/opensearch/search/labels/RequestLabelingServiceTests.java b/server/src/test/java/org/opensearch/search/labels/RequestLabelingServiceTests.java index fe7f899d9c45e..2225002a3e6db 100644 --- a/server/src/test/java/org/opensearch/search/labels/RequestLabelingServiceTests.java +++ b/server/src/test/java/org/opensearch/search/labels/RequestLabelingServiceTests.java @@ -42,18 +42,10 @@ public void setUpVariables() { when(threadPool.getThreadContext()).thenReturn(threadContext); } - public void testAddRule() { - Rule mockRule = mock(Rule.class); - requestLabelingService.addRule(mockRule); - List rules = requestLabelingService.getRules(); - assertEquals(1, rules.size()); - assertEquals(mockRule, rules.get(0)); - } - public void testGetUserProvidedTag() { String expectedTag = "test-tag"; threadContext.setHeaders(new Tuple<>(Collections.singletonMap(Task.X_OPAQUE_ID, expectedTag), new HashMap<>())); - String actualTag = requestLabelingService.getUserProvidedTag(); + String actualTag = RequestLabelingService.getUserProvidedTag(threadPool); assertEquals(expectedTag, actualTag); } @@ -63,7 +55,7 @@ public void testBasicApplyAllRules() { when(mockRule1.evaluate(threadContext, mockSearchRequest)).thenReturn(mockLabelMap); rules.add(mockRule1); requestLabelingService.applyAllRules(mockSearchRequest); - Map computedLabels = threadContext.getTransient(RequestLabelingService.COMPUTED_LABELS); + Map computedLabels = threadContext.getTransient(RequestLabelingService.RULE_BASED_LABELS); assertEquals(1, computedLabels.size()); assertEquals("value1", computedLabels.get("label1")); } @@ -77,7 +69,7 @@ public void testApplyAllRulesWithConflict() { rules.add(mockRule1); rules.add(mockRule2); requestLabelingService.applyAllRules(mockSearchRequest); - Map computedLabels = threadContext.getTransient(RequestLabelingService.COMPUTED_LABELS); + Map computedLabels = threadContext.getTransient(RequestLabelingService.RULE_BASED_LABELS); assertEquals(1, computedLabels.size()); assertEquals("value2", computedLabels.get("conflictingLabel")); } @@ -91,7 +83,7 @@ public void testApplyAllRulesWithoutConflict() { rules.add(mockRule1); rules.add(mockRule2); requestLabelingService.applyAllRules(mockSearchRequest); - Map computedLabels = threadContext.getTransient(RequestLabelingService.COMPUTED_LABELS); + Map computedLabels = threadContext.getTransient(RequestLabelingService.RULE_BASED_LABELS); assertEquals(2, computedLabels.size()); assertEquals("value1", computedLabels.get("label1")); assertEquals("value2", computedLabels.get("label2")); From 9e5b6214fbc480749c6fb85f5483169e6cc2e8ec Mon Sep 17 00:00:00 2001 From: Chenyang Ji Date: Thu, 6 Jun 2024 12:24:05 -0700 Subject: [PATCH 5/5] p0 feature to support labeling in top queries Signed-off-by: Chenyang Ji --- CHANGELOG.md | 2 +- .../plugin/insights/QueryInsightsPlugin.java | 2 +- .../core/listener/QueryInsightsListener.java | 19 +--- .../listener/QueryInsightsListenerTests.java | 14 +-- .../main/java/org/opensearch/node/Node.java | 11 +-- .../search/labels/RequestLabelingService.java | 69 -------------- .../labels/SearchRequestLabelingListener.java | 31 ------- .../search/labels/package-info.java | 10 -- .../opensearch/search/labels/rules/Rule.java | 27 ------ .../search/labels/rules/package-info.java | 10 -- .../labels/RequestLabelingServiceTests.java | 91 ------------------- 11 files changed, 13 insertions(+), 273 deletions(-) delete mode 100644 server/src/main/java/org/opensearch/search/labels/RequestLabelingService.java delete mode 100644 server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java delete mode 100644 server/src/main/java/org/opensearch/search/labels/package-info.java delete mode 100644 server/src/main/java/org/opensearch/search/labels/rules/Rule.java delete mode 100644 server/src/main/java/org/opensearch/search/labels/rules/package-info.java delete mode 100644 server/src/test/java/org/opensearch/search/labels/RequestLabelingServiceTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 2dca88f7c2430..db0e26375cbfb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add remote routing table for remote state publication with experimental feature flag ([#13304](https://github.com/opensearch-project/OpenSearch/pull/13304)) - [Remote Store] Add support to disable flush based on translog reader count ([#14027](https://github.com/opensearch-project/OpenSearch/pull/14027)) - [Query Insights] Add exporter support for top n queries ([#12982](https://github.com/opensearch-project/OpenSearch/pull/12982)) -- Support rule-based labeling for search queries ([#13374](https://github.com/opensearch-project/OpenSearch/pull/13374)) +- [Query Insights] Add X-Opaque-Id to search request metadata for top n queries ([#13374](https://github.com/opensearch-project/OpenSearch/pull/13374)) ### Dependencies - Bump `com.github.spullara.mustache.java:compiler` from 0.9.10 to 0.9.13 ([#13329](https://github.com/opensearch-project/OpenSearch/pull/13329), [#13559](https://github.com/opensearch-project/OpenSearch/pull/13559)) diff --git a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/QueryInsightsPlugin.java b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/QueryInsightsPlugin.java index ee25da5d8b217..22831c3e0f8ba 100644 --- a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/QueryInsightsPlugin.java +++ b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/QueryInsightsPlugin.java @@ -71,7 +71,7 @@ public Collection createComponents( ) { // create top n queries service final QueryInsightsService queryInsightsService = new QueryInsightsService(clusterService.getClusterSettings(), threadPool, client); - return List.of(queryInsightsService, new QueryInsightsListener(threadPool, clusterService, queryInsightsService)); + return List.of(queryInsightsService, new QueryInsightsListener(clusterService, queryInsightsService)); } @Override diff --git a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java index 263c4d3a6f78d..cad2fe374f1b6 100644 --- a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java +++ b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java @@ -21,9 +21,7 @@ import org.opensearch.plugin.insights.rules.model.Attribute; import org.opensearch.plugin.insights.rules.model.MetricType; import org.opensearch.plugin.insights.rules.model.SearchQueryRecord; -import org.opensearch.search.labels.RequestLabelingService; import org.opensearch.tasks.Task; -import org.opensearch.threadpool.ThreadPool; import java.util.Collections; import java.util.HashMap; @@ -48,21 +46,15 @@ public final class QueryInsightsListener extends SearchRequestOperationsListener private static final Logger log = LogManager.getLogger(QueryInsightsListener.class); private final QueryInsightsService queryInsightsService; - private final ThreadPool threadPool; /** * Constructor for QueryInsightsListener * - * @param threadPool the OpenSearch internal threadPool * @param clusterService The Node's cluster service. * @param queryInsightsService The topQueriesByLatencyService associated with this listener */ @Inject - public QueryInsightsListener( - final ThreadPool threadPool, - final ClusterService clusterService, - final QueryInsightsService queryInsightsService - ) { + public QueryInsightsListener(final ClusterService clusterService, final QueryInsightsService queryInsightsService) { this.queryInsightsService = queryInsightsService; clusterService.getClusterSettings() .addSettingsUpdateConsumer(TOP_N_LATENCY_QUERIES_ENABLED, v -> this.setEnableTopQueries(MetricType.LATENCY, v)); @@ -83,7 +75,6 @@ public QueryInsightsListener( .setTopNSize(clusterService.getClusterSettings().get(TOP_N_LATENCY_QUERIES_SIZE)); this.queryInsightsService.getTopQueriesService(MetricType.LATENCY) .setWindowSize(clusterService.getClusterSettings().get(TOP_N_LATENCY_QUERIES_WINDOW_SIZE)); - this.threadPool = threadPool; } /** @@ -149,18 +140,12 @@ public void onRequestEnd(final SearchPhaseContext context, final SearchRequestCo attributes.put(Attribute.INDICES, request.indices()); attributes.put(Attribute.PHASE_LATENCY_MAP, searchRequestContext.phaseTookMap()); - // Get internal computed and user provided labels Map labels = new HashMap<>(); // Retrieve user provided label if exists - String userProvidedLabel = RequestLabelingService.getUserProvidedTag(threadPool); + String userProvidedLabel = context.getTask().getHeader(Task.X_OPAQUE_ID); if (userProvidedLabel != null) { labels.put(Task.X_OPAQUE_ID, userProvidedLabel); } - // Retrieve computed labels if exists - Map computedLabels = RequestLabelingService.getRuleBasedLabels(threadPool); - if (computedLabels != null) { - labels.putAll(computedLabels); - } attributes.put(Attribute.LABELS, labels); // construct SearchQueryRecord from attributes and measurements SearchQueryRecord record = new SearchQueryRecord(request.getOrCreateAbsoluteStartMillis(), measurements, attributes); diff --git a/plugins/query-insights/src/test/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListenerTests.java b/plugins/query-insights/src/test/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListenerTests.java index d944ed46778f6..b794a2e4b8608 100644 --- a/plugins/query-insights/src/test/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListenerTests.java +++ b/plugins/query-insights/src/test/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListenerTests.java @@ -11,6 +11,7 @@ import org.opensearch.action.search.SearchPhaseContext; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchRequestContext; +import org.opensearch.action.search.SearchTask; import org.opensearch.action.search.SearchType; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.collect.Tuple; @@ -26,7 +27,6 @@ import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.opensearch.search.aggregations.support.ValueType; import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.search.labels.RequestLabelingService; import org.opensearch.tasks.Task; import org.opensearch.test.ClusterServiceUtils; import org.opensearch.test.OpenSearchTestCase; @@ -76,7 +76,6 @@ public void setup() { ThreadContext threadContext = new ThreadContext(Settings.EMPTY); threadContext.setHeaders(new Tuple<>(Collections.singletonMap(Task.X_OPAQUE_ID, "userLabel"), new HashMap<>())); - threadContext.putTransient(RequestLabelingService.RULE_BASED_LABELS, Map.of("labelKey", "labelValue")); when(threadPool.getThreadContext()).thenReturn(threadContext); } @@ -88,6 +87,7 @@ public void testOnRequestEnd() throws InterruptedException { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.aggregation(new TermsAggregationBuilder("agg1").userValueTypeHint(ValueType.STRING).field("type.keyword")); searchSourceBuilder.size(0); + SearchTask task = new SearchTask(0, "n/a", "n/a", () -> "test", null, Collections.singletonMap(Task.X_OPAQUE_ID, "userLabel")); String[] indices = new String[] { "index-1", "index-2" }; @@ -98,7 +98,7 @@ public void testOnRequestEnd() throws InterruptedException { int numberOfShards = 10; - QueryInsightsListener queryInsightsListener = new QueryInsightsListener(threadPool, clusterService, queryInsightsService); + QueryInsightsListener queryInsightsListener = new QueryInsightsListener(clusterService, queryInsightsService); when(searchRequest.getOrCreateAbsoluteStartMillis()).thenReturn(timestamp); when(searchRequest.searchType()).thenReturn(searchType); @@ -107,6 +107,7 @@ public void testOnRequestEnd() throws InterruptedException { when(searchRequestContext.phaseTookMap()).thenReturn(phaseLatencyMap); when(searchPhaseContext.getRequest()).thenReturn(searchRequest); when(searchPhaseContext.getNumShards()).thenReturn(numberOfShards); + when(searchPhaseContext.getTask()).thenReturn(task); ArgumentCaptor captor = ArgumentCaptor.forClass(SearchQueryRecord.class); queryInsightsListener.onRequestEnd(searchPhaseContext, searchRequestContext); @@ -118,7 +119,6 @@ public void testOnRequestEnd() throws InterruptedException { assertEquals(searchType.toString().toLowerCase(Locale.ROOT), generatedRecord.getAttributes().get(Attribute.SEARCH_TYPE)); assertEquals(searchSourceBuilder.toString(), generatedRecord.getAttributes().get(Attribute.SOURCE)); Map labels = (Map) generatedRecord.getAttributes().get(Attribute.LABELS); - assertEquals("labelValue", labels.get("labelKey")); assertEquals("userLabel", labels.get(Task.X_OPAQUE_ID)); } @@ -129,6 +129,7 @@ public void testConcurrentOnRequestEnd() throws InterruptedException { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.aggregation(new TermsAggregationBuilder("agg1").userValueTypeHint(ValueType.STRING).field("type.keyword")); searchSourceBuilder.size(0); + SearchTask task = new SearchTask(0, "n/a", "n/a", () -> "test", null, Collections.singletonMap(Task.X_OPAQUE_ID, "userLabel")); String[] indices = new String[] { "index-1", "index-2" }; @@ -148,6 +149,7 @@ public void testConcurrentOnRequestEnd() throws InterruptedException { when(searchRequestContext.phaseTookMap()).thenReturn(phaseLatencyMap); when(searchPhaseContext.getRequest()).thenReturn(searchRequest); when(searchPhaseContext.getNumShards()).thenReturn(numberOfShards); + when(searchPhaseContext.getTask()).thenReturn(task); int numRequests = 50; Thread[] threads = new Thread[numRequests]; @@ -155,7 +157,7 @@ public void testConcurrentOnRequestEnd() throws InterruptedException { CountDownLatch countDownLatch = new CountDownLatch(numRequests); for (int i = 0; i < numRequests; i++) { - searchListenersList.add(new QueryInsightsListener(threadPool, clusterService, queryInsightsService)); + searchListenersList.add(new QueryInsightsListener(clusterService, queryInsightsService)); } for (int i = 0; i < numRequests; i++) { @@ -176,7 +178,7 @@ public void testConcurrentOnRequestEnd() throws InterruptedException { public void testSetEnabled() { when(queryInsightsService.isCollectionEnabled(MetricType.LATENCY)).thenReturn(true); - QueryInsightsListener queryInsightsListener = new QueryInsightsListener(threadPool, clusterService, queryInsightsService); + QueryInsightsListener queryInsightsListener = new QueryInsightsListener(clusterService, queryInsightsService); queryInsightsListener.setEnableTopQueries(MetricType.LATENCY, true); assertTrue(queryInsightsListener.isEnabled()); diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index 4d3f526e9a448..cb1f2caa082fc 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -227,9 +227,6 @@ import org.opensearch.search.backpressure.SearchBackpressureService; import org.opensearch.search.backpressure.settings.SearchBackpressureSettings; import org.opensearch.search.fetch.FetchPhase; -import org.opensearch.search.labels.RequestLabelingService; -import org.opensearch.search.labels.SearchRequestLabelingListener; -import org.opensearch.search.labels.rules.Rule; import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.search.query.QueryPhase; import org.opensearch.snapshots.InternalSnapshotsInfoService; @@ -965,17 +962,11 @@ protected Node( // Add the telemetryAwarePlugin components to the existing pluginComponents collection. pluginComponents.addAll(telemetryAwarePluginComponents); - final SearchRequestLabelingListener searchRequestLabelingListener = new SearchRequestLabelingListener( - new RequestLabelingService( - threadPool, - pluginComponents.stream().filter(p -> p instanceof Rule).map(p -> (Rule) p).collect(toList()) - ) - ); // register all standard SearchRequestOperationsCompositeListenerFactory to the SearchRequestOperationsCompositeListenerFactory final SearchRequestOperationsCompositeListenerFactory searchRequestOperationsCompositeListenerFactory = new SearchRequestOperationsCompositeListenerFactory( Stream.concat( - Stream.of(searchRequestStats, searchRequestSlowLog, searchRequestLabelingListener), + Stream.of(searchRequestStats, searchRequestSlowLog), pluginComponents.stream() .filter(p -> p instanceof SearchRequestOperationsListener) .map(p -> (SearchRequestOperationsListener) p) diff --git a/server/src/main/java/org/opensearch/search/labels/RequestLabelingService.java b/server/src/main/java/org/opensearch/search/labels/RequestLabelingService.java deleted file mode 100644 index 8a37a322428d5..0000000000000 --- a/server/src/main/java/org/opensearch/search/labels/RequestLabelingService.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.search.labels; - -import org.opensearch.action.search.SearchRequest; -import org.opensearch.search.labels.rules.Rule; -import org.opensearch.tasks.Task; -import org.opensearch.threadpool.ThreadPool; - -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.stream.Collectors; - -/** - * Service to attach labels to a search request based on pre-defined rules - * It evaluate all available rules and generate labels into the thread context. - */ -public class RequestLabelingService { - /** - * Field name for computed labels - */ - public static final String RULE_BASED_LABELS = "rule_based_labels"; - private final ThreadPool threadPool; - private final List rules; - - public RequestLabelingService(final ThreadPool threadPool, final List rules) { - this.threadPool = threadPool; - this.rules = rules; - } - - /** - * Evaluate all labeling rules and store the computed rules into thread context - * - * @param searchRequest {@link SearchRequest} - */ - public void applyAllRules(final SearchRequest searchRequest) { - Map labels = rules.stream() - .map(rule -> rule.evaluate(threadPool.getThreadContext(), searchRequest)) - .flatMap(m -> m.entrySet().stream()) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (existing, replacement) -> replacement)); - String userProvidedTag = getUserProvidedTag(threadPool); - if (labels.containsKey(Task.X_OPAQUE_ID) && userProvidedTag.equals(labels.get(Task.X_OPAQUE_ID))) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "Unexpected label %s found: %s", Task.X_OPAQUE_ID, userProvidedTag) - ); - } - threadPool.getThreadContext().putTransient(RULE_BASED_LABELS, labels); - } - - /** - * Get the user provided tag from the X-Opaque-Id header - * - * @return user provided tag - */ - public static String getUserProvidedTag(ThreadPool threadPool) { - return threadPool.getThreadContext().getRequestHeadersOnly().getOrDefault(Task.X_OPAQUE_ID, null); - } - - public static Map getRuleBasedLabels(ThreadPool threadPool) { - return threadPool.getThreadContext().getTransient(RequestLabelingService.RULE_BASED_LABELS); - } -} diff --git a/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java b/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java deleted file mode 100644 index 2c191aa491b32..0000000000000 --- a/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java +++ /dev/null @@ -1,31 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.search.labels; - -import org.opensearch.action.search.SearchRequestContext; -import org.opensearch.action.search.SearchRequestOperationsListener; - -/** - * SearchRequestOperationsListener subscriber for labeling search requests - * - * @opensearch.internal - */ -public final class SearchRequestLabelingListener extends SearchRequestOperationsListener { - final private RequestLabelingService requestLabelingService; - - public SearchRequestLabelingListener(final RequestLabelingService requestLabelingService) { - this.requestLabelingService = requestLabelingService; - } - - @Override - public void onRequestStart(SearchRequestContext searchRequestContext) { - // add tags to search request - requestLabelingService.applyAllRules(searchRequestContext.getRequest()); - } -} diff --git a/server/src/main/java/org/opensearch/search/labels/package-info.java b/server/src/main/java/org/opensearch/search/labels/package-info.java deleted file mode 100644 index acb7b154cb3f2..0000000000000 --- a/server/src/main/java/org/opensearch/search/labels/package-info.java +++ /dev/null @@ -1,10 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -/** Search labeling service. */ -package org.opensearch.search.labels; diff --git a/server/src/main/java/org/opensearch/search/labels/rules/Rule.java b/server/src/main/java/org/opensearch/search/labels/rules/Rule.java deleted file mode 100644 index 331ba92b1e70f..0000000000000 --- a/server/src/main/java/org/opensearch/search/labels/rules/Rule.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.search.labels.rules; - -import org.opensearch.action.search.SearchRequest; -import org.opensearch.common.util.concurrent.ThreadContext; - -import java.util.Map; - -/** - * An interface to define a labeling rule - */ -public interface Rule { - /** - * Defines the rule to calculate labels from the context and request - * - * @return a Map of labels for POC - */ - public Map evaluate(final ThreadContext threadContext, final SearchRequest searchRequest); - -} diff --git a/server/src/main/java/org/opensearch/search/labels/rules/package-info.java b/server/src/main/java/org/opensearch/search/labels/rules/package-info.java deleted file mode 100644 index 8d16a48e3a57b..0000000000000 --- a/server/src/main/java/org/opensearch/search/labels/rules/package-info.java +++ /dev/null @@ -1,10 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -/** Search labeling rules. */ -package org.opensearch.search.labels.rules; diff --git a/server/src/test/java/org/opensearch/search/labels/RequestLabelingServiceTests.java b/server/src/test/java/org/opensearch/search/labels/RequestLabelingServiceTests.java deleted file mode 100644 index 2225002a3e6db..0000000000000 --- a/server/src/test/java/org/opensearch/search/labels/RequestLabelingServiceTests.java +++ /dev/null @@ -1,91 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.search.labels; - -import org.opensearch.action.search.SearchRequest; -import org.opensearch.common.collect.Tuple; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.search.labels.rules.Rule; -import org.opensearch.tasks.Task; -import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.threadpool.ThreadPool; -import org.junit.Before; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -public class RequestLabelingServiceTests extends OpenSearchTestCase { - private RequestLabelingService requestLabelingService; - private ThreadContext threadContext; - private final ThreadPool threadPool = mock(ThreadPool.class); - private final Rule mockRule1 = mock(Rule.class); - private final Rule mockRule2 = mock(Rule.class); - private final List rules = new ArrayList<>(); - - @Before - public void setUpVariables() { - requestLabelingService = new RequestLabelingService(threadPool, rules); - threadContext = new ThreadContext(Settings.EMPTY); - when(threadPool.getThreadContext()).thenReturn(threadContext); - } - - public void testGetUserProvidedTag() { - String expectedTag = "test-tag"; - threadContext.setHeaders(new Tuple<>(Collections.singletonMap(Task.X_OPAQUE_ID, expectedTag), new HashMap<>())); - String actualTag = RequestLabelingService.getUserProvidedTag(threadPool); - assertEquals(expectedTag, actualTag); - } - - public void testBasicApplyAllRules() { - SearchRequest mockSearchRequest = mock(SearchRequest.class); - Map mockLabelMap = Collections.singletonMap("label1", "value1"); - when(mockRule1.evaluate(threadContext, mockSearchRequest)).thenReturn(mockLabelMap); - rules.add(mockRule1); - requestLabelingService.applyAllRules(mockSearchRequest); - Map computedLabels = threadContext.getTransient(RequestLabelingService.RULE_BASED_LABELS); - assertEquals(1, computedLabels.size()); - assertEquals("value1", computedLabels.get("label1")); - } - - public void testApplyAllRulesWithConflict() { - SearchRequest mockSearchRequest = mock(SearchRequest.class); - Map mockLabelMap1 = Collections.singletonMap("conflictingLabel", "value1"); - Map mockLabelMap2 = Collections.singletonMap("conflictingLabel", "value2"); - when(mockRule1.evaluate(threadContext, mockSearchRequest)).thenReturn(mockLabelMap1); - when(mockRule2.evaluate(threadContext, mockSearchRequest)).thenReturn(mockLabelMap2); - rules.add(mockRule1); - rules.add(mockRule2); - requestLabelingService.applyAllRules(mockSearchRequest); - Map computedLabels = threadContext.getTransient(RequestLabelingService.RULE_BASED_LABELS); - assertEquals(1, computedLabels.size()); - assertEquals("value2", computedLabels.get("conflictingLabel")); - } - - public void testApplyAllRulesWithoutConflict() { - SearchRequest mockSearchRequest = mock(SearchRequest.class); - Map mockLabelMap1 = Collections.singletonMap("label1", "value1"); - Map mockLabelMap2 = Collections.singletonMap("label2", "value2"); - when(mockRule1.evaluate(threadContext, mockSearchRequest)).thenReturn(mockLabelMap1); - when(mockRule2.evaluate(threadContext, mockSearchRequest)).thenReturn(mockLabelMap2); - rules.add(mockRule1); - rules.add(mockRule2); - requestLabelingService.applyAllRules(mockSearchRequest); - Map computedLabels = threadContext.getTransient(RequestLabelingService.RULE_BASED_LABELS); - assertEquals(2, computedLabels.size()); - assertEquals("value1", computedLabels.get("label1")); - assertEquals("value2", computedLabels.get("label2")); - } -}