Skip to content

Commit

Permalink
Support push down aggregate with group by partition column
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you committed Sep 29, 2024
1 parent 66cc214 commit abc1746
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;

import java.util.Objects;

import static org.apache.paimon.utils.InternalRowUtils.copyInternalRow;

/** Spark {@link org.apache.spark.sql.catalyst.InternalRow} to wrap {@link InternalRow}. */
Expand Down Expand Up @@ -245,6 +247,25 @@ public Object get(int ordinal, org.apache.spark.sql.types.DataType dataType) {
throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString());
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
SparkInternalRow that = (SparkInternalRow) o;
return Objects.equals(rowType, that.rowType) && Objects.equals(row, that.row);
}

@Override
public int hashCode() {
return Objects.hash(rowType, row);
}

// ================== static methods =========================================

public static Object fromPaimon(Object o, DataType type) {
if (o == null) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,17 @@ public class SparkTypeUtils {

private SparkTypeUtils() {}

public static StructType toSparkPartitionType(Table table) {
public static RowType toPartitionType(Table table) {
int[] projections = table.rowType().getFieldIndices(table.partitionKeys());
List<DataField> partitionTypes = new ArrayList<>();
for (int i : projections) {
partitionTypes.add(table.rowType().getFields().get(i));
}
return (StructType) SparkTypeUtils.fromPaimonType(new RowType(false, partitionTypes));
return new RowType(false, partitionTypes);
}

public static StructType toSparkPartitionType(Table table) {
return (StructType) SparkTypeUtils.fromPaimonType(toPartitionType(table));
}

public static StructType fromPaimonRowType(RowType type) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,46 @@

package org.apache.paimon.spark.aggregate

import org.apache.paimon.data.BinaryRow
import org.apache.paimon.manifest.PartitionEntry
import org.apache.paimon.spark.{SparkInternalRow, SparkTypeUtils}
import org.apache.paimon.table.{DataTable, Table}
import org.apache.paimon.utils.{InternalRowUtils, ProjectedRow}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.JoinedRow
import org.apache.spark.sql.connector.expressions.{Expression, NamedReference}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, CountStar}
import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}

import scala.collection.mutable

class LocalAggregator(table: Table) {
private var aggFuncEvaluator: Seq[AggFuncEvaluator[_]] = _
private val partitionType = SparkTypeUtils.toPartitionType(table)
private val groupByEvaluatorMap = new mutable.HashMap[InternalRow, Seq[AggFuncEvaluator[_]]]()
private var requiredGroupByType: Seq[DataType] = _
private var requiredGroupByIndexMapping: Seq[Int] = _
private var aggFuncEvaluatorGetter: () => Seq[AggFuncEvaluator[_]] = _
private var isInitialized = false

private def initialize(aggregation: Aggregation): Unit = {
aggFuncEvaluator = aggregation.aggregateExpressions().map {
case _: CountStar => new CountStarEvaluator()
case _ => throw new UnsupportedOperationException()
aggFuncEvaluatorGetter = () =>
aggregation.aggregateExpressions().map {
case _: CountStar => new CountStarEvaluator()
case _ => throw new UnsupportedOperationException()
}

requiredGroupByType = aggregation.groupByExpressions().map {
case r: NamedReference =>
SparkTypeUtils.fromPaimonType(partitionType.getField(r.fieldNames().head).`type`())
}

requiredGroupByIndexMapping = aggregation.groupByExpressions().map {
case r: NamedReference =>
partitionType.getFieldIndex(r.fieldNames().head)
}

isInitialized = true
}

private def supportAggregateFunction(func: AggregateFunc): Boolean = {
Expand All @@ -42,6 +67,15 @@ class LocalAggregator(table: Table) {
}
}

private def supportGroupByExpressions(exprs: Array[Expression]): Boolean = {
// Support empty group by keys or group by partition column
exprs.forall {
case r: NamedReference =>
r.fieldNames.length == 1 && table.partitionKeys().contains(r.fieldNames().head)
case _ => false
}
}

def pushAggregation(aggregation: Aggregation): Boolean = {
if (
!table.isInstanceOf[DataTable] ||
Expand All @@ -54,7 +88,7 @@ class LocalAggregator(table: Table) {
}

if (
aggregation.groupByExpressions().nonEmpty ||
!supportGroupByExpressions(aggregation.groupByExpressions()) ||
aggregation.aggregateExpressions().isEmpty ||
aggregation.aggregateExpressions().exists(!supportAggregateFunction(_))
) {
Expand All @@ -65,25 +99,49 @@ class LocalAggregator(table: Table) {
true
}

private def requiredGroupByRow(partitionRow: BinaryRow): InternalRow = {
val projectedRow =
ProjectedRow.from(requiredGroupByIndexMapping.toArray).replaceRow(partitionRow)
// `ProjectedRow` does not support `hashCode`, so do a deep copy
val genericRow = InternalRowUtils.copyInternalRow(projectedRow, partitionType)
new SparkInternalRow(partitionType).replace(genericRow)
}

def update(partitionEntry: PartitionEntry): Unit = {
assert(aggFuncEvaluator != null)
assert(isInitialized)
val groupByRow = requiredGroupByRow(partitionEntry.partition())
val aggFuncEvaluator =
groupByEvaluatorMap.getOrElseUpdate(groupByRow, aggFuncEvaluatorGetter())
aggFuncEvaluator.foreach(_.update(partitionEntry))
}

def result(): Array[InternalRow] = {
assert(aggFuncEvaluator != null)
Array(InternalRow.fromSeq(aggFuncEvaluator.map(_.result())))
assert(isInitialized)
if (groupByEvaluatorMap.isEmpty && requiredGroupByType.isEmpty) {
// Always return one row for global aggregate
Array(InternalRow.fromSeq(aggFuncEvaluatorGetter().map(_.result())))
} else {
groupByEvaluatorMap.map {
case (partitionRow, aggFuncEvaluator) =>
new JoinedRow(partitionRow, InternalRow.fromSeq(aggFuncEvaluator.map(_.result())))
}.toArray
}
}

def resultSchema(): StructType = {
assert(aggFuncEvaluator != null)
val fields = aggFuncEvaluator.zipWithIndex.map {
assert(isInitialized)
// Always put the group by keys before the aggregate function result
val groupByFields = requiredGroupByType.zipWithIndex.map {
case (dt, i) =>
StructField(s"groupby_$i", dt)
}
val aggResultFields = aggFuncEvaluatorGetter().zipWithIndex.map {
case (evaluator, i) =>
// Note that, Spark will re-assign the attribute name to original name,
// so here we just return an arbitrary name
StructField(s"${evaluator.prettyName}_$i", evaluator.resultType)
}
StructType.apply(fields)
StructType.apply(groupByFields ++ aggResultFields)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,43 @@ class PushDownAggregatesTest extends PaimonSparkTestBase with AdaptiveSparkPlanH
}
}

test("Push down aggregate - group by partition column") {
withTable("T") {
spark.sql("CREATE TABLE T (c1 INT) PARTITIONED BY(day STRING, hour INT)")

runAndCheckAggregate("SELECT COUNT(*) FROM T GROUP BY day", Nil, 0)
runAndCheckAggregate("SELECT day, COUNT(*) as c FROM T GROUP BY day, hour", Nil, 0)
runAndCheckAggregate("SELECT day, COUNT(*), hour FROM T GROUP BY day, hour", Nil, 0)
runAndCheckAggregate(
"SELECT day, COUNT(*), hour FROM T WHERE day='x' GROUP BY day, hour",
Nil,
0)
// This query does not contain aggregate due to AQE optimize it to empty relation.
runAndCheckAggregate("SELECT day, COUNT(*) FROM T GROUP BY c1, day", Nil, 0)

spark.sql(
"INSERT INTO T VALUES(1, 'x', 1), (2, 'x', 1), (3, 'x', 2), (3, 'x', 3), (null, 'y', null)")

runAndCheckAggregate("SELECT COUNT(*) FROM T GROUP BY day", Row(1) :: Row(4) :: Nil, 0)
runAndCheckAggregate(
"SELECT day, COUNT(*) as c FROM T GROUP BY day, hour",
Row("x", 1) :: Row("x", 1) :: Row("x", 2) :: Row("y", 1) :: Nil,
0)
runAndCheckAggregate(
"SELECT day, COUNT(*), hour FROM T GROUP BY day, hour",
Row("x", 1, 2) :: Row("y", 1, null) :: Row("x", 2, 1) :: Row("x", 1, 3) :: Nil,
0)
runAndCheckAggregate(
"SELECT day, COUNT(*), hour FROM T WHERE day='x' GROUP BY day, hour",
Row("x", 1, 2) :: Row("x", 1, 3) :: Row("x", 2, 1) :: Nil,
0)
runAndCheckAggregate(
"SELECT day, COUNT(*) FROM T GROUP BY c1, day",
Row("x", 1) :: Row("x", 1) :: Row("x", 2) :: Row("y", 1) :: Nil,
2)
}
}

test("Push down aggregate - primary table") {
withTable("T") {
spark.sql("CREATE TABLE T (c1 INT, c2 STRING) TBLPROPERTIES ('primary-key' = 'c1')")
Expand Down

0 comments on commit abc1746

Please sign in to comment.