From df5d9a195325b38ff2dafcf44de6a8273968239e Mon Sep 17 00:00:00 2001 From: jzonthemtn Date: Wed, 11 Dec 2024 16:37:14 -0500 Subject: [PATCH] Adding tests for dcg and ndcgs. --- .../build.gradle | 9 ++++++ .../scripts/create-query-set-no-sampling.sh | 2 +- .../scripts/run-query-set.sh | 4 +-- .../eval/metrics/DcgSearchMetric.java | 17 ++++++----- .../eval/metrics/NdcgSearchMetric.java | 7 +++-- .../eval/metrics/DcgSearchMetricTest.java | 29 +++++++++++++++++++ .../eval/metrics/NdcgSearchMetricTest.java | 29 +++++++++++++++++++ 7 files changed, 84 insertions(+), 13 deletions(-) create mode 100644 opensearch-search-quality-evaluation-plugin/src/test/java/org/opensearch/eval/metrics/DcgSearchMetricTest.java create mode 100644 opensearch-search-quality-evaluation-plugin/src/test/java/org/opensearch/eval/metrics/NdcgSearchMetricTest.java diff --git a/opensearch-search-quality-evaluation-plugin/build.gradle b/opensearch-search-quality-evaluation-plugin/build.gradle index 5c6c101..dcfa2da 100644 --- a/opensearch-search-quality-evaluation-plugin/build.gradle +++ b/opensearch-search-quality-evaluation-plugin/build.gradle @@ -3,6 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ +apply plugin: 'java' +apply plugin: 'idea' apply plugin: 'opensearch.opensearchplugin' apply plugin: 'opensearch.yaml-rest-test' @@ -19,6 +21,13 @@ ext { noticeFile = rootProject.file('NOTICE.txt') } +test { + include "**/Test*.class" + include "**/*Test.class" + include "**/*Test.class" + include "**/*TestCase.class" +} + group = 'org.opensearch' version = "${evalVersion}" diff --git a/opensearch-search-quality-evaluation-plugin/scripts/create-query-set-no-sampling.sh b/opensearch-search-quality-evaluation-plugin/scripts/create-query-set-no-sampling.sh index 7a67cd3..ace0404 100755 --- a/opensearch-search-quality-evaluation-plugin/scripts/create-query-set-no-sampling.sh +++ b/opensearch-search-quality-evaluation-plugin/scripts/create-query-set-no-sampling.sh @@ -2,4 +2,4 @@ curl -s -X DELETE "http://localhost:9200/search_quality_eval_query_sets" -curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=none&query_set_size=50" +curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=none&query_set_size=10" diff --git a/opensearch-search-quality-evaluation-plugin/scripts/run-query-set.sh b/opensearch-search-quality-evaluation-plugin/scripts/run-query-set.sh index 6b943ab..820e9e1 100755 --- a/opensearch-search-quality-evaluation-plugin/scripts/run-query-set.sh +++ b/opensearch-search-quality-evaluation-plugin/scripts/run-query-set.sh @@ -1,7 +1,7 @@ #!/bin/bash -e -QUERY_SET_ID="ca3c8091-ad48-4978-a16f-58e2cc5698b3" -JUDGMENTS_ID="97021d5d-d8c6-4147-a2f0-bbdacfe89b8a" +QUERY_SET_ID="dcbf3db4-56ea-47cd-87ea-3d13d067ae7a" +JUDGMENTS_ID="78f0e4e4-1cbf-47b4-9737-5feef65dad4d" INDEX="ecommerce" ID_FIELD="asin" K="10" diff --git a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/DcgSearchMetric.java b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/DcgSearchMetric.java index 4be76a9..446696f 100644 --- a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/DcgSearchMetric.java +++ b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/DcgSearchMetric.java @@ -42,20 +42,23 @@ protected double calculateDcg(final List relevanceScores) { // k should equal the size of relevanceScores. double dcg = 0.0; - for(int i = 1; i <= k && i <= relevanceScores.size(); i++) { - final double relevanceScore = relevanceScores.get(i - 1); - final double numerator = Math.pow(2, relevanceScore) - 1.0; - final double denominator = Math.log(i) / Math.log(i + 2); + for (int i = 0; i < relevanceScores.size(); i++) { - if (denominator != 0) { - dcg += (numerator / denominator); + double d = log2(i + 2); + double n = Math.pow(2, relevanceScores.get(i)) - 1; + + if(d != 0) { + dcg += (n / d); } } - return dcg; } + private double log2(int N) { + return Math.log(N) / Math.log(2); + } + } diff --git a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/NdcgSearchMetric.java b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/NdcgSearchMetric.java index 2bc6f75..a392732 100644 --- a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/NdcgSearchMetric.java +++ b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/NdcgSearchMetric.java @@ -8,6 +8,7 @@ */ package org.opensearch.eval.metrics; +import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.List; @@ -43,10 +44,10 @@ public double calculate() { } else { - // Make the ideal relevance scores by sorting the relevance scores largest to smallest. - relevanceScores.sort(Comparator.reverseOrder()); + final List idealRelevanceScores = new ArrayList<>(relevanceScores); + idealRelevanceScores.sort(Collections.reverseOrder()); - double idcg = super.calculateDcg(relevanceScores); + double idcg = super.calculateDcg(idealRelevanceScores); if(idcg == 0) { return 0; diff --git a/opensearch-search-quality-evaluation-plugin/src/test/java/org/opensearch/eval/metrics/DcgSearchMetricTest.java b/opensearch-search-quality-evaluation-plugin/src/test/java/org/opensearch/eval/metrics/DcgSearchMetricTest.java new file mode 100644 index 0000000..bed5b2d --- /dev/null +++ b/opensearch-search-quality-evaluation-plugin/src/test/java/org/opensearch/eval/metrics/DcgSearchMetricTest.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.metrics; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; + +public class DcgSearchMetricTest extends OpenSearchTestCase { + + public void testCalculate() { + + final int k = 5; + final List relevanceScores = List.of(1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 0.0); + + final DcgSearchMetric dcgSearchMetric = new DcgSearchMetric(k, relevanceScores); + final double dcg = dcgSearchMetric.calculate(); + + assertEquals(13.864412483585935, dcg, 0.0); + + } + +} diff --git a/opensearch-search-quality-evaluation-plugin/src/test/java/org/opensearch/eval/metrics/NdcgSearchMetricTest.java b/opensearch-search-quality-evaluation-plugin/src/test/java/org/opensearch/eval/metrics/NdcgSearchMetricTest.java new file mode 100644 index 0000000..437a6f1 --- /dev/null +++ b/opensearch-search-quality-evaluation-plugin/src/test/java/org/opensearch/eval/metrics/NdcgSearchMetricTest.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.metrics; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; + +public class NdcgSearchMetricTest extends OpenSearchTestCase { + + public void testCalculate() { + + final int k = 5; + final List relevanceScores = List.of(1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 0.0); + + final NdcgSearchMetric ndcgSearchMetric = new NdcgSearchMetric(k, relevanceScores); + final double dcg = ndcgSearchMetric.calculate(); + + assertEquals(0.7151195094457645, dcg, 0.0); + + } + +}