Skip to content

Commit

Permalink
[spark] dataframe.write and insert sql syntax in byName mode (apache#…
Browse files Browse the repository at this point in the history
  • Loading branch information
YannByron authored Aug 1, 2024
1 parent b9539f8 commit 7a9bcda
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.ResolvedTable
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}

Expand All @@ -39,11 +40,17 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsDown {

case a @ PaimonV2WriteCommand(table, paimonTable)
if !schemaCompatible(
a.query.output.toStructType,
table.output.toStructType,
paimonTable.partitionKeys().asScala) =>
val newQuery = resolveQueryColumns(a.query, table.output)
if a.isByName && needsSchemaAdjustmentByName(a.query, table.output, paimonTable) =>
val newQuery = resolveQueryColumnsByName(a.query, table.output)
if (newQuery != a.query) {
Compatibility.withNewQuery(a, newQuery)
} else {
a
}

case a @ PaimonV2WriteCommand(table, paimonTable)
if !a.isByName && needsSchemaAdjustmentByPosition(a.query, table.output, paimonTable) =>
val newQuery = resolveQueryColumnsByPosition(a.query, table.output)
if (newQuery != a.query) {
Compatibility.withNewQuery(a, newQuery)
} else {
Expand All @@ -57,6 +64,62 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
PaimonMergeIntoResolver(merge, session)
}

private def needsSchemaAdjustmentByName(
query: LogicalPlan,
targetAttrs: Seq[Attribute],
paimonTable: FileStoreTable): Boolean = {
val userSpecifiedNames = if (session.sessionState.conf.caseSensitiveAnalysis) {
query.output.map(a => (a.name, a)).toMap
} else {
CaseInsensitiveMap(query.output.map(a => (a.name, a)).toMap)
}
val specifiedTargetAttrs = targetAttrs.filter(col => userSpecifiedNames.contains(col.name))
!schemaCompatible(
specifiedTargetAttrs.toStructType,
query.output.toStructType,
paimonTable.partitionKeys().asScala)
}

private def resolveQueryColumnsByName(
query: LogicalPlan,
targetAttrs: Seq[Attribute]): LogicalPlan = {
val output = query.output
val project = targetAttrs.map {
attr =>
val outputAttr = output
.find(t => session.sessionState.conf.resolver(t.name, attr.name))
.getOrElse {
throw new RuntimeException(
s"Cannot find ${attr.name} in data columns: ${output.map(_.name).mkString(", ")}")
}
addCastToColumn(outputAttr, attr)
}
Project(project, query)
}

private def needsSchemaAdjustmentByPosition(
query: LogicalPlan,
targetAttrs: Seq[Attribute],
paimonTable: FileStoreTable): Boolean = {
val output = query.output
targetAttrs.map(_.name) != output.map(_.name) ||
!schemaCompatible(
targetAttrs.toStructType,
output.toStructType,
paimonTable.partitionKeys().asScala)
}

private def resolveQueryColumnsByPosition(
query: LogicalPlan,
tableAttributes: Seq[Attribute]): LogicalPlan = {
val project = query.output.zipWithIndex.map {
case (attr, i) =>
val targetAttr = tableAttributes(i)
addCastToColumn(attr, targetAttr)
}
Project(project, query)
}

private def schemaCompatible(
dataSchema: StructType,
tableSchema: StructType,
Expand All @@ -83,22 +146,10 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
}

dataSchema.zip(tableSchema).forall {
case (f1, f2) =>
f1.name == f2.name && dataTypeCompatible(f1.name, f1.dataType, f2.dataType)
case (f1, f2) => dataTypeCompatible(f1.name, f1.dataType, f2.dataType)
}
}

private def resolveQueryColumns(
query: LogicalPlan,
tableAttributes: Seq[Attribute]): LogicalPlan = {
val project = query.output.zipWithIndex.map {
case (attr, i) =>
val targetAttr = tableAttributes(i)
addCastToColumn(attr, targetAttr)
}
Project(project, query)
}

private def addCastToColumn(attr: Attribute, targetAttr: Attribute): NamedExpression = {
val expr = (attr.dataType, targetAttr.dataType) match {
case (s, t) if s == t =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,35 @@ class DataFrameWriteTest extends PaimonSparkTestBase {
Assertions.assertFalse(paimonTable.options().containsKey("write.merge-schema.explicit-cast"))
}

fileFormats.foreach {
fileFormat =>
test(s"Paimon: DataFrameWrite.saveAsTable in ByName mode, file.format: $fileFormat") {
withTable("t1", "t2") {
spark.sql(s"""
|CREATE TABLE t1 (col1 STRING, col2 INT, col3 DOUBLE)
|TBLPROPERTIES ('file.format' = '$fileFormat')
|""".stripMargin)

spark.sql(s"""
|CREATE TABLE t2 (col2 INT, col3 DOUBLE, col1 STRING)
|TBLPROPERTIES ('file.format' = '$fileFormat')
|""".stripMargin)

sql(s"""
|INSERT INTO TABLE t1 VALUES
|("Hello", 1, 1.1),
|("World", 2, 2.2),
|("Paimon", 3, 3.3);
|""".stripMargin)

spark.table("t1").write.format("paimon").mode("append").saveAsTable("t2")
checkAnswer(
sql("SELECT * FROM t2 ORDER BY col2"),
Row(1, 1.1d, "Hello") :: Row(2, 2.2d, "World") :: Row(3, 3.3d, "Paimon") :: Nil)
}
}
}

withPk.foreach {
hasPk =>
bucketModes.foreach {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,57 @@ import java.sql.{Date, Timestamp}

abstract class InsertOverwriteTableTestBase extends PaimonSparkTestBase {

fileFormats.foreach {
fileFormat =>
Seq(true, false).foreach {
isPartitioned =>
test(
s"Paimon: insert into/overwrite in ByName mode, file.format: $fileFormat, isPartitioned: $isPartitioned") {
withTable("t1", "t2") {
val partitionedSQL = if (isPartitioned) {
"PARTITIONED BY (col4)"
} else {
""
}
spark.sql(s"""
|CREATE TABLE t1 (col1 STRING, col2 INT, col3 DOUBLE, col4 STRING)
|$partitionedSQL
|TBLPROPERTIES ('file.format' = '$fileFormat')
|""".stripMargin)

spark.sql(s"""
|CREATE TABLE t2 (col2 INT, col3 DOUBLE, col1 STRING, col4 STRING)
|$partitionedSQL
|TBLPROPERTIES ('file.format' = '$fileFormat')
|""".stripMargin)

sql(s"""
|INSERT INTO TABLE t1 VALUES
|("Hello", 1, 1.1, "pt1"),
|("Paimon", 3, 3.3, "pt2");
|""".stripMargin)

sql("INSERT INTO t2 (col1, col2, col3, col4) SELECT * FROM t1")
checkAnswer(
sql("SELECT * FROM t2 ORDER BY col2"),
Row(1, 1.1d, "Hello", "pt1") :: Row(3, 3.3d, "Paimon", "pt2") :: Nil)

sql(s"""
|INSERT INTO TABLE t1 VALUES ("World", 2, 2.2, "pt1");
|""".stripMargin)
sql("INSERT OVERWRITE t2 (col1, col2, col3, col4) SELECT * FROM t1")
checkAnswer(
sql("SELECT * FROM t2 ORDER BY col2"),
Row(1, 1.1d, "Hello", "pt1") :: Row(2, 2.2d, "World", "pt1") :: Row(
3,
3.3d,
"Paimon",
"pt2") :: Nil)
}
}
}
}

withPk.foreach {
hasPk =>
bucketModes.foreach {
Expand Down

0 comments on commit 7a9bcda

Please sign in to comment.