diff --git a/benchmarks/src/main/java/org/opensearch/benchmark/search/aggregations/BKDTreeMultiRangesTraverseBenchmark.java b/benchmarks/src/main/java/org/opensearch/benchmark/search/aggregations/BKDTreeMultiRangesTraverseBenchmark.java new file mode 100644 index 0000000000000..0b45ea1fe59ac --- /dev/null +++ b/benchmarks/src/main/java/org/opensearch/benchmark/search/aggregations/BKDTreeMultiRangesTraverseBenchmark.java @@ -0,0 +1,157 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.benchmark.search.aggregations; + +import org.openjdk.jmh.annotations.*; + +import org.apache.lucene.index.PointValues; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.IntField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.FSDirectory; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; + +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.index.mapper.NumericPointEncoder; +import org.opensearch.search.optimization.filterrewrite.Ranges; + +import java.util.*; +import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; + +import static org.opensearch.search.optimization.filterrewrite.TreeTraversal.multiRangesTraverse; + +@Warmup(iterations = 10) +@Measurement(iterations = 5) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@State(Scope.Thread) +@Fork(value = 1) +public class BKDTreeMultiRangesTraverseBenchmark { + @State(Scope.Benchmark) + public static class treeState { + @Param({ "10000", "10000000" }) + int treeSize; + + @Param({ "10000", "10000000" }) + int valMax; + + @Param({ "10", "100" }) + int buckets; + + @Param({ "12345" }) + int seed; + + private Random random; + + Path tmpDir; + Directory directory; + IndexWriter writer; + IndexReader reader; + + // multiRangesTraverse params + PointValues.PointTree pointTree; + Ranges ranges; + BiConsumer> collectRangeIDs; + int maxNumNonZeroRanges = Integer.MAX_VALUE; + + @Setup + public void setup() throws IOException { + random = new Random(seed); + tmpDir = Files.createTempDirectory("tree-test"); + directory = FSDirectory.open(tmpDir); + writer = new IndexWriter(directory, new IndexWriterConfig()); + + for (int i = 0; i < treeSize; i++) { + writer.addDocument(List.of(new IntField("val", random.nextInt(valMax), Field.Store.NO))); + } + + reader = DirectoryReader.open(writer); + + // should only contain single segment + for (LeafReaderContext lrc : reader.leaves()) { + pointTree = lrc.reader().getPointValues("val").getPointTree(); + } + + MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType("val", NumberFieldMapper.NumberType.INTEGER); + NumericPointEncoder numericPointEncoder = (NumericPointEncoder) fieldType; + + int bucketWidth = valMax/buckets; + byte[][] lowers = new byte[buckets][]; + byte[][] uppers = new byte[buckets][]; + for (int i = 0; i < buckets; i++) { + lowers[i] = numericPointEncoder.encodePoint(i * bucketWidth); + uppers[i] = numericPointEncoder.encodePoint(i * bucketWidth); + } + + ranges = new Ranges(lowers, uppers); + } + + @TearDown + public void tearDown() throws IOException { + for (String indexFile : FSDirectory.listAll(tmpDir)) { + Files.deleteIfExists(tmpDir.resolve(indexFile)); + } + Files.deleteIfExists(tmpDir); + } + } + + @Benchmark + public Map> multiRangeTraverseTree(treeState state) throws Exception { + Map> mockIDCollect = new HashMap<>(); + + BiConsumer> collectRangeIDs = (activeIndex, docIDs) -> { + if (mockIDCollect.containsKey(activeIndex)) { + mockIDCollect.get(activeIndex).addAll(docIDs); + } else { + mockIDCollect.put(activeIndex, docIDs); + } + }; + + multiRangesTraverse(state.pointTree, state.ranges, collectRangeIDs, state.maxNumNonZeroRanges); + + return mockIDCollect; + } +} diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/Ranges.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/Ranges.java index ebf4b5c9b2b9c..9cb6cc6b52ff3 100644 --- a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/Ranges.java +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/Ranges.java @@ -13,14 +13,14 @@ /** * Internal ranges representation for the filter rewrite optimization */ -final class Ranges { +public final class Ranges { byte[][] lowers; // inclusive byte[][] uppers; // exclusive int size; int byteLen; static ArrayUtil.ByteArrayComparator comparator; - Ranges(byte[][] lowers, byte[][] uppers) { + public Ranges(byte[][] lowers, byte[][] uppers) { this.lowers = lowers; this.uppers = uppers; assert lowers.length == uppers.length; diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/TreeTraversal.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/TreeTraversal.java index 2f57e981baf7f..6ee3b7831b36e 100644 --- a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/TreeTraversal.java +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/TreeTraversal.java @@ -34,7 +34,7 @@ * PointValues.IntersectVisitor} implementation is responsible for the actual visitation and * document count collection. */ -final class TreeTraversal { +public final class TreeTraversal { private TreeTraversal() {} private static final Logger logger = LogManager.getLogger(loggerName); @@ -48,7 +48,7 @@ private TreeTraversal() {} * @param maxNumNonZeroRanges the maximum number of non-zero ranges to collect * @return a {@link OptimizationContext.DebugInfo} object containing debug information about the traversal */ - static OptimizationContext.DebugInfo multiRangesTraverse( + public static OptimizationContext.DebugInfo multiRangesTraverse( final PointValues.PointTree tree, final Ranges ranges, final BiConsumer> collectRangeIDs,