Skip to content

Commit

Permalink
Support dataframe v2 write with overwrite (apache#4082)
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you authored Sep 4, 2024
1 parent 538c9c7 commit cbda788
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ public Object convertLiteral(String field, Object value) {
return convertLiteral(fieldIndex(field), value);
}

public String convertString(String field, Object value) {
Object literal = convertLiteral(field, value);
return literal == null ? null : literal.toString();
}

private int fieldIndex(String field) {
int index = rowType.getFieldIndex(field);
// TODO: support nested field
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,70 @@ package org.apache.paimon.spark
import org.apache.paimon.options.Options
import org.apache.paimon.table.FileStoreTable

import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.connector.write.{SupportsOverwrite, WriteBuilder}
import org.apache.spark.sql.sources.{And, Filter}
import org.apache.spark.sql.sources.{AlwaysFalse, AlwaysTrue, And, EqualNullSafe, EqualTo, Filter, Not, Or}

import scala.collection.JavaConverters._

private class SparkWriteBuilder(table: FileStoreTable, options: Options)
extends WriteBuilder
with SupportsOverwrite {
with SupportsOverwrite
with SQLConfHelper {

private var saveMode: SaveMode = InsertInto

override def build = new SparkWrite(table, saveMode, options)

private def failWithReason(filter: Filter): Unit = {
throw new RuntimeException(
s"Only support Overwrite filters with Equal and EqualNullSafe, but got: $filter")
}

private def validateFilter(filter: Filter): Unit = filter match {
case And(left, right) =>
validateFilter(left)
validateFilter(right)
case _: Or => failWithReason(filter)
case _: Not => failWithReason(filter)
case e: EqualTo if e.references.length == 1 && !e.value.isInstanceOf[Filter] =>
case e: EqualNullSafe if e.references.length == 1 && !e.value.isInstanceOf[Filter] =>
case _: AlwaysTrue | _: AlwaysFalse =>
case _ => failWithReason(filter)
}

// `SupportsOverwrite#canOverwrite` is added since Spark 3.4.0.
// We do this checking by self to work with previous Spark version.
private def failIfCanNotOverwrite(filters: Array[Filter]): Unit = {
// For now, we only support overwrite with two cases:
// - overwrite with partition columns to be compatible with v1 insert overwrite
// See [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveInsertInto#staticDeleteExpression]].
// - truncate-like overwrite and the filter is always true.
//
// Fast fail for other custom filters which through v2 write interface, e.g.,
// `dataframe.writeTo(T).overwrite(...)`
val partitionRowType = table.schema.logicalPartitionType()
val partitionNames = partitionRowType.getFieldNames.asScala
val allReferences = filters.flatMap(_.references)
val containsDataColumn = allReferences.exists {
reference => !partitionNames.exists(conf.resolver.apply(reference, _))
}
if (containsDataColumn) {
throw new RuntimeException(
s"Only support Overwrite filters on partition column ${partitionNames.mkString(
", ")}, but got ${filters.mkString(", ")}.")
}
if (allReferences.distinct.length < allReferences.length) {
// fail with `part = 1 and part = 2`
throw new RuntimeException(
s"Only support Overwrite with one filter for each partition column, but got ${filters.mkString(", ")}.")
}
filters.foreach(validateFilter)
}

override def overwrite(filters: Array[Filter]): WriteBuilder = {
failIfCanNotOverwrite(filters)

val conjunctiveFilters = if (filters.nonEmpty) {
Some(filters.reduce((l, r) => And(l, r)))
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
.mkString(", ")
// There are seme unknown column names
throw new RuntimeException(
s"Cannot write incompatible data for the table `${table.name}`, due to unknown column names: ${extraCols
.mkString(", ")}.")
s"Cannot write incompatible data for the table `${table.name}`, due to unknown column names: $extraCols.")
}
Project(reorderedCols, query)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

package org.apache.paimon.spark.commands

import org.apache.paimon.data.BinaryRow
import org.apache.paimon.deletionvectors.BitmapDeletionVector
import org.apache.paimon.fs.Path
import org.apache.paimon.index.IndexFileMeta
Expand All @@ -38,9 +37,9 @@ import org.apache.spark.sql.{Dataset, Row, SparkSession}
import org.apache.spark.sql.PaimonUtils.createDataset
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.plans.logical.{Filter => FilterLogicalNode, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.{Filter => FilterLogicalNode, LogicalPlan}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation}
import org.apache.spark.sql.sources.{AlwaysTrue, And, EqualNullSafe, Filter}
import org.apache.spark.sql.sources.{AlwaysTrue, And, EqualNullSafe, EqualTo, Filter}

import java.net.URI
import java.util.Collections
Expand All @@ -59,23 +58,20 @@ trait PaimonCommand extends WithFileStoreTable with ExpressionHelper {
filters.length == 1 && filters.head.isInstanceOf[AlwaysTrue]
}

/**
* For the 'INSERT OVERWRITE T PARTITION (partitionVal, ...)' semantics of SQL, Spark will
* transform `partitionVal`s to EqualNullSafe Filters.
*/
def convertFilterToMap(filter: Filter, partitionRowType: RowType): Map[String, String] = {
/** See [[ org.apache.paimon.spark.SparkWriteBuilder#failIfCanNotOverwrite]] */
def convertPartitionFilterToMap(
filter: Filter,
partitionRowType: RowType): Map[String, String] = {
val converter = new SparkFilterConverter(partitionRowType)
splitConjunctiveFilters(filter).map {
case EqualNullSafe(attribute, value) =>
if (isNestedFilterInValue(value)) {
throw new RuntimeException(
s"Not support the complex partition value in EqualNullSafe when run `INSERT OVERWRITE`.")
} else {
(attribute, converter.convertLiteral(attribute, value).toString)
}
(attribute, converter.convertString(attribute, value))
case EqualTo(attribute, value) =>
(attribute, converter.convertString(attribute, value))
case _ =>
// Should not happen
throw new RuntimeException(
s"Only EqualNullSafe should be used when run `INSERT OVERWRITE`.")
s"Only support Overwrite filters with Equal and EqualNullSafe, but got: $filter")
}.toMap
}

Expand All @@ -87,10 +83,6 @@ trait PaimonCommand extends WithFileStoreTable with ExpressionHelper {
}
}

private def isNestedFilterInValue(value: Any): Boolean = {
value.isInstanceOf[Filter]
}

/** Gets a relative path against the table path. */
protected def relativePath(absolutePath: String): String = {
val location = table.location().toUri
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ case class WriteIntoPaimonTable(
} else if (isTruncate(filter.get)) {
Map.empty[String, String]
} else {
convertFilterToMap(filter.get, table.schema.logicalPartitionType())
convertPartitionFilterToMap(filter.get, table.schema.logicalPartitionType())
}
case DynamicOverWrite =>
dynamicPartitionOverwriteMode = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,9 @@ import org.junit.jupiter.api.Assertions
import java.sql.{Date, Timestamp}

class DataFrameWriteTest extends PaimonSparkTestBase {
import testImplicits._

test("Paimon: DataFrameWrite.saveAsTable") {

import testImplicits._

Seq((1L, "x1"), (2L, "x2"))
.toDF("a", "b")
.write
Expand Down Expand Up @@ -139,9 +137,6 @@ class DataFrameWriteTest extends PaimonSparkTestBase {
bucket =>
test(s"Write data into Paimon directly: has-pk: $hasPk, bucket: $bucket") {

val _spark = spark
import _spark.implicits._

val prop = if (hasPk) {
s"'primary-key'='a', 'bucket' = '$bucket' "
} else if (bucket != -1) {
Expand Down Expand Up @@ -278,8 +273,6 @@ class DataFrameWriteTest extends PaimonSparkTestBase {
bucket =>
test(
s"Schema evolution: write data into Paimon with allowExplicitCast = true: $hasPk, bucket: $bucket") {
val _spark = spark
import _spark.implicits._

val prop = if (hasPk) {
s"'primary-key'='a', 'bucket' = '$bucket' "
Expand Down Expand Up @@ -380,4 +373,98 @@ class DataFrameWriteTest extends PaimonSparkTestBase {
}
}

withPk.foreach {
hasPk =>
test(s"Support v2 write with overwrite, hasPk: $hasPk") {
withTable("t") {
val prop = if (hasPk) {
"'primary-key'='c1'"
} else {
"'write-only'='true'"
}
spark.sql(s"""
|CREATE TABLE t (c1 INT, c2 STRING) PARTITIONED BY(p1 String, p2 string)
|TBLPROPERTIES ($prop)
|""".stripMargin)

spark
.range(3)
.selectExpr("id as c1", "id as c2", "'a' as p1", "id as p2")
.writeTo("t")
.overwrite($"p1" === "a")
checkAnswer(
spark.sql("SELECT * FROM t ORDER BY c1"),
Row(0, "0", "a", "0") :: Row(1, "1", "a", "1") :: Row(2, "2", "a", "2") :: Nil
)

spark
.range(7, 10)
.selectExpr("id as c1", "id as c2", "'a' as p1", "id as p2")
.writeTo("t")
.overwrite($"p1" === "a")
checkAnswer(
spark.sql("SELECT * FROM t ORDER BY c1"),
Row(7, "7", "a", "7") :: Row(8, "8", "a", "8") :: Row(9, "9", "a", "9") :: Nil
)

spark
.range(2)
.selectExpr("id as c1", "id as c2", "'a' as p1", "9 as p2")
.writeTo("t")
.overwrite(($"p1" <=> "a").and($"p2" === "9"))
checkAnswer(
spark.sql("SELECT * FROM t ORDER BY c1"),
Row(0, "0", "a", "9") :: Row(1, "1", "a", "9") :: Row(7, "7", "a", "7") ::
Row(8, "8", "a", "8") :: Nil
)

// bad case
val msg1 = intercept[Exception] {
spark
.range(2)
.selectExpr("id as c1", "id as c2", "'a' as p1", "id as p2")
.writeTo("t")
.overwrite($"p1" =!= "a")
}.getMessage
assert(msg1.contains("Only support Overwrite filters with Equal and EqualNullSafe"))

val msg2 = intercept[Exception] {
spark
.range(2)
.selectExpr("id as c1", "id as c2", "'a' as p1", "id as p2")
.writeTo("t")
.overwrite($"p1" === $"c2")
}.getMessage
assert(msg2.contains("Table does not support overwrite by expression"))

val msg3 = intercept[Exception] {
spark
.range(2)
.selectExpr("id as c1", "id as c2", "'a' as p1", "id as p2")
.writeTo("t")
.overwrite($"c1" === ($"c2" + 1))
}.getMessage
assert(msg3.contains("cannot translate expression to source filter"))

val msg4 = intercept[Exception] {
spark
.range(2)
.selectExpr("id as c1", "id as c2", "'a' as p1", "id as p2")
.writeTo("t")
.overwrite(($"p1" === "a").and($"p1" === "b"))
}.getMessage
assert(msg4.contains("Only support Overwrite with one filter for each partition column"))

// Overwrite a partition which is not the specified
val msg5 = intercept[Exception] {
spark
.range(2)
.selectExpr("id as c1", "id as c2", "'a' as p1", "id as p2")
.writeTo("t")
.overwrite($"p1" === "b")
}.getMessage
assert(msg5.contains("does not belong to this partition"))
}
}
}
}

0 comments on commit cbda788

Please sign in to comment.