Skip to content

Commit

Permalink
v1
Browse files Browse the repository at this point in the history
  • Loading branch information
Zouxxyy committed Aug 7, 2024
1 parent b194e19 commit 48394ce
Show file tree
Hide file tree
Showing 8 changed files with 884 additions and 702 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

package org.apache.paimon.spark.catalyst.analysis

import org.apache.paimon.CoreOptions
import org.apache.paimon.spark.SparkTable
import org.apache.paimon.spark.catalyst.analysis.expressions.ExpressionHelper
import org.apache.paimon.spark.commands.MergeIntoPaimonTable
Expand All @@ -28,6 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeS
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule

import scala.collection.JavaConverters._

trait PaimonMergeIntoBase
extends Rule[LogicalPlan]
with RowLevelHelper
Expand All @@ -52,13 +53,14 @@ trait PaimonMergeIntoBase
merge.notMatchedActions.flatMap(_.condition).foreach(checkCondition)

val updateActions = merge.matchedActions.collect { case a: UpdateAction => a }
val primaryKeys = v2Table.properties().get(CoreOptions.PRIMARY_KEY.key).split(",")
checkUpdateActionValidity(
AttributeSet(targetOutput),
merge.mergeCondition,
updateActions,
primaryKeys)

val primaryKeys = v2Table.getTable.primaryKeys().asScala
if (primaryKeys.nonEmpty) {
checkUpdateActionValidity(
AttributeSet(targetOutput),
merge.mergeCondition,
updateActions,
primaryKeys)
}
val alignedMatchedActions =
merge.matchedActions.map(checkAndAlignActionAssignment(_, targetOutput))
val alignedNotMatchedActions =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,6 @@ case object MergeInto extends RowLevelOp {
override val supportedMergeEngine: Seq[MergeEngine] =
Seq(MergeEngine.DEDUPLICATE, MergeEngine.PARTIAL_UPDATE)

override val supportAppendOnlyTable: Boolean = false
override val supportAppendOnlyTable: Boolean = true

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,12 @@

package org.apache.paimon.spark.commands

import org.apache.paimon.CoreOptions
import org.apache.paimon.CoreOptions.MergeEngine
import org.apache.paimon.spark.PaimonSplitScan
import org.apache.paimon.spark.catalyst.Compatibility
import org.apache.paimon.spark.catalyst.analysis.expressions.ExpressionHelper
import org.apache.paimon.spark.leafnode.PaimonLeafRunnableCommand
import org.apache.paimon.spark.schema.SparkSystemColumns.ROW_KIND_COL
import org.apache.paimon.spark.util.SQLHelper
import org.apache.paimon.table.{BucketMode, FileStoreTable}
import org.apache.paimon.table.FileStoreTable
import org.apache.paimon.table.sink.{BatchWriteBuilder, CommitMessage}
import org.apache.paimon.types.RowKind
import org.apache.paimon.utils.InternalRowPartitionComputer
Expand Down Expand Up @@ -144,20 +141,11 @@ case class DeleteFromPaimonTableCommand(
findTouchedFiles(candidateDataSplits, condition, relation, sparkSession)

// Step3: the smallest range of data files that need to be rewritten.
val touchedFiles = touchedFilePaths.map {
file =>
dataFilePathToMeta.getOrElse(file, throw new RuntimeException(s"Missing file: $file"))
}
val (touchedFiles, newRelation) =
createNewRelation(touchedFilePaths, dataFilePathToMeta, relation)

// Step4: build a dataframe that contains the unchanged data, and write out them.
val touchedDataSplits =
SparkDataFileMeta.convertToDataSplits(touchedFiles, rawConvertible = true, pathFactory)
val toRewriteScanRelation = Filter(
Not(condition),
Compatibility.createDataSourceV2ScanRelation(
relation,
PaimonSplitScan(table, touchedDataSplits),
relation.output))
val toRewriteScanRelation = Filter(Not(condition), newRelation)
val data = createDataset(sparkSession, toRewriteScanRelation)

// only write new files, should have no compaction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@

package org.apache.paimon.spark.commands

import org.apache.paimon.options.Options
import org.apache.paimon.spark.{InsertInto, SparkTable}
import org.apache.paimon.spark.SparkTable
import org.apache.paimon.spark.catalyst.analysis.PaimonRelation
import org.apache.paimon.spark.leafnode.PaimonLeafRunnableCommand
import org.apache.paimon.spark.schema.SparkSystemColumns
import org.apache.paimon.spark.util.EncoderUtils
import org.apache.paimon.spark.util.{EncoderUtils, SparkRowUtils}
import org.apache.paimon.table.FileStoreTable
import org.apache.paimon.table.sink.CommitMessage
import org.apache.paimon.types.RowKind

import org.apache.spark.sql.{Column, Dataset, Row, SparkSession}
Expand All @@ -33,10 +34,12 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, BasePredicate, Expression, Literal, UnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
import org.apache.spark.sql.catalyst.plans.logical.{DeleteAction, Filter, InsertAction, LogicalPlan, MergeAction, UpdateAction}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.functions.{col, lit, monotonically_increasing_id, sum}
import org.apache.spark.sql.types.{ByteType, StructField, StructType}

import scala.collection.mutable

/** Command for Merge Into. */
case class MergeIntoPaimonTable(
v2Table: SparkTable,
Expand All @@ -55,33 +58,83 @@ case class MergeIntoPaimonTable(

lazy val tableSchema: StructType = v2Table.schema

lazy val filteredTargetPlan: LogicalPlan = {
private lazy val writer = PaimonSparkWriter(table)

private lazy val filteredTargetPlan: LogicalPlan = {
val filtersOnlyTarget = getExpressionOnlyRelated(mergeCondition, targetTable)
filtersOnlyTarget
.map(Filter.apply(_, targetTable))
.getOrElse(targetTable)
}

override def run(sparkSession: SparkSession): Seq[Row] = {

// Avoid that more than one source rows match the same target row.
checkMatchRationality(sparkSession)
val commitMessages = if (withPrimaryKeys) {
performMergeForPkTable(sparkSession)
} else {
performMergeForNonPkTable(sparkSession)
}
writer.commit(commitMessages)
Seq.empty[Row]
}

val changed = constructChangedRows(sparkSession)
private def performMergeForPkTable(sparkSession: SparkSession): Seq[CommitMessage] = {
writer.write(
constructChangedRows(sparkSession, createDataset(sparkSession, filteredTargetPlan)))
}

WriteIntoPaimonTable(
table,
InsertInto,
changed,
new Options()
).run(sparkSession)
private def performMergeForNonPkTable(sparkSession: SparkSession): Seq[CommitMessage] = {
val targetDS = createDataset(sparkSession, filteredTargetPlan)
val sourceDS = createDataset(sparkSession, sourceTable)

Seq.empty[Row]
val targetFilePaths: Array[String] = findTouchedFiles(targetDS, sparkSession)

val touchedFilePathsSet = mutable.Set.empty[String]
def hasUpdate(actions: Seq[MergeAction]): Boolean = {
actions.exists {
case _: UpdateAction | _: DeleteAction => true
case _ => false
}
}
if (hasUpdate(matchedActions)) {
touchedFilePathsSet ++= findTouchedFiles(
targetDS.join(sourceDS, new Column(mergeCondition), "inner"),
sparkSession)
}
if (hasUpdate(notMatchedBySourceActions)) {
touchedFilePathsSet ++= findTouchedFiles(
targetDS.join(sourceDS, new Column(mergeCondition), "left_anti"),
sparkSession)
}

val touchedFilePaths: Array[String] = touchedFilePathsSet.toArray
val unTouchedFilePaths = targetFilePaths.filterNot(touchedFilePaths.contains)

val relation = PaimonRelation.getPaimonRelation(targetTable)
val dataFilePathToMeta = candidateFileMap(findCandidateDataSplits(TrueLiteral, relation.output))
val (touchedFiles, touchedFileRelation) =
createNewRelation(touchedFilePaths, dataFilePathToMeta, relation)
val (_, unTouchedFileRelation) =
createNewRelation(unTouchedFilePaths, dataFilePathToMeta, relation)

val targetDSWithFileTouchedCol = createDataset(sparkSession, touchedFileRelation)
.withColumn(FILE_TOUCHED_COL, lit(true))
.union(
createDataset(sparkSession, unTouchedFileRelation).withColumn(FILE_TOUCHED_COL, lit(false)))

val addCommitMessage =
writer.write(constructChangedRows(sparkSession, targetDSWithFileTouchedCol))
val deletedCommitMessage = buildDeletedCommitMessage(touchedFiles)

addCommitMessage ++ deletedCommitMessage
}

/** Get a Dataset where each of Row has an additional column called _row_kind_. */
private def constructChangedRows(sparkSession: SparkSession): Dataset[Row] = {
val targetDS = createDataset(sparkSession, filteredTargetPlan)
private def constructChangedRows(
sparkSession: SparkSession,
targetDataset: Dataset[Row]): Dataset[Row] = {
val targetDS = targetDataset
.withColumn(TARGET_ROW_COL, lit(true))

val sourceDS = createDataset(sparkSession, sourceTable)
Expand All @@ -100,29 +153,30 @@ case class MergeIntoPaimonTable(
val matchedExprs = matchedActions.map(_.condition.getOrElse(TrueLiteral))
val notMatchedExprs = notMatchedActions.map(_.condition.getOrElse(TrueLiteral))
val notMatchedBySourceExprs = notMatchedBySourceActions.map(_.condition.getOrElse(TrueLiteral))
val matchedOutputs = matchedActions.map {
case UpdateAction(_, assignments) =>
assignments.map(_.value) :+ Literal(RowKind.UPDATE_AFTER.toByteValue)
case DeleteAction(_) =>
targetOutput :+ Literal(RowKind.DELETE.toByteValue)
case _ =>
throw new RuntimeException("should not be here.")
}
val notMatchedBySourceOutputs = notMatchedBySourceActions.map {
case UpdateAction(_, assignments) =>
assignments.map(_.value) :+ Literal(RowKind.UPDATE_AFTER.toByteValue)
case DeleteAction(_) =>
targetOutput :+ Literal(RowKind.DELETE.toByteValue)
case _ =>
throw new RuntimeException("should not be here.")
}
val notMatchedOutputs = notMatchedActions.map {
case InsertAction(_, assignments) =>
assignments.map(_.value) :+ Literal(RowKind.INSERT.toByteValue)
case _ =>
throw new RuntimeException("should not be here.")
}
val noopOutput = targetOutput :+ Alias(Literal(NOOP_ROW_KIND_VALUE), ROW_KIND_COL)()
val keepOutput = targetOutput :+ Alias(Literal(RowKind.INSERT.toByteValue), ROW_KIND_COL)()

def processMergeActions(actions: Seq[MergeAction], applyOnTargetTable: Boolean) = {
actions.map {
case UpdateAction(_, assignments) if applyOnTargetTable =>
assignments.map(_.value) :+ Literal(RowKind.UPDATE_AFTER.toByteValue)
case DeleteAction(_) if applyOnTargetTable =>
if (withPrimaryKeys) {
targetOutput :+ Literal(RowKind.DELETE.toByteValue)
} else {
noopOutput
}
case InsertAction(_, assignments) if !applyOnTargetTable =>
assignments.map(_.value) :+ Literal(RowKind.INSERT.toByteValue)
case _ =>
throw new RuntimeException("should not be here.")
}
}

val matchedOutputs = processMergeActions(matchedActions, applyOnTargetTable = true)
val notMatchedBySourceOutputs =
processMergeActions(notMatchedBySourceActions, applyOnTargetTable = true)
val notMatchedOutputs = processMergeActions(notMatchedActions, applyOnTargetTable = false)
val outputSchema = StructType(tableSchema.fields :+ StructField(ROW_KIND_COL, ByteType))

val joinedRowEncoder = EncoderUtils.encode(joinedPlan.schema)
Expand All @@ -139,10 +193,11 @@ case class MergeIntoPaimonTable(
notMatchedExprs,
notMatchedOutputs,
noopOutput,
keepOutput,
joinedRowEncoder,
outputEncoder
)
joinedDS.mapPartitions(processor.processPartition)(outputEncoder)
joinedDS.mapPartitions(processor.processPartition)(outputEncoder).toDF()
}

private def checkMatchRationality(sparkSession: SparkSession): Unit = {
Expand All @@ -159,21 +214,23 @@ case class MergeIntoPaimonTable(
.count()
if (count > 0) {
throw new RuntimeException(
"Can't execute this MergeInto when there are some target rows that each of them match more then one source rows. It may lead to an unexpected result.")
"Can't execute this MergeInto when there are some target rows that each of " +
"them match more then one source rows. It may lead to an unexpected result.")
}
}
}
}

object MergeIntoPaimonTable {
val ROW_ID_COL = "_row_id_"
val SOURCE_ROW_COL = "_source_row_"
val TARGET_ROW_COL = "_target_row_"
private val ROW_ID_COL = "_row_id_"
private val SOURCE_ROW_COL = "_source_row_"
private val TARGET_ROW_COL = "_target_row_"
private val FILE_TOUCHED_COL = "_file_touched_col_"
// +I, +U, -U, -D
val ROW_KIND_COL: String = SparkSystemColumns.ROW_KIND_COL
val NOOP_ROW_KIND_VALUE: Byte = "-1".toByte
private val ROW_KIND_COL: String = SparkSystemColumns.ROW_KIND_COL
private val NOOP_ROW_KIND_VALUE: Byte = "-1".toByte

case class MergeIntoProcessor(
private case class MergeIntoProcessor(
joinedAttributes: Seq[Attribute],
targetRowHasNoMatch: Expression,
sourceRowHasNoMatch: Expression,
Expand All @@ -184,10 +241,15 @@ object MergeIntoPaimonTable {
notMatchedConditions: Seq[Expression],
notMatchedOutputs: Seq[Seq[Expression]],
noopCopyOutput: Seq[Expression],
keepOutput: Seq[Expression],
joinedRowEncoder: ExpressionEncoder[Row],
outputRowEncoder: ExpressionEncoder[Row]
) extends Serializable {

private val file_touched_col_index: Int =
SparkRowUtils.getFieldIndex(joinedRowEncoder.schema, FILE_TOUCHED_COL)
private val row_kind_col_index: Int = outputRowEncoder.schema.fieldIndex(ROW_KIND_COL)

private def generateProjection(exprs: Seq[Expression]): UnsafeProjection = {
UnsafeProjection.create(exprs, joinedAttributes)
}
Expand All @@ -196,8 +258,13 @@ object MergeIntoPaimonTable {
GeneratePredicate.generate(expr, joinedAttributes)
}

// keep row if it is from touched file and not be matched
private def keepRow(row: InternalRow): Boolean = {
file_touched_col_index != -1 && row.getBoolean(file_touched_col_index)
}

private def unusedRow(row: InternalRow): Boolean = {
row.getByte(outputRowEncoder.schema.fieldIndex(ROW_KIND_COL)) == NOOP_ROW_KIND_VALUE
row.getByte(row_kind_col_index) == NOOP_ROW_KIND_VALUE
}

def processPartition(rowIterator: Iterator[Row]): Iterator[Row] = {
Expand All @@ -210,38 +277,28 @@ object MergeIntoPaimonTable {
val notMatchedPreds = notMatchedConditions.map(generatePredicate)
val notMatchedProjs = notMatchedOutputs.map(generateProjection)
val noopCopyProj = generateProjection(noopCopyOutput)
val keepProj = generateProjection(keepOutput)
val outputProj = UnsafeProjection.create(outputRowEncoder.schema)

def processRow(inputRow: InternalRow): InternalRow = {
if (targetRowHasNoMatchPred.eval(inputRow)) {
val pair = notMatchedBySourcePreds.zip(notMatchedBySourceProjs).find {
case (predicate, _) => predicate.eval(inputRow)
def applyPreds(preds: Seq[BasePredicate], projs: Seq[UnsafeProjection]): InternalRow = {
preds.zip(projs).find { case (predicate, _) => predicate.eval(inputRow) } match {
case Some((_, projections)) => projections.apply(inputRow)
case None =>
if (keepRow(inputRow)) {
keepProj.apply(inputRow)
} else {
noopCopyProj.apply(inputRow)
}
}
}

pair match {
case Some((_, projections)) =>
projections.apply(inputRow)
case None => noopCopyProj.apply(inputRow)
}
if (targetRowHasNoMatchPred.eval(inputRow)) {
applyPreds(notMatchedBySourcePreds, notMatchedBySourceProjs)
} else if (sourceRowHasNoMatchPred.eval(inputRow)) {
val pair = notMatchedPreds.zip(notMatchedProjs).find {
case (predicate, _) => predicate.eval(inputRow)
}

pair match {
case Some((_, projections)) =>
projections.apply(inputRow)
case None => noopCopyProj.apply(inputRow)
}
applyPreds(notMatchedPreds, notMatchedProjs)
} else {
val pair =
matchedPreds.zip(matchedProjs).find { case (predicate, _) => predicate.eval(inputRow) }

pair match {
case Some((_, projections)) =>
projections.apply(inputRow)
case None => noopCopyProj.apply(inputRow)
}
applyPreds(matchedPreds, matchedProjs)
}
}

Expand Down
Loading

0 comments on commit 48394ce

Please sign in to comment.