Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds PPTSS sampling #46

Merged
merged 4 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions data/esci/ubi_queries_events.ndjson.bz2

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ services:
logger.level: info
OPENSEARCH_INITIAL_ADMIN_PASSWORD: SuperSecretPassword_123
http.max_content_length: 500mb
OPENSEARCH_JAVA_OPTS: "-Xms8192m -Xmx8192m"
OPENSEARCH_JAVA_OPTS: "-Xms8g -Xmx8g"
ulimits:
memlock:
soft: -1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash -e

#QUERY_SET=`curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=pptss" | jq .query_set | tr -d '"'`
curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=none&max_queries=500"
curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=none&query_set_size=500"

#echo ${QUERY_SET}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#!/bin/bash -e

#QUERY_SET=`curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=pptss" | jq .query_set | tr -d '"'`
curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=pptss&max_queries=500"
curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=pptss&query_set_size=5000"

#echo ${QUERY_SET}

#curl -s http://localhost:9200/search_quality_eval_query_sets/_search | jq
#curl -s -X GET http://localhost:9200/search_quality_eval_query_sets/_doc/${QUERY_SET} | jq

# Run the query set now.
#curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/run?id=${QUERY_SET}" | jq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
*/
package org.opensearch.eval;

import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.delete.DeleteRequest;
Expand All @@ -24,24 +23,25 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.eval.judgments.clickmodel.coec.CoecClickModel;
import org.opensearch.eval.judgments.clickmodel.coec.CoecClickModelParameters;
import org.opensearch.eval.samplers.AllQueriesQuerySampler;
import org.opensearch.eval.samplers.AllQueriesQuerySamplerParameters;
import org.opensearch.eval.samplers.ProbabilityProportionalToSizeAbstractQuerySampler;
import org.opensearch.eval.samplers.ProbabilityProportionalToSizeParameters;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestResponse;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

import java.io.IOException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;

public class SearchQualityEvaluationRestHandler extends BaseRestHandler {
Expand Down Expand Up @@ -87,47 +87,30 @@ public List<Route> routes() {
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {

// Handle managing query sets.
if(StringUtils.equalsIgnoreCase(request.path(), QUERYSET_MANAGEMENT_URL)) {
if(QUERYSET_MANAGEMENT_URL.equalsIgnoreCase(request.path())) {

// Creating a new query set by sampling the UBI queries.
if (request.method().equals(RestRequest.Method.POST)) {

final String name = request.param("name");
final String description = request.param("description");
final String sampling = request.param("sampling", "pptss");
final int maxQueries = Integer.parseInt(request.param("max_queries", "1000"));
final int querySetSize = Integer.parseInt(request.param("query_set_size", "1000"));

// Create a query set by finding all the unique user_query terms.
if (StringUtils.equalsIgnoreCase(sampling, "none")) {
if ("none".equalsIgnoreCase(sampling)) {

// If we are not sampling queries, the query sets should just be directly
// indexed into OpenSearch using the `ubu_queries` index directly.

try {

// Get queries from the UBI queries index.
final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(QueryBuilders.matchAllQuery());
searchSourceBuilder.from(0);
searchSourceBuilder.size(maxQueries);
final AllQueriesQuerySamplerParameters parameters = new AllQueriesQuerySamplerParameters(name, description, sampling, querySetSize);
final AllQueriesQuerySampler sampler = new AllQueriesQuerySampler(client, parameters);

final SearchRequest searchRequest = new SearchRequest(SearchQualityEvaluationPlugin.UBI_QUERIES_INDEX_NAME);
searchRequest.source(searchSourceBuilder);
// Sample and index the queries.
final String querySetId = sampler.sample();

final SearchResponse searchResponse = client.search(searchRequest).get();

LOGGER.info("Found {} user queries from the ubi_queries index.", searchResponse.getHits().getTotalHits().toString());

final Set<String> queries = new HashSet<>();
for(final SearchHit hit : searchResponse.getHits().getHits()) {
final Map<String, Object> fields = hit.getSourceAsMap();
queries.add(fields.get("user_query").toString());
}

LOGGER.info("Found {} user queries from the ubi_queries index.", queries.size());

// Create the query set and return its ID.
final String querySetId = indexQuerySet(client, name, description, sampling, queries);
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"query_set\": \"" + querySetId + "\"}"));

} catch(Exception ex) {
Expand All @@ -136,15 +119,18 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli


// Create a query set by using PPTSS sampling.
} else if (StringUtils.equalsIgnoreCase(sampling, "pptss")) {
} else if ("pptss".equalsIgnoreCase(sampling)) {

// TODO: Use the PPTSS sampling method - https://opensourceconnections.com/blog/2022/10/13/how-to-succeed-with-explicit-relevance-evaluation-using-probability-proportional-to-size-sampling/
final Collection<String> queries = List.of("computer", "desk", "table", "battery");
LOGGER.info("Creating query set using PPTSS");

final ProbabilityProportionalToSizeParameters parameters = new ProbabilityProportionalToSizeParameters(name, description, sampling, querySetSize);
final ProbabilityProportionalToSizeAbstractQuerySampler sampler = new ProbabilityProportionalToSizeAbstractQuerySampler(client, parameters);

try {

// Create the query set and return its ID.
final String querySetId = indexQuerySet(client, name, description, sampling, queries);
// Sample and index the queries.
final String querySetId = sampler.sample();

return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"query_set\": \"" + querySetId + "\"}"));

} catch(Exception ex) {
Expand All @@ -162,7 +148,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
}

// Handle running query sets.
} else if(StringUtils.equalsIgnoreCase(request.path(), QUERYSET_RUN_URL)) {
} else if(QUERYSET_RUN_URL.equalsIgnoreCase(request.path())) {

final String id = request.param("id");

Expand Down Expand Up @@ -197,7 +183,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"message\": \"Query set " + id + " run initiated.\"}"));

// Handle the on-demand creation of implicit judgments.
} else if(StringUtils.equalsIgnoreCase(request.path(), IMPLICIT_JUDGMENTS_URL)) {
} else if(IMPLICIT_JUDGMENTS_URL.equalsIgnoreCase(request.path())) {

if (request.method().equals(RestRequest.Method.POST)) {

Expand All @@ -206,7 +192,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
final int maxRank = Integer.parseInt(request.param("max_rank", "20"));
final long judgments;

if (StringUtils.equalsIgnoreCase(clickModel, "coec")) {
if ("coec".equalsIgnoreCase(clickModel)) {

final CoecClickModelParameters coecClickModelParameters = new CoecClickModelParameters(true, maxRank);
final CoecClickModel coecClickModel = new CoecClickModel(client, coecClickModelParameters);
Expand Down Expand Up @@ -249,7 +235,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
}

// Handle the scheduling of creating implicit judgments.
} else if(StringUtils.equalsIgnoreCase(request.path(), SCHEDULING_URL)) {
} else if(SCHEDULING_URL.equalsIgnoreCase(request.path())) {

if (request.method().equals(RestRequest.Method.POST)) {

Expand All @@ -270,15 +256,15 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli

// Read the start_time.
final Instant startTime;
if (StringUtils.isEmpty(request.param("start_time"))) {
if (request.param("start_time") == null) {
startTime = Instant.now();
} else {
startTime = Instant.ofEpochMilli(Long.parseLong(request.param("start_time")));
}

// Read the interval.
final int interval;
if (StringUtils.isEmpty(request.param("interval"))) {
if (request.param("interval") == null) {
// Default to every 24 hours.
interval = 1440;
} else {
Expand Down Expand Up @@ -355,29 +341,4 @@ public void onFailure(Exception e) {

}

/**
* Index the query set.
*/
private String indexQuerySet(final NodeClient client, final String name, final String description, final String sampling, Collection<String> queries) throws Exception {

final Map<String, Object> querySet = new HashMap<>();
querySet.put("name", name);
querySet.put("description", description);
querySet.put("sampling", sampling);
querySet.put("queries", queries);
querySet.put("created_at", Instant.now().toEpochMilli());

final String querySetId = UUID.randomUUID().toString();

final IndexRequest indexRequest = new IndexRequest().index(SearchQualityEvaluationPlugin.QUERY_SETS_INDEX_NAME)
.id(querySetId)
.source(querySet)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

client.index(indexRequest).get();

return querySetId;

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import com.google.gson.annotations.SerializedName;

/**
* A UBI event.
* Creates a representation of a UBI event.
*/
public class UbiEvent {

Expand All @@ -27,6 +27,13 @@ public class UbiEvent {
@SerializedName("event_attributes")
private EventAttributes eventAttributes;

/**
* Creates a new representation of an UBI event.
*/
public UbiEvent() {

}

@Override
public String toString() {
return actionName + ", " + clientId + ", " + queryId + ", " + eventAttributes.getObject().toString() + ", " + eventAttributes.getPosition().getIndex();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.eval.samplers;

import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.client.node.NodeClient;
import org.opensearch.eval.SearchQualityEvaluationPlugin;

import java.time.Instant;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;

/**
* An interface for sampling UBI queries.
*/
public abstract class AbstractQuerySampler {

/**
* Gets the name of the sampler.
* @return The name of the sampler.
*/
abstract String getName();

/**
* Samples the queries and inserts the query set into an index.
* @return A query set ID.
*/
abstract String sample() throws Exception;

/**
* Index the query set.
*/
protected String indexQuerySet(final NodeClient client, final String name, final String description, final String sampling, Collection<String> queries) throws Exception {

final Map<String, Object> querySet = new HashMap<>();
querySet.put("name", name);
querySet.put("description", description);
querySet.put("sampling", sampling);
querySet.put("queries", queries);
querySet.put("created_at", Instant.now().toEpochMilli());

final String querySetId = UUID.randomUUID().toString();

final IndexRequest indexRequest = new IndexRequest().index(SearchQualityEvaluationPlugin.QUERY_SETS_INDEX_NAME)
.id(querySetId)
.source(querySet)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

client.index(indexRequest).get();

return querySetId;

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.eval.samplers;

public class AbstractSamplerParameters {

private final String name;
private final String description;
private final String sampling;
private final int querySetSize;

public AbstractSamplerParameters(final String name, final String description, final String sampling, final int querySetSize) {
this.name = name;
this.description = description;
this.sampling = sampling;
this.querySetSize = querySetSize;
}

public String getName() {
return name;
}

public String getDescription() {
return description;
}

public String getSampling() {
return sampling;
}

public int getQuerySetSize() {
return querySetSize;
}

}
Loading
Loading