From d9397390159bf1c7dbaf1dbdd2946ce9d622543f Mon Sep 17 00:00:00 2001 From: Zouxxyy Date: Wed, 21 Feb 2024 09:10:49 +0800 Subject: [PATCH] [spark] SparkFilterConverter supports convert EqualNullSafe (#2878) --- .../paimon/spark/SparkFilterConverter.java | 17 +++++++++++++---- .../paimon/spark/SparkFilterConverterTest.java | 11 +++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java index 4f7cee52ce43..f944ae1a2cc6 100644 --- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java @@ -24,6 +24,7 @@ import org.apache.paimon.types.RowType; import org.apache.spark.sql.sources.And; +import org.apache.spark.sql.sources.EqualNullSafe; import org.apache.spark.sql.sources.EqualTo; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.sources.GreaterThan; @@ -49,6 +50,7 @@ public class SparkFilterConverter { public static final List SUPPORT_FILTERS = Arrays.asList( "EqualTo", + "EqualNullSafe", "GreaterThan", "GreaterThanOrEqual", "LessThan", @@ -76,6 +78,15 @@ public Predicate convert(Filter filter) { int index = fieldIndex(eq.attribute()); Object literal = convertLiteral(index, eq.value()); return builder.equal(index, literal); + } else if (filter instanceof EqualNullSafe) { + EqualNullSafe eq = (EqualNullSafe) filter; + if (eq.value() == null) { + return builder.isNull(fieldIndex(eq.attribute())); + } else { + int index = fieldIndex(eq.attribute()); + Object literal = convertLiteral(index, eq.value()); + return builder.equal(index, literal); + } } else if (filter instanceof GreaterThan) { GreaterThan gt = (GreaterThan) filter; int index = fieldIndex(gt.attribute()); @@ -124,15 +135,13 @@ public Predicate convert(Filter filter) { return builder.startsWith(index, literal); } - // TODO: In, NotIn, AlwaysTrue, AlwaysFalse, EqualNullSafe + // TODO: AlwaysTrue, AlwaysFalse throw new UnsupportedOperationException( filter + " is unsupported. Support Filters: " + SUPPORT_FILTERS); } public Object convertLiteral(String field, Object value) { - int index = fieldIndex(field); - DataType type = rowType.getTypeAt(index); - return convertJavaObject(type, value); + return convertLiteral(fieldIndex(field), value); } private int fieldIndex(String field) { diff --git a/paimon-spark/paimon-spark-common/src/test/java/org/apache/paimon/spark/SparkFilterConverterTest.java b/paimon-spark/paimon-spark-common/src/test/java/org/apache/paimon/spark/SparkFilterConverterTest.java index 4f7e4643376b..9f669d493715 100644 --- a/paimon-spark/paimon-spark-common/src/test/java/org/apache/paimon/spark/SparkFilterConverterTest.java +++ b/paimon-spark/paimon-spark-common/src/test/java/org/apache/paimon/spark/SparkFilterConverterTest.java @@ -27,6 +27,7 @@ import org.apache.paimon.types.RowType; import org.apache.paimon.types.TimestampType; +import org.apache.spark.sql.sources.EqualNullSafe; import org.apache.spark.sql.sources.EqualTo; import org.apache.spark.sql.sources.GreaterThan; import org.apache.spark.sql.sources.GreaterThanOrEqual; @@ -118,6 +119,16 @@ public void testAll() { Predicate actualEqNull = converter.convert(eqNull); assertThat(actualEqNull).isEqualTo(expectedEqNull); + EqualNullSafe eqSafe = EqualNullSafe.apply(field, 1); + Predicate expectedEqSafe = builder.equal(0, 1); + Predicate actualEqSafe = converter.convert(eqSafe); + assertThat(actualEqSafe).isEqualTo(expectedEqSafe); + + EqualNullSafe eqNullSafe = EqualNullSafe.apply(field, null); + Predicate expectEqNullSafe = builder.isNull(0); + Predicate actualEqNullSafe = converter.convert(eqNullSafe); + assertThat(actualEqNullSafe).isEqualTo(expectEqNullSafe); + In in = In.apply(field, new Object[] {1, null, 2}); Predicate expectedIn = builder.in(0, Arrays.asList(1, null, 2)); Predicate actualIn = converter.convert(in);