From 39ca57d72ec8334a349298cc3806ed238319620f Mon Sep 17 00:00:00 2001 From: Yann Byron Date: Fri, 2 Aug 2024 20:48:48 +0800 Subject: [PATCH] [spark] support complex data type in byName mode (#3878) --- .../catalyst/analysis/PaimonAnalysis.scala | 108 +++++++++++++++++- .../paimon/spark/sql/DataFrameWriteTest.scala | 50 +++++++- 2 files changed, 152 insertions(+), 6 deletions(-) diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala index d115fe3fd13c..7ed90283da5a 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala @@ -26,12 +26,12 @@ import org.apache.paimon.table.FileStoreTable import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.ResolvedTable -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, Attribute, CreateStruct, Expression, GetArrayItem, GetStructField, LambdaFunction, NamedExpression, NamedLambdaVariable} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, MapType, StructField, StructType} import scala.collection.JavaConverters._ @@ -92,7 +92,7 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] { throw new RuntimeException( s"Cannot find ${attr.name} in data columns: ${output.map(_.name).mkString(", ")}") } - addCastToColumn(outputAttr, attr) + addCastToColumn(outputAttr, attr, isByName = true) } Project(project, query) } @@ -115,7 +115,7 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] { val project = query.output.zipWithIndex.map { case (attr, i) => val targetAttr = tableAttributes(i) - addCastToColumn(attr, targetAttr) + addCastToColumn(attr, targetAttr, isByName = false) } Project(project, query) } @@ -150,16 +150,114 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] { } } - private def addCastToColumn(attr: Attribute, targetAttr: Attribute): NamedExpression = { + private def addCastToColumn( + attr: Attribute, + targetAttr: Attribute, + isByName: Boolean): NamedExpression = { val expr = (attr.dataType, targetAttr.dataType) match { case (s, t) if s == t => attr + case (s: StructType, t: StructType) if s != t => + if (isByName) { + addCastToStructByName(attr, s, t) + } else { + addCastToStructByPosition(attr, s, t) + } + case (ArrayType(s: StructType, sNull: Boolean), ArrayType(t: StructType, _: Boolean)) + if s != t => + val castToStructFunc = if (isByName) { + addCastToStructByName _ + } else { + addCastToStructByPosition _ + } + castToArrayStruct(attr, s, t, sNull, castToStructFunc) case _ => cast(attr, targetAttr.dataType) } Alias(expr, targetAttr.name)(explicitMetadata = Option(targetAttr.metadata)) } + private def addCastToStructByName( + parent: NamedExpression, + source: StructType, + target: StructType): NamedExpression = { + val fields = target.map { + case targetField @ StructField(name, nested: StructType, _, _) => + val sourceIndex = source.fieldIndex(name) + val sourceField = source(sourceIndex) + sourceField.dataType match { + case s: StructType => + val subField = castStructField(parent, sourceIndex, sourceField.name, targetField) + addCastToStructByName(subField, s, nested) + case o => + throw new RuntimeException(s"Can not support to cast $o to StructType.") + } + case targetField => + val sourceIndex = source.fieldIndex(targetField.name) + val sourceField = source(sourceIndex) + castStructField(parent, sourceIndex, sourceField.name, targetField) + } + Alias(CreateStruct(fields), parent.name)( + parent.exprId, + parent.qualifier, + Option(parent.metadata)) + } + + private def addCastToStructByPosition( + parent: NamedExpression, + source: StructType, + target: StructType): NamedExpression = { + if (source.length != target.length) { + throw new RuntimeException("The number of fields in source and target is not same.") + } + + val fields = target.zipWithIndex.map { + case (targetField @ StructField(_, nested: StructType, _, _), i) => + val sourceField = source(i) + sourceField.dataType match { + case s: StructType => + val subField = castStructField(parent, i, sourceField.name, targetField) + addCastToStructByPosition(subField, s, nested) + case o => + throw new RuntimeException(s"Can not support to cast $o to StructType.") + } + case (targetField, i) => + val sourceField = source(i) + castStructField(parent, i, sourceField.name, targetField) + } + Alias(CreateStruct(fields), parent.name)( + parent.exprId, + parent.qualifier, + Option(parent.metadata)) + } + + private def castStructField( + parent: NamedExpression, + i: Int, + sourceFieldName: String, + targetField: StructField): NamedExpression = { + Alias( + cast(GetStructField(parent, i, Option(sourceFieldName)), targetField.dataType), + targetField.name + )(explicitMetadata = Option(targetField.metadata)) + } + private def castToArrayStruct( + parent: NamedExpression, + source: StructType, + target: StructType, + sourceNullable: Boolean, + castToStructFunc: (NamedExpression, StructType, StructType) => NamedExpression + ): Expression = { + val structConverter: (Expression, Expression) => Expression = (_, i) => + castToStructFunc(Alias(GetArrayItem(parent, i), i.toString)(), source, target) + val transformLambdaFunc = { + val elementVar = NamedLambdaVariable("elementVar", source, sourceNullable) + val indexVar = NamedLambdaVariable("indexVar", IntegerType, false) + LambdaFunction(structConverter(elementVar, indexVar), Seq(elementVar, indexVar)) + } + ArrayTransform(parent, transformLambdaFunc) + } + private def cast(expr: Expression, dataType: DataType): Expression = { val cast = Compatibility.cast(expr, dataType, Option(conf.sessionLocalTimeZone)) cast.setTagValue(Compatibility.castByTableInsertionTag, ()) diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala index a2509871f44a..f50483d9f72f 100644 --- a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala @@ -23,7 +23,7 @@ import org.apache.paimon.spark.PaimonSparkTestBase import org.apache.spark.sql.Row import org.junit.jupiter.api.Assertions -import java.sql.Date +import java.sql.{Date, Timestamp} class DataFrameWriteTest extends PaimonSparkTestBase { @@ -85,6 +85,54 @@ class DataFrameWriteTest extends PaimonSparkTestBase { } } + fileFormats.foreach { + fileFormat => + test( + s"Paimon: DataFrameWrite.saveAsTable with complex data type in ByName mode, file.format: $fileFormat") { + withTable("t1", "t2") { + spark.sql( + s""" + |CREATE TABLE t1 (a STRING, b INT, c STRUCT, d ARRAY>>, e ARRAY) + |TBLPROPERTIES ('file.format' = '$fileFormat') + |""".stripMargin) + + spark.sql( + s""" + |CREATE TABLE t2 (b INT, c STRUCT, d ARRAY, d1 TIMESTAMP>>, e ARRAY, a STRING) + |TBLPROPERTIES ('file.format' = '$fileFormat') + |""".stripMargin) + + sql(s""" + |INSERT INTO TABLE t1 VALUES + |("Hello", 1, struct(1.1, 1000), array(struct(timestamp'2024-01-01 00:00:00', map("k1", "v1")), struct(timestamp'2024-08-01 00:00:00', map("k1", "v11"))), array(123, 345)), + |("World", 2, struct(2.2, 2000), array(struct(timestamp'2024-02-01 00:00:00', map("k2", "v2"))), array(234, 456)), + |("Paimon", 3, struct(3.3, 3000), null, array(345, 567)); + |""".stripMargin) + + spark.table("t1").write.format("paimon").mode("append").saveAsTable("t2") + checkAnswer( + sql("SELECT * FROM t2 ORDER BY b"), + Row( + 1, + Row(1000L, 1.1d), + Array( + Row(Map("k1" -> "v1"), Timestamp.valueOf("2024-01-01 00:00:00")), + Row(Map("k1" -> "v11"), Timestamp.valueOf("2024-08-01 00:00:00"))), + Array(123, 345), + "Hello" + ) + :: Row( + 2, + Row(2000L, 2.2d), + Array(Row(Map("k2" -> "v2"), Timestamp.valueOf("2024-02-01 00:00:00"))), + Array(234, 456), + "World") + :: Row(3, Row(3000L, 3.3d), null, Array(345, 567), "Paimon") :: Nil + ) + } + } + } + withPk.foreach { hasPk => bucketModes.foreach {