Skip to content

Commit

Permalink
[spark] support complex data type in byName mode (apache#3878)
Browse files Browse the repository at this point in the history
  • Loading branch information
YannByron authored Aug 2, 2024
1 parent 0ad2ae9 commit 39ca57d
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ import org.apache.paimon.table.FileStoreTable

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.ResolvedTable
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, Attribute, CreateStruct, Expression, GetArrayItem, GetStructField, LambdaFunction, NamedExpression, NamedLambdaVariable}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, MapType, StructField, StructType}

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -92,7 +92,7 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
throw new RuntimeException(
s"Cannot find ${attr.name} in data columns: ${output.map(_.name).mkString(", ")}")
}
addCastToColumn(outputAttr, attr)
addCastToColumn(outputAttr, attr, isByName = true)
}
Project(project, query)
}
Expand All @@ -115,7 +115,7 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
val project = query.output.zipWithIndex.map {
case (attr, i) =>
val targetAttr = tableAttributes(i)
addCastToColumn(attr, targetAttr)
addCastToColumn(attr, targetAttr, isByName = false)
}
Project(project, query)
}
Expand Down Expand Up @@ -150,16 +150,114 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
}
}

private def addCastToColumn(attr: Attribute, targetAttr: Attribute): NamedExpression = {
private def addCastToColumn(
attr: Attribute,
targetAttr: Attribute,
isByName: Boolean): NamedExpression = {
val expr = (attr.dataType, targetAttr.dataType) match {
case (s, t) if s == t =>
attr
case (s: StructType, t: StructType) if s != t =>
if (isByName) {
addCastToStructByName(attr, s, t)
} else {
addCastToStructByPosition(attr, s, t)
}
case (ArrayType(s: StructType, sNull: Boolean), ArrayType(t: StructType, _: Boolean))
if s != t =>
val castToStructFunc = if (isByName) {
addCastToStructByName _
} else {
addCastToStructByPosition _
}
castToArrayStruct(attr, s, t, sNull, castToStructFunc)
case _ =>
cast(attr, targetAttr.dataType)
}
Alias(expr, targetAttr.name)(explicitMetadata = Option(targetAttr.metadata))
}

private def addCastToStructByName(
parent: NamedExpression,
source: StructType,
target: StructType): NamedExpression = {
val fields = target.map {
case targetField @ StructField(name, nested: StructType, _, _) =>
val sourceIndex = source.fieldIndex(name)
val sourceField = source(sourceIndex)
sourceField.dataType match {
case s: StructType =>
val subField = castStructField(parent, sourceIndex, sourceField.name, targetField)
addCastToStructByName(subField, s, nested)
case o =>
throw new RuntimeException(s"Can not support to cast $o to StructType.")
}
case targetField =>
val sourceIndex = source.fieldIndex(targetField.name)
val sourceField = source(sourceIndex)
castStructField(parent, sourceIndex, sourceField.name, targetField)
}
Alias(CreateStruct(fields), parent.name)(
parent.exprId,
parent.qualifier,
Option(parent.metadata))
}

private def addCastToStructByPosition(
parent: NamedExpression,
source: StructType,
target: StructType): NamedExpression = {
if (source.length != target.length) {
throw new RuntimeException("The number of fields in source and target is not same.")
}

val fields = target.zipWithIndex.map {
case (targetField @ StructField(_, nested: StructType, _, _), i) =>
val sourceField = source(i)
sourceField.dataType match {
case s: StructType =>
val subField = castStructField(parent, i, sourceField.name, targetField)
addCastToStructByPosition(subField, s, nested)
case o =>
throw new RuntimeException(s"Can not support to cast $o to StructType.")
}
case (targetField, i) =>
val sourceField = source(i)
castStructField(parent, i, sourceField.name, targetField)
}
Alias(CreateStruct(fields), parent.name)(
parent.exprId,
parent.qualifier,
Option(parent.metadata))
}

private def castStructField(
parent: NamedExpression,
i: Int,
sourceFieldName: String,
targetField: StructField): NamedExpression = {
Alias(
cast(GetStructField(parent, i, Option(sourceFieldName)), targetField.dataType),
targetField.name
)(explicitMetadata = Option(targetField.metadata))
}
private def castToArrayStruct(
parent: NamedExpression,
source: StructType,
target: StructType,
sourceNullable: Boolean,
castToStructFunc: (NamedExpression, StructType, StructType) => NamedExpression
): Expression = {
val structConverter: (Expression, Expression) => Expression = (_, i) =>
castToStructFunc(Alias(GetArrayItem(parent, i), i.toString)(), source, target)
val transformLambdaFunc = {
val elementVar = NamedLambdaVariable("elementVar", source, sourceNullable)
val indexVar = NamedLambdaVariable("indexVar", IntegerType, false)
LambdaFunction(structConverter(elementVar, indexVar), Seq(elementVar, indexVar))
}
ArrayTransform(parent, transformLambdaFunc)
}

private def cast(expr: Expression, dataType: DataType): Expression = {
val cast = Compatibility.cast(expr, dataType, Option(conf.sessionLocalTimeZone))
cast.setTagValue(Compatibility.castByTableInsertionTag, ())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.paimon.spark.PaimonSparkTestBase
import org.apache.spark.sql.Row
import org.junit.jupiter.api.Assertions

import java.sql.Date
import java.sql.{Date, Timestamp}

class DataFrameWriteTest extends PaimonSparkTestBase {

Expand Down Expand Up @@ -85,6 +85,54 @@ class DataFrameWriteTest extends PaimonSparkTestBase {
}
}

fileFormats.foreach {
fileFormat =>
test(
s"Paimon: DataFrameWrite.saveAsTable with complex data type in ByName mode, file.format: $fileFormat") {
withTable("t1", "t2") {
spark.sql(
s"""
|CREATE TABLE t1 (a STRING, b INT, c STRUCT<c1:DOUBLE, c2:LONG>, d ARRAY<STRUCT<d1 TIMESTAMP, d2 MAP<STRING, STRING>>>, e ARRAY<INT>)
|TBLPROPERTIES ('file.format' = '$fileFormat')
|""".stripMargin)

spark.sql(
s"""
|CREATE TABLE t2 (b INT, c STRUCT<c2:LONG, c1:DOUBLE>, d ARRAY<STRUCT<d2 MAP<STRING, STRING>, d1 TIMESTAMP>>, e ARRAY<INT>, a STRING)
|TBLPROPERTIES ('file.format' = '$fileFormat')
|""".stripMargin)

sql(s"""
|INSERT INTO TABLE t1 VALUES
|("Hello", 1, struct(1.1, 1000), array(struct(timestamp'2024-01-01 00:00:00', map("k1", "v1")), struct(timestamp'2024-08-01 00:00:00', map("k1", "v11"))), array(123, 345)),
|("World", 2, struct(2.2, 2000), array(struct(timestamp'2024-02-01 00:00:00', map("k2", "v2"))), array(234, 456)),
|("Paimon", 3, struct(3.3, 3000), null, array(345, 567));
|""".stripMargin)

spark.table("t1").write.format("paimon").mode("append").saveAsTable("t2")
checkAnswer(
sql("SELECT * FROM t2 ORDER BY b"),
Row(
1,
Row(1000L, 1.1d),
Array(
Row(Map("k1" -> "v1"), Timestamp.valueOf("2024-01-01 00:00:00")),
Row(Map("k1" -> "v11"), Timestamp.valueOf("2024-08-01 00:00:00"))),
Array(123, 345),
"Hello"
)
:: Row(
2,
Row(2000L, 2.2d),
Array(Row(Map("k2" -> "v2"), Timestamp.valueOf("2024-02-01 00:00:00"))),
Array(234, 456),
"World")
:: Row(3, Row(3000L, 3.3d), null, Array(345, 567), "Paimon") :: Nil
)
}
}
}

withPk.foreach {
hasPk =>
bucketModes.foreach {
Expand Down

0 comments on commit 39ca57d

Please sign in to comment.