From e90e5ab494979de3b6a8432510fed423c63e2a46 Mon Sep 17 00:00:00 2001 From: Yann Byron Date: Sun, 17 Dec 2023 20:21:26 +0800 Subject: [PATCH] [spark] Dynamically adjust the parallelism of scan (#2482) --- .../apache/paimon/spark/PaimonBaseScan.scala | 14 +- .../org/apache/paimon/spark/PaimonScan.scala | 3 +- .../org/apache/paimon/spark/ScanHelper.scala | 131 ++++++++++++++++++ .../apache/paimon/spark/ScanHelperTest.scala | 73 ++++++++++ .../paimon/spark/sql/PaimonPushDownTest.scala | 6 +- 5 files changed, 221 insertions(+), 6 deletions(-) create mode 100644 paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/ScanHelper.scala create mode 100644 paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/ScanHelperTest.scala diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala index eb055ccf41f6..a822ea9db5d0 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala @@ -17,6 +17,7 @@ */ package org.apache.paimon.spark +import org.apache.paimon.CoreOptions import org.apache.paimon.predicate.{Predicate, PredicateBuilder} import org.apache.paimon.spark.sources.PaimonMicroBatchStream import org.apache.paimon.table.{DataTable, FileStoreTable, Table} @@ -36,7 +37,8 @@ abstract class PaimonBaseScan( filters: Array[(Filter, Predicate)], pushDownLimit: Option[Int]) extends Scan - with SupportsReportStatistics { + with SupportsReportStatistics + with ScanHelper { private val tableRowType = table.rowType @@ -46,7 +48,9 @@ abstract class PaimonBaseScan( protected var splits: Array[Split] = _ - protected lazy val readBuilder: ReadBuilder = { + override val coreOptions: CoreOptions = CoreOptions.fromMap(table.options()) + + lazy val readBuilder: ReadBuilder = { val _readBuilder = table.newReadBuilder() val projection = readSchema().fieldNames.map(field => tableRowType.getFieldNames.indexOf(field)) @@ -60,9 +64,13 @@ abstract class PaimonBaseScan( _readBuilder } + def getOriginSplits: Array[Split] = { + readBuilder.newScan().plan().splits().asScala.toArray + } + def getSplits: Array[Split] = { if (splits == null) { - splits = readBuilder.newScan().plan().splits().asScala.toArray + splits = reshuffleSplits(getOriginSplits) } splits } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala index c8ca0748405f..b05e9be35757 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala @@ -56,7 +56,8 @@ case class PaimonScan( if (partitionFilter.nonEmpty) { this.runtimeFilters = filters readBuilder.withFilter(partitionFilter.head) - splits = readBuilder.newScan().plan().splits().asScala.toArray + // set splits null to trigger to get the new splits. + splits = null } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/ScanHelper.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/ScanHelper.scala new file mode 100644 index 000000000000..cdadb579028b --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/ScanHelper.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.paimon.spark + +import org.apache.paimon.CoreOptions +import org.apache.paimon.io.DataFileMeta +import org.apache.paimon.table.source.{DataSplit, RawFile, Split} + +import org.apache.spark.sql.SparkSession + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +trait ScanHelper { + + private val spark = SparkSession.active + + val coreOptions: CoreOptions + + private lazy val openCostInBytes: Long = coreOptions.splitOpenFileCost() + + private lazy val leafNodeDefaultParallelism: Int = { + spark.conf + .get("spark.sql.leafNodeDefaultParallelism", spark.sparkContext.defaultParallelism.toString) + .toInt + } + + def reshuffleSplits(splits: Array[Split]): Array[Split] = { + if (splits.length < leafNodeDefaultParallelism) { + val (toReshuffle, reserved) = splits.partition { + case split: DataSplit => split.beforeFiles().isEmpty && split.convertToRawFiles.isPresent + case _ => false + } + val reshuffled = reshuffleSplits0(toReshuffle.collect { case ds: DataSplit => ds }) + reshuffled ++ reserved + } else { + splits + } + } + + private def reshuffleSplits0(splits: Array[DataSplit]): Array[DataSplit] = { + val maxSplitBytes = computeMaxSplitBytes(splits) + + val newSplits = new ArrayBuffer[DataSplit] + + var currentSplit: Option[DataSplit] = None + val currentDataFiles = new ArrayBuffer[DataFileMeta] + val currentRawFiles = new ArrayBuffer[RawFile] + var currentSize = 0L + + def closeDataSplit(): Unit = { + if (currentSplit.nonEmpty && currentDataFiles.nonEmpty) { + val newSplit = copyDataSplit(currentSplit.get, currentDataFiles, currentRawFiles) + newSplits += newSplit + } + currentDataFiles.clear() + currentRawFiles.clear() + currentSize = 0 + } + + splits.foreach { + split => + currentSplit = Some(split) + val hasRawFiles = split.convertToRawFiles().isPresent + + split.dataFiles().asScala.zipWithIndex.foreach { + case (file, idx) => + if (currentSize + file.fileSize > maxSplitBytes) { + closeDataSplit() + } + currentSize += file.fileSize + openCostInBytes + currentDataFiles += file + if (hasRawFiles) { + currentRawFiles += split.convertToRawFiles().get().get(idx) + } + } + closeDataSplit() + } + + newSplits.toArray + } + + private def unpack(split: Split): Array[DataFileMeta] = { + split match { + case ds: DataSplit => + ds.dataFiles().asScala.toArray + case _ => Array.empty + } + } + + private def copyDataSplit( + split: DataSplit, + dataFiles: Seq[DataFileMeta], + rawFiles: Seq[RawFile]): DataSplit = { + val builder = DataSplit + .builder() + .withSnapshot(split.snapshotId()) + .withPartition(split.partition()) + .withBucket(split.bucket()) + .withDataFiles(dataFiles.toList.asJava) + .rawFiles(rawFiles.toList.asJava) + builder.build() + } + + private def computeMaxSplitBytes(dataSplits: Seq[DataSplit]): Long = { + val dataFiles = dataSplits.flatMap(unpack) + val defaultMaxSplitBytes = spark.sessionState.conf.filesMaxPartitionBytes + val minPartitionNum = spark.sessionState.conf.filesMinPartitionNum + .getOrElse(leafNodeDefaultParallelism) + val totalBytes = dataFiles.map(file => file.fileSize + openCostInBytes).sum + val bytesPerCore = totalBytes / minPartitionNum + + Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore)) + } + +} diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/ScanHelperTest.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/ScanHelperTest.scala new file mode 100644 index 000000000000..90c5d1ba05ef --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/ScanHelperTest.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.paimon.spark + +import org.apache.paimon.CoreOptions +import org.apache.paimon.data.BinaryRow +import org.apache.paimon.io.DataFileMeta +import org.apache.paimon.table.source.{DataSplit, RawFile, Split} + +import org.junit.jupiter.api.Assertions + +import java.util.HashMap + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +class ScanHelperTest extends PaimonSparkTestBase { + + test("Paimon: reshuffle splits") { + withSQLConf(("spark.sql.leafNodeDefaultParallelism", "20")) { + val splitNum = 5 + val fileNum = 100 + + val files = scala.collection.mutable.ListBuffer.empty[DataFileMeta] + val rawFiles = scala.collection.mutable.ListBuffer.empty[RawFile] + 0.until(fileNum).foreach { + i => + val path = s"f$i.parquet" + files += DataFileMeta.forAppend(path, 750000, 30000, null, 0, 29999, 1) + + rawFiles += new RawFile(s"/a/b/$path", 0, 75000, "parquet", 0, 30000) + } + + val dataSplits = mutable.ArrayBuffer.empty[Split] + 0.until(splitNum).foreach { + i => + dataSplits += DataSplit + .builder() + .withSnapshot(1) + .withBucket(0) + .withPartition(new BinaryRow(0)) + .withDataFiles(files.zipWithIndex.filter(_._2 % splitNum == i).map(_._1).toList.asJava) + .rawFiles(rawFiles.zipWithIndex.filter(_._2 % splitNum == i).map(_._1).toList.asJava) + .build() + } + + val fakeScan = new FakeScan() + val reshuffled = fakeScan.reshuffleSplits(dataSplits.toArray) + Assertions.assertTrue(reshuffled.length > 5) + } + } + + class FakeScan extends ScanHelper { + override val coreOptions: CoreOptions = + CoreOptions.fromMap(new HashMap[String, String]()) + } + +} diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala index 4837bca792e0..89d79667a169 100644 --- a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala @@ -17,7 +17,7 @@ */ package org.apache.paimon.spark.sql -import org.apache.paimon.spark.{PaimonSparkTestBase, SparkInputPartition, SparkTable} +import org.apache.paimon.spark.{PaimonBatch, PaimonScan, PaimonSparkTestBase, SparkInputPartition, SparkTable} import org.apache.paimon.table.source.DataSplit import org.apache.spark.sql.Row @@ -111,7 +111,9 @@ class PaimonPushDownTest extends PaimonSparkTestBase { // It still return false even it can push down limit. Assertions.assertFalse(scanBuilder.asInstanceOf[SupportsPushDownLimit].pushLimit(1)) - val partitions = scanBuilder.build().toBatch.planInputPartitions() + val paimonScan = scanBuilder.build().asInstanceOf[PaimonScan] + val partitions = + PaimonBatch(paimonScan.getOriginSplits, paimonScan.readBuilder).planInputPartitions() Assertions.assertEquals(1, partitions.length) Assertions.assertEquals(1, spark.sql("SELECT * FROM T LIMIT 1").count())