Skip to content

Commit

Permalink
[spark] Support auto disable bucketed scan (apache#3928)
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you authored Aug 16, 2024
1 parent ad9be43 commit a7e7bf6
Show file tree
Hide file tree
Showing 8 changed files with 482 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ case class PaimonScan(
requiredSchema: StructType,
filters: Seq[Predicate],
reservedFilters: Seq[Filter],
pushDownLimit: Option[Int])
pushDownLimit: Option[Int],
disableBucketedScan: Boolean = false)
extends PaimonBaseScan(table, requiredSchema, filters, reservedFilters, pushDownLimit)
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ case class PaimonScan(
requiredSchema: StructType,
filters: Seq[Predicate],
reservedFilters: Seq[Filter],
pushDownLimit: Option[Int])
pushDownLimit: Option[Int],
disableBucketedScan: Boolean = false)
extends PaimonBaseScan(table, requiredSchema, filters, reservedFilters, pushDownLimit)
with SupportsRuntimeFiltering {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ trait PaimonInputPartition extends InputPartition {
def rowCount(): Long = {
splits.map(_.rowCount()).sum
}

// Used to avoid checking [[PaimonBucketedInputPartition]] to workaround for multi Spark version
def bucketed = false
}

case class SimplePaimonInputPartition(splits: Seq[Split]) extends PaimonInputPartition
Expand All @@ -48,4 +51,5 @@ case class PaimonBucketedInputPartition(splits: Seq[Split], bucket: Int)
extends PaimonInputPartition
with HasPartitionKey {
override def partitionKey(): InternalRow = new GenericInternalRow(Array(bucket.asInstanceOf[Any]))
override def bucketed: Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.paimon.table.{BucketMode, FileStoreTable, Table}
import org.apache.paimon.table.source.{DataSplit, Split}

import org.apache.spark.sql.PaimonUtils.fieldReference
import org.apache.spark.sql.connector.expressions.{Expressions, NamedReference}
import org.apache.spark.sql.connector.expressions.{Expressions, NamedReference, Transform}
import org.apache.spark.sql.connector.read.{SupportsReportPartitioning, SupportsRuntimeFiltering}
import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning}
import org.apache.spark.sql.sources.{Filter, In}
Expand All @@ -36,35 +36,53 @@ case class PaimonScan(
requiredSchema: StructType,
filters: Seq[Predicate],
reservedFilters: Seq[Filter],
pushDownLimit: Option[Int])
pushDownLimit: Option[Int],
bucketedScanDisabled: Boolean = false)
extends PaimonBaseScan(table, requiredSchema, filters, reservedFilters, pushDownLimit)
with SupportsRuntimeFiltering
with SupportsReportPartitioning {

override def outputPartitioning(): Partitioning = {
def disableBucketedScan(): PaimonScan = {
copy(bucketedScanDisabled = true)
}

@transient
private lazy val extractBucketTransform: Option[Transform] = {
table match {
case fileStoreTable: FileStoreTable =>
val bucketSpec = fileStoreTable.bucketSpec()
if (bucketSpec.getBucketMode != BucketMode.HASH_FIXED) {
new UnknownPartitioning(0)
None
} else if (bucketSpec.getBucketKeys.size() > 1) {
new UnknownPartitioning(0)
None
} else {
// Spark does not support bucket with several input attributes,
// so we only support one bucket key case.
assert(bucketSpec.getNumBuckets > 0)
assert(bucketSpec.getBucketKeys.size() == 1)
val key = Expressions.bucket(bucketSpec.getNumBuckets, bucketSpec.getBucketKeys.get(0))
new KeyGroupedPartitioning(Array(key), lazyInputPartitions.size)
val bucketKey = bucketSpec.getBucketKeys.get(0)
if (requiredSchema.exists(f => conf.resolver(f.name, bucketKey))) {
Some(Expressions.bucket(bucketSpec.getNumBuckets, bucketKey))
} else {
None
}
}

case _ =>
new UnknownPartitioning(0)
case _ => None
}
}

override def outputPartitioning: Partitioning = {
extractBucketTransform
.map(bucket => new KeyGroupedPartitioning(Array(bucket), lazyInputPartitions.size))
.getOrElse(new UnknownPartitioning(0))
}

override def getInputPartitions(splits: Array[Split]): Seq[PaimonInputPartition] = {
if (!conf.v2BucketingEnabled || splits.exists(!_.isInstanceOf[DataSplit])) {
if (
bucketedScanDisabled || !conf.v2BucketingEnabled || extractBucketTransform.isEmpty ||
splits.exists(!_.isInstanceOf[DataSplit])
) {
return super.getInputPartitions(splits)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.paimon.spark.execution.adaptive

import org.apache.paimon.spark.PaimonScan

import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeLike}

// spotless:off
/**
* This rule is inspired from Spark [[DisableUnnecessaryBucketedScan]] but work for v2 scan.
*
* Disable unnecessary bucketed table scan based on actual physical query plan.
* NOTE: this rule is designed to be applied right after [[EnsureRequirements]],
* where all [[ShuffleExchangeLike]] and [[SortExec]] have been added to plan properly.
*
* When BUCKETING_ENABLED and AUTO_BUCKETED_SCAN_ENABLED are set to true, go through
* query plan to check where bucketed table scan is unnecessary, and disable bucketed table
* scan if:
*
* 1. The sub-plan from root to bucketed table scan, does not contain
* [[hasInterestingPartition]] operator.
*
* 2. The sub-plan from the nearest downstream [[hasInterestingPartition]] operator
* to the bucketed table scan and at least one [[ShuffleExchangeLike]].
*
* Examples:
* 1. no [[hasInterestingPartition]] operator:
* Project
* |
* Filter
* |
* Scan(t1: i, j)
* (bucketed on column j, DISABLE bucketed scan)
*
* 2. join:
* SortMergeJoin(t1.i = t2.j)
* / \
* Sort(i) Sort(j)
* / \
* Shuffle(i) Scan(t2: i, j)
* / (bucketed on column j, enable bucketed scan)
* Scan(t1: i, j)
* (bucketed on column j, DISABLE bucketed scan)
*
* 3. aggregate:
* HashAggregate(i, ..., Final)
* |
* Shuffle(i)
* |
* HashAggregate(i, ..., Partial)
* |
* Filter
* |
* Scan(t1: i, j)
* (bucketed on column j, DISABLE bucketed scan)
*
* The idea of [[hasInterestingPartition]] is inspired from "interesting order" in
* the paper "Access Path Selection in a Relational Database Management System"
* (https://dl.acm.org/doi/10.1145/582095.582099).
*/
// spotless:on
object DisableUnnecessaryPaimonBucketedScan extends Rule[SparkPlan] {

/**
* Disable bucketed table scan with pre-order traversal of plan.
*
* @param hashInterestingPartition
* The traversed plan has operator with interesting partition.
* @param hasExchange
* The traversed plan has [[Exchange]] operator.
*/
private def disableBucketScan(
plan: SparkPlan,
hashInterestingPartition: Boolean,
hasExchange: Boolean): SparkPlan = {
plan match {
case p if hasInterestingPartition(p) =>
// Operator with interesting partition, propagates `hashInterestingPartition` as true
// to its children, and resets `hasExchange`.
p.mapChildren(disableBucketScan(_, hashInterestingPartition = true, hasExchange = false))
case exchange: ShuffleExchangeLike =>
// Exchange operator propagates `hasExchange` as true to its child.
exchange.mapChildren(disableBucketScan(_, hashInterestingPartition, hasExchange = true))
case batch: BatchScanExec =>
val paimonBucketedScan = extractPaimonBucketedScan(batch)
if (paimonBucketedScan.isDefined && (!hashInterestingPartition || hasExchange)) {
val (batch, paimonScan) = paimonBucketedScan.get
val newBatch = batch.copy(scan = paimonScan.disableBucketedScan())
newBatch.copyTagsFrom(batch)
newBatch
} else {
batch
}
case p if canPassThrough(p) =>
p.mapChildren(disableBucketScan(_, hashInterestingPartition, hasExchange))
case other =>
other.mapChildren(
disableBucketScan(_, hashInterestingPartition = false, hasExchange = false))
}
}

private def hasInterestingPartition(plan: SparkPlan): Boolean = {
plan.requiredChildDistribution.exists {
case _: ClusteredDistribution | AllTuples => true
case _ => false
}
}

/**
* Check if the operator is allowed single-child operator. We may revisit this method later as we
* probably can remove this restriction to allow arbitrary operator between bucketed table scan
* and operator with interesting partition.
*/
private def canPassThrough(plan: SparkPlan): Boolean = {
plan match {
case _: ProjectExec | _: FilterExec => true
case s: SortExec if !s.global => true
case partialAgg: BaseAggregateExec =>
partialAgg.requiredChildDistributionExpressions.isEmpty
case _ => false
}
}

def extractPaimonBucketedScan(plan: SparkPlan): Option[(BatchScanExec, PaimonScan)] =
plan match {
case batch: BatchScanExec =>
batch.scan match {
case scan: PaimonScan if scan.lazyInputPartitions.forall(_.bucketed) =>
Some((batch, scan))
case _ => None
}
case _ => None
}

def apply(plan: SparkPlan): SparkPlan = {
lazy val hasBucketedScan = plan.exists {
case p if extractPaimonBucketedScan(p).isDefined => true
case _ => false
}

// TODO: replace it with `conf.v2BucketingEnabled` after dropping Spark3.1
val v2BucketingEnabled =
conf.getConfString("spark.sql.sources.v2.bucketing.enabled", "false").toBoolean
if (!v2BucketingEnabled || !conf.autoBucketedScanEnabled || !hasBucketedScan) {
plan
} else {
disableBucketScan(plan, hashInterestingPartition = false, hasExchange = false)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.paimon.spark.catalyst.analysis.{PaimonAnalysis, PaimonDeleteTa
import org.apache.paimon.spark.catalyst.optimizer.{EvalSubqueriesForDeleteTable, MergePaimonScalarSubqueries}
import org.apache.paimon.spark.catalyst.plans.logical.PaimonTableValuedFunctions
import org.apache.paimon.spark.execution.PaimonStrategy
import org.apache.paimon.spark.execution.adaptive.DisableUnnecessaryPaimonBucketedScan

import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.catalyst.parser.extensions.PaimonSparkSqlExtensionsParser
Expand Down Expand Up @@ -58,5 +59,8 @@ class PaimonSparkSessionExtensions extends (SparkSessionExtensions => Unit) {

// planner extensions
extensions.injectPlannerStrategy(spark => PaimonStrategy(spark))

// query stage preparation
extensions.injectQueryStagePrepRule(_ => DisableUnnecessaryPaimonBucketedScan)
}
}
Loading

0 comments on commit a7e7bf6

Please sign in to comment.