From 0b00f3c64a29e2142b343b81bc9454a34b488118 Mon Sep 17 00:00:00 2001 From: Zouxxyy Date: Thu, 26 Dec 2024 22:45:58 +0800 Subject: [PATCH] [spark] Fix writing null struct col (#4787) --- .../catalyst/analysis/PaimonAnalysis.scala | 24 ++++++++++++------- .../paimon/spark/PaimonSparkTestBase.scala | 11 +++++++-- .../sql/InsertOverwriteTableTestBase.scala | 22 +++++++++++++++++ .../paimon/spark/sql/WithTableOptions.scala | 1 + 4 files changed, 47 insertions(+), 11 deletions(-) diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala index f567d925ea57..790983866845 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala @@ -26,7 +26,7 @@ import org.apache.paimon.table.FileStoreTable import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.{NamedRelation, ResolvedTable} -import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, Attribute, CreateStruct, Expression, GetArrayItem, GetStructField, LambdaFunction, Literal, NamedExpression, NamedLambdaVariable} +import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, Attribute, CreateNamedStruct, CreateStruct, Expression, GetArrayItem, GetStructField, If, IsNull, LambdaFunction, Literal, NamedExpression, NamedLambdaVariable} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.CharVarcharUtils @@ -206,10 +206,7 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] { val sourceField = source(sourceIndex) castStructField(parent, sourceIndex, sourceField.name, targetField) } - Alias(CreateStruct(fields), parent.name)( - parent.exprId, - parent.qualifier, - Option(parent.metadata)) + structAlias(fields, parent) } private def addCastToStructByPosition( @@ -234,10 +231,19 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] { val sourceField = source(i) castStructField(parent, i, sourceField.name, targetField) } - Alias(CreateStruct(fields), parent.name)( - parent.exprId, - parent.qualifier, - Option(parent.metadata)) + structAlias(fields, parent) + } + + private def structAlias( + fields: Seq[NamedExpression], + parent: NamedExpression): NamedExpression = { + val struct = CreateStruct(fields) + val res = if (parent.nullable) { + If(IsNull(parent), Literal(null, struct.dataType), struct) + } else { + struct + } + Alias(res, parent.name)(parent.exprId, parent.qualifier, Option(parent.metadata)) } private def castStructField( diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/PaimonSparkTestBase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/PaimonSparkTestBase.scala index 867b3e5e3337..9a6719010e36 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/PaimonSparkTestBase.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/PaimonSparkTestBase.scala @@ -19,6 +19,8 @@ package org.apache.paimon.spark import org.apache.paimon.catalog.{Catalog, Identifier} +import org.apache.paimon.fs.FileIO +import org.apache.paimon.fs.local.LocalFileIO import org.apache.paimon.spark.catalog.WithPaimonCatalog import org.apache.paimon.spark.extensions.PaimonSparkSessionExtensions import org.apache.paimon.spark.sql.{SparkVersionSupport, WithTableOptions} @@ -46,6 +48,8 @@ class PaimonSparkTestBase with WithTableOptions with SparkVersionSupport { + protected lazy val fileIO: FileIO = LocalFileIO.create + protected lazy val tempDBDir: File = Utils.createTempDir protected def paimonCatalog: Catalog = { @@ -64,6 +68,7 @@ class PaimonSparkTestBase "org.apache.spark.serializer.JavaSerializer" } super.sparkConf + .set("spark.sql.warehouse.dir", tempDBDir.getCanonicalPath) .set("spark.sql.catalog.paimon", classOf[SparkCatalog].getName) .set("spark.sql.catalog.paimon.warehouse", tempDBDir.getCanonicalPath) .set("spark.sql.extensions", classOf[PaimonSparkSessionExtensions].getName) @@ -152,8 +157,10 @@ class PaimonSparkTestBase override def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit pos: Position): Unit = { - println(testName) - super.test(testName, testTags: _*)(testFun)(pos) + super.test(testName, testTags: _*) { + println(testName) + testFun + }(pos) } def loadTable(tableName: String): FileStoreTable = { diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTestBase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTestBase.scala index 977b74707069..38cca371f042 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTestBase.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTestBase.scala @@ -560,4 +560,26 @@ abstract class InsertOverwriteTableTestBase extends PaimonSparkTestBase { } checkAnswer(sql("SELECT * FROM T ORDER BY name"), Row("g", null, "Shanghai")) } + + test("Paimon Insert: read and write struct with null") { + fileFormats { + format => + withTable("t") { + sql( + s"CREATE TABLE t (i INT, s STRUCT) TBLPROPERTIES ('file.format' = '$format')") + sql( + "INSERT INTO t VALUES (1, STRUCT(1, 1)), (2, null), (3, STRUCT(1, null)), (4, STRUCT(null, null))") + if (format.equals("parquet")) { + // todo: fix it, see https://github.com/apache/paimon/issues/4785 + checkAnswer( + sql("SELECT * FROM t ORDER BY i"), + Seq(Row(1, Row(1, 1)), Row(2, null), Row(3, Row(1, null)), Row(4, null))) + } else { + checkAnswer( + sql("SELECT * FROM t ORDER BY i"), + Seq(Row(1, Row(1, 1)), Row(2, null), Row(3, Row(1, null)), Row(4, Row(null, null)))) + } + } + } + } } diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/WithTableOptions.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/WithTableOptions.scala index e390058bafab..d5866a31b165 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/WithTableOptions.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/WithTableOptions.scala @@ -27,4 +27,5 @@ trait WithTableOptions { protected val withPk: Seq[Boolean] = Seq(true, false) + protected def fileFormats(fn: String => Unit): Unit = Seq("parquet", "orc", "avro").foreach(fn) }