Skip to content

Commit

Permalink
[SPARK-49839][SQL] SPJ: Skip shuffles if possible for sorts
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This is a proposal for skipping shuffles for ORDER BY or other sort operations, if on partition columns.

 ### Why are the changes needed?

This could potentially optimize many jobs, where today all data is shuffled even if we have all the partition values and can sort them.  This is a common scenario, for example Iceberg often requests data to be sorted by partition before write , to avoid small file issue.

 ### Does this PR introduce _any_ user-facing change?

 No

  ### How was this patch tested?

 Add test in KeyGroupedPartitioningSuite

 ### Was this patch authored or co-authored using generative AI tooling?

  No

Closes apache#48303 from szehon-ho/SPARK-49839.

Authored-by: Szehon Ho <[email protected]>
Signed-off-by: Chao Sun <[email protected]>
  • Loading branch information
szehon-ho authored and sunchao committed Dec 15, 2024
1 parent 976192a commit d2965ae
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,13 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
override def createPartitioning(numPartitions: Int): Partitioning = {
RangePartitioning(ordering, numPartitions)
}

def areAllClusterKeysMatched(expressions: Seq[Expression]): Boolean = {
expressions.length == ordering.length &&
expressions.zip(ordering).forall {
case (x, o) => x.semanticEquals(o.child)
}
}
}

/**
Expand Down Expand Up @@ -394,6 +401,9 @@ case class KeyGroupedPartitioning(
}
}

case o @ OrderedDistribution(_) if SQLConf.get.v2BucketingAllowSorting =>
o.areAllClusterKeysMatched(expressions)

case _ =>
false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1679,6 +1679,16 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val V2_BUCKETING_SORTING_ENABLED =
buildConf("spark.sql.sources.v2.bucketing.sorting.enabled")
.doc(s"When turned on, Spark will recognize the specific distribution reported by " +
s"a V2 data source through SupportsReportPartitioning, and will try to avoid a shuffle " +
s"if possible when sorting by those columns. This config requires " +
s"${V2_BUCKETING_ENABLED.key} to be enabled.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets")
.doc("The maximum number of buckets allowed.")
.version("2.4.0")
Expand Down Expand Up @@ -5896,6 +5906,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def v2BucketingAllowCompatibleTransforms: Boolean =
getConf(SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS)

def v2BucketingAllowSorting: Boolean =
getConf(SQLConf.V2_BUCKETING_SORTING_ENABLED)

def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ case class EnsureRequirements(
// Ensure that the operator's children satisfy their output distribution requirements.
var children = originalChildren.zip(requiredChildDistributions).map {
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
child
ensureOrdering(child, distribution)
case (child, BroadcastDistribution(mode)) =>
BroadcastExchangeExec(mode, child)
case (child, distribution) =>
Expand Down Expand Up @@ -290,6 +290,23 @@ case class EnsureRequirements(
}
}

private def ensureOrdering(plan: SparkPlan, distribution: Distribution) = {
(plan.outputPartitioning, distribution) match {
case (p @ KeyGroupedPartitioning(expressions, _, partitionValues, _),
d @ OrderedDistribution(ordering)) if p.satisfies(d) =>
val attrs = expressions.flatMap(_.collectLeaves()).map(_.asInstanceOf[Attribute])
val partitionOrdering: Ordering[InternalRow] = {
RowOrdering.create(ordering, attrs)
}
// Sort 'commonPartitionValues' and use this mechanism to ensure BatchScan's
// output partitions are ordered
val sorted = partitionValues.sorted(partitionOrdering)
populateCommonPartitionInfo(plan, sorted.map((_, 1)),
None, None, applyPartialClustering = false, replicatePartitions = false)
case _ => plan
}
}

/**
* Recursively reorders the join keys based on partitioning. It starts reordering the
* join keys to match HashPartitioning on either side, followed by PartitioningCollection.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,62 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
checkAnswer(df.sort("res"), Seq(Row(10.0), Row(15.5), Row(41.0)))
}

test("SPARK-48655: order by on partition keys should not introduce additional shuffle") {
val items_partitions = Array(identity("price"), identity("id"))
createTable(items, itemsColumns, items_partitions)
sql(s"INSERT INTO testcat.ns.$items VALUES " +
s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
s"(1, 'aa', 41.0, cast('2020-01-02' as timestamp)), " +
s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp)), " +
s"(null, 'cc', 15.5, cast('2020-02-01' as timestamp)), " +
s"(3, 'cc', null, cast('2020-02-01' as timestamp))")

Seq(true, false).foreach { sortingEnabled =>
withSQLConf(SQLConf.V2_BUCKETING_SORTING_ENABLED.key -> sortingEnabled.toString) {

def verifyShuffle(cmd: String, answer: Seq[Row]): Unit = {
val df = sql(cmd)
if (sortingEnabled) {
assert(collectAllShuffles(df.queryExecution.executedPlan).isEmpty,
"should contain no shuffle when sorting by partition values")
} else {
assert(collectAllShuffles(df.queryExecution.executedPlan).size == 1,
"should contain one shuffle when optimization is disabled")
}
checkAnswer(df, answer)
}: Unit

verifyShuffle(
s"SELECT price, id FROM testcat.ns.$items ORDER BY price ASC, id ASC",
Seq(Row(null, 3), Row(10.0, 2), Row(15.5, null),
Row(15.5, 3), Row(40.0, 1), Row(41.0, 1)))

verifyShuffle(
s"SELECT price, id FROM testcat.ns.$items " +
s"ORDER BY price ASC NULLS LAST, id ASC NULLS LAST",
Seq(Row(10.0, 2), Row(15.5, 3), Row(15.5, null),
Row(40.0, 1), Row(41.0, 1), Row(null, 3)))

verifyShuffle(
s"SELECT price, id FROM testcat.ns.$items ORDER BY price DESC, id ASC",
Seq(Row(41.0, 1), Row(40.0, 1), Row(15.5, null),
Row(15.5, 3), Row(10.0, 2), Row(null, 3)))

verifyShuffle(
s"SELECT price, id FROM testcat.ns.$items ORDER BY price DESC, id DESC",
Seq(Row(41.0, 1), Row(40.0, 1), Row(15.5, 3),
Row(15.5, null), Row(10.0, 2), Row(null, 3)))

verifyShuffle(
s"SELECT price, id FROM testcat.ns.$items " +
s"ORDER BY price DESC NULLS FIRST, id DESC NULLS FIRST",
Seq(Row(null, 3), Row(41.0, 1), Row(40.0, 1),
Row(15.5, null), Row(15.5, 3), Row(10.0, 2)));
}
}
}

test("SPARK-49179: Fix v2 multi bucketed inner joins throw AssertionError") {
val cols = Array(
Column.create("id", LongType),
Expand Down

0 comments on commit d2965ae

Please sign in to comment.