Skip to content

Commit

Permalink
Add user info in top queries
Browse files Browse the repository at this point in the history
Signed-off-by: Chenyang Ji <[email protected]>
  • Loading branch information
ansjcy committed Mar 4, 2024
1 parent bd8af6a commit bfdefdd
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 10 deletions.
19 changes: 19 additions & 0 deletions plugins/query-insights/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,24 @@ opensearchplugin {
classname 'org.opensearch.plugin.insights.QueryInsightsPlugin'
}

dependencyLicenses.enabled = false

ext {
opensearch_version = versions.opensearch
// 3.0.0-SNAPSHOT -> 3.0.0.0
version_tokens = opensearch_version.tokenize('-')
common_utils_version = version_tokens[0] + '.0'
if (version_tokens.size() > 1) {
common_utils_version = common_utils_version + "-" + version_tokens[1]
}
}

repositories {
mavenLocal()
mavenCentral()
maven { url "https://aws.oss.sonatype.org/content/repositories/snapshots" }
}

dependencies {
implementation "org.opensearch:common-utils:${common_utils_version}"
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public Collection<Object> createComponents(
) {
// create top n queries service
final QueryInsightsService queryInsightsService = new QueryInsightsService(threadPool);
return List.of(queryInsightsService, new QueryInsightsListener(clusterService, queryInsightsService));
return List.of(queryInsightsService, new QueryInsightsListener(clusterService, queryInsightsService, threadPool));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
import org.opensearch.action.search.SearchRequestOperationsListener;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.plugin.insights.core.service.QueryInsightsService;
import org.opensearch.plugin.insights.rules.model.Attribute;
import org.opensearch.plugin.insights.rules.model.MetricType;
import org.opensearch.plugin.insights.rules.model.SearchQueryRecord;
import org.opensearch.plugin.insights.settings.QueryInsightsSettings;
import org.opensearch.threadpool.ThreadPool;

import java.util.Collections;
import java.util.HashMap;
Expand All @@ -45,16 +48,23 @@ public final class QueryInsightsListener extends SearchRequestOperationsListener
private static final Logger log = LogManager.getLogger(QueryInsightsListener.class);

private final QueryInsightsService queryInsightsService;
private final ThreadPool threadPool;

/**
* Constructor for QueryInsightsListener
*
* @param clusterService The Node's cluster service.
* @param queryInsightsService The topQueriesByLatencyService associated with this listener
* @param threadPool The OpenSearch thread pool to run async tasks
*/
@Inject
public QueryInsightsListener(final ClusterService clusterService, final QueryInsightsService queryInsightsService) {
public QueryInsightsListener(
final ClusterService clusterService,
final QueryInsightsService queryInsightsService,
final ThreadPool threadPool
) {
this.queryInsightsService = queryInsightsService;
this.threadPool = threadPool;
clusterService.getClusterSettings()
.addSettingsUpdateConsumer(TOP_N_LATENCY_QUERIES_ENABLED, v -> this.setEnableTopQueries(MetricType.LATENCY, v));
clusterService.getClusterSettings()
Expand Down Expand Up @@ -138,6 +148,16 @@ public void onRequestEnd(final SearchPhaseContext context, final SearchRequestCo
attributes.put(Attribute.TOTAL_SHARDS, context.getNumShards());
attributes.put(Attribute.INDICES, request.indices());
attributes.put(Attribute.PHASE_LATENCY_MAP, searchRequestContext.phaseTookMap());
// add user related information
Object userInfo = threadPool.getThreadContext().getTransient(QueryInsightsSettings.REQUEST_HEADER_USER_INFO);
if (userInfo != null) {
attributes.put(Attribute.USER, User.parse(userInfo.toString()));
}
// add remote ip address
Object remoteAddress = threadPool.getThreadContext().getTransient(QueryInsightsSettings.REQUEST_HEADER_REMOTE_ADDRESS);
if (remoteAddress != null) {
attributes.put(Attribute.REMOTE_ADDRESS, remoteAddress.toString());
}
SearchQueryRecord record = new SearchQueryRecord(request.getOrCreateAbsoluteStartMillis(), measurements, attributes);
queryInsightsService.addRecord(record);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,15 @@ public enum Attribute {
/**
* The node id for this request
*/
NODE_ID;
NODE_ID,
/**
* The remote address of this request
*/
REMOTE_ADDRESS,
/**
* Information related to the user who sent this request
*/
USER;

/**
* Read an Attribute from a StreamInput
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@
* @opensearch.experimental
*/
public class QueryInsightsSettings {
/**
* Constant setting for user info header key that are injected during authentication
*/
public static final String REQUEST_HEADER_USER_INFO = "_opendistro_security_user_info";
/**
* Constant setting for remote address info header key that are injected during authentication
*/
public static final String REQUEST_HEADER_REMOTE_ADDRESS = "_opendistro_security_remote_address";

/**
* Executors settings
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,32 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.InjectSecurity;
import org.opensearch.commons.authuser.User;
import org.opensearch.plugin.insights.core.service.QueryInsightsService;
import org.opensearch.plugin.insights.core.service.TopQueriesService;
import org.opensearch.plugin.insights.rules.model.Attribute;
import org.opensearch.plugin.insights.rules.model.MetricType;
import org.opensearch.plugin.insights.rules.model.SearchQueryRecord;
import org.opensearch.plugin.insights.settings.QueryInsightsSettings;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.opensearch.search.aggregations.support.ValueType;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.junit.Before;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Phaser;

import org.mockito.ArgumentCaptor;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
Expand All @@ -47,23 +56,34 @@ public class QueryInsightsListenerTests extends OpenSearchTestCase {
private final SearchRequest searchRequest = mock(SearchRequest.class);
private final QueryInsightsService queryInsightsService = mock(QueryInsightsService.class);
private final TopQueriesService topQueriesService = mock(TopQueriesService.class);
private final ThreadPool threadPool = mock(ThreadPool.class);
private final Settings.Builder settingsBuilder = Settings.builder();
private final Settings settings = settingsBuilder.build();
private final String remoteAddress = "1.2.3.4";
private User user;
private ClusterService clusterService;

@Before
public void setup() {
Settings.Builder settingsBuilder = Settings.builder();
Settings settings = settingsBuilder.build();
ClusterSettings clusterSettings = new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
clusterSettings.registerSetting(QueryInsightsSettings.TOP_N_LATENCY_QUERIES_ENABLED);
clusterSettings.registerSetting(QueryInsightsSettings.TOP_N_LATENCY_QUERIES_SIZE);
clusterSettings.registerSetting(QueryInsightsSettings.TOP_N_LATENCY_QUERIES_WINDOW_SIZE);
clusterService = new ClusterService(settings, clusterSettings, null);
when(queryInsightsService.isCollectionEnabled(MetricType.LATENCY)).thenReturn(true);
when(queryInsightsService.getTopQueriesService(MetricType.LATENCY)).thenReturn(topQueriesService);

// inject user info
ThreadContext threadContext = new ThreadContext(settings);
user = new User("user-1", List.of("role1", "role2"), List.of("role3", "role4"), List.of());
InjectSecurity injector = new InjectSecurity("id", settings, threadContext);
injector.injectUserInfo(user);
threadContext.putTransient(QueryInsightsSettings.REQUEST_HEADER_REMOTE_ADDRESS, remoteAddress);
when(threadPool.getThreadContext()).thenReturn(threadContext);
}

public void testOnRequestEnd() throws InterruptedException {
Long timestamp = System.currentTimeMillis() - 100L;
public void testOnRequestEnd() {
long timestamp = System.currentTimeMillis() - 100L;
SearchType searchType = SearchType.QUERY_THEN_FETCH;

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
Expand All @@ -79,7 +99,7 @@ public void testOnRequestEnd() throws InterruptedException {

int numberOfShards = 10;

QueryInsightsListener queryInsightsListener = new QueryInsightsListener(clusterService, queryInsightsService);
QueryInsightsListener queryInsightsListener = new QueryInsightsListener(clusterService, queryInsightsService, threadPool);

when(searchRequest.getOrCreateAbsoluteStartMillis()).thenReturn(timestamp);
when(searchRequest.searchType()).thenReturn(searchType);
Expand All @@ -92,6 +112,16 @@ public void testOnRequestEnd() throws InterruptedException {
queryInsightsListener.onRequestEnd(searchPhaseContext, searchRequestContext);

verify(queryInsightsService, times(1)).addRecord(any());
ArgumentCaptor<SearchQueryRecord> argumentCaptor = ArgumentCaptor.forClass(SearchQueryRecord.class);
verify(queryInsightsService).addRecord(argumentCaptor.capture());
assertEquals(timestamp, argumentCaptor.getValue().getTimestamp());
Map<Attribute, Object> attrs = argumentCaptor.getValue().getAttributes();
assertEquals(searchType.toString().toLowerCase(Locale.ROOT), attrs.get(Attribute.SEARCH_TYPE));
assertEquals(numberOfShards, attrs.get(Attribute.TOTAL_SHARDS));
assertEquals(indices, attrs.get(Attribute.INDICES));
assertEquals(phaseLatencyMap, attrs.get(Attribute.PHASE_LATENCY_MAP));
assertEquals(user, attrs.get(Attribute.USER));
assertEquals(remoteAddress, attrs.get(Attribute.REMOTE_ADDRESS));
}

public void testConcurrentOnRequestEnd() throws InterruptedException {
Expand Down Expand Up @@ -127,7 +157,7 @@ public void testConcurrentOnRequestEnd() throws InterruptedException {
CountDownLatch countDownLatch = new CountDownLatch(numRequests);

for (int i = 0; i < numRequests; i++) {
searchListenersList.add(new QueryInsightsListener(clusterService, queryInsightsService));
searchListenersList.add(new QueryInsightsListener(clusterService, queryInsightsService, threadPool));
}

for (int i = 0; i < numRequests; i++) {
Expand All @@ -148,7 +178,7 @@ public void testConcurrentOnRequestEnd() throws InterruptedException {

public void testSetEnabled() {
when(queryInsightsService.isCollectionEnabled(MetricType.LATENCY)).thenReturn(true);
QueryInsightsListener queryInsightsListener = new QueryInsightsListener(clusterService, queryInsightsService);
QueryInsightsListener queryInsightsListener = new QueryInsightsListener(clusterService, queryInsightsService, threadPool);
queryInsightsListener.setEnableTopQueries(MetricType.LATENCY, true);
assertTrue(queryInsightsListener.isEnabled());

Expand Down

0 comments on commit bfdefdd

Please sign in to comment.