diff --git a/opensearch-search-quality-evaluation-framework/.gitignore b/opensearch-search-quality-evaluation-framework/.gitignore new file mode 100644 index 0000000..6c884e1 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/.gitignore @@ -0,0 +1,14 @@ +# Ignore Gradle project-specific cache directory +.gradle + +# Ignore Gradle build output directory +build + +# intellij files +.idea/ +*.iml +*.ipr +*.iws +build-idea/ +out/ + diff --git a/opensearch-search-quality-evaluation-framework/Dockerfile b/opensearch-search-quality-evaluation-framework/Dockerfile new file mode 100644 index 0000000..02f56c8 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/Dockerfile @@ -0,0 +1,6 @@ +FROM opensearchproject/opensearch:2.18.0 + +RUN /usr/share/opensearch/bin/opensearch-plugin install --batch https://github.com/opensearch-project/user-behavior-insights/releases/download/2.18.0.2/opensearch-ubi-2.18.0.2.zip + +ADD ./build/distributions/search-quality-evaluation-plugin-0.0.1.zip /tmp/search-quality-evaluation-plugin.zip +RUN /usr/share/opensearch/bin/opensearch-plugin install --batch file:/tmp/search-quality-evaluation-plugin.zip diff --git a/opensearch-search-quality-evaluation-framework/LICENSE.txt b/opensearch-search-quality-evaluation-framework/LICENSE.txt new file mode 100644 index 0000000..67db858 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/LICENSE.txt @@ -0,0 +1,175 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. diff --git a/opensearch-search-quality-evaluation-framework/NOTICE.txt b/opensearch-search-quality-evaluation-framework/NOTICE.txt new file mode 100644 index 0000000..be5c6b3 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/NOTICE.txt @@ -0,0 +1 @@ +Copyright Open Source Connections or its affiliates. All Rights Reserved. diff --git a/opensearch-search-quality-evaluation-framework/README.md b/opensearch-search-quality-evaluation-framework/README.md new file mode 100644 index 0000000..215ccce --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/README.md @@ -0,0 +1,72 @@ +# OpenSearch Evaluation Framework + +This is an OpenSearch plugin built on the OpenSearch job scheduler plugin. + +## API Endpoints + +| Method | Endpoint | Description | +|--------|-----------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------| +| `POST` | `/_plugins/search_quality_eval/queryset` | Create a query set by sampling from the `ubi_queries` index. The `name`, `description`, and `sampling` method parameters are required. | +| `POST` | `/_plugins/search_quality_eval/run` | Initiate a run of a query set. The `name` of the query set is a required parameter. | +| `POST` | `/_plugins/search_quality_eval/judgments` | Generate implicit judgments from UBI events and queries now. | +| `POST` | `/_plugins/search_quality_eval/schedule` | Create a scheduled job to generate implicit judgments. | + + +## Building + +Build the project from the top-level directory to build all projects. + +``` +cd .. +./gradlew build +``` + +## Running in Docker + +From this directory: + +``` +docker compose build && docker compose up +``` + +Verify the plugin is installed: + +``` +curl http://localhost:9200/_cat/plugins +``` + +In the list returned you should see: + +``` +opensearch search-quality-evaluation-plugin 2.17.1.0-SNAPSHOT +``` + +To create a schedule to generate implicit judgments: + +``` +curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/schedule?id=1&click_model=coec&job_name=test&interval=60" | jq +``` + +See the created job: + +``` +curl -s http://localhost:9200/search_quality_eval_scheduled_jobs/_search | jq +``` + +To run an on-demand job without scheduling: + +``` +curl -X POST "http://localhost:9200/_plugins/search_quality_eval/judgments?click_model=coec&max_rank=20" | jq +``` + +To see the job runs: + +``` +curl -X POST "http://localhost:9200/search_quality_eval_completed_jobs/_search" | jq +``` + +See the first 10 judgments: + +``` +curl -s http://localhost:9200/judgments/_search | jq +``` \ No newline at end of file diff --git a/opensearch-search-quality-evaluation-framework/aggs.sh b/opensearch-search-quality-evaluation-framework/aggs.sh new file mode 100755 index 0000000..5cf0e24 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/aggs.sh @@ -0,0 +1,20 @@ +#!/bin/bash -e + +curl -X GET http://localhost:9200/ubi_events/_search -H "Content-Type: application/json" -d' +{ + "size": 0, + "aggs": { + "By_Action": { + "terms": { + "field": "action_name" + }, + "aggs": { + "By_Position": { + "terms": { + "field": "event_attributes.position.ordinal" + } + } + } + } + } +}' | jq \ No newline at end of file diff --git a/opensearch-search-quality-evaluation-framework/build.gradle b/opensearch-search-quality-evaluation-framework/build.gradle new file mode 100644 index 0000000..bdd0586 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/build.gradle @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +apply plugin: 'java' +apply plugin: 'idea' + +ext { + projectSubstitutions = [:] + licenseFile = rootProject.file('LICENSE.txt') + noticeFile = rootProject.file('NOTICE.txt') +} + +test { + include "**/Test*.class" + include "**/*Test.class" + include "**/*Test.class" + include "**/*TestCase.class" +} + +group = 'org.opensearch' +version = "${evalVersion}" + +buildscript { + repositories { + mavenLocal() + maven { url "https://aws.oss.sonatype.org/content/repositories/snapshots" } + mavenCentral() + maven { url "https://plugins.gradle.org/m2/" } + } + + dependencies { + classpath "org.opensearch.gradle:build-tools:${opensearchVersion}" + } +} + +repositories { + mavenLocal() + mavenCentral() + maven { url "https://aws.oss.sonatype.org/content/repositories/snapshots" } +} + +dependencies { + implementation 'org.apache.logging.log4j:log4j-core:2.24.3' + implementation 'com.fasterxml.jackson.core:jackson-annotations:2.18.2' + implementation 'com.fasterxml.jackson.core:jackson-databind:2.18.2' + implementation 'org.apache.httpcomponents.core5:httpcore5:5.3.1' + implementation 'org.apache.httpcomponents.client5:httpclient5:5.4.1' + implementation 'commons-logging:commons-logging:1.3.4' + implementation 'com.google.code.gson:gson:2.11.0' +} diff --git a/opensearch-search-quality-evaluation-framework/gradle.properties b/opensearch-search-quality-evaluation-framework/gradle.properties new file mode 100644 index 0000000..2659a68 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/gradle.properties @@ -0,0 +1,2 @@ +opensearchVersion = 2.18.0 +evalVersion = 0.0.1 diff --git a/opensearch-search-quality-evaluation-framework/gradle/wrapper/gradle-wrapper.jar b/opensearch-search-quality-evaluation-framework/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 0000000..e708b1c Binary files /dev/null and b/opensearch-search-quality-evaluation-framework/gradle/wrapper/gradle-wrapper.jar differ diff --git a/opensearch-search-quality-evaluation-framework/gradle/wrapper/gradle-wrapper.properties b/opensearch-search-quality-evaluation-framework/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000..2bbac7d --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,5 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-8.5-bin.zip +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists \ No newline at end of file diff --git a/opensearch-search-quality-evaluation-framework/settings.gradle b/opensearch-search-quality-evaluation-framework/settings.gradle new file mode 100644 index 0000000..ef059e1 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/settings.gradle @@ -0,0 +1 @@ +rootProject.name = 'search-evaluation-framework' diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/SearchQualityEvaluationFramework.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/SearchQualityEvaluationFramework.java new file mode 100644 index 0000000..245ee64 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/SearchQualityEvaluationFramework.java @@ -0,0 +1,10 @@ +package org.opensearch.eval; + +public class SearchQualityEvaluationFramework { + + public void main(String[] args) { + + + } + +} \ No newline at end of file diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/SearchQualityEvaluationJobParameter.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/SearchQualityEvaluationJobParameter.java new file mode 100644 index 0000000..2ea5379 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/SearchQualityEvaluationJobParameter.java @@ -0,0 +1,248 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval; + +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.jobscheduler.spi.schedule.Schedule; + +import java.io.IOException; +import java.time.Instant; + +public class SearchQualityEvaluationJobParameter implements ScheduledJobParameter { + + /** + * The name of the parameter for providing a name for the scheduled job. + */ + public static final String NAME_FIELD = "name"; + + /** + * The name of the parameter for creating a job as enabled or disabled. + */ + public static final String ENABLED_FILED = "enabled"; + + /** + * The name of the parameter for specifying when the job was last updated. + */ + public static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; + + /** + * The name of the parameter for specifying a readable time for when the job was last updated. + */ + public static final String LAST_UPDATE_TIME_FIELD_READABLE = "last_update_time_field"; + public static final String SCHEDULE_FIELD = "schedule"; + public static final String ENABLED_TIME_FILED = "enabled_time"; + public static final String ENABLED_TIME_FILED_READABLE = "enabled_time_field"; + public static final String LOCK_DURATION_SECONDS = "lock_duration_seconds"; + public static final String JITTER = "jitter"; + + /** + * The name of the parameter that allows for specifying the type of click model to use. + */ + public static final String CLICK_MODEL = "click_model"; + + /** + * The name of the parameter that allows for setting a max rank value to use during judgment generation. + */ + public static final String MAX_RANK = "max_rank"; + + // Properties from ScheduledJobParameter. + private String jobName; + private Instant lastUpdateTime; + private Instant enabledTime; + private boolean enabled; + private Schedule schedule; + private Long lockDurationSeconds; + private Double jitter; + + // Custom properties. + private String clickModel; + private int maxRank; + + public SearchQualityEvaluationJobParameter() { + + } + + public SearchQualityEvaluationJobParameter(final String name, final Schedule schedule, + final Long lockDurationSeconds, final Double jitter, + final String clickModel, final int maxRank) { + this.jobName = name; + this.schedule = schedule; + this.enabled = true; + this.lockDurationSeconds = lockDurationSeconds; + this.jitter = jitter; + + final Instant now = Instant.now(); + this.enabledTime = now; + this.lastUpdateTime = now; + + // Custom properties. + this.clickModel = clickModel; + this.maxRank = maxRank; + + } + + @Override + public XContentBuilder toXContent(final XContentBuilder builder, final Params params) throws IOException { + + builder.startObject(); + + builder + .field(NAME_FIELD, this.jobName) + .field(ENABLED_FILED, this.enabled) + .field(SCHEDULE_FIELD, this.schedule) + .field(CLICK_MODEL, this.clickModel) + .field(MAX_RANK, this.maxRank); + + if (this.enabledTime != null) { + builder.timeField(ENABLED_TIME_FILED, ENABLED_TIME_FILED_READABLE, this.enabledTime.toEpochMilli()); + } + + if (this.lastUpdateTime != null) { + builder.timeField(LAST_UPDATE_TIME_FIELD, LAST_UPDATE_TIME_FIELD_READABLE, this.lastUpdateTime.toEpochMilli()); + } + + if (this.lockDurationSeconds != null) { + builder.field(LOCK_DURATION_SECONDS, this.lockDurationSeconds); + } + + if (this.jitter != null) { + builder.field(JITTER, this.jitter); + } + + builder.endObject(); + + return builder; + + } + + @Override + public String getName() { + return this.jobName; + } + + @Override + public Instant getLastUpdateTime() { + return this.lastUpdateTime; + } + + @Override + public Instant getEnabledTime() { + return this.enabledTime; + } + + @Override + public Schedule getSchedule() { + return this.schedule; + } + + @Override + public boolean isEnabled() { + return this.enabled; + } + + @Override + public Long getLockDurationSeconds() { + return this.lockDurationSeconds; + } + + @Override + public Double getJitter() { + return jitter; + } + + /** + * Sets the name of the job. + * @param jobName The name of the job. + */ + public void setJobName(String jobName) { + this.jobName = jobName; + } + + /** + * Sets when the job was last updated. + * @param lastUpdateTime An {@link Instant} of when the job was last updated. + */ + public void setLastUpdateTime(Instant lastUpdateTime) { + this.lastUpdateTime = lastUpdateTime; + } + + /** + * Sets when the job was enabled. + * @param enabledTime An {@link Instant} of when the job was enabled. + */ + public void setEnabledTime(Instant enabledTime) { + this.enabledTime = enabledTime; + } + + /** + * Sets whether the job is enabled. + * @param enabled A boolean representing whether the job is enabled. + */ + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + + /** + * Sets the schedule for the job. + * @param schedule A {@link Schedule} for the job. + */ + public void setSchedule(Schedule schedule) { + this.schedule = schedule; + } + + /** + * Sets the lock duration for the cluster when running the job. + * @param lockDurationSeconds The lock duration in seconds. + */ + public void setLockDurationSeconds(Long lockDurationSeconds) { + this.lockDurationSeconds = lockDurationSeconds; + } + + /** + * Sets the jitter for the job. + * @param jitter The jitter for the job. + */ + public void setJitter(Double jitter) { + this.jitter = jitter; + } + + /** + * Gets the type of click model to use for implicit judgment generation. + * @return The type of click model to use for implicit judgment generation. + */ + public String getClickModel() { + return clickModel; + } + + /** + * Sets the click model type to use for implicit judgment generation. + * @param clickModel The click model type to use for implicit judgment generation. + */ + public void setClickModel(String clickModel) { + this.clickModel = clickModel; + } + + /** + * Gets the max rank to use when generating implicit judgments. + * @return The max rank to use when generating implicit judgments. + */ + public int getMaxRank() { + return maxRank; + } + + /** + * Sets the max rank to use when generating implicit judgments. + * @param maxRank The max rank to use when generating implicit judgments. + */ + public void setMaxRank(int maxRank) { + this.maxRank = maxRank; + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/SearchQualityEvaluationJobRunner.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/SearchQualityEvaluationJobRunner.java new file mode 100644 index 0000000..442ae4c --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/SearchQualityEvaluationJobRunner.java @@ -0,0 +1,168 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.action.ActionListener; +import org.opensearch.eval.judgments.clickmodel.coec.CoecClickModel; +import org.opensearch.eval.judgments.clickmodel.coec.CoecClickModelParameters; +import org.opensearch.jobscheduler.spi.JobExecutionContext; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.jobscheduler.spi.ScheduledJobRunner; +import org.opensearch.jobscheduler.spi.utils.LockService; +import org.opensearch.threadpool.ThreadPool; + +import java.util.HashMap; +import java.util.Map; + +/** + * Job runner for scheduled implicit judgments jobs. + */ +public class SearchQualityEvaluationJobRunner implements ScheduledJobRunner { + + private static final Logger LOGGER = LogManager.getLogger(SearchQualityEvaluationJobRunner.class); + + private static SearchQualityEvaluationJobRunner INSTANCE; + + /** + * Gets a singleton instance of this class. + * @return A {@link SearchQualityEvaluationJobRunner}. + */ + public static SearchQualityEvaluationJobRunner getJobRunnerInstance() { + + LOGGER.info("Getting job runner instance"); + + if (INSTANCE != null) { + return INSTANCE; + } + + synchronized (SearchQualityEvaluationJobRunner.class) { + if (INSTANCE == null) { + INSTANCE = new SearchQualityEvaluationJobRunner(); + } + return INSTANCE; + } + + } + + private ClusterService clusterService; + private ThreadPool threadPool; + private Client client; + + private SearchQualityEvaluationJobRunner() { + + } + + public void setClusterService(ClusterService clusterService) { + this.clusterService = clusterService; + } + + public void setThreadPool(ThreadPool threadPool) { + this.threadPool = threadPool; + } + + public void setClient(Client client) { + this.client = client; + } + + @Override + public void runJob(final ScheduledJobParameter jobParameter, final JobExecutionContext context) { + + if(!(jobParameter instanceof SearchQualityEvaluationJobParameter)) { + throw new IllegalStateException( + "Job parameter is not instance of SampleJobParameter, type: " + jobParameter.getClass().getCanonicalName() + ); + } + + if(this.clusterService == null) { + throw new IllegalStateException("ClusterService is not initialized."); + } + + if(this.threadPool == null) { + throw new IllegalStateException("ThreadPool is not initialized."); + } + + final LockService lockService = context.getLockService(); + + final Runnable runnable = () -> { + + if (jobParameter.getLockDurationSeconds() != null) { + + lockService.acquireLock(jobParameter, context, ActionListener.wrap(lock -> { + + if (lock == null) { + return; + } + + final SearchQualityEvaluationJobParameter searchQualityEvaluationJobParameter = (SearchQualityEvaluationJobParameter) jobParameter; + + final long startTime = System.currentTimeMillis(); + final String judgmentsId; + + if("coec".equalsIgnoreCase(searchQualityEvaluationJobParameter.getClickModel())) { + + LOGGER.info("Beginning implicit judgment generation using clicks-over-expected-clicks."); + final CoecClickModelParameters coecClickModelParameters = new CoecClickModelParameters(searchQualityEvaluationJobParameter.getMaxRank()); + final CoecClickModel coecClickModel = new CoecClickModel(client, coecClickModelParameters); + + judgmentsId = coecClickModel.calculateJudgments(); + + } else { + + // Invalid click model. + throw new IllegalArgumentException("Invalid click model: " + searchQualityEvaluationJobParameter.getClickModel()); + + } + + final long elapsedTime = System.currentTimeMillis() - startTime; + LOGGER.info("Implicit judgment generation completed in {} ms", elapsedTime); + + final Map job = new HashMap<>(); + job.put("name", searchQualityEvaluationJobParameter.getName()); + job.put("click_model", searchQualityEvaluationJobParameter.getClickModel()); + job.put("started", startTime); + job.put("duration", elapsedTime); + job.put("judgments", judgmentsId); + job.put("invocation", "scheduled"); + job.put("max_rank", searchQualityEvaluationJobParameter.getMaxRank()); + + final IndexRequest indexRequest = new IndexRequest() + .index(SearchQualityEvaluationPlugin.COMPLETED_JOBS_INDEX_NAME) + .id(judgmentsId) + .source(job) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + client.index(indexRequest, new ActionListener<>() { + @Override + public void onResponse(IndexResponse indexResponse) { + LOGGER.info("Successfully indexed implicit judgments {}", judgmentsId); + } + + @Override + public void onFailure(Exception ex) { + LOGGER.error("Unable to index implicit judgments", ex); + } + }); + + }, exception -> { throw new IllegalStateException("Failed to acquire lock."); })); + } + + }; + + threadPool.generic().submit(runnable); + + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/SearchQualityEvaluationPlugin.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/SearchQualityEvaluationPlugin.java new file mode 100644 index 0000000..6a7b581 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/SearchQualityEvaluationPlugin.java @@ -0,0 +1,213 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.IndexScopedSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.settings.SettingsFilter; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.xcontent.XContentParserUtils; +import org.opensearch.env.Environment; +import org.opensearch.env.NodeEnvironment; +import org.opensearch.jobscheduler.spi.JobSchedulerExtension; +import org.opensearch.jobscheduler.spi.ScheduledJobParser; +import org.opensearch.jobscheduler.spi.ScheduledJobRunner; +import org.opensearch.jobscheduler.spi.schedule.ScheduleParser; +import org.opensearch.plugins.ActionPlugin; +import org.opensearch.plugins.Plugin; +import org.opensearch.repositories.RepositoriesService; +import org.opensearch.rest.RestController; +import org.opensearch.rest.RestHandler; +import org.opensearch.script.ScriptService; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.watcher.ResourceWatcherService; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.function.Supplier; + +/** + * Main class for the Search Quality Evaluation plugin. + */ +public class SearchQualityEvaluationPlugin extends Plugin implements ActionPlugin, JobSchedulerExtension { + + private static final Logger LOGGER = LogManager.getLogger(SearchQualityEvaluationPlugin.class); + + /** + * The name of the UBI index containing the queries. This should not be changed. + */ + public static final String UBI_QUERIES_INDEX_NAME = "ubi_queries"; + + /** + * The name of the UBI index containing the events. This should not be changed. + */ + public static final String UBI_EVENTS_INDEX_NAME = "ubi_events"; + + /** + * The name of the index to store the scheduled jobs to create implicit judgments. + */ + public static final String SCHEDULED_JOBS_INDEX_NAME = "search_quality_eval_scheduled_jobs"; + + /** + * The name of the index to store the completed jobs to create implicit judgments. + */ + public static final String COMPLETED_JOBS_INDEX_NAME = "search_quality_eval_completed_jobs"; + + /** + * The name of the index that stores the query sets. + */ + public static final String QUERY_SETS_INDEX_NAME = "search_quality_eval_query_sets"; + + /** + * The name of the index that stores the metrics for the dashboard. + */ + public static final String DASHBOARD_METRICS_INDEX_NAME = "sqe_metrics_sample_data"; + + /** + * The name of the index that stores the implicit judgments. + */ + public static final String JUDGMENTS_INDEX_NAME = "judgments"; + + @Override + public Collection createComponents( + final Client client, + final ClusterService clusterService, + final ThreadPool threadPool, + final ResourceWatcherService resourceWatcherService, + final ScriptService scriptService, + final NamedXContentRegistry xContentRegistry, + final Environment environment, + final NodeEnvironment nodeEnvironment, + final NamedWriteableRegistry namedWriteableRegistry, + final IndexNameExpressionResolver indexNameExpressionResolver, + final Supplier repositoriesServiceSupplier + ) { + + LOGGER.info("Creating search evaluation framework components"); + final SearchQualityEvaluationJobRunner jobRunner = SearchQualityEvaluationJobRunner.getJobRunnerInstance(); + jobRunner.setClusterService(clusterService); + jobRunner.setThreadPool(threadPool); + jobRunner.setClient(client); + + return Collections.emptyList(); + + } + + @Override + public String getJobType() { + return "scheduler_search_quality_eval"; + } + + @Override + public String getJobIndex() { + LOGGER.info("Getting job index name"); + return SCHEDULED_JOBS_INDEX_NAME; + } + + @Override + public ScheduledJobRunner getJobRunner() { + LOGGER.info("Creating job runner"); + return SearchQualityEvaluationJobRunner.getJobRunnerInstance(); + } + + @Override + public ScheduledJobParser getJobParser() { + + return (parser, id, jobDocVersion) -> { + + final SearchQualityEvaluationJobParameter jobParameter = new SearchQualityEvaluationJobParameter(); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + + while (!parser.nextToken().equals(XContentParser.Token.END_OBJECT)) { + + final String fieldName = parser.currentName(); + + parser.nextToken(); + + switch (fieldName) { + case SearchQualityEvaluationJobParameter.NAME_FIELD: + jobParameter.setJobName(parser.text()); + break; + case SearchQualityEvaluationJobParameter.ENABLED_FILED: + jobParameter.setEnabled(parser.booleanValue()); + break; + case SearchQualityEvaluationJobParameter.ENABLED_TIME_FILED: + jobParameter.setEnabledTime(parseInstantValue(parser)); + break; + case SearchQualityEvaluationJobParameter.LAST_UPDATE_TIME_FIELD: + jobParameter.setLastUpdateTime(parseInstantValue(parser)); + break; + case SearchQualityEvaluationJobParameter.SCHEDULE_FIELD: + jobParameter.setSchedule(ScheduleParser.parse(parser)); + break; + case SearchQualityEvaluationJobParameter.LOCK_DURATION_SECONDS: + jobParameter.setLockDurationSeconds(parser.longValue()); + break; + case SearchQualityEvaluationJobParameter.JITTER: + jobParameter.setJitter(parser.doubleValue()); + break; + case SearchQualityEvaluationJobParameter.CLICK_MODEL: + jobParameter.setClickModel(parser.text()); + break; + case SearchQualityEvaluationJobParameter.MAX_RANK: + jobParameter.setMaxRank(parser.intValue()); + break; + default: + XContentParserUtils.throwUnknownToken(parser.currentToken(), parser.getTokenLocation()); + } + + } + + return jobParameter; + + }; + + } + + private Instant parseInstantValue(final XContentParser parser) throws IOException { + + if (XContentParser.Token.VALUE_NULL.equals(parser.currentToken())) { + return null; + } + + if (parser.currentToken().isValue()) { + return Instant.ofEpochMilli(parser.longValue()); + } + + XContentParserUtils.throwUnknownToken(parser.currentToken(), parser.getTokenLocation()); + return null; + + } + + @Override + public List getRestHandlers( + final Settings settings, + final RestController restController, + final ClusterSettings clusterSettings, + final IndexScopedSettings indexScopedSettings, + final SettingsFilter settingsFilter, + final IndexNameExpressionResolver indexNameExpressionResolver, + final Supplier nodesInCluster + ) { + return Collections.singletonList(new SearchQualityEvaluationRestHandler()); + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/SearchQualityEvaluationRestHandler.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/SearchQualityEvaluationRestHandler.java new file mode 100644 index 0000000..ba56f04 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/SearchQualityEvaluationRestHandler.java @@ -0,0 +1,417 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.admin.indices.exists.indices.IndicesExistsRequest; +import org.opensearch.action.admin.indices.exists.indices.IndicesExistsResponse; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.eval.judgments.clickmodel.coec.CoecClickModel; +import org.opensearch.eval.judgments.clickmodel.coec.CoecClickModelParameters; +import org.opensearch.eval.runners.OpenSearchQuerySetRunner; +import org.opensearch.eval.runners.QuerySetRunResult; +import org.opensearch.eval.samplers.AllQueriesQuerySampler; +import org.opensearch.eval.samplers.AllQueriesQuerySamplerParameters; +import org.opensearch.eval.samplers.ProbabilityProportionalToSizeAbstractQuerySampler; +import org.opensearch.eval.samplers.ProbabilityProportionalToSizeParameters; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestResponse; + +import java.io.IOException; +import java.nio.charset.Charset; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.ExecutionException; + +import static org.opensearch.eval.SearchQualityEvaluationPlugin.JUDGMENTS_INDEX_NAME; + +public class SearchQualityEvaluationRestHandler extends BaseRestHandler { + + private static final Logger LOGGER = LogManager.getLogger(SearchQualityEvaluationRestHandler.class); + + /** + * URL for the implicit judgment scheduling. + */ + public static final String SCHEDULING_URL = "/_plugins/search_quality_eval/schedule"; + + /** + * URL for on-demand implicit judgment generation. + */ + public static final String IMPLICIT_JUDGMENTS_URL = "/_plugins/search_quality_eval/judgments"; + + /** + * URL for managing query sets. + */ + public static final String QUERYSET_MANAGEMENT_URL = "/_plugins/search_quality_eval/queryset"; + + /** + * URL for initiating query sets to run on-demand. + */ + public static final String QUERYSET_RUN_URL = "/_plugins/search_quality_eval/run"; + + /** + * The placeholder in the query that gets replaced by the query term when running a query set. + */ + public static final String QUERY_PLACEHOLDER = "#$query##"; + + @Override + public String getName() { + return "Search Quality Evaluation Framework"; + } + + @Override + public List routes() { + return List.of( + new Route(RestRequest.Method.POST, IMPLICIT_JUDGMENTS_URL), + new Route(RestRequest.Method.POST, SCHEDULING_URL), + new Route(RestRequest.Method.DELETE, SCHEDULING_URL), + new Route(RestRequest.Method.POST, QUERYSET_MANAGEMENT_URL), + new Route(RestRequest.Method.POST, QUERYSET_RUN_URL)); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + + // Handle managing query sets. + if(QUERYSET_MANAGEMENT_URL.equalsIgnoreCase(request.path())) { + + // Creating a new query set by sampling the UBI queries. + if (request.method().equals(RestRequest.Method.POST)) { + + final String name = request.param("name"); + final String description = request.param("description"); + final String sampling = request.param("sampling", "pptss"); + final int querySetSize = Integer.parseInt(request.param("query_set_size", "1000")); + + // Create a query set by finding all the unique user_query terms. + if (AllQueriesQuerySampler.NAME.equalsIgnoreCase(sampling)) { + + // If we are not sampling queries, the query sets should just be directly + // indexed into OpenSearch using the `ubi_queries` index directly. + + try { + + final AllQueriesQuerySamplerParameters parameters = new AllQueriesQuerySamplerParameters(name, description, sampling, querySetSize); + final AllQueriesQuerySampler sampler = new AllQueriesQuerySampler(client, parameters); + + // Sample and index the queries. + final String querySetId = sampler.sample(); + + return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"query_set\": \"" + querySetId + "\"}")); + + } catch(Exception ex) { + return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, "{\"error\": \"" + ex.getMessage() + "\"}")); + } + + + // Create a query set by using PPTSS sampling. + } else if (ProbabilityProportionalToSizeAbstractQuerySampler.NAME.equalsIgnoreCase(sampling)) { + + LOGGER.info("Creating query set using PPTSS"); + + final ProbabilityProportionalToSizeParameters parameters = new ProbabilityProportionalToSizeParameters(name, description, sampling, querySetSize); + final ProbabilityProportionalToSizeAbstractQuerySampler sampler = new ProbabilityProportionalToSizeAbstractQuerySampler(client, parameters); + + try { + + // Sample and index the queries. + final String querySetId = sampler.sample(); + + return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"query_set\": \"" + querySetId + "\"}")); + + } catch(Exception ex) { + return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, "{\"error\": \"" + ex.getMessage() + "\"}")); + } + + } else { + // An Invalid sampling method was provided in the request. + return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, "{\"error\": \"Invalid sampling method: " + sampling + "\"}")); + } + + } else { + // Invalid HTTP method for this endpoint. + return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.METHOD_NOT_ALLOWED, "{\"error\": \"" + request.method() + " is not allowed.\"}")); + } + + // Handle running query sets. + } else if(QUERYSET_RUN_URL.equalsIgnoreCase(request.path())) { + + final String querySetId = request.param("id"); + final String judgmentsId = request.param("judgments_id"); + final String index = request.param("index"); + final String searchPipeline = request.param("search_pipeline", null); + final String idField = request.param("id_field", "_id"); + final int k = Integer.parseInt(request.param("k", "10")); + final double threshold = Double.parseDouble(request.param("threshold", "1.0")); + + if(querySetId == null || querySetId.isEmpty() || judgmentsId == null || judgmentsId.isEmpty() || index == null || index.isEmpty()) { + return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, "{\"error\": \"Missing required parameters.\"}")); + } + + if(k < 1) { + return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, "{\"error\": \"k must be a positive integer.\"}")); + } + + if(!request.hasContent()) { + return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, "{\"error\": \"Missing query in body.\"}")); + } + + // Get the query JSON from the content. + final String query = new String(BytesReference.toBytes(request.content()), Charset.defaultCharset()); + + // Validate the query has a QUERY_PLACEHOLDER. + if(!query.contains(QUERY_PLACEHOLDER)) { + return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, "{\"error\": \"Missing query placeholder in query.\"}")); + } + + try { + + final OpenSearchQuerySetRunner openSearchQuerySetRunner = new OpenSearchQuerySetRunner(client); + final QuerySetRunResult querySetRunResult = openSearchQuerySetRunner.run(querySetId, judgmentsId, index, searchPipeline, idField, query, k, threshold); + openSearchQuerySetRunner.save(querySetRunResult); + + } catch (Exception ex) { + LOGGER.error("Unable to run query set. Verify query set and judgments exist.", ex); + return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, ex.getMessage())); + } + + return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"message\": \"Run initiated for query set " + querySetId + "\"}")); + + // Handle the on-demand creation of implicit judgments. + } else if(IMPLICIT_JUDGMENTS_URL.equalsIgnoreCase(request.path())) { + + if (request.method().equals(RestRequest.Method.POST)) { + + //final long startTime = System.currentTimeMillis(); + final String clickModel = request.param("click_model", "coec"); + final int maxRank = Integer.parseInt(request.param("max_rank", "20")); + + if (CoecClickModel.CLICK_MODEL_NAME.equalsIgnoreCase(clickModel)) { + + final CoecClickModelParameters coecClickModelParameters = new CoecClickModelParameters(maxRank); + final CoecClickModel coecClickModel = new CoecClickModel(client, coecClickModelParameters); + + final String judgmentsId; + + // TODO: Run this in a separate thread. + try { + + // Create the judgments index. + createJudgmentsIndex(client); + + judgmentsId = coecClickModel.calculateJudgments(); + + // judgmentsId will be null if no judgments were created (and indexed). + if(judgmentsId == null) { + // TODO: Is Bad Request the appropriate error? Perhaps Conflict is more appropriate? + return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, "{\"error\": \"No judgments were created. Check the queries and events data.\"}")); + } + +// final long elapsedTime = System.currentTimeMillis() - startTime; +// +// final Map job = new HashMap<>(); +// job.put("name", "manual_generation"); +// job.put("click_model", clickModel); +// job.put("started", startTime); +// job.put("duration", elapsedTime); +// job.put("invocation", "on_demand"); +// job.put("judgments_id", judgmentsId); +// job.put("max_rank", maxRank); +// +// final String jobId = UUID.randomUUID().toString(); +// +// final IndexRequest indexRequest = new IndexRequest() +// .index(SearchQualityEvaluationPlugin.COMPLETED_JOBS_INDEX_NAME) +// .id(jobId) +// .source(job) +// .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); +// +// client.index(indexRequest, new ActionListener<>() { +// @Override +// public void onResponse(final IndexResponse indexResponse) { +// LOGGER.debug("Click model job completed successfully: {}", jobId); +// } +// +// @Override +// public void onFailure(final Exception ex) { +// LOGGER.error("Unable to run job with ID {}", jobId, ex); +// throw new RuntimeException("Unable to run job", ex); +// } +// }); + + } catch (Exception ex) { + throw new RuntimeException("Unable to generate judgments.", ex); + } + + return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"judgments_id\": \"" + judgmentsId + "\"}")); + + } else { + return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, "{\"error\": \"Invalid click model.\"}")); + } + + } else { + return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.METHOD_NOT_ALLOWED, "{\"error\": \"" + request.method() + " is not allowed.\"}")); + } + + // Handle the scheduling of creating implicit judgments. + } else if(SCHEDULING_URL.equalsIgnoreCase(request.path())) { + + if (request.method().equals(RestRequest.Method.POST)) { + + // Get the job parameters from the request. + final String id = request.param("id"); + final String jobName = request.param("job_name", UUID.randomUUID().toString()); + final String lockDurationSecondsString = request.param("lock_duration_seconds", "600"); + final Long lockDurationSeconds = lockDurationSecondsString != null ? Long.parseLong(lockDurationSecondsString) : null; + final String jitterString = request.param("jitter"); + final Double jitter = jitterString != null ? Double.parseDouble(jitterString) : null; + final String clickModel = request.param("click_model"); + final int maxRank = Integer.parseInt(request.param("max_rank", "20")); + + // Validate the request parameters. + if (id == null || clickModel == null) { + throw new IllegalArgumentException("The id and click_model parameters must be provided."); + } + + // Read the start_time. + final Instant startTime; + if (request.param("start_time") == null) { + startTime = Instant.now(); + } else { + startTime = Instant.ofEpochMilli(Long.parseLong(request.param("start_time"))); + } + + // Read the interval. + final int interval; + if (request.param("interval") == null) { + // Default to every 24 hours. + interval = 1440; + } else { + interval = Integer.parseInt(request.param("interval")); + } + + final SearchQualityEvaluationJobParameter jobParameter = new SearchQualityEvaluationJobParameter( + jobName, new IntervalSchedule(startTime, interval, ChronoUnit.MINUTES), lockDurationSeconds, + jitter, clickModel, maxRank + ); + + final IndexRequest indexRequest = new IndexRequest().index(SearchQualityEvaluationPlugin.SCHEDULED_JOBS_INDEX_NAME) + .id(id) + .source(jobParameter.toXContent(JsonXContent.contentBuilder(), null)) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + return restChannel -> { + + // index the job parameter + client.index(indexRequest, new ActionListener<>() { + + @Override + public void onResponse(final IndexResponse indexResponse) { + + try { + + final RestResponse restResponse = new BytesRestResponse( + RestStatus.OK, + indexResponse.toXContent(JsonXContent.contentBuilder(), null) + ); + LOGGER.info("Created implicit judgments schedule for click-model {}: Job name {}, running every {} minutes starting {}", clickModel, jobName, interval, startTime); + + restChannel.sendResponse(restResponse); + + } catch (IOException e) { + restChannel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, e.getMessage())); + } + + } + + @Override + public void onFailure(Exception e) { + restChannel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, e.getMessage())); + } + }); + + }; + + // Delete a scheduled job to make implicit judgments. + } else if (request.method().equals(RestRequest.Method.DELETE)) { + + final String id = request.param("id"); + final DeleteRequest deleteRequest = new DeleteRequest().index(SearchQualityEvaluationPlugin.SCHEDULED_JOBS_INDEX_NAME).id(id); + + return restChannel -> client.delete(deleteRequest, new ActionListener<>() { + @Override + public void onResponse(final DeleteResponse deleteResponse) { + restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"message\": \"Scheduled job deleted.\"}")); + } + + @Override + public void onFailure(Exception e) { + restChannel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, e.getMessage())); + } + }); + + } else { + return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.METHOD_NOT_ALLOWED, "{\"error\": \"" + request.method() + " is not allowed.\"}")); + } + + } else { + return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.NOT_FOUND, "{\"error\": \"" + request.path() + " was not found.\"}")); + } + + } + + private void createJudgmentsIndex(final NodeClient client) throws Exception { + + // If the judgments index does not exist we need to create it. + final IndicesExistsRequest indicesExistsRequest = new IndicesExistsRequest(JUDGMENTS_INDEX_NAME); + + final IndicesExistsResponse indicesExistsResponse = client.admin().indices().exists(indicesExistsRequest).get(); + + if(!indicesExistsResponse.isExists()) { + + // TODO: Read this mapping from a resource file instead. + final String mapping = "{\n" + + " \"properties\": {\n" + + " \"judgments_id\": { \"type\": \"keyword\" },\n" + + " \"query_id\": { \"type\": \"keyword\" },\n" + + " \"query\": { \"type\": \"keyword\" },\n" + + " \"document_id\": { \"type\": \"keyword\" },\n" + + " \"judgment\": { \"type\": \"double\" },\n" + + " \"timestamp\": { \"type\": \"date\", \"format\": \"strict_date_time\" }\n" + + " }\n" + + " }"; + + // Create the judgments index. + final CreateIndexRequest createIndexRequest = new CreateIndexRequest(JUDGMENTS_INDEX_NAME).mapping(mapping); + + // TODO: Don't use .get() + client.admin().indices().create(createIndexRequest).get(); + + } + + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/clickmodel/ClickModel.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/clickmodel/ClickModel.java new file mode 100644 index 0000000..ea83a87 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/clickmodel/ClickModel.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.judgments.clickmodel; + +/** + * Base class for creating click models. + */ +public abstract class ClickModel { + + /** + * Calculate implicit judgments. + * @return The judgments ID. + * @throws Exception Thrown if the judgments cannot be created. + */ + public abstract String calculateJudgments() throws Exception; + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/clickmodel/ClickModelParameters.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/clickmodel/ClickModelParameters.java new file mode 100644 index 0000000..8c42550 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/clickmodel/ClickModelParameters.java @@ -0,0 +1,13 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.judgments.clickmodel; + +public abstract class ClickModelParameters { + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/clickmodel/coec/CoecClickModel.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/clickmodel/coec/CoecClickModel.java new file mode 100644 index 0000000..f2e8aa8 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/clickmodel/coec/CoecClickModel.java @@ -0,0 +1,422 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.judgments.clickmodel.coec; + +import com.google.gson.Gson; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchScrollRequest; +import org.opensearch.client.Client; +import org.opensearch.client.Requests; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.eval.SearchQualityEvaluationPlugin; +import org.opensearch.eval.judgments.clickmodel.ClickModel; +import org.opensearch.eval.judgments.model.ClickthroughRate; +import org.opensearch.eval.judgments.model.Judgment; +import org.opensearch.eval.judgments.model.ubi.event.UbiEvent; +import org.opensearch.eval.judgments.opensearch.OpenSearchHelper; +import org.opensearch.eval.judgments.queryhash.IncrementalUserQueryHash; +import org.opensearch.eval.utils.MathUtils; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.WrapperQueryBuilder; +import org.opensearch.search.Scroll; +import org.opensearch.search.SearchHit; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.BucketOrder; +import org.opensearch.search.aggregations.bucket.terms.Terms; +import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.Collection; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.LinkedList; +import java.util.Map; +import java.util.Set; + +public class CoecClickModel extends ClickModel { + + public static final String CLICK_MODEL_NAME = "coec"; + + // OpenSearch indexes for COEC data. + public static final String INDEX_RANK_AGGREGATED_CTR = "rank_aggregated_ctr"; + public static final String INDEX_QUERY_DOC_CTR = "click_through_rates"; + + // UBI event names. + public static final String EVENT_CLICK = "click"; + public static final String EVENT_IMPRESSION = "impression"; + + private final CoecClickModelParameters parameters; + + private final OpenSearchHelper openSearchHelper; + + private final IncrementalUserQueryHash incrementalUserQueryHash = new IncrementalUserQueryHash(); + private final Gson gson = new Gson(); + private final Client client; + + private static final Logger LOGGER = LogManager.getLogger(CoecClickModel.class.getName()); + + public CoecClickModel(final Client client, final CoecClickModelParameters parameters) { + + this.parameters = parameters; + this.openSearchHelper = new OpenSearchHelper(client); + this.client = client; + + } + + @Override + public String calculateJudgments() throws Exception { + + final int maxRank = parameters.getMaxRank(); + + // Calculate and index the rank-aggregated click-through. + LOGGER.info("Beginning calculation of rank-aggregated click-through."); + final Map rankAggregatedClickThrough = getRankAggregatedClickThrough(); + LOGGER.info("Rank-aggregated clickthrough positions: {}", rankAggregatedClickThrough.size()); + showRankAggregatedClickThrough(rankAggregatedClickThrough); + + // Calculate and index the click-through rate for query/doc pairs. + LOGGER.info("Beginning calculation of clickthrough rates."); + final Map> clickthroughRates = getClickthroughRate(); + LOGGER.info("Clickthrough rates for number of queries: {}", clickthroughRates.size()); + showClickthroughRates(clickthroughRates); + + // Generate and index the implicit judgments. + LOGGER.info("Beginning calculation of implicit judgments."); + return calculateCoec(rankAggregatedClickThrough, clickthroughRates); + + } + + public String calculateCoec(final Map rankAggregatedClickThrough, + final Map> clickthroughRates) throws Exception { + + // Calculate the COEC. + // Numerator is the total number of clicks received by a query/result pair. + // Denominator is the expected clicks (EC) that an average result would receive after being impressed i times at rank r, + // and CTR is the average CTR for each position in the results page (up to R) computed over all queries and results. + + // Format: query_id, query, document, judgment + final Collection judgments = new LinkedList<>(); + + LOGGER.info("Count of queries: {}", clickthroughRates.size()); + + for(final String userQuery : clickthroughRates.keySet()) { + + // The clickthrough rates for this one query. + // A ClickthroughRate is a document with counts of impressions and clicks. + final Collection ctrs = clickthroughRates.get(userQuery); + + // Go through each clickthrough rate for this query. + for(final ClickthroughRate ctr : ctrs) { + + double denominatorSum = 0; + + for(int rank = 0; rank < parameters.getMaxRank(); rank++) { + + // The document's mean CTR at the rank. + final double meanCtrAtRank = rankAggregatedClickThrough.getOrDefault(rank, 0.0); + + // The number of times this document was shown as this rank. + final long countOfTimesShownAtRank = openSearchHelper.getCountOfQueriesForUserQueryHavingResultInRankR(userQuery, ctr.getObjectId(), rank); + + denominatorSum += (meanCtrAtRank * countOfTimesShownAtRank); + + } + + // Numerator is sum of clicks at all ranks up to the maxRank. + final int totalNumberClicksForQueryResult = ctr.getClicks(); + + // Divide the numerator by the denominator (value). + final double judgmentValue; + + if(denominatorSum == 0) { + judgmentValue = 0.0; + } else { + judgmentValue = totalNumberClicksForQueryResult / denominatorSum; + } + + // Hash the user query to get a query ID. + final int queryId = incrementalUserQueryHash.getHash(userQuery); + + // Add the judgment to the list. + // TODO: What to do for query ID when the values are per user_query instead? + final Judgment judgment = new Judgment(String.valueOf(queryId), userQuery, ctr.getObjectId(), judgmentValue); + judgments.add(judgment); + + } + + } + + LOGGER.info("Count of user queries: {}", clickthroughRates.size()); + LOGGER.info("Count of judgments: {}", judgments.size()); + + showJudgments(judgments); + + if(!(judgments.isEmpty())) { + return openSearchHelper.indexJudgments(judgments); + } else { + return null; + } + + } + + /** + * Gets the clickthrough rates for each query and its results. + * @return A map of user_query to the clickthrough rate for each query result. + * @throws IOException Thrown when a problem accessing OpenSearch. + */ + private Map> getClickthroughRate() throws Exception { + + // For each query: + // - Get each document returned in that query (in the QueryResponse object). + // - Calculate the click-through rate for the document. (clicks/impressions) + + // TODO: Allow for a time period and for a specific application. + + final String query = "{\n" + + " \"bool\": {\n" + + " \"should\": [\n" + + " {\n" + + " \"term\": {\n" + + " \"action_name\": \"click\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"term\": {\n" + + " \"action_name\": \"impression\"\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"must\": [\n" + + " {\n" + + " \"range\": {\n" + + " \"event_attributes.position.ordinal\": {\n" + + " \"lte\": " + parameters.getMaxRank() + "\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }"; + + final BoolQueryBuilder queryBuilder = new BoolQueryBuilder().must(new WrapperQueryBuilder(query)); + final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(queryBuilder).size(1000); + final Scroll scroll = new Scroll(TimeValue.timeValueMinutes(10L)); + + final SearchRequest searchRequest = Requests + .searchRequest(SearchQualityEvaluationPlugin.UBI_EVENTS_INDEX_NAME) + .source(searchSourceBuilder) + .scroll(scroll); + + // TODO Don't use .get() + SearchResponse searchResponse = client.search(searchRequest).get(); + + String scrollId = searchResponse.getScrollId(); + SearchHit[] searchHits = searchResponse.getHits().getHits(); + + final Map> queriesToClickthroughRates = new HashMap<>(); + + while (searchHits != null && searchHits.length > 0) { + + for (final SearchHit hit : searchHits) { + + final UbiEvent ubiEvent = AccessController.doPrivileged((PrivilegedAction) () -> gson.fromJson(hit.getSourceAsString(), UbiEvent.class)); + + // We need to the hash of the query_id because two users can both search + // for "computer" and those searches will have different query IDs, but they are the same search. + final String userQuery = openSearchHelper.getUserQuery(ubiEvent.getQueryId()); + + // userQuery will be null if there is not a query for this event in ubi_queries. + if(userQuery != null) { + + // Get the clicks for this queryId from the map, or an empty list if this is a new query. + final Set clickthroughRates = queriesToClickthroughRates.getOrDefault(userQuery, new LinkedHashSet<>()); + + // Get the ClickthroughRate object for the object that was interacted with. + final ClickthroughRate clickthroughRate = clickthroughRates.stream().filter(p -> p.getObjectId().equals(ubiEvent.getEventAttributes().getObject().getObjectId())).findFirst().orElse(new ClickthroughRate(ubiEvent.getEventAttributes().getObject().getObjectId())); + + if (EVENT_CLICK.equalsIgnoreCase(ubiEvent.getActionName())) { + //LOGGER.info("Logging a CLICK on " + ubiEvent.getEventAttributes().getObject().getObjectId()); + clickthroughRate.logClick(); + } else if (EVENT_IMPRESSION.equalsIgnoreCase(ubiEvent.getActionName())) { + //LOGGER.info("Logging an IMPRESSION on " + ubiEvent.getEventAttributes().getObject().getObjectId()); + clickthroughRate.logImpression(); + } else { + LOGGER.warn("Invalid event action name: {}", ubiEvent.getActionName()); + } + + clickthroughRates.add(clickthroughRate); + queriesToClickthroughRates.put(userQuery, clickthroughRates); + // LOGGER.debug("clickthroughRate = {}", queriesToClickthroughRates.size()); + + } + + } + + final SearchScrollRequest scrollRequest = new SearchScrollRequest(scrollId); + scrollRequest.scroll(scroll); + + //LOGGER.info("Doing scroll to next results"); + // TODO: Getting a warning in the log that "QueryGroup _id can't be null, It should be set before accessing it. This is abnormal behaviour" + // I don't remember seeing this prior to 2.18.0 but it's possible I just didn't see it. + // https://github.com/opensearch-project/OpenSearch/blob/f105e4eb2ede1556b5dd3c743bea1ab9686ebccf/server/src/main/java/org/opensearch/wlm/QueryGroupTask.java#L73 + searchResponse = client.searchScroll(scrollRequest).get(); + //LOGGER.info("Scroll complete."); + + scrollId = searchResponse.getScrollId(); + + searchHits = searchResponse.getHits().getHits(); + + } + + openSearchHelper.indexClickthroughRates(queriesToClickthroughRates); + + return queriesToClickthroughRates; + + } + + /** + * Calculate the rank-aggregated click through from the UBI events. + * @return A map of positions to clickthrough rates. + * @throws IOException Thrown when a problem accessing OpenSearch. + */ + public Map getRankAggregatedClickThrough() throws Exception { + + final Map rankAggregatedClickThrough = new HashMap<>(); + + // TODO: Allow for a time period and for a specific application. + + final QueryBuilder findRangeNumber = QueryBuilders.rangeQuery("event_attributes.position.ordinal").lte(parameters.getMaxRank()); + final QueryBuilder queryBuilder = new BoolQueryBuilder().must(findRangeNumber); + + // Order the aggregations by key and not by value. + final BucketOrder bucketOrder = BucketOrder.key(true); + + final TermsAggregationBuilder positionsAggregator = AggregationBuilders.terms("By_Position").field("event_attributes.position.ordinal").order(bucketOrder).size(parameters.getMaxRank()); + final TermsAggregationBuilder actionNameAggregation = AggregationBuilders.terms("By_Action").field("action_name").subAggregation(positionsAggregator).order(bucketOrder).size(parameters.getMaxRank()); + + final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(queryBuilder) + .aggregation(actionNameAggregation) + .from(0) + .size(0); + + final SearchRequest searchRequest = new SearchRequest(SearchQualityEvaluationPlugin.UBI_EVENTS_INDEX_NAME).source(searchSourceBuilder); + final SearchResponse searchResponse = client.search(searchRequest).get(); + + final Map clickCounts = new HashMap<>(); + final Map impressionCounts = new HashMap<>(); + + final Terms actionTerms = searchResponse.getAggregations().get("By_Action"); + final Collection actionBuckets = actionTerms.getBuckets(); + + LOGGER.debug("Aggregation query: {}", searchSourceBuilder.toString()); + + for(final Terms.Bucket actionBucket : actionBuckets) { + + // Handle the "impression" bucket. + if(EVENT_IMPRESSION.equalsIgnoreCase(actionBucket.getKey().toString())) { + + final Terms positionTerms = actionBucket.getAggregations().get("By_Position"); + final Collection positionBuckets = positionTerms.getBuckets(); + + for(final Terms.Bucket positionBucket : positionBuckets) { + LOGGER.debug("Inserting impression event from position {} with click count {}", positionBucket.getKey(), (double) positionBucket.getDocCount()); + impressionCounts.put(Integer.valueOf(positionBucket.getKey().toString()), (double) positionBucket.getDocCount()); + } + + } + + // Handle the "click" bucket. + if(EVENT_CLICK.equalsIgnoreCase(actionBucket.getKey().toString())) { + + final Terms positionTerms = actionBucket.getAggregations().get("By_Position"); + final Collection positionBuckets = positionTerms.getBuckets(); + + for(final Terms.Bucket positionBucket : positionBuckets) { + LOGGER.debug("Inserting client event from position {} with click count {}", positionBucket.getKey(), (double) positionBucket.getDocCount()); + clickCounts.put(Integer.valueOf(positionBucket.getKey().toString()), (double) positionBucket.getDocCount()); + } + + } + + } + + for(int rank = 0; rank < parameters.getMaxRank(); rank++) { + + if(impressionCounts.containsKey(rank)) { + + if(clickCounts.containsKey(rank)) { + + // Calculate the CTR by dividing the number of clicks by the number of impressions. + LOGGER.info("Position = {}, Impression Counts = {}, Click Count = {}", rank, impressionCounts.get(rank), clickCounts.get(rank)); + rankAggregatedClickThrough.put(rank, clickCounts.get(rank) / impressionCounts.get(rank)); + + } else { + + // This document has impressions but no clicks, so it's CTR is zero. + LOGGER.info("Position = {}, Impression Counts = {}, Impressions but no clicks so CTR is 0", rank, clickCounts.get(rank)); + rankAggregatedClickThrough.put(rank, 0.0); + + } + + } else { + + // No impressions so the clickthrough rate is 0. + LOGGER.info("No impressions for rank {}, so using CTR of 0", rank); + rankAggregatedClickThrough.put(rank, (double) 0); + + } + + } + + openSearchHelper.indexRankAggregatedClickthrough(rankAggregatedClickThrough); + + return rankAggregatedClickThrough; + + } + + private void showJudgments(final Collection judgments) { + + for(final Judgment judgment : judgments) { + LOGGER.info(judgment.toJudgmentString()); + } + + } + + private void showClickthroughRates(final Map> clickthroughRates) { + + for(final String userQuery : clickthroughRates.keySet()) { + + LOGGER.debug("user_query: {}", userQuery); + + for(final ClickthroughRate clickthroughRate : clickthroughRates.get(userQuery)) { + LOGGER.debug("\t - {}", clickthroughRate.toString()); + } + + } + + } + + private void showRankAggregatedClickThrough(final Map rankAggregatedClickThrough) { + + for(final int position : rankAggregatedClickThrough.keySet()) { + LOGGER.info("Position: {}, # ctr: {}", position, MathUtils.round(rankAggregatedClickThrough.get(position), parameters.getRoundingDigits())); + } + + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/clickmodel/coec/CoecClickModelParameters.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/clickmodel/coec/CoecClickModelParameters.java new file mode 100644 index 0000000..36df03e --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/clickmodel/coec/CoecClickModelParameters.java @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.judgments.clickmodel.coec; + +import org.opensearch.eval.judgments.clickmodel.ClickModelParameters; + +/** + * The parameters for the {@link CoecClickModel}. + */ +public class CoecClickModelParameters extends ClickModelParameters { + + private final int maxRank; + private int roundingDigits = 3; + + /** + * Creates new parameters. + * @param maxRank The max rank to use when calculating the judgments. + */ + public CoecClickModelParameters(final int maxRank) { + this.maxRank = maxRank; + } + + /** + * Creates new parameters. + * @param maxRank The max rank to use when calculating the judgments. + * @param roundingDigits The number of decimal places to round calculated values to. + */ + public CoecClickModelParameters(final int maxRank, final int roundingDigits) { + this.maxRank = maxRank; + this.roundingDigits = roundingDigits; + } + + /** + * Gets the max rank for the implicit judgments calculation. + * @return The max rank for the implicit judgments calculation. + */ + public int getMaxRank() { + return maxRank; + } + + /** + * Gets the number of rounding digits to use for judgments. + * @return The number of rounding digits to use for judgments. + */ + public int getRoundingDigits() { + return roundingDigits; + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/ClickthroughRate.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/ClickthroughRate.java new file mode 100644 index 0000000..cef1f1f --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/ClickthroughRate.java @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.judgments.model; + +import org.opensearch.eval.utils.MathUtils; + +/** + * A query result and its number of clicks and total events. + */ +public class ClickthroughRate { + + private final String objectId; + private int clicks; + private int impressions; + + /** + * Creates a new clickthrough rate for an object. + * @param objectId The ID of the object. + */ + public ClickthroughRate(final String objectId) { + this.objectId = objectId; + this.clicks = 0; + this.impressions = 0; + } + + /** + * Creates a new clickthrough rate for an object given counts of clicks and events. + * @param objectId The object ID. + * @param clicks The count of clicks. + * @param impressions The count of events. + */ + public ClickthroughRate(final String objectId, final int clicks, final int impressions) { + this.objectId = objectId; + this.clicks = clicks; + this.impressions = impressions; + } + + @Override + public String toString() { + return "object_id: " + objectId + ", clicks: " + clicks + ", events: " + impressions + ", ctr: " + MathUtils.round(getClickthroughRate()); + } + + /** + * Log a click to this object. + * This increments clicks and events. + */ + public void logClick() { + clicks++; + } + + /** + * Log an impression to this object. + */ + public void logImpression() { + impressions++; + } + + /** + * Calculate the clickthrough rate. + * @return The clickthrough rate as clicks divided by events. + */ + public double getClickthroughRate() { + return (double) clicks / impressions; + } + + /** + * Gets the count of clicks. + * @return The count of clicks. + */ + public int getClicks() { + return clicks; + } + + /** + * Gets the count of events. + * @return The count of events. + */ + public int getImpressions() { + return impressions; + } + + /** + * Gets the object ID. + * @return The object ID. + */ + public String getObjectId() { + return objectId; + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/Judgment.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/Judgment.java new file mode 100644 index 0000000..bc9955f --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/Judgment.java @@ -0,0 +1,97 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.judgments.model; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.eval.utils.MathUtils; + +import java.util.HashMap; +import java.util.Map; + +/** + * A judgment of a search result's quality for a given query. + */ +public class Judgment { + + private static final Logger LOGGER = LogManager.getLogger(Judgment.class.getName()); + + private final String queryId; + private final String query; + private final String document; + private final double judgment; + + /** + * Creates a new judgment. + * @param queryId The query ID for the judgment. + * @param query The query for the judgment. + * @param document The document in the jdugment. + * @param judgment The judgment value. + */ + public Judgment(final String queryId, final String query, final String document, final double judgment) { + this.queryId = queryId; + this.query = query; + this.document = document; + this.judgment = judgment; + } + + public String toJudgmentString() { + return queryId + ", " + query + ", " + document + ", " + MathUtils.round(judgment); + } + + public Map getJudgmentAsMap() { + + final Map judgmentMap = new HashMap<>(); + judgmentMap.put("query_id", queryId); + judgmentMap.put("query", query); + judgmentMap.put("document_id", document); + judgmentMap.put("judgment", judgment); + + return judgmentMap; + + } + + @Override + public String toString() { + return "query_id: " + queryId + ", query: " + query + ", document: " + document + ", judgment: " + MathUtils.round(judgment); + } + + /** + * Gets the judgment's query ID. + * @return The judgment's query ID. + */ + public String getQueryId() { + return queryId; + } + + /** + * Gets the judgment's query. + * @return The judgment's query. + */ + public String getQuery() { + return query; + } + + /** + * Gets the judgment's document. + * @return The judgment's document. + */ + public String getDocument() { + return document; + } + + /** + * Gets the judgment's value. + * @return The judgment's value. + */ + public double getJudgment() { + return judgment; + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/QuerySetQuery.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/QuerySetQuery.java new file mode 100644 index 0000000..2244df4 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/QuerySetQuery.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.judgments.model; + +public class QuerySetQuery { + + private final String query; + private final long frequency; + + public QuerySetQuery(final String query, final long frequency) { + this.query = query; + this.frequency = frequency; + } + + public String getQuery() { + return query; + } + + public long getFrequency() { + return frequency; + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/ubi/event/EventAttributes.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/ubi/event/EventAttributes.java new file mode 100644 index 0000000..cf09444 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/ubi/event/EventAttributes.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.judgments.model.ubi.event; + +import com.google.gson.annotations.SerializedName; + +/** + * Attributes on an UBI event. + */ +public class EventAttributes { + + @SerializedName("object") + private EventObject object; + + @SerializedName("session_id") + private String sessionId; + + @SerializedName("position") + private Position position; + + /** + * Creates a new object. + */ + public EventAttributes() { + + } + + /** + * Gets the {@link EventObject} for the event. + * @return A {@link EventObject}. + */ + public EventObject getObject() { + return object; + } + + /** + * Sets the {@link EventObject} for the event. + * @param object A {@link EventObject}. + */ + public void setObject(EventObject object) { + this.object = object; + } + + /** + * Gets the session ID for the event. + * @return The session ID for the event. + */ + public String getSessionId() { + return sessionId; + } + + /** + * Sets the session ID for the event. + * @param sessionId The session ID for the evnet. + */ + public void setSessionId(String sessionId) { + this.sessionId = sessionId; + } + + /** + * Gets the {@link Position} associated with the event. + * @return The {@link Position} associated with the event. + */ + public Position getPosition() { + return position; + } + + /** + * Sets the {@link Position} associated with the event. + * @param position The {@link Position} associated with the event. + */ + public void setPosition(Position position) { + this.position = position; + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/ubi/event/EventObject.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/ubi/event/EventObject.java new file mode 100644 index 0000000..55595ba --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/ubi/event/EventObject.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.judgments.model.ubi.event; + +import com.google.gson.annotations.SerializedName; + +public class EventObject { + + @SerializedName("object_id_field") + private String objectIdField; + + @SerializedName("object_id") + private String objectId; + + @Override + public String toString() { + return "[" + objectIdField + ", " + objectId + "]"; + } + + /** + * Gets the object ID. + * @return The object ID. + */ + public String getObjectId() { + return objectId; + } + + /** + * Sets the object ID. + * @param objectId The object ID. + */ + public void setObjectId(String objectId) { + this.objectId = objectId; + } + + /** + * Gets the object ID field. + * @return The object ID field. + */ + public String getObjectIdField() { + return objectIdField; + } + + /** + * Sets the object ID field. + * @param objectIdField The object ID field. + */ + public void setObjectIdField(String objectIdField) { + this.objectIdField = objectIdField; + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/ubi/event/Position.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/ubi/event/Position.java new file mode 100644 index 0000000..e3ebaad --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/ubi/event/Position.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.judgments.model.ubi.event; + +import com.google.gson.annotations.SerializedName; + +/** + * A position represents the location of a search result in an event. + */ +public class Position { + + @SerializedName("ordinal") + private int ordinal; + + @Override + public String toString() { + return String.valueOf(ordinal); + } + + /** + * Gets the ordinal of the position. + * @return The ordinal of the position. + */ + public int getOrdinal() { + return ordinal; + } + + /** + * Sets the ordinal of the position. + * @param ordinal The ordinal of the position. + */ + public void setOrdinal(int ordinal) { + this.ordinal = ordinal; + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/ubi/event/UbiEvent.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/ubi/event/UbiEvent.java new file mode 100644 index 0000000..61c0f8b --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/ubi/event/UbiEvent.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.judgments.model.ubi.event; + +import com.google.gson.annotations.SerializedName; + +/** + * Creates a representation of a UBI event. + */ +public class UbiEvent { + + @SerializedName("action_name") + private String actionName; + + @SerializedName("client_id") + private String clientId; + + @SerializedName("query_id") + private String queryId; + + @SerializedName("event_attributes") + private EventAttributes eventAttributes; + + /** + * Creates a new representation of an UBI event. + */ + public UbiEvent() { + + } + + @Override + public String toString() { + return actionName + ", " + clientId + ", " + queryId + ", " + eventAttributes.getObject().toString() + ", " + eventAttributes.getPosition().getOrdinal(); + } + + /** + * Gets the name of the action. + * @return The name of the action. + */ + public String getActionName() { + return actionName; + } + + /** + * Gets the client ID. + * @return The client ID. + */ + public String getClientId() { + return clientId; + } + + /** + * Gets the query ID. + * @return The query ID. + */ + public String getQueryId() { + return queryId; + } + + /** + * Gets the event attributes. + * @return The {@link EventAttributes}. + */ + public EventAttributes getEventAttributes() { + return eventAttributes; + } + + /** + * Sets the event attributes. + * @param eventAttributes The {@link EventAttributes}. + */ + public void setEventAttributes(EventAttributes eventAttributes) { + this.eventAttributes = eventAttributes; + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/ubi/query/QueryResponse.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/ubi/query/QueryResponse.java new file mode 100644 index 0000000..5d45ee0 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/ubi/query/QueryResponse.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.judgments.model.ubi.query; + +import java.util.List; + +/** + * A query response for a {@link UbiQuery query}. + */ +public class QueryResponse { + + private final String queryId; + private final String queryResponseId; + private final List queryResponseHitIds; + + /** + * Creates a query response. + * @param queryId The ID of the query. + * @param queryResponseId The ID of the query response. + * @param queryResponseHitIds A list of IDs for the hits in the query. + */ + public QueryResponse(final String queryId, final String queryResponseId, final List queryResponseHitIds) { + this.queryId = queryId; + this.queryResponseId = queryResponseId; + this.queryResponseHitIds = queryResponseHitIds; + } + + /** + * Gets the query ID. + * @return The query ID. + */ + public String getQueryId() { + return queryId; + } + + /** + * Gets the query response ID. + * @return The query response ID. + */ + public String getQueryResponseId() { + return queryResponseId; + } + + /** + * Gets the list of query response hit IDs. + * @return A list of query response hit IDs. + */ + public List getQueryResponseHitIds() { + return queryResponseHitIds; + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/ubi/query/UbiQuery.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/ubi/query/UbiQuery.java new file mode 100644 index 0000000..0b7ca0b --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/model/ubi/query/UbiQuery.java @@ -0,0 +1,160 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.judgments.model.ubi.query; + +import com.google.gson.annotations.SerializedName; + +import java.util.Map; + +/** + * Represents a UBI query. + */ +public class UbiQuery { + + @SerializedName("timestamp") + private String timestamp; + + @SerializedName("query_id") + private String queryId; + + @SerializedName("client_id") + private String clientId; + + @SerializedName("user_query") + private String userQuery; + + @SerializedName("query") + private String query; + + @SerializedName("query_attributes") + private Map queryAttributes; + + @SerializedName("query_response") + private QueryResponse queryResponse; + + /** + * Creates a new UBI query object. + */ + public UbiQuery() { + + } + + /** + * Gets the timestamp for the query. + * @return The timestamp for the query. + */ + public String getTimestamp() { + return timestamp; + } + + /** + * Sets the timestamp for the query. + * @param timestamp The timestamp for the query. + */ + public void setTimestamp(String timestamp) { + this.timestamp = timestamp; + } + + /** + * Gets the query ID. + * @return The query ID. + */ + public String getQueryId() { + return queryId; + } + + /** + * Sets the query ID. + * @param queryId The query ID. + */ + public void setQueryId(String queryId) { + this.queryId = queryId; + } + + /** + * Sets the client ID. + * @param clientId The client ID. + */ + public void setClientId(String clientId) { + this.clientId = clientId; + } + + /** + * Gets the client ID. + * @return The client ID. + */ + public String getClientId() { + return clientId; + } + + /** + * Gets the user query. + * @return The user query. + */ + public String getUserQuery() { + return userQuery; + } + + /** + * Sets the user query. + * @param userQuery The user query. + */ + public void setUserQuery(String userQuery) { + this.userQuery = userQuery; + } + + /** + * Gets the query. + * @return The query. + */ + public String getQuery() { + return query; + } + + /** + * Sets the query. + * @param query The query. + */ + public void setQuery(String query) { + this.query = query; + } + + /** + * Sets the query attributes. + * @return The query attributes. + */ + public Map getQueryAttributes() { + return queryAttributes; + } + + /** + * Sets the query attributes. + * @param queryAttributes The query attributes. + */ + public void setQueryAttributes(Map queryAttributes) { + this.queryAttributes = queryAttributes; + } + + /** + * Gets the query responses. + * @return The query responses. + */ + public QueryResponse getQueryResponse() { + return queryResponse; + } + + /** + * Sets the query responses. + * @param queryResponse The query responses. + */ + public void setQueryResponse(QueryResponse queryResponse) { + this.queryResponse = queryResponse; + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/opensearch/OpenSearchHelper.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/opensearch/OpenSearchHelper.java new file mode 100644 index 0000000..3c391b3 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/opensearch/OpenSearchHelper.java @@ -0,0 +1,342 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.judgments.opensearch; + +import com.google.gson.Gson; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.eval.judgments.model.ClickthroughRate; +import org.opensearch.eval.judgments.model.Judgment; +import org.opensearch.eval.judgments.model.ubi.query.UbiQuery; +import org.opensearch.eval.utils.TimeUtils; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.WrapperQueryBuilder; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.UUID; + +import static org.opensearch.eval.SearchQualityEvaluationPlugin.JUDGMENTS_INDEX_NAME; +import static org.opensearch.eval.SearchQualityEvaluationPlugin.UBI_EVENTS_INDEX_NAME; +import static org.opensearch.eval.SearchQualityEvaluationPlugin.UBI_QUERIES_INDEX_NAME; +import static org.opensearch.eval.judgments.clickmodel.coec.CoecClickModel.INDEX_QUERY_DOC_CTR; +import static org.opensearch.eval.judgments.clickmodel.coec.CoecClickModel.INDEX_RANK_AGGREGATED_CTR; + +/** + * Functionality for interacting with OpenSearch. + * TODO: Move these functions out of this class. + */ +public class OpenSearchHelper { + + private static final Logger LOGGER = LogManager.getLogger(OpenSearchHelper.class.getName()); + + private final Client client; + private final Gson gson = new Gson(); + + // Used to cache the query ID->user_query to avoid unnecessary lookups to OpenSearch. + private static final Map userQueryCache = new HashMap<>(); + + public OpenSearchHelper(final Client client) { + this.client = client; + } + + /** + * Gets the user query for a given query ID. + * @param queryId The query ID. + * @return The user query. + * @throws IOException Thrown when there is a problem accessing OpenSearch. + */ + public String getUserQuery(final String queryId) throws Exception { + + // If it's in the cache just get it and return it. + if(userQueryCache.containsKey(queryId)) { + return userQueryCache.get(queryId); + } + + // Cache it and return it. + final UbiQuery ubiQuery = getQueryFromQueryId(queryId); + + // ubiQuery will be null if the query does not exist. + if(ubiQuery != null) { + + userQueryCache.put(queryId, ubiQuery.getUserQuery()); + return ubiQuery.getUserQuery(); + + } else { + + return null; + + } + + } + + /** + * Gets the query object for a given query ID. + * @param queryId The query ID. + * @return A {@link UbiQuery} object for the given query ID. + * @throws Exception Thrown if the query cannot be retrieved. + */ + public UbiQuery getQueryFromQueryId(final String queryId) throws Exception { + + LOGGER.debug("Getting query from query ID {}", queryId); + + final String query = "{\"match\": {\"query_id\": \"" + queryId + "\" }}"; + final WrapperQueryBuilder qb = QueryBuilders.wrapperQuery(query); + + // The query_id should be unique anyway, but we are limiting it to a single result anyway. + final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(qb); + searchSourceBuilder.from(0); + searchSourceBuilder.size(1); + + final String[] indexes = {UBI_QUERIES_INDEX_NAME}; + + final SearchRequest searchRequest = new SearchRequest(indexes, searchSourceBuilder); + final SearchResponse response = client.search(searchRequest).get(); + + // If this does not return a query then we cannot calculate the judgments. Each even should have a query associated with it. + if(response.getHits().getHits() != null & response.getHits().getHits().length > 0) { + + final SearchHit hit = response.getHits().getHits()[0]; + return AccessController.doPrivileged((PrivilegedAction) () -> gson.fromJson(hit.getSourceAsString(), UbiQuery.class)); + + } else { + + LOGGER.warn("No query exists for query ID {} to calculate judgments.", queryId); + return null; + + } + + } + + private Collection getQueryIdsHavingUserQuery(final String userQuery) throws Exception { + + final String query = "{\"match\": {\"user_query\": \"" + userQuery + "\" }}"; + final WrapperQueryBuilder qb = QueryBuilders.wrapperQuery(query); + + final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(qb); + + final String[] indexes = {UBI_QUERIES_INDEX_NAME}; + + final SearchRequest searchRequest = new SearchRequest(indexes, searchSourceBuilder); + final SearchResponse response = client.search(searchRequest).get(); + + final Collection queryIds = new ArrayList<>(); + + for(final SearchHit hit : response.getHits().getHits()) { + final String queryId = hit.getSourceAsMap().get("query_id").toString(); + queryIds.add(queryId); + } + + return queryIds; + + } + + public long getCountOfQueriesForUserQueryHavingResultInRankR(final String userQuery, final String objectId, final int rank) throws Exception { + + long countOfTimesShownAtRank = 0; + + // Get all query IDs matching this user query. + final Collection queryIds = getQueryIdsHavingUserQuery(userQuery); + + // For each query ID, get the events with action_name = "impression" having a match on objectId and rank (position). + for(final String queryId : queryIds) { + + final String query = "{\n" + + " \"bool\": {\n" + + " \"must\": [\n" + + " {\n" + + " \"term\": {\n" + + " \"query_id\": \"" + queryId + "\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"term\": {\n" + + " \"action_name\": \"impression\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"term\": {\n" + + " \"event_attributes.position.ordinal\": \"" + rank + "\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"term\": {\n" + + " \"event_attributes.object.object_id\": \"" + objectId + "\"\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }"; + + final WrapperQueryBuilder qb = QueryBuilders.wrapperQuery(query); + + final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(qb); + searchSourceBuilder.trackTotalHits(true); + searchSourceBuilder.size(0); + + final String[] indexes = {UBI_EVENTS_INDEX_NAME}; + + final SearchRequest searchRequest = new SearchRequest(indexes, searchSourceBuilder); + final SearchResponse response = client.search(searchRequest).get(); + + // Won't be null as long as trackTotalHits is true. + if(response.getHits().getTotalHits() != null) { + countOfTimesShownAtRank += response.getHits().getTotalHits().value; + } + + } + + LOGGER.debug("Count of {} having {} at rank {} = {}", userQuery, objectId, rank, countOfTimesShownAtRank); + + if(countOfTimesShownAtRank > 0) { + LOGGER.debug("Count of {} having {} at rank {} = {}", userQuery, objectId, rank, countOfTimesShownAtRank); + } + + return countOfTimesShownAtRank; + + } + + /** + * Index the rank-aggregated clickthrough values. + * @param rankAggregatedClickThrough A map of position to clickthrough values. + * @throws IOException Thrown when there is a problem accessing OpenSearch. + */ + public void indexRankAggregatedClickthrough(final Map rankAggregatedClickThrough) throws Exception { + + if(!rankAggregatedClickThrough.isEmpty()) { + + // TODO: Split this into multiple bulk insert requests. + + final BulkRequest request = new BulkRequest(); + + for (final int position : rankAggregatedClickThrough.keySet()) { + + final Map jsonMap = new HashMap<>(); + jsonMap.put("position", position); + jsonMap.put("ctr", rankAggregatedClickThrough.get(position)); + + final IndexRequest indexRequest = new IndexRequest(INDEX_RANK_AGGREGATED_CTR).id(UUID.randomUUID().toString()).source(jsonMap); + + request.add(indexRequest); + + } + + client.bulk(request).get(); + + } + + } + + /** + * Index the clickthrough rates. + * @param clickthroughRates A map of query IDs to a collection of {@link ClickthroughRate} objects. + * @throws IOException Thrown when there is a problem accessing OpenSearch. + */ + public void indexClickthroughRates(final Map> clickthroughRates) throws Exception { + + if(!clickthroughRates.isEmpty()) { + + final BulkRequest request = new BulkRequest(); + + for(final String userQuery : clickthroughRates.keySet()) { + + for(final ClickthroughRate clickthroughRate : clickthroughRates.get(userQuery)) { + + final Map jsonMap = new HashMap<>(); + jsonMap.put("user_query", userQuery); + jsonMap.put("clicks", clickthroughRate.getClicks()); + jsonMap.put("events", clickthroughRate.getImpressions()); + jsonMap.put("ctr", clickthroughRate.getClickthroughRate()); + jsonMap.put("object_id", clickthroughRate.getObjectId()); + + final IndexRequest indexRequest = new IndexRequest(INDEX_QUERY_DOC_CTR) + .id(UUID.randomUUID().toString()) + .source(jsonMap); + + request.add(indexRequest); + + } + + } + + client.bulk(request, new ActionListener<>() { + + @Override + public void onResponse(BulkResponse bulkItemResponses) { + if(bulkItemResponses.hasFailures()) { + LOGGER.error("Clickthrough rates were not all successfully indexed: {}", bulkItemResponses.buildFailureMessage()); + } else { + LOGGER.debug("Clickthrough rates has been successfully indexed."); + } + } + + @Override + public void onFailure(Exception ex) { + LOGGER.error("Indexing the clickthrough rates failed.", ex); + } + + }); + + } + + } + + /** + * Index the judgments. + * @param judgments A collection of {@link Judgment judgments}. + * @throws IOException Thrown when there is a problem accessing OpenSearch. + * @return The ID of the indexed judgments. + */ + public String indexJudgments(final Collection judgments) throws Exception { + + final String judgmentsId = UUID.randomUUID().toString(); + final String timestamp = TimeUtils.getTimestamp(); + + final BulkRequest bulkRequest = new BulkRequest(); + + for(final Judgment judgment : judgments) { + + final Map j = judgment.getJudgmentAsMap(); + j.put("judgments_id", judgmentsId); + j.put("timestamp", timestamp); + + final IndexRequest indexRequest = new IndexRequest(JUDGMENTS_INDEX_NAME) + .id(UUID.randomUUID().toString()) + .source(j); + + bulkRequest.add(indexRequest); + + } + + // TODO: Don't use .get() + client.bulk(bulkRequest).get(); + + return judgmentsId; + + } + +} \ No newline at end of file diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/queryhash/IncrementalUserQueryHash.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/queryhash/IncrementalUserQueryHash.java new file mode 100644 index 0000000..b893f43 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/queryhash/IncrementalUserQueryHash.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.judgments.queryhash; + +import java.util.HashMap; +import java.util.Map; + +/** + * Facilitates the hashing of user queries. + */ +public class IncrementalUserQueryHash implements UserQueryHash { + + private final Map userQueries; + private int count = 1; + + /** + * Creates a new instance of this class. + */ + public IncrementalUserQueryHash() { + this.userQueries = new HashMap<>(); + } + + @Override + public int getHash(String userQuery) { + + final int hash; + + if(userQueries.containsKey(userQuery)) { + + return userQueries.get(userQuery); + + } else { + + userQueries.put(userQuery, count); + hash = count; + count++; + + + } + + return hash; + + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/queryhash/UserQueryHash.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/queryhash/UserQueryHash.java new file mode 100644 index 0000000..714f85a --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/judgments/queryhash/UserQueryHash.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.judgments.queryhash; + +/** + * In interface for creating hashes of user queries. + */ +public interface UserQueryHash { + + /** + * Creates a unique integer given a user query. + * @param userQuery The user query. + * @return A unique integer representing the user query. + */ + int getHash(String userQuery); + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/metrics/DcgSearchMetric.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/metrics/DcgSearchMetric.java new file mode 100644 index 0000000..446696f --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/metrics/DcgSearchMetric.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.metrics; + +import java.util.List; + +/** + * Subclass of {@link SearchMetric} that calculates Discounted Cumulative Gain @ k. + */ +public class DcgSearchMetric extends SearchMetric { + + protected final List relevanceScores; + + /** + * Creates new DCG metrics. + * @param k The k value. + * @param relevanceScores A list of relevance scores. + */ + public DcgSearchMetric(final int k, final List relevanceScores) { + super(k); + this.relevanceScores = relevanceScores; + } + + @Override + public String getName() { + return "dcg_at_" + k; + } + + @Override + public double calculate() { + return calculateDcg(relevanceScores); + } + + protected double calculateDcg(final List relevanceScores) { + + // k should equal the size of relevanceScores. + + double dcg = 0.0; + + for (int i = 0; i < relevanceScores.size(); i++) { + + double d = log2(i + 2); + double n = Math.pow(2, relevanceScores.get(i)) - 1; + + if(d != 0) { + dcg += (n / d); + } + + } + return dcg; + + } + + private double log2(int N) { + return Math.log(N) / Math.log(2); + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/metrics/NdcgSearchMetric.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/metrics/NdcgSearchMetric.java new file mode 100644 index 0000000..a392732 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/metrics/NdcgSearchMetric.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.metrics; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; + +/** + * Subclass of {@link SearchMetric} that calculates Normalized Discounted Cumulative Gain @ k. + */ +public class NdcgSearchMetric extends DcgSearchMetric { + + /** + * Creates new NDCG metrics. + * @param k The k value. + * @param relevanceScores A list of relevancy scores. + */ + public NdcgSearchMetric(final int k, final List relevanceScores) { + super(k, relevanceScores); + } + + @Override + public String getName() { + return "ndcg_at_" + k; + } + + @Override + public double calculate() { + + double dcg = super.calculate(); + + if(dcg == 0) { + + // The ndcg is 0. No need to continue. + return 0; + + } else { + + final List idealRelevanceScores = new ArrayList<>(relevanceScores); + idealRelevanceScores.sort(Collections.reverseOrder()); + + double idcg = super.calculateDcg(idealRelevanceScores); + + if(idcg == 0) { + return 0; + } else { + return dcg / idcg; + } + + } + + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/metrics/PrecisionSearchMetric.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/metrics/PrecisionSearchMetric.java new file mode 100644 index 0000000..a2ac50b --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/metrics/PrecisionSearchMetric.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.metrics; + +import java.util.List; + +/** + * Subclass of {@link SearchMetric} that calculates Precision @ k. + */ +public class PrecisionSearchMetric extends SearchMetric { + + private final double threshold; + private final List relevanceScores; + + /** + * Creates new precision metrics. + * @param k The k value. + * @param threshold The threshold for assigning binary relevancy scores to non-binary scores. + * Scores greater than or equal to this value will be assigned a relevancy score of 1 (relevant). + * Scores less than this value will be assigned a relevancy score of 0 (not relevant). + * @param relevanceScores A list of relevance scores. + */ + public PrecisionSearchMetric(final int k, final double threshold, final List relevanceScores) { + super(k); + this.threshold = threshold; + this.relevanceScores = relevanceScores; + } + + @Override + public String getName() { + return "precision_at_" + k; + } + + @Override + public double calculate() { + + double numberOfRelevantItems = 0; + + for(final double relevanceScore : relevanceScores) { + if(relevanceScore >= threshold) { + numberOfRelevantItems++; + } + } + + return numberOfRelevantItems / (double) k; + + } + + /** + * Gets the threshold value. + * @return The threshold value. + */ + public double threshold() { + return threshold; + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/metrics/SearchMetric.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/metrics/SearchMetric.java new file mode 100644 index 0000000..acd580a --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/metrics/SearchMetric.java @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.metrics; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +/** + * Base class for search metrics. + */ +public abstract class SearchMetric { + + protected static final Logger LOGGER = LogManager.getLogger(SearchMetric.class); + + protected int k; + + /** + * Gets the name of the metric, i.e. ndcg. + * @return The name of the metric. + */ + public abstract String getName(); + + /** + * Calculates the metric. + * @return The value of the metric. + */ + public abstract double calculate(); + + private Double value = Double.NaN; + + /** + * Creates the metric. + * @param k The k value. + */ + public SearchMetric(final int k) { + this.k = k; + } + + /** + * Gets the k value. + * @return The k value. + */ + public int getK() { + return k; + } + + /** + * Gets the value of the metric. If the metric has not yet been calculated, + * the metric will first be calculated by calling calculate. This + * function should be used in cases where repeated access to the metrics value is + * needed without recalculating the metrics value. + * @return The value of the metric. + */ + public double getValue() { + + if(Double.isNaN(value)) { + this.value = calculate(); + } + + return value; + + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/runners/AbstractQuerySetRunner.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/runners/AbstractQuerySetRunner.java new file mode 100644 index 0000000..7ca0ad6 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/runners/AbstractQuerySetRunner.java @@ -0,0 +1,208 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.runners; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.eval.SearchQualityEvaluationPlugin; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.builder.SearchSourceBuilder; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +/** + * Base class for query set runners. Classes that extend this class + * should be specific to a search engine. See the {@link OpenSearchQuerySetRunner} for an example. + */ +public abstract class AbstractQuerySetRunner { + + private static final Logger LOGGER = LogManager.getLogger(AbstractQuerySetRunner.class); + + protected final Client client; + + public AbstractQuerySetRunner(final Client client) { + this.client = client; + } + + /** + * Runs the query set. + * @param querySetId The ID of the query set to run. + * @param judgmentsId The ID of the judgments set to use for search metric calculation. + * @param index The name of the index to run the query sets against. + * @param searchPipeline The name of the search pipeline to use, or null to not use a search pipeline. + * @param idField The field in the index that is used to uniquely identify a document. + * @param query The query that will be used to run the query set. + * @param k The k used for metrics calculation, i.e. DCG@k. + * @param threshold The cutoff for binary judgments. A judgment score greater than or equal + * to this value will be assigned a binary judgment value of 1. A judgment score + * less than this value will be assigned a binary judgment value of 0. + * @return The query set {@link QuerySetRunResult results} and calculated metrics. + */ + abstract QuerySetRunResult run(String querySetId, final String judgmentsId, final String index, final String searchPipeline, + final String idField, final String query, final int k, + final double threshold) throws Exception; + + /** + * Saves the query set results to a persistent store, which may be the search engine itself. + * @param result The {@link QuerySetRunResult results}. + */ + abstract void save(QuerySetRunResult result) throws Exception; + + /** + * Gets a query set from the index. + * @param querySetId The ID of the query set to get. + * @return The query set as a collection of maps of query to frequency + * @throws Exception Thrown if the query set cannot be retrieved. + */ + public final Collection> getQuerySet(final String querySetId) throws Exception { + + // Get the query set. + final SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(QueryBuilders.matchQuery("_id", querySetId)); + + // Will be at most one match. + sourceBuilder.from(0); + sourceBuilder.size(1); + + final SearchRequest searchRequest = new SearchRequest(SearchQualityEvaluationPlugin.QUERY_SETS_INDEX_NAME).source(sourceBuilder); + + // TODO: Don't use .get() + final SearchResponse searchResponse = client.search(searchRequest).get(); + + if(searchResponse.getHits().getHits().length > 0) { + + // The queries from the query set that will be run. + return (Collection>) searchResponse.getHits().getAt(0).getSourceAsMap().get("queries"); + + } else { + + LOGGER.error("Unable to get query set with ID {}", querySetId); + + // The query set was not found. + throw new RuntimeException("The query set with ID " + querySetId + " was not found."); + + } + + } + + /** + * Get a judgment from the index. + * @param judgmentsId The ID of the judgments to find. + * @param query The user query. + * @param documentId The document ID. + * @return The value of the judgment, or NaN if the judgment cannot be found. + */ + public Double getJudgmentValue(final String judgmentsId, final String query, final String documentId) throws Exception { + + // Find a judgment that matches the judgments_id, query_id, and document_id fields in the index. + + final BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); + boolQueryBuilder.must(QueryBuilders.termQuery("judgments_id", judgmentsId)); + boolQueryBuilder.must(QueryBuilders.termQuery("query", query)); + boolQueryBuilder.must(QueryBuilders.termQuery("document_id", documentId)); + + final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(boolQueryBuilder); + + // Will be a max of 1 result since we are getting the judgments by ID. + searchSourceBuilder.from(0); + searchSourceBuilder.size(1); + + // Only include the judgment field in the response. + final String[] includeFields = new String[] {"judgment"}; + final String[] excludeFields = new String[] {}; + searchSourceBuilder.fetchSource(includeFields, excludeFields); + + final SearchRequest searchRequest = new SearchRequest(SearchQualityEvaluationPlugin.JUDGMENTS_INDEX_NAME).source(searchSourceBuilder); + + Double judgment = Double.NaN; + + final SearchResponse searchResponse = client.search(searchRequest).get(); + + if (searchResponse.getHits().getHits().length > 0) { + + final Map j = searchResponse.getHits().getAt(0).getSourceAsMap(); + + // LOGGER.debug("Judgment contains a value: {}", j.get("judgment")); + + // TODO: Why does this not exist in some cases? + if(j.containsKey("judgment")) { + judgment = (Double) j.get("judgment"); + } + + } else { + + // No judgment for this query/doc pair exists. + judgment = Double.NaN; + + } + + return judgment; + + } + + /** + * Gets the judgments for a query / document pairs. + * @param judgmentsId The judgments collection for which the judgment to retrieve belongs. + * @param query The user query. + * @param orderedDocumentIds A list of document IDs returned for the user query. + * @param k The k used for metrics calculation, i.e. DCG@k. + * @return An ordered list of relevance scores for the query / document pairs. + * @throws Exception Thrown if a judgment cannot be retrieved. + */ + protected RelevanceScores getRelevanceScores(final String judgmentsId, final String query, final List orderedDocumentIds, final int k) throws Exception { + + // Ordered list of scores. + final List scores = new ArrayList<>(); + + // Count the number of documents without judgments. + int documentsWithoutJudgmentsCount = 0; + + // For each document (up to k), get the judgment for the document. + for (int i = 0; i < k && i < orderedDocumentIds.size(); i++) { + + final String documentId = orderedDocumentIds.get(i); + + // Find the judgment value for this combination of query and documentId from the index. + final Double judgmentValue = getJudgmentValue(judgmentsId, query, documentId); + + // If a judgment for this query/doc pair is not found, Double.NaN will be returned. + if(!Double.isNaN(judgmentValue)) { + LOGGER.info("Score found for document ID {} with judgments {} and query {} = {}", documentId, judgmentsId, query, judgmentValue); + scores.add(judgmentValue); + } else { + //LOGGER.info("No score found for document ID {} with judgments {} and query {}", documentId, judgmentsId, query); + documentsWithoutJudgmentsCount++; + } + + } + + double frogs = ((double) documentsWithoutJudgmentsCount) / orderedDocumentIds.size(); + + if(Double.isNaN(frogs)) { + frogs = 1.0; + } + + // Multiply by 100 to be a percentage. + frogs *= 100; + + LOGGER.info("frogs for query {} = {} ------- {} / {}", query, frogs, documentsWithoutJudgmentsCount, orderedDocumentIds.size()); + + return new RelevanceScores(scores, frogs); + + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/runners/OpenSearchQuerySetRunner.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/runners/OpenSearchQuerySetRunner.java new file mode 100644 index 0000000..a1f0c4f --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/runners/OpenSearchQuerySetRunner.java @@ -0,0 +1,290 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.runners; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.admin.indices.exists.indices.IndicesExistsRequest; +import org.opensearch.action.admin.indices.exists.indices.IndicesExistsResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.eval.SearchQualityEvaluationPlugin; +import org.opensearch.eval.metrics.DcgSearchMetric; +import org.opensearch.eval.metrics.NdcgSearchMetric; +import org.opensearch.eval.metrics.PrecisionSearchMetric; +import org.opensearch.eval.metrics.SearchMetric; +import org.opensearch.eval.utils.TimeUtils; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import static org.opensearch.eval.SearchQualityEvaluationRestHandler.QUERY_PLACEHOLDER; + +/** + * A {@link AbstractQuerySetRunner} for Amazon OpenSearch. + */ +public class OpenSearchQuerySetRunner extends AbstractQuerySetRunner { + + private static final Logger LOGGER = LogManager.getLogger(OpenSearchQuerySetRunner.class); + + /** + * Creates a new query set runner + * + * @param client An OpenSearch {@link Client}. + */ + public OpenSearchQuerySetRunner(final Client client) { + super(client); + } + + @Override + public QuerySetRunResult run(final String querySetId, final String judgmentsId, final String index, + final String searchPipeline, final String idField, final String query, + final int k, final double threshold) throws Exception { + + final Collection> querySet = getQuerySet(querySetId); + LOGGER.info("Found {} queries in query set {}", querySet.size(), querySetId); + + try { + + // The results of each query. + final List queryResults = new ArrayList<>(); + + for (Map queryMap : querySet) { + + // Loop over each query in the map and run each one. + for (final String userQuery : queryMap.keySet()) { + + // Replace the query placeholder with the user query. + final String parsedQuery = query.replace(QUERY_PLACEHOLDER, userQuery); + + // Build the query from the one that was passed in. + final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + + searchSourceBuilder.query(QueryBuilders.wrapperQuery(parsedQuery)); + searchSourceBuilder.from(0); + searchSourceBuilder.size(k); + + final String[] includeFields = new String[]{idField}; + final String[] excludeFields = new String[]{}; + searchSourceBuilder.fetchSource(includeFields, excludeFields); + + // LOGGER.info(searchSourceBuilder.toString()); + + final SearchRequest searchRequest = new SearchRequest(index); + searchRequest.source(searchSourceBuilder); + + if(searchPipeline != null) { + searchSourceBuilder.pipeline(searchPipeline); + searchRequest.pipeline(searchPipeline); + } + + // This is to keep OpenSearch from rejecting queries. + // TODO: Look at using the Workload Management in 2.18.0. + Thread.sleep(50); + + client.search(searchRequest, new ActionListener<>() { + + @Override + public void onResponse(final SearchResponse searchResponse) { + + final List orderedDocumentIds = new ArrayList<>(); + + for (final SearchHit hit : searchResponse.getHits().getHits()) { + + final String documentId; + + if("_id".equals(idField)) { + documentId = hit.getId(); + } else { + // TODO: Need to check this field actually exists. + documentId = hit.getSourceAsMap().get(idField).toString(); + } + + orderedDocumentIds.add(documentId); + + } + + try { + + final RelevanceScores relevanceScores = getRelevanceScores(judgmentsId, userQuery, orderedDocumentIds, k); + + // Calculate the metrics for this query. + final SearchMetric dcgSearchMetric = new DcgSearchMetric(k, relevanceScores.getRelevanceScores()); + final SearchMetric ndcgSearchmetric = new NdcgSearchMetric(k, relevanceScores.getRelevanceScores()); + final SearchMetric precisionSearchMetric = new PrecisionSearchMetric(k, threshold, relevanceScores.getRelevanceScores()); + + final Collection searchMetrics = List.of(dcgSearchMetric, ndcgSearchmetric, precisionSearchMetric); + + queryResults.add(new QueryResult(userQuery, orderedDocumentIds, k, searchMetrics, relevanceScores.getFrogs())); + + } catch (Exception ex) { + LOGGER.error("Unable to get relevance scores for judgments {} and user query {}.", judgmentsId, userQuery, ex); + } + + } + + @Override + public void onFailure(Exception ex) { + LOGGER.error("Unable to search using query: {}", searchSourceBuilder.toString(), ex); + } + }); + + } + + } + + // Calculate the search metrics for the entire query set given the individual query set metrics. + // Sum up the metrics for each query per metric type. + final int querySetSize = queryResults.size(); + final Map sumOfMetrics = new HashMap<>(); + for(final QueryResult queryResult : queryResults) { + for(final SearchMetric searchMetric : queryResult.getSearchMetrics()) { + //LOGGER.info("Summing: {} - {}", searchMetric.getName(), searchMetric.getValue()); + sumOfMetrics.merge(searchMetric.getName(), searchMetric.getValue(), Double::sum); + } + } + + // Now divide by the number of queries. + final Map querySetMetrics = new HashMap<>(); + for(final String metric : sumOfMetrics.keySet()) { + //LOGGER.info("Dividing by the query set size: {} / {}", sumOfMetrics.get(metric), querySetSize); + querySetMetrics.put(metric, sumOfMetrics.get(metric) / querySetSize); + } + + final String querySetRunId = UUID.randomUUID().toString(); + final QuerySetRunResult querySetRunResult = new QuerySetRunResult(querySetRunId, querySetId, queryResults, querySetMetrics); + + LOGGER.info("Query set run complete: {}", querySetRunId); + + return querySetRunResult; + + } catch (Exception ex) { + throw new RuntimeException("Unable to run query set.", ex); + } + + } + + @Override + public void save(final QuerySetRunResult result) throws Exception { + + // Now, index the metrics as expected by the dashboards. + + // See https://github.com/o19s/opensearch-search-quality-evaluation/blob/main/opensearch-dashboard-prototyping/METRICS_SCHEMA.md + // See https://github.com/o19s/opensearch-search-quality-evaluation/blob/main/opensearch-dashboard-prototyping/sample_data.ndjson + + final IndicesExistsRequest indicesExistsRequest = new IndicesExistsRequest(SearchQualityEvaluationPlugin.DASHBOARD_METRICS_INDEX_NAME); + + client.admin().indices().exists(indicesExistsRequest, new ActionListener<>() { + + @Override + public void onResponse(IndicesExistsResponse indicesExistsResponse) { + + if(!indicesExistsResponse.isExists()) { + + // Create the index. + // TODO: Read this mapping from a resource file instead. + final String mapping = "{\n" + + " \"properties\": {\n" + + " \"datetime\": { \"type\": \"date\", \"format\": \"strict_date_time\" },\n" + + " \"search_config\": { \"type\": \"keyword\" },\n" + + " \"query_set_id\": { \"type\": \"keyword\" },\n" + + " \"query\": { \"type\": \"keyword\" },\n" + + " \"metric\": { \"type\": \"keyword\" },\n" + + " \"value\": { \"type\": \"double\" },\n" + + " \"application\": { \"type\": \"keyword\" },\n" + + " \"evaluation_id\": { \"type\": \"keyword\" },\n" + + " \"frogs_percent\": { \"type\": \"double\" }\n" + + " }\n" + + " }"; + + // Create the judgments index. + final CreateIndexRequest createIndexRequest = new CreateIndexRequest(SearchQualityEvaluationPlugin.DASHBOARD_METRICS_INDEX_NAME).mapping(mapping); + + client.admin().indices().create(createIndexRequest, new ActionListener<>() { + + @Override + public void onResponse(CreateIndexResponse createIndexResponse) { + LOGGER.info("{} index created.", SearchQualityEvaluationPlugin.DASHBOARD_METRICS_INDEX_NAME); + } + + @Override + public void onFailure(Exception ex) { + LOGGER.error("Unable to create the {} index.", SearchQualityEvaluationPlugin.DASHBOARD_METRICS_INDEX_NAME, ex); + } + + }); + + } + + } + + @Override + public void onFailure(Exception ex) { + LOGGER.error("Unable to determine if {} index exists.", SearchQualityEvaluationPlugin.DASHBOARD_METRICS_INDEX_NAME, ex); + } + + }); + + final BulkRequest bulkRequest = new BulkRequest(); + final String timestamp = TimeUtils.getTimestamp(); + + for(final QueryResult queryResult : result.getQueryResults()) { + + for(final SearchMetric searchMetric : queryResult.getSearchMetrics()) { + + // TODO: Make sure all of these items have values. + final Map metrics = new HashMap<>(); + metrics.put("datetime", timestamp); + metrics.put("search_config", "research_1"); + metrics.put("query_set_id", result.getQuerySetId()); + metrics.put("query", queryResult.getQuery()); + metrics.put("metric", searchMetric.getName()); + metrics.put("value", searchMetric.getValue()); + metrics.put("application", "sample_data"); + metrics.put("evaluation_id", result.getRunId()); + metrics.put("frogs_percent", queryResult.getFrogs()); + + // TODO: This is using the index name from the sample data. + bulkRequest.add(new IndexRequest("sqe_metrics_sample_data").source(metrics)); + + } + + } + + client.bulk(bulkRequest, new ActionListener<>() { + + @Override + public void onResponse(BulkResponse bulkItemResponses) { + LOGGER.info("Successfully indexed {} metrics.", bulkItemResponses.getItems().length); + } + + @Override + public void onFailure(Exception ex) { + LOGGER.error("Unable to bulk index metrics.", ex); + } + + }); + + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/runners/QueryResult.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/runners/QueryResult.java new file mode 100644 index 0000000..cc2b118 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/runners/QueryResult.java @@ -0,0 +1,72 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.runners; + +import org.opensearch.eval.metrics.SearchMetric; + +import java.util.Collection; +import java.util.List; + +/** + * Contains the search results for a single query. + */ +public class QueryResult { + + private final String query; + private final List orderedDocumentIds; + private final int k; + private final Collection searchMetrics; + private final double frogs; + + /** + * Creates the search results. + * @param query The query used to generate this result. + * @param orderedDocumentIds A list of ordered document IDs in the same order as they appeared + * in the query. + * @param k The k used for metrics calculation, i.e. DCG@k. + * @param searchMetrics A collection of {@link SearchMetric} for this query. + * @param frogs The percentage of documents not having a judgment. + */ + public QueryResult(final String query, final List orderedDocumentIds, final int k, final Collection searchMetrics, final double frogs) { + this.query = query; + this.orderedDocumentIds = orderedDocumentIds; + this.k = k; + this.searchMetrics = searchMetrics; + this.frogs = frogs; + } + + /** + * Gets the query used to generate this result. + * @return The query used to generate this result. + */ + public String getQuery() { + return query; + } + + /** + * Gets the list of ordered document IDs. + * @return A list of ordered documented IDs. + */ + public List getOrderedDocumentIds() { + return orderedDocumentIds; + } + + public int getK() { + return k; + } + + public Collection getSearchMetrics() { + return searchMetrics; + } + + public double getFrogs() { + return frogs; + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/runners/QuerySetRunResult.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/runners/QuerySetRunResult.java new file mode 100644 index 0000000..280ba9c --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/runners/QuerySetRunResult.java @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.runners; + +import org.opensearch.eval.metrics.SearchMetric; +import org.opensearch.eval.utils.TimeUtils; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * The results of a query set run. + */ +public class QuerySetRunResult { + + private final String runId; + private final String querySetId; + private final List queryResults; + private final Map metrics; + private final String timestamp; + + /** + * Creates a new query set run result. A random UUID is generated as the run ID. + * @param runId A unique identifier for this query set run. + * @param querySetId A unique identifier for the query set. + * @param queryResults A collection of {@link QueryResult} that contains the queries and search results. + * @param metrics A map of metric name to value. + */ + public QuerySetRunResult(final String runId, final String querySetId, final List queryResults, final Map metrics) { + this.runId = runId; + this.querySetId = querySetId; + this.queryResults = queryResults; + this.metrics = metrics; + this.timestamp = TimeUtils.getTimestamp(); + } + + /** + * Get the run's ID. + * @return The run's ID. + */ + public String getRunId() { + return runId; + } + + /** + * Gets the query set ID. + * @return The query set ID. + */ + public String getQuerySetId() { + return querySetId; + } + + /** + * Gets the search metrics. + * @return The search metrics. + */ + public Map getSearchMetrics() { + return metrics; + } + + /** + * Gets the results of the query set run. + * @return A collection of {@link QueryResult results}. + */ + public Collection getQueryResults() { + return queryResults; + } + + public String getTimestamp() { + return timestamp; + } + + public Collection> getQueryResultsAsMap() { + + final Collection> qs = new ArrayList<>(); + + for(final QueryResult queryResult : queryResults) { + + final Map q = new HashMap<>(); + + q.put("query", queryResult.getQuery()); + q.put("document_ids", queryResult.getOrderedDocumentIds()); + q.put("frogs", queryResult.getFrogs()); + + // Calculate and add each metric to the map. + for(final SearchMetric searchMetric : queryResult.getSearchMetrics()) { + q.put(searchMetric.getName(), searchMetric.calculate()); + } + + qs.add(q); + + } + + return qs; + + } + + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/runners/RelevanceScores.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/runners/RelevanceScores.java new file mode 100644 index 0000000..d57de40 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/runners/RelevanceScores.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.runners; + +import java.util.List; + +public class RelevanceScores { + + private List relevanceScores; + private double frogs; + + public RelevanceScores(final List relevanceScores, final double frogs) { + this.relevanceScores = relevanceScores; + this.frogs = frogs; + } + + public List getRelevanceScores() { + return relevanceScores; + } + + + public double getFrogs() { + return frogs; + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/samplers/AbstractQuerySampler.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/samplers/AbstractQuerySampler.java new file mode 100644 index 0000000..3c70f0a --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/samplers/AbstractQuerySampler.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.samplers; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.eval.SearchQualityEvaluationPlugin; +import org.opensearch.eval.utils.TimeUtils; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; + +/** + * An interface for sampling UBI queries. + */ +public abstract class AbstractQuerySampler { + + private static final Logger LOGGER = LogManager.getLogger(AbstractQuerySampler.class); + + /** + * Gets the name of the sampler. + * @return The name of the sampler. + */ + public abstract String getName(); + + /** + * Samples the queries and inserts the query set into an index. + * @return A query set ID. + */ + public abstract String sample() throws Exception; + + /** + * Index the query set. + */ + protected String indexQuerySet(final NodeClient client, final String name, final String description, final String sampling, Map queries) throws Exception { + + LOGGER.info("Indexing {} queries for query set {}", queries.size(), name); + + final Collection> querySetQueries = new ArrayList<>(); + + // Convert the queries map to an object. + for(final String query : queries.keySet()) { + + // Map of the query itself to the frequency of the query. + final Map querySetQuery = new HashMap<>(); + querySetQuery.put(query, queries.get(query)); + + querySetQueries.add(querySetQuery); + + } + + final Map querySet = new HashMap<>(); + querySet.put("name", name); + querySet.put("description", description); + querySet.put("sampling", sampling); + querySet.put("queries", querySetQueries); + querySet.put("timestamp", TimeUtils.getTimestamp()); + + final String querySetId = UUID.randomUUID().toString(); + + // TODO: Create a mapping for the query set index. + final IndexRequest indexRequest = new IndexRequest().index(SearchQualityEvaluationPlugin.QUERY_SETS_INDEX_NAME) + .id(querySetId) + .source(querySet) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + client.index(indexRequest, new ActionListener<>() { + + @Override + public void onResponse(IndexResponse indexResponse) { + LOGGER.info("Indexed query set {} having name {}", querySetId, name); + } + + @Override + public void onFailure(Exception ex) { + LOGGER.error("Unable to index query set {}", querySetId, ex); + } + }); + + return querySetId; + + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/samplers/AbstractSamplerParameters.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/samplers/AbstractSamplerParameters.java new file mode 100644 index 0000000..c8d731a --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/samplers/AbstractSamplerParameters.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.samplers; + +public class AbstractSamplerParameters { + + private final String name; + private final String description; + private final String sampling; + private final int querySetSize; + + public AbstractSamplerParameters(final String name, final String description, final String sampling, final int querySetSize) { + this.name = name; + this.description = description; + this.sampling = sampling; + this.querySetSize = querySetSize; + } + + public String getName() { + return name; + } + + public String getDescription() { + return description; + } + + public String getSampling() { + return sampling; + } + + public int getQuerySetSize() { + return querySetSize; + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/samplers/AllQueriesQuerySampler.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/samplers/AllQueriesQuerySampler.java new file mode 100644 index 0000000..263d70a --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/samplers/AllQueriesQuerySampler.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.samplers; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.eval.SearchQualityEvaluationPlugin; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; + +import java.util.HashMap; +import java.util.Map; + +/** + * An implementation of {@link AbstractQuerySampler} that uses all UBI queries without any sampling. + */ +public class AllQueriesQuerySampler extends AbstractQuerySampler { + + public static final String NAME = "none"; + + private final NodeClient client; + private final AllQueriesQuerySamplerParameters parameters; + + /** + * Creates a new sampler. + * @param client The OpenSearch {@link NodeClient client}. + */ + public AllQueriesQuerySampler(final NodeClient client, final AllQueriesQuerySamplerParameters parameters) { + this.client = client; + this.parameters = parameters; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public String sample() throws Exception { + + // Get queries from the UBI queries index. + // TODO: This needs to use scroll or something else. + final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(QueryBuilders.matchAllQuery()); + searchSourceBuilder.from(0); + searchSourceBuilder.size(parameters.getQuerySetSize()); + + final SearchRequest searchRequest = new SearchRequest(SearchQualityEvaluationPlugin.UBI_QUERIES_INDEX_NAME).source(searchSourceBuilder); + + // TODO: Don't use .get() + final SearchResponse searchResponse = client.search(searchRequest).get(); + + final Map queries = new HashMap<>(); + + for(final SearchHit hit : searchResponse.getHits().getHits()) { + + final Map fields = hit.getSourceAsMap(); + queries.merge(fields.get("user_query").toString(), 1L, Long::sum); + + // Will be useful for paging once implemented. + if(queries.size() > parameters.getQuerySetSize()) { + break; + } + + } + + return indexQuerySet(client, parameters.getName(), parameters.getDescription(), parameters.getSampling(), queries); + + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/samplers/AllQueriesQuerySamplerParameters.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/samplers/AllQueriesQuerySamplerParameters.java new file mode 100644 index 0000000..3149668 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/samplers/AllQueriesQuerySamplerParameters.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.samplers; + +public class AllQueriesQuerySamplerParameters extends AbstractSamplerParameters { + + public AllQueriesQuerySamplerParameters(final String name, final String description, final String sampling, final int querySetSize) { + super(name, description, sampling, querySetSize); + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/samplers/ProbabilityProportionalToSizeAbstractQuerySampler.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/samplers/ProbabilityProportionalToSizeAbstractQuerySampler.java new file mode 100644 index 0000000..79f2c7c --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/samplers/ProbabilityProportionalToSizeAbstractQuerySampler.java @@ -0,0 +1,176 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.samplers; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchScrollRequest; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.eval.SearchQualityEvaluationPlugin; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.Scroll; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** + * An implementation of {@link AbstractQuerySampler} that uses PPTSS sampling. + * See https://opensourceconnections.com/blog/2022/10/13/how-to-succeed-with-explicit-relevance-evaluation-using-probability-proportional-to-size-sampling/ + * for more information on PPTSS. + */ +public class ProbabilityProportionalToSizeAbstractQuerySampler extends AbstractQuerySampler { + + public static final String NAME = "pptss"; + + private static final Logger LOGGER = LogManager.getLogger(ProbabilityProportionalToSizeAbstractQuerySampler.class); + + private final NodeClient client; + private final ProbabilityProportionalToSizeParameters parameters; + + /** + * Creates a new PPTSS sampler. + * @param client The OpenSearch {@link NodeClient client}. + * @param parameters The {@link ProbabilityProportionalToSizeParameters parameters} for the sampling. + */ + public ProbabilityProportionalToSizeAbstractQuerySampler(final NodeClient client, final ProbabilityProportionalToSizeParameters parameters) { + this.client = client; + this.parameters = parameters; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public String sample() throws Exception { + + // TODO: Can this be changed to an aggregation? + // An aggregation is limited (?) to 10,000 which could miss some queries. + + // Get queries from the UBI queries index. + final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(QueryBuilders.matchAllQuery()); + searchSourceBuilder.size(10000); + final Scroll scroll = new Scroll(TimeValue.timeValueMinutes(10L)); + + final SearchRequest searchRequest = new SearchRequest(SearchQualityEvaluationPlugin.UBI_QUERIES_INDEX_NAME); + searchRequest.scroll(scroll); + searchRequest.source(searchSourceBuilder); + + // TODO: Don't use .get() + SearchResponse searchResponse = client.search(searchRequest).get(); + + String scrollId = searchResponse.getScrollId(); + SearchHit[] searchHits = searchResponse.getHits().getHits(); + + final Collection userQueries = new ArrayList<>(); + + while (searchHits != null && searchHits.length > 0) { + + for(final SearchHit hit : searchHits) { + final Map fields = hit.getSourceAsMap(); + userQueries.add(fields.get("user_query").toString()); + // LOGGER.info("user queries count: {} user query: {}", userQueries.size(), fields.get("user_query").toString()); + } + + final SearchScrollRequest scrollRequest = new SearchScrollRequest(scrollId); + scrollRequest.scroll(scroll); + + // TODO: Don't use .get() + searchResponse = client.searchScroll(scrollRequest).get(); + + scrollId = searchResponse.getScrollId(); + searchHits = searchResponse.getHits().getHits(); + + } + + // LOGGER.info("User queries found: {}", userQueries); + + final Map weights = new HashMap<>(); + + // Increment the weight for each user query. + for(final String userQuery : userQueries) { + weights.merge(userQuery, 1L, Long::sum); + } + + // The total number of queries will be used to normalize the weights. + final long countOfQueries = userQueries.size(); + + // Calculate the normalized weights by dividing by the total number of queries. + final Map normalizedWeights = new HashMap<>(); + for(final String userQuery : weights.keySet()) { + normalizedWeights.put(userQuery, weights.get(userQuery) / (double) countOfQueries); + //LOGGER.info("{}: {}/{} = {}", userQuery, weights.get(userQuery), countOfQueries, normalizedWeights.get(userQuery)); + } + + // Ensure all normalized weights sum to 1. + final double sumOfNormalizedWeights = normalizedWeights.values().stream().reduce(0.0, Double::sum); + if(!compare(1.0, sumOfNormalizedWeights)) { + throw new RuntimeException("Summed normalized weights do not equal 1.0: Actual value: " + sumOfNormalizedWeights); + } else { + LOGGER.info("Summed normalized weights sum to {}", sumOfNormalizedWeights); + } + + final Map querySet = new HashMap<>(); + final Set randomNumbers = new HashSet<>(); + + // Generate random numbers between 0 and 1 for the size of the query set. + // Do this until our query set has reached the requested maximum size. + // This may require generating more random numbers than what was requested + // because removing duplicate user queries will require randomly picking more queries. + int count = 1; + + // TODO: How to short-circuit this such that if the same query gets picked over and over, the loop will never end. + final int max = 5000; + while(querySet.size() < parameters.getQuerySetSize() && count < max) { + + // Make a random number not yet used. + double random; + do { + random = Math.random(); + } while (randomNumbers.contains(random)); + randomNumbers.add(random); + + // Find the weight closest to the random weight in the map of deltas. + double smallestDelta = Integer.MAX_VALUE; + String closestQuery = null; + for(final String query : normalizedWeights.keySet()) { + final double delta = Math.abs(normalizedWeights.get(query) - random); + if(delta < smallestDelta) { + smallestDelta = delta; + closestQuery = query; + } + } + + querySet.put(closestQuery, weights.get(closestQuery)); + count++; + + //LOGGER.info("Generated random value: {}; Smallest delta = {}; Closest query = {}", random, smallestDelta, closestQuery); + + } + + return indexQuerySet(client, parameters.getName(), parameters.getDescription(), parameters.getSampling(), querySet); + + } + + public static boolean compare(double a, double b) { + return Math.abs(a - b) < 0.00001; + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/samplers/ProbabilityProportionalToSizeParameters.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/samplers/ProbabilityProportionalToSizeParameters.java new file mode 100644 index 0000000..d5e4311 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/samplers/ProbabilityProportionalToSizeParameters.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.samplers; + +public class ProbabilityProportionalToSizeParameters extends AbstractSamplerParameters { + + public ProbabilityProportionalToSizeParameters(final String name, final String description, final String sampling, final int querySetSize) { + super(name, description, sampling, querySetSize); + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/utils/MathUtils.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/utils/MathUtils.java new file mode 100644 index 0000000..d83adcd --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/utils/MathUtils.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.utils; + +public class MathUtils { + + private MathUtils() { + + } + + public static String round(final double value, final int decimalPlaces) { + double factor = Math.pow(10, decimalPlaces); + return String.valueOf(Math.round(value * factor) / factor); + } + + public static String round(final double value) { + return round(value, 3); + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/utils/TimeUtils.java b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/utils/TimeUtils.java new file mode 100644 index 0000000..1948b60 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/main/java/org/opensearch/eval/utils/TimeUtils.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.utils; + +import java.text.SimpleDateFormat; +import java.util.Date; +import java.util.Locale; +import java.util.TimeZone; + +/** + * This is a utility class. + */ +public class TimeUtils { + + /** + * Generate a timestamp in the yyyy-MM-ddTHH:mm:ss.SSSZ format. + * @return A timestamp in the yyyy-MM-ddTHH:mm:ss.SSSZ format. + */ + public static String getTimestamp() { + + final SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'", Locale.getDefault()); + formatter.setTimeZone(TimeZone.getTimeZone("UTC")); + + final Date date = new Date(); + return formatter.format(date); + + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/test/java/org/opensearch/eval/metrics/DcgSearchMetricTest.java b/opensearch-search-quality-evaluation-framework/src/test/java/org/opensearch/eval/metrics/DcgSearchMetricTest.java new file mode 100644 index 0000000..f3755f3 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/test/java/org/opensearch/eval/metrics/DcgSearchMetricTest.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.metrics; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; + +public class DcgSearchMetricTest extends OpenSearchTestCase { + + public void testCalculate() { + + final int k = 10; + final List relevanceScores = List.of(1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 0.0); + + final DcgSearchMetric dcgSearchMetric = new DcgSearchMetric(k, relevanceScores); + final double dcg = dcgSearchMetric.calculate(); + + assertEquals(13.864412483585935, dcg, 0.0); + + } + + public void testCalculateAllZeros() { + + final int k = 10; + final List relevanceScores = List.of(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0); + + final DcgSearchMetric dcgSearchMetric = new DcgSearchMetric(k, relevanceScores); + final double dcg = dcgSearchMetric.calculate(); + + assertEquals(0.0, dcg, 0.0); + + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/test/java/org/opensearch/eval/metrics/NdcgSearchMetricTest.java b/opensearch-search-quality-evaluation-framework/src/test/java/org/opensearch/eval/metrics/NdcgSearchMetricTest.java new file mode 100644 index 0000000..08795f8 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/test/java/org/opensearch/eval/metrics/NdcgSearchMetricTest.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.metrics; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; + +public class NdcgSearchMetricTest extends OpenSearchTestCase { + + public void testCalculate() { + + final int k = 10; + final List relevanceScores = List.of(1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 0.0); + + final NdcgSearchMetric ndcgSearchMetric = new NdcgSearchMetric(k, relevanceScores); + final double ndcg = ndcgSearchMetric.calculate(); + + assertEquals(0.7151195094457645, ndcg, 0.0); + + } + + public void testCalculateAllZeros() { + + final int k = 10; + final List relevanceScores = List.of(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0); + + final NdcgSearchMetric ndcgSearchMetric = new NdcgSearchMetric(k, relevanceScores); + final double ndcg = ndcgSearchMetric.calculate(); + + assertEquals(0.0, ndcg, 0.0); + + } + +} diff --git a/opensearch-search-quality-evaluation-framework/src/test/java/org/opensearch/eval/metrics/PrecisionSearchMetricTest.java b/opensearch-search-quality-evaluation-framework/src/test/java/org/opensearch/eval/metrics/PrecisionSearchMetricTest.java new file mode 100644 index 0000000..b6c260f --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/src/test/java/org/opensearch/eval/metrics/PrecisionSearchMetricTest.java @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.metrics; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; + +public class PrecisionSearchMetricTest extends OpenSearchTestCase { + + public void testCalculate() { + + final int k = 10; + final double threshold = 1.0; + final List relevanceScores = List.of(1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 0.0); + + final PrecisionSearchMetric precisionSearchMetric = new PrecisionSearchMetric(k, threshold, relevanceScores); + final double precision = precisionSearchMetric.calculate(); + + assertEquals(0.9, precision, 0.0); + + } + +} diff --git a/opensearch-search-quality-evaluation-framework/useful_queries.txt b/opensearch-search-quality-evaluation-framework/useful_queries.txt new file mode 100644 index 0000000..35c8335 --- /dev/null +++ b/opensearch-search-quality-evaluation-framework/useful_queries.txt @@ -0,0 +1,151 @@ +DELETE ubi_events +DELETE ubi_queries + +GET ubi_events/_mapping +GET ubi_events/_search + +GET ubi_queries/_mapping +GET ubi_queries/_search + +DELETE judgments +GET judgments/_search + + +PUT ubi_queries +{ + "mappings": { + "properties": { + "timestamp": { "type": "date", "format": "strict_date_time" }, + "query_id": { "type": "keyword", "ignore_above": 100 }, + "query": { "type": "text" }, + "query_response_id": { "type": "keyword", "ignore_above": 100 }, + "query_response_hit_ids": { "type": "keyword" }, + "user_query": { "type": "keyword", "ignore_above": 256 }, + "query_attributes": { "type": "flat_object" }, + "client_id": { "type": "keyword", "ignore_above": 100 }, + "application": { "type": "keyword", "ignore_above": 100 } + } + } +} + +PUT ubi_events +{ +"mappings": { + "properties": { + "application": { "type": "keyword", "ignore_above": 256 }, + "action_name": { "type": "keyword", "ignore_above": 100 }, + "client_id": { "type": "keyword", "ignore_above": 100 }, + "query_id": { "type": "keyword", "ignore_above": 100 }, + "message": { "type": "keyword", "ignore_above": 1024 }, + "message_type": { "type": "keyword", "ignore_above": 100 }, + "timestamp": { + "type": "date", + "format":"strict_date_time", + "ignore_malformed": true, + "doc_values": true + }, + "event_attributes": { + "dynamic": true, + "properties": { + "position": { + "properties": { + "ordinal": { "type": "integer" }, + "x": { "type": "integer" }, + "y": { "type": "integer" }, + "page_depth": { "type": "integer" }, + "scroll_depth": { "type": "integer" }, + "trail": { "type": "text", + "fields": { "keyword": { "type": "keyword", "ignore_above": 256 } + } + } + } + }, + "object": { + "properties": { + "internal_id": { "type": "keyword" }, + "object_id": { "type": "keyword", "ignore_above": 256 }, + "object_id_field": { "type": "keyword", "ignore_above": 100 }, + "name": { "type": "keyword", "ignore_above": 256 }, + "description": { "type": "text", + "fields": { "keyword": { "type": "keyword", "ignore_above": 256 } } + }, + "object_detail": { "type": "object" } + } + } + } + } + } + } +} + +GET ubi_events/_search +{ + "query": { + "range": { + "event_attributes.position.ordinal": { + "lte": 20 + } + } + } +} + +GET ubi_queries/_search +{ + "query": { + "term": { + "user_query": "batteries" + } + } +} + +GET ubi_events/_search +{ + "query": { + "bool": { + "must": [ + { + "term": { + "query_id": "cdc01f67-0b24-4c96-bb56-a89234f4fb0c" + } + }, + { + "term": { + "action_name": "click" + } + }, + { + "term": { + "event_attributes.position.ordinal": "0" + } + }, + { + "term": { + "event_attributes.object.object_id": "B0797J3DWK" + } + } + ] + } + } + } +} + +GET ubi_events/_search +{ + "size": 0, + "aggs": { + "By_Action": { + "terms": { + "field": "action_name", + "size": 20 + }, + "aggs": { + "By_Position": { + "terms": { + "field": "event_attributes.position.ordinal", + "size": 20 + } + } + } + } + } +} \ No newline at end of file diff --git a/settings.gradle b/settings.gradle index 7f2d692..b6e6b20 100644 --- a/settings.gradle +++ b/settings.gradle @@ -1,3 +1,3 @@ rootProject.name = 'search-evaluation-framework' include 'opensearch-search-quality-evaluation-plugin' -include 'opensearch-search-quality-implicit-judgments' \ No newline at end of file +include 'opensearch-search-quality-evaluation-framework' \ No newline at end of file