Skip to content

Commit

Permalink
[spark] support to read multi splits in a spark input partition (#3612)
Browse files Browse the repository at this point in the history
  • Loading branch information
YannByron authored Jun 28, 2024
1 parent 5bd7ee9 commit ad8248f
Show file tree
Hide file tree
Showing 14 changed files with 179 additions and 159 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ abstract class PaimonBaseScan(

protected var runtimeFilters: Array[Filter] = Array.empty

protected var splits: Array[Split] = _
protected var inputPartitions: Seq[PaimonInputPartition] = _

override val coreOptions: CoreOptions = CoreOptions.fromMap(table.options())

Expand Down Expand Up @@ -93,11 +93,11 @@ abstract class PaimonBaseScan(
readBuilder.newScan().plan().splits().asScala.toArray
}

def getSplits: Array[Split] = {
if (splits == null) {
splits = reshuffleSplits(getOriginSplits)
def getInputPartitions: Seq[PaimonInputPartition] = {
if (inputPartitions == null) {
inputPartitions = getInputPartitions(getOriginSplits)
}
splits
inputPartitions
}

override def readSchema(): StructType = {
Expand All @@ -106,7 +106,7 @@ abstract class PaimonBaseScan(

override def toBatch: Batch = {
val metadataColumns = metadataFields.map(field => PaimonMetadataColumn.get(field.name))
PaimonBatch(getSplits, readBuilder, metadataColumns)
PaimonBatch(getInputPartitions, readBuilder, metadataColumns)
}

override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,23 @@
package org.apache.paimon.spark

import org.apache.paimon.spark.schema.PaimonMetadataColumn
import org.apache.paimon.table.source.{ReadBuilder, Split}
import org.apache.paimon.table.source.ReadBuilder

import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory}
import org.apache.spark.sql.types.StructType

import java.util.Objects

/** A Spark [[Batch]] for paimon. */
case class PaimonBatch(
splits: Array[Split],
inputPartitions: Seq[PaimonInputPartition],
readBuilder: ReadBuilder,
metadataColumns: Seq[PaimonMetadataColumn] = Seq.empty)
extends Batch {

override def planInputPartitions(): Array[InputPartition] =
splits.map(new SparkInputPartition(_).asInstanceOf[InputPartition])
inputPartitions.map(_.asInstanceOf[InputPartition]).toArray

override def createReaderFactory(): PartitionReaderFactory =
PaimonPartitionReaderFactory(readBuilder, metadataColumns)

override def equals(obj: Any): Boolean = {
obj match {
case other: PaimonBatch =>
this.splits.sameElements(other.splits) &&
readBuilder.equals(other.readBuilder)

case _ => false
}
}

override def hashCode(): Int = {
Objects.hashCode(splits.toSeq, readBuilder)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.paimon.spark

import org.apache.paimon.table.source.Split

import org.apache.spark.sql.connector.read.InputPartition

case class PaimonInputPartition(splits: Seq[Split]) extends InputPartition {}

object PaimonInputPartition {
def apply(split: Split): PaimonInputPartition = {
PaimonInputPartition(Seq(split))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
package org.apache.paimon.spark

import org.apache.paimon.data.{InternalRow => PaimonInternalRow}
import org.apache.paimon.reader.{RecordReader, RecordReaderIterator}
import org.apache.paimon.reader.RecordReader
import org.apache.paimon.spark.schema.PaimonMetadataColumn
import org.apache.paimon.table.source.{DataSplit, Split}

Expand All @@ -33,50 +33,88 @@ import scala.collection.JavaConverters._

case class PaimonPartitionReader(
readFunc: Split => RecordReader[PaimonInternalRow],
partition: SparkInputPartition,
partition: PaimonInputPartition,
row: SparkInternalRow,
metadataColumns: Seq[PaimonMetadataColumn]
) extends PartitionReader[InternalRow] {

private lazy val split: Split = partition.split

private lazy val iterator = {
val reader = readFunc(split)
PaimonRecordReaderIterator(reader, metadataColumns)
}
private val splits: Iterator[Split] = partition.splits.toIterator
private var currentRecordReader: PaimonRecordReaderIterator = readSplit()
private var advanced = false
private var currentRow: PaimonInternalRow = _

override def next(): Boolean = {
if (iterator.hasNext) {
row.replace(iterator.next())
true
} else {
if (currentRecordReader == null) {
false
} else {
advanceIfNeeded()
currentRow != null
}
}

override def get(): InternalRow = {
row
if (!next) {
null
} else {
advanced = false
row.replace(currentRow)
}
}

override def currentMetricsValues(): Array[CustomTaskMetric] = {
val paimonMetricsValues: Array[CustomTaskMetric] = split match {
case dataSplit: DataSplit =>
val splitSize = dataSplit.dataFiles().asScala.map(_.fileSize).sum
Array(
PaimonNumSplitsTaskMetric(1L),
PaimonSplitSizeTaskMetric(splitSize),
PaimonAvgSplitSizeTaskMetric(splitSize)
)
private def advanceIfNeeded(): Unit = {
if (!advanced) {
advanced = true
var stop = false
while (!stop) {
if (currentRecordReader.hasNext) {
currentRow = currentRecordReader.next()
} else {
currentRow = null
}

if (currentRow != null) {
stop = true
} else {
currentRecordReader.close()
currentRecordReader = readSplit()
if (currentRecordReader == null) {
stop = true
}
}
}
}
}

private def readSplit(): PaimonRecordReaderIterator = {
if (splits.hasNext) {
val reader = readFunc(splits.next())
PaimonRecordReaderIterator(reader, metadataColumns)
} else {
null
}
}

case _ =>
Array.empty[CustomTaskMetric]
override def currentMetricsValues(): Array[CustomTaskMetric] = {
val dataSplits = partition.splits.collect { case ds: DataSplit => ds }
val numSplits = dataSplits.length
val paimonMetricsValues: Array[CustomTaskMetric] = if (dataSplits.nonEmpty) {
val splitSize = dataSplits.map(_.dataFiles().asScala.map(_.fileSize).sum).sum
Array(
PaimonNumSplitsTaskMetric(numSplits),
PaimonSplitSizeTaskMetric(splitSize),
PaimonAvgSplitSizeTaskMetric(splitSize / numSplits)
)
} else {
Array.empty[CustomTaskMetric]
}
super.currentMetricsValues() ++ paimonMetricsValues
}

override def close(): Unit = {
try {
iterator.close()
if (currentRecordReader != null) {
currentRecordReader.close()
}
} catch {
case e: Exception =>
throw new IOException(e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ case class PaimonPartitionReaderFactory(

override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
partition match {
case paimonInputPartition: SparkInputPartition =>
case paimonInputPartition: PaimonInputPartition =>
val readFunc: Split => RecordReader[data.InternalRow] =
(split: Split) => readBuilder.newRead().withIOManager(ioManager).createReader(split)
PaimonPartitionReader(readFunc, paimonInputPartition, row, metadataColumns)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ case class PaimonScan(
if (partitionFilter.nonEmpty) {
this.runtimeFilters = filters
readBuilder.withFilter(partitionFilter.head)
// set splits null to trigger to get the new splits.
splits = null
// set inputPartitions null to trigger to get the new splits.
inputPartitions = null
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ case class PaimonSplitScan(

override def toBatch: Batch = {
PaimonBatch(
reshuffleSplits(dataSplits.asInstanceOf[Array[Split]]),
getInputPartitions(dataSplits.asInstanceOf[Array[Split]]),
table.newReadBuilder,
metadataColumns)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import scala.collection.JavaConverters._

case class PaimonStatistics[T <: PaimonBaseScan](scan: T) extends Statistics {

private lazy val rowCount: Long = scan.getSplits.map(_.rowCount).sum
private lazy val rowCount: Long = scan.getOriginSplits.map(_.rowCount).sum

private lazy val scannedTotalSize: Long = rowCount * scan.readSchema().defaultSize

Expand Down
Loading

0 comments on commit ad8248f

Please sign in to comment.