Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
Signed-off-by: bowenlan-amzn <[email protected]>
  • Loading branch information
bowenlan-amzn committed Jun 5, 2024
1 parent d590081 commit 58e5281
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ public static class FastFilterContext {
private AggregationType aggregationType;
private final SearchContext context;

private String fieldName;
private MappedFieldType fieldType;
private Ranges ranges;

Expand Down Expand Up @@ -228,7 +227,6 @@ public boolean isRewriteable(final Object parent, final int subAggLength) {

public void buildRanges(MappedFieldType fieldType) throws IOException {
assert ranges == null : "Ranges should only be built once at shard level, but they are already built";
this.fieldName = fieldType.name();
this.fieldType = fieldType;
this.ranges = this.aggregationType.buildRanges(context, fieldType);
if (ranges != null) {
Expand All @@ -249,6 +247,9 @@ private Ranges buildRanges(LeafReaderContext leaf) throws IOException {
* Try to populate the bucket doc counts for aggregation
* <p>
* Usage: invoked at segment level — in getLeafCollector of aggregator
*
* @param bucketOrd bucket ordinal producer
* @param incrementDocCount consume the doc_count results for certain ordinal
*/
public boolean tryFastFilterAggregation(
final LeafReaderContext ctx,
Expand All @@ -262,7 +263,7 @@ public boolean tryFastFilterAggregation(

if (ctx.reader().hasDeletions()) return false;

PointValues values = ctx.reader().getPointValues(this.fieldName);
PointValues values = ctx.reader().getPointValues(this.fieldType.name());
if (values == null) return false;
// only proceed if every document corresponds to exactly one point
if (values.getDocCount() != values.size()) return false;
Expand Down Expand Up @@ -458,13 +459,11 @@ public DebugInfo tryFastFilterAggregation(
*/
public static class RangeAggregationType implements AggregationType {

private final ValuesSource.Numeric source;
private final ValuesSourceConfig config;
private final Range[] ranges;
private FieldTypeEnum fieldTypeEnum;

public RangeAggregationType(ValuesSourceConfig config, Range[] ranges) {
this.source = (ValuesSource.Numeric) config.getValuesSource();
this.config = config;
this.ranges = ranges;
}
Expand All @@ -482,7 +481,7 @@ public boolean isRewriteable(Object parent, int subAggLength) {
return false;
}

if (source instanceof ValuesSource.Numeric.FieldData) {
if (config.getValuesSource() instanceof ValuesSource.Numeric.FieldData) {
// ranges are already sorted by from and then to
// we want ranges not overlapping with each other
double prevTo = ranges[0].getTo();
Expand All @@ -499,7 +498,7 @@ public boolean isRewriteable(Object parent, int subAggLength) {
}

@Override
public Ranges buildRanges(SearchContext ctx, MappedFieldType fieldType) throws IOException {
public Ranges buildRanges(SearchContext ctx, MappedFieldType fieldType) {
int byteLen = this.fieldTypeEnum.getByteLen();
String pointType = this.fieldTypeEnum.getPointType();

Expand Down Expand Up @@ -604,26 +603,8 @@ static FieldTypeEnum fromTypeName(String typeName) {
}
}

public static BigInteger convertDoubleToBigInteger(double value) {
// we use big integer to represent unsigned long
BigInteger maxUnsignedLong = BigInteger.valueOf(2).pow(64).subtract(BigInteger.ONE);

if (Double.isNaN(value)) {
return BigInteger.ZERO;
} else if (Double.isInfinite(value)) {
if (value > 0) {
return maxUnsignedLong;
} else {
return BigInteger.ZERO;
}
} else {
BigDecimal bigDecimal = BigDecimal.valueOf(value);
return bigDecimal.toBigInteger();
}
}

@Override
public Ranges buildRanges(LeafReaderContext leaf, SearchContext ctx, MappedFieldType fieldType) throws IOException {
public Ranges buildRanges(LeafReaderContext leaf, SearchContext ctx, MappedFieldType fieldType) {
throw new UnsupportedOperationException("Range aggregation should not build ranges at segment level");
}

Expand All @@ -645,6 +626,24 @@ public DebugInfo tryFastFilterAggregation(
}
}

public static BigInteger convertDoubleToBigInteger(double value) {
// we use big integer to represent unsigned long
BigInteger maxUnsignedLong = BigInteger.valueOf(2).pow(64).subtract(BigInteger.ONE);

if (Double.isNaN(value)) {
return BigInteger.ZERO;
} else if (Double.isInfinite(value)) {
if (value > 0) {
return maxUnsignedLong;
} else {
return BigInteger.ZERO;
}
} else {
BigDecimal bigDecimal = BigDecimal.valueOf(value);
return bigDecimal.toBigInteger();
}
}

public static boolean isCompositeAggRewriteable(CompositeValuesSourceConfig[] sourceConfigs) {
return sourceConfigs.length == 1 && sourceConfigs[0].valuesSource() instanceof RoundingValuesSource;
}
Expand Down Expand Up @@ -776,13 +775,6 @@ public int firstRangeIndex(byte[] globalMin, byte[] globalMax) {
int i = 0;
while (compareByteValue(uppers[i], globalMin) <= 0) {
i++;
// special case
// lower and upper may be same for the last range
// if (i == size - 1) {
// if (compareByteValue(lowers[i], globalMin) >= 0) {
// return i;
// }
// }
if (i >= size) {
return -1;
}
Expand Down Expand Up @@ -957,27 +949,18 @@ private boolean withinLowerBound(byte[] value) {
}

private boolean withinUpperBound(byte[] value) {
// special case
// lower and upper may be same for the last range
// if (activeIndex == ranges.size - 1) {
// return Ranges.compareByteValue(value, activeRange[1]) <= 0;
// }
return Ranges.withinUpperBound(value, activeRange[1]);
}

private boolean withinRange(byte[] value) {
return withinLowerBound(value) && withinUpperBound(value);
}

private boolean cellCross(byte[] min, byte[] max) {
return Ranges.compareByteValue(activeRange[0], min) > 0 || withinUpperBound(max);
}
}

/**
* Contains debug info of BKD traversal to show in profile
*/
public static class DebugInfo {
private static class DebugInfo {
private int leaf = 0; // leaf node visited
private int inner = 0; // inner node visited

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ public ScoreMode scoreMode() {

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {

boolean optimized = fastFilterContext.tryFastFilterAggregation(
ctx,
this::incrementBucketDocCount,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,32 @@
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.BytesRef;
import org.opensearch.common.CheckedConsumer;
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.indices.breaker.NoneCircuitBreakerService;
import org.opensearch.index.mapper.DateFieldMapper;
import org.opensearch.index.mapper.KeywordFieldMapper;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.search.aggregations.AggregationBuilder;
import org.opensearch.search.aggregations.AggregatorTestCase;
import org.opensearch.search.aggregations.CardinalityUpperBound;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.MultiBucketConsumerService;
import org.opensearch.search.aggregations.bucket.FastFilterRewriteHelper;
import org.opensearch.search.aggregations.pipeline.PipelineAggregator;
import org.opensearch.search.aggregations.support.AggregationInspectionHelper;

import java.io.IOException;
import java.math.BigInteger;
import java.time.ZoneOffset;
import java.time.ZonedDateTime;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;

import static java.util.Collections.singleton;
import static org.opensearch.test.InternalAggregationTestCase.DEFAULT_MAX_BUCKETS;
import static org.hamcrest.Matchers.equalTo;

public class RangeAggregatorTests extends AggregatorTestCase {
Expand All @@ -74,6 +84,10 @@ public class RangeAggregatorTests extends AggregatorTestCase {
private static final String DATE_FIELD_NAME = "date";

private static final String DOUBLE_FIELD_NAME = "double";
private static final String FLOAT_FIELD_NAME = "float";
private static final String HALF_FLOAT_FIELD_NAME = "half_float";
private static final String UNSIGNED_LONG_FIELD_NAME = "unsigned_long";
private static final String SCALED_FLOAT_FIELD_NAME = "scaled_float";

public void testNoMatchingField() throws IOException {
testCase(new MatchAllDocsQuery(), iw -> {
Expand Down Expand Up @@ -313,15 +327,38 @@ public void testSubAggCollectsFromManyBucketsIfManyRanges() throws IOException {
});
}

public void testDoubleType() throws IOException {
public void testOverlappingRanges() throws IOException {
RangeAggregationBuilder aggregationBuilder = new RangeAggregationBuilder("range").field(DOUBLE_FIELD_NAME)
.addRange(1, 2)
.addRange(2, 3);
.addRange(1, 1.5)
.addRange(0, 0.5);

testRewriteOptimizationCase(aggregationBuilder, DoublePoint.newRangeQuery(DOUBLE_FIELD_NAME, 0, 5), indexWriter -> {
indexWriter.addDocument(singleton(new DoubleField(DOUBLE_FIELD_NAME, 0.1, Field.Store.NO)));
indexWriter.addDocument(singleton(new DoubleField(DOUBLE_FIELD_NAME, 1.1, Field.Store.NO)));
indexWriter.addDocument(singleton(new DoubleField(DOUBLE_FIELD_NAME, 2.1, Field.Store.NO)));
}, range -> {
List<? extends InternalRange.Bucket> ranges = range.getBuckets();
assertEquals(3, ranges.size());
assertEquals("0.0-0.5", ranges.get(0).getKeyAsString());
assertEquals(1, ranges.get(0).getDocCount());
assertEquals("1.0-1.5", ranges.get(1).getKeyAsString());
assertEquals(1, ranges.get(1).getDocCount());
assertEquals("1.0-2.0", ranges.get(2).getKeyAsString());
assertEquals(1, ranges.get(2).getDocCount());
assertTrue(AggregationInspectionHelper.hasValue(range));
}, new NumberFieldMapper.NumberFieldType(DOUBLE_FIELD_NAME, NumberFieldMapper.NumberType.DOUBLE), false);
}

public void testDoubleType() throws IOException {
RangeAggregationBuilder aggregationBuilder = new RangeAggregationBuilder("range").field(DOUBLE_FIELD_NAME)
.addRange(1, 2)
.addRange(2, 3);

testRewriteOptimizationCase(aggregationBuilder, new MatchAllDocsQuery(), indexWriter -> {
indexWriter.addDocument(NumberFieldMapper.NumberType.DOUBLE.createFields(DOUBLE_FIELD_NAME, 0.1, true, true, false));
indexWriter.addDocument(NumberFieldMapper.NumberType.DOUBLE.createFields(DOUBLE_FIELD_NAME, 1.1, true, true, false));
indexWriter.addDocument(NumberFieldMapper.NumberType.DOUBLE.createFields(DOUBLE_FIELD_NAME, 2.1, true, true, false));
}, range -> {
List<? extends InternalRange.Bucket> ranges = range.getBuckets();
assertEquals(2, ranges.size());
Expand All @@ -330,7 +367,26 @@ public void testDoubleType() throws IOException {
assertEquals("2.0-3.0", ranges.get(1).getKeyAsString());
assertEquals(1, ranges.get(1).getDocCount());
assertTrue(AggregationInspectionHelper.hasValue(range));
}, new NumberFieldMapper.NumberFieldType(DOUBLE_FIELD_NAME, NumberFieldMapper.NumberType.DOUBLE));
}, new NumberFieldMapper.NumberFieldType(DOUBLE_FIELD_NAME, NumberFieldMapper.NumberType.DOUBLE), true);
}

public void testConvertDoubleToBigInteger() {
double value = Double.NaN;
BigInteger result = FastFilterRewriteHelper.convertDoubleToBigInteger(value);
assertEquals(BigInteger.ZERO, result);

value = Double.POSITIVE_INFINITY;
result = FastFilterRewriteHelper.convertDoubleToBigInteger(value);
BigInteger maxUnsignedLong = BigInteger.valueOf(2).pow(64).subtract(BigInteger.ONE);
assertEquals(maxUnsignedLong, result);

value = Double.NEGATIVE_INFINITY;
result = FastFilterRewriteHelper.convertDoubleToBigInteger(value);
assertEquals(BigInteger.ZERO, result);

value = 123.456;
result = FastFilterRewriteHelper.convertDoubleToBigInteger(value);
assertEquals("123", result.toString());
}

private void testCase(
Expand Down Expand Up @@ -391,7 +447,8 @@ private void testRewriteOptimizationCase(
Query query,
CheckedConsumer<IndexWriter, IOException> buildIndex,
Consumer<InternalRange<? extends InternalRange.Bucket, ? extends InternalRange>> verify,
MappedFieldType fieldType
MappedFieldType fieldType,
boolean optimized
) throws IOException {
try (Directory directory = newDirectory()) {
try (IndexWriter indexWriter = new IndexWriter(directory, new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec()))) {
Expand All @@ -401,14 +458,54 @@ private void testRewriteOptimizationCase(
try (IndexReader indexReader = DirectoryReader.open(directory)) {
IndexSearcher indexSearcher = newSearcher(indexReader, true, true);

InternalRange<? extends InternalRange.Bucket, ? extends InternalRange> agg = searchAndReduce(
indexSearcher,
query,
aggregationBuilder,
fieldType
CountingAggregator aggregator = createCountingAggregator(query, aggregationBuilder, indexSearcher, fieldType);
aggregator.preCollection();
indexSearcher.search(query, aggregator);
aggregator.postCollection();

MultiBucketConsumerService.MultiBucketConsumer reduceBucketConsumer = new MultiBucketConsumerService.MultiBucketConsumer(
Integer.MAX_VALUE,
new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST)
);
InternalAggregation.ReduceContext context = InternalAggregation.ReduceContext.forFinalReduction(
aggregator.context().bigArrays(),
getMockScriptService(),
reduceBucketConsumer,
PipelineAggregator.PipelineTree.EMPTY
);
InternalRange topLevel = (InternalRange) aggregator.buildTopLevel();
InternalRange agg = (InternalRange) topLevel.reduce(Collections.singletonList(topLevel), context);
doAssertReducedMultiBucketConsumer(agg, reduceBucketConsumer);

verify.accept(agg);

if (optimized) {
assertEquals(0, aggregator.getCollectCount().get());
} else {
assertTrue(aggregator.getCollectCount().get() > 0);
}
}
}
}

protected CountingAggregator createCountingAggregator(
Query query,
AggregationBuilder builder,
IndexSearcher searcher,
MappedFieldType... fieldTypes
) throws IOException {
return new CountingAggregator(
new AtomicInteger(),
createAggregator(
query,
builder,
searcher,
new MultiBucketConsumerService.MultiBucketConsumer(
DEFAULT_MAX_BUCKETS,
new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST)
),
fieldTypes
)
);
}
}

0 comments on commit 58e5281

Please sign in to comment.