Skip to content

Commit

Permalink
[spark] SparkFilterConverter supports convert EqualNullSafe (#2878)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zouxxyy authored Feb 21, 2024
1 parent d1f2acb commit d939739
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.paimon.types.RowType;

import org.apache.spark.sql.sources.And;
import org.apache.spark.sql.sources.EqualNullSafe;
import org.apache.spark.sql.sources.EqualTo;
import org.apache.spark.sql.sources.Filter;
import org.apache.spark.sql.sources.GreaterThan;
Expand All @@ -49,6 +50,7 @@ public class SparkFilterConverter {
public static final List<String> SUPPORT_FILTERS =
Arrays.asList(
"EqualTo",
"EqualNullSafe",
"GreaterThan",
"GreaterThanOrEqual",
"LessThan",
Expand Down Expand Up @@ -76,6 +78,15 @@ public Predicate convert(Filter filter) {
int index = fieldIndex(eq.attribute());
Object literal = convertLiteral(index, eq.value());
return builder.equal(index, literal);
} else if (filter instanceof EqualNullSafe) {
EqualNullSafe eq = (EqualNullSafe) filter;
if (eq.value() == null) {
return builder.isNull(fieldIndex(eq.attribute()));
} else {
int index = fieldIndex(eq.attribute());
Object literal = convertLiteral(index, eq.value());
return builder.equal(index, literal);
}
} else if (filter instanceof GreaterThan) {
GreaterThan gt = (GreaterThan) filter;
int index = fieldIndex(gt.attribute());
Expand Down Expand Up @@ -124,15 +135,13 @@ public Predicate convert(Filter filter) {
return builder.startsWith(index, literal);
}

// TODO: In, NotIn, AlwaysTrue, AlwaysFalse, EqualNullSafe
// TODO: AlwaysTrue, AlwaysFalse
throw new UnsupportedOperationException(
filter + " is unsupported. Support Filters: " + SUPPORT_FILTERS);
}

public Object convertLiteral(String field, Object value) {
int index = fieldIndex(field);
DataType type = rowType.getTypeAt(index);
return convertJavaObject(type, value);
return convertLiteral(fieldIndex(field), value);
}

private int fieldIndex(String field) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.paimon.types.RowType;
import org.apache.paimon.types.TimestampType;

import org.apache.spark.sql.sources.EqualNullSafe;
import org.apache.spark.sql.sources.EqualTo;
import org.apache.spark.sql.sources.GreaterThan;
import org.apache.spark.sql.sources.GreaterThanOrEqual;
Expand Down Expand Up @@ -118,6 +119,16 @@ public void testAll() {
Predicate actualEqNull = converter.convert(eqNull);
assertThat(actualEqNull).isEqualTo(expectedEqNull);

EqualNullSafe eqSafe = EqualNullSafe.apply(field, 1);
Predicate expectedEqSafe = builder.equal(0, 1);
Predicate actualEqSafe = converter.convert(eqSafe);
assertThat(actualEqSafe).isEqualTo(expectedEqSafe);

EqualNullSafe eqNullSafe = EqualNullSafe.apply(field, null);
Predicate expectEqNullSafe = builder.isNull(0);
Predicate actualEqNullSafe = converter.convert(eqNullSafe);
assertThat(actualEqNullSafe).isEqualTo(expectEqNullSafe);

In in = In.apply(field, new Object[] {1, null, 2});
Predicate expectedIn = builder.in(0, Arrays.asList(1, null, 2));
Predicate actualIn = converter.convert(in);
Expand Down

0 comments on commit d939739

Please sign in to comment.