diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java index ea61f6f7d9e8..2ea2e3c45347 100644 --- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java @@ -166,6 +166,11 @@ public Object convertLiteral(String field, Object value) { return convertLiteral(fieldIndex(field), value); } + public String convertString(String field, Object value) { + Object literal = convertLiteral(field, value); + return literal == null ? null : literal.toString(); + } + private int fieldIndex(String field) { int index = rowType.getFieldIndex(field); // TODO: support nested field diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWriteBuilder.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWriteBuilder.scala index 96f75ab10fea..74a474b8c354 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWriteBuilder.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWriteBuilder.scala @@ -21,18 +21,70 @@ package org.apache.paimon.spark import org.apache.paimon.options.Options import org.apache.paimon.table.FileStoreTable +import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.connector.write.{SupportsOverwrite, WriteBuilder} -import org.apache.spark.sql.sources.{And, Filter} +import org.apache.spark.sql.sources.{AlwaysFalse, AlwaysTrue, And, EqualNullSafe, EqualTo, Filter, Not, Or} + +import scala.collection.JavaConverters._ private class SparkWriteBuilder(table: FileStoreTable, options: Options) extends WriteBuilder - with SupportsOverwrite { + with SupportsOverwrite + with SQLConfHelper { private var saveMode: SaveMode = InsertInto override def build = new SparkWrite(table, saveMode, options) + private def failWithReason(filter: Filter): Unit = { + throw new RuntimeException( + s"Only support Overwrite filters with Equal and EqualNullSafe, but got: $filter") + } + + private def validateFilter(filter: Filter): Unit = filter match { + case And(left, right) => + validateFilter(left) + validateFilter(right) + case _: Or => failWithReason(filter) + case _: Not => failWithReason(filter) + case e: EqualTo if e.references.length == 1 && !e.value.isInstanceOf[Filter] => + case e: EqualNullSafe if e.references.length == 1 && !e.value.isInstanceOf[Filter] => + case _: AlwaysTrue | _: AlwaysFalse => + case _ => failWithReason(filter) + } + + // `SupportsOverwrite#canOverwrite` is added since Spark 3.4.0. + // We do this checking by self to work with previous Spark version. + private def failIfCanNotOverwrite(filters: Array[Filter]): Unit = { + // For now, we only support overwrite with two cases: + // - overwrite with partition columns to be compatible with v1 insert overwrite + // See [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveInsertInto#staticDeleteExpression]]. + // - truncate-like overwrite and the filter is always true. + // + // Fast fail for other custom filters which through v2 write interface, e.g., + // `dataframe.writeTo(T).overwrite(...)` + val partitionRowType = table.schema.logicalPartitionType() + val partitionNames = partitionRowType.getFieldNames.asScala + val allReferences = filters.flatMap(_.references) + val containsDataColumn = allReferences.exists { + reference => !partitionNames.exists(conf.resolver.apply(reference, _)) + } + if (containsDataColumn) { + throw new RuntimeException( + s"Only support Overwrite filters on partition column ${partitionNames.mkString( + ", ")}, but got ${filters.mkString(", ")}.") + } + if (allReferences.distinct.length < allReferences.length) { + // fail with `part = 1 and part = 2` + throw new RuntimeException( + s"Only support Overwrite with one filter for each partition column, but got ${filters.mkString(", ")}.") + } + filters.foreach(validateFilter) + } + override def overwrite(filters: Array[Filter]): WriteBuilder = { + failIfCanNotOverwrite(filters) + val conjunctiveFilters = if (filters.nonEmpty) { Some(filters.reduce((l, r) => And(l, r))) } else { diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala index 307cab734c56..98d3c03aacbb 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala @@ -119,8 +119,7 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] { .mkString(", ") // There are seme unknown column names throw new RuntimeException( - s"Cannot write incompatible data for the table `${table.name}`, due to unknown column names: ${extraCols - .mkString(", ")}.") + s"Cannot write incompatible data for the table `${table.name}`, due to unknown column names: $extraCols.") } Project(reorderedCols, query) } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala index aad4b82bd5b6..e8caea3cdd34 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala @@ -18,7 +18,6 @@ package org.apache.paimon.spark.commands -import org.apache.paimon.data.BinaryRow import org.apache.paimon.deletionvectors.BitmapDeletionVector import org.apache.paimon.fs.Path import org.apache.paimon.index.IndexFileMeta @@ -38,9 +37,9 @@ import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.PaimonUtils.createDataset import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral -import org.apache.spark.sql.catalyst.plans.logical.{Filter => FilterLogicalNode, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Filter => FilterLogicalNode, LogicalPlan} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation} -import org.apache.spark.sql.sources.{AlwaysTrue, And, EqualNullSafe, Filter} +import org.apache.spark.sql.sources.{AlwaysTrue, And, EqualNullSafe, EqualTo, Filter} import java.net.URI import java.util.Collections @@ -59,23 +58,20 @@ trait PaimonCommand extends WithFileStoreTable with ExpressionHelper { filters.length == 1 && filters.head.isInstanceOf[AlwaysTrue] } - /** - * For the 'INSERT OVERWRITE T PARTITION (partitionVal, ...)' semantics of SQL, Spark will - * transform `partitionVal`s to EqualNullSafe Filters. - */ - def convertFilterToMap(filter: Filter, partitionRowType: RowType): Map[String, String] = { + /** See [[ org.apache.paimon.spark.SparkWriteBuilder#failIfCanNotOverwrite]] */ + def convertPartitionFilterToMap( + filter: Filter, + partitionRowType: RowType): Map[String, String] = { val converter = new SparkFilterConverter(partitionRowType) splitConjunctiveFilters(filter).map { case EqualNullSafe(attribute, value) => - if (isNestedFilterInValue(value)) { - throw new RuntimeException( - s"Not support the complex partition value in EqualNullSafe when run `INSERT OVERWRITE`.") - } else { - (attribute, converter.convertLiteral(attribute, value).toString) - } + (attribute, converter.convertString(attribute, value)) + case EqualTo(attribute, value) => + (attribute, converter.convertString(attribute, value)) case _ => + // Should not happen throw new RuntimeException( - s"Only EqualNullSafe should be used when run `INSERT OVERWRITE`.") + s"Only support Overwrite filters with Equal and EqualNullSafe, but got: $filter") }.toMap } @@ -87,10 +83,6 @@ trait PaimonCommand extends WithFileStoreTable with ExpressionHelper { } } - private def isNestedFilterInValue(value: Any): Boolean = { - value.isInstanceOf[Filter] - } - /** Gets a relative path against the table path. */ protected def relativePath(absolutePath: String): String = { val location = table.location().toUri diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala index 905c9cdfb7ff..fe740ea8ca11 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala @@ -76,7 +76,7 @@ case class WriteIntoPaimonTable( } else if (isTruncate(filter.get)) { Map.empty[String, String] } else { - convertFilterToMap(filter.get, table.schema.logicalPartitionType()) + convertPartitionFilterToMap(filter.get, table.schema.logicalPartitionType()) } case DynamicOverWrite => dynamicPartitionOverwriteMode = true diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala index f50483d9f72f..ca3ba8797be6 100644 --- a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala @@ -26,11 +26,9 @@ import org.junit.jupiter.api.Assertions import java.sql.{Date, Timestamp} class DataFrameWriteTest extends PaimonSparkTestBase { + import testImplicits._ test("Paimon: DataFrameWrite.saveAsTable") { - - import testImplicits._ - Seq((1L, "x1"), (2L, "x2")) .toDF("a", "b") .write @@ -139,9 +137,6 @@ class DataFrameWriteTest extends PaimonSparkTestBase { bucket => test(s"Write data into Paimon directly: has-pk: $hasPk, bucket: $bucket") { - val _spark = spark - import _spark.implicits._ - val prop = if (hasPk) { s"'primary-key'='a', 'bucket' = '$bucket' " } else if (bucket != -1) { @@ -278,8 +273,6 @@ class DataFrameWriteTest extends PaimonSparkTestBase { bucket => test( s"Schema evolution: write data into Paimon with allowExplicitCast = true: $hasPk, bucket: $bucket") { - val _spark = spark - import _spark.implicits._ val prop = if (hasPk) { s"'primary-key'='a', 'bucket' = '$bucket' " @@ -380,4 +373,98 @@ class DataFrameWriteTest extends PaimonSparkTestBase { } } + withPk.foreach { + hasPk => + test(s"Support v2 write with overwrite, hasPk: $hasPk") { + withTable("t") { + val prop = if (hasPk) { + "'primary-key'='c1'" + } else { + "'write-only'='true'" + } + spark.sql(s""" + |CREATE TABLE t (c1 INT, c2 STRING) PARTITIONED BY(p1 String, p2 string) + |TBLPROPERTIES ($prop) + |""".stripMargin) + + spark + .range(3) + .selectExpr("id as c1", "id as c2", "'a' as p1", "id as p2") + .writeTo("t") + .overwrite($"p1" === "a") + checkAnswer( + spark.sql("SELECT * FROM t ORDER BY c1"), + Row(0, "0", "a", "0") :: Row(1, "1", "a", "1") :: Row(2, "2", "a", "2") :: Nil + ) + + spark + .range(7, 10) + .selectExpr("id as c1", "id as c2", "'a' as p1", "id as p2") + .writeTo("t") + .overwrite($"p1" === "a") + checkAnswer( + spark.sql("SELECT * FROM t ORDER BY c1"), + Row(7, "7", "a", "7") :: Row(8, "8", "a", "8") :: Row(9, "9", "a", "9") :: Nil + ) + + spark + .range(2) + .selectExpr("id as c1", "id as c2", "'a' as p1", "9 as p2") + .writeTo("t") + .overwrite(($"p1" <=> "a").and($"p2" === "9")) + checkAnswer( + spark.sql("SELECT * FROM t ORDER BY c1"), + Row(0, "0", "a", "9") :: Row(1, "1", "a", "9") :: Row(7, "7", "a", "7") :: + Row(8, "8", "a", "8") :: Nil + ) + + // bad case + val msg1 = intercept[Exception] { + spark + .range(2) + .selectExpr("id as c1", "id as c2", "'a' as p1", "id as p2") + .writeTo("t") + .overwrite($"p1" =!= "a") + }.getMessage + assert(msg1.contains("Only support Overwrite filters with Equal and EqualNullSafe")) + + val msg2 = intercept[Exception] { + spark + .range(2) + .selectExpr("id as c1", "id as c2", "'a' as p1", "id as p2") + .writeTo("t") + .overwrite($"p1" === $"c2") + }.getMessage + assert(msg2.contains("Table does not support overwrite by expression")) + + val msg3 = intercept[Exception] { + spark + .range(2) + .selectExpr("id as c1", "id as c2", "'a' as p1", "id as p2") + .writeTo("t") + .overwrite($"c1" === ($"c2" + 1)) + }.getMessage + assert(msg3.contains("cannot translate expression to source filter")) + + val msg4 = intercept[Exception] { + spark + .range(2) + .selectExpr("id as c1", "id as c2", "'a' as p1", "id as p2") + .writeTo("t") + .overwrite(($"p1" === "a").and($"p1" === "b")) + }.getMessage + assert(msg4.contains("Only support Overwrite with one filter for each partition column")) + + // Overwrite a partition which is not the specified + val msg5 = intercept[Exception] { + spark + .range(2) + .selectExpr("id as c1", "id as c2", "'a' as p1", "id as p2") + .writeTo("t") + .overwrite($"p1" === "b") + }.getMessage + assert(msg5.contains("does not belong to this partition")) + } + } + } }