Skip to content

Commit

Permalink
[core][spark] check column nullability when write (#3842)
Browse files Browse the repository at this point in the history
  • Loading branch information
YannByron authored Jul 30, 2024
1 parent 3682344 commit 6435dd2
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ public TableWriteImpl<InternalRow> newWrite(
AppendOnlyFileStoreWrite writer =
store().newWrite(commitUser, manifestFilter).withBucketMode(bucketMode());
return new TableWriteImpl<>(
rowType(),
writer,
createRowKeyExtractor(),
(record, rowKind) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ public TableWriteImpl<KeyValue> newWrite(
String commitUser, ManifestCacheFilter manifestFilter) {
KeyValue kv = new KeyValue();
return new TableWriteImpl<>(
rowType(),
store().newWrite(commitUser, manifestFilter),
createRowKeyExtractor(),
(record, rowKind) ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@
import org.apache.paimon.operation.FileStoreWrite;
import org.apache.paimon.operation.FileStoreWrite.State;
import org.apache.paimon.table.BucketMode;
import org.apache.paimon.types.DataField;
import org.apache.paimon.types.RowKind;
import org.apache.paimon.types.RowType;
import org.apache.paimon.utils.Restorable;

import javax.annotation.Nullable;

import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;

import static org.apache.paimon.utils.Preconditions.checkState;

Expand All @@ -47,6 +50,7 @@
*/
public class TableWriteImpl<T> implements InnerTableWrite, Restorable<List<State<T>>> {

private final RowType rowType;
private final FileStoreWrite<T> write;
private final KeyAndBucketExtractor<InternalRow> keyAndBucketExtractor;
private final RecordExtractor<T> recordExtractor;
Expand All @@ -56,17 +60,28 @@ public class TableWriteImpl<T> implements InnerTableWrite, Restorable<List<State
private boolean batchCommitted = false;
private BucketMode bucketMode;

private final int[] notNullFieldIndex;

public TableWriteImpl(
RowType rowType,
FileStoreWrite<T> write,
KeyAndBucketExtractor<InternalRow> keyAndBucketExtractor,
RecordExtractor<T> recordExtractor,
@Nullable RowKindGenerator rowKindGenerator,
boolean ignoreDelete) {
this.rowType = rowType;
this.write = write;
this.keyAndBucketExtractor = keyAndBucketExtractor;
this.recordExtractor = recordExtractor;
this.rowKindGenerator = rowKindGenerator;
this.ignoreDelete = ignoreDelete;

List<String> notNullColumnNames =
rowType.getFields().stream()
.filter(field -> !field.type().isNullable())
.map(DataField::name)
.collect(Collectors.toList());
this.notNullFieldIndex = rowType.getFieldIndices(notNullColumnNames);
}

@Override
Expand Down Expand Up @@ -137,6 +152,7 @@ public void write(InternalRow row, int bucket) throws Exception {

@Nullable
public SinkRecord writeAndReturn(InternalRow row) throws Exception {
checkNullability(row);
RowKind rowKind = RowKindGenerator.getRowKind(rowKindGenerator, row);
if (ignoreDelete && rowKind.isRetract()) {
return null;
Expand All @@ -148,6 +164,7 @@ public SinkRecord writeAndReturn(InternalRow row) throws Exception {

@Nullable
public SinkRecord writeAndReturn(InternalRow row, int bucket) throws Exception {
checkNullability(row);
RowKind rowKind = RowKindGenerator.getRowKind(rowKindGenerator, row);
if (ignoreDelete && rowKind.isRetract()) {
return null;
Expand All @@ -157,6 +174,16 @@ public SinkRecord writeAndReturn(InternalRow row, int bucket) throws Exception {
return record;
}

private void checkNullability(InternalRow row) {
for (int idx : notNullFieldIndex) {
if (row.isNullAt(idx)) {
String columnName = rowType.getFields().get(idx).name();
throw new RuntimeException(
String.format("Cannot write null to non-null column(%s)", columnName));
}
}
}

private SinkRecord toSinkRecord(InternalRow row) {
keyAndBucketExtractor.setRecord(row);
return new SinkRecord(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ import org.apache.paimon.table.FileStoreTable

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.ResolvedTable
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -58,8 +58,8 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
}

private def schemaCompatible(
tableSchema: StructType,
dataSchema: StructType,
tableSchema: StructType,
partitionCols: Seq[String],
parent: Array[String] = Array.empty): Boolean = {

Expand All @@ -82,9 +82,8 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
}
}

tableSchema.zip(dataSchema).forall {
dataSchema.zip(tableSchema).forall {
case (f1, f2) =>
checkNullability(f1, f2, partitionCols, parent)
f1.name == f2.name && dataTypeCompatible(f1.name, f1.dataType, f2.dataType)
}
}
Expand Down Expand Up @@ -115,17 +114,6 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
cast.setTagValue(Compatibility.castByTableInsertionTag, ())
cast
}

private def checkNullability(
input: StructField,
expected: StructField,
partitionCols: Seq[String],
parent: Array[String] = Array.empty): Unit = {
val fullColumnName = (parent ++ Array(input.name)).mkString(".")
if (!partitionCols.contains(fullColumnName) && input.nullable && !expected.nullable) {
throw new RuntimeException("Cannot write nullable values to non-null column")
}
}
}

case class PaimonPostHocResolutionRules(session: SparkSession) extends Rule[LogicalPlan] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.paimon.schema.Schema
import org.apache.paimon.spark.PaimonSparkTestBase
import org.apache.paimon.types.DataTypes

import org.apache.spark.SparkException
import org.apache.spark.sql.Row
import org.junit.jupiter.api.Assertions

Expand All @@ -33,33 +34,70 @@ abstract class DDLTestBase extends PaimonSparkTestBase {

import testImplicits._

test("Paimon DDL: create table with not null") {
test("Paimon DDL: create append table with not null") {
withTable("T") {
sql("""
|CREATE TABLE T (id INT NOT NULL, name STRING)
|""".stripMargin)
sql("CREATE TABLE T (id INT NOT NULL, name STRING)")

val exception = intercept[RuntimeException] {
sql("""
|INSERT INTO T VALUES (1, "a"), (2, "b"), (null, "c")
|""".stripMargin)
val e1 = intercept[SparkException] {
sql("""INSERT INTO T VALUES (1, "a"), (2, "b"), (null, "c")""")
}
Assertions.assertTrue(
exception.getMessage().contains("Cannot write nullable values to non-null column"))
Assertions.assertTrue(e1.getMessage().contains("Cannot write null to non-null column"))

sql("""INSERT INTO T VALUES (1, "a"), (2, "b"), (3, null)""")
checkAnswer(
sql("SELECT * FROM T ORDER BY id"),
Seq((1, "a"), (2, "b"), (3, null)).toDF()
)

val schema = spark.table("T").schema
Assertions.assertEquals(schema.size, 2)
Assertions.assertFalse(schema("id").nullable)
Assertions.assertTrue(schema("name").nullable)
}
}
test("Paimon DDL: create primary-key table with not null") {
withTable("T") {
sql("""
|INSERT INTO T VALUES (1, "a"), (2, "b"), (3, null)
|CREATE TABLE T (id INT, name STRING, pt STRING)
|TBLPROPERTIES ('primary-key' = 'id,pt')
|""".stripMargin)

val e1 = intercept[SparkException] {
sql("""INSERT INTO T VALUES (1, "a", "pt1"), (2, "b", null)""")
}
Assertions.assertTrue(e1.getMessage().contains("Cannot write null to non-null column"))

val e2 = intercept[SparkException] {
sql("""INSERT INTO T VALUES (1, "a", "pt1"), (null, "b", "pt2")""")
}
Assertions.assertTrue(e2.getMessage().contains("Cannot write null to non-null column"))

sql("""INSERT INTO T VALUES (1, "a", "pt1"), (2, "b", "pt1"), (3, null, "pt2")""")
checkAnswer(
sql("SELECT * FROM T ORDER BY id"),
Seq((1, "a"), (2, "b"), (3, null)).toDF()
Seq((1, "a", "pt1"), (2, "b", "pt1"), (3, null, "pt2")).toDF()
)

val schema = spark.table("T").schema
Assertions.assertEquals(schema.size, 2)
Assertions.assertEquals(schema.size, 3)
Assertions.assertFalse(schema("id").nullable)
Assertions.assertTrue(schema("name").nullable)
Assertions.assertFalse(schema("pt").nullable)
}
}

test("Paimon DDL: write nullable expression to non-null column") {
withTable("T") {
sql("""
|CREATE TABLE T (id INT NOT NULL, ts TIMESTAMP NOT NULL)
|""".stripMargin)

sql("INSERT INTO T SELECT 1, TO_TIMESTAMP('2024-07-01 16:00:00')")

checkAnswer(
sql("SELECT * FROM T ORDER BY id"),
Row(1, Timestamp.valueOf("2024-07-01 16:00:00")) :: Nil
)
}
}

Expand Down

0 comments on commit 6435dd2

Please sign in to comment.