Skip to content

Commit

Permalink
Shared aggregations in StarTree (apache#12164)
Browse files Browse the repository at this point in the history
  • Loading branch information
davecromberge authored Jan 26, 2024
1 parent 7978d29 commit 5a382f2
Show file tree
Hide file tree
Showing 17 changed files with 439 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,22 +60,23 @@ private AggregationFunctionUtils() {
}

/**
* (For Star-Tree) Creates an {@link AggregationFunctionColumnPair} from the {@link AggregationFunction}. Returns
* {@code null} if the {@link AggregationFunction} cannot be represented as an {@link AggregationFunctionColumnPair}
* (e.g. has multiple arguments, argument is not column etc.).
* (For Star-Tree) Creates an {@link AggregationFunctionColumnPair} in stored type from the
* {@link AggregationFunction}. Returns {@code null} if the {@link AggregationFunction} cannot be represented as an
* {@link AggregationFunctionColumnPair} (e.g. has multiple arguments, argument is not column etc.).
* TODO: Allow multiple arguments for aggregation functions, e.g. percentileEst
*/
@Nullable
public static AggregationFunctionColumnPair getAggregationFunctionColumnPair(
AggregationFunction aggregationFunction) {
AggregationFunctionType aggregationFunctionType = aggregationFunction.getType();
if (aggregationFunctionType == AggregationFunctionType.COUNT) {
public static AggregationFunctionColumnPair getStoredFunctionColumnPair(AggregationFunction aggregationFunction) {
AggregationFunctionType functionType = aggregationFunction.getType();
if (functionType == AggregationFunctionType.COUNT) {
return AggregationFunctionColumnPair.COUNT_STAR;
}
List<ExpressionContext> inputExpressions = aggregationFunction.getInputExpressions();
if (inputExpressions.size() == 1) {
ExpressionContext inputExpression = inputExpressions.get(0);
if (inputExpression.getType() == ExpressionContext.Type.IDENTIFIER) {
return new AggregationFunctionColumnPair(aggregationFunctionType, inputExpression.getIdentifier());
return new AggregationFunctionColumnPair(AggregationFunctionColumnPair.getStoredType(functionType),
inputExpression.getIdentifier());
}
}
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public static AggregationFunctionColumnPair[] extractAggregationFunctionPairs(
new AggregationFunctionColumnPair[numAggregationFunctions];
for (int i = 0; i < numAggregationFunctions; i++) {
AggregationFunctionColumnPair aggregationFunctionColumnPair =
AggregationFunctionUtils.getAggregationFunctionColumnPair(aggregationFunctions[i]);
AggregationFunctionUtils.getStoredFunctionColumnPair(aggregationFunctions[i]);
if (aggregationFunctionColumnPair != null) {
aggregationFunctionColumnPairs[i] = aggregationFunctionColumnPair;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public StarTreeAggregationExecutor(AggregationFunction[] aggregationFunctions) {
_aggregationFunctionColumnPairs = new AggregationFunctionColumnPair[numAggregationFunctions];
for (int i = 0; i < numAggregationFunctions; i++) {
_aggregationFunctionColumnPairs[i] =
AggregationFunctionUtils.getAggregationFunctionColumnPair(aggregationFunctions[i]);
AggregationFunctionUtils.getStoredFunctionColumnPair(aggregationFunctions[i]);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public StarTreeGroupByExecutor(QueryContext queryContext, AggregationFunction[]
_aggregationFunctionColumnPairs = new AggregationFunctionColumnPair[numAggregationFunctions];
for (int i = 0; i < numAggregationFunctions; i++) {
_aggregationFunctionColumnPairs[i] =
AggregationFunctionUtils.getAggregationFunctionColumnPair(aggregationFunctions[i]);
AggregationFunctionUtils.getStoredFunctionColumnPair(aggregationFunctions[i]);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,7 @@ public void setUp()
throws Exception {
_valueAggregator = getValueAggregator();
_aggregatedValueType = _valueAggregator.getAggregatedValueType();
AggregationFunctionType aggregationType = _valueAggregator.getAggregationType();
if (aggregationType == AggregationFunctionType.COUNT) {
_aggregation = "COUNT(*)";
} else if (aggregationType == AggregationFunctionType.PERCENTILEEST
|| aggregationType == AggregationFunctionType.PERCENTILETDIGEST) {
// Append a percentile number for percentile functions
_aggregation = String.format("%s(%s, 50)", aggregationType.getName(), METRIC);
} else {
_aggregation = String.format("%s(%s)", aggregationType.getName(), METRIC);
}
_aggregation = getAggregation(_valueAggregator.getAggregationType());

Schema.SchemaBuilder schemaBuilder = new Schema.SchemaBuilder().addSingleValueDimension(DIMENSION_D1, DataType.INT)
.addSingleValueDimension(DIMENSION_D2, DataType.INT);
Expand Down Expand Up @@ -185,6 +176,18 @@ public void setUp()
_starTreeV2 = _indexSegment.getStarTrees().get(0);
}

String getAggregation(AggregationFunctionType aggregationType) {
if (aggregationType == AggregationFunctionType.COUNT) {
return "COUNT(*)";
} else if (aggregationType == AggregationFunctionType.PERCENTILEEST
|| aggregationType == AggregationFunctionType.PERCENTILETDIGEST) {
// Append a percentile number for percentile functions
return String.format("%s(%s, 50)", aggregationType.getName(), METRIC);
} else {
return String.format("%s(%s)", aggregationType.getName(), METRIC);
}
}

@Test
public void testUnsupportedFilters() {
String query = String.format("SELECT %s FROM %s", _aggregation, TABLE_NAME);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pinot.core.startree.v2;

import com.clearspring.analytics.stream.cardinality.HyperLogLog;
import java.util.Collections;
import java.util.Random;
import org.apache.pinot.segment.local.aggregator.DistinctCountHLLValueAggregator;
import org.apache.pinot.segment.local.aggregator.ValueAggregator;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.spi.data.FieldSpec.DataType;

import static org.testng.Assert.assertEquals;


public class DistinctCountRawHLLStarTreeV2Test extends BaseStarTreeV2Test<Object, HyperLogLog> {

@Override
String getAggregation(AggregationFunctionType aggregationType) {
return "distinctCountRawHLL(m)";
}

@Override
ValueAggregator<Object, HyperLogLog> getValueAggregator() {
return new DistinctCountHLLValueAggregator(Collections.emptyList());
}

@Override
DataType getRawValueType() {
return DataType.INT;
}

@Override
Object getRandomRawValue(Random random) {
return random.nextInt(100);
}

@Override
void assertAggregatedValue(HyperLogLog starTreeResult, HyperLogLog nonStarTreeResult) {
assertEquals(starTreeResult.cardinality(), nonStarTreeResult.cardinality());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public DistinctCountHLLPlusValueAggregator(List<ExpressionContext> arguments) {

@Override
public AggregationFunctionType getAggregationType() {
return AggregationFunctionType.DISTINCTCOUNTHLL;
return AggregationFunctionType.DISTINCTCOUNTHLLPLUS;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -56,7 +57,7 @@ public class StarTreeIndexReader implements Closeable {
private final int _numStarTrees;

// StarTree index can contain multiple index instances, identified by ids like 0, 1, etc.
private final Map<Integer, Map<IndexKey, StarTreeIndexEntry>> _indexColumnEntries;
private final List<Map<IndexKey, StarTreeIndexEntry>> _indexColumnEntries;
private PinotDataBuffer _dataBuffer;

/**
Expand All @@ -78,7 +79,7 @@ public StarTreeIndexReader(File segmentDirectory, SegmentMetadataImpl segmentMet
_readMode = readMode;
_numStarTrees = _segmentMetadata.getStarTreeV2MetadataList().size();
_indexFile = new File(_segmentDirectory, StarTreeV2Constants.INDEX_FILE_NAME);
_indexColumnEntries = new HashMap<>(_numStarTrees);
_indexColumnEntries = new ArrayList<>(_numStarTrees);
load();
}

Expand All @@ -104,40 +105,37 @@ private void load()

private void mapBufferEntries(int starTreeId,
Map<StarTreeIndexMapUtils.IndexKey, StarTreeIndexMapUtils.IndexValue> indexMap) {
Map<IndexKey, StarTreeIndexEntry> columnEntries =
_indexColumnEntries.computeIfAbsent(starTreeId, k -> new HashMap<>());
Map<IndexKey, StarTreeIndexEntry> columnEntries = new HashMap<>();
_indexColumnEntries.add(columnEntries);
// Load star-tree index. The index tree doesn't have corresponding column name or column index type to create an
// IndexKey. As it's a kind of inverted index, we uniquely identify it with index id and inverted index type.
columnEntries.computeIfAbsent(new IndexKey(String.valueOf(starTreeId), StandardIndexes.inverted()),
k -> new StarTreeIndexEntry(indexMap.get(StarTreeIndexMapUtils.STAR_TREE_INDEX_KEY), _dataBuffer,
columnEntries.put(new IndexKey(String.valueOf(starTreeId), StandardIndexes.inverted()),
new StarTreeIndexEntry(indexMap.get(StarTreeIndexMapUtils.STAR_TREE_INDEX_KEY), _dataBuffer,
ByteOrder.LITTLE_ENDIAN));
List<StarTreeV2Metadata> starTreeMetadataList = _segmentMetadata.getStarTreeV2MetadataList();
StarTreeV2Metadata starTreeMetadata = starTreeMetadataList.get(starTreeId);
// Load dimension forward indexes
for (String dimension : starTreeMetadata.getDimensionsSplitOrder()) {
IndexKey indexKey = new IndexKey(dimension, StandardIndexes.forward());
columnEntries.computeIfAbsent(indexKey, k -> new StarTreeIndexEntry(
columnEntries.put(new IndexKey(dimension, StandardIndexes.forward()), new StarTreeIndexEntry(
indexMap.get(new StarTreeIndexMapUtils.IndexKey(StarTreeIndexMapUtils.IndexType.FORWARD_INDEX, dimension)),
_dataBuffer, ByteOrder.BIG_ENDIAN));
}
// Load metric (function-column pair) forward indexes
for (AggregationFunctionColumnPair functionColumnPair : starTreeMetadata.getFunctionColumnPairs()) {
String metric = functionColumnPair.toColumnName();
IndexKey indexKey = new IndexKey(metric, StandardIndexes.forward());
columnEntries.computeIfAbsent(indexKey, k -> new StarTreeIndexEntry(
columnEntries.put(new IndexKey(metric, StandardIndexes.forward()), new StarTreeIndexEntry(
indexMap.get(new StarTreeIndexMapUtils.IndexKey(StarTreeIndexMapUtils.IndexType.FORWARD_INDEX, metric)),
_dataBuffer, ByteOrder.BIG_ENDIAN));
}
}

public PinotDataBuffer getBuffer(int starTreeId, String column, IndexType<?, ?, ?> type)
throws IOException {
Map<IndexKey, StarTreeIndexEntry> columnEntries = _indexColumnEntries.get(starTreeId);
if (columnEntries == null) {
if (_indexColumnEntries.size() <= starTreeId) {
throw new RuntimeException(
String.format("Could not find StarTree index: %s in segment: %s", starTreeId, _segmentDirectory.toString()));
}
StarTreeIndexEntry entry = columnEntries.get(new IndexKey(column, type));
StarTreeIndexEntry entry = _indexColumnEntries.get(starTreeId).get(new IndexKey(column, type));
if (entry != null && entry._buffer != null) {
return entry._buffer;
}
Expand All @@ -147,11 +145,10 @@ public PinotDataBuffer getBuffer(int starTreeId, String column, IndexType<?, ?,
}

public boolean hasIndexFor(int starTreeId, String column, IndexType<?, ?, ?> type) {
Map<IndexKey, StarTreeIndexEntry> columnEntries = _indexColumnEntries.get(starTreeId);
if (columnEntries == null) {
if (_indexColumnEntries.size() <= starTreeId) {
return false;
}
return columnEntries.containsKey(new IndexKey(column, type));
return _indexColumnEntries.get(starTreeId).containsKey(new IndexKey(column, type));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,24 @@ public static StarTreeV2BuilderConfig fromIndexConfig(StarTreeIndexConfig indexC
for (String functionColumnPair : indexConfig.getFunctionColumnPairs()) {
AggregationFunctionColumnPair aggregationFunctionColumnPair =
AggregationFunctionColumnPair.fromColumnName(functionColumnPair);
aggregationSpecs.put(aggregationFunctionColumnPair, AggregationSpec.DEFAULT);
AggregationFunctionColumnPair storedType =
AggregationFunctionColumnPair.resolveToStoredType(aggregationFunctionColumnPair);
// If there is already an equivalent functionColumnPair in the map, do not load another.
// This prevents the duplication of the aggregation when the StarTree is constructed.
aggregationSpecs.putIfAbsent(storedType, AggregationSpec.DEFAULT);
}
}
if (indexConfig.getAggregationConfigs() != null) {
for (StarTreeAggregationConfig aggregationConfig : indexConfig.getAggregationConfigs()) {
AggregationFunctionColumnPair aggregationFunctionColumnPair =
AggregationFunctionColumnPair.fromAggregationConfig(aggregationConfig);
AggregationFunctionColumnPair storedType =
AggregationFunctionColumnPair.resolveToStoredType(aggregationFunctionColumnPair);
ChunkCompressionType compressionType =
ChunkCompressionType.valueOf(aggregationConfig.getCompressionCodec().name());
aggregationSpecs.put(aggregationFunctionColumnPair, new AggregationSpec(compressionType));
// If there is already an equivalent functionColumnPair in the map, do not load another.
// This prevents the duplication of the aggregation when the StarTree is constructed.
aggregationSpecs.putIfAbsent(storedType, new AggregationSpec(compressionType));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
import org.apache.commons.configuration2.ex.ConfigurationException;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.segment.spi.index.startree.AggregationFunctionColumnPair;
import org.apache.pinot.spi.env.CommonsConfigurationUtils;


/**
* The {@code StarTreeIndexMapUtils} class is a utility class to store/load star-tree index map to/from file.
* <p>
Expand Down Expand Up @@ -182,24 +184,29 @@ public static List<Map<IndexKey, IndexValue>> loadFromInputStream(InputStream in
int starTreeId = Integer.parseInt(split[0]);
Map<IndexKey, IndexValue> indexMap = indexMaps.get(starTreeId);

// Handle the case of column name containing '.'
String column;
int columnSplitEndIndex = split.length - 2;
if (columnSplitEndIndex == 2) {
column = split[1];
} else {
column = StringUtils.join(split, KEY_SEPARATOR, 1, columnSplitEndIndex);
}

IndexType indexType = IndexType.valueOf(split[columnSplitEndIndex]);
IndexKey indexKey;
if (indexType == IndexType.STAR_TREE) {
indexKey = STAR_TREE_INDEX_KEY;
} else {
// Handle the case of column name containing '.'
String column;
if (columnSplitEndIndex == 2) {
column = split[1];
} else {
column = StringUtils.join(split, KEY_SEPARATOR, 1, columnSplitEndIndex);
}
// Convert metric (function-column pair) to stored name for backward-compatibility
if (column.contains(AggregationFunctionColumnPair.DELIMITER)) {
AggregationFunctionColumnPair functionColumnPair = AggregationFunctionColumnPair.fromColumnName(column);
column = AggregationFunctionColumnPair.resolveToStoredType(functionColumnPair).toColumnName();
}
indexKey = new IndexKey(IndexType.FORWARD_INDEX, column);
}
IndexValue indexValue = indexMap.computeIfAbsent(indexKey, (k) -> new IndexValue());

long value = configuration.getLong(key);
IndexValue indexValue = indexMap.computeIfAbsent(indexKey, k -> new IndexValue());
if (split[columnSplitEndIndex + 1].equals(OFFSET_SUFFIX)) {
indexValue._offset = value;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,7 @@ private static void validateIndexingConfig(IndexingConfig indexingConfig, @Nulla
}

List<StarTreeIndexConfig> starTreeIndexConfigList = indexingConfig.getStarTreeIndexConfigs();
Set<AggregationFunctionColumnPair> storedTypes = new HashSet<>();
if (starTreeIndexConfigList != null) {
for (StarTreeIndexConfig starTreeIndexConfig : starTreeIndexConfigList) {
// Dimension split order cannot be null
Expand All @@ -1049,6 +1050,11 @@ private static void validateIndexingConfig(IndexingConfig indexingConfig, @Nulla
throw new IllegalStateException("Invalid StarTreeIndex config: " + functionColumnPair + ". Must be"
+ "in the form <Aggregation function>__<Column name>");
}
AggregationFunctionColumnPair storedType = AggregationFunctionColumnPair.resolveToStoredType(columnPair);
if (!storedTypes.add(storedType)) {
LOGGER.warn("StarTreeIndex config duplication: {} already matches existing function column pair: {}. ",
columnPair, storedType);
}
String columnName = columnPair.getColumn();
if (!columnName.equals(AggregationFunctionColumnPair.STAR)) {
columnNameToConfigMap.put(columnName, STAR_TREE_CONFIG_NAME);
Expand All @@ -1064,6 +1070,11 @@ private static void validateIndexingConfig(IndexingConfig indexingConfig, @Nulla
} catch (Exception e) {
throw new IllegalStateException("Invalid StarTreeIndex config: " + aggregationConfig);
}
AggregationFunctionColumnPair storedType = AggregationFunctionColumnPair.resolveToStoredType(columnPair);
if (!storedTypes.add(storedType)) {
LOGGER.warn("StarTreeIndex config duplication: {} already matches existing function column pair: {}. ",
columnPair, storedType);
}
String columnName = columnPair.getColumn();
if (!columnName.equals(AggregationFunctionColumnPair.STAR)) {
columnNameToConfigMap.put(columnName, STAR_TREE_CONFIG_NAME);
Expand Down
Loading

0 comments on commit 5a382f2

Please sign in to comment.