Skip to content

Commit

Permalink
[spark] Fix writing null struct col (apache#4787)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zouxxyy authored Dec 26, 2024
1 parent 2e57c59 commit 0b00f3c
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.paimon.table.FileStoreTable

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.{NamedRelation, ResolvedTable}
import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, Attribute, CreateStruct, Expression, GetArrayItem, GetStructField, LambdaFunction, Literal, NamedExpression, NamedLambdaVariable}
import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, Attribute, CreateNamedStruct, CreateStruct, Expression, GetArrayItem, GetStructField, If, IsNull, LambdaFunction, Literal, NamedExpression, NamedLambdaVariable}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
Expand Down Expand Up @@ -206,10 +206,7 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
val sourceField = source(sourceIndex)
castStructField(parent, sourceIndex, sourceField.name, targetField)
}
Alias(CreateStruct(fields), parent.name)(
parent.exprId,
parent.qualifier,
Option(parent.metadata))
structAlias(fields, parent)
}

private def addCastToStructByPosition(
Expand All @@ -234,10 +231,19 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
val sourceField = source(i)
castStructField(parent, i, sourceField.name, targetField)
}
Alias(CreateStruct(fields), parent.name)(
parent.exprId,
parent.qualifier,
Option(parent.metadata))
structAlias(fields, parent)
}

private def structAlias(
fields: Seq[NamedExpression],
parent: NamedExpression): NamedExpression = {
val struct = CreateStruct(fields)
val res = if (parent.nullable) {
If(IsNull(parent), Literal(null, struct.dataType), struct)
} else {
struct
}
Alias(res, parent.name)(parent.exprId, parent.qualifier, Option(parent.metadata))
}

private def castStructField(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
package org.apache.paimon.spark

import org.apache.paimon.catalog.{Catalog, Identifier}
import org.apache.paimon.fs.FileIO
import org.apache.paimon.fs.local.LocalFileIO
import org.apache.paimon.spark.catalog.WithPaimonCatalog
import org.apache.paimon.spark.extensions.PaimonSparkSessionExtensions
import org.apache.paimon.spark.sql.{SparkVersionSupport, WithTableOptions}
Expand Down Expand Up @@ -46,6 +48,8 @@ class PaimonSparkTestBase
with WithTableOptions
with SparkVersionSupport {

protected lazy val fileIO: FileIO = LocalFileIO.create

protected lazy val tempDBDir: File = Utils.createTempDir

protected def paimonCatalog: Catalog = {
Expand All @@ -64,6 +68,7 @@ class PaimonSparkTestBase
"org.apache.spark.serializer.JavaSerializer"
}
super.sparkConf
.set("spark.sql.warehouse.dir", tempDBDir.getCanonicalPath)
.set("spark.sql.catalog.paimon", classOf[SparkCatalog].getName)
.set("spark.sql.catalog.paimon.warehouse", tempDBDir.getCanonicalPath)
.set("spark.sql.extensions", classOf[PaimonSparkSessionExtensions].getName)
Expand Down Expand Up @@ -152,8 +157,10 @@ class PaimonSparkTestBase

override def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit
pos: Position): Unit = {
println(testName)
super.test(testName, testTags: _*)(testFun)(pos)
super.test(testName, testTags: _*) {
println(testName)
testFun
}(pos)
}

def loadTable(tableName: String): FileStoreTable = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -560,4 +560,26 @@ abstract class InsertOverwriteTableTestBase extends PaimonSparkTestBase {
}
checkAnswer(sql("SELECT * FROM T ORDER BY name"), Row("g", null, "Shanghai"))
}

test("Paimon Insert: read and write struct with null") {
fileFormats {
format =>
withTable("t") {
sql(
s"CREATE TABLE t (i INT, s STRUCT<f1: INT, f2: INT>) TBLPROPERTIES ('file.format' = '$format')")
sql(
"INSERT INTO t VALUES (1, STRUCT(1, 1)), (2, null), (3, STRUCT(1, null)), (4, STRUCT(null, null))")
if (format.equals("parquet")) {
// todo: fix it, see https://github.com/apache/paimon/issues/4785
checkAnswer(
sql("SELECT * FROM t ORDER BY i"),
Seq(Row(1, Row(1, 1)), Row(2, null), Row(3, Row(1, null)), Row(4, null)))
} else {
checkAnswer(
sql("SELECT * FROM t ORDER BY i"),
Seq(Row(1, Row(1, 1)), Row(2, null), Row(3, Row(1, null)), Row(4, Row(null, null))))
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ trait WithTableOptions {

protected val withPk: Seq[Boolean] = Seq(true, false)

protected def fileFormats(fn: String => Unit): Unit = Seq("parquet", "orc", "avro").foreach(fn)
}

0 comments on commit 0b00f3c

Please sign in to comment.