From 7a9bcdaa3b556adb2695740d44fd109dc80eb87a Mon Sep 17 00:00:00 2001 From: Yann Byron Date: Thu, 1 Aug 2024 23:08:41 +0800 Subject: [PATCH] [spark] dataframe.write and insert sql syntax in byName mode (#3871) --- .../catalyst/analysis/PaimonAnalysis.scala | 87 +++++++++++++++---- .../paimon/spark/sql/DataFrameWriteTest.scala | 29 +++++++ .../spark/sql/InsertOverwriteTableTest.scala | 51 +++++++++++ 3 files changed, 149 insertions(+), 18 deletions(-) diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala index 3dc0e40c9eff..d115fe3fd13c 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.ResolvedTable import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} @@ -39,11 +40,17 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsDown { case a @ PaimonV2WriteCommand(table, paimonTable) - if !schemaCompatible( - a.query.output.toStructType, - table.output.toStructType, - paimonTable.partitionKeys().asScala) => - val newQuery = resolveQueryColumns(a.query, table.output) + if a.isByName && needsSchemaAdjustmentByName(a.query, table.output, paimonTable) => + val newQuery = resolveQueryColumnsByName(a.query, table.output) + if (newQuery != a.query) { + Compatibility.withNewQuery(a, newQuery) + } else { + a + } + + case a @ PaimonV2WriteCommand(table, paimonTable) + if !a.isByName && needsSchemaAdjustmentByPosition(a.query, table.output, paimonTable) => + val newQuery = resolveQueryColumnsByPosition(a.query, table.output) if (newQuery != a.query) { Compatibility.withNewQuery(a, newQuery) } else { @@ -57,6 +64,62 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] { PaimonMergeIntoResolver(merge, session) } + private def needsSchemaAdjustmentByName( + query: LogicalPlan, + targetAttrs: Seq[Attribute], + paimonTable: FileStoreTable): Boolean = { + val userSpecifiedNames = if (session.sessionState.conf.caseSensitiveAnalysis) { + query.output.map(a => (a.name, a)).toMap + } else { + CaseInsensitiveMap(query.output.map(a => (a.name, a)).toMap) + } + val specifiedTargetAttrs = targetAttrs.filter(col => userSpecifiedNames.contains(col.name)) + !schemaCompatible( + specifiedTargetAttrs.toStructType, + query.output.toStructType, + paimonTable.partitionKeys().asScala) + } + + private def resolveQueryColumnsByName( + query: LogicalPlan, + targetAttrs: Seq[Attribute]): LogicalPlan = { + val output = query.output + val project = targetAttrs.map { + attr => + val outputAttr = output + .find(t => session.sessionState.conf.resolver(t.name, attr.name)) + .getOrElse { + throw new RuntimeException( + s"Cannot find ${attr.name} in data columns: ${output.map(_.name).mkString(", ")}") + } + addCastToColumn(outputAttr, attr) + } + Project(project, query) + } + + private def needsSchemaAdjustmentByPosition( + query: LogicalPlan, + targetAttrs: Seq[Attribute], + paimonTable: FileStoreTable): Boolean = { + val output = query.output + targetAttrs.map(_.name) != output.map(_.name) || + !schemaCompatible( + targetAttrs.toStructType, + output.toStructType, + paimonTable.partitionKeys().asScala) + } + + private def resolveQueryColumnsByPosition( + query: LogicalPlan, + tableAttributes: Seq[Attribute]): LogicalPlan = { + val project = query.output.zipWithIndex.map { + case (attr, i) => + val targetAttr = tableAttributes(i) + addCastToColumn(attr, targetAttr) + } + Project(project, query) + } + private def schemaCompatible( dataSchema: StructType, tableSchema: StructType, @@ -83,22 +146,10 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] { } dataSchema.zip(tableSchema).forall { - case (f1, f2) => - f1.name == f2.name && dataTypeCompatible(f1.name, f1.dataType, f2.dataType) + case (f1, f2) => dataTypeCompatible(f1.name, f1.dataType, f2.dataType) } } - private def resolveQueryColumns( - query: LogicalPlan, - tableAttributes: Seq[Attribute]): LogicalPlan = { - val project = query.output.zipWithIndex.map { - case (attr, i) => - val targetAttr = tableAttributes(i) - addCastToColumn(attr, targetAttr) - } - Project(project, query) - } - private def addCastToColumn(attr: Attribute, targetAttr: Attribute): NamedExpression = { val expr = (attr.dataType, targetAttr.dataType) match { case (s, t) if s == t => diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala index a4b618318f73..a2509871f44a 100644 --- a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala @@ -56,6 +56,35 @@ class DataFrameWriteTest extends PaimonSparkTestBase { Assertions.assertFalse(paimonTable.options().containsKey("write.merge-schema.explicit-cast")) } + fileFormats.foreach { + fileFormat => + test(s"Paimon: DataFrameWrite.saveAsTable in ByName mode, file.format: $fileFormat") { + withTable("t1", "t2") { + spark.sql(s""" + |CREATE TABLE t1 (col1 STRING, col2 INT, col3 DOUBLE) + |TBLPROPERTIES ('file.format' = '$fileFormat') + |""".stripMargin) + + spark.sql(s""" + |CREATE TABLE t2 (col2 INT, col3 DOUBLE, col1 STRING) + |TBLPROPERTIES ('file.format' = '$fileFormat') + |""".stripMargin) + + sql(s""" + |INSERT INTO TABLE t1 VALUES + |("Hello", 1, 1.1), + |("World", 2, 2.2), + |("Paimon", 3, 3.3); + |""".stripMargin) + + spark.table("t1").write.format("paimon").mode("append").saveAsTable("t2") + checkAnswer( + sql("SELECT * FROM t2 ORDER BY col2"), + Row(1, 1.1d, "Hello") :: Row(2, 2.2d, "World") :: Row(3, 3.3d, "Paimon") :: Nil) + } + } + } + withPk.foreach { hasPk => bucketModes.foreach { diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTest.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTest.scala index 9ad1f4523884..528df32e6a2b 100644 --- a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTest.scala +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTest.scala @@ -27,6 +27,57 @@ import java.sql.{Date, Timestamp} abstract class InsertOverwriteTableTestBase extends PaimonSparkTestBase { + fileFormats.foreach { + fileFormat => + Seq(true, false).foreach { + isPartitioned => + test( + s"Paimon: insert into/overwrite in ByName mode, file.format: $fileFormat, isPartitioned: $isPartitioned") { + withTable("t1", "t2") { + val partitionedSQL = if (isPartitioned) { + "PARTITIONED BY (col4)" + } else { + "" + } + spark.sql(s""" + |CREATE TABLE t1 (col1 STRING, col2 INT, col3 DOUBLE, col4 STRING) + |$partitionedSQL + |TBLPROPERTIES ('file.format' = '$fileFormat') + |""".stripMargin) + + spark.sql(s""" + |CREATE TABLE t2 (col2 INT, col3 DOUBLE, col1 STRING, col4 STRING) + |$partitionedSQL + |TBLPROPERTIES ('file.format' = '$fileFormat') + |""".stripMargin) + + sql(s""" + |INSERT INTO TABLE t1 VALUES + |("Hello", 1, 1.1, "pt1"), + |("Paimon", 3, 3.3, "pt2"); + |""".stripMargin) + + sql("INSERT INTO t2 (col1, col2, col3, col4) SELECT * FROM t1") + checkAnswer( + sql("SELECT * FROM t2 ORDER BY col2"), + Row(1, 1.1d, "Hello", "pt1") :: Row(3, 3.3d, "Paimon", "pt2") :: Nil) + + sql(s""" + |INSERT INTO TABLE t1 VALUES ("World", 2, 2.2, "pt1"); + |""".stripMargin) + sql("INSERT OVERWRITE t2 (col1, col2, col3, col4) SELECT * FROM t1") + checkAnswer( + sql("SELECT * FROM t2 ORDER BY col2"), + Row(1, 1.1d, "Hello", "pt1") :: Row(2, 2.2d, "World", "pt1") :: Row( + 3, + 3.3d, + "Paimon", + "pt2") :: Nil) + } + } + } + } + withPk.foreach { hasPk => bucketModes.foreach {