Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[spark] check column nullability when write #3842

Merged
merged 3 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ public TableWriteImpl<InternalRow> newWrite(
AppendOnlyFileStoreWrite writer =
store().newWrite(commitUser, manifestFilter).withBucketMode(bucketMode());
return new TableWriteImpl<>(
rowType(),
writer,
createRowKeyExtractor(),
(record, rowKind) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ public TableWriteImpl<KeyValue> newWrite(
String commitUser, ManifestCacheFilter manifestFilter) {
KeyValue kv = new KeyValue();
return new TableWriteImpl<>(
rowType(),
store().newWrite(commitUser, manifestFilter),
createRowKeyExtractor(),
(record, rowKind) ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@
import org.apache.paimon.operation.FileStoreWrite;
import org.apache.paimon.operation.FileStoreWrite.State;
import org.apache.paimon.table.BucketMode;
import org.apache.paimon.types.DataField;
import org.apache.paimon.types.RowKind;
import org.apache.paimon.types.RowType;
import org.apache.paimon.utils.Restorable;

import javax.annotation.Nullable;

import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;

import static org.apache.paimon.utils.Preconditions.checkState;

Expand All @@ -47,6 +50,7 @@
*/
public class TableWriteImpl<T> implements InnerTableWrite, Restorable<List<State<T>>> {

private final RowType rowType;
private final FileStoreWrite<T> write;
private final KeyAndBucketExtractor<InternalRow> keyAndBucketExtractor;
private final RecordExtractor<T> recordExtractor;
Expand All @@ -56,17 +60,28 @@ public class TableWriteImpl<T> implements InnerTableWrite, Restorable<List<State
private boolean batchCommitted = false;
private BucketMode bucketMode;

private final int[] notNullFieldIndex;

public TableWriteImpl(
RowType rowType,
FileStoreWrite<T> write,
KeyAndBucketExtractor<InternalRow> keyAndBucketExtractor,
RecordExtractor<T> recordExtractor,
@Nullable RowKindGenerator rowKindGenerator,
boolean ignoreDelete) {
this.rowType = rowType;
this.write = write;
this.keyAndBucketExtractor = keyAndBucketExtractor;
this.recordExtractor = recordExtractor;
this.rowKindGenerator = rowKindGenerator;
this.ignoreDelete = ignoreDelete;

List<String> notNullColumnNames =
rowType.getFields().stream()
.filter(field -> !field.type().isNullable())
.map(DataField::name)
.collect(Collectors.toList());
this.notNullFieldIndex = rowType.getFieldIndices(notNullColumnNames);
}

@Override
Expand Down Expand Up @@ -137,6 +152,7 @@ public void write(InternalRow row, int bucket) throws Exception {

@Nullable
public SinkRecord writeAndReturn(InternalRow row) throws Exception {
checkNullability(row);
RowKind rowKind = RowKindGenerator.getRowKind(rowKindGenerator, row);
if (ignoreDelete && rowKind.isRetract()) {
return null;
Expand All @@ -148,6 +164,7 @@ public SinkRecord writeAndReturn(InternalRow row) throws Exception {

@Nullable
public SinkRecord writeAndReturn(InternalRow row, int bucket) throws Exception {
checkNullability(row);
RowKind rowKind = RowKindGenerator.getRowKind(rowKindGenerator, row);
if (ignoreDelete && rowKind.isRetract()) {
return null;
Expand All @@ -157,6 +174,16 @@ public SinkRecord writeAndReturn(InternalRow row, int bucket) throws Exception {
return record;
}

private void checkNullability(InternalRow row) {
for (int idx : notNullFieldIndex) {
if (row.isNullAt(idx)) {
String columnName = rowType.getFields().get(idx).name();
throw new RuntimeException(
String.format("Cannot write null to non-null column(%s)", columnName));
}
}
}

private SinkRecord toSinkRecord(InternalRow row) {
keyAndBucketExtractor.setRecord(row);
return new SinkRecord(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ import org.apache.paimon.table.FileStoreTable

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.ResolvedTable
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -58,8 +58,8 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
}

private def schemaCompatible(
tableSchema: StructType,
dataSchema: StructType,
tableSchema: StructType,
partitionCols: Seq[String],
parent: Array[String] = Array.empty): Boolean = {

Expand All @@ -82,9 +82,8 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
}
}

tableSchema.zip(dataSchema).forall {
dataSchema.zip(tableSchema).forall {
case (f1, f2) =>
checkNullability(f1, f2, partitionCols, parent)
f1.name == f2.name && dataTypeCompatible(f1.name, f1.dataType, f2.dataType)
}
}
Expand Down Expand Up @@ -115,17 +114,6 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
cast.setTagValue(Compatibility.castByTableInsertionTag, ())
cast
}

private def checkNullability(
input: StructField,
expected: StructField,
partitionCols: Seq[String],
parent: Array[String] = Array.empty): Unit = {
val fullColumnName = (parent ++ Array(input.name)).mkString(".")
if (!partitionCols.contains(fullColumnName) && input.nullable && !expected.nullable) {
throw new RuntimeException("Cannot write nullable values to non-null column")
}
}
}

case class PaimonPostHocResolutionRules(session: SparkSession) extends Rule[LogicalPlan] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.paimon.schema.Schema
import org.apache.paimon.spark.PaimonSparkTestBase
import org.apache.paimon.types.DataTypes

import org.apache.spark.SparkException
import org.apache.spark.sql.Row
import org.junit.jupiter.api.Assertions

Expand All @@ -33,33 +34,70 @@ abstract class DDLTestBase extends PaimonSparkTestBase {

import testImplicits._

test("Paimon DDL: create table with not null") {
test("Paimon DDL: create append table with not null") {
withTable("T") {
sql("""
|CREATE TABLE T (id INT NOT NULL, name STRING)
|""".stripMargin)
sql("CREATE TABLE T (id INT NOT NULL, name STRING)")

val exception = intercept[RuntimeException] {
sql("""
|INSERT INTO T VALUES (1, "a"), (2, "b"), (null, "c")
|""".stripMargin)
val e1 = intercept[SparkException] {
sql("""INSERT INTO T VALUES (1, "a"), (2, "b"), (null, "c")""")
}
Assertions.assertTrue(
exception.getMessage().contains("Cannot write nullable values to non-null column"))
Assertions.assertTrue(e1.getMessage().contains("Cannot write null to non-null column"))

sql("""INSERT INTO T VALUES (1, "a"), (2, "b"), (3, null)""")
checkAnswer(
sql("SELECT * FROM T ORDER BY id"),
Seq((1, "a"), (2, "b"), (3, null)).toDF()
)

val schema = spark.table("T").schema
Assertions.assertEquals(schema.size, 2)
Assertions.assertFalse(schema("id").nullable)
Assertions.assertTrue(schema("name").nullable)
}
}
test("Paimon DDL: create primary-key table with not null") {
withTable("T") {
sql("""
|INSERT INTO T VALUES (1, "a"), (2, "b"), (3, null)
|CREATE TABLE T (id INT, name STRING, pt STRING)
|TBLPROPERTIES ('primary-key' = 'id,pt')
|""".stripMargin)

val e1 = intercept[SparkException] {
sql("""INSERT INTO T VALUES (1, "a", "pt1"), (2, "b", null)""")
}
Assertions.assertTrue(e1.getMessage().contains("Cannot write null to non-null column"))

val e2 = intercept[SparkException] {
sql("""INSERT INTO T VALUES (1, "a", "pt1"), (null, "b", "pt2")""")
}
Assertions.assertTrue(e2.getMessage().contains("Cannot write null to non-null column"))

sql("""INSERT INTO T VALUES (1, "a", "pt1"), (2, "b", "pt1"), (3, null, "pt2")""")
checkAnswer(
sql("SELECT * FROM T ORDER BY id"),
Seq((1, "a"), (2, "b"), (3, null)).toDF()
Seq((1, "a", "pt1"), (2, "b", "pt1"), (3, null, "pt2")).toDF()
)

val schema = spark.table("T").schema
Assertions.assertEquals(schema.size, 2)
Assertions.assertEquals(schema.size, 3)
Assertions.assertFalse(schema("id").nullable)
Assertions.assertTrue(schema("name").nullable)
Assertions.assertFalse(schema("pt").nullable)
}
}

test("Paimon DDL: write nullable expression to non-null column") {
withTable("T") {
sql("""
|CREATE TABLE T (id INT NOT NULL, ts TIMESTAMP NOT NULL)
|""".stripMargin)

sql("INSERT INTO T SELECT 1, TO_TIMESTAMP('2024-07-01 16:00:00')")

checkAnswer(
sql("SELECT * FROM T ORDER BY id"),
Row(1, Timestamp.valueOf("2024-07-01 16:00:00")) :: Nil
)
}
}

Expand Down
Loading