From 7ec68cf242111706afb926f44df78311a5267333 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Fri, 9 Aug 2024 17:48:06 +0800 Subject: [PATCH] Support auto disable bucketed scan --- .../org/apache/paimon/spark/PaimonScan.scala | 3 +- .../org/apache/paimon/spark/PaimonScan.scala | 3 +- .../paimon/spark/PaimonInputPartition.scala | 4 + .../org/apache/paimon/spark/PaimonScan.scala | 38 ++- ...DisableUnnecessaryPaimonBucketedScan.scala | 169 +++++++++++ .../PaimonSparkSessionExtensions.scala | 4 + ...leUnnecessaryPaimonBucketedScanSuite.scala | 268 ++++++++++++++++++ pom.xml | 3 + 8 files changed, 480 insertions(+), 12 deletions(-) create mode 100644 paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/adaptive/DisableUnnecessaryPaimonBucketedScan.scala create mode 100644 paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DisableUnnecessaryPaimonBucketedScanSuite.scala diff --git a/paimon-spark/paimon-spark-3.1/src/main/scala/org/apache/paimon/spark/PaimonScan.scala b/paimon-spark/paimon-spark-3.1/src/main/scala/org/apache/paimon/spark/PaimonScan.scala index 24dfb342abf3e..254c63679bcad 100644 --- a/paimon-spark/paimon-spark-3.1/src/main/scala/org/apache/paimon/spark/PaimonScan.scala +++ b/paimon-spark/paimon-spark-3.1/src/main/scala/org/apache/paimon/spark/PaimonScan.scala @@ -29,5 +29,6 @@ case class PaimonScan( requiredSchema: StructType, filters: Seq[Predicate], reservedFilters: Seq[Filter], - pushDownLimit: Option[Int]) + pushDownLimit: Option[Int], + disableBucketedScan: Boolean = false) extends PaimonBaseScan(table, requiredSchema, filters, reservedFilters, pushDownLimit) diff --git a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala index f0476cf70732a..361b1e7a77d67 100644 --- a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala +++ b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala @@ -34,7 +34,8 @@ case class PaimonScan( requiredSchema: StructType, filters: Seq[Predicate], reservedFilters: Seq[Filter], - pushDownLimit: Option[Int]) + pushDownLimit: Option[Int], + disableBucketedScan: Boolean = false) extends PaimonBaseScan(table, requiredSchema, filters, reservedFilters, pushDownLimit) with SupportsRuntimeFiltering { diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala index a7c33b21de5b4..7e3dbf893b226 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala @@ -30,6 +30,9 @@ trait PaimonInputPartition extends InputPartition { def rowCount(): Long = { splits.map(_.rowCount()).sum } + + // Used to avoid checking [[PaimonBucketedInputPartition]] to workaround for multi Spark version + def bucketed = false } case class SimplePaimonInputPartition(splits: Seq[Split]) extends PaimonInputPartition @@ -48,4 +51,5 @@ case class PaimonBucketedInputPartition(splits: Seq[Split], bucket: Int) extends PaimonInputPartition with HasPartitionKey { override def partitionKey(): InternalRow = new GenericInternalRow(Array(bucket.asInstanceOf[Any])) + override def bucketed: Boolean = true } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala index f34a24991d2a5..296f0569efeba 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala @@ -23,7 +23,7 @@ import org.apache.paimon.table.{BucketMode, FileStoreTable, Table} import org.apache.paimon.table.source.{DataSplit, Split} import org.apache.spark.sql.PaimonUtils.fieldReference -import org.apache.spark.sql.connector.expressions.{Expressions, NamedReference} +import org.apache.spark.sql.connector.expressions.{Expressions, NamedReference, Transform} import org.apache.spark.sql.connector.read.{SupportsReportPartitioning, SupportsRuntimeFiltering} import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.sources.{Filter, In} @@ -36,35 +36,53 @@ case class PaimonScan( requiredSchema: StructType, filters: Seq[Predicate], reservedFilters: Seq[Filter], - pushDownLimit: Option[Int]) + pushDownLimit: Option[Int], + disableBucketedScan: Boolean = false) extends PaimonBaseScan(table, requiredSchema, filters, reservedFilters, pushDownLimit) with SupportsRuntimeFiltering with SupportsReportPartitioning { - override def outputPartitioning(): Partitioning = { + def withDisabledBucketedScan(): PaimonScan = { + copy(disableBucketedScan = true) + } + + @transient + private lazy val extractBucketTransform: Option[Transform] = { table match { case fileStoreTable: FileStoreTable => val bucketSpec = fileStoreTable.bucketSpec() if (bucketSpec.getBucketMode != BucketMode.HASH_FIXED) { - new UnknownPartitioning(0) + None } else if (bucketSpec.getBucketKeys.size() > 1) { - new UnknownPartitioning(0) + None } else { // Spark does not support bucket with several input attributes, // so we only support one bucket key case. assert(bucketSpec.getNumBuckets > 0) assert(bucketSpec.getBucketKeys.size() == 1) - val key = Expressions.bucket(bucketSpec.getNumBuckets, bucketSpec.getBucketKeys.get(0)) - new KeyGroupedPartitioning(Array(key), lazyInputPartitions.size) + val bucketKey = bucketSpec.getBucketKeys.get(0) + if (requiredSchema.exists(f => conf.resolver(f.name, bucketKey))) { + Some(Expressions.bucket(bucketSpec.getNumBuckets, bucketKey)) + } else { + None + } } - case _ => - new UnknownPartitioning(0) + case _ => None } } + override def outputPartitioning: Partitioning = { + extractBucketTransform + .map(bucket => new KeyGroupedPartitioning(Array(bucket), lazyInputPartitions.size)) + .getOrElse(new UnknownPartitioning(0)) + } + override def getInputPartitions(splits: Array[Split]): Seq[PaimonInputPartition] = { - if (!conf.v2BucketingEnabled || splits.exists(!_.isInstanceOf[DataSplit])) { + if ( + disableBucketedScan || !conf.v2BucketingEnabled || extractBucketTransform.isEmpty || + splits.exists(!_.isInstanceOf[DataSplit]) + ) { return super.getInputPartitions(splits) } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/adaptive/DisableUnnecessaryPaimonBucketedScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/adaptive/DisableUnnecessaryPaimonBucketedScan.scala new file mode 100644 index 0000000000000..382f96dfe5119 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/adaptive/DisableUnnecessaryPaimonBucketedScan.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.execution.adaptive + +import org.apache.paimon.spark.PaimonScan + +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.aggregate.BaseAggregateExec +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeLike} + +// spotless:off +/** + * This rule is inspired from Spark [[DisableUnnecessaryBucketedScan]] but work for v2 scan. + * + * Disable unnecessary bucketed table scan based on actual physical query plan. + * NOTE: this rule is designed to be applied right after [[EnsureRequirements]], + * where all [[ShuffleExchangeLike]] and [[SortExec]] have been added to plan properly. + * + * When BUCKETING_ENABLED and AUTO_BUCKETED_SCAN_ENABLED are set to true, go through + * query plan to check where bucketed table scan is unnecessary, and disable bucketed table + * scan if: + * + * 1. The sub-plan from root to bucketed table scan, does not contain + * [[hasInterestingPartition]] operator. + * + * 2. The sub-plan from the nearest downstream [[hasInterestingPartition]] operator + * to the bucketed table scan and at least one [[ShuffleExchangeLike]]. + * + * Examples: + * 1. no [[hasInterestingPartition]] operator: + * Project + * | + * Filter + * | + * Scan(t1: i, j) + * (bucketed on column j, DISABLE bucketed scan) + * + * 2. join: + * SortMergeJoin(t1.i = t2.j) + * / \ + * Sort(i) Sort(j) + * / \ + * Shuffle(i) Scan(t2: i, j) + * / (bucketed on column j, enable bucketed scan) + * Scan(t1: i, j) + * (bucketed on column j, DISABLE bucketed scan) + * + * 3. aggregate: + * HashAggregate(i, ..., Final) + * | + * Shuffle(i) + * | + * HashAggregate(i, ..., Partial) + * | + * Filter + * | + * Scan(t1: i, j) + * (bucketed on column j, DISABLE bucketed scan) + * + * The idea of [[hasInterestingPartition]] is inspired from "interesting order" in + * the paper "Access Path Selection in a Relational Database Management System" + * (https://dl.acm.org/doi/10.1145/582095.582099). + */ +// spotless:on +object DisableUnnecessaryPaimonBucketedScan extends Rule[SparkPlan] { + + /** + * Disable bucketed table scan with pre-order traversal of plan. + * + * @param hashInterestingPartition + * The traversed plan has operator with interesting partition. + * @param hasExchange + * The traversed plan has [[Exchange]] operator. + */ + private def disableBucketScan( + plan: SparkPlan, + hashInterestingPartition: Boolean, + hasExchange: Boolean): SparkPlan = { + plan match { + case p if hasInterestingPartition(p) => + // Operator with interesting partition, propagates `hashInterestingPartition` as true + // to its children, and resets `hasExchange`. + p.mapChildren(disableBucketScan(_, hashInterestingPartition = true, hasExchange = false)) + case exchange: ShuffleExchangeLike => + // Exchange operator propagates `hasExchange` as true to its child. + exchange.mapChildren(disableBucketScan(_, hashInterestingPartition, hasExchange = true)) + case batch: BatchScanExec => + val paimonBucketedScan = extractPaimonBucketedScan(batch) + if (paimonBucketedScan.isDefined && (!hashInterestingPartition || hasExchange)) { + val (batch, paimonScan) = paimonBucketedScan.get + val newBatch = batch.copy(scan = paimonScan.withDisabledBucketedScan()) + newBatch.copyTagsFrom(batch) + newBatch + } else { + batch + } + case p if canPassThrough(p) => + p.mapChildren(disableBucketScan(_, hashInterestingPartition, hasExchange)) + case other => + other.mapChildren( + disableBucketScan(_, hashInterestingPartition = false, hasExchange = false)) + } + } + + private def hasInterestingPartition(plan: SparkPlan): Boolean = { + plan.requiredChildDistribution.exists { + case _: ClusteredDistribution | AllTuples => true + case _ => false + } + } + + /** + * Check if the operator is allowed single-child operator. We may revisit this method later as we + * probably can remove this restriction to allow arbitrary operator between bucketed table scan + * and operator with interesting partition. + */ + private def canPassThrough(plan: SparkPlan): Boolean = { + plan match { + case _: ProjectExec | _: FilterExec => true + case s: SortExec if !s.global => true + case partialAgg: BaseAggregateExec => + partialAgg.requiredChildDistributionExpressions.isEmpty + case _ => false + } + } + + def extractPaimonBucketedScan(plan: SparkPlan): Option[(BatchScanExec, PaimonScan)] = + plan match { + case batch: BatchScanExec => + batch.scan match { + case scan: PaimonScan if scan.lazyInputPartitions.forall(_.bucketed) => + Some((batch, scan)) + case _ => None + } + case _ => None + } + + def apply(plan: SparkPlan): SparkPlan = { + lazy val hasBucketedScan = plan.exists { + case p if extractPaimonBucketedScan(p).isDefined => true + case _ => false + } + + if (!conf.v2BucketingEnabled || !conf.autoBucketedScanEnabled || !hasBucketedScan) { + plan + } else { + disableBucketScan(plan, hashInterestingPartition = false, hasExchange = false) + } + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala index 3cd2783221c9b..4fe217ee09bd8 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala @@ -22,6 +22,7 @@ import org.apache.paimon.spark.catalyst.analysis.{PaimonAnalysis, PaimonDeleteTa import org.apache.paimon.spark.catalyst.optimizer.{EvalSubqueriesForDeleteTable, MergePaimonScalarSubqueries} import org.apache.paimon.spark.catalyst.plans.logical.PaimonTableValuedFunctions import org.apache.paimon.spark.execution.PaimonStrategy +import org.apache.paimon.spark.execution.adaptive.DisableUnnecessaryPaimonBucketedScan import org.apache.spark.sql.SparkSessionExtensions import org.apache.spark.sql.catalyst.parser.extensions.PaimonSparkSqlExtensionsParser @@ -58,5 +59,8 @@ class PaimonSparkSessionExtensions extends (SparkSessionExtensions => Unit) { // planner extensions extensions.injectPlannerStrategy(spark => PaimonStrategy(spark)) + + // query stage preparation + extensions.injectQueryStagePrepRule(_ => DisableUnnecessaryPaimonBucketedScan) } } diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DisableUnnecessaryPaimonBucketedScanSuite.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DisableUnnecessaryPaimonBucketedScanSuite.scala new file mode 100644 index 0000000000000..3112cdf22c094 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DisableUnnecessaryPaimonBucketedScanSuite.scala @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +import org.apache.paimon.spark.PaimonSparkTestBase +import org.apache.paimon.spark.execution.adaptive.DisableUnnecessaryPaimonBucketedScan + +import org.apache.spark.SparkConf +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.internal.SQLConf + +class DisableUnnecessaryPaimonBucketedScanSuite + extends PaimonSparkTestBase + with AdaptiveSparkPlanHelper { + + override protected def sparkConf: SparkConf = { + // Make non-shuffle query work with stage preparation rule + super.sparkConf + .set("spark.sql.adaptive.forceApply", "true") + } + + private def checkDisableBucketedScan( + query: String, + expectedNumScanWithAutoScanEnabled: Int, + expectedNumScanWithAutoScanDisabled: Int): Unit = { + + def checkNumBucketedScan(df: DataFrame, expectedNumBucketedScan: Int): Unit = { + val plan = df.queryExecution.executedPlan + val bucketedScan = collect(plan) { + case p if DisableUnnecessaryPaimonBucketedScan.extractPaimonBucketedScan(p).isDefined => p + } + assert(bucketedScan.length == expectedNumBucketedScan, query) + } + + withSQLConf(SQLConf.V2_BUCKETING_ENABLED.key -> "true") { + withSQLConf(SQLConf.AUTO_BUCKETED_SCAN_ENABLED.key -> "true") { + val df = sql(query) + val result = df.collect() + checkNumBucketedScan(df, expectedNumScanWithAutoScanEnabled) + + withSQLConf(SQLConf.AUTO_BUCKETED_SCAN_ENABLED.key -> "false") { + val expected = sql(query) + checkAnswer(expected, result) + checkNumBucketedScan(expected, expectedNumScanWithAutoScanDisabled) + } + } + } + } + + private def initializeTable(): Unit = { + spark.sql( + "CREATE TABLE t1 (i INT, j INT, k STRING) TBLPROPERTIES ('primary-key' = 'i', 'bucket'='10')") + spark.sql( + "INSERT INTO t1 VALUES (1, 1, 'x1'), (2, 2, 'x3'), (3, 3, 'x3'), (4, 4, 'x4'), (5, 5, 'x5')") + spark.sql( + "CREATE TABLE t2 (i INT, j INT, k STRING) TBLPROPERTIES ('primary-key' = 'i', 'bucket'='10')") + spark.sql( + "INSERT INTO t2 VALUES (1, 1, 'x1'), (2, 2, 'x3'), (3, 3, 'x3'), (4, 4, 'x4'), (5, 5, 'x5')") + spark.sql( + "CREATE TABLE t3 (i INT, j INT, k STRING) TBLPROPERTIES ('primary-key' = 'i', 'bucket'='2')") + spark.sql( + "INSERT INTO t3 VALUES (1, 1, 'x1'), (2, 2, 'x3'), (3, 3, 'x3'), (4, 4, 'x4'), (5, 5, 'x5')") + } + + test("Disable unnecessary bucketed table scan - basic test") { + assume(gteqSpark3_3) + + withTable("t1", "t2", "t3") { + initializeTable() + + Seq( + // Read bucketed table + ("SELECT * FROM t1", 0, 1), + ("SELECT i FROM t1", 0, 1), + ("SELECT j FROM t1", 0, 0), + // Filter on bucketed column + ("SELECT * FROM t1 WHERE i = 1", 0, 1), + // Filter on non-bucketed column + ("SELECT * FROM t1 WHERE j = 1", 0, 1), + // Join with same buckets + ("SELECT /*+ broadcast(t1)*/ * FROM t1 JOIN t2 ON t1.i = t2.i", 0, 2), + ("SELECT /*+ shuffle_hash(t1)*/ * FROM t1 JOIN t2 ON t1.i = t2.i", 2, 2), + ("SELECT /*+ merge(t1)*/ * FROM t1 JOIN t2 ON t1.i = t2.i", 2, 2), + // Join with different buckets + ("SELECT /*+ broadcast(t1)*/ * FROM t1 JOIN t3 ON t1.i = t3.i", 0, 2), + ("SELECT /*+ shuffle_hash(t1)*/ * FROM t1 JOIN t3 ON t1.i = t3.i", 0, 2), + ("SELECT /*+ merge(t1)*/ * FROM t1 JOIN t3 ON t1.i = t3.i", 0, 2), + // Join on non-bucketed column + ("SELECT /*+ broadcast(t1)*/ * FROM t1 JOIN t2 ON t1.i = t2.j", 0, 2), + ("SELECT /*+ shuffle_hash(t1)*/ * FROM t1 JOIN t2 ON t1.i = t2.j", 0, 2), + ("SELECT /*+ merge(t1)*/ * FROM t1 JOIN t2 ON t1.i = t2.j", 0, 2), + ("SELECT /*+ broadcast(t1)*/ * FROM t1 JOIN t2 ON t1.j = t2.j", 0, 2), + ("SELECT /*+ shuffle_hash(t1)*/ * FROM t1 JOIN t2 ON t1.j = t2.j", 0, 2), + ("SELECT /*+ merge(t1)*/ * FROM t1 JOIN t2 ON t1.j = t2.j", 0, 2), + // Aggregate on bucketed column + ("SELECT SUM(i) FROM t1 GROUP BY i", 1, 1), + // Aggregate on non-bucketed column + ("SELECT SUM(i) FROM t1 GROUP BY j", 0, 1), + ("SELECT j, SUM(i), COUNT(j) FROM t1 GROUP BY j", 0, 1) + ).foreach { + case (query, numScanWithAutoScanEnabled, numScanWithAutoScanDisabled) => + checkDisableBucketedScan(query, numScanWithAutoScanEnabled, numScanWithAutoScanDisabled) + } + } + } + + test("Disable unnecessary bucketed table scan - multiple joins test") { + assume(gteqSpark3_3) + + withTable("t1", "t2", "t3") { + initializeTable() + + Seq( + // Multiple joins on bucketed columns + ( + """ + SELECT /*+ broadcast(t1, t3)*/ * FROM t1 JOIN t2 JOIN t3 + ON t1.i = t2.i AND t2.i = t3.i + """.stripMargin, + 0, + 3), + ( + """ + SELECT /*+ broadcast(t1) merge(t3)*/ * FROM t1 JOIN t2 JOIN t3 + ON t1.i = t2.i AND t2.i = t3.i + """.stripMargin, + 0, + 3), + ( + """ + SELECT /*+ merge(t1) broadcast(t3)*/ * FROM t1 JOIN t2 JOIN t3 + ON t1.i = t2.i AND t2.i = t3.i + """.stripMargin, + 2, + 3), + ( + """ + SELECT /*+ merge(t1, t3)*/ * FROM t1 LEFT JOIN t2 LEFT JOIN t3 + ON t1.i = t2.i AND t2.i = t3.i + """.stripMargin, + 0, + 3), + // Multiple joins on non-bucketed columns + ( + """ + SELECT /*+ broadcast(t1, t3)*/ * FROM t1 JOIN t2 JOIN t3 + ON t1.i = t2.j AND t2.j = t3.i + """.stripMargin, + 0, + 3), + ( + """ + SELECT /*+ merge(t1, t3)*/ * FROM t1 JOIN t2 JOIN t3 + ON t1.i = t2.j AND t2.j = t3.i + """.stripMargin, + 0, + 3), + ( + """ + SELECT /*+ merge(t1, t3)*/ * FROM t1 JOIN t2 JOIN t3 + ON t1.j = t2.j AND t2.j = t3.j + """.stripMargin, + 0, + 3) + ).foreach { + case (query, numScanWithAutoScanEnabled, numScanWithAutoScanDisabled) => + checkDisableBucketedScan(query, numScanWithAutoScanEnabled, numScanWithAutoScanDisabled) + } + } + } + + test("Disable unnecessary bucketed table scan - other operators test") { + assume(gteqSpark3_3) + + withTable("t1", "t2", "t3") { + initializeTable() + + Seq( + // Operator with interesting partition not in sub-plan + ( + """ + SELECT t1.i FROM t1 + UNION ALL + (SELECT t2.i FROM t2 GROUP BY t2.i) + """.stripMargin, + 1, + 2), + // Non-allowed operator in sub-plan + ( + """ + SELECT COUNT(*) + FROM (SELECT t1.i FROM t1 UNION ALL SELECT t2.i FROM t2) + GROUP BY i + """.stripMargin, + 0, + 2), + // Multiple [[Exchange]] in sub-plan + ( + """ + SELECT j, SUM(i), COUNT(*) FROM t1 GROUP BY j + DISTRIBUTE BY j + """.stripMargin, + 0, + 1), + ( + """ + SELECT j, COUNT(*) + FROM (SELECT i, j FROM t1 DISTRIBUTE BY i, j) + GROUP BY j + """.stripMargin, + 0, + 1), + // No bucketed table scan in plan + ( + """ + SELECT j, COUNT(*) + FROM (SELECT t1.j FROM t1 JOIN t3 ON t1.j = t3.j) + GROUP BY j + """.stripMargin, + 0, + 0) + ).foreach { + case (query, numScanWithAutoScanEnabled, numScanWithAutoScanDisabled) => + checkDisableBucketedScan(query, numScanWithAutoScanEnabled, numScanWithAutoScanDisabled) + } + } + } + + test("Aggregates with no groupby over tables having 1 BUCKET, return multiple rows") { + assume(gteqSpark3_3) + + withTable("t1") { + sql(""" + |CREATE TABLE t1 (`id` BIGINT, `event_date` DATE) + |TBLPROPERTIES ('primary-key' = 'id', 'bucket'='1') + |""".stripMargin) + sql(""" + |INSERT INTO TABLE t1 VALUES(1.23, cast("2021-07-07" as date)) + |""".stripMargin) + sql(""" + |INSERT INTO TABLE t1 VALUES(2.28, cast("2021-08-08" as date)) + |""".stripMargin) + val df = spark.sql("select sum(id) from t1 where id is not null") + assert(df.count() == 1) + checkDisableBucketedScan( + query = "SELECT SUM(id) FROM t1 WHERE id is not null", + expectedNumScanWithAutoScanEnabled = 1, + expectedNumScanWithAutoScanDisabled = 1) + } + } +} diff --git a/pom.xml b/pom.xml index 75a7cd098a95a..b34ac3f876239 100644 --- a/pom.xml +++ b/pom.xml @@ -834,6 +834,9 @@ under the License. + + + 3.4.3 ${maven.multiModuleProjectDirectory}/.scalafmt.conf