Skip to content

Commit

Permalink
[spark] Dynamically adjust the parallelism of scan (#2482)
Browse files Browse the repository at this point in the history
  • Loading branch information
YannByron authored Dec 17, 2023
1 parent 17063b0 commit e90e5ab
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.paimon.spark

import org.apache.paimon.CoreOptions
import org.apache.paimon.predicate.{Predicate, PredicateBuilder}
import org.apache.paimon.spark.sources.PaimonMicroBatchStream
import org.apache.paimon.table.{DataTable, FileStoreTable, Table}
Expand All @@ -36,7 +37,8 @@ abstract class PaimonBaseScan(
filters: Array[(Filter, Predicate)],
pushDownLimit: Option[Int])
extends Scan
with SupportsReportStatistics {
with SupportsReportStatistics
with ScanHelper {

private val tableRowType = table.rowType

Expand All @@ -46,7 +48,9 @@ abstract class PaimonBaseScan(

protected var splits: Array[Split] = _

protected lazy val readBuilder: ReadBuilder = {
override val coreOptions: CoreOptions = CoreOptions.fromMap(table.options())

lazy val readBuilder: ReadBuilder = {
val _readBuilder = table.newReadBuilder()

val projection = readSchema().fieldNames.map(field => tableRowType.getFieldNames.indexOf(field))
Expand All @@ -60,9 +64,13 @@ abstract class PaimonBaseScan(
_readBuilder
}

def getOriginSplits: Array[Split] = {
readBuilder.newScan().plan().splits().asScala.toArray
}

def getSplits: Array[Split] = {
if (splits == null) {
splits = readBuilder.newScan().plan().splits().asScala.toArray
splits = reshuffleSplits(getOriginSplits)
}
splits
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ case class PaimonScan(
if (partitionFilter.nonEmpty) {
this.runtimeFilters = filters
readBuilder.withFilter(partitionFilter.head)
splits = readBuilder.newScan().plan().splits().asScala.toArray
// set splits null to trigger to get the new splits.
splits = null
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.paimon.spark

import org.apache.paimon.CoreOptions
import org.apache.paimon.io.DataFileMeta
import org.apache.paimon.table.source.{DataSplit, RawFile, Split}

import org.apache.spark.sql.SparkSession

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

trait ScanHelper {

private val spark = SparkSession.active

val coreOptions: CoreOptions

private lazy val openCostInBytes: Long = coreOptions.splitOpenFileCost()

private lazy val leafNodeDefaultParallelism: Int = {
spark.conf
.get("spark.sql.leafNodeDefaultParallelism", spark.sparkContext.defaultParallelism.toString)
.toInt
}

def reshuffleSplits(splits: Array[Split]): Array[Split] = {
if (splits.length < leafNodeDefaultParallelism) {
val (toReshuffle, reserved) = splits.partition {
case split: DataSplit => split.beforeFiles().isEmpty && split.convertToRawFiles.isPresent
case _ => false
}
val reshuffled = reshuffleSplits0(toReshuffle.collect { case ds: DataSplit => ds })
reshuffled ++ reserved
} else {
splits
}
}

private def reshuffleSplits0(splits: Array[DataSplit]): Array[DataSplit] = {
val maxSplitBytes = computeMaxSplitBytes(splits)

val newSplits = new ArrayBuffer[DataSplit]

var currentSplit: Option[DataSplit] = None
val currentDataFiles = new ArrayBuffer[DataFileMeta]
val currentRawFiles = new ArrayBuffer[RawFile]
var currentSize = 0L

def closeDataSplit(): Unit = {
if (currentSplit.nonEmpty && currentDataFiles.nonEmpty) {
val newSplit = copyDataSplit(currentSplit.get, currentDataFiles, currentRawFiles)
newSplits += newSplit
}
currentDataFiles.clear()
currentRawFiles.clear()
currentSize = 0
}

splits.foreach {
split =>
currentSplit = Some(split)
val hasRawFiles = split.convertToRawFiles().isPresent

split.dataFiles().asScala.zipWithIndex.foreach {
case (file, idx) =>
if (currentSize + file.fileSize > maxSplitBytes) {
closeDataSplit()
}
currentSize += file.fileSize + openCostInBytes
currentDataFiles += file
if (hasRawFiles) {
currentRawFiles += split.convertToRawFiles().get().get(idx)
}
}
closeDataSplit()
}

newSplits.toArray
}

private def unpack(split: Split): Array[DataFileMeta] = {
split match {
case ds: DataSplit =>
ds.dataFiles().asScala.toArray
case _ => Array.empty
}
}

private def copyDataSplit(
split: DataSplit,
dataFiles: Seq[DataFileMeta],
rawFiles: Seq[RawFile]): DataSplit = {
val builder = DataSplit
.builder()
.withSnapshot(split.snapshotId())
.withPartition(split.partition())
.withBucket(split.bucket())
.withDataFiles(dataFiles.toList.asJava)
.rawFiles(rawFiles.toList.asJava)
builder.build()
}

private def computeMaxSplitBytes(dataSplits: Seq[DataSplit]): Long = {
val dataFiles = dataSplits.flatMap(unpack)
val defaultMaxSplitBytes = spark.sessionState.conf.filesMaxPartitionBytes
val minPartitionNum = spark.sessionState.conf.filesMinPartitionNum
.getOrElse(leafNodeDefaultParallelism)
val totalBytes = dataFiles.map(file => file.fileSize + openCostInBytes).sum
val bytesPerCore = totalBytes / minPartitionNum

Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore))
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.paimon.spark

import org.apache.paimon.CoreOptions
import org.apache.paimon.data.BinaryRow
import org.apache.paimon.io.DataFileMeta
import org.apache.paimon.table.source.{DataSplit, RawFile, Split}

import org.junit.jupiter.api.Assertions

import java.util.HashMap

import scala.collection.JavaConverters._
import scala.collection.mutable

class ScanHelperTest extends PaimonSparkTestBase {

test("Paimon: reshuffle splits") {
withSQLConf(("spark.sql.leafNodeDefaultParallelism", "20")) {
val splitNum = 5
val fileNum = 100

val files = scala.collection.mutable.ListBuffer.empty[DataFileMeta]
val rawFiles = scala.collection.mutable.ListBuffer.empty[RawFile]
0.until(fileNum).foreach {
i =>
val path = s"f$i.parquet"
files += DataFileMeta.forAppend(path, 750000, 30000, null, 0, 29999, 1)

rawFiles += new RawFile(s"/a/b/$path", 0, 75000, "parquet", 0, 30000)
}

val dataSplits = mutable.ArrayBuffer.empty[Split]
0.until(splitNum).foreach {
i =>
dataSplits += DataSplit
.builder()
.withSnapshot(1)
.withBucket(0)
.withPartition(new BinaryRow(0))
.withDataFiles(files.zipWithIndex.filter(_._2 % splitNum == i).map(_._1).toList.asJava)
.rawFiles(rawFiles.zipWithIndex.filter(_._2 % splitNum == i).map(_._1).toList.asJava)
.build()
}

val fakeScan = new FakeScan()
val reshuffled = fakeScan.reshuffleSplits(dataSplits.toArray)
Assertions.assertTrue(reshuffled.length > 5)
}
}

class FakeScan extends ScanHelper {
override val coreOptions: CoreOptions =
CoreOptions.fromMap(new HashMap[String, String]())
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
package org.apache.paimon.spark.sql

import org.apache.paimon.spark.{PaimonSparkTestBase, SparkInputPartition, SparkTable}
import org.apache.paimon.spark.{PaimonBatch, PaimonScan, PaimonSparkTestBase, SparkInputPartition, SparkTable}
import org.apache.paimon.table.source.DataSplit

import org.apache.spark.sql.Row
Expand Down Expand Up @@ -111,7 +111,9 @@ class PaimonPushDownTest extends PaimonSparkTestBase {

// It still return false even it can push down limit.
Assertions.assertFalse(scanBuilder.asInstanceOf[SupportsPushDownLimit].pushLimit(1))
val partitions = scanBuilder.build().toBatch.planInputPartitions()
val paimonScan = scanBuilder.build().asInstanceOf[PaimonScan]
val partitions =
PaimonBatch(paimonScan.getOriginSplits, paimonScan.readBuilder).planInputPartitions()
Assertions.assertEquals(1, partitions.length)

Assertions.assertEquals(1, spark.sql("SELECT * FROM T LIMIT 1").count())
Expand Down

0 comments on commit e90e5ab

Please sign in to comment.