diff --git a/server/src/main/java/org/opensearch/index/codec/freshstartree/builder/BaseSingleTreeBuilder.java b/server/src/main/java/org/opensearch/index/codec/freshstartree/builder/BaseSingleTreeBuilder.java index 7cf7106588c19..93b52fc583592 100644 --- a/server/src/main/java/org/opensearch/index/codec/freshstartree/builder/BaseSingleTreeBuilder.java +++ b/server/src/main/java/org/opensearch/index/codec/freshstartree/builder/BaseSingleTreeBuilder.java @@ -155,7 +155,7 @@ public abstract class BaseSingleTreeBuilder { } // TODO : Removing hardcoding - _maxLeafRecords = 1000; // builderConfig.getMaxLeafRecords(); + _maxLeafRecords = 100; // builderConfig.getMaxLeafRecords(); } private void constructStarTree(StarTreeBuilderUtils.TreeNode node, int startDocId, int endDocId) throws IOException { diff --git a/server/src/main/java/org/opensearch/index/codec/freshstartree/query/StarTreeFilter.java b/server/src/main/java/org/opensearch/index/codec/freshstartree/query/StarTreeFilter.java index 120b196ff8ef5..db0446717ddb9 100644 --- a/server/src/main/java/org/opensearch/index/codec/freshstartree/query/StarTreeFilter.java +++ b/server/src/main/java/org/opensearch/index/codec/freshstartree/query/StarTreeFilter.java @@ -96,10 +96,7 @@ public StarTreeFilter( // 1706268600 / (60*60*1000) * (60*60*1000) public DocIdSetIterator getStarTreeResult() throws IOException { - long startTime = System.nanoTime(); StarTreeResult starTreeResult = traverseStarTree(); - logger.info("Star tree traversal took : {}", System.nanoTime() - startTime); - startTime = System.nanoTime(); List andIterators = new ArrayList<>(); andIterators.add(starTreeResult._matchedDocIds.build().iterator()); DocIdSetIterator docIdSetIterator = andIterators.get(0); @@ -109,31 +106,33 @@ public DocIdSetIterator getStarTreeResult() throws IOException { DocIdSetBuilder builder = new DocIdSetBuilder(starTreeResult.numOfMatchedDocs); List> compositePredicateEvaluators = _predicateEvaluators.get(remainingPredicateColumn); SortedNumericDocValues ndv = this.dimValueMap.get(remainingPredicateColumn); - long ndvStartTime1 = System.nanoTime(); + List docIds = new ArrayList<>(); while (docIdSetIterator.nextDoc() != NO_MORE_DOCS) { docCount++; - long ndvStartTime = System.nanoTime(); int docID = docIdSetIterator.docID(); - ndv.advanceExact(docID); - long value = ndv.nextValue(); - logger.info("Advancing took : {}", System.nanoTime() - ndvStartTime); - ndvStartTime = System.nanoTime(); - for (Predicate compositePredicateEvaluator : compositePredicateEvaluators) { - // TODO : this might be expensive as its done against all doc values docs - if (compositePredicateEvaluator.test(value)) { - builder.grow(1).add(docID); - break; + if(ndv.advanceExact(docID)) { + final int valuesCount = ndv.docValueCount(); + long value = ndv.nextValue(); + for (Predicate compositePredicateEvaluator : compositePredicateEvaluators) { + // TODO : this might be expensive as its done against all doc values docs + if (compositePredicateEvaluator.test(value)) { + docIds.add(docID); + for (int i = 0; i < valuesCount - 1; i++) { + while(docIdSetIterator.nextDoc() != NO_MORE_DOCS) { + docIds.add(docIdSetIterator.docID()); + } + } + break; + } } } - logger.info("Predicate took : {}", System.nanoTime() - ndvStartTime); } - logger.info("Overall ndv took : {}", System.nanoTime() - ndvStartTime1); - long buildTime = System.nanoTime(); + DocIdSetBuilder.BulkAdder adder = builder.grow(docIds.size()); + for(int docID : docIds) { + adder.add(docID); + } docIdSetIterator = builder.build().iterator(); - logger.info("Builder took : {}", System.nanoTime() - buildTime); } - logger.info("Doc value num : {}" , docCount); - logger.info("Rest of tree traversal took : {}", System.nanoTime() - startTime); return docIdSetIterator; } @@ -168,6 +167,7 @@ private StarTreeResult traverseStarTree() throws IOException { } StarTreeNode starTreeNode; + List docIds = new ArrayList<>(); while ((starTreeNode = queue.poll()) != null) { int dimensionId = starTreeNode.getDimensionId(); if (dimensionId > currentDimensionId) { @@ -183,9 +183,8 @@ private StarTreeResult traverseStarTree() throws IOException { // If all predicate columns and group-by columns are matched, we can use aggregated document if (remainingPredicateColumns.isEmpty() && remainingGroupByColumns.isEmpty()) { - adder = docsWithField.grow(1); int docId = starTreeNode.getAggregatedDocId(); - adder.add(docId); + docIds.add(docId); docNum = docId > docNum ? docId : docNum; continue; } @@ -197,8 +196,7 @@ private StarTreeResult traverseStarTree() throws IOException { // remaining predicate columns for this node if (starTreeNode.isLeaf()) { for (long i = starTreeNode.getStartDocId(); i < starTreeNode.getEndDocId(); i++) { - adder = docsWithField.grow(1); - adder.add((int) i); + docIds.add((int)i); docNum = (int)i > docNum ? (int)i : docNum; } continue; @@ -289,6 +287,10 @@ private StarTreeResult traverseStarTree() throws IOException { } } + adder = docsWithField.grow(docIds.size()); + for(int id : docIds) { + adder.add(id); + } return new StarTreeResult( docsWithField, globalRemainingPredicateColumns != null ? globalRemainingPredicateColumns : Collections.emptySet(), diff --git a/server/src/main/java/org/opensearch/index/codec/freshstartree/query/StarTreeQueryBuilder.java b/server/src/main/java/org/opensearch/index/codec/freshstartree/query/StarTreeQueryBuilder.java index 540d8a9e2e0bc..fc2c019557d6f 100644 --- a/server/src/main/java/org/opensearch/index/codec/freshstartree/query/StarTreeQueryBuilder.java +++ b/server/src/main/java/org/opensearch/index/codec/freshstartree/query/StarTreeQueryBuilder.java @@ -144,7 +144,7 @@ public static StarTreeQueryBuilder fromXContent(XContentParser parser) { protected Query doToQuery(QueryShardContext context) { // TODO : star tree supports either group by or filter if (predicateMap.size() > 0) { - logger.info("Predicates: {} ", this.groupBy.toString() ); + //logger.info("Predicates: {} ", this.groupBy.toString() ); return new StarTreeQuery(predicateMap, new HashSet<>()); } logger.info("Group by : {} ", this.groupBy.toString() ); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/startree/StarTreeAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/startree/StarTreeAggregator.java index 9258c786200af..f2fd5a45799e2 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/startree/StarTreeAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/startree/StarTreeAggregator.java @@ -8,6 +8,9 @@ package org.opensearch.search.aggregations.bucket.startree; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Predicate; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.index.LeafReaderContext; @@ -172,25 +175,30 @@ public InternalAggregation buildEmptyAggregation() { @Override protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { StarTreeAggregatedValues values = (StarTreeAggregatedValues) ctx.reader().getAggregatedDocValues(); + final AtomicReference aggrVal = new AtomicReference<>(null); return new LeafBucketCollectorBase(sub, values) { @Override public void collect(int doc, long bucket) throws IOException { - StarTreeAggregatedValues aggrVals = (StarTreeAggregatedValues) ctx.reader().getAggregatedDocValues(); - - Map fieldColToDocValuesMap = new HashMap<>(); + if(aggrVal.get() == null) { + aggrVal.set((StarTreeAggregatedValues) ctx.reader().getAggregatedDocValues()); + } + StarTreeAggregatedValues aggrVals = aggrVal.get(); + List fieldColToDocValuesMap = new ArrayList<>(); // TODO : validations for (String field : fieldCols) { - fieldColToDocValuesMap.put(field, aggrVals.dimensionValues.get(field)); + fieldColToDocValuesMap.add(aggrVals.dimensionValues.get(field)); } // Another hardcoding SortedNumericDocValues dv = aggrVals.metricValues.get(metrics.get(0)); if (dv.advanceExact(doc)) { - + long val1 = dv.nextValue(); String key = getKey(fieldColToDocValuesMap, doc); - + if(key.equals("") ) { + return; + } if (indexMap.containsKey(key)) { - sumMap.put(key, sumMap.getOrDefault(key, 0l) + dv.nextValue()); + sumMap.put(key, sumMap.getOrDefault(key, 0l) + val1); } else { indexMap.put(key, indexMap.size()); sumMap.put(key, dv.nextValue()); @@ -202,11 +210,11 @@ public void collect(int doc, long bucket) throws IOException { } - private String getKey(Map fieldColsMap, int doc) throws IOException { + private String getKey(List colsList, int doc) throws IOException { StringJoiner sj = new StringJoiner("-"); - for (Map.Entry fieldEntry : fieldColsMap.entrySet()) { - fieldEntry.getValue().advanceExact(doc); - long val = fieldEntry.getValue().nextValue(); + for (SortedNumericDocValues col : colsList) { + col.advanceExact(doc); + long val = col.nextValue(); //System.out.println("Key field : " + fieldEntry.getKey() + " Value : " + val); sj.add("" + val); } diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregator.java index 4b8e882cd69bc..feca5627dcff0 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregator.java @@ -31,6 +31,8 @@ package org.opensearch.search.aggregations.metrics; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.ScoreMode; import org.opensearch.common.lease.Releasables; @@ -42,6 +44,7 @@ import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.LeafBucketCollector; import org.opensearch.search.aggregations.LeafBucketCollectorBase; +import org.opensearch.search.aggregations.bucket.startree.StarTreeAggregator; import org.opensearch.search.aggregations.support.ValuesSource; import org.opensearch.search.aggregations.support.ValuesSourceConfig; import org.opensearch.search.internal.SearchContext; @@ -61,6 +64,8 @@ public class SumAggregator extends NumericMetricsAggregator.SingleValue { private DoubleArray sums; private DoubleArray compensations; + private static final Logger logger = LogManager.getLogger(SumAggregator.class); + SumAggregator( String name, @@ -95,11 +100,13 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBuc return new LeafBucketCollectorBase(sub, values) { @Override public void collect(int doc, long bucket) throws IOException { + //logger.info("collecting doc : {}", doc); sums = bigArrays.grow(sums, bucket + 1); compensations = bigArrays.grow(compensations, bucket + 1); if (values.advanceExact(doc)) { final int valuesCount = values.docValueCount(); + //logger.info("values count : {}" , valuesCount); // Compute the sum of double values with Kahan summation algorithm which is more // accurate than naive summation. double sum = sums.get(bucket);