diff --git a/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala b/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala index 1a7dffaf1257..f1f0d8c06567 100644 --- a/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala +++ b/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala @@ -18,12 +18,22 @@ package org.apache.paimon.spark.sql -import org.apache.paimon.spark.{PaimonPrimaryKeyBucketedTableTest, PaimonPrimaryKeyNonBucketTableTest} +import org.apache.paimon.spark.{PaimonAppendBucketedTableTest, PaimonAppendNonBucketTableTest, PaimonPrimaryKeyBucketedTableTest, PaimonPrimaryKeyNonBucketTableTest} class MergeIntoPrimaryKeyBucketedTableTest extends MergeIntoTableTestBase + with MergeIntoPrimaryKeyTableTest with PaimonPrimaryKeyBucketedTableTest {} class MergeIntoPrimaryKeyNonBucketTableTest extends MergeIntoTableTestBase + with MergeIntoPrimaryKeyTableTest with PaimonPrimaryKeyNonBucketTableTest {} + +class MergeIntoAppendBucketedTableTest + extends MergeIntoTableTestBase + with PaimonAppendBucketedTableTest {} + +class MergeIntoAppendNonBucketedTableTest + extends MergeIntoTableTestBase + with PaimonAppendNonBucketTableTest {} diff --git a/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala b/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala index 1a7dffaf1257..f1f0d8c06567 100644 --- a/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala +++ b/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala @@ -18,12 +18,22 @@ package org.apache.paimon.spark.sql -import org.apache.paimon.spark.{PaimonPrimaryKeyBucketedTableTest, PaimonPrimaryKeyNonBucketTableTest} +import org.apache.paimon.spark.{PaimonAppendBucketedTableTest, PaimonAppendNonBucketTableTest, PaimonPrimaryKeyBucketedTableTest, PaimonPrimaryKeyNonBucketTableTest} class MergeIntoPrimaryKeyBucketedTableTest extends MergeIntoTableTestBase + with MergeIntoPrimaryKeyTableTest with PaimonPrimaryKeyBucketedTableTest {} class MergeIntoPrimaryKeyNonBucketTableTest extends MergeIntoTableTestBase + with MergeIntoPrimaryKeyTableTest with PaimonPrimaryKeyNonBucketTableTest {} + +class MergeIntoAppendBucketedTableTest + extends MergeIntoTableTestBase + with PaimonAppendBucketedTableTest {} + +class MergeIntoAppendNonBucketedTableTest + extends MergeIntoTableTestBase + with PaimonAppendNonBucketTableTest {} diff --git a/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala b/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala index 13b79e744e8f..e1cfe3a3960f 100644 --- a/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala +++ b/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala @@ -18,14 +18,26 @@ package org.apache.paimon.spark.sql -import org.apache.paimon.spark.{PaimonPrimaryKeyBucketedTableTest, PaimonPrimaryKeyNonBucketTableTest} +import org.apache.paimon.spark.{PaimonAppendBucketedTableTest, PaimonAppendNonBucketTableTest, PaimonPrimaryKeyBucketedTableTest, PaimonPrimaryKeyNonBucketTableTest} class MergeIntoPrimaryKeyBucketedTableTest extends MergeIntoTableTestBase + with MergeIntoPrimaryKeyTableTest with MergeIntoNotMatchedBySourceTest with PaimonPrimaryKeyBucketedTableTest {} class MergeIntoPrimaryKeyNonBucketTableTest extends MergeIntoTableTestBase + with MergeIntoPrimaryKeyTableTest with MergeIntoNotMatchedBySourceTest with PaimonPrimaryKeyNonBucketTableTest {} + +class MergeIntoAppendBucketedTableTest + extends MergeIntoTableTestBase + with MergeIntoNotMatchedBySourceTest + with PaimonAppendBucketedTableTest {} + +class MergeIntoAppendNonBucketedTableTest + extends MergeIntoTableTestBase + with MergeIntoNotMatchedBySourceTest + with PaimonAppendNonBucketTableTest {} diff --git a/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala b/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala index 13b79e744e8f..e1cfe3a3960f 100644 --- a/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala +++ b/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala @@ -18,14 +18,26 @@ package org.apache.paimon.spark.sql -import org.apache.paimon.spark.{PaimonPrimaryKeyBucketedTableTest, PaimonPrimaryKeyNonBucketTableTest} +import org.apache.paimon.spark.{PaimonAppendBucketedTableTest, PaimonAppendNonBucketTableTest, PaimonPrimaryKeyBucketedTableTest, PaimonPrimaryKeyNonBucketTableTest} class MergeIntoPrimaryKeyBucketedTableTest extends MergeIntoTableTestBase + with MergeIntoPrimaryKeyTableTest with MergeIntoNotMatchedBySourceTest with PaimonPrimaryKeyBucketedTableTest {} class MergeIntoPrimaryKeyNonBucketTableTest extends MergeIntoTableTestBase + with MergeIntoPrimaryKeyTableTest with MergeIntoNotMatchedBySourceTest with PaimonPrimaryKeyNonBucketTableTest {} + +class MergeIntoAppendBucketedTableTest + extends MergeIntoTableTestBase + with MergeIntoNotMatchedBySourceTest + with PaimonAppendBucketedTableTest {} + +class MergeIntoAppendNonBucketedTableTest + extends MergeIntoTableTestBase + with MergeIntoNotMatchedBySourceTest + with PaimonAppendNonBucketTableTest {} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoBase.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoBase.scala index c07b58399883..ba6108395a7c 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoBase.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonMergeIntoBase.scala @@ -18,7 +18,6 @@ package org.apache.paimon.spark.catalyst.analysis -import org.apache.paimon.CoreOptions import org.apache.paimon.spark.SparkTable import org.apache.paimon.spark.catalyst.analysis.expressions.ExpressionHelper import org.apache.paimon.spark.commands.MergeIntoPaimonTable @@ -28,6 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeS import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import scala.collection.JavaConverters._ + trait PaimonMergeIntoBase extends Rule[LogicalPlan] with RowLevelHelper @@ -52,13 +53,14 @@ trait PaimonMergeIntoBase merge.notMatchedActions.flatMap(_.condition).foreach(checkCondition) val updateActions = merge.matchedActions.collect { case a: UpdateAction => a } - val primaryKeys = v2Table.properties().get(CoreOptions.PRIMARY_KEY.key).split(",") - checkUpdateActionValidity( - AttributeSet(targetOutput), - merge.mergeCondition, - updateActions, - primaryKeys) - + val primaryKeys = v2Table.getTable.primaryKeys().asScala + if (primaryKeys.nonEmpty) { + checkUpdateActionValidity( + AttributeSet(targetOutput), + merge.mergeCondition, + updateActions, + primaryKeys) + } val alignedMatchedActions = merge.matchedActions.map(checkAndAlignActionAssignment(_, targetOutput)) val alignedNotMatchedActions = diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/RowLevelOp.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/RowLevelOp.scala index 41881b7b7398..3e1e2b52d296 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/RowLevelOp.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/RowLevelOp.scala @@ -70,6 +70,6 @@ case object MergeInto extends RowLevelOp { override val supportedMergeEngine: Seq[MergeEngine] = Seq(MergeEngine.DEDUPLICATE, MergeEngine.PARTIAL_UPDATE) - override val supportAppendOnlyTable: Boolean = false + override val supportAppendOnlyTable: Boolean = true } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/DeleteFromPaimonTableCommand.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/DeleteFromPaimonTableCommand.scala index 2aef8e576410..cc440dd5c16b 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/DeleteFromPaimonTableCommand.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/DeleteFromPaimonTableCommand.scala @@ -18,15 +18,12 @@ package org.apache.paimon.spark.commands -import org.apache.paimon.CoreOptions import org.apache.paimon.CoreOptions.MergeEngine -import org.apache.paimon.spark.PaimonSplitScan -import org.apache.paimon.spark.catalyst.Compatibility import org.apache.paimon.spark.catalyst.analysis.expressions.ExpressionHelper import org.apache.paimon.spark.leafnode.PaimonLeafRunnableCommand import org.apache.paimon.spark.schema.SparkSystemColumns.ROW_KIND_COL import org.apache.paimon.spark.util.SQLHelper -import org.apache.paimon.table.{BucketMode, FileStoreTable} +import org.apache.paimon.table.FileStoreTable import org.apache.paimon.table.sink.{BatchWriteBuilder, CommitMessage} import org.apache.paimon.types.RowKind import org.apache.paimon.utils.InternalRowPartitionComputer @@ -144,20 +141,11 @@ case class DeleteFromPaimonTableCommand( findTouchedFiles(candidateDataSplits, condition, relation, sparkSession) // Step3: the smallest range of data files that need to be rewritten. - val touchedFiles = touchedFilePaths.map { - file => - dataFilePathToMeta.getOrElse(file, throw new RuntimeException(s"Missing file: $file")) - } + val (touchedFiles, newRelation) = + createNewRelation(touchedFilePaths, dataFilePathToMeta, relation) // Step4: build a dataframe that contains the unchanged data, and write out them. - val touchedDataSplits = - SparkDataFileMeta.convertToDataSplits(touchedFiles, rawConvertible = true, pathFactory) - val toRewriteScanRelation = Filter( - Not(condition), - Compatibility.createDataSourceV2ScanRelation( - relation, - PaimonSplitScan(table, touchedDataSplits), - relation.output)) + val toRewriteScanRelation = Filter(Not(condition), newRelation) val data = createDataset(sparkSession, toRewriteScanRelation) // only write new files, should have no compaction diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala index a06bc437dfcd..5fec8b99751f 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala @@ -18,12 +18,13 @@ package org.apache.paimon.spark.commands -import org.apache.paimon.options.Options -import org.apache.paimon.spark.{InsertInto, SparkTable} +import org.apache.paimon.spark.SparkTable +import org.apache.paimon.spark.catalyst.analysis.PaimonRelation import org.apache.paimon.spark.leafnode.PaimonLeafRunnableCommand import org.apache.paimon.spark.schema.SparkSystemColumns -import org.apache.paimon.spark.util.EncoderUtils +import org.apache.paimon.spark.util.{EncoderUtils, SparkRowUtils} import org.apache.paimon.table.FileStoreTable +import org.apache.paimon.table.sink.CommitMessage import org.apache.paimon.types.RowKind import org.apache.spark.sql.{Column, Dataset, Row, SparkSession} @@ -33,10 +34,12 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, BasePredicate, Expression, Literal, UnsafeProjection} import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate -import org.apache.spark.sql.catalyst.plans.logical.{DeleteAction, Filter, InsertAction, LogicalPlan, MergeAction, UpdateAction} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.functions.{col, lit, monotonically_increasing_id, sum} import org.apache.spark.sql.types.{ByteType, StructField, StructType} +import scala.collection.mutable + /** Command for Merge Into. */ case class MergeIntoPaimonTable( v2Table: SparkTable, @@ -55,7 +58,9 @@ case class MergeIntoPaimonTable( lazy val tableSchema: StructType = v2Table.schema - lazy val filteredTargetPlan: LogicalPlan = { + private lazy val writer = PaimonSparkWriter(table) + + private lazy val filteredTargetPlan: LogicalPlan = { val filtersOnlyTarget = getExpressionOnlyRelated(mergeCondition, targetTable) filtersOnlyTarget .map(Filter.apply(_, targetTable)) @@ -63,25 +68,75 @@ case class MergeIntoPaimonTable( } override def run(sparkSession: SparkSession): Seq[Row] = { - // Avoid that more than one source rows match the same target row. checkMatchRationality(sparkSession) + val commitMessages = if (withPrimaryKeys) { + performMergeForPkTable(sparkSession) + } else { + performMergeForNonPkTable(sparkSession) + } + writer.commit(commitMessages) + Seq.empty[Row] + } - val changed = constructChangedRows(sparkSession) + private def performMergeForPkTable(sparkSession: SparkSession): Seq[CommitMessage] = { + writer.write( + constructChangedRows(sparkSession, createDataset(sparkSession, filteredTargetPlan))) + } - WriteIntoPaimonTable( - table, - InsertInto, - changed, - new Options() - ).run(sparkSession) + private def performMergeForNonPkTable(sparkSession: SparkSession): Seq[CommitMessage] = { + val targetDS = createDataset(sparkSession, filteredTargetPlan) + val sourceDS = createDataset(sparkSession, sourceTable) - Seq.empty[Row] + val targetFilePaths: Array[String] = findTouchedFiles(targetDS, sparkSession) + + val touchedFilePathsSet = mutable.Set.empty[String] + def hasUpdate(actions: Seq[MergeAction]): Boolean = { + actions.exists { + case _: UpdateAction | _: DeleteAction => true + case _ => false + } + } + if (hasUpdate(matchedActions)) { + touchedFilePathsSet ++= findTouchedFiles( + targetDS.join(sourceDS, new Column(mergeCondition), "inner"), + sparkSession) + } + if (hasUpdate(notMatchedBySourceActions)) { + touchedFilePathsSet ++= findTouchedFiles( + targetDS.join(sourceDS, new Column(mergeCondition), "left_anti"), + sparkSession) + } + + val touchedFilePaths: Array[String] = touchedFilePathsSet.toArray + val unTouchedFilePaths = targetFilePaths.filterNot(touchedFilePaths.contains) + + val relation = PaimonRelation.getPaimonRelation(targetTable) + val dataFilePathToMeta = candidateFileMap(findCandidateDataSplits(TrueLiteral, relation.output)) + val (touchedFiles, touchedFileRelation) = + createNewRelation(touchedFilePaths, dataFilePathToMeta, relation) + val (_, unTouchedFileRelation) = + createNewRelation(unTouchedFilePaths, dataFilePathToMeta, relation) + + // Add FILE_TOUCHED_COL to mark the row as coming from the touched file, if the row has not been + // modified and was from touched file, it should be kept too. + val targetDSWithFileTouchedCol = createDataset(sparkSession, touchedFileRelation) + .withColumn(FILE_TOUCHED_COL, lit(true)) + .union( + createDataset(sparkSession, unTouchedFileRelation).withColumn(FILE_TOUCHED_COL, lit(false))) + + val addCommitMessage = + writer.write(constructChangedRows(sparkSession, targetDSWithFileTouchedCol)) + val deletedCommitMessage = buildDeletedCommitMessage(touchedFiles) + + addCommitMessage ++ deletedCommitMessage } /** Get a Dataset where each of Row has an additional column called _row_kind_. */ - private def constructChangedRows(sparkSession: SparkSession): Dataset[Row] = { - val targetDS = createDataset(sparkSession, filteredTargetPlan) + private def constructChangedRows( + sparkSession: SparkSession, + targetDataset: Dataset[Row]): Dataset[Row] = { + val targetDS = targetDataset .withColumn(TARGET_ROW_COL, lit(true)) val sourceDS = createDataset(sparkSession, sourceTable) @@ -100,29 +155,30 @@ case class MergeIntoPaimonTable( val matchedExprs = matchedActions.map(_.condition.getOrElse(TrueLiteral)) val notMatchedExprs = notMatchedActions.map(_.condition.getOrElse(TrueLiteral)) val notMatchedBySourceExprs = notMatchedBySourceActions.map(_.condition.getOrElse(TrueLiteral)) - val matchedOutputs = matchedActions.map { - case UpdateAction(_, assignments) => - assignments.map(_.value) :+ Literal(RowKind.UPDATE_AFTER.toByteValue) - case DeleteAction(_) => - targetOutput :+ Literal(RowKind.DELETE.toByteValue) - case _ => - throw new RuntimeException("should not be here.") - } - val notMatchedBySourceOutputs = notMatchedBySourceActions.map { - case UpdateAction(_, assignments) => - assignments.map(_.value) :+ Literal(RowKind.UPDATE_AFTER.toByteValue) - case DeleteAction(_) => - targetOutput :+ Literal(RowKind.DELETE.toByteValue) - case _ => - throw new RuntimeException("should not be here.") - } - val notMatchedOutputs = notMatchedActions.map { - case InsertAction(_, assignments) => - assignments.map(_.value) :+ Literal(RowKind.INSERT.toByteValue) - case _ => - throw new RuntimeException("should not be here.") - } val noopOutput = targetOutput :+ Alias(Literal(NOOP_ROW_KIND_VALUE), ROW_KIND_COL)() + val keepOutput = targetOutput :+ Alias(Literal(RowKind.INSERT.toByteValue), ROW_KIND_COL)() + + def processMergeActions(actions: Seq[MergeAction], applyOnTargetTable: Boolean) = { + actions.map { + case UpdateAction(_, assignments) if applyOnTargetTable => + assignments.map(_.value) :+ Literal(RowKind.UPDATE_AFTER.toByteValue) + case DeleteAction(_) if applyOnTargetTable => + if (withPrimaryKeys) { + targetOutput :+ Literal(RowKind.DELETE.toByteValue) + } else { + noopOutput + } + case InsertAction(_, assignments) if !applyOnTargetTable => + assignments.map(_.value) :+ Literal(RowKind.INSERT.toByteValue) + case _ => + throw new RuntimeException("should not be here.") + } + } + + val matchedOutputs = processMergeActions(matchedActions, applyOnTargetTable = true) + val notMatchedBySourceOutputs = + processMergeActions(notMatchedBySourceActions, applyOnTargetTable = true) + val notMatchedOutputs = processMergeActions(notMatchedActions, applyOnTargetTable = false) val outputSchema = StructType(tableSchema.fields :+ StructField(ROW_KIND_COL, ByteType)) val joinedRowEncoder = EncoderUtils.encode(joinedPlan.schema) @@ -139,10 +195,11 @@ case class MergeIntoPaimonTable( notMatchedExprs, notMatchedOutputs, noopOutput, + keepOutput, joinedRowEncoder, outputEncoder ) - joinedDS.mapPartitions(processor.processPartition)(outputEncoder) + joinedDS.mapPartitions(processor.processPartition)(outputEncoder).toDF() } private def checkMatchRationality(sparkSession: SparkSession): Unit = { @@ -159,21 +216,23 @@ case class MergeIntoPaimonTable( .count() if (count > 0) { throw new RuntimeException( - "Can't execute this MergeInto when there are some target rows that each of them match more then one source rows. It may lead to an unexpected result.") + "Can't execute this MergeInto when there are some target rows that each of " + + "them match more then one source rows. It may lead to an unexpected result.") } } } } object MergeIntoPaimonTable { - val ROW_ID_COL = "_row_id_" - val SOURCE_ROW_COL = "_source_row_" - val TARGET_ROW_COL = "_target_row_" + private val ROW_ID_COL = "_row_id_" + private val SOURCE_ROW_COL = "_source_row_" + private val TARGET_ROW_COL = "_target_row_" + private val FILE_TOUCHED_COL = "_file_touched_col_" // +I, +U, -U, -D - val ROW_KIND_COL: String = SparkSystemColumns.ROW_KIND_COL - val NOOP_ROW_KIND_VALUE: Byte = "-1".toByte + private val ROW_KIND_COL: String = SparkSystemColumns.ROW_KIND_COL + private val NOOP_ROW_KIND_VALUE: Byte = "-1".toByte - case class MergeIntoProcessor( + private case class MergeIntoProcessor( joinedAttributes: Seq[Attribute], targetRowHasNoMatch: Expression, sourceRowHasNoMatch: Expression, @@ -184,10 +243,15 @@ object MergeIntoPaimonTable { notMatchedConditions: Seq[Expression], notMatchedOutputs: Seq[Seq[Expression]], noopCopyOutput: Seq[Expression], + keepOutput: Seq[Expression], joinedRowEncoder: ExpressionEncoder[Row], outputRowEncoder: ExpressionEncoder[Row] ) extends Serializable { + private val file_touched_col_index: Int = + SparkRowUtils.getFieldIndex(joinedRowEncoder.schema, FILE_TOUCHED_COL) + private val row_kind_col_index: Int = outputRowEncoder.schema.fieldIndex(ROW_KIND_COL) + private def generateProjection(exprs: Seq[Expression]): UnsafeProjection = { UnsafeProjection.create(exprs, joinedAttributes) } @@ -196,8 +260,12 @@ object MergeIntoPaimonTable { GeneratePredicate.generate(expr, joinedAttributes) } + private def fromTouchedFile(row: InternalRow): Boolean = { + file_touched_col_index != -1 && row.getBoolean(file_touched_col_index) + } + private def unusedRow(row: InternalRow): Boolean = { - row.getByte(outputRowEncoder.schema.fieldIndex(ROW_KIND_COL)) == NOOP_ROW_KIND_VALUE + row.getByte(row_kind_col_index) == NOOP_ROW_KIND_VALUE } def processPartition(rowIterator: Iterator[Row]): Iterator[Row] = { @@ -210,38 +278,29 @@ object MergeIntoPaimonTable { val notMatchedPreds = notMatchedConditions.map(generatePredicate) val notMatchedProjs = notMatchedOutputs.map(generateProjection) val noopCopyProj = generateProjection(noopCopyOutput) + val keepProj = generateProjection(keepOutput) val outputProj = UnsafeProjection.create(outputRowEncoder.schema) def processRow(inputRow: InternalRow): InternalRow = { - if (targetRowHasNoMatchPred.eval(inputRow)) { - val pair = notMatchedBySourcePreds.zip(notMatchedBySourceProjs).find { - case (predicate, _) => predicate.eval(inputRow) + def applyPreds(preds: Seq[BasePredicate], projs: Seq[UnsafeProjection]): InternalRow = { + preds.zip(projs).find { case (predicate, _) => predicate.eval(inputRow) } match { + case Some((_, projections)) => projections.apply(inputRow) + case None => + // keep the row if it is from touched file and not be matched + if (fromTouchedFile(inputRow)) { + keepProj.apply(inputRow) + } else { + noopCopyProj.apply(inputRow) + } } + } - pair match { - case Some((_, projections)) => - projections.apply(inputRow) - case None => noopCopyProj.apply(inputRow) - } + if (targetRowHasNoMatchPred.eval(inputRow)) { + applyPreds(notMatchedBySourcePreds, notMatchedBySourceProjs) } else if (sourceRowHasNoMatchPred.eval(inputRow)) { - val pair = notMatchedPreds.zip(notMatchedProjs).find { - case (predicate, _) => predicate.eval(inputRow) - } - - pair match { - case Some((_, projections)) => - projections.apply(inputRow) - case None => noopCopyProj.apply(inputRow) - } + applyPreds(notMatchedPreds, notMatchedProjs) } else { - val pair = - matchedPreds.zip(matchedProjs).find { case (predicate, _) => predicate.eval(inputRow) } - - pair match { - case Some((_, projections)) => - projections.apply(inputRow) - case None => noopCopyProj.apply(inputRow) - } + applyPreds(matchedPreds, matchedProjs) } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala index 4a42e4f46077..8e341a657ff8 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala @@ -33,12 +33,12 @@ import org.apache.paimon.table.source.DataSplit import org.apache.paimon.types.RowType import org.apache.paimon.utils.SerializationUtils -import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.PaimonUtils.createDataset import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.plans.logical.{Filter => FilterLogicalNode, LogicalPlan, Project} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation} import org.apache.spark.sql.sources.{AlwaysTrue, And, EqualNullSafe, Filter} import java.net.URI @@ -101,7 +101,7 @@ trait PaimonCommand extends WithFileStoreTable with ExpressionHelper { output: Seq[Attribute]): Seq[DataSplit] = { // low level snapshot reader, it can not be affected by 'scan.mode' val snapshotReader = table.newSnapshotReader() - if (condition == TrueLiteral) { + if (condition != TrueLiteral) { val filter = convertConditionToPaimonPredicate(condition, output, rowType, ignoreFailure = true) filter.foreach(snapshotReader.withFilter) @@ -115,8 +115,6 @@ trait PaimonCommand extends WithFileStoreTable with ExpressionHelper { condition: Expression, relation: DataSourceV2Relation, sparkSession: SparkSession): Array[String] = { - import sparkSession.implicits._ - for (split <- candidateDataSplits) { if (!split.rawConvertible()) { throw new IllegalArgumentException( @@ -126,7 +124,14 @@ trait PaimonCommand extends WithFileStoreTable with ExpressionHelper { val metadataCols = Seq(FILE_PATH) val filteredRelation = createNewScanPlan(candidateDataSplits, condition, relation, metadataCols) - createDataset(sparkSession, filteredRelation) + findTouchedFiles(createDataset(sparkSession, filteredRelation), sparkSession) + } + + protected def findTouchedFiles( + dataset: Dataset[Row], + sparkSession: SparkSession): Array[String] = { + import sparkSession.implicits._ + dataset .select(FILE_PATH_COLUMN) .distinct() .as[String] @@ -134,6 +139,21 @@ trait PaimonCommand extends WithFileStoreTable with ExpressionHelper { .map(relativePath) } + protected def createNewRelation( + filePaths: Array[String], + filePathToMeta: Map[String, SparkDataFileMeta], + relation: DataSourceV2Relation): (Array[SparkDataFileMeta], DataSourceV2ScanRelation) = { + val files = filePaths.map( + file => filePathToMeta.getOrElse(file, throw new RuntimeException(s"Missing file: $file"))) + val touchedDataSplits = + SparkDataFileMeta.convertToDataSplits(files, rawConvertible = true, fileStore.pathFactory()) + val newRelation = Compatibility.createDataSourceV2ScanRelation( + relation, + PaimonSplitScan(table, touchedDataSplits), + relation.output) + (files, newRelation) + } + /** Notice that, the key is a relative path, not just the file name. */ protected def candidateFileMap( candidateDataSplits: Seq[DataSplit]): Map[String, SparkDataFileMeta] = { diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/UpdatePaimonTableCommand.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/UpdatePaimonTableCommand.scala index 6c7d07bf5e26..dd88f388cb63 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/UpdatePaimonTableCommand.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/UpdatePaimonTableCommand.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.PaimonUtils.createDataset import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, If} import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Filter, Project, SupportsSubquery} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation} import org.apache.spark.sql.functions.lit case class UpdatePaimonTableCommand( @@ -116,15 +116,11 @@ case class UpdatePaimonTableCommand( findTouchedFiles(candidateDataSplits, condition, relation, sparkSession) // Step3: the smallest range of data files that need to be rewritten. - val touchedFiles = touchedFilePaths.map { - file => - dataFilePathToMeta.getOrElse(file, throw new RuntimeException(s"Missing file: $file")) - } + val (touchedFiles, touchedFileRelation) = + createNewRelation(touchedFilePaths, dataFilePathToMeta, relation) // Step4: build a dataframe that contains the unchanged and updated data, and write out them. - val touchedDataSplits = - SparkDataFileMeta.convertToDataSplits(touchedFiles, rawConvertible = true, pathFactory) - val addCommitMessage = writeUpdatedAndUnchangedData(sparkSession, touchedDataSplits) + val addCommitMessage = writeUpdatedAndUnchangedData(sparkSession, touchedFileRelation) // Step5: convert the deleted files that need to be wrote to commit message. val deletedCommitMessage = buildDeletedCommitMessage(touchedFiles) @@ -157,7 +153,7 @@ case class UpdatePaimonTableCommand( private def writeUpdatedAndUnchangedData( sparkSession: SparkSession, - touchedDataSplits: Array[DataSplit]): Seq[CommitMessage] = { + toUpdateScanRelation: DataSourceV2ScanRelation): Seq[CommitMessage] = { val updateColumns = updateExpressions.zip(relation.output).map { case (update, origin) => val updated = if (condition == TrueLiteral) { @@ -168,10 +164,6 @@ case class UpdatePaimonTableCommand( new Column(updated).as(origin.name, origin.metadata) } - val toUpdateScanRelation = Compatibility.createDataSourceV2ScanRelation( - relation, - PaimonSplitScan(table, touchedDataSplits), - relation.output) val data = createDataset(sparkSession, toUpdateScanRelation).select(updateColumns: _*) writer.write(data) } diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/PaimonTableTest.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/PaimonTableTest.scala index 53f41833f7a7..0477bcbafaa6 100644 --- a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/PaimonTableTest.scala +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/PaimonTableTest.scala @@ -20,24 +20,39 @@ package org.apache.paimon.spark import org.apache.spark.sql.test.SharedSparkSession -import scala.collection.mutable - trait PaimonTableTest extends SharedSparkSession { val bucket: Int - def appendPrimaryKey(primaryKeys: Seq[String], props: mutable.Map[String, String]): Unit - + def initProps(primaryOrBucketKeys: Seq[String], partitionKeys: Seq[String]): Map[String, String] + + /** + * Create a table configured by the given parameters. + * + * @param tableName + * table name + * @param columns + * columns string, e.g. "a INT, b INT, c STRING" + * @param primaryOrBucketKeys + * for [[PaimonPrimaryKeyTable]] they are `primary-key`, if you want to specify additional + * `bucket-key`, you can specify that in extraProps. for [[PaimonAppendTable]] they are + * `bucket-key` + * @param partitionKeys + * partition keys seq + * @param extraProps + * extra properties map + */ def createTable( tableName: String, columns: String, - primaryKeys: Seq[String], + primaryOrBucketKeys: Seq[String], partitionKeys: Seq[String] = Seq.empty, - props: Map[String, String] = Map.empty): Unit = { - val newProps: mutable.Map[String, String] = - mutable.Map.empty[String, String] ++ Map("bucket" -> bucket.toString) ++ props - appendPrimaryKey(primaryKeys, newProps) - createTable0(tableName, columns, partitionKeys, newProps.toMap) + extraProps: Map[String, String] = Map.empty): Unit = { + createTable0( + tableName, + columns, + partitionKeys, + initProps(primaryOrBucketKeys, partitionKeys) ++ extraProps) } private def createTable0( @@ -72,35 +87,35 @@ trait PaimonNonBucketedTable { val bucket: Int = -1 } -trait PaimonPrimaryKeyTable { - def appendPrimaryKey(primaryKeys: Seq[String], props: mutable.Map[String, String]): Unit = { - assert(primaryKeys.nonEmpty) - props += ("primary-key" -> primaryKeys.mkString(",")) +trait PaimonPrimaryKeyTable extends PaimonTableTest { + def initProps( + primaryOrBucketKeys: Seq[String], + partitionKeys: Seq[String]): Map[String, String] = { + assert(primaryOrBucketKeys.nonEmpty) + Map("primary-key" -> primaryOrBucketKeys.mkString(","), "bucket" -> bucket.toString) } } -trait PaimonAppendTable { - def appendPrimaryKey(primaryKeys: Seq[String], props: mutable.Map[String, String]): Unit = { - // nothing to do +trait PaimonAppendTable extends PaimonTableTest { + def initProps( + primaryOrBucketKeys: Seq[String], + partitionKeys: Seq[String]): Map[String, String] = { + if (bucket == -1) { + // Ignore bucket keys for unaware bucket table + Map("bucket" -> bucket.toString) + } else { + // Filter partition keys in bucket keys for fixed bucket table + val bucketKeys = primaryOrBucketKeys.filterNot(partitionKeys.contains(_)) + assert(bucketKeys.nonEmpty) + Map("bucket-key" -> bucketKeys.mkString(","), "bucket" -> bucket.toString) + } } } -trait PaimonPrimaryKeyBucketedTableTest - extends PaimonTableTest - with PaimonPrimaryKeyTable - with PaimonBucketedTable +trait PaimonPrimaryKeyBucketedTableTest extends PaimonPrimaryKeyTable with PaimonBucketedTable -trait PaimonPrimaryKeyNonBucketTableTest - extends PaimonTableTest - with PaimonPrimaryKeyTable - with PaimonNonBucketedTable +trait PaimonPrimaryKeyNonBucketTableTest extends PaimonPrimaryKeyTable with PaimonNonBucketedTable -trait PaimonAppendBucketedTableTest - extends PaimonTableTest - with PaimonAppendTable - with PaimonBucketedTable +trait PaimonAppendBucketedTableTest extends PaimonAppendTable with PaimonBucketedTable -trait PaimonAppendNonBucketTableTest - extends PaimonTableTest - with PaimonAppendTable - with PaimonNonBucketedTable +trait PaimonAppendNonBucketTableTest extends PaimonAppendTable with PaimonNonBucketedTable diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTestBase.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTestBase.scala index 65670ebd8db3..1a4eae51d007 100644 --- a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTestBase.scala +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTestBase.scala @@ -18,7 +18,7 @@ package org.apache.paimon.spark.sql -import org.apache.paimon.spark.{PaimonSparkTestBase, PaimonTableTest} +import org.apache.paimon.spark.{PaimonPrimaryKeyTable, PaimonSparkTestBase, PaimonTableTest} import org.apache.spark.sql.Row @@ -497,6 +497,30 @@ abstract class MergeIntoTableTestBase extends PaimonSparkTestBase with PaimonTab Row(1, 10, Row("x1", "y")) :: Row(2, 20, Row("x", "y")) :: Nil) } } + test(s"Paimon MergeInto: update on source eq target condition") { + withTable("source", "target") { + Seq((1, 100, "c11"), (3, 300, "c33")).toDF("a", "b", "c").createOrReplaceTempView("source") + + createTable("target", "a INT, b INT, c STRING", Seq("a")) + sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2')") + + sql(s""" + |MERGE INTO target + |USING source + |ON source.a = target.a + |WHEN MATCHED THEN + |UPDATE SET a = source.a, b = source.b, c = source.c + |""".stripMargin) + + checkAnswer( + sql("SELECT * FROM target ORDER BY a, b"), + Row(1, 100, "c11") :: Row(2, 20, "c2") :: Nil) + } + } +} + +trait MergeIntoPrimaryKeyTableTest extends PaimonSparkTestBase with PaimonPrimaryKeyTable { + import testImplicits._ test("Paimon MergeInto: fail in case that maybe update primary key column") { withTable("source", "target") { @@ -535,50 +559,4 @@ abstract class MergeIntoTableTestBase extends PaimonSparkTestBase with PaimonTab Row(1, 10, "c111") :: Row(2, 20, "c2") :: Row(103, 30, "c333") :: Nil) } } - - test("Paimon MergeInto: not support in table without primary keys") { - withTable("source", "target") { - - Seq((1, 100, "c11"), (3, 300, "c33")).toDF("a", "b", "c").createOrReplaceTempView("source") - - spark.sql(s""" - |CREATE TABLE target (a INT, b INT, c STRING) - |""".stripMargin) - spark.sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2')") - - val error = intercept[RuntimeException] { - spark.sql(s""" - |MERGE INTO target - |USING source - |ON target.a = source.a - |WHEN MATCHED THEN - |UPDATE SET a = source.a, b = source.b, c = source.c - |WHEN NOT MATCHED - |THEN INSERT (a, b, c) values (a, b, c) - |""".stripMargin) - }.getMessage - assert(error.contains("Only support to MergeInto table with primary keys.")) - } - } - - test(s"Paimon MergeInto: update on source eq target condition") { - withTable("source", "target") { - Seq((1, 100, "c11"), (3, 300, "c33")).toDF("a", "b", "c").createOrReplaceTempView("source") - - createTable("target", "a INT, b INT, c STRING", Seq("a")) - sql("INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2')") - - sql(s""" - |MERGE INTO target - |USING source - |ON source.a = target.a - |WHEN MATCHED THEN - |UPDATE SET a = source.a, b = source.b, c = source.c - |""".stripMargin) - - checkAnswer( - sql("SELECT * FROM target ORDER BY a, b"), - Row(1, 100, "c11") :: Row(2, 20, "c2") :: Nil) - } - } }