Skip to content

Commit

Permalink
[spark] Refactor Paimon write by name (apache#4058)
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you authored Aug 27, 2024
1 parent 25a4f99 commit 8f0a5f3
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,22 @@ import org.apache.paimon.spark.commands.{PaimonAnalyzeTableColumnCommand, Paimon
import org.apache.paimon.table.FileStoreTable

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.ResolvedTable
import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, Attribute, CreateStruct, Expression, GetArrayItem, GetStructField, LambdaFunction, NamedExpression, NamedLambdaVariable}
import org.apache.spark.sql.catalyst.analysis.{NamedRelation, ResolvedTable}
import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, Attribute, CreateStruct, Expression, GetArrayItem, GetStructField, LambdaFunction, Literal, NamedExpression, NamedLambdaVariable}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, MapType, StructField, StructType}

import scala.collection.JavaConverters._
import scala.collection.mutable

class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsDown {

case a @ PaimonV2WriteCommand(table, paimonTable)
if a.isByName && needsSchemaAdjustmentByName(a.query, table.output, paimonTable) =>
val newQuery = resolveQueryColumnsByName(a.query, table.output)
if (newQuery != a.query) {
Compatibility.withNewQuery(a, newQuery)
} else {
a
}

case a @ PaimonV2WriteCommand(table, paimonTable)
if !a.isByName && needsSchemaAdjustmentByPosition(a.query, table.output, paimonTable) =>
val newQuery = resolveQueryColumnsByPosition(a.query, table.output)
case a @ PaimonV2WriteCommand(table) if !paimonWriteResolved(a.query, table) =>
val newQuery = resolveQueryColumns(a.query, table, a.isByName)
if (newQuery != a.query) {
Compatibility.withNewQuery(a, newQuery)
} else {
Expand All @@ -67,89 +57,107 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
PaimonShowColumnsCommand(table)
}

private def needsSchemaAdjustmentByName(
private def paimonWriteResolved(query: LogicalPlan, table: NamedRelation): Boolean = {
query.output.size == table.output.size &&
query.output.zip(table.output).forall {
case (inAttr, outAttr) =>
val inType = CharVarcharUtils.getRawType(inAttr.metadata).getOrElse(inAttr.dataType)
val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType)
inAttr.name == outAttr.name && schemaCompatible(inType, outType)
}
}

private def resolveQueryColumns(
query: LogicalPlan,
targetAttrs: Seq[Attribute],
paimonTable: FileStoreTable): Boolean = {
val userSpecifiedNames = if (session.sessionState.conf.caseSensitiveAnalysis) {
query.output.map(a => (a.name, a)).toMap
table: NamedRelation,
byName: Boolean): LogicalPlan = {
// More details see: `TableOutputResolver#resolveOutputColumns`
if (byName) {
resolveQueryColumnsByName(query, table)
} else {
CaseInsensitiveMap(query.output.map(a => (a.name, a)).toMap)
resolveQueryColumnsByPosition(query, table)
}
val specifiedTargetAttrs = targetAttrs.filter(col => userSpecifiedNames.contains(col.name))
!schemaCompatible(
specifiedTargetAttrs.toStructType,
query.output.toStructType,
paimonTable.partitionKeys().asScala)
}

private def resolveQueryColumnsByName(
query: LogicalPlan,
targetAttrs: Seq[Attribute]): LogicalPlan = {
val output = query.output
val project = targetAttrs.map {
attr =>
val outputAttr = output
.find(t => session.sessionState.conf.resolver(t.name, attr.name))
.getOrElse {
private def resolveQueryColumnsByName(query: LogicalPlan, table: NamedRelation): LogicalPlan = {
val inputCols = query.output
val expectedCols = table.output
if (inputCols.size > expectedCols.size) {
throw new RuntimeException(
s"Cannot write incompatible data for the table `${table.name}`, " +
"the number of data columns don't match with the table schema's.")
}

val matchedCols = mutable.HashSet.empty[String]
val reorderedCols = expectedCols.map {
expectedCol =>
val matched = inputCols.filter(col => conf.resolver(col.name, expectedCol.name))
if (matched.isEmpty) {
// TODO: Support Spark default value framework if Paimon supports to change default values.
if (!expectedCol.nullable) {
throw new RuntimeException(
s"Cannot find ${attr.name} in data columns: ${output.map(_.name).mkString(", ")}")
s"Cannot write incompatible data for the table `${table.name}`, " +
s"due to non-nullable column `${expectedCol.name}` has no specified value.")
}
Alias(Literal(null, expectedCol.dataType), expectedCol.name)()
} else if (matched.length > 1) {
throw new RuntimeException(
s"Cannot write incompatible data for the table `${table.name}`, due to column name conflicts: ${matched
.mkString(", ")}.")
} else {
matchedCols += matched.head.name
val matchedCol = matched.head
val actualExpectedCol = expectedCol.withDataType {
CharVarcharUtils.getRawType(expectedCol.metadata).getOrElse(expectedCol.dataType)
}
addCastToColumn(outputAttr, attr, isByName = true)
addCastToColumn(matchedCol, actualExpectedCol, isByName = true)
}
}
Project(project, query)
}

private def needsSchemaAdjustmentByPosition(
query: LogicalPlan,
targetAttrs: Seq[Attribute],
paimonTable: FileStoreTable): Boolean = {
val output = query.output
targetAttrs.map(_.name) != output.map(_.name) ||
!schemaCompatible(
targetAttrs.toStructType,
output.toStructType,
paimonTable.partitionKeys().asScala)
assert(reorderedCols.length == expectedCols.length)
if (matchedCols.size < inputCols.length) {
val extraCols = inputCols
.filterNot(col => matchedCols.contains(col.name))
.map(col => s"${toSQLId(col.name)}")
.mkString(", ")
// There are seme unknown column names
throw new RuntimeException(
s"Cannot write incompatible data for the table `${table.name}`, due to unknown column names: ${extraCols
.mkString(", ")}.")
}
Project(reorderedCols, query)
}

private def resolveQueryColumnsByPosition(
query: LogicalPlan,
tableAttributes: Seq[Attribute]): LogicalPlan = {
val project = query.output.zipWithIndex.map {
table: NamedRelation): LogicalPlan = {
val expectedCols = table.output
val queryCols = query.output
if (queryCols.size != expectedCols.size) {
throw new RuntimeException(
s"Cannot write incompatible data for the table `${table.name}`, " +
"the number of data columns don't match with the table schema's.")
}

val project = queryCols.zipWithIndex.map {
case (attr, i) =>
val targetAttr = tableAttributes(i)
val targetAttr = expectedCols(i)
addCastToColumn(attr, targetAttr, isByName = false)
}
Project(project, query)
}

private def schemaCompatible(
dataSchema: StructType,
tableSchema: StructType,
partitionCols: Seq[String],
parent: Array[String] = Array.empty): Boolean = {

if (tableSchema.size != dataSchema.size) {
throw new RuntimeException("the number of data columns don't match with the table schema's.")
}

def dataTypeCompatible(column: String, dt1: DataType, dt2: DataType): Boolean = {
(dt1, dt2) match {
case (s1: StructType, s2: StructType) =>
schemaCompatible(s1, s2, partitionCols, Array(column))
case (a1: ArrayType, a2: ArrayType) =>
dataTypeCompatible(column, a1.elementType, a2.elementType)
case (m1: MapType, m2: MapType) =>
dataTypeCompatible(column, m1.keyType, m2.keyType) && dataTypeCompatible(
column,
m1.valueType,
m2.valueType)
case (d1, d2) => d1 == d2
}
}

dataSchema.zip(tableSchema).forall {
case (f1, f2) => dataTypeCompatible(f1.name, f1.dataType, f2.dataType)
private def schemaCompatible(dataSchema: DataType, tableSchema: DataType): Boolean = {
(dataSchema, tableSchema) match {
case (s1: StructType, s2: StructType) =>
s1.zip(s2).forall { case (d1, d2) => schemaCompatible(d1.dataType, d2.dataType) }
case (a1: ArrayType, a2: ArrayType) =>
a1.containsNull == a2.containsNull && schemaCompatible(a1.elementType, a2.elementType)
case (m1: MapType, m2: MapType) =>
m1.valueContainsNull == m2.valueContainsNull &&
schemaCompatible(m1.keyType, m2.keyType) &&
schemaCompatible(m1.valueType, m2.valueType)
case (d1, d2) => d1 == d2
}
}

Expand Down Expand Up @@ -244,6 +252,7 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
targetField.name
)(explicitMetadata = Option(targetField.metadata))
}

private def castToArrayStruct(
parent: NamedExpression,
source: StructType,
Expand Down Expand Up @@ -304,11 +313,10 @@ case class PaimonPostHocResolutionRules(session: SparkSession) extends Rule[Logi
}

object PaimonV2WriteCommand {
def unapply(o: V2WriteCommand): Option[(DataSourceV2Relation, FileStoreTable)] = {
def unapply(o: V2WriteCommand): Option[DataSourceV2Relation] = {
if (o.query.resolved) {
o.table match {
case r: DataSourceV2Relation if r.table.isInstanceOf[SparkTable] =>
Some((r, r.table.asInstanceOf[SparkTable].getTable.asInstanceOf[FileStoreTable]))
case r: DataSourceV2Relation if r.table.isInstanceOf[SparkTable] => Some(r)
case _ => None
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

package org.apache.paimon.spark

import org.apache.paimon.Snapshot
import org.apache.paimon.catalog.{Catalog, CatalogContext, CatalogFactory, Identifier}
import org.apache.paimon.options.{CatalogOptions, Options}
import org.apache.paimon.spark.catalog.Catalogs
Expand All @@ -37,7 +36,6 @@ import org.scalactic.source.Position
import org.scalatest.Tag

import java.io.File
import java.util
import java.util.{HashMap => JHashMap}
import java.util.TimeZone

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ abstract class InsertOverwriteTableTestBase extends PaimonSparkTestBase {
|TBLPROPERTIES ('file.format' = '$fileFormat')
|""".stripMargin)

spark.sql(s"""
|CREATE TABLE t2 (col2 INT, col3 DOUBLE, col1 STRING, col4 STRING)
|$partitionedSQL
|TBLPROPERTIES ('file.format' = '$fileFormat')
|""".stripMargin)
spark.sql(
s"""
|CREATE TABLE t2 (col2 INT, col3 DOUBLE, col1 STRING NOT NULL, col4 STRING)
|$partitionedSQL
|TBLPROPERTIES ('file.format' = '$fileFormat')
|""".stripMargin)

sql(s"""
|INSERT INTO TABLE t1 VALUES
Expand All @@ -68,11 +69,55 @@ abstract class InsertOverwriteTableTestBase extends PaimonSparkTestBase {
sql("INSERT OVERWRITE t2 (col1, col2, col3, col4) SELECT * FROM t1")
checkAnswer(
sql("SELECT * FROM t2 ORDER BY col2"),
Row(1, 1.1d, "Hello", "pt1") :: Row(2, 2.2d, "World", "pt1") :: Row(
3,
3.3d,
"Paimon",
"pt2") :: Nil)
Row(1, 1.1d, "Hello", "pt1") :: Row(2, 2.2d, "World", "pt1") ::
Row(3, 3.3d, "Paimon", "pt2") :: Nil
)

// BY NAME statementis supported since Spark3.5
if (gteqSpark3_5) {
sql("INSERT OVERWRITE TABLE t1 BY NAME SELECT col3, col2, col4, col1 FROM t1")
// null for non-specified column
sql("INSERT OVERWRITE TABLE t2 BY NAME SELECT col1, col2 FROM t2 ")
checkAnswer(
sql("SELECT * FROM t2 ORDER BY col2"),
Row(1, null, "Hello", null) :: Row(2, null, "World", null) ::
Row(3, null, "Paimon", null) :: Nil
)

// by name bad case
// names conflict
val msg1 = intercept[Exception] {
sql("INSERT INTO TABLE t1 BY NAME SELECT col1, col2 as col1 FROM t1")
}
assert(msg1.getMessage.contains("due to column name conflicts"))
// name does not match
val msg2 = intercept[Exception] {
sql("INSERT INTO TABLE t1 BY NAME SELECT col1, col2 as colx FROM t1")
}
assert(msg2.getMessage.contains("due to unknown column names"))
// query column size bigger than table's
val msg3 = intercept[Exception] {
sql("INSERT INTO TABLE t1 BY NAME SELECT col1, col2, col3, col4, col4 as col5 FROM t1")
}
assert(
msg3.getMessage.contains(
"the number of data columns don't match with the table schema"))
// non-nullable column has no specified value
val msg4 = intercept[Exception] {
sql("INSERT INTO TABLE t2 BY NAME SELECT col2 FROM t2")
}
assert(
msg4.getMessage.contains("non-nullable column `col1` has no specified value"))

// by position
// column size does not match
val msg5 = intercept[Exception] {
sql("INSERT INTO TABLE t1 VALUES(1)")
}
assert(
msg5.getMessage.contains(
"the number of data columns don't match with the table schema"))
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,6 @@ trait SparkVersionSupport {
lazy val gteqSpark3_3: Boolean = sparkVersion >= "3.3"

lazy val gteqSpark3_4: Boolean = sparkVersion >= "3.4"

lazy val gteqSpark3_5: Boolean = sparkVersion >= "3.5"
}

0 comments on commit 8f0a5f3

Please sign in to comment.