Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ability to run query sets and save results on OpenSearch #49

Merged
merged 11 commits into from
Dec 3, 2024
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash -e

#QUERY_SET=`curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=pptss" | jq .query_set | tr -d '"'`
curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=pptss&query_set_size=5000"
curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=pptss&query_set_size=100"

#echo ${QUERY_SET}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash -e

curl -s "http://localhost:9200/search_quality_eval_query_sets_run_results/_search" | jq
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash -e

curl -s "http://localhost:9200/search_quality_eval_query_sets/_search" | jq
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash -e

QUERY_SET_ID="${1}"
JUDGMENTS_ID="12345"
INDEX="ecommerce"
ID_FIELD="asin"
K="10"

curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/run?id=${QUERY_SET_ID}&judgments_id=${JUDGMENTS_ID}&index=${INDEX}&id_field=${ID_FIELD}&k=${K}" \
-H "Content-Type: application/json" \
--data-binary '{
"query": {
"match": {
"description": {
"query": "#$query##"
}
}
}
}'
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
Expand Down Expand Up @@ -139,10 +140,25 @@ public void runJob(final ScheduledJobParameter jobParameter, final JobExecutionC
job.put("invocation", "scheduled");
job.put("max_rank", searchQualityEvaluationJobParameter.getMaxRank());

final IndexRequest indexRequest = new IndexRequest().index(SearchQualityEvaluationPlugin.COMPLETED_JOBS_INDEX_NAME)
.id(UUID.randomUUID().toString()).source(job).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

client.index(indexRequest).get();
final String judgmentsId = UUID.randomUUID().toString();

final IndexRequest indexRequest = new IndexRequest()
.index(SearchQualityEvaluationPlugin.COMPLETED_JOBS_INDEX_NAME)
.id(judgmentsId)
.source(job)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

client.index(indexRequest, new ActionListener<>() {
@Override
public void onResponse(IndexResponse indexResponse) {
LOGGER.info("Successfully indexed implicit judgments {}", judgmentsId);
}

@Override
public void onFailure(Exception ex) {
LOGGER.error("Unable to index implicit judgments", ex);
}
});

}, exception -> { throw new IllegalStateException("Failed to acquire lock."); }));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ public class SearchQualityEvaluationPlugin extends Plugin implements ActionPlugi
*/
public static final String QUERY_SETS_INDEX_NAME = "search_quality_eval_query_sets";

/**
* The name of the index that stores the query set run results.
*/
public static final String QUERY_SETS_RUN_RESULTS = "search_quality_eval_query_sets_run_results";

@Override
public Collection<Object> createComponents(
final Client client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.eval.judgments.clickmodel.coec.CoecClickModel;
import org.opensearch.eval.judgments.clickmodel.coec.CoecClickModelParameters;
Expand All @@ -40,6 +41,7 @@
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;

public class SearchQualityEvaluationRestHandler extends BaseRestHandler {

Expand All @@ -65,6 +67,11 @@ public class SearchQualityEvaluationRestHandler extends BaseRestHandler {
*/
public static final String QUERYSET_RUN_URL = "/_plugins/search_quality_eval/run";

/**
* The placeholder in the query that gets replaced by the query term when running a query set.
*/
public static final String QUERY_PLACEHOLDER = "#$query##";

@Override
public String getName() {
return "Search Quality Evaluation Framework";
Expand Down Expand Up @@ -98,7 +105,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
if (AllQueriesQuerySampler.NAME.equalsIgnoreCase(sampling)) {

// If we are not sampling queries, the query sets should just be directly
// indexed into OpenSearch using the `ubu_queries` index directly.
// indexed into OpenSearch using the `ubi_queries` index directly.

try {

Expand Down Expand Up @@ -148,20 +155,43 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
} else if(QUERYSET_RUN_URL.equalsIgnoreCase(request.path())) {

final String querySetId = request.param("id");
final String judgmentsId = request.param("judgments_id");
final String index = request.param("index");
final String idField = request.param("id_field", "_id");
final int k = Integer.parseInt(request.param("k", "10"));

if(querySetId == null || judgmentsId == null || index == null) {
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, "{\"error\": \"Missing required parameters.\"}"));
}

if(k < 1) {
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, "{\"error\": \"k must be a positive integer.\"}"));
}

if(!request.hasContent()) {
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, "{\"error\": \"Missing query in body.\"}"));
}

// Get the query JSON from the content.
final String query = new String(BytesReference.toBytes(request.content()));

// Validate the query has a QUERY_PLACEHOLDER.
if(!query.contains(QUERY_PLACEHOLDER)) {
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, "{\"error\": \"Missing query placeholder in query.\"}"));
}

try {

final OpenSearchQuerySetRunner openSearchQuerySetRunner = new OpenSearchQuerySetRunner(client);
final QuerySetRunResult querySetRunResult = openSearchQuerySetRunner.run(querySetId);

// TODO: Index the querySetRunResult.
final QuerySetRunResult querySetRunResult = openSearchQuerySetRunner.run(querySetId, judgmentsId, index, idField, query, k);
openSearchQuerySetRunner.save(querySetRunResult);

} catch (Exception ex) {
LOGGER.error("Unable to retrieve query set with ID {}", querySetId);
LOGGER.error("Unable to run query set with ID {}: ", querySetId, ex);
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, ex.getMessage()));
}

return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"message\": \"Query set " + querySetId + " run initiated.\"}"));
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"message\": \"Run initiated for query set " + querySetId + "\"}"));

// Handle the on-demand creation of implicit judgments.
} else if(IMPLICIT_JUDGMENTS_URL.equalsIgnoreCase(request.path())) {
Expand Down Expand Up @@ -196,16 +226,35 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
job.put("invocation", "on_demand");
job.put("max_rank", maxRank);

final IndexRequest indexRequest = new IndexRequest().index(SearchQualityEvaluationPlugin.COMPLETED_JOBS_INDEX_NAME)
.id(UUID.randomUUID().toString()).source(job).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
final String judgmentsId = UUID.randomUUID().toString();

try {
client.index(indexRequest).get();
} catch (Exception e) {
throw new RuntimeException(e);
}
final IndexRequest indexRequest = new IndexRequest()
.index(SearchQualityEvaluationPlugin.COMPLETED_JOBS_INDEX_NAME)
.id(judgmentsId)
.source(job)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

final AtomicBoolean success = new AtomicBoolean(false);

client.index(indexRequest, new ActionListener<>() {
@Override
public void onResponse(final IndexResponse indexResponse) {
LOGGER.debug("Judgments indexed: {}", judgmentsId);
success.set(true);
}

return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"message\": \"Implicit judgment generation initiated.\"}"));
@Override
public void onFailure(final Exception ex) {
LOGGER.error("Unable to index judgment with ID {}", judgmentsId, ex);
success.set(false);
}
});

if(success.get()) {
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"judgments_id\": \"" + judgmentsId + "\"}"));
} else {
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR,"Unable to index judgments."));
}

} else {
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, "{\"error\": \"Invalid click model.\"}"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,50 @@
*/
package org.opensearch.eval.runners;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.eval.SearchQualityEvaluationPlugin;
import org.opensearch.eval.judgments.model.Judgment;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class OpenSearchQuerySetRunner extends QuerySetRunner {
import static org.opensearch.eval.SearchQualityEvaluationRestHandler.QUERY_PLACEHOLDER;

/**
* A {@link QuerySetRunner} for Amazon OpenSearch.
*/
public class OpenSearchQuerySetRunner implements QuerySetRunner {

private static final Logger LOGGER = LogManager.getLogger(OpenSearchQuerySetRunner.class);

final Client client;

/**
* Creates a new query set runner
* @param client An OpenSearch {@link Client}.
*/
public OpenSearchQuerySetRunner(final Client client) {
this.client = client;
}

@Override
public QuerySetRunResult run(String querySetId) {
public QuerySetRunResult run(final String querySetId, final String judgmentsId, final String index, final String idField, final String query, final int k) {

// TODO: Get the judgments we will use for metric calculation.
final List<Judgment> judgments = new ArrayList<>();

// Get the query set.
final SearchSourceBuilder getQuerySetSearchSourceBuilder = new SearchSourceBuilder();
Expand All @@ -40,43 +62,70 @@ public QuerySetRunResult run(String querySetId) {

try {

// TODO: Don't use .get()
final SearchResponse searchResponse = client.search(getQuerySetSearchRequest).get();

// The queries from the query set that will be run.
final Collection<String> queries = (Collection<String>) searchResponse.getHits().getAt(0).getSourceAsMap().get("queries");
final Collection<Map<String, Long>> queries = (Collection<Map<String, Long>>) searchResponse.getHits().getAt(0).getSourceAsMap().get("queries");

// The results of each query.
final Collection<QueryResult> queryResults = new ArrayList<>();
final List<QueryResult> queryResults = new ArrayList<>();

// TODO: Initiate the running of the query set.
for(final String query : queries) {
for(Map<String, Long> queryMap : queries) {

// TODO: What should this query be?
final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(QueryBuilders.matchQuery("title", query));
// TODO: Just fetch the id ("asin") field and not all the unnecessary fields.
// Loop over each query in the map and run each one.
for (final String userQuery : queryMap.keySet()) {

// TODO: Allow for setting this index name.
final SearchRequest searchRequest = new SearchRequest("ecommerce");
getQuerySetSearchRequest.source(getQuerySetSearchSourceBuilder);
// Replace the query placeholder with the user query.
final String q = query.replace(QUERY_PLACEHOLDER, userQuery);

final SearchResponse sr = client.search(searchRequest).get();
// Build the query from the one that was passed in.
final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(QueryBuilders.wrapperQuery(q));
searchSourceBuilder.from(0);
// TODO: If k is > 10, we'll need to page through these.
searchSourceBuilder.size(k);

final List<String> orderedDocumentIds = new ArrayList<>();
String[] includeFields = new String[] {idField};
String[] excludeFields = new String[] {};
searchSourceBuilder.fetchSource(includeFields, excludeFields);

for(final SearchHit hit : sr.getHits().getHits()) {
// TODO: Allow for setting this index name.
final SearchRequest searchRequest = new SearchRequest(index);
getQuerySetSearchRequest.source(searchSourceBuilder);

// TODO: This field needs to be customizable.
orderedDocumentIds.add(hit.getFields().get("asin").toString());
client.search(searchRequest, new ActionListener<>() {

}
@Override
public void onResponse(final SearchResponse searchResponse) {

final List<String> orderedDocumentIds = new ArrayList<>();

for (final SearchHit hit : searchResponse.getHits().getHits()) {

queryResults.add(new QueryResult(orderedDocumentIds));
final Map<String, Object> sourceAsMap = hit.getSourceAsMap();
final String documentId = sourceAsMap.get(idField).toString();

orderedDocumentIds.add(documentId);

}

queryResults.add(new QueryResult(query, orderedDocumentIds, judgments, k));

}

@Override
public void onFailure(Exception ex) {
LOGGER.error("Unable to search for query: {}", query, ex);
}
});

}

}

// TODO: Calculate the search metrics given the results and the judgments.
final SearchMetrics searchMetrics = new SearchMetrics();
final SearchMetrics searchMetrics = new SearchMetrics(queryResults, judgments, k);

return new QuerySetRunResult(queryResults, searchMetrics);

Expand All @@ -86,4 +135,31 @@ public QuerySetRunResult run(String querySetId) {

}

@Override
public void save(final QuerySetRunResult result) throws Exception {

// Index the results into OpenSearch.

final Map<String, Object> results = new HashMap<>();

results.put("run_id", result.getRunId());
results.put("search_metrics", result.getSearchMetrics().getSearchMetricsAsMap());
results.put("query_results", result.getQueryResultsAsMap());

final IndexRequest indexRequest = new IndexRequest(SearchQualityEvaluationPlugin.QUERY_SETS_RUN_RESULTS);
indexRequest.source(results);

client.index(indexRequest, new ActionListener<>() {
@Override
public void onResponse(IndexResponse indexResponse) {
LOGGER.debug("Query set results indexed.");
}

@Override
public void onFailure(Exception ex) {
throw new RuntimeException(ex);
}
});
}

}
Loading
Loading