Skip to content

Commit

Permalink
[BLAZE-707][FOLLOWUP] NativePaimonTableScanExec should use shimed Par…
Browse files Browse the repository at this point in the history
…titionedFile and min partition number (#713)
  • Loading branch information
SteNicholas authored Dec 21, 2024
1 parent d5bf5a0 commit 43e3621
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.blaze
import java.io.File
import java.util.UUID
import org.apache.commons.lang3.reflect.FieldUtils
import org.apache.hadoop.fs.Path
import org.apache.spark.OneToOneDependency
import org.apache.spark.ShuffleDependency
import org.apache.spark.SparkEnv
Expand All @@ -33,6 +34,7 @@ import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.blaze.BlazeConverters.ForceNativeExecutionWrapperBase
import org.apache.spark.sql.blaze.NativeConverters.NativeExprWrapperBase
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.Expression
Expand Down Expand Up @@ -97,6 +99,7 @@ import org.apache.spark.sql.execution.blaze.plan._
import org.apache.spark.sql.execution.blaze.shuffle.RssPartitionWriterBase
import org.apache.spark.sql.execution.blaze.shuffle.celeborn.BlazeCelebornShuffleManager
import org.apache.spark.sql.execution.blaze.shuffle.BlazeBlockStoreShuffleReaderBase
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec}
import org.apache.spark.sql.execution.joins.blaze.plan.NativeBroadcastJoinExec
import org.apache.spark.sql.execution.joins.blaze.plan.NativeShuffledHashJoinExecProvider
Expand Down Expand Up @@ -818,6 +821,37 @@ class ShimsImpl extends Shims with Logging {
NativeExprWrapper(nativeExpr, dataType, nullable)
}

@enableIf(
Seq("spark-3.0", "spark-3.1", "spark-3.2", "spark-3.3").contains(
System.getProperty("blaze.shim")))
override def getPartitionedFile(
partitionValues: InternalRow,
filePath: String,
offset: Long,
size: Long): PartitionedFile =
PartitionedFile(partitionValues, filePath, offset, size)

@enableIf(Seq("spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim")))
override def getPartitionedFile(
partitionValues: InternalRow,
filePath: String,
offset: Long,
size: Long): PartitionedFile = {
import org.apache.spark.paths.SparkPath
PartitionedFile(partitionValues, SparkPath.fromPath(new Path(filePath)), offset, size)
}

@enableIf(
Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
System.getProperty("blaze.shim")))
override def getMinPartitionNum(sparkSession: SparkSession): Int =
sparkSession.sessionState.conf.filesMinPartitionNum
.getOrElse(sparkSession.sparkContext.defaultParallelism)

@enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim")))
override def getMinPartitionNum(sparkSession: SparkSession): Int =
sparkSession.sparkContext.defaultParallelism

@enableIf(
Seq("spark-3.0", "spark-3.1", "spark-3.2", "spark-3.3").contains(
System.getProperty("blaze.shim")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.hive.execution.InsertIntoHiveTable
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.storage.FileSegment

Expand Down Expand Up @@ -238,6 +240,14 @@ abstract class Shims {
dataType: DataType,
nullable: Boolean): Expression

def getPartitionedFile(
partitionValues: InternalRow,
filePath: String,
offset: Long,
size: Long): PartitionedFile

def getMinPartitionNum(sparkSession: SparkSession): Int

def postTransform(plan: SparkPlan, sc: SparkContext): Unit = {}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,10 @@ case class NativePaimonTableScanExec(basedHiveScan: HiveTableScanExec)
(0L until dataFileMeta.fileSize() by maxSplitBytes).map { offset =>
val remaining = dataFileMeta.fileSize() - offset
val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining
PartitionedFile(partitionValues, filePath, offset, size)
Shims.get.getPartitionedFile(partitionValues, filePath, offset, size)
}
} else {
Seq(PartitionedFile(partitionValues, filePath, 0, dataFileMeta.fileSize()))
Seq(Shims.get.getPartitionedFile(partitionValues, filePath, 0, dataFileMeta.fileSize()))
}
}

Expand All @@ -229,8 +229,7 @@ case class NativePaimonTableScanExec(basedHiveScan: HiveTableScanExec)
selectedSplits: Seq[DataSplit]): Long = {
val defaultMaxSplitBytes = sparkSession.sessionState.conf.filesMaxPartitionBytes
val openCostInBytes = sparkSession.sessionState.conf.filesOpenCostInBytes
val minPartitionNum = sparkSession.sessionState.conf.filesMinPartitionNum
.getOrElse(sparkSession.sparkContext.defaultParallelism)
val minPartitionNum = Shims.get.getMinPartitionNum(sparkSession)
val totalBytes = selectedSplits
.flatMap(_.dataFiles().asScala.map(_.fileSize() + openCostInBytes))
.sum
Expand Down

0 comments on commit 43e3621

Please sign in to comment.