From ad8248fc8645c3887ed39dea7fa93ed9ff9ffeb7 Mon Sep 17 00:00:00 2001 From: Yann Byron Date: Fri, 28 Jun 2024 23:36:43 +0800 Subject: [PATCH] [spark] support to read multi splits in a spark input partition (#3612) --- .../paimon/spark/SparkInputPartition.java | 60 ------------- .../apache/paimon/spark/PaimonBaseScan.scala | 12 +-- .../org/apache/paimon/spark/PaimonBatch.scala | 20 +---- .../paimon/spark/PaimonInputPartition.scala | 31 +++++++ .../paimon/spark/PaimonPartitionReader.scala | 88 +++++++++++++------ .../spark/PaimonPartitionReaderFactory.scala | 2 +- .../org/apache/paimon/spark/PaimonScan.scala | 4 +- .../apache/paimon/spark/PaimonSplitScan.scala | 2 +- .../paimon/spark/PaimonStatistics.scala | 2 +- .../org/apache/paimon/spark/ScanHelper.scala | 83 +++++++++++------ .../paimon/spark/commands/PaimonCommand.scala | 1 + .../sources/PaimonMicroBatchStream.scala | 4 +- .../apache/paimon/spark/ScanHelperTest.scala | 4 +- .../paimon/spark/sql/PaimonPushDownTest.scala | 25 ++---- 14 files changed, 179 insertions(+), 159 deletions(-) delete mode 100644 paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInputPartition.java create mode 100644 paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInputPartition.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInputPartition.java deleted file mode 100644 index 3b02c61a646c..000000000000 --- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInputPartition.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.paimon.spark; - -import org.apache.paimon.table.source.Split; - -import org.apache.spark.sql.connector.read.InputPartition; - -import java.util.Objects; - -/** A Spark {@link InputPartition} for paimon. */ -public class SparkInputPartition implements InputPartition { - - private static final long serialVersionUID = 1L; - - private final Split split; - - public SparkInputPartition(Split split) { - this.split = split; - } - - public Split split() { - return split; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - - if (o == null || getClass() != o.getClass()) { - return false; - } - - SparkInputPartition that = (SparkInputPartition) o; - return this.split.equals(that.split); - } - - @Override - public int hashCode() { - return Objects.hash(split); - } -} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala index 73a3e68f03c6..86ac01327a36 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala @@ -64,7 +64,7 @@ abstract class PaimonBaseScan( protected var runtimeFilters: Array[Filter] = Array.empty - protected var splits: Array[Split] = _ + protected var inputPartitions: Seq[PaimonInputPartition] = _ override val coreOptions: CoreOptions = CoreOptions.fromMap(table.options()) @@ -93,11 +93,11 @@ abstract class PaimonBaseScan( readBuilder.newScan().plan().splits().asScala.toArray } - def getSplits: Array[Split] = { - if (splits == null) { - splits = reshuffleSplits(getOriginSplits) + def getInputPartitions: Seq[PaimonInputPartition] = { + if (inputPartitions == null) { + inputPartitions = getInputPartitions(getOriginSplits) } - splits + inputPartitions } override def readSchema(): StructType = { @@ -106,7 +106,7 @@ abstract class PaimonBaseScan( override def toBatch: Batch = { val metadataColumns = metadataFields.map(field => PaimonMetadataColumn.get(field.name)) - PaimonBatch(getSplits, readBuilder, metadataColumns) + PaimonBatch(getInputPartitions, readBuilder, metadataColumns) } override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = { diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBatch.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBatch.scala index c7b782573a30..9969f7ebae4d 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBatch.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBatch.scala @@ -19,37 +19,23 @@ package org.apache.paimon.spark import org.apache.paimon.spark.schema.PaimonMetadataColumn -import org.apache.paimon.table.source.{ReadBuilder, Split} +import org.apache.paimon.table.source.ReadBuilder import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory} -import org.apache.spark.sql.types.StructType import java.util.Objects /** A Spark [[Batch]] for paimon. */ case class PaimonBatch( - splits: Array[Split], + inputPartitions: Seq[PaimonInputPartition], readBuilder: ReadBuilder, metadataColumns: Seq[PaimonMetadataColumn] = Seq.empty) extends Batch { override def planInputPartitions(): Array[InputPartition] = - splits.map(new SparkInputPartition(_).asInstanceOf[InputPartition]) + inputPartitions.map(_.asInstanceOf[InputPartition]).toArray override def createReaderFactory(): PartitionReaderFactory = PaimonPartitionReaderFactory(readBuilder, metadataColumns) - override def equals(obj: Any): Boolean = { - obj match { - case other: PaimonBatch => - this.splits.sameElements(other.splits) && - readBuilder.equals(other.readBuilder) - - case _ => false - } - } - - override def hashCode(): Int = { - Objects.hashCode(splits.toSeq, readBuilder) - } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala new file mode 100644 index 000000000000..b0a1d64999bc --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark + +import org.apache.paimon.table.source.Split + +import org.apache.spark.sql.connector.read.InputPartition + +case class PaimonInputPartition(splits: Seq[Split]) extends InputPartition {} + +object PaimonInputPartition { + def apply(split: Split): PaimonInputPartition = { + PaimonInputPartition(Seq(split)) + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonPartitionReader.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonPartitionReader.scala index 1509ea83c4c3..cea235120ca4 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonPartitionReader.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonPartitionReader.scala @@ -19,7 +19,7 @@ package org.apache.paimon.spark import org.apache.paimon.data.{InternalRow => PaimonInternalRow} -import org.apache.paimon.reader.{RecordReader, RecordReaderIterator} +import org.apache.paimon.reader.RecordReader import org.apache.paimon.spark.schema.PaimonMetadataColumn import org.apache.paimon.table.source.{DataSplit, Split} @@ -33,50 +33,88 @@ import scala.collection.JavaConverters._ case class PaimonPartitionReader( readFunc: Split => RecordReader[PaimonInternalRow], - partition: SparkInputPartition, + partition: PaimonInputPartition, row: SparkInternalRow, metadataColumns: Seq[PaimonMetadataColumn] ) extends PartitionReader[InternalRow] { - private lazy val split: Split = partition.split - - private lazy val iterator = { - val reader = readFunc(split) - PaimonRecordReaderIterator(reader, metadataColumns) - } + private val splits: Iterator[Split] = partition.splits.toIterator + private var currentRecordReader: PaimonRecordReaderIterator = readSplit() + private var advanced = false + private var currentRow: PaimonInternalRow = _ override def next(): Boolean = { - if (iterator.hasNext) { - row.replace(iterator.next()) - true - } else { + if (currentRecordReader == null) { false + } else { + advanceIfNeeded() + currentRow != null } } override def get(): InternalRow = { - row + if (!next) { + null + } else { + advanced = false + row.replace(currentRow) + } } - override def currentMetricsValues(): Array[CustomTaskMetric] = { - val paimonMetricsValues: Array[CustomTaskMetric] = split match { - case dataSplit: DataSplit => - val splitSize = dataSplit.dataFiles().asScala.map(_.fileSize).sum - Array( - PaimonNumSplitsTaskMetric(1L), - PaimonSplitSizeTaskMetric(splitSize), - PaimonAvgSplitSizeTaskMetric(splitSize) - ) + private def advanceIfNeeded(): Unit = { + if (!advanced) { + advanced = true + var stop = false + while (!stop) { + if (currentRecordReader.hasNext) { + currentRow = currentRecordReader.next() + } else { + currentRow = null + } + + if (currentRow != null) { + stop = true + } else { + currentRecordReader.close() + currentRecordReader = readSplit() + if (currentRecordReader == null) { + stop = true + } + } + } + } + } + + private def readSplit(): PaimonRecordReaderIterator = { + if (splits.hasNext) { + val reader = readFunc(splits.next()) + PaimonRecordReaderIterator(reader, metadataColumns) + } else { + null + } + } - case _ => - Array.empty[CustomTaskMetric] + override def currentMetricsValues(): Array[CustomTaskMetric] = { + val dataSplits = partition.splits.collect { case ds: DataSplit => ds } + val numSplits = dataSplits.length + val paimonMetricsValues: Array[CustomTaskMetric] = if (dataSplits.nonEmpty) { + val splitSize = dataSplits.map(_.dataFiles().asScala.map(_.fileSize).sum).sum + Array( + PaimonNumSplitsTaskMetric(numSplits), + PaimonSplitSizeTaskMetric(splitSize), + PaimonAvgSplitSizeTaskMetric(splitSize / numSplits) + ) + } else { + Array.empty[CustomTaskMetric] } super.currentMetricsValues() ++ paimonMetricsValues } override def close(): Unit = { try { - iterator.close() + if (currentRecordReader != null) { + currentRecordReader.close() + } } catch { case e: Exception => throw new IOException(e) diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonPartitionReaderFactory.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonPartitionReaderFactory.scala index 3ac724fcc521..94de0bec3b50 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonPartitionReaderFactory.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonPartitionReaderFactory.scala @@ -50,7 +50,7 @@ case class PaimonPartitionReaderFactory( override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { partition match { - case paimonInputPartition: SparkInputPartition => + case paimonInputPartition: PaimonInputPartition => val readFunc: Split => RecordReader[data.InternalRow] = (split: Split) => readBuilder.newRead().withIOManager(ioManager).createReader(split) PaimonPartitionReader(readFunc, paimonInputPartition, row, metadataColumns) diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala index eb6bb10ee844..f0476cf70732 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala @@ -58,8 +58,8 @@ case class PaimonScan( if (partitionFilter.nonEmpty) { this.runtimeFilters = filters readBuilder.withFilter(partitionFilter.head) - // set splits null to trigger to get the new splits. - splits = null + // set inputPartitions null to trigger to get the new splits. + inputPartitions = null } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSplitScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSplitScan.scala index b4e95f087b2d..58c2fd693c1f 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSplitScan.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSplitScan.scala @@ -40,7 +40,7 @@ case class PaimonSplitScan( override def toBatch: Batch = { PaimonBatch( - reshuffleSplits(dataSplits.asInstanceOf[Array[Split]]), + getInputPartitions(dataSplits.asInstanceOf[Array[Split]]), table.newReadBuilder, metadataColumns) } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala index fcf386397251..edf8c01dd3ee 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala @@ -34,7 +34,7 @@ import scala.collection.JavaConverters._ case class PaimonStatistics[T <: PaimonBaseScan](scan: T) extends Statistics { - private lazy val rowCount: Long = scan.getSplits.map(_.rowCount).sum + private lazy val rowCount: Long = scan.getOriginSplits.map(_.rowCount).sum private lazy val scannedTotalSize: Long = rowCount * scan.readSchema().defaultSize diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/ScanHelper.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/ScanHelper.scala index 762ae014670e..4971de24c255 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/ScanHelper.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/ScanHelper.scala @@ -44,62 +44,80 @@ trait ScanHelper extends Logging { .toInt } - def reshuffleSplits(splits: Array[Split]): Array[Split] = { - if (splits.length < leafNodeDefaultParallelism) { - val beforeLength = splits.length - val (toReshuffle, reserved) = splits.partition { - case split: DataSplit => split.beforeFiles().isEmpty && split.rawConvertible() - case _ => false - } - val reshuffled = reshuffleSplits0(toReshuffle.collect { case ds: DataSplit => ds }) - val all = reshuffled ++ reserved - logInfo(s"Reshuffle splits from $beforeLength to ${all.length}") + def getInputPartitions(splits: Array[Split]): Seq[PaimonInputPartition] = { + val (toReshuffle, reserved) = splits.partition { + case split: DataSplit => split.beforeFiles().isEmpty && split.rawConvertible() + case _ => false + } + if (toReshuffle.nonEmpty) { + val startTS = System.currentTimeMillis() + val reshuffled = getInputPartitions(toReshuffle.collect { case ds: DataSplit => ds }) + val all = reserved.map(PaimonInputPartition.apply) ++ reshuffled + val duration = System.currentTimeMillis() - startTS + logInfo( + s"Reshuffle splits from ${toReshuffle.length} to ${reshuffled.length} in $duration ms. Total number of splits is ${all.length}") all } else { - splits + splits.map(PaimonInputPartition.apply) } } - private def reshuffleSplits0(splits: Array[DataSplit]): Array[DataSplit] = { + private def getInputPartitions(splits: Array[DataSplit]): Array[PaimonInputPartition] = { val maxSplitBytes = computeMaxSplitBytes(splits) - val newSplits = new ArrayBuffer[DataSplit] + var currentSize = 0L + val currentSplits = new ArrayBuffer[DataSplit] + val partitions = new ArrayBuffer[PaimonInputPartition] var currentSplit: Option[DataSplit] = None val currentDataFiles = new ArrayBuffer[DataFileMeta] val currentDeletionFiles = new ArrayBuffer[DeletionFile] - var currentSize = 0L def closeDataSplit(): Unit = { if (currentSplit.nonEmpty && currentDataFiles.nonEmpty) { val newSplit = copyDataSplit(currentSplit.get, currentDataFiles, currentDeletionFiles) - newSplits += newSplit + currentSplits += newSplit } currentDataFiles.clear() currentDeletionFiles.clear() + } + + def closeInputPartition(): Unit = { + closeDataSplit() + if (currentSplit.nonEmpty) { + partitions += PaimonInputPartition(currentSplits.toArray) + } + currentSplits.clear() currentSize = 0 } splits.foreach { split => - currentSplit = Some(split) + if (!currentSplit.exists(withSamePartitionAndBucket(_, split))) { + // close and open another data split + closeDataSplit() + currentSplit = Some(split) + } - split.dataFiles().asScala.zipWithIndex.foreach { - case (file, idx) => - if (currentSize + file.fileSize > maxSplitBytes) { - closeDataSplit() + val ddFiles = dataFileAndDeletionFiles(split) + ddFiles.foreach { + case (dataFile, deletionFile) => + val size = dataFile + .fileSize() + openCostInBytes + Option(deletionFile).map(_.length()).getOrElse(0L) + if (currentSize + size > maxSplitBytes) { + closeInputPartition() } - currentSize += file.fileSize + openCostInBytes - currentDataFiles += file + currentDataFiles += dataFile if (deletionVectors) { - currentDeletionFiles += split.deletionFiles().get().get(idx) + currentDeletionFiles += deletionFile } + currentSize += size } - closeDataSplit() } + closeInputPartition() - newSplits.toArray + partitions.toArray } private def unpack(split: Split): Array[DataFileMeta] = { @@ -120,7 +138,7 @@ trait ScanHelper extends Logging { .withPartition(split.partition()) .withBucket(split.bucket()) .withDataFiles(dataFiles.toList.asJava) - .rawConvertible(split.rawConvertible()) + .rawConvertible(split.rawConvertible) .withBucketPath(split.bucketPath) if (deletionVectors) { builder.withDataDeletionFiles(deletionFiles.toList.asJava) @@ -128,6 +146,19 @@ trait ScanHelper extends Logging { builder.build() } + private def withSamePartitionAndBucket(split1: DataSplit, split2: DataSplit): Boolean = { + split1.partition().equals(split2.partition()) && split1.bucket() == split2.bucket() + } + + private def dataFileAndDeletionFiles(split: DataSplit): Array[(DataFileMeta, DeletionFile)] = { + if (deletionVectors && split.deletionFiles().isPresent) { + val deletionFiles = split.deletionFiles().get().asScala + split.dataFiles().asScala.zip(deletionFiles).toArray + } else { + split.dataFiles().asScala.map((_, null)).toArray + } + } + private def computeMaxSplitBytes(dataSplits: Seq[DataSplit]): Long = { val dataFiles = dataSplits.flatMap(unpack) val defaultMaxSplitBytes = spark.sessionState.conf.filesMaxPartitionBytes diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala index 8d89af2a14de..5e5f46f69e8b 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala @@ -180,6 +180,7 @@ trait PaimonCommand extends WithFileStoreTable with ExpressionHelper { .toMap dvIndexFileMaintainer.notifyDeletionFiles(touchedDataFileAndDeletionFiles.asJava) + dvIndexFileMaintainer.writeUnchangedDeletionVector().asScala } protected def collectDeletionVectors( diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/sources/PaimonMicroBatchStream.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/sources/PaimonMicroBatchStream.scala index 78e2e813248e..398a427673f0 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/sources/PaimonMicroBatchStream.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/sources/PaimonMicroBatchStream.scala @@ -19,7 +19,7 @@ package org.apache.paimon.spark.sources import org.apache.paimon.options.Options -import org.apache.paimon.spark.{PaimonImplicits, PaimonPartitionReaderFactory, SparkConnectorOptions, SparkInputPartition} +import org.apache.paimon.spark.{PaimonImplicits, PaimonInputPartition, PaimonPartitionReaderFactory, SparkConnectorOptions} import org.apache.paimon.table.DataTable import org.apache.paimon.table.source.ReadBuilder @@ -122,7 +122,7 @@ class PaimonMicroBatchStream( val endOffset = PaimonSourceOffset(end) getBatch(startOffset, Some(endOffset), None) - .map(ids => new SparkInputPartition(ids.entry)) + .map(ids => PaimonInputPartition(ids.entry)) .toArray[InputPartition] } diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/ScanHelperTest.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/ScanHelperTest.scala index 7fae33953dd7..1a844d9983c2 100644 --- a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/ScanHelperTest.scala +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/ScanHelperTest.scala @@ -52,7 +52,7 @@ class ScanHelperTest extends PaimonSparkTestBase { .builder() .withSnapshot(1) .withBucket(0) - .withPartition(new BinaryRow(0)) + .withPartition(BinaryRow.EMPTY_ROW) .withDataFiles(files.zipWithIndex.filter(_._2 % splitNum == i).map(_._1).toList.asJava) .rawConvertible(true) .withBucketPath("no use") @@ -60,7 +60,7 @@ class ScanHelperTest extends PaimonSparkTestBase { } val fakeScan = new FakeScan() - val reshuffled = fakeScan.reshuffleSplits(dataSplits.toArray) + val reshuffled = fakeScan.getInputPartitions(dataSplits.toArray) Assertions.assertTrue(reshuffled.length > 5) } } diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala index d7d6ff763dd6..c55ed876d6b1 100644 --- a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala @@ -18,7 +18,7 @@ package org.apache.paimon.spark.sql -import org.apache.paimon.spark.{PaimonBatch, PaimonScan, PaimonSparkTestBase, SparkInputPartition, SparkTable} +import org.apache.paimon.spark.{PaimonBatch, PaimonInputPartition, PaimonScan, PaimonSparkTestBase, SparkTable} import org.apache.paimon.table.source.DataSplit import org.apache.spark.sql.Row @@ -89,33 +89,26 @@ class PaimonPushDownTest extends PaimonSparkTestBase { test("Paimon pushDown: limit for append-only tables") { spark.sql(s""" |CREATE TABLE T (a INT, b STRING, c STRING) + |PARTITIONED BY (c) |""".stripMargin) - spark.sql("INSERT INTO T VALUES (1, 'a', '11'), (2, 'b', '22')") - spark.sql("INSERT INTO T VALUES (3, 'c', '11'), (4, 'd', '22')") + spark.sql("INSERT INTO T VALUES (1, 'a', '11'), (2, 'b', '11')") + spark.sql("INSERT INTO T VALUES (3, 'c', '22'), (4, 'd', '22')") checkAnswer( spark.sql("SELECT * FROM T ORDER BY a"), - Row(1, "a", "11") :: Row(2, "b", "22") :: Row(3, "c", "11") :: Row(4, "d", "22") :: Nil) + Row(1, "a", "11") :: Row(2, "b", "11") :: Row(3, "c", "22") :: Row(4, "d", "22") :: Nil) val scanBuilder = getScanBuilder() Assertions.assertTrue(scanBuilder.isInstanceOf[SupportsPushDownLimit]) - val dataFilesWithoutLimit = scanBuilder.build().toBatch.planInputPartitions().flatMap { - case partition: SparkInputPartition => - partition.split() match { - case dataSplit: DataSplit => dataSplit.dataFiles().asScala - case _ => Seq.empty - } - } - Assertions.assertTrue(dataFilesWithoutLimit.length >= 2) + val dataSplitsWithoutLimit = scanBuilder.build().asInstanceOf[PaimonScan].getOriginSplits + Assertions.assertTrue(dataSplitsWithoutLimit.length >= 2) // It still return false even it can push down limit. Assertions.assertFalse(scanBuilder.asInstanceOf[SupportsPushDownLimit].pushLimit(1)) - val paimonScan = scanBuilder.build().asInstanceOf[PaimonScan] - val partitions = - PaimonBatch(paimonScan.getOriginSplits, paimonScan.readBuilder).planInputPartitions() - Assertions.assertEquals(1, partitions.length) + val dataSplitsWithLimit = scanBuilder.build().asInstanceOf[PaimonScan].getOriginSplits + Assertions.assertEquals(1, dataSplitsWithLimit.length) Assertions.assertEquals(1, spark.sql("SELECT * FROM T LIMIT 1").count()) }