Skip to content

Commit

Permalink
Add query shape hash method
Browse files Browse the repository at this point in the history
Signed-off-by: David Zane <[email protected]>
  • Loading branch information
dzane17 committed Aug 6, 2024
1 parent b55d760 commit 425ce8c
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.apache.lucene.util.BytesRef;
import org.opensearch.common.hash.MurmurHash3;
import org.opensearch.core.common.io.stream.NamedWriteable;
import org.opensearch.index.query.AbstractGeometryQueryBuilder;
import org.opensearch.index.query.CommonTermsQueryBuilder;
Expand Down Expand Up @@ -76,6 +78,19 @@ public class QueryShapeGenerator {
static final Map<Class<?>, List<Function<Object, String>>> AGG_FIELD_DATA_MAP = FieldDataMapHelper.getAggFieldDataMap();
static final Map<Class<?>, List<Function<Object, String>>> SORT_FIELD_DATA_MAP = FieldDataMapHelper.getSortFieldDataMap();

/**
* Method to get query shape hash code given a source
* @param source search request source
* @param showFields whether to include field data in query shape
* @return Hash code of query shape as long (64-bit)
*/
public static long getShapeHashCode(SearchSourceBuilder source, Boolean showFields) {
String shape = buildShape(source, showFields);

final BytesRef shapeBytes = new BytesRef(shape);
return MurmurHash3.hash128(shapeBytes.bytes, 0, shapeBytes.length, 0, new MurmurHash3.Hash128()).h1;
}

/**
* Method to build search query shape given a source
* @param source search request source
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.plugin.insights;

import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.index.query.RegexpQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.search.aggregations.bucket.terms.SignificantTextAggregationBuilder;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.opensearch.search.aggregations.metrics.TopHitsAggregationBuilder;
import org.opensearch.search.aggregations.pipeline.AvgBucketPipelineAggregationBuilder;
import org.opensearch.search.aggregations.pipeline.DerivativePipelineAggregationBuilder;
import org.opensearch.search.aggregations.pipeline.MaxBucketPipelineAggregationBuilder;
import org.opensearch.search.aggregations.support.ValueType;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.sort.SortOrder;

public class SearchSourceBuilderUtils {

public static SearchSourceBuilder createDefaultSearchSourceBuilder() {
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.size(0);
// build query
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery("field1", "value2");
MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery("field2", "php");
RegexpQueryBuilder regexpQueryBuilder = new RegexpQueryBuilder("field3", "text");
RangeQueryBuilder rangeQueryBuilder = new RangeQueryBuilder("field4");
sourceBuilder.query(
new BoolQueryBuilder().must(termQueryBuilder).filter(matchQueryBuilder).should(regexpQueryBuilder).filter(rangeQueryBuilder)
);
// build aggregation
sourceBuilder.aggregation(
new TermsAggregationBuilder("agg1").userValueTypeHint(ValueType.STRING)
.field("type")
.subAggregation(new DerivativePipelineAggregationBuilder("pipeline-agg1", "bucket1"))
.subAggregation(new TermsAggregationBuilder("child-agg3").userValueTypeHint(ValueType.STRING).field("key.sub3"))
);
sourceBuilder.aggregation(new TermsAggregationBuilder("agg2").userValueTypeHint(ValueType.STRING).field("model"));
sourceBuilder.aggregation(
new TermsAggregationBuilder("agg3").userValueTypeHint(ValueType.STRING)
.field("key")
.subAggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg2", "bucket2"))
.subAggregation(new TermsAggregationBuilder("child-agg1").userValueTypeHint(ValueType.STRING).field("key.sub1"))
.subAggregation(new TermsAggregationBuilder("child-agg2").userValueTypeHint(ValueType.STRING).field("key.sub2"))
);
sourceBuilder.aggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"));
sourceBuilder.aggregation(new SignificantTextAggregationBuilder("sig_text", "agg4").filterDuplicateText(true));
sourceBuilder.aggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg4", "bucket4"));
sourceBuilder.aggregation(new DerivativePipelineAggregationBuilder("pipeline-agg3", "bucket3"));
sourceBuilder.aggregation(new AvgBucketPipelineAggregationBuilder("pipeline-agg5", "bucket5"));
// build sort
sourceBuilder.sort("color", SortOrder.DESC);
sourceBuilder.sort("vendor", SortOrder.DESC);
sourceBuilder.sort("price", SortOrder.ASC);
sourceBuilder.sort("album", SortOrder.ASC);

return sourceBuilder;
}

public static SearchSourceBuilder createQuerySearchSourceBuilder() {
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.size(0);
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery("field1", "value2");
MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery("field2", "php");
RegexpQueryBuilder regexpQueryBuilder = new RegexpQueryBuilder("field3", "text");
RangeQueryBuilder rangeQueryBuilder = new RangeQueryBuilder("field4");
sourceBuilder.query(
new BoolQueryBuilder().must(termQueryBuilder).filter(matchQueryBuilder).should(regexpQueryBuilder).filter(rangeQueryBuilder)
);
return sourceBuilder;
}

public static SearchSourceBuilder createAggregationSearchSourceBuilder() {
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();

sourceBuilder.aggregation(
new TermsAggregationBuilder("agg1").userValueTypeHint(ValueType.STRING)
.field("type")
.subAggregation(new DerivativePipelineAggregationBuilder("pipeline-agg1", "bucket1"))
.subAggregation(new TermsAggregationBuilder("child-agg3").userValueTypeHint(ValueType.STRING).field("key.sub3"))
);
sourceBuilder.aggregation(new TermsAggregationBuilder("agg2").userValueTypeHint(ValueType.STRING).field("model"));
sourceBuilder.aggregation(
new TermsAggregationBuilder("agg3").userValueTypeHint(ValueType.STRING)
.field("key")
.subAggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg2", "bucket2"))
.subAggregation(new TermsAggregationBuilder("child-agg1").userValueTypeHint(ValueType.STRING).field("key.sub1"))
.subAggregation(new TermsAggregationBuilder("child-agg2").userValueTypeHint(ValueType.STRING).field("key.sub2"))
);
sourceBuilder.aggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"));
sourceBuilder.aggregation(new SignificantTextAggregationBuilder("sig_text", "agg4").filterDuplicateText(true));
sourceBuilder.aggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg4", "bucket4"));
sourceBuilder.aggregation(new DerivativePipelineAggregationBuilder("pipeline-agg3", "bucket3"));
sourceBuilder.aggregation(new AvgBucketPipelineAggregationBuilder("pipeline-agg5", "bucket5"));

return sourceBuilder;
}

public static SearchSourceBuilder createSortSearchSourceBuilder() {
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.sort("color", SortOrder.DESC);
sourceBuilder.sort("vendor", SortOrder.DESC);
sourceBuilder.sort("price", SortOrder.ASC);
sourceBuilder.sort("album", SortOrder.ASC);
return sourceBuilder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,61 +8,15 @@

package org.opensearch.plugin.insights.core.service.categorizor;

import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.index.query.RegexpQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.plugin.insights.SearchSourceBuilderUtils;
import org.opensearch.plugin.insights.core.service.categorizer.QueryShapeGenerator;
import org.opensearch.search.aggregations.bucket.terms.SignificantTextAggregationBuilder;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.opensearch.search.aggregations.metrics.TopHitsAggregationBuilder;
import org.opensearch.search.aggregations.pipeline.AvgBucketPipelineAggregationBuilder;
import org.opensearch.search.aggregations.pipeline.DerivativePipelineAggregationBuilder;
import org.opensearch.search.aggregations.pipeline.MaxBucketPipelineAggregationBuilder;
import org.opensearch.search.aggregations.support.ValueType;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.sort.SortOrder;
import org.opensearch.test.OpenSearchTestCase;

public final class QueryShapeGeneratorTests extends OpenSearchTestCase {

public void testComplexSearch() {
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.size(0);
// build query
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery("field1", "value2");
MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery("field2", "php");
RegexpQueryBuilder regexpQueryBuilder = new RegexpQueryBuilder("field3", "text");
RangeQueryBuilder rangeQueryBuilder = new RangeQueryBuilder("field4");
sourceBuilder.query(
new BoolQueryBuilder().must(termQueryBuilder).filter(matchQueryBuilder).should(regexpQueryBuilder).filter(rangeQueryBuilder)
);
// build agg
sourceBuilder.aggregation(
new TermsAggregationBuilder("agg1").userValueTypeHint(ValueType.STRING)
.field("type")
.subAggregation(new DerivativePipelineAggregationBuilder("pipeline-agg1", "bucket1"))
.subAggregation(new TermsAggregationBuilder("child-agg3").userValueTypeHint(ValueType.STRING).field("key.sub3"))
);
sourceBuilder.aggregation(new TermsAggregationBuilder("agg2").userValueTypeHint(ValueType.STRING).field("model"));
sourceBuilder.aggregation(
new TermsAggregationBuilder("agg3").userValueTypeHint(ValueType.STRING)
.field("key")
.subAggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg2", "bucket2"))
.subAggregation(new TermsAggregationBuilder("child-agg1").userValueTypeHint(ValueType.STRING).field("key.sub1"))
.subAggregation(new TermsAggregationBuilder("child-agg2").userValueTypeHint(ValueType.STRING).field("key.sub2"))
);
sourceBuilder.aggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"));
sourceBuilder.aggregation(new SignificantTextAggregationBuilder("sig_text", "agg4").filterDuplicateText(true));
sourceBuilder.aggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg4", "bucket4"));
sourceBuilder.aggregation(new DerivativePipelineAggregationBuilder("pipeline-agg3", "bucket3"));
sourceBuilder.aggregation(new AvgBucketPipelineAggregationBuilder("pipeline-agg5", "bucket5"));
// build sort
sourceBuilder.sort("color", SortOrder.DESC);
sourceBuilder.sort("vendor", SortOrder.DESC);
sourceBuilder.sort("price", SortOrder.ASC);
sourceBuilder.sort("album", SortOrder.ASC);
SearchSourceBuilder sourceBuilder = SearchSourceBuilderUtils.createDefaultSearchSourceBuilder();

String shapeShowFieldsTrue = QueryShapeGenerator.buildShape(sourceBuilder, true);
String expectedShowFieldsTrue = "bool []\n"
Expand Down Expand Up @@ -136,15 +90,7 @@ public void testComplexSearch() {
}

public void testQueryShape() {
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.size(0);
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery("field1", "value2");
MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery("field2", "php");
RegexpQueryBuilder regexpQueryBuilder = new RegexpQueryBuilder("field3", "text");
RangeQueryBuilder rangeQueryBuilder = new RangeQueryBuilder("field4");
sourceBuilder.query(
new BoolQueryBuilder().must(termQueryBuilder).filter(matchQueryBuilder).should(regexpQueryBuilder).filter(rangeQueryBuilder)
);
SearchSourceBuilder sourceBuilder = SearchSourceBuilderUtils.createQuerySearchSourceBuilder();

String shapeShowFieldsTrue = QueryShapeGenerator.buildShape(sourceBuilder, true);
String expectedShowFieldsTrue = "bool []\n"
Expand All @@ -170,26 +116,7 @@ public void testQueryShape() {
}

public void testAggregationShape() {
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.aggregation(
new TermsAggregationBuilder("agg1").userValueTypeHint(ValueType.STRING)
.field("type")
.subAggregation(new DerivativePipelineAggregationBuilder("pipeline-agg1", "bucket1"))
.subAggregation(new TermsAggregationBuilder("child-agg3").userValueTypeHint(ValueType.STRING).field("key.sub3"))
);
sourceBuilder.aggregation(new TermsAggregationBuilder("agg2").userValueTypeHint(ValueType.STRING).field("model"));
sourceBuilder.aggregation(
new TermsAggregationBuilder("agg3").userValueTypeHint(ValueType.STRING)
.field("key")
.subAggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg2", "bucket2"))
.subAggregation(new TermsAggregationBuilder("child-agg1").userValueTypeHint(ValueType.STRING).field("key.sub1"))
.subAggregation(new TermsAggregationBuilder("child-agg2").userValueTypeHint(ValueType.STRING).field("key.sub2"))
);
sourceBuilder.aggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"));
sourceBuilder.aggregation(new SignificantTextAggregationBuilder("sig_text", "agg4").filterDuplicateText(true));
sourceBuilder.aggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg4", "bucket4"));
sourceBuilder.aggregation(new DerivativePipelineAggregationBuilder("pipeline-agg3", "bucket3"));
sourceBuilder.aggregation(new AvgBucketPipelineAggregationBuilder("pipeline-agg5", "bucket5"));
SearchSourceBuilder sourceBuilder = SearchSourceBuilderUtils.createAggregationSearchSourceBuilder();

String shapeShowFieldsTrue = QueryShapeGenerator.buildShape(sourceBuilder, true);
String expectedShowFieldsTrue = "aggregation:\n"
Expand Down Expand Up @@ -237,11 +164,7 @@ public void testAggregationShape() {
}

public void testSortShape() {
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.sort("color", SortOrder.DESC);
sourceBuilder.sort("vendor", SortOrder.DESC);
sourceBuilder.sort("price", SortOrder.ASC);
sourceBuilder.sort("album", SortOrder.ASC);
SearchSourceBuilder sourceBuilder = SearchSourceBuilderUtils.createSortSearchSourceBuilder();

String shapeShowFieldsTrue = QueryShapeGenerator.buildShape(sourceBuilder, true);
String expectedShowFieldsTrue = "sort:\n" + " asc [album]\n" + " asc [price]\n" + " desc [color]\n" + " desc [vendor]\n";
Expand All @@ -251,4 +174,32 @@ public void testSortShape() {
String expectedShowFieldsFalse = "sort:\n" + " asc\n" + " asc\n" + " desc\n" + " desc\n";
assertEquals(expectedShowFieldsFalse, shapeShowFieldsFalse);
}

public void testHashCode() {
// Create test source builders
SearchSourceBuilder defaultSourceBuilder = SearchSourceBuilderUtils.createDefaultSearchSourceBuilder();
SearchSourceBuilder querySourceBuilder = SearchSourceBuilderUtils.createQuerySearchSourceBuilder();

// showFields true
long defaultHashTrue = QueryShapeGenerator.getShapeHashCode(defaultSourceBuilder, true);
long queryHashTrue = QueryShapeGenerator.getShapeHashCode(querySourceBuilder, true);
assertEquals(defaultHashTrue, QueryShapeGenerator.getShapeHashCode(defaultSourceBuilder, true));
assertEquals(queryHashTrue, QueryShapeGenerator.getShapeHashCode(querySourceBuilder, true));
assertEquals(-3113426516628802209L, defaultHashTrue);
assertEquals(-3836794442240421775L, queryHashTrue);
assertNotEquals(defaultHashTrue, queryHashTrue);

// showFields false
long defaultHashFalse = QueryShapeGenerator.getShapeHashCode(defaultSourceBuilder, false);
long queryHashFalse = QueryShapeGenerator.getShapeHashCode(querySourceBuilder, false);
assertEquals(defaultHashFalse, QueryShapeGenerator.getShapeHashCode(defaultSourceBuilder, false));
assertEquals(queryHashFalse, QueryShapeGenerator.getShapeHashCode(querySourceBuilder, false));
assertEquals(-7879433503859751764L, defaultHashFalse);
assertEquals(-7008121725161549992L, queryHashFalse);
assertNotEquals(defaultHashFalse, queryHashFalse);

// Compare field data on vs off
assertNotEquals(defaultHashTrue, defaultHashFalse);
assertNotEquals(queryHashTrue, queryHashFalse);
}
}

0 comments on commit 425ce8c

Please sign in to comment.