diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala index 816535d3b4bf..1b28fa9b63ca 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala @@ -25,32 +25,22 @@ import org.apache.paimon.spark.commands.{PaimonAnalyzeTableColumnCommand, Paimon import org.apache.paimon.table.FileStoreTable import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.analysis.ResolvedTable -import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, Attribute, CreateStruct, Expression, GetArrayItem, GetStructField, LambdaFunction, NamedExpression, NamedLambdaVariable} +import org.apache.spark.sql.catalyst.analysis.{NamedRelation, ResolvedTable} +import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, Attribute, CreateStruct, Expression, GetArrayItem, GetStructField, LambdaFunction, Literal, NamedExpression, NamedLambdaVariable} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, MapType, StructField, StructType} -import scala.collection.JavaConverters._ +import scala.collection.mutable class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsDown { - - case a @ PaimonV2WriteCommand(table, paimonTable) - if a.isByName && needsSchemaAdjustmentByName(a.query, table.output, paimonTable) => - val newQuery = resolveQueryColumnsByName(a.query, table.output) - if (newQuery != a.query) { - Compatibility.withNewQuery(a, newQuery) - } else { - a - } - - case a @ PaimonV2WriteCommand(table, paimonTable) - if !a.isByName && needsSchemaAdjustmentByPosition(a.query, table.output, paimonTable) => - val newQuery = resolveQueryColumnsByPosition(a.query, table.output) + case a @ PaimonV2WriteCommand(table) if !paimonWriteResolved(a.query, table) => + val newQuery = resolveQueryColumns(a.query, table, a.isByName) if (newQuery != a.query) { Compatibility.withNewQuery(a, newQuery) } else { @@ -67,89 +57,107 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] { PaimonShowColumnsCommand(table) } - private def needsSchemaAdjustmentByName( + private def paimonWriteResolved(query: LogicalPlan, table: NamedRelation): Boolean = { + query.output.size == table.output.size && + query.output.zip(table.output).forall { + case (inAttr, outAttr) => + val inType = CharVarcharUtils.getRawType(inAttr.metadata).getOrElse(inAttr.dataType) + val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType) + inAttr.name == outAttr.name && schemaCompatible(inType, outType) + } + } + + private def resolveQueryColumns( query: LogicalPlan, - targetAttrs: Seq[Attribute], - paimonTable: FileStoreTable): Boolean = { - val userSpecifiedNames = if (session.sessionState.conf.caseSensitiveAnalysis) { - query.output.map(a => (a.name, a)).toMap + table: NamedRelation, + byName: Boolean): LogicalPlan = { + // More details see: `TableOutputResolver#resolveOutputColumns` + if (byName) { + resolveQueryColumnsByName(query, table) } else { - CaseInsensitiveMap(query.output.map(a => (a.name, a)).toMap) + resolveQueryColumnsByPosition(query, table) } - val specifiedTargetAttrs = targetAttrs.filter(col => userSpecifiedNames.contains(col.name)) - !schemaCompatible( - specifiedTargetAttrs.toStructType, - query.output.toStructType, - paimonTable.partitionKeys().asScala) } - private def resolveQueryColumnsByName( - query: LogicalPlan, - targetAttrs: Seq[Attribute]): LogicalPlan = { - val output = query.output - val project = targetAttrs.map { - attr => - val outputAttr = output - .find(t => session.sessionState.conf.resolver(t.name, attr.name)) - .getOrElse { + private def resolveQueryColumnsByName(query: LogicalPlan, table: NamedRelation): LogicalPlan = { + val inputCols = query.output + val expectedCols = table.output + if (inputCols.size > expectedCols.size) { + throw new RuntimeException( + s"Cannot write incompatible data for the table `${table.name}`, " + + "the number of data columns don't match with the table schema's.") + } + + val matchedCols = mutable.HashSet.empty[String] + val reorderedCols = expectedCols.map { + expectedCol => + val matched = inputCols.filter(col => conf.resolver(col.name, expectedCol.name)) + if (matched.isEmpty) { + // TODO: Support Spark default value framework if Paimon supports to change default values. + if (!expectedCol.nullable) { throw new RuntimeException( - s"Cannot find ${attr.name} in data columns: ${output.map(_.name).mkString(", ")}") + s"Cannot write incompatible data for the table `${table.name}`, " + + s"due to non-nullable column `${expectedCol.name}` has no specified value.") + } + Alias(Literal(null, expectedCol.dataType), expectedCol.name)() + } else if (matched.length > 1) { + throw new RuntimeException( + s"Cannot write incompatible data for the table `${table.name}`, due to column name conflicts: ${matched + .mkString(", ")}.") + } else { + matchedCols += matched.head.name + val matchedCol = matched.head + val actualExpectedCol = expectedCol.withDataType { + CharVarcharUtils.getRawType(expectedCol.metadata).getOrElse(expectedCol.dataType) } - addCastToColumn(outputAttr, attr, isByName = true) + addCastToColumn(matchedCol, actualExpectedCol, isByName = true) + } } - Project(project, query) - } - private def needsSchemaAdjustmentByPosition( - query: LogicalPlan, - targetAttrs: Seq[Attribute], - paimonTable: FileStoreTable): Boolean = { - val output = query.output - targetAttrs.map(_.name) != output.map(_.name) || - !schemaCompatible( - targetAttrs.toStructType, - output.toStructType, - paimonTable.partitionKeys().asScala) + assert(reorderedCols.length == expectedCols.length) + if (matchedCols.size < inputCols.length) { + val extraCols = inputCols + .filterNot(col => matchedCols.contains(col.name)) + .map(col => s"${toSQLId(col.name)}") + .mkString(", ") + // There are seme unknown column names + throw new RuntimeException( + s"Cannot write incompatible data for the table `${table.name}`, due to unknown column names: ${extraCols + .mkString(", ")}.") + } + Project(reorderedCols, query) } private def resolveQueryColumnsByPosition( query: LogicalPlan, - tableAttributes: Seq[Attribute]): LogicalPlan = { - val project = query.output.zipWithIndex.map { + table: NamedRelation): LogicalPlan = { + val expectedCols = table.output + val queryCols = query.output + if (queryCols.size != expectedCols.size) { + throw new RuntimeException( + s"Cannot write incompatible data for the table `${table.name}`, " + + "the number of data columns don't match with the table schema's.") + } + + val project = queryCols.zipWithIndex.map { case (attr, i) => - val targetAttr = tableAttributes(i) + val targetAttr = expectedCols(i) addCastToColumn(attr, targetAttr, isByName = false) } Project(project, query) } - private def schemaCompatible( - dataSchema: StructType, - tableSchema: StructType, - partitionCols: Seq[String], - parent: Array[String] = Array.empty): Boolean = { - - if (tableSchema.size != dataSchema.size) { - throw new RuntimeException("the number of data columns don't match with the table schema's.") - } - - def dataTypeCompatible(column: String, dt1: DataType, dt2: DataType): Boolean = { - (dt1, dt2) match { - case (s1: StructType, s2: StructType) => - schemaCompatible(s1, s2, partitionCols, Array(column)) - case (a1: ArrayType, a2: ArrayType) => - dataTypeCompatible(column, a1.elementType, a2.elementType) - case (m1: MapType, m2: MapType) => - dataTypeCompatible(column, m1.keyType, m2.keyType) && dataTypeCompatible( - column, - m1.valueType, - m2.valueType) - case (d1, d2) => d1 == d2 - } - } - - dataSchema.zip(tableSchema).forall { - case (f1, f2) => dataTypeCompatible(f1.name, f1.dataType, f2.dataType) + private def schemaCompatible(dataSchema: DataType, tableSchema: DataType): Boolean = { + (dataSchema, tableSchema) match { + case (s1: StructType, s2: StructType) => + s1.zip(s2).forall { case (d1, d2) => schemaCompatible(d1.dataType, d2.dataType) } + case (a1: ArrayType, a2: ArrayType) => + a1.containsNull == a2.containsNull && schemaCompatible(a1.elementType, a2.elementType) + case (m1: MapType, m2: MapType) => + m1.valueContainsNull == m2.valueContainsNull && + schemaCompatible(m1.keyType, m2.keyType) && + schemaCompatible(m1.valueType, m2.valueType) + case (d1, d2) => d1 == d2 } } @@ -244,6 +252,7 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] { targetField.name )(explicitMetadata = Option(targetField.metadata)) } + private def castToArrayStruct( parent: NamedExpression, source: StructType, @@ -304,11 +313,10 @@ case class PaimonPostHocResolutionRules(session: SparkSession) extends Rule[Logi } object PaimonV2WriteCommand { - def unapply(o: V2WriteCommand): Option[(DataSourceV2Relation, FileStoreTable)] = { + def unapply(o: V2WriteCommand): Option[DataSourceV2Relation] = { if (o.query.resolved) { o.table match { - case r: DataSourceV2Relation if r.table.isInstanceOf[SparkTable] => - Some((r, r.table.asInstanceOf[SparkTable].getTable.asInstanceOf[FileStoreTable])) + case r: DataSourceV2Relation if r.table.isInstanceOf[SparkTable] => Some(r) case _ => None } } else { diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/PaimonSparkTestBase.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/PaimonSparkTestBase.scala index 19e711a600a9..3deb91cbcba7 100644 --- a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/PaimonSparkTestBase.scala +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/PaimonSparkTestBase.scala @@ -18,7 +18,6 @@ package org.apache.paimon.spark -import org.apache.paimon.Snapshot import org.apache.paimon.catalog.{Catalog, CatalogContext, CatalogFactory, Identifier} import org.apache.paimon.options.{CatalogOptions, Options} import org.apache.paimon.spark.catalog.Catalogs @@ -37,7 +36,6 @@ import org.scalactic.source.Position import org.scalatest.Tag import java.io.File -import java.util import java.util.{HashMap => JHashMap} import java.util.TimeZone diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTest.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTestBase.scala similarity index 86% rename from paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTest.scala rename to paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTestBase.scala index 528df32e6a2b..6d9ceb1873a2 100644 --- a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTest.scala +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTestBase.scala @@ -45,11 +45,12 @@ abstract class InsertOverwriteTableTestBase extends PaimonSparkTestBase { |TBLPROPERTIES ('file.format' = '$fileFormat') |""".stripMargin) - spark.sql(s""" - |CREATE TABLE t2 (col2 INT, col3 DOUBLE, col1 STRING, col4 STRING) - |$partitionedSQL - |TBLPROPERTIES ('file.format' = '$fileFormat') - |""".stripMargin) + spark.sql( + s""" + |CREATE TABLE t2 (col2 INT, col3 DOUBLE, col1 STRING NOT NULL, col4 STRING) + |$partitionedSQL + |TBLPROPERTIES ('file.format' = '$fileFormat') + |""".stripMargin) sql(s""" |INSERT INTO TABLE t1 VALUES @@ -68,11 +69,55 @@ abstract class InsertOverwriteTableTestBase extends PaimonSparkTestBase { sql("INSERT OVERWRITE t2 (col1, col2, col3, col4) SELECT * FROM t1") checkAnswer( sql("SELECT * FROM t2 ORDER BY col2"), - Row(1, 1.1d, "Hello", "pt1") :: Row(2, 2.2d, "World", "pt1") :: Row( - 3, - 3.3d, - "Paimon", - "pt2") :: Nil) + Row(1, 1.1d, "Hello", "pt1") :: Row(2, 2.2d, "World", "pt1") :: + Row(3, 3.3d, "Paimon", "pt2") :: Nil + ) + + // BY NAME statementis supported since Spark3.5 + if (gteqSpark3_5) { + sql("INSERT OVERWRITE TABLE t1 BY NAME SELECT col3, col2, col4, col1 FROM t1") + // null for non-specified column + sql("INSERT OVERWRITE TABLE t2 BY NAME SELECT col1, col2 FROM t2 ") + checkAnswer( + sql("SELECT * FROM t2 ORDER BY col2"), + Row(1, null, "Hello", null) :: Row(2, null, "World", null) :: + Row(3, null, "Paimon", null) :: Nil + ) + + // by name bad case + // names conflict + val msg1 = intercept[Exception] { + sql("INSERT INTO TABLE t1 BY NAME SELECT col1, col2 as col1 FROM t1") + } + assert(msg1.getMessage.contains("due to column name conflicts")) + // name does not match + val msg2 = intercept[Exception] { + sql("INSERT INTO TABLE t1 BY NAME SELECT col1, col2 as colx FROM t1") + } + assert(msg2.getMessage.contains("due to unknown column names")) + // query column size bigger than table's + val msg3 = intercept[Exception] { + sql("INSERT INTO TABLE t1 BY NAME SELECT col1, col2, col3, col4, col4 as col5 FROM t1") + } + assert( + msg3.getMessage.contains( + "the number of data columns don't match with the table schema")) + // non-nullable column has no specified value + val msg4 = intercept[Exception] { + sql("INSERT INTO TABLE t2 BY NAME SELECT col2 FROM t2") + } + assert( + msg4.getMessage.contains("non-nullable column `col1` has no specified value")) + + // by position + // column size does not match + val msg5 = intercept[Exception] { + sql("INSERT INTO TABLE t1 VALUES(1)") + } + assert( + msg5.getMessage.contains( + "the number of data columns don't match with the table schema")) + } } } } diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/SparkVersionSupport.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/SparkVersionSupport.scala index 5ac408934fd0..fed73ba0f9e2 100644 --- a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/SparkVersionSupport.scala +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/SparkVersionSupport.scala @@ -26,4 +26,6 @@ trait SparkVersionSupport { lazy val gteqSpark3_3: Boolean = sparkVersion >= "3.3" lazy val gteqSpark3_4: Boolean = sparkVersion >= "3.4" + + lazy val gteqSpark3_5: Boolean = sparkVersion >= "3.5" }