diff --git a/docs/content/engines/spark3.md b/docs/content/engines/spark3.md index 414679bb7e8e..07ffd1ffcee6 100644 --- a/docs/content/engines/spark3.md +++ b/docs/content/engines/spark3.md @@ -221,18 +221,18 @@ dataset.show() ## Update Table -For now, Paimon does not support `UPDATE` syntax. But we can use `INSERT INTO` syntax instead for changelog tables. - -```sql -INSERT INTO my_table VALUES (1, 'Hi Again'), (3, 'Test'); +{{< hint info >}} +Important table properties setting: +1. Only [primary key table]({{< ref "concepts/primary-key-table" >}}) supports this feature. +2. [MergeEngine]({{< ref "concepts/primary-key-table#merge-engines" >}}) needs to be [deduplicate]({{< ref "concepts/primary-key-table#deduplicate" >}}) or [partial-update]({{< ref "concepts/primary-key-table#partial-update" >}}) to support this feature. + {{< /hint >}} -SELECT * FROM my_table; +{{< hint warning >}} +Warning: we do not support updating primary keys. +{{< /hint >}} -/* -1 Hi Again -2 Hello -3 Test -*/ +```sql +UPDATE my_table SET v = 'new_value' WHERE id = 1; ``` ## Streaming Write diff --git a/docs/content/how-to/writing-tables.md b/docs/content/how-to/writing-tables.md index a2b56dc764b9..adbfc2c76937 100644 --- a/docs/content/how-to/writing-tables.md +++ b/docs/content/how-to/writing-tables.md @@ -329,8 +329,6 @@ For more information of drop-partition, see ## Updating tables -Currently, Paimon supports updating records by using `UPDATE` in Flink 1.17 and later versions. You can perform `UPDATE` in Flink's `batch` mode. - {{< hint info >}} Important table properties setting: 1. Only [primary key table]({{< ref "concepts/primary-key-table" >}}) supports this feature. @@ -345,6 +343,8 @@ Warning: we do not support updating primary keys. {{< tab "Flink" >}} +Currently, Paimon supports updating records by using `UPDATE` in Flink 1.17 and later versions. You can perform `UPDATE` in Flink's `batch` mode. + ```sql -- Syntax UPDATE table_identifier SET column1 = value1, column2 = value2, ... WHERE condition; @@ -366,6 +366,37 @@ UPDATE MyTable SET b = 1, c = 2 WHERE a = 'myTable'; {{< /tab >}} +{{< tab "Spark" >}} + +To enable update needs these configs below: + +```text +--conf spark.sql.catalog.spark_catalog=org.apache.paimon.spark.SparkGenericCatalog +--conf spark.sql.extensions=org.apache.paimon.spark.extensions.PaimonSparkSessionExtensions +``` + +spark supports update PrimitiveType and StructType, for example: + +```sql +-- Syntax +UPDATE table_identifier SET column1 = value1, column2 = value2, ... WHERE condition; + +CREATE TABLE T ( + id INT, + s STRUCT, + name STRING) +TBLPROPERTIES ( + 'primary-key' = 'id', + 'merge-engine' = 'deduplicate' +); + +-- you can use +UPDATE T SET name = 'a_new' WHERE id = 1; +UPDATE T SET s.c2 = 'a_new' WHERE s.c1 = 1; +``` + +{{< /tab >}} + {{< /tabs >}} ## Deleting from table diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala index 3c9f7a76b5a7..d5006ea645eb 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala @@ -18,8 +18,7 @@ package org.apache.paimon.spark.extensions import org.apache.spark.sql.SparkSessionExtensions -import org.apache.spark.sql.catalyst.analysis.{CoerceArguments, PaimonAnalysis, ResolveProcedures} -import org.apache.spark.sql.catalyst.optimizer.RewriteRowLeverCommands +import org.apache.spark.sql.catalyst.analysis.{CoerceArguments, PaimonAnalysis, ResolveProcedures, RewriteRowLevelCommands} import org.apache.spark.sql.catalyst.parser.extensions.PaimonSparkSqlExtensionsParser import org.apache.spark.sql.catalyst.plans.logical.PaimonTableValuedFunctions import org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Strategy @@ -35,9 +34,7 @@ class PaimonSparkSessionExtensions extends (SparkSessionExtensions => Unit) { extensions.injectResolutionRule(sparkSession => new PaimonAnalysis(sparkSession)) extensions.injectResolutionRule(spark => ResolveProcedures(spark)) extensions.injectResolutionRule(_ => CoerceArguments) - - // optimizer extensions - extensions.injectOptimizerRule(_ => RewriteRowLeverCommands) + extensions.injectPostHocResolutionRule(_ => RewriteRowLevelCommands) // table function extensions PaimonTableValuedFunctions.supportedFnNames.foreach { diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/schema/SparkSystemColumns.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/schema/SparkSystemColumns.scala index bae79127c9b0..3b42dba1bb99 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/schema/SparkSystemColumns.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/schema/SparkSystemColumns.scala @@ -25,7 +25,7 @@ object SparkSystemColumns { // for assigning bucket when writing val BUCKET_COL = "_bucket_" - // for row lever operation + // for row level operation val ROW_KIND_COL = "_row_kind_" val SPARK_SYSTEM_COLUMNS_NAME: Seq[String] = Seq(BUCKET_COL, ROW_KIND_COL) diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/RowLevelOp.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/RowLevelOp.scala new file mode 100644 index 000000000000..620ab814aeb4 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/RowLevelOp.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.paimon.CoreOptions.MergeEngine + +sealed trait RowLevelOp { + val supportedMergeEngine: Seq[MergeEngine] +} + +case object Delete extends RowLevelOp { + override def toString: String = "delete" + + override val supportedMergeEngine: Seq[MergeEngine] = Seq(MergeEngine.DEDUPLICATE) +} + +case object Update extends RowLevelOp { + override def toString: String = "update" + + override val supportedMergeEngine: Seq[MergeEngine] = + Seq(MergeEngine.DEDUPLICATE, MergeEngine.PARTIAL_UPDATE) +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentAlignmentHelper.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentAlignmentHelper.scala new file mode 100644 index 000000000000..33bec58da20f --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentAlignmentHelper.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, CreateNamedStruct, Expression, GetStructField, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.plans.logical.Assignment +import org.apache.spark.sql.types.StructType + +trait AssignmentAlignmentHelper extends SQLConfHelper { + + private lazy val resolver = conf.resolver + + /** + * @param ref + * attribute reference seq, e.g. a => Seq["a"], s.c1 => Seq["s", "c1"] + * @param expr + * update expression + */ + private case class AttrUpdate(ref: Seq[String], expr: Expression) + + /** + * Align update assignments to the given attrs, only supports PrimitiveType and StructType. For + * example, if attrs are [a int, b int, s struct(c1 int, c2 int)] and update assignments are [a = + * 1, s.c1 = 2], will return [1, b, struct(2, c2)]. + * @param attrs + * target attrs + * @param assignments + * update assignments + * @return + * aligned update expressions + */ + protected def alignUpdateAssignments( + attrs: Seq[Attribute], + assignments: Seq[Assignment]): Seq[Expression] = { + val attrUpdates = assignments.map(a => AttrUpdate(toRefSeq(a.key), a.value)) + recursiveAlignUpdates(attrs, attrUpdates) + } + + def toRefSeq(expr: Expression): Seq[String] = expr match { + case attr: Attribute => + Seq(attr.name) + case GetStructField(child, _, Some(name)) => + toRefSeq(child) :+ name + case other => + throw new UnsupportedOperationException( + s"Unsupported update expression: $other, only support update with PrimitiveType and StructType.") + } + + private def recursiveAlignUpdates( + targetAttrs: Seq[NamedExpression], + updates: Seq[AttrUpdate], + namePrefix: Seq[String] = Nil): Seq[Expression] = { + + // build aligned updated expression for each target attr + targetAttrs.map { + targetAttr => + val headMatchedUpdates = updates.filter(u => resolver(u.ref.head, targetAttr.name)) + if (headMatchedUpdates.isEmpty) { + // when no matched update, return the attr as is + targetAttr + } else { + val exactMatchedUpdate = headMatchedUpdates.find(_.ref.size == 1) + if (exactMatchedUpdate.isDefined) { + if (headMatchedUpdates.size == 1) { + // when an exact match (no nested fields) occurs, it must be the only match, then return it's expr + exactMatchedUpdate.get.expr + } else { + // otherwise, there must be conflicting updates, for example: + // - update the same attr multiple times + // - update a struct attr and its fields at the same time (e.g. s and s.c1) + val conflictingAttrNames = + headMatchedUpdates.map(u => (namePrefix ++ u.ref).mkString(".")).distinct + throw new UnsupportedOperationException( + s"Conflicting updates on attrs: ${conflictingAttrNames.mkString(", ")}" + ) + } + } else { + targetAttr.dataType match { + case StructType(fields) => + val fieldExprs = fields.zipWithIndex.map { + case (field, ordinal) => + Alias(GetStructField(targetAttr, ordinal, Some(field.name)), field.name)() + } + val newUpdates = updates.map(u => u.copy(ref = u.ref.tail)) + // process StructType's nested fields recursively + val updatedFieldExprs = + recursiveAlignUpdates(fieldExprs, newUpdates, namePrefix :+ targetAttr.name) + + // build updated struct expression + CreateNamedStruct(fields.zip(updatedFieldExprs).flatMap { + case (field, expr) => + Seq(Literal(field.name), expr) + }) + case _ => + // can't reach here + throw new UnsupportedOperationException("") + } + } + } + } + } + +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommands.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommands.scala new file mode 100644 index 000000000000..b5cf04c49bca --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommands.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.analysis + +import org.apache.paimon.CoreOptions.MERGE_ENGINE +import org.apache.paimon.options.Options +import org.apache.paimon.spark.SparkTable +import org.apache.paimon.table.Table + +import org.apache.spark.sql.{AnalysisException, Delete, RowLevelOp, Update} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PredicateHelper, SubqueryExpression} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.SupportsDelete +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation + +import java.util + +object RewriteRowLevelCommands extends Rule[LogicalPlan] with PredicateHelper { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + case d @ DeleteFromTable(ResolvesToPaimonTable(table), condition) => + validateRowLevelOp(Delete, table.getTable, Option.empty) + if (canDeleteWhere(d, table, condition)) { + d + } else { + DeleteFromPaimonTableCommand(d) + } + + case u @ UpdateTable(ResolvesToPaimonTable(table), assignments, _) => + validateRowLevelOp(Update, table.getTable, Option.apply(assignments)) + UpdatePaimonTableCommand(u) + } + + private object ResolvesToPaimonTable { + def unapply(plan: LogicalPlan): Option[SparkTable] = + EliminateSubqueryAliases(plan) match { + case DataSourceV2Relation(table: SparkTable, _, _, _, _) => Some(table) + case _ => None + } + } + + private def validateRowLevelOp( + op: RowLevelOp, + table: Table, + assignments: Option[Seq[Assignment]]): Unit = { + val options = Options.fromMap(table.options) + val primaryKeys = table.primaryKeys() + if (primaryKeys.isEmpty) { + throw new UnsupportedOperationException( + s"table ${table.getClass.getName} can not support $op, because there is no primary key.") + } + + if (op.equals(Update) && isPrimaryKeyUpdate(primaryKeys, assignments.get)) { + throw new UnsupportedOperationException(s"$op to primary keys is not supported.") + } + + val mergeEngine = options.get(MERGE_ENGINE) + if (!op.supportedMergeEngine.contains(mergeEngine)) { + throw new UnsupportedOperationException( + s"merge engine $mergeEngine can not support $op, currently only ${op.supportedMergeEngine + .mkString(", ")} can support $op.") + } + } + + private def canDeleteWhere( + d: DeleteFromTable, + table: SparkTable, + condition: Expression): Boolean = { + table match { + case t: SupportsDelete if !SubqueryExpression.hasSubquery(condition) => + // fail if any filter cannot be converted. + // correctness depends on removing all matching data. + val filters = DataSourceStrategy + .normalizeExprs(Seq(condition), d.output) + .flatMap(splitConjunctivePredicates(_).map { + f => + DataSourceStrategy + .translateFilter(f, supportNestedPredicatePushdown = true) + .getOrElse(throw new AnalysisException(s"Exec update failed:" + + s" cannot translate expression to source filter: $f")) + }) + .toArray + t.canDeleteWhere(filters) + case _ => false + } + } + + private def isPrimaryKeyUpdate( + primaryKeys: util.List[String], + assignments: Seq[Assignment]): Boolean = { + assignments.exists( + a => { + a.key match { + case attr: Attribute => primaryKeys.contains(attr.name) + case _ => false + } + }) + } + +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteRowLeverCommands.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteRowLeverCommands.scala deleted file mode 100644 index 807b851e8fe8..000000000000 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteRowLeverCommands.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.paimon.CoreOptions.{MERGE_ENGINE, MergeEngine} -import org.apache.paimon.options.Options -import org.apache.paimon.spark.SparkTable -import org.apache.paimon.table.Table - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Expression, PredicateHelper, SubqueryExpression} -import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromPaimonTableCommand, DeleteFromTable, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.catalog.SupportsDelete -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation - -object RewriteRowLeverCommands extends Rule[LogicalPlan] with PredicateHelper { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { - case d @ DeleteFromTable(r: DataSourceV2Relation, condition) => - validateDeletable(r.table.asInstanceOf[SparkTable].getTable) - if (canDeleteWhere(r, condition)) { - d - } else { - DeleteFromPaimonTableCommand(r, condition) - } - } - - private def validateDeletable(table: Table): Boolean = { - val options = Options.fromMap(table.options) - if (table.primaryKeys().isEmpty) { - throw new UnsupportedOperationException( - String.format( - "table '%s' can not support delete, because there is no primary key.", - table.getClass.getName)) - } - if (!options.get(MERGE_ENGINE).equals(MergeEngine.DEDUPLICATE)) { - throw new UnsupportedOperationException( - String.format( - "merge engine '%s' can not support delete, currently only %s can support delete.", - options.get(MERGE_ENGINE), - MergeEngine.DEDUPLICATE)) - } - true - } - - private def canDeleteWhere(relation: DataSourceV2Relation, condition: Expression): Boolean = { - relation.table match { - case t: SupportsDelete if !SubqueryExpression.hasSubquery(condition) => - // fail if any filter cannot be converted. - // correctness depends on removing all matching data. - val filters = DataSourceStrategy - .normalizeExprs(Seq(condition), relation.output) - .flatMap(splitConjunctivePredicates(_).map { - f => - DataSourceStrategy - .translateFilter(f, supportNestedPredicatePushdown = true) - .getOrElse(throw new AnalysisException(s"Exec update failed:" + - s" cannot translate expression to source filter: $f")) - }) - .toArray - t.canDeleteWhere(filters) - case _ => false - } - } - -} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DeleteFromPaimonTableCommand.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DeleteFromPaimonTableCommand.scala index f380533a883a..92833cf1a667 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DeleteFromPaimonTableCommand.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DeleteFromPaimonTableCommand.scala @@ -25,20 +25,23 @@ import org.apache.paimon.table.FileStoreTable import org.apache.paimon.types.RowKind import org.apache.spark.sql.{Dataset, Row, SparkSession} -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.execution.command.LeafRunnableCommand import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.functions.lit -case class DeleteFromPaimonTableCommand(relation: DataSourceV2Relation, condition: Expression) - extends LeafRunnableCommand { +case class DeleteFromPaimonTableCommand(d: DeleteFromTable) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { + val relation = EliminateSubqueryAliases(d.table).asInstanceOf[DataSourceV2Relation] + val condition = d.condition + val filteredPlan = if (condition != null) { Filter(condition, relation) } else { relation } + val df = Dataset .ofRows(sparkSession, filteredPlan) .withColumn(ROW_KIND_COL, lit(RowKind.DELETE.toByteValue)) diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/UpdatePaimonTableCommand.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/UpdatePaimonTableCommand.scala new file mode 100644 index 000000000000..c9ec48a16ee4 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/UpdatePaimonTableCommand.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.paimon.options.Options +import org.apache.paimon.spark.{InsertInto, SparkTable} +import org.apache.paimon.spark.commands.WriteIntoPaimonTable +import org.apache.paimon.spark.schema.SparkSystemColumns.ROW_KIND_COL +import org.apache.paimon.table.FileStoreTable +import org.apache.paimon.types.RowKind + +import org.apache.spark.sql.{Dataset, Row, SparkSession} +import org.apache.spark.sql.catalyst.analysis.{AssignmentAlignmentHelper, EliminateSubqueryAliases} +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, If} +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.functions.lit + +case class UpdatePaimonTableCommand(u: UpdateTable) + extends LeafRunnableCommand + with AssignmentAlignmentHelper { + + override def run(sparkSession: SparkSession): Seq[Row] = { + + val relation = EliminateSubqueryAliases(u.table).asInstanceOf[DataSourceV2Relation] + + val updatedExprs: Seq[Alias] = + alignUpdateAssignments(relation.output, u.assignments).zip(relation.output).map { + case (expr, attr) => Alias(expr, attr.name)() + } + + val updatedPlan = Project(updatedExprs, Filter(u.condition.getOrElse(TrueLiteral), relation)) + + val df = Dataset + .ofRows(sparkSession, updatedPlan) + .withColumn(ROW_KIND_COL, lit(RowKind.UPDATE_AFTER.toByteValue)) + + WriteIntoPaimonTable( + relation.table.asInstanceOf[SparkTable].getTable.asInstanceOf[FileStoreTable], + InsertInto, + df, + Options.fromMap(relation.options)).run(sparkSession) + + Seq.empty[Row] + } +} diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DeleteFromTableTest.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DeleteFromTableTest.scala index 2008d777599e..7623e08ab26a 100644 --- a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DeleteFromTableTest.scala +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DeleteFromTableTest.scala @@ -20,6 +20,7 @@ package org.apache.paimon.spark.sql import org.apache.paimon.CoreOptions import org.apache.paimon.spark.PaimonSparkTestBase +import org.apache.spark.sql.Delete import org.assertj.core.api.Assertions.{assertThat, assertThatThrownBy} class DeleteFromTableTest extends PaimonSparkTestBase { @@ -39,7 +40,6 @@ class DeleteFromTableTest extends PaimonSparkTestBase { mergeEngine => { test(s"test delete with merge engine $mergeEngine") { - val supportUpdateEngines = Seq("deduplicate") val options = if ("first-row".equals(mergeEngine.toString)) { s"'primary-key' = 'id', 'merge-engine' = '$mergeEngine', 'changelog-producer' = 'lookup'" } else { @@ -52,7 +52,7 @@ class DeleteFromTableTest extends PaimonSparkTestBase { spark.sql("INSERT INTO T VALUES (1, 'a', '11'), (2, 'b', '22')") - if (supportUpdateEngines.contains(mergeEngine.toString)) { + if (Delete.supportedMergeEngine.contains(mergeEngine)) { spark.sql("DELETE FROM T WHERE name = 'a'") } else assertThatThrownBy(() => spark.sql("DELETE FROM T WHERE name = 'a'")) diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/UpdateTableTest.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/UpdateTableTest.scala new file mode 100644 index 000000000000..52f83d7572d9 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/UpdateTableTest.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.paimon.spark.sql + +import org.apache.paimon.CoreOptions +import org.apache.paimon.spark.PaimonSparkTestBase + +import org.apache.spark.sql.Update +import org.assertj.core.api.Assertions.{assertThat, assertThatThrownBy} + +class UpdateTableTest extends PaimonSparkTestBase { + + test(s"test update append only table") { + spark.sql(s""" + |CREATE TABLE T (id INT, name STRING, dt STRING) + |""".stripMargin) + + spark.sql("INSERT INTO T VALUES (1, 'a', '11'), (2, 'b', '22')") + + assertThatThrownBy(() => spark.sql("UPDATE T SET name = 'a_new' WHERE id = 1")) + .hasMessageContaining("can not support update, because there is no primary key") + } + + CoreOptions.MergeEngine.values().foreach { + mergeEngine => + { + test(s"test update with merge engine $mergeEngine") { + val options = if ("first-row".equals(mergeEngine.toString)) { + s"'primary-key' = 'id', 'merge-engine' = '$mergeEngine', 'changelog-producer' = 'lookup'" + } else { + s"'primary-key' = 'id', 'merge-engine' = '$mergeEngine'" + } + spark.sql(s""" + |CREATE TABLE T (id INT, name STRING, dt STRING) + |TBLPROPERTIES ($options) + |""".stripMargin) + + spark.sql("INSERT INTO T VALUES (1, 'a', '11'), (2, 'b', '22')") + + if (Update.supportedMergeEngine.contains(mergeEngine)) { + spark.sql("UPDATE T SET name = 'a_new' WHERE id = 1") + val rows = spark.sql("SELECT * FROM T ORDER BY id").collectAsList() + assertThat(rows.toString).isEqualTo("[[1,a_new,11], [2,b,22]]") + } else + assertThatThrownBy(() => spark.sql("UPDATE T SET name = 'a_new' WHERE id = 1")) + .isInstanceOf(classOf[UnsupportedOperationException]) + } + } + } + + test(s"test update with primary key") { + spark.sql(s""" + |CREATE TABLE T (id INT, name STRING, dt STRING) + |TBLPROPERTIES ('primary-key' = 'id', 'merge-engine' = 'deduplicate') + |""".stripMargin) + + spark.sql("INSERT INTO T VALUES (1, 'a', '11'), (2, 'b', '22'), (3, 'c', '33')") + + assertThatThrownBy(() => spark.sql("UPDATE T SET id = 11 WHERE name = 'a'")) + .hasMessageContaining("update to primary keys is not supported") + } + + test(s"test update with no where") { + spark.sql(s""" + |CREATE TABLE T (id INT, name STRING, dt STRING) + |TBLPROPERTIES ('primary-key' = 'id, dt', 'merge-engine' = 'deduplicate') + |PARTITIONED BY (id) + |""".stripMargin) + + spark.sql("INSERT INTO T VALUES (1, 'a', '11'), (2, 'b', '22'), (3, 'c', '33')") + + spark.sql("UPDATE T SET name = 'a_new'") + val rows = spark.sql("SELECT * FROM T ORDER BY id").collectAsList() + assertThat(rows.toString).isEqualTo("[[1,a_new,11], [2,a_new,22], [3,a_new,33]]") + } + + test(s"test update with alias") { + spark.sql(s""" + |CREATE TABLE T (id INT, name STRING, dt STRING) + |TBLPROPERTIES ('primary-key' = 'id, dt', 'merge-engine' = 'deduplicate') + |PARTITIONED BY (id) + |""".stripMargin) + + spark.sql("INSERT INTO T VALUES (1, 'a', '11'), (2, 'b', '22'), (3, 'c', '33')") + + spark.sql("UPDATE T AS t SET t.name = 'a_new' where id = 1") + val rows = spark.sql("SELECT * FROM T ORDER BY id").collectAsList() + assertThat(rows.toString).isEqualTo("[[1,a_new,11], [2,b,22], [3,c,33]]") + } + + test(s"test update with alias assignment") { + spark.sql(s""" + |CREATE TABLE T (id INT, c1 INT, c2 INT) + |TBLPROPERTIES ('primary-key' = 'id', 'merge-engine' = 'deduplicate') + |""".stripMargin) + + spark.sql("INSERT INTO T VALUES (1, 1, 11), (2, 2, 22), (3, 3, 33)") + + spark.sql("UPDATE T set c1 = c1 + 1, c2 = c2 + 1 where id = 1") + val rows = spark.sql("SELECT * FROM T ORDER BY id").collectAsList() + assertThat(rows.toString).isEqualTo("[[1,2,12], [2,2,22], [3,3,33]]") + } + + test(s"test update with in condition and not in condition") { + spark.sql(s""" + |CREATE TABLE T (id INT, name STRING, dt STRING) + |TBLPROPERTIES ('primary-key' = 'id, dt', 'merge-engine' = 'deduplicate') + |PARTITIONED BY (dt) + |""".stripMargin) + + spark.sql("INSERT INTO T VALUES (1, 'a', '11'), (2, 'b', '22'), (3, 'c', '33')") + + spark.sql("UPDATE T set name = 'in_new' WHERE id IN (1)") + val rows1 = spark.sql("SELECT * FROM T ORDER BY id").collectAsList() + assertThat(rows1.toString).isEqualTo("[[1,in_new,11], [2,b,22], [3,c,33]]") + + spark.sql("UPDATE T set name = 'not_in_new' WHERE id NOT IN (2)") + val rows2 = spark.sql("SELECT * FROM T ORDER BY id").collectAsList() + assertThat(rows2.toString).isEqualTo("[[1,not_in_new,11], [2,b,22], [3,not_in_new,33]]") + } + + test(s"test update with in subquery") { + spark.sql(s""" + |CREATE TABLE T (id INT, name STRING, dt STRING) + |TBLPROPERTIES ('primary-key' = 'id, dt', 'merge-engine' = 'deduplicate') + |PARTITIONED BY (dt) + |""".stripMargin) + + spark.sql("INSERT INTO T VALUES (1, 'a', '11'), (2, 'b', '22'), (3, 'c', '33')") + + import testImplicits._ + val df = Seq(1, 2).toDF("id") + df.createOrReplaceTempView("updated_ids") + spark.sql("UPDATE T set name = 'in_new' WHERE id IN (SELECT * FROM updated_ids)") + val rows = spark.sql("SELECT * FROM T ORDER BY id").collectAsList() + assertThat(rows.toString).isEqualTo("[[1,in_new,11], [2,in_new,22], [3,c,33]]") + } + + test(s"test update with self subquery") { + spark.sql(s""" + |CREATE TABLE T (id INT, name STRING, dt STRING) + |TBLPROPERTIES ('primary-key' = 'id, dt', 'merge-engine' = 'deduplicate') + |PARTITIONED BY (dt) + |""".stripMargin) + + spark.sql("INSERT INTO T VALUES (1, 'a', '11'), (2, 'b', '22'), (3, 'c', '33')") + + spark.sql("UPDATE T set name = 'in_new' WHERE id IN (SELECT id + 1 FROM T)") + val rows = spark.sql("SELECT * FROM T ORDER BY id").collectAsList() + assertThat(rows.toString).isEqualTo("[[1,a,11], [2,in_new,22], [3,in_new,33]]") + } + + test(s"test update with various column references") { + spark.sql(s""" + |CREATE TABLE T (id INT, c1 INT, c2 INT, dt STRING) + |TBLPROPERTIES ('primary-key' = 'id, dt', 'merge-engine' = 'deduplicate') + |PARTITIONED BY (dt) + |""".stripMargin) + + spark.sql("INSERT INTO T VALUES (1, 1, 10, '11'), (2, 2, 20, '22'), (3, 3, 300, '33')") + + spark.sql("UPDATE T SET c1 = c2 + 1, c2 = 1000") + val rows = spark.sql("SELECT * FROM T ORDER BY id").collectAsList() + assertThat(rows.toString).isEqualTo("[[1,11,1000,11], [2,21,1000,22], [3,301,1000,33]]") + } + + test(s"test update with struct column") { + spark.sql(s""" + |CREATE TABLE T (id INT, s STRUCT, dt STRING) + |TBLPROPERTIES ('primary-key' = 'id, dt', 'merge-engine' = 'deduplicate') + |PARTITIONED BY (dt) + |""".stripMargin) + + spark.sql( + "INSERT INTO T VALUES (1, struct(1, 'a'), '11'), (2, struct(2, 'b'), '22'), (3, struct(3, 'c'), '33')") + + spark.sql("UPDATE T SET s.c2 = 'a_new' WHERE s.c1 = 1") + val rows = spark.sql("SELECT * FROM T ORDER BY id").collectAsList() + assertThat(rows.toString).isEqualTo("[[1,[1,a_new],11], [2,[2,b],22], [3,[3,c],33]]") + } + + test(s"test update with map column") { + spark.sql(s""" + |CREATE TABLE T (id INT, m MAP, dt STRING) + |TBLPROPERTIES ('primary-key' = 'id, dt', 'merge-engine' = 'deduplicate') + |PARTITIONED BY (dt) + |""".stripMargin) + + spark.sql( + "INSERT INTO T VALUES (1, map(1, 'a'), '11'), (2, map(2, 'b'), '22'), (3, map(3, 'c'), '33')") + + assertThatThrownBy(() => spark.sql("UPDATE T SET m.key = 11 WHERE id = 1")) + .hasMessageContaining("Unsupported update expression") + + spark.sql("UPDATE T SET m = map(11, 'a_new') WHERE id = 1") + val rows = spark.sql("SELECT * FROM T ORDER BY id").collectAsList() + assertThat(rows.toString).isEqualTo( + "[[1,Map(11 -> a_new),11], [2,Map(2 -> b),22], [3,Map(3 -> c),33]]") + } + + test(s"test update with conflicted column") { + spark.sql(s""" + |CREATE TABLE T (id INT, s STRUCT, dt STRING) + |TBLPROPERTIES ('primary-key' = 'id, dt', 'merge-engine' = 'deduplicate') + |PARTITIONED BY (dt) + |""".stripMargin) + + spark.sql( + "INSERT INTO T VALUES (1, struct(1, 'a'), '11'), (2, struct(2, 'b'), '22'), (3, struct(3, 'c'), '33')") + + assertThatThrownBy( + () => spark.sql("UPDATE T SET s.c2 = 'a_new', s = struct(11, 'a_new') WHERE s.c1 = 1")) + .hasMessageContaining("Conflicting updates on attrs: s.c2, s") + } +}