Skip to content

Commit

Permalink
[spark] Support push down aggregate (#4259)
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you authored Sep 26, 2024
1 parent 21bfe42 commit 3120722
Show file tree
Hide file tree
Showing 6 changed files with 321 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,16 @@ abstract class PaimonBaseScanBuilder(table: Table)

protected var requiredSchema: StructType = SparkTypeUtils.fromPaimonRowType(table.rowType())

protected var pushed: Array[(Filter, Predicate)] = Array.empty
protected var pushedPredicates: Array[(Filter, Predicate)] = Array.empty

protected var reservedFilters: Array[Filter] = Array.empty
protected var partitionFilters: Array[Filter] = Array.empty

protected var postScanFilters: Array[Filter] = Array.empty

protected var pushDownLimit: Option[Int] = None

override def build(): Scan = {
PaimonScan(table, requiredSchema, pushed.map(_._2), reservedFilters, pushDownLimit)
PaimonScan(table, requiredSchema, pushedPredicates.map(_._2), partitionFilters, pushDownLimit)
}

/**
Expand All @@ -54,7 +56,7 @@ abstract class PaimonBaseScanBuilder(table: Table)
override def pushFilters(filters: Array[Filter]): Array[Filter] = {
val pushable = mutable.ArrayBuffer.empty[(Filter, Predicate)]
val postScan = mutable.ArrayBuffer.empty[Filter]
val reserved = mutable.ArrayBuffer.empty[Filter]
val partitionFilter = mutable.ArrayBuffer.empty[Filter]

val converter = new SparkFilterConverter(table.rowType)
val visitor = new PartitionPredicateVisitor(table.partitionKeys())
Expand All @@ -66,24 +68,27 @@ abstract class PaimonBaseScanBuilder(table: Table)
} else {
pushable.append((filter, predicate))
if (predicate.visit(visitor)) {
reserved.append(filter)
partitionFilter.append(filter)
} else {
postScan.append(filter)
}
}
}

if (pushable.nonEmpty) {
this.pushed = pushable.toArray
this.pushedPredicates = pushable.toArray
}
if (partitionFilter.nonEmpty) {
this.partitionFilters = partitionFilter.toArray
}
if (reserved.nonEmpty) {
this.reservedFilters = reserved.toArray
if (postScan.nonEmpty) {
this.postScanFilters = postScan.toArray
}
postScan.toArray
}

override def pushedFilters(): Array[Filter] = {
pushed.map(_._1)
pushedPredicates.map(_._1)
}

override def pruneColumns(requiredSchema: StructType): Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.paimon.spark

import org.apache.paimon.table.Table

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.LocalScan
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType

/** A scan does not require [[RDD]] to execute */
case class PaimonLocalScan(
rows: Array[InternalRow],
readSchema: StructType,
table: Table,
filters: Array[Filter])
extends LocalScan {

override def description(): String = {
val pushedFiltersStr = if (filters.nonEmpty) {
", PushedFilters: [" + filters.mkString(",") + "]"
} else {
""
}
s"PaimonLocalScan: [${table.name}]" + pushedFiltersStr
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,20 @@

package org.apache.paimon.spark

import org.apache.paimon.predicate.PredicateBuilder
import org.apache.paimon.spark.aggregate.LocalAggregator
import org.apache.paimon.table.Table

import org.apache.spark.sql.connector.read.SupportsPushDownLimit
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates, SupportsPushDownLimit}

import scala.collection.JavaConverters._

class PaimonScanBuilder(table: Table)
extends PaimonBaseScanBuilder(table)
with SupportsPushDownLimit {
with SupportsPushDownLimit
with SupportsPushDownAggregates {
private var localScan: Option[Scan] = None

override def pushLimit(limit: Int): Boolean = {
if (table.primaryKeys().isEmpty) {
Expand All @@ -33,4 +40,49 @@ class PaimonScanBuilder(table: Table)
// just make a best effort to push down limit
false
}

override def supportCompletePushDown(aggregation: Aggregation): Boolean = {
// for now we only support complete push down, so there is no difference with `pushAggregation`
pushAggregation(aggregation)
}

// Spark does not support push down aggregation for streaming scan.
override def pushAggregation(aggregation: Aggregation): Boolean = {
if (localScan.isDefined) {
return true
}

// Only support with push down partition filter
if (postScanFilters.nonEmpty) {
return false
}

val aggregator = new LocalAggregator(table)
if (!aggregator.pushAggregation(aggregation)) {
return false
}

val readBuilder = table.newReadBuilder
if (pushedPredicates.nonEmpty) {
val pushedPartitionPredicate = PredicateBuilder.and(pushedPredicates.map(_._2): _*)
readBuilder.withFilter(pushedPartitionPredicate)
}
val scan = readBuilder.newScan()
scan.listPartitionEntries.asScala.foreach(aggregator.update)
localScan = Some(
PaimonLocalScan(
aggregator.result(),
aggregator.resultSchema(),
table,
pushedPredicates.map(_._1)))
true
}

override def build(): Scan = {
if (localScan.isDefined) {
localScan.get
} else {
super.build()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.StructType

class PaimonSplitScanBuilder(table: KnownSplitsTable) extends PaimonBaseScanBuilder(table) {
override def build(): Scan = {
PaimonSplitScan(table, table.splits(), requiredSchema, pushed.map(_._2))
PaimonSplitScan(table, table.splits(), requiredSchema, pushedPredicates.map(_._2))
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.paimon.spark.aggregate

import org.apache.paimon.manifest.PartitionEntry
import org.apache.paimon.table.{DataTable, Table}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, CountStar}
import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}

class LocalAggregator(table: Table) {
private var aggFuncEvaluator: Seq[AggFuncEvaluator[_]] = _

private def initialize(aggregation: Aggregation): Unit = {
aggFuncEvaluator = aggregation.aggregateExpressions().map {
case _: CountStar => new CountStarEvaluator()
case _ => throw new UnsupportedOperationException()
}
}

private def supportAggregateFunction(func: AggregateFunc): Boolean = {
func match {
case _: CountStar => true
case _ => false
}
}

def pushAggregation(aggregation: Aggregation): Boolean = {
if (
!table.isInstanceOf[DataTable] ||
!table.primaryKeys.isEmpty
) {
return false
}
if (table.asInstanceOf[DataTable].coreOptions.deletionVectorsEnabled) {
return false
}

if (
aggregation.groupByExpressions().nonEmpty ||
aggregation.aggregateExpressions().isEmpty ||
aggregation.aggregateExpressions().exists(!supportAggregateFunction(_))
) {
return false
}

initialize(aggregation)
true
}

def update(partitionEntry: PartitionEntry): Unit = {
assert(aggFuncEvaluator != null)
aggFuncEvaluator.foreach(_.update(partitionEntry))
}

def result(): Array[InternalRow] = {
assert(aggFuncEvaluator != null)
Array(InternalRow.fromSeq(aggFuncEvaluator.map(_.result())))
}

def resultSchema(): StructType = {
assert(aggFuncEvaluator != null)
val fields = aggFuncEvaluator.zipWithIndex.map {
case (evaluator, i) =>
// Note that, Spark will re-assign the attribute name to original name,
// so here we just return an arbitrary name
StructField(s"${evaluator.prettyName}_$i", evaluator.resultType)
}
StructType.apply(fields)
}
}

trait AggFuncEvaluator[T] {
def update(partitionEntry: PartitionEntry): Unit
def result(): T
def resultType: DataType
def prettyName: String
}

class CountStarEvaluator extends AggFuncEvaluator[Long] {
private var _result: Long = 0L

override def update(partitionEntry: PartitionEntry): Unit = {
_result += partitionEntry.recordCount()
}

override def result(): Long = _result

override def resultType: DataType = LongType

override def prettyName: String = "count_star"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.paimon.spark.sql

import org.apache.paimon.spark.PaimonSparkTestBase

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.execution.LocalTableScanExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec

class PushDownAggregatesTest extends PaimonSparkTestBase with AdaptiveSparkPlanHelper {

private def runAndCheckAggregate(
query: String,
expectedRows: Seq[Row],
expectedNumAggregates: Int): Unit = {
val df = spark.sql(query)
checkAnswer(df, expectedRows)
assert(df.schema.names.toSeq == df.queryExecution.executedPlan.output.map(_.name))
assert(df.queryExecution.analyzed.find(_.isInstanceOf[Aggregate]).isDefined)
val numAggregates = collect(df.queryExecution.executedPlan) {
case agg: BaseAggregateExec => agg
}.size
assert(numAggregates == expectedNumAggregates, query)
if (numAggregates == 0) {
assert(collect(df.queryExecution.executedPlan) {
case scan: LocalTableScanExec => scan
}.size == 1)
}
}

test("Push down aggregate - append table") {
withTable("T") {
spark.sql("CREATE TABLE T (c1 INT, c2 STRING) PARTITIONED BY(day STRING)")

runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(0) :: Nil, 0)
// This query does not contain aggregate due to AQE optimize it to empty relation.
runAndCheckAggregate("SELECT COUNT(*) FROM T GROUP BY c1", Nil, 0)
runAndCheckAggregate("SELECT COUNT(c1) FROM T", Row(0) :: Nil, 2)
runAndCheckAggregate("SELECT COUNT(*), COUNT(c1) FROM T", Row(0, 0) :: Nil, 2)
runAndCheckAggregate("SELECT COUNT(*), COUNT(*) + 1 FROM T", Row(0, 1) :: Nil, 0)
runAndCheckAggregate("SELECT COUNT(*) as c FROM T WHERE day='a'", Row(0) :: Nil, 0)
runAndCheckAggregate("SELECT COUNT(*) FROM T WHERE c1=1", Row(0) :: Nil, 2)
runAndCheckAggregate("SELECT COUNT(*) FROM T WHERE day='a' and c1=1", Row(0) :: Nil, 2)

spark.sql(
"INSERT INTO T VALUES(1, 'x', 'a'), (2, 'x', 'a'), (3, 'x', 'b'), (3, 'x', 'c'), (null, 'x', 'a')")

runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(5) :: Nil, 0)
runAndCheckAggregate(
"SELECT COUNT(*) FROM T GROUP BY c1",
Row(1) :: Row(1) :: Row(1) :: Row(2) :: Nil,
2)
runAndCheckAggregate("SELECT COUNT(c1) FROM T", Row(4) :: Nil, 2)
runAndCheckAggregate("SELECT COUNT(*), COUNT(c1) FROM T", Row(5, 4) :: Nil, 2)
runAndCheckAggregate("SELECT COUNT(*), COUNT(*) + 1 FROM T", Row(5, 6) :: Nil, 0)
runAndCheckAggregate("SELECT COUNT(*) as c FROM T WHERE day='a'", Row(3) :: Nil, 0)
runAndCheckAggregate("SELECT COUNT(*) FROM T WHERE c1=1", Row(1) :: Nil, 2)
runAndCheckAggregate("SELECT COUNT(*) FROM T WHERE day='a' and c1=1", Row(1) :: Nil, 2)
}
}

test("Push down aggregate - primary table") {
withTable("T") {
spark.sql("CREATE TABLE T (c1 INT, c2 STRING) TBLPROPERTIES ('primary-key' = 'c1')")
runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(0) :: Nil, 2)
spark.sql("INSERT INTO T VALUES(1, 'x'), (2, 'x'), (3, 'x'), (3, 'x')")
runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(3) :: Nil, 2)
}
}

test("Push down aggregate - enable deletion vector") {
withTable("T") {
spark.sql(
"CREATE TABLE T (c1 INT, c2 STRING) TBLPROPERTIES('deletion-vectors.enabled' = 'true')")
runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(0) :: Nil, 2)
spark.sql("INSERT INTO T VALUES(1, 'x'), (2, 'x'), (3, 'x'), (3, 'x')")
runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(4) :: Nil, 2)
}
}
}

0 comments on commit 3120722

Please sign in to comment.