Skip to content

Commit

Permalink
first draft of implementation
Browse files Browse the repository at this point in the history
Signed-off-by: bowenlan-amzn <[email protected]>
  • Loading branch information
bowenlan-amzn committed Jan 2, 2024
1 parent d3eead8 commit debfa27
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 41 deletions.
6 changes: 5 additions & 1 deletion .idea/runConfigurations/Debug_OpenSearch.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,14 @@
import java.util.function.Supplier;

/**
* Help rewrite and optimize aggregations using range filter queries
* Currently supported types of aggregations are: DateHistogramAggregator, AutoDateHistogramAggregator, CompositeAggregator
*
* Help rewrite aggregations into filters.
* Instead of aggregation collects documents one by one, filter may count all documents that match in one pass.
* <p>
* Currently supported rewrite:
* <ul>
* <li> date histogram -> date range filter.
* Applied: DateHistogramAggregator, AutoDateHistogramAggregator, CompositeAggregator </li>
* </ul>
* @opensearch.internal
*/
public class FastFilterRewriteHelper {
Expand Down Expand Up @@ -88,6 +93,10 @@ private static long[] getIndexBoundsFromLeaves(final SearchContext context, fina
return new long[] { min, max };
}

/**
* This method also acts as a pre-condition check for the optimization,
* returns null if the optimization cannot be applied
*/
public static long[] getAggregationBounds(final SearchContext context, final String fieldName) throws IOException {
final Query cq = unwrapIntoConcreteQuery(context.query());
final long[] indexBounds = getIndexBoundsFromLeaves(context, fieldName);
Expand All @@ -109,30 +118,18 @@ public static long[] getAggregationBounds(final SearchContext context, final Str
}

/**
* Creates the range query filters for aggregations using the interval, min/max
* Creates the date range filters for aggregations using the interval, min/max
* bounds and the rounding values
*/
private static Weight[] createFilterForAggregations(
final SearchContext context,
final Rounding rounding,
final long interval,
final Rounding.Prepared preparedRounding,
final String field,
final DateFieldMapper.DateFieldType fieldType,
long low,
final long high,
final long afterKey
final long high
) throws IOException {
final OptionalLong intervalOpt = Rounding.getInterval(rounding);
if (intervalOpt.isEmpty()) {
return null;
}

final long interval = intervalOpt.getAsLong();
// afterKey is the last bucket key in previous response, while the bucket key
// is the start of the bucket values, so add the interval
if (afterKey != -1) {
low = afterKey + interval;
}
// Calculate the number of buckets using range and interval
long roundedLow = preparedRounding.round(fieldType.convertNanosToMillis(low));
long prevRounded = roundedLow;
Expand Down Expand Up @@ -179,10 +176,11 @@ protected String toString(int dimension, byte[] value) {

/**
* @param computeBounds get the lower and upper bound of the field in a shard search
* @param roundingFunction produce Rounding that will provide the interval
* @param preparedRoundingSupplier produce PreparedRounding that will do the rounding
* @param roundingFunction produce Rounding that contains interval of date range.
* Rounding is computed dynamically using the bounds in AutoDateHistogram
* @param preparedRoundingSupplier produce PreparedRounding to round values at call-time
*/
public static void buildFastFilterContext(
public static void buildFastFilter(
SearchContext context,
Function<long[], Rounding> roundingFunction,
Supplier<Rounding.Prepared> preparedRoundingSupplier,
Expand All @@ -191,19 +189,29 @@ public static void buildFastFilterContext(
) throws IOException {
assert fastFilterContext.fieldType instanceof DateFieldMapper.DateFieldType;
DateFieldMapper.DateFieldType fieldType = (DateFieldMapper.DateFieldType) fastFilterContext.fieldType;
final String fieldName = fieldType.name();
final long[] bounds = computeBounds.apply(fastFilterContext);
if (bounds != null) {
final Rounding rounding = roundingFunction.apply(bounds);
final OptionalLong intervalOpt = Rounding.getInterval(rounding);
if (intervalOpt.isEmpty()) {
return;
}
final long interval = intervalOpt.getAsLong();

// afterKey is the last bucket key in previous response, while the bucket key
// is the start of the bucket values, so add the interval
if (fastFilterContext.afterKey != -1) {
bounds[0] = fastFilterContext.afterKey + interval;
}

final Weight[] filters = FastFilterRewriteHelper.createFilterForAggregations(
context,
rounding,
interval,
preparedRoundingSupplier.get(),
fieldName,
fieldType.name(),
fieldType,
bounds[0],
bounds[1],
fastFilterContext.afterKey
bounds[1]
);
fastFilterContext.setFilters(filters);
}
Expand All @@ -221,15 +229,18 @@ public static class FastFilterContext {
private int size = Integer.MAX_VALUE; // only used by composite aggregation for pagination
private Weight[] filters = null;

private Type type = Type.UNKEYED;

/**
* @param fieldType null if the field doesn't exist
*/
public FastFilterContext(MappedFieldType fieldType) {
this.fieldType = fieldType;
}

public MappedFieldType getFieldType() {
return fieldType;
public DateFieldMapper.DateFieldType getFieldType() {
assert fieldType instanceof DateFieldMapper.DateFieldType;
return (DateFieldMapper.DateFieldType) fieldType;
}

public void setSize(int size) {
Expand All @@ -251,17 +262,26 @@ public void setMissingAndHasScript(boolean missing, boolean hasScript) {

/**
* The pre-conditions to initiate fast filter optimization on aggregations are:
* 1. The query with aggregation has to be PointRangeQuery on the same date field
* 2. No parent/sub aggregations
* 3. No missing value/bucket
* 4. No script
* <ul>
* <li>No parent/sub aggregations</li>
* <li>No missing value/bucket or script</li>
* <li>Field type is date</li>
* </ul>
*/
public boolean isRewriteable(Object parent, int subAggLength) {
if (parent == null && subAggLength == 0 && !missing && !hasScript) {
return fieldType != null && fieldType instanceof DateFieldMapper.DateFieldType;
}
return false;
}

public void setType(Type type) {
this.type = type;
}

public enum Type {
KEYED, UNKEYED
}
}

public static long getBucketOrd(long bucketOrd) {
Expand All @@ -274,17 +294,20 @@ public static long getBucketOrd(long bucketOrd) {

/**
* This should be executed for each segment
*
* @param incrementDocCount takes in the bucket key value and the bucket count
*/
public static boolean tryFastFilterAggregation(
final LeafReaderContext ctx,
FastFilterContext fastFilterContext,
final BiConsumer<Long, Integer> incrementDocCount
// TODO b can I have a function that calculates the bucket ord, so
) throws IOException {
if (fastFilterContext == null) return false;
if (fastFilterContext.filters == null) return false;

final Weight[] filters = fastFilterContext.filters;
final DateFieldMapper.DateFieldType fieldType = (DateFieldMapper.DateFieldType) fastFilterContext.fieldType;
// TODO b refactor the type conversion to the context
final int[] counts = new int[filters.length];
int i;
for (i = 0; i < filters.length; i++) {
Expand All @@ -299,10 +322,16 @@ public static boolean tryFastFilterAggregation(
int s = 0;
for (i = 0; i < filters.length; i++) {
if (counts[i] > 0) {
incrementDocCount.accept(
fieldType.convertNanosToMillis(
long key = i;
if (fastFilterContext.type == FastFilterContext.Type.UNKEYED) {
final DateFieldMapper.DateFieldType fieldType = (DateFieldMapper.DateFieldType) fastFilterContext.fieldType;
key = fieldType.convertNanosToMillis(
NumericUtils.sortableBytesToLong(((PointRangeQuery) filters[i].getQuery()).getLowerPoint(), 0)
),
);
}
incrementDocCount.accept(
// TODO b this is what should be the bucket key showing out
key,
counts[i]
);
s++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ final class CompositeAggregator extends BucketsAggregator {
}
if (fastFilterContext.isRewriteable(parent, subAggregators.length)) {
fastFilterContext.setSize(size);
FastFilterRewriteHelper.buildFastFilterContext(
FastFilterRewriteHelper.buildFastFilter(
context,
x -> dateHistogramSource.getRounding(),
() -> preparedRounding,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
package org.opensearch.search.aggregations.bucket.filter;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Bits;
import org.opensearch.common.lucene.Lucene;
Expand Down Expand Up @@ -145,6 +146,8 @@ public boolean equals(Object obj) {
private final String otherBucketKey;
private final int totalNumKeys;

private final FastFilterRewriteHelper.FastFilterContext fastFilterContext;

public FiltersAggregator(
String name,
AggregatorFactories factories,
Expand All @@ -168,12 +171,21 @@ public FiltersAggregator(
} else {
this.totalNumKeys = keys.length;
}

fastFilterContext = new FastFilterRewriteHelper.FastFilterContext(null);
fastFilterContext.setType(FastFilterRewriteHelper.FastFilterContext.Type.KEYED);
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
// no need to provide deleted docs to the filter
Weight[] filters = this.filters.get();
fastFilterContext.setFilters(filters);
boolean optimized = FastFilterRewriteHelper.tryFastFilterAggregation(ctx, fastFilterContext, (key, count) -> {
incrementBucketDocCount(bucketOrd(0, key.intValue()), count); // TODO b this key should be the index of filter
});
if (optimized) throw new CollectionTerminatedException();

final Bits[] bits = new Bits[filters.length];
for (int i = 0; i < filters.length; ++i) {
bits[i] = Lucene.asSequentialAccessBits(ctx.reader().maxDoc(), filters[i].scorerSupplier(ctx));
Expand All @@ -183,7 +195,7 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBuc
public void collect(int doc, long bucket) throws IOException {
boolean matched = false;
for (int i = 0; i < bits.length; i++) {
if (bits[i].get(doc)) {
if (bits[i].get(doc)) { // TODO b this shows bit can tell if doc matches
collectBucket(sub, doc, bucketOrd(bucket, i));
matched = true;
}
Expand Down Expand Up @@ -227,6 +239,7 @@ public InternalAggregation buildEmptyAggregation() {
return new InternalFilters(name, buckets, keyed, metadata());
}

// TODO b the way to produce the bucketOrd
final long bucketOrd(long owningBucketOrdinal, int filterOrd) {
return owningBucketOrdinal * totalNumKeys + filterOrd;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ private AutoDateHistogramAggregator(
valuesSourceConfig.script() != null
);
if (fastFilterContext.isRewriteable(parent, subAggregators.length)) {
FastFilterRewriteHelper.buildFastFilterContext(
FastFilterRewriteHelper.buildFastFilter(
context,
b -> getMinimumRounding(b[0], b[1]),
// Passing prepared rounding as supplier to ensure the correct prepared
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class DateHistogramAggregator extends BucketsAggregator implements SizedBucketAg
valuesSourceConfig.script() != null
);
if (fastFilterContext.isRewriteable(parent, subAggregators.length)) {
FastFilterRewriteHelper.buildFastFilterContext(
FastFilterRewriteHelper.buildFastFilter(
context,
x -> rounding,
() -> preparedRounding,
Expand Down

0 comments on commit debfa27

Please sign in to comment.