diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala index 5fec8b99751f..f4185806777a 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala @@ -21,7 +21,8 @@ package org.apache.paimon.spark.commands import org.apache.paimon.spark.SparkTable import org.apache.paimon.spark.catalyst.analysis.PaimonRelation import org.apache.paimon.spark.leafnode.PaimonLeafRunnableCommand -import org.apache.paimon.spark.schema.SparkSystemColumns +import org.apache.paimon.spark.schema.{PaimonMetadataColumn, SparkSystemColumns} +import org.apache.paimon.spark.schema.PaimonMetadataColumn.{FILE_PATH, FILE_PATH_COLUMN, ROW_INDEX, ROW_INDEX_COLUMN} import org.apache.paimon.spark.util.{EncoderUtils, SparkRowUtils} import org.apache.paimon.table.FileStoreTable import org.apache.paimon.table.sink.CommitMessage @@ -31,10 +32,11 @@ import org.apache.spark.sql.{Column, Dataset, Row, SparkSession} import org.apache.spark.sql.PaimonUtils._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, BasePredicate, Expression, Literal, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, BasePredicate, EqualTo, Expression, Literal, Or, UnsafeProjection} import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.functions.{col, lit, monotonically_increasing_id, sum} import org.apache.spark.sql.types.{ByteType, StructField, StructType} @@ -56,15 +58,19 @@ case class MergeIntoPaimonTable( override val table: FileStoreTable = v2Table.getTable.asInstanceOf[FileStoreTable] + lazy val relation: DataSourceV2Relation = PaimonRelation.getPaimonRelation(targetTable) + lazy val tableSchema: StructType = v2Table.schema private lazy val writer = PaimonSparkWriter(table) - private lazy val filteredTargetPlan: LogicalPlan = { + private lazy val (targetOnlyCondition, filteredTargetPlan): (Option[Expression], LogicalPlan) = { val filtersOnlyTarget = getExpressionOnlyRelated(mergeCondition, targetTable) - filtersOnlyTarget - .map(Filter.apply(_, targetTable)) - .getOrElse(targetTable) + ( + filtersOnlyTarget, + filtersOnlyTarget + .map(Filter.apply(_, targetTable)) + .getOrElse(targetTable)) } override def run(sparkSession: SparkSession): Seq[Row] = { @@ -81,61 +87,112 @@ case class MergeIntoPaimonTable( private def performMergeForPkTable(sparkSession: SparkSession): Seq[CommitMessage] = { writer.write( - constructChangedRows(sparkSession, createDataset(sparkSession, filteredTargetPlan))) + constructChangedRows( + sparkSession, + createDataset(sparkSession, filteredTargetPlan), + remainDeletedRow = true)) } private def performMergeForNonPkTable(sparkSession: SparkSession): Seq[CommitMessage] = { val targetDS = createDataset(sparkSession, filteredTargetPlan) val sourceDS = createDataset(sparkSession, sourceTable) - val targetFilePaths: Array[String] = findTouchedFiles(targetDS, sparkSession) - - val touchedFilePathsSet = mutable.Set.empty[String] - def hasUpdate(actions: Seq[MergeAction]): Boolean = { - actions.exists { - case _: UpdateAction | _: DeleteAction => true - case _ => false + // Step1: get the candidate data splits which are filtered by Paimon Predicate. + val candidateDataSplits = + findCandidateDataSplits(targetOnlyCondition.getOrElse(TrueLiteral), relation.output) + val dataFilePathToMeta = candidateFileMap(candidateDataSplits) + + if (deletionVectorsEnabled) { + // Step2: generate dataset that should contains ROW_KIND, FILE_PATH, ROW_INDEX columns + val metadataCols = Seq(FILE_PATH, ROW_INDEX) + val filteredRelation = createDataset( + sparkSession, + createNewScanPlan( + candidateDataSplits, + targetOnlyCondition.getOrElse(TrueLiteral), + relation, + metadataCols)) + val ds = constructChangedRows( + sparkSession, + filteredRelation, + remainDeletedRow = true, + metadataCols = metadataCols) + + ds.cache() + try { + val rowKindAttribute = ds.queryExecution.analyzed.output + .find(attr => sparkSession.sessionState.conf.resolver(attr.name, ROW_KIND_COL)) + .getOrElse(throw new RuntimeException("Can not find _row_kind_ column.")) + + // Step3: filter rows that should be marked as DELETED in Deletion Vector mode. + val dvDS = ds.where( + s"$ROW_KIND_COL = ${RowKind.DELETE.toByteValue} or $ROW_KIND_COL = ${RowKind.UPDATE_AFTER.toByteValue}") + val deletionVectors = collectDeletionVectors(dataFilePathToMeta, dvDS, sparkSession) + val indexCommitMsg = writer.persistDeletionVectors(deletionVectors) + + // Step4: filter rows that should be written as the inserted/updated data. + val toWriteDS = ds + .where( + s"$ROW_KIND_COL = ${RowKind.INSERT.toByteValue} or $ROW_KIND_COL = ${RowKind.UPDATE_AFTER.toByteValue}") + .drop(FILE_PATH_COLUMN, ROW_INDEX_COLUMN) + val addCommitMessage = writer.write(toWriteDS) + + // Step5: commit index and data commit messages + addCommitMessage ++ indexCommitMsg + } finally { + ds.unpersist() + } + } else { + val touchedFilePathsSet = mutable.Set.empty[String] + def hasUpdate(actions: Seq[MergeAction]): Boolean = { + actions.exists { + case _: UpdateAction | _: DeleteAction => true + case _ => false + } + } + if (hasUpdate(matchedActions)) { + touchedFilePathsSet ++= findTouchedFiles( + targetDS.join(sourceDS, new Column(mergeCondition), "inner"), + sparkSession) + } + if (hasUpdate(notMatchedBySourceActions)) { + touchedFilePathsSet ++= findTouchedFiles( + targetDS.join(sourceDS, new Column(mergeCondition), "left_anti"), + sparkSession) } - } - if (hasUpdate(matchedActions)) { - touchedFilePathsSet ++= findTouchedFiles( - targetDS.join(sourceDS, new Column(mergeCondition), "inner"), - sparkSession) - } - if (hasUpdate(notMatchedBySourceActions)) { - touchedFilePathsSet ++= findTouchedFiles( - targetDS.join(sourceDS, new Column(mergeCondition), "left_anti"), - sparkSession) - } - val touchedFilePaths: Array[String] = touchedFilePathsSet.toArray - val unTouchedFilePaths = targetFilePaths.filterNot(touchedFilePaths.contains) + val targetFilePaths: Array[String] = findTouchedFiles(targetDS, sparkSession) + val touchedFilePaths: Array[String] = touchedFilePathsSet.toArray + val unTouchedFilePaths = targetFilePaths.filterNot(touchedFilePaths.contains) - val relation = PaimonRelation.getPaimonRelation(targetTable) - val dataFilePathToMeta = candidateFileMap(findCandidateDataSplits(TrueLiteral, relation.output)) - val (touchedFiles, touchedFileRelation) = - createNewRelation(touchedFilePaths, dataFilePathToMeta, relation) - val (_, unTouchedFileRelation) = - createNewRelation(unTouchedFilePaths, dataFilePathToMeta, relation) + val (touchedFiles, touchedFileRelation) = + createNewRelation(touchedFilePaths, dataFilePathToMeta, relation) + val (_, unTouchedFileRelation) = + createNewRelation(unTouchedFilePaths, dataFilePathToMeta, relation) - // Add FILE_TOUCHED_COL to mark the row as coming from the touched file, if the row has not been - // modified and was from touched file, it should be kept too. - val targetDSWithFileTouchedCol = createDataset(sparkSession, touchedFileRelation) - .withColumn(FILE_TOUCHED_COL, lit(true)) - .union( - createDataset(sparkSession, unTouchedFileRelation).withColumn(FILE_TOUCHED_COL, lit(false))) + // Add FILE_TOUCHED_COL to mark the row as coming from the touched file, if the row has not been + // modified and was from touched file, it should be kept too. + val targetDSWithFileTouchedCol = createDataset(sparkSession, touchedFileRelation) + .withColumn(FILE_TOUCHED_COL, lit(true)) + .union(createDataset(sparkSession, unTouchedFileRelation) + .withColumn(FILE_TOUCHED_COL, lit(false))) - val addCommitMessage = - writer.write(constructChangedRows(sparkSession, targetDSWithFileTouchedCol)) - val deletedCommitMessage = buildDeletedCommitMessage(touchedFiles) + val toWriteDS = + constructChangedRows(sparkSession, targetDSWithFileTouchedCol).drop(ROW_KIND_COL) + val addCommitMessage = writer.write(toWriteDS) + val deletedCommitMessage = buildDeletedCommitMessage(touchedFiles) - addCommitMessage ++ deletedCommitMessage + addCommitMessage ++ deletedCommitMessage + } } /** Get a Dataset where each of Row has an additional column called _row_kind_. */ private def constructChangedRows( sparkSession: SparkSession, - targetDataset: Dataset[Row]): Dataset[Row] = { + targetDataset: Dataset[Row], + remainDeletedRow: Boolean = false, + deletionVectorEnabled: Boolean = false, + metadataCols: Seq[PaimonMetadataColumn] = Seq.empty): Dataset[Row] = { val targetDS = targetDataset .withColumn(TARGET_ROW_COL, lit(true)) @@ -158,28 +215,35 @@ case class MergeIntoPaimonTable( val noopOutput = targetOutput :+ Alias(Literal(NOOP_ROW_KIND_VALUE), ROW_KIND_COL)() val keepOutput = targetOutput :+ Alias(Literal(RowKind.INSERT.toByteValue), ROW_KIND_COL)() - def processMergeActions(actions: Seq[MergeAction], applyOnTargetTable: Boolean) = { - actions.map { - case UpdateAction(_, assignments) if applyOnTargetTable => + val resolver = sparkSession.sessionState.conf.resolver + val metadataAttributes = metadataCols.flatMap { + metadataCol => joinedPlan.output.find(attr => resolver(metadataCol.name, attr.name)) + } + def processMergeActions(actions: Seq[MergeAction]): Seq[Seq[Expression]] = { + val columnExprs = actions.map { + case UpdateAction(_, assignments) => assignments.map(_.value) :+ Literal(RowKind.UPDATE_AFTER.toByteValue) - case DeleteAction(_) if applyOnTargetTable => - if (withPrimaryKeys) { + case DeleteAction(_) => + if (remainDeletedRow || deletionVectorEnabled) { targetOutput :+ Literal(RowKind.DELETE.toByteValue) } else { + // If RowKind = NOOP_ROW_KIND_VALUE, then these rows will be dropped in MergeIntoProcessor.processPartition by default. + // If these rows still need to be remained, set MergeIntoProcessor.remainNoopRow true. noopOutput } - case InsertAction(_, assignments) if !applyOnTargetTable => + case InsertAction(_, assignments) => assignments.map(_.value) :+ Literal(RowKind.INSERT.toByteValue) - case _ => - throw new RuntimeException("should not be here.") } + columnExprs.map(exprs => exprs ++ metadataAttributes) } - val matchedOutputs = processMergeActions(matchedActions, applyOnTargetTable = true) - val notMatchedBySourceOutputs = - processMergeActions(notMatchedBySourceActions, applyOnTargetTable = true) - val notMatchedOutputs = processMergeActions(notMatchedActions, applyOnTargetTable = false) - val outputSchema = StructType(tableSchema.fields :+ StructField(ROW_KIND_COL, ByteType)) + val matchedOutputs = processMergeActions(matchedActions) + val notMatchedBySourceOutputs = processMergeActions(notMatchedBySourceActions) + val notMatchedOutputs = processMergeActions(notMatchedActions) + val outputFields = mutable.ArrayBuffer(tableSchema.fields: _*) + outputFields += StructField(ROW_KIND_COL, ByteType) + outputFields ++= metadataCols.map(_.toStructField) + val outputSchema = StructType(outputFields) val joinedRowEncoder = EncoderUtils.encode(joinedPlan.schema) val outputEncoder = EncoderUtils.encode(outputSchema).resolveAndBind() @@ -248,9 +312,10 @@ object MergeIntoPaimonTable { outputRowEncoder: ExpressionEncoder[Row] ) extends Serializable { - private val file_touched_col_index: Int = + private val rowKindColumnIndex: Int = outputRowEncoder.schema.fieldIndex(ROW_KIND_COL) + + private val fileTouchedColumnIndex: Int = SparkRowUtils.getFieldIndex(joinedRowEncoder.schema, FILE_TOUCHED_COL) - private val row_kind_col_index: Int = outputRowEncoder.schema.fieldIndex(ROW_KIND_COL) private def generateProjection(exprs: Seq[Expression]): UnsafeProjection = { UnsafeProjection.create(exprs, joinedAttributes) @@ -261,11 +326,11 @@ object MergeIntoPaimonTable { } private def fromTouchedFile(row: InternalRow): Boolean = { - file_touched_col_index != -1 && row.getBoolean(file_touched_col_index) + fileTouchedColumnIndex != -1 && row.getBoolean(fileTouchedColumnIndex) } private def unusedRow(row: InternalRow): Boolean = { - row.getByte(row_kind_col_index) == NOOP_ROW_KIND_VALUE + row.getByte(rowKindColumnIndex) == NOOP_ROW_KIND_VALUE } def processPartition(rowIterator: Iterator[Row]): Iterator[Row] = { @@ -284,7 +349,8 @@ object MergeIntoPaimonTable { def processRow(inputRow: InternalRow): InternalRow = { def applyPreds(preds: Seq[BasePredicate], projs: Seq[UnsafeProjection]): InternalRow = { preds.zip(projs).find { case (predicate, _) => predicate.eval(inputRow) } match { - case Some((_, projections)) => projections.apply(inputRow) + case Some((_, projections)) => + projections.apply(inputRow) case None => // keep the row if it is from touched file and not be matched if (fromTouchedFile(inputRow)) { diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala index 8e341a657ff8..aad4b82bd5b6 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala @@ -18,6 +18,7 @@ package org.apache.paimon.spark.commands +import org.apache.paimon.data.BinaryRow import org.apache.paimon.deletionvectors.BitmapDeletionVector import org.apache.paimon.fs.Path import org.apache.paimon.index.IndexFileMeta @@ -172,16 +173,33 @@ trait PaimonCommand extends WithFileStoreTable with ExpressionHelper { condition: Expression, relation: DataSourceV2Relation, sparkSession: SparkSession): Dataset[SparkDeletionVectors] = { + val metadataCols = Seq(FILE_PATH, ROW_INDEX) + val filteredRelation = createNewScanPlan(candidateDataSplits, condition, relation, metadataCols) + val dataWithMetadataColumns = createDataset(sparkSession, filteredRelation) + collectDeletionVectors(dataFilePathToMeta, dataWithMetadataColumns, sparkSession) + } + + protected def collectDeletionVectors( + dataFilePathToMeta: Map[String, SparkDataFileMeta], + dataWithMetadataColumns: Dataset[Row], + sparkSession: SparkSession): Dataset[SparkDeletionVectors] = { import sparkSession.implicits._ + val resolver = sparkSession.sessionState.conf.resolver + Seq(FILE_PATH_COLUMN, ROW_INDEX_COLUMN).foreach { + metadata => + dataWithMetadataColumns.schema + .find(field => resolver(field.name, metadata)) + .orElse(throw new RuntimeException( + "This input dataset doesn't contains the required metadata columns: __paimon_file_path and __paimon_row_index.")) + } + val dataFileToPartitionAndBucket = dataFilePathToMeta.mapValues(meta => (meta.partition, meta.bucket)).toArray - val metadataCols = Seq(FILE_PATH, ROW_INDEX) - val filteredRelation = createNewScanPlan(candidateDataSplits, condition, relation, metadataCols) val my_table = table val location = my_table.location - createDataset(sparkSession, filteredRelation) + dataWithMetadataColumns .select(FILE_PATH_COLUMN, ROW_INDEX_COLUMN) .as[(String, Long)] .groupByKey(_._1) @@ -208,7 +226,7 @@ trait PaimonCommand extends WithFileStoreTable with ExpressionHelper { } } - private def createNewScanPlan( + protected def createNewScanPlan( candidateDataSplits: Seq[DataSplit], condition: Expression, relation: DataSourceV2Relation, @@ -216,11 +234,9 @@ trait PaimonCommand extends WithFileStoreTable with ExpressionHelper { val metadataProj = metadataCols.map(_.toAttribute) val newRelation = relation.copy(output = relation.output ++ metadataProj) val scan = PaimonSplitScan(table, candidateDataSplits.toArray, metadataCols) - Project( - metadataProj, - FilterLogicalNode( - condition, - Compatibility.createDataSourceV2ScanRelation(newRelation, scan, newRelation.output))) + FilterLogicalNode( + condition, + Compatibility.createDataSourceV2ScanRelation(newRelation, scan, newRelation.output)) } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/schema/PaimonMetadataColumn.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/schema/PaimonMetadataColumn.scala index a82454d98dcc..b539a1351dff 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/schema/PaimonMetadataColumn.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/schema/PaimonMetadataColumn.scala @@ -23,7 +23,7 @@ import org.apache.paimon.types.DataField import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.connector.catalog.MetadataColumn -import org.apache.spark.sql.types.{DataType, LongType, StringType} +import org.apache.spark.sql.types.{DataType, LongType, StringType, StructField} case class PaimonMetadataColumn(id: Int, override val name: String, override val dataType: DataType) extends MetadataColumn { @@ -32,6 +32,10 @@ case class PaimonMetadataColumn(id: Int, override val name: String, override val new DataField(id, name, SparkTypeUtils.toPaimonType(dataType)); } + def toStructField: StructField = { + StructField(name, dataType); + } + def toAttribute: AttributeReference = { AttributeReference(name, dataType)() } diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DeletionVectorTest.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DeletionVectorTest.scala index 719117bcc2de..3adccaa7c157 100644 --- a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DeletionVectorTest.scala +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DeletionVectorTest.scala @@ -35,6 +35,59 @@ class DeletionVectorTest extends PaimonSparkTestBase { import testImplicits._ + bucketModes.foreach { + bucket => + test(s"Paimon DeletionVector: merge into with bucket = $bucket") { + withTable("source", "target") { + val bucketKey = if (bucket > 1) { + ", 'bucket-key' = 'a'" + } else { + "" + } + Seq((1, 100, "c11"), (3, 300, "c33"), (5, 500, "c55"), (7, 700, "c77"), (9, 900, "c99")) + .toDF("a", "b", "c") + .createOrReplaceTempView("source") + + spark.sql( + s""" + |CREATE TABLE target (a INT, b INT, c STRING) + |TBLPROPERTIES ('deletion-vectors.enabled' = 'true', 'bucket' = '$bucket' $bucketKey) + |""".stripMargin) + spark.sql( + "INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2'), (3, 30, 'c3'), (4, 40, 'c4'), (5, 50, 'c5')") + + val table = loadTable("target") + val dvMaintainerFactory = + new DeletionVectorsMaintainer.Factory(table.store().newIndexFileHandler()) + spark.sql(s""" + |MERGE INTO target + |USING source + |ON target.a = source.a + |WHEN MATCHED AND target.a = 5 THEN + |UPDATE SET b = source.b + target.b + |WHEN MATCHED AND source.c > 'c2' THEN + |UPDATE SET * + |WHEN MATCHED THEN + |DELETE + |WHEN NOT MATCHED AND c > 'c9' THEN + |INSERT (a, b, c) VALUES (a, b * 1.1, c) + |WHEN NOT MATCHED THEN + |INSERT * + |""".stripMargin) + + checkAnswer( + spark.sql("SELECT * FROM target ORDER BY a, b"), + Row(2, 20, "c2") :: Row(3, 300, "c33") :: Row(4, 40, "c4") :: Row(5, 550, "c5") :: Row( + 7, + 700, + "c77") :: Row(9, 990, "c99") :: Nil + ) + val deletionVectors = getAllLatestDeletionVectors(table, dvMaintainerFactory) + Assertions.assertTrue(deletionVectors.nonEmpty) + } + } + } + bucketModes.foreach { bucket => test(