diff --git a/paimon-common/src/main/java/org/apache/paimon/predicate/PartitionPredicateVisitor.java b/paimon-common/src/main/java/org/apache/paimon/predicate/PartitionPredicateVisitor.java new file mode 100644 index 000000000000..b859c7a56c1b --- /dev/null +++ b/paimon-common/src/main/java/org/apache/paimon/predicate/PartitionPredicateVisitor.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.predicate; + +import java.util.List; + +/** Visit the predicate and check if it only contains partition key's predicate. */ +public class PartitionPredicateVisitor implements PredicateVisitor { + + private final List partitionKeys; + + public PartitionPredicateVisitor(List partitionKeys) { + this.partitionKeys = partitionKeys; + } + + @Override + public Boolean visit(LeafPredicate predicate) { + return partitionKeys.contains(predicate.fieldName()); + } + + @Override + public Boolean visit(CompoundPredicate predicate) { + for (Predicate child : predicate.children()) { + Boolean matched = child.visit(this); + + if (!matched) { + return false; + } + } + return true; + } +} diff --git a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/FlinkTableSource.java b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/FlinkTableSource.java index c35b46444358..68209240ed17 100644 --- a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/FlinkTableSource.java +++ b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/FlinkTableSource.java @@ -20,8 +20,7 @@ import org.apache.paimon.flink.LogicalTypeConversion; import org.apache.paimon.flink.PredicateConverter; -import org.apache.paimon.predicate.CompoundPredicate; -import org.apache.paimon.predicate.LeafPredicate; +import org.apache.paimon.predicate.PartitionPredicateVisitor; import org.apache.paimon.predicate.Predicate; import org.apache.paimon.predicate.PredicateBuilder; import org.apache.paimon.predicate.PredicateVisitor; @@ -83,25 +82,7 @@ public List pushFilters(List filters) { List unConsumedFilters = new ArrayList<>(); List consumedFilters = new ArrayList<>(); List converted = new ArrayList<>(); - PredicateVisitor visitor = - new PredicateVisitor() { - @Override - public Boolean visit(LeafPredicate predicate) { - return partitionKeys.contains(predicate.fieldName()); - } - - @Override - public Boolean visit(CompoundPredicate predicate) { - for (Predicate child : predicate.children()) { - Boolean matched = child.visit(this); - - if (!matched) { - return false; - } - } - return true; - } - }; + PredicateVisitor visitor = new PartitionPredicateVisitor(partitionKeys); for (ResolvedExpression filter : filters) { Optional predicateOptional = PredicateConverter.convert(rowType, filter); diff --git a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/source/FlinkTableSourceTest.java b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/source/FlinkTableSourceTest.java index 350fe9e3a01f..ef48099c7203 100644 --- a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/source/FlinkTableSourceTest.java +++ b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/source/FlinkTableSourceTest.java @@ -86,11 +86,11 @@ public void testApplyPartitionTable() throws Exception { Assertions.assertThat(tableSource.pushFilters(filters)) .isEqualTo(ImmutableList.of(filters.get(0))); - // col1 = 1 && p2 like '%a' => [p2 like '%a'] - filters = ImmutableList.of(p2Like("%a")); + // col1 = 1 && p2 like '%a' => None + filters = ImmutableList.of(col1Equal1(), p2Like("%a")); Assertions.assertThat(tableSource.pushFilters(filters)).isEqualTo(filters); - // col1 = 1 && p2 like 'a%' => None + // col1 = 1 && p2 like 'a%' => [p2 like 'a%'] filters = ImmutableList.of(col1Equal1(), p2Like("a%")); Assertions.assertThat(tableSource.pushFilters(filters)) .isEqualTo(ImmutableList.of(filters.get(0))); @@ -99,7 +99,7 @@ public void testApplyPartitionTable() throws Exception { filters = ImmutableList.of(rand()); Assertions.assertThat(tableSource.pushFilters(filters)).isEqualTo(filters); - // upper(p1) = "A" + // upper(p1) = "A" => [upper(p1) = "A"] filters = ImmutableList.of(upperP2EqualA()); Assertions.assertThat(tableSource.pushFilters(filters)).isEqualTo(filters); diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkScanBuilder.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkScanBuilder.java index ec7367ed2e80..06652ff9fc0f 100644 --- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkScanBuilder.java +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkScanBuilder.java @@ -18,6 +18,7 @@ package org.apache.paimon.spark; +import org.apache.paimon.predicate.PartitionPredicateVisitor; import org.apache.paimon.predicate.Predicate; import org.apache.paimon.table.Table; @@ -27,6 +28,8 @@ import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.List; @@ -34,7 +37,7 @@ /** A Spark {@link ScanBuilder} for paimon. */ public class SparkScanBuilder implements ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns { - + private static final Logger LOG = LoggerFactory.getLogger(SparkScanBuilder.class); private final Table table; private List predicates = new ArrayList<>(); @@ -47,19 +50,35 @@ public SparkScanBuilder(Table table) { @Override public Filter[] pushFilters(Filter[] filters) { + // There are 3 kinds of filters: + // (1) pushable filters which don't need to be evaluated again after scanning, e.g. filter + // partitions. + // (2) pushable filters which still need to be evaluated after scanning. + // (3) non-pushable filters. + // case 1 and 2 are considered as pushable filters and will be returned by pushedFilters(). + // case 2 and 3 are considered as postScan filters and should be return by this method. + List pushable = new ArrayList<>(filters.length); + List postScan = new ArrayList<>(filters.length); + List predicates = new ArrayList<>(filters.length); + SparkFilterConverter converter = new SparkFilterConverter(table.rowType()); - List predicates = new ArrayList<>(); - List pushed = new ArrayList<>(); + PartitionPredicateVisitor visitor = new PartitionPredicateVisitor(table.partitionKeys()); for (Filter filter : filters) { try { - predicates.add(converter.convert(filter)); - pushed.add(filter); - } catch (UnsupportedOperationException ignore) { + Predicate predicate = converter.convert(filter); + predicates.add(predicate); + pushable.add(filter); + if (!predicate.visit(visitor)) { + postScan.add(filter); + } + } catch (UnsupportedOperationException e) { + LOG.warn(e.getMessage()); + postScan.add(filter); } } this.predicates = predicates; - this.pushedFilters = pushed.toArray(new Filter[0]); - return filters; + this.pushedFilters = pushable.toArray(new Filter[0]); + return postScan.toArray(new Filter[0]); } @Override diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/SparkPushDownTest.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/SparkPushDownTest.scala new file mode 100644 index 000000000000..6d1360b95c02 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/SparkPushDownTest.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.paimon.spark + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal} +import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.junit.jupiter.api.Assertions + +class SparkPushDownTest extends PaimonSparkTestBase { + + test(s"Paimon push down: apply partition filter push down with non-partitioned table") { + spark.sql(s""" + |CREATE TABLE T (id INT, name STRING, pt STRING) + |TBLPROPERTIES ('primary-key'='id, pt', 'bucket'='2') + |""".stripMargin) + + spark.sql("INSERT INTO T VALUES (1, 'a', 'p1'), (2, 'b', 'p1'), (3, 'c', 'p2')") + + val q = "SELECT * FROM T WHERE pt = 'p1'" + Assertions.assertTrue(checkEqualToFilterExists(q, "pt", Literal("p1"))) + checkAnswer(spark.sql(q), Row(1, "a", "p1") :: Row(2, "b", "p1") :: Nil) + } + + test(s"Paimon push down: apply partition filter push down with partitioned table") { + spark.sql(s""" + |CREATE TABLE T (id INT, name STRING, pt STRING) + |TBLPROPERTIES ('primary-key'='id, pt', 'bucket'='2') + |PARTITIONED BY (pt) + |""".stripMargin) + + spark.sql("INSERT INTO T VALUES (1, 'a', 'p1'), (2, 'b', 'p1'), (3, 'c', 'p2'), (4, 'd', 'p3')") + + // partition filter push down did not hit cases: + // case 1 + var q = "SELECT * FROM T WHERE id = '1'" + Assertions.assertTrue(checkFilterExists(q)) + checkAnswer(spark.sql(q), Row(1, "a", "p1") :: Nil) + + // case 2 + // filter "id = '1' or pt = 'p1'" can't push down completely, it still need to be evaluated after scanning + q = "SELECT * FROM T WHERE id = '1' or pt = 'p1'" + Assertions.assertTrue(checkEqualToFilterExists(q, "pt", Literal("p1"))) + checkAnswer(spark.sql(q), Row(1, "a", "p1") :: Row(2, "b", "p1") :: Nil) + + // partition filter push down hit cases: + // case 1 + q = "SELECT * FROM T WHERE pt = 'p1'" + Assertions.assertFalse(checkFilterExists(q)) + checkAnswer(spark.sql(q), Row(1, "a", "p1") :: Row(2, "b", "p1") :: Nil) + + // case 2 + q = "SELECT * FROM T WHERE id = '1' and pt = 'p1'" + Assertions.assertFalse(checkEqualToFilterExists(q, "pt", Literal("p1"))) + checkAnswer(spark.sql(q), Row(1, "a", "p1") :: Nil) + + // case 3 + q = "SELECT * FROM T WHERE pt < 'p3'" + Assertions.assertFalse(checkFilterExists(q)) + checkAnswer(spark.sql(q), Row(1, "a", "p1") :: Row(2, "b", "p1") :: Row(3, "c", "p2") :: Nil) + } + + private def checkFilterExists(sql: String): Boolean = { + spark.sql(sql).queryExecution.optimizedPlan.exists { + case Filter(_: Expression, _) => true + case _ => false + } + } + + private def checkEqualToFilterExists(sql: String, name: String, value: Literal): Boolean = { + spark.sql(sql).queryExecution.optimizedPlan.exists { + case Filter(c: Expression, _) => + c.exists { + case EqualTo(a: AttributeReference, r: Literal) => + a.name.equals(name) && r.equals(value) + case _ => false + } + case _ => false + } + } + +}