Skip to content

Commit

Permalink
[spark] Support report scan ordering (#4026)
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you authored Aug 28, 2024
1 parent b8639d4 commit 3efd2f3
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 40 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.read;

import org.apache.spark.sql.connector.expressions.SortOrder;

/** Just make compile happy and never be used. */
public interface SupportsReportOrdering extends Scan {
SortOrder[] outputOrdering();
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import org.apache.paimon.table.{BucketMode, FileStoreTable, Table}
import org.apache.paimon.table.source.{DataSplit, Split}

import org.apache.spark.sql.PaimonUtils.fieldReference
import org.apache.spark.sql.connector.expressions.{Expressions, NamedReference, Transform}
import org.apache.spark.sql.connector.read.{SupportsReportPartitioning, SupportsRuntimeFiltering}
import org.apache.spark.sql.connector.expressions.{Expressions, NamedReference, SortDirection, SortOrder, Transform}
import org.apache.spark.sql.connector.read.{SupportsReportOrdering, SupportsReportPartitioning, SupportsRuntimeFiltering}
import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning}
import org.apache.spark.sql.sources.{Filter, In}
import org.apache.spark.sql.types.StructType
Expand All @@ -40,7 +40,8 @@ case class PaimonScan(
bucketedScanDisabled: Boolean = false)
extends PaimonBaseScan(table, requiredSchema, filters, reservedFilters, pushDownLimit)
with SupportsRuntimeFiltering
with SupportsReportPartitioning {
with SupportsReportPartitioning
with SupportsReportOrdering {

def disableBucketedScan(): PaimonScan = {
copy(bucketedScanDisabled = true)
Expand Down Expand Up @@ -72,17 +73,65 @@ case class PaimonScan(
}
}

private def shouldDoBucketedScan: Boolean = {
!bucketedScanDisabled && conf.v2BucketingEnabled && extractBucketTransform.isDefined
}

// Since Spark 3.3
override def outputPartitioning: Partitioning = {
extractBucketTransform
.map(bucket => new KeyGroupedPartitioning(Array(bucket), lazyInputPartitions.size))
.getOrElse(new UnknownPartitioning(0))
}

override def getInputPartitions(splits: Array[Split]): Seq[PaimonInputPartition] = {
// Since Spark 3.4
override def outputOrdering(): Array[SortOrder] = {
if (
bucketedScanDisabled || !conf.v2BucketingEnabled || extractBucketTransform.isEmpty ||
splits.exists(!_.isInstanceOf[DataSplit])
!shouldDoBucketedScan || lazyInputPartitions.exists(
!_.isInstanceOf[PaimonBucketedInputPartition])
) {
return Array.empty
}

val primaryKeys = table match {
case fileStoreTable: FileStoreTable => fileStoreTable.primaryKeys().asScala
case _ => Seq.empty
}
if (primaryKeys.isEmpty) {
return Array.empty
}

val allSplitsKeepOrdering = lazyInputPartitions.toSeq
.map(_.asInstanceOf[PaimonBucketedInputPartition])
.map(_.splits.asInstanceOf[Seq[DataSplit]])
.forall {
splits =>
// Only support report ordering if all matches:
// - one `Split` per InputPartition (TODO: Re-construct splits using minKey/maxKey)
// - `Split` is not rawConvertible so that the merge read can happen
// - `Split` only contains one data file so it always sorted even without merge read
splits.size < 2 && splits.forall {
split => !split.rawConvertible() || split.dataFiles().size() < 2
}
}
if (!allSplitsKeepOrdering) {
return Array.empty
}

// Multi-primary keys are fine:
// `Array(a, b)` satisfies the required ordering `Array(a)`
primaryKeys
.map(Expressions.identity)
.map {
sortExpr =>
// Primary key can not be null, the null ordering is no matter.
Expressions.sort(sortExpr, SortDirection.ASCENDING)
}
.toArray
}

override def getInputPartitions(splits: Array[Split]): Seq[PaimonInputPartition] = {
if (!shouldDoBucketedScan || splits.exists(!_.isInstanceOf[DataSplit])) {
return super.getInputPartitions(splits)
}

Expand All @@ -96,6 +145,7 @@ case class PaimonScan(
.toSeq
}

// Since Spark 3.2
override def filterAttributes(): Array[NamedReference] = {
val requiredFields = readBuilder.readType().getFieldNames.asScala
table
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeLike}
* scan if:
*
* 1. The sub-plan from root to bucketed table scan, does not contain
* [[hasInterestingPartition]] operator.
* [[hasInterestingPartitionOrOrder]] operator.
*
* 2. The sub-plan from the nearest downstream [[hasInterestingPartition]] operator
* 2. The sub-plan from the nearest downstream [[hasInterestingPartitionOrOrder]] operator
* to the bucketed table scan and at least one [[ShuffleExchangeLike]].
*
* Examples:
* 1. no [[hasInterestingPartition]] operator:
* 1. no [[hasInterestingPartitionOrOrder]] operator:
* Project
* |
* Filter
Expand Down Expand Up @@ -76,7 +76,7 @@ import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeLike}
* Scan(t1: i, j)
* (bucketed on column j, DISABLE bucketed scan)
*
* The idea of [[hasInterestingPartition]] is inspired from "interesting order" in
* The idea of [[hasInterestingPartitionOrOrder]] is inspired from "interesting order" in
* the paper "Access Path Selection in a Relational Database Management System"
* (https://dl.acm.org/doi/10.1145/582095.582099).
*/
Expand All @@ -86,26 +86,28 @@ object DisableUnnecessaryPaimonBucketedScan extends Rule[SparkPlan] {
/**
* Disable bucketed table scan with pre-order traversal of plan.
*
* @param hashInterestingPartition
* The traversed plan has operator with interesting partition.
* @param hashInterestingPartitionOrOrder
* The traversed plan has operator with interesting partition and order.
* @param hasExchange
* The traversed plan has [[Exchange]] operator.
*/
private def disableBucketScan(
plan: SparkPlan,
hashInterestingPartition: Boolean,
hashInterestingPartitionOrOrder: Boolean,
hasExchange: Boolean): SparkPlan = {
plan match {
case p if hasInterestingPartition(p) =>
// Operator with interesting partition, propagates `hashInterestingPartition` as true
case p if hasInterestingPartitionOrOrder(p) =>
// Operator with interesting partition, propagates `hashInterestingPartitionOrOrder` as true
// to its children, and resets `hasExchange`.
p.mapChildren(disableBucketScan(_, hashInterestingPartition = true, hasExchange = false))
p.mapChildren(
disableBucketScan(_, hashInterestingPartitionOrOrder = true, hasExchange = false))
case exchange: ShuffleExchangeLike =>
// Exchange operator propagates `hasExchange` as true to its child.
exchange.mapChildren(disableBucketScan(_, hashInterestingPartition, hasExchange = true))
exchange.mapChildren(
disableBucketScan(_, hashInterestingPartitionOrOrder, hasExchange = true))
case batch: BatchScanExec =>
val paimonBucketedScan = extractPaimonBucketedScan(batch)
if (paimonBucketedScan.isDefined && (!hashInterestingPartition || hasExchange)) {
if (paimonBucketedScan.isDefined && (!hashInterestingPartitionOrOrder || hasExchange)) {
val (batch, paimonScan) = paimonBucketedScan.get
val newBatch = batch.copy(scan = paimonScan.disableBucketedScan())
newBatch.copyTagsFrom(batch)
Expand All @@ -114,18 +116,22 @@ object DisableUnnecessaryPaimonBucketedScan extends Rule[SparkPlan] {
batch
}
case p if canPassThrough(p) =>
p.mapChildren(disableBucketScan(_, hashInterestingPartition, hasExchange))
p.mapChildren(disableBucketScan(_, hashInterestingPartitionOrOrder, hasExchange))
case other =>
other.mapChildren(
disableBucketScan(_, hashInterestingPartition = false, hasExchange = false))
disableBucketScan(_, hashInterestingPartitionOrOrder = false, hasExchange = false))
}
}

private def hasInterestingPartition(plan: SparkPlan): Boolean = {
plan.requiredChildDistribution.exists {
private def hasInterestingPartitionOrOrder(plan: SparkPlan): Boolean = {
val hashPartition = plan.requiredChildDistribution.exists {
case _: ClusteredDistribution | AllTuples => true
case _ => false
}
// Some operators may only require local sort without distribution,
// so we do not disable bucketed scan for these queries.
val hashOrder = plan.requiredChildOrdering.exists(_.nonEmpty)
hashPartition || hashOrder
}

/**
Expand Down Expand Up @@ -166,7 +172,7 @@ object DisableUnnecessaryPaimonBucketedScan extends Rule[SparkPlan] {
if (!v2BucketingEnabled || !conf.autoBucketedScanEnabled || !hasBucketedScan) {
plan
} else {
disableBucketScan(plan, hashInterestingPartition = false, hasExchange = false)
disableBucketScan(plan, hashInterestingPartitionOrOrder = false, hasExchange = false)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ package org.apache.paimon.spark.sql
import org.apache.paimon.spark.PaimonSparkTestBase

import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.SortExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike

class BucketedTableQueryTest extends PaimonSparkTestBase with AdaptiveSparkPlanHelper {
private def checkAnswerAndShuffle(query: String, numShuffle: Int): Unit = {
private def checkAnswerAndShuffleSorts(query: String, numShuffles: Int, numSorts: Int): Unit = {
var expectedResult: Array[Row] = null
// avoid config default value change in future, so specify it manually
withSQLConf(
Expand All @@ -40,14 +41,19 @@ class BucketedTableQueryTest extends PaimonSparkTestBase with AdaptiveSparkPlanH
checkAnswer(df, expectedResult.toSeq)
assert(collect(df.queryExecution.executedPlan) {
case shuffle: ShuffleExchangeLike => shuffle
}.size == numShuffle)
}.size == numShuffles)
if (gteqSpark3_4) {
assert(collect(df.queryExecution.executedPlan) {
case sort: SortExec => sort
}.size == numSorts)
}
}
}

test("Query on a bucketed table - join - positive case") {
assume(gteqSpark3_3)

withTable("t1", "t2", "t3", "t4") {
withTable("t1", "t2", "t3", "t4", "t5") {
spark.sql(
"CREATE TABLE t1 (id INT, c STRING) TBLPROPERTIES ('primary-key' = 'id', 'bucket'='10')")
spark.sql("INSERT INTO t1 VALUES (1, 'x1'), (2, 'x3'), (3, 'x3'), (4, 'x4'), (5, 'x5')")
Expand All @@ -56,19 +62,26 @@ class BucketedTableQueryTest extends PaimonSparkTestBase with AdaptiveSparkPlanH
spark.sql(
"CREATE TABLE t2 (id INT, c STRING) TBLPROPERTIES ('primary-key' = 'id', 'bucket'='10')")
spark.sql("INSERT INTO t2 VALUES (1, 'x1'), (2, 'x3'), (3, 'x3'), (4, 'x4'), (5, 'x5')")
checkAnswerAndShuffle("SELECT * FROM t1 JOIN t2 on t1.id = t2.id", 0)
checkAnswerAndShuffleSorts("SELECT * FROM t1 JOIN t2 on t1.id = t2.id", 0, 0)

// different primary-key name but does not matter
spark.sql(
"CREATE TABLE t3 (id2 INT, c STRING) TBLPROPERTIES ('primary-key' = 'id2', 'bucket'='10')")
spark.sql("INSERT INTO t3 VALUES (1, 'x1'), (2, 'x3'), (3, 'x3'), (4, 'x4'), (5, 'x5')")
checkAnswerAndShuffle("SELECT * FROM t1 JOIN t3 on t1.id = t3.id2", 0)
checkAnswerAndShuffleSorts("SELECT * FROM t1 JOIN t3 on t1.id = t3.id2", 0, 0)

// one primary-key table and one bucketed table
spark.sql(
"CREATE TABLE t4 (id INT, c STRING) TBLPROPERTIES ('bucket-key' = 'id', 'bucket'='10')")
spark.sql("INSERT INTO t4 VALUES (1, 'x1'), (2, 'x3'), (3, 'x3'), (4, 'x4'), (5, 'x5')")
checkAnswerAndShuffle("SELECT * FROM t1 JOIN t4 on t1.id = t4.id", 0)
checkAnswerAndShuffleSorts("SELECT * FROM t1 JOIN t4 on t1.id = t4.id", 0, 1)

// one primary-key table and
// one primary-key table with two primary keys and one bucket column
spark.sql(
"CREATE TABLE t5 (id INT, c STRING) TBLPROPERTIES ('primary-key' = 'id,c', 'bucket-key' = 'id', 'bucket'='10')")
spark.sql("INSERT INTO t5 VALUES (1, 'x1'), (2, 'x3'), (3, 'x3'), (4, 'x4'), (5, 'x5')")
checkAnswerAndShuffleSorts("SELECT * FROM t1 JOIN t5 on t1.id = t5.id", 0, 0)
}
}

Expand All @@ -83,32 +96,32 @@ class BucketedTableQueryTest extends PaimonSparkTestBase with AdaptiveSparkPlanH
// dynamic bucket number
spark.sql("CREATE TABLE t2 (id INT, c STRING) TBLPROPERTIES ('primary-key' = 'id')")
spark.sql("INSERT INTO t2 VALUES (1, 'x1'), (2, 'x3'), (3, 'x3'), (4, 'x4'), (5, 'x5')")
checkAnswerAndShuffle("SELECT * FROM t1 JOIN t2 on t1.id = t2.id", 2)
checkAnswerAndShuffleSorts("SELECT * FROM t1 JOIN t2 on t1.id = t2.id", 2, 2)

// different bucket number
spark.sql(
"CREATE TABLE t3 (id INT, c STRING) TBLPROPERTIES ('primary-key' = 'id', 'bucket'='2')")
spark.sql("INSERT INTO t3 VALUES (1, 'x1'), (2, 'x3'), (3, 'x3'), (4, 'x4'), (5, 'x5')")
checkAnswerAndShuffle("SELECT * FROM t1 JOIN t3 on t1.id = t3.id", 2)
checkAnswerAndShuffleSorts("SELECT * FROM t1 JOIN t3 on t1.id = t3.id", 2, 2)

// different primary-key data type
spark.sql(
"CREATE TABLE t4 (id STRING, c STRING) TBLPROPERTIES ('primary-key' = 'id', 'bucket'='10')")
spark.sql("INSERT INTO t4 VALUES (1, 'x1'), (2, 'x3'), (3, 'x3'), (4, 'x4'), (5, 'x5')")
checkAnswerAndShuffle("SELECT * FROM t1 JOIN t4 on t1.id = t4.id", 2)
checkAnswerAndShuffleSorts("SELECT * FROM t1 JOIN t4 on t1.id = t4.id", 2, 2)

// different input partition number
spark.sql(
"CREATE TABLE t5 (id INT, c STRING) TBLPROPERTIES ('primary-key' = 'id', 'bucket'='10')")
spark.sql("INSERT INTO t5 VALUES (1, 'x1')")
checkAnswerAndShuffle("SELECT * FROM t1 JOIN t5 on t1.id = t5.id", 2)
checkAnswerAndShuffleSorts("SELECT * FROM t1 JOIN t5 on t1.id = t5.id", 2, 2)

// one more bucket keys
spark.sql(
"CREATE TABLE t6 (id1 INT, id2 INT, c STRING) TBLPROPERTIES ('bucket-key' = 'id1,id2', 'bucket'='10')")
spark.sql(
"INSERT INTO t6 VALUES (1, 1, 'x1'), (2, 2, 'x3'), (3, 3, 'x3'), (4, 4, 'x4'), (5, 5, 'x5')")
checkAnswerAndShuffle("SELECT * FROM t1 JOIN t6 on t1.id = t6.id1", 2)
checkAnswerAndShuffleSorts("SELECT * FROM t1 JOIN t6 on t1.id = t6.id1", 2, 2)
}
}

Expand All @@ -120,17 +133,41 @@ class BucketedTableQueryTest extends PaimonSparkTestBase with AdaptiveSparkPlanH
"CREATE TABLE t1 (id INT, c STRING) TBLPROPERTIES ('primary-key' = 'id', 'bucket'='10')")
spark.sql("INSERT INTO t1 VALUES (1, 'x1'), (2, 'x3'), (3, 'x3'), (4, 'x4'), (5, 'x5')")

checkAnswerAndShuffle("SELECT id, count(*) FROM t1 GROUP BY id", 0)
checkAnswerAndShuffle("SELECT c, count(*) FROM t1 GROUP BY c", 1)
checkAnswerAndShuffle("select sum(c) OVER (PARTITION BY id ORDER BY c) from t1", 0)
checkAnswerAndShuffle("select sum(id) OVER (PARTITION BY c ORDER BY id) from t1", 1)
checkAnswerAndShuffleSorts("SELECT id, count(*) FROM t1 GROUP BY id", 0, 0)
checkAnswerAndShuffleSorts("SELECT id, max(c) FROM t1 GROUP BY id", 0, 0)
checkAnswerAndShuffleSorts("SELECT c, count(*) FROM t1 GROUP BY c", 1, 0)
checkAnswerAndShuffleSorts("SELECT c, max(c) FROM t1 GROUP BY c", 1, 2)
checkAnswerAndShuffleSorts("select sum(c) OVER (PARTITION BY id ORDER BY c) from t1", 0, 1)
// TODO: it is a Spark issue for `WindowExec` which would required partition-by + and order-by
// without do distinct..
checkAnswerAndShuffleSorts("select sum(c) OVER (PARTITION BY id ORDER BY id) from t1", 0, 1)
checkAnswerAndShuffleSorts("select sum(id) OVER (PARTITION BY c ORDER BY id) from t1", 1, 1)

withSQLConf("spark.sql.requireAllClusterKeysForDistribution" -> "false") {
checkAnswerAndShuffle("SELECT id, c, count(*) FROM t1 GROUP BY id, c", 0)
checkAnswerAndShuffleSorts("SELECT id, c, count(*) FROM t1 GROUP BY id, c", 0, 0)
}
withSQLConf("spark.sql.requireAllClusterKeysForDistribution" -> "true") {
checkAnswerAndShuffle("SELECT id, c, count(*) FROM t1 GROUP BY id, c", 1)
checkAnswerAndShuffleSorts("SELECT id, c, count(*) FROM t1 GROUP BY id, c", 1, 0)
}
}
}

test("Report scan output ordering - rawConvertible") {
assume(gteqSpark3_3)

withTable("t") {
spark.sql(
"CREATE TABLE t (id INT, c STRING) TBLPROPERTIES ('primary-key' = 'id', 'bucket'='2', 'deletion-vectors.enabled'='true')")

// one file case
spark.sql(s"INSERT INTO t VALUES (1, 'x1'), (2, 'x3')")
checkAnswerAndShuffleSorts("SELECT id, max(c) FROM t GROUP BY id", 0, 0)

// generate some files
(1.to(20)).foreach {
i => spark.sql(s"INSERT INTO t VALUES ($i, 'x1'), ($i, 'x3'), ($i, 'x3')")
}
checkAnswerAndShuffleSorts("SELECT id, max(c) FROM t GROUP BY id", 0, 1)
}
}
}

0 comments on commit 3efd2f3

Please sign in to comment.