Skip to content

Commit

Permalink
Working on standalone implementation.
Browse files Browse the repository at this point in the history
Signed-off-by: jzonthemtn <[email protected]>
  • Loading branch information
jzonthemtn committed Dec 17, 2024
1 parent e937df1 commit 871a9bf
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 166 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,6 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.admin.indices.create.CreateIndexRequest;
import org.opensearch.action.admin.indices.create.CreateIndexResponse;
import org.opensearch.action.admin.indices.exists.indices.IndicesExistsRequest;
import org.opensearch.action.admin.indices.exists.indices.IndicesExistsResponse;
import org.opensearch.action.delete.DeleteRequest;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.eval.judgments.clickmodel.coec.CoecClickModel;
import org.opensearch.eval.judgments.clickmodel.coec.CoecClickModelParameters;
import org.opensearch.eval.runners.OpenSearchQuerySetRunner;
Expand All @@ -32,23 +18,11 @@
import org.opensearch.eval.samplers.AllQueriesQuerySamplerParameters;
import org.opensearch.eval.samplers.ProbabilityProportionalToSizeAbstractQuerySampler;
import org.opensearch.eval.samplers.ProbabilityProportionalToSizeParameters;
import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestResponse;

import java.io.IOException;
import java.nio.charset.Charset;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutionException;

import static org.opensearch.eval.SearchQualityEvaluationPlugin.JUDGMENTS_INDEX_NAME;

public class SearchQualityEvaluationRestHandler extends BaseRestHandler {
public class SearchQualityEvaluationRestHandler {

private static final Logger LOGGER = LogManager.getLogger(SearchQualityEvaluationRestHandler.class);

Expand Down Expand Up @@ -77,21 +51,6 @@ public class SearchQualityEvaluationRestHandler extends BaseRestHandler {
*/
public static final String QUERY_PLACEHOLDER = "#$query##";

@Override
public String getName() {
return "Search Quality Evaluation Framework";
}

@Override
public List<Route> routes() {
return List.of(
new Route(RestRequest.Method.POST, IMPLICIT_JUDGMENTS_URL),
new Route(RestRequest.Method.POST, SCHEDULING_URL),
new Route(RestRequest.Method.DELETE, SCHEDULING_URL),
new Route(RestRequest.Method.POST, QUERYSET_MANAGEMENT_URL),
new Route(RestRequest.Method.POST, QUERYSET_RUN_URL));
}

@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {

Expand Down Expand Up @@ -276,107 +235,6 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.METHOD_NOT_ALLOWED, "{\"error\": \"" + request.method() + " is not allowed.\"}"));
}

// Handle the scheduling of creating implicit judgments.
} else if(SCHEDULING_URL.equalsIgnoreCase(request.path())) {

if (request.method().equals(RestRequest.Method.POST)) {

// Get the job parameters from the request.
final String id = request.param("id");
final String jobName = request.param("job_name", UUID.randomUUID().toString());
final String lockDurationSecondsString = request.param("lock_duration_seconds", "600");
final Long lockDurationSeconds = lockDurationSecondsString != null ? Long.parseLong(lockDurationSecondsString) : null;
final String jitterString = request.param("jitter");
final Double jitter = jitterString != null ? Double.parseDouble(jitterString) : null;
final String clickModel = request.param("click_model");
final int maxRank = Integer.parseInt(request.param("max_rank", "20"));

// Validate the request parameters.
if (id == null || clickModel == null) {
throw new IllegalArgumentException("The id and click_model parameters must be provided.");
}

// Read the start_time.
final Instant startTime;
if (request.param("start_time") == null) {
startTime = Instant.now();
} else {
startTime = Instant.ofEpochMilli(Long.parseLong(request.param("start_time")));
}

// Read the interval.
final int interval;
if (request.param("interval") == null) {
// Default to every 24 hours.
interval = 1440;
} else {
interval = Integer.parseInt(request.param("interval"));
}

final SearchQualityEvaluationJobParameter jobParameter = new SearchQualityEvaluationJobParameter(
jobName, new IntervalSchedule(startTime, interval, ChronoUnit.MINUTES), lockDurationSeconds,
jitter, clickModel, maxRank
);

final IndexRequest indexRequest = new IndexRequest().index(SearchQualityEvaluationPlugin.SCHEDULED_JOBS_INDEX_NAME)
.id(id)
.source(jobParameter.toXContent(JsonXContent.contentBuilder(), null))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

return restChannel -> {

// index the job parameter
client.index(indexRequest, new ActionListener<>() {

@Override
public void onResponse(final IndexResponse indexResponse) {

try {

final RestResponse restResponse = new BytesRestResponse(
RestStatus.OK,
indexResponse.toXContent(JsonXContent.contentBuilder(), null)
);
LOGGER.info("Created implicit judgments schedule for click-model {}: Job name {}, running every {} minutes starting {}", clickModel, jobName, interval, startTime);

restChannel.sendResponse(restResponse);

} catch (IOException e) {
restChannel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, e.getMessage()));
}

}

@Override
public void onFailure(Exception e) {
restChannel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, e.getMessage()));
}
});

};

// Delete a scheduled job to make implicit judgments.
} else if (request.method().equals(RestRequest.Method.DELETE)) {

final String id = request.param("id");
final DeleteRequest deleteRequest = new DeleteRequest().index(SearchQualityEvaluationPlugin.SCHEDULED_JOBS_INDEX_NAME).id(id);

return restChannel -> client.delete(deleteRequest, new ActionListener<>() {
@Override
public void onResponse(final DeleteResponse deleteResponse) {
restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"message\": \"Scheduled job deleted.\"}"));
}

@Override
public void onFailure(Exception e) {
restChannel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, e.getMessage()));
}
});

} else {
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.METHOD_NOT_ALLOWED, "{\"error\": \"" + request.method() + " is not allowed.\"}"));
}

} else {
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.NOT_FOUND, "{\"error\": \"" + request.path() + " was not found.\"}"));
}
Expand All @@ -386,7 +244,7 @@ public void onFailure(Exception e) {
private void createJudgmentsIndex(final NodeClient client) throws Exception {

// If the judgments index does not exist we need to create it.
final IndicesExistsRequest indicesExistsRequest = new IndicesExistsRequest(JUDGMENTS_INDEX_NAME);
final IndicesExistsRequest indicesExistsRequest = new IndicesExistsRequest(Constants.JUDGMENTS_INDEX_NAME);

final IndicesExistsResponse indicesExistsResponse = client.admin().indices().exists(indicesExistsRequest).get();

Expand All @@ -405,7 +263,7 @@ private void createJudgmentsIndex(final NodeClient client) throws Exception {
" }";

// Create the judgments index.
final CreateIndexRequest createIndexRequest = new CreateIndexRequest(JUDGMENTS_INDEX_NAME).mapping(mapping);
final CreateIndexRequest createIndexRequest = new CreateIndexRequest(Constants.JUDGMENTS_INDEX_NAME).mapping(mapping);

// TODO: Don't use .get()
client.admin().indices().create(createIndexRequest).get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.client.node.NodeClient;
import org.opensearch.core.action.ActionListener;

import org.opensearch.client.opensearch.OpenSearchClient;
import org.opensearch.client.opensearch.core.IndexRequest;
import org.opensearch.client.opensearch.indices.CreateIndexRequest;
import org.opensearch.eval.Constants;
import org.opensearch.eval.utils.TimeUtils;

Expand Down Expand Up @@ -46,7 +45,7 @@ public abstract class AbstractQuerySampler {
/**
* Index the query set.
*/
protected String indexQuerySet(final NodeClient client, final String name, final String description, final String sampling, Map<String, Long> queries) throws Exception {
protected String indexQuerySet(final OpenSearchClient client, final String name, final String description, final String sampling, Map<String, Long> queries) throws Exception {

LOGGER.info("Indexing {} queries for query set {}", queries.size(), name);

Expand All @@ -73,23 +72,33 @@ protected String indexQuerySet(final NodeClient client, final String name, final
final String querySetId = UUID.randomUUID().toString();

// TODO: Create a mapping for the query set index.
final IndexRequest indexRequest = new IndexRequest().index(Constants.QUERY_SETS_INDEX_NAME)
.id(querySetId)
.source(querySet)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

client.index(indexRequest, new ActionListener<>() {
final IndexData indexData = new IndexData("Document 1", "Text for document 1");

@Override
public void onResponse(IndexResponse indexResponse) {
LOGGER.info("Indexed query set {} having name {}", querySetId, name);
}

@Override
public void onFailure(Exception ex) {
LOGGER.error("Unable to index query set {}", querySetId, ex);
}
});
final IndexRequest indexRequest = new IndexRequest.Builder<IndexData>().index(Constants.QUERY_SETS_INDEX_NAME)
.id(querySetId)
.document(indexData)
.source(querySet);

client.index(indexRequest);
//
// final IndexRequest indexRequest = new IndexRequest().index(Constants.QUERY_SETS_INDEX_NAME)
// .id(querySetId)
// .source(querySet)
// .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
//
// client.index(indexRequest, new ActionListener<>() {
//
// @Override
// public void onResponse(IndexResponse indexResponse) {
// LOGGER.info("Indexed query set {} having name {}", querySetId, name);
// }
//
// @Override
// public void onFailure(Exception ex) {
// LOGGER.error("Unable to index query set {}", querySetId, ex);
// }
// });

return querySetId;

Expand Down

0 comments on commit 871a9bf

Please sign in to comment.