Skip to content

Commit

Permalink
[spark] Enhance spark push down filter (#2376)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zouxxyy authored Nov 28, 2023
1 parent 15f61c2 commit 686a6eb
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 33 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.paimon.predicate;

import java.util.List;

/** Visit the predicate and check if it only contains partition key's predicate. */
public class PartitionPredicateVisitor implements PredicateVisitor<Boolean> {

private final List<String> partitionKeys;

public PartitionPredicateVisitor(List<String> partitionKeys) {
this.partitionKeys = partitionKeys;
}

@Override
public Boolean visit(LeafPredicate predicate) {
return partitionKeys.contains(predicate.fieldName());
}

@Override
public Boolean visit(CompoundPredicate predicate) {
for (Predicate child : predicate.children()) {
Boolean matched = child.visit(this);

if (!matched) {
return false;
}
}
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

import org.apache.paimon.flink.LogicalTypeConversion;
import org.apache.paimon.flink.PredicateConverter;
import org.apache.paimon.predicate.CompoundPredicate;
import org.apache.paimon.predicate.LeafPredicate;
import org.apache.paimon.predicate.PartitionPredicateVisitor;
import org.apache.paimon.predicate.Predicate;
import org.apache.paimon.predicate.PredicateBuilder;
import org.apache.paimon.predicate.PredicateVisitor;
Expand Down Expand Up @@ -83,25 +82,7 @@ public List<ResolvedExpression> pushFilters(List<ResolvedExpression> filters) {
List<ResolvedExpression> unConsumedFilters = new ArrayList<>();
List<ResolvedExpression> consumedFilters = new ArrayList<>();
List<Predicate> converted = new ArrayList<>();
PredicateVisitor<Boolean> visitor =
new PredicateVisitor<Boolean>() {
@Override
public Boolean visit(LeafPredicate predicate) {
return partitionKeys.contains(predicate.fieldName());
}

@Override
public Boolean visit(CompoundPredicate predicate) {
for (Predicate child : predicate.children()) {
Boolean matched = child.visit(this);

if (!matched) {
return false;
}
}
return true;
}
};
PredicateVisitor<Boolean> visitor = new PartitionPredicateVisitor(partitionKeys);

for (ResolvedExpression filter : filters) {
Optional<Predicate> predicateOptional = PredicateConverter.convert(rowType, filter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ public void testApplyPartitionTable() throws Exception {
Assertions.assertThat(tableSource.pushFilters(filters))
.isEqualTo(ImmutableList.of(filters.get(0)));

// col1 = 1 && p2 like '%a' => [p2 like '%a']
filters = ImmutableList.of(p2Like("%a"));
// col1 = 1 && p2 like '%a' => None
filters = ImmutableList.of(col1Equal1(), p2Like("%a"));
Assertions.assertThat(tableSource.pushFilters(filters)).isEqualTo(filters);

// col1 = 1 && p2 like 'a%' => None
// col1 = 1 && p2 like 'a%' => [p2 like 'a%']
filters = ImmutableList.of(col1Equal1(), p2Like("a%"));
Assertions.assertThat(tableSource.pushFilters(filters))
.isEqualTo(ImmutableList.of(filters.get(0)));
Expand All @@ -99,7 +99,7 @@ public void testApplyPartitionTable() throws Exception {
filters = ImmutableList.of(rand());
Assertions.assertThat(tableSource.pushFilters(filters)).isEqualTo(filters);

// upper(p1) = "A"
// upper(p1) = "A" => [upper(p1) = "A"]
filters = ImmutableList.of(upperP2EqualA());
Assertions.assertThat(tableSource.pushFilters(filters)).isEqualTo(filters);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.apache.paimon.spark;

import org.apache.paimon.predicate.PartitionPredicateVisitor;
import org.apache.paimon.predicate.Predicate;
import org.apache.paimon.table.Table;

Expand All @@ -27,14 +28,16 @@
import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns;
import org.apache.spark.sql.sources.Filter;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;

/** A Spark {@link ScanBuilder} for paimon. */
public class SparkScanBuilder
implements ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns {

private static final Logger LOG = LoggerFactory.getLogger(SparkScanBuilder.class);
private final Table table;

private List<Predicate> predicates = new ArrayList<>();
Expand All @@ -47,19 +50,35 @@ public SparkScanBuilder(Table table) {

@Override
public Filter[] pushFilters(Filter[] filters) {
// There are 3 kinds of filters:
// (1) pushable filters which don't need to be evaluated again after scanning, e.g. filter
// partitions.
// (2) pushable filters which still need to be evaluated after scanning.
// (3) non-pushable filters.
// case 1 and 2 are considered as pushable filters and will be returned by pushedFilters().
// case 2 and 3 are considered as postScan filters and should be return by this method.
List<Filter> pushable = new ArrayList<>(filters.length);
List<Filter> postScan = new ArrayList<>(filters.length);
List<Predicate> predicates = new ArrayList<>(filters.length);

SparkFilterConverter converter = new SparkFilterConverter(table.rowType());
List<Predicate> predicates = new ArrayList<>();
List<Filter> pushed = new ArrayList<>();
PartitionPredicateVisitor visitor = new PartitionPredicateVisitor(table.partitionKeys());
for (Filter filter : filters) {
try {
predicates.add(converter.convert(filter));
pushed.add(filter);
} catch (UnsupportedOperationException ignore) {
Predicate predicate = converter.convert(filter);
predicates.add(predicate);
pushable.add(filter);
if (!predicate.visit(visitor)) {
postScan.add(filter);
}
} catch (UnsupportedOperationException e) {
LOG.warn(e.getMessage());
postScan.add(filter);
}
}
this.predicates = predicates;
this.pushedFilters = pushed.toArray(new Filter[0]);
return filters;
this.pushedFilters = pushable.toArray(new Filter[0]);
return postScan.toArray(new Filter[0]);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.paimon.spark

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal}
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.junit.jupiter.api.Assertions

class SparkPushDownTest extends PaimonSparkTestBase {

test(s"Paimon push down: apply partition filter push down with non-partitioned table") {
spark.sql(s"""
|CREATE TABLE T (id INT, name STRING, pt STRING)
|TBLPROPERTIES ('primary-key'='id, pt', 'bucket'='2')
|""".stripMargin)

spark.sql("INSERT INTO T VALUES (1, 'a', 'p1'), (2, 'b', 'p1'), (3, 'c', 'p2')")

val q = "SELECT * FROM T WHERE pt = 'p1'"
Assertions.assertTrue(checkEqualToFilterExists(q, "pt", Literal("p1")))
checkAnswer(spark.sql(q), Row(1, "a", "p1") :: Row(2, "b", "p1") :: Nil)
}

test(s"Paimon push down: apply partition filter push down with partitioned table") {
spark.sql(s"""
|CREATE TABLE T (id INT, name STRING, pt STRING)
|TBLPROPERTIES ('primary-key'='id, pt', 'bucket'='2')
|PARTITIONED BY (pt)
|""".stripMargin)

spark.sql("INSERT INTO T VALUES (1, 'a', 'p1'), (2, 'b', 'p1'), (3, 'c', 'p2'), (4, 'd', 'p3')")

// partition filter push down did not hit cases:
// case 1
var q = "SELECT * FROM T WHERE id = '1'"
Assertions.assertTrue(checkFilterExists(q))
checkAnswer(spark.sql(q), Row(1, "a", "p1") :: Nil)

// case 2
// filter "id = '1' or pt = 'p1'" can't push down completely, it still need to be evaluated after scanning
q = "SELECT * FROM T WHERE id = '1' or pt = 'p1'"
Assertions.assertTrue(checkEqualToFilterExists(q, "pt", Literal("p1")))
checkAnswer(spark.sql(q), Row(1, "a", "p1") :: Row(2, "b", "p1") :: Nil)

// partition filter push down hit cases:
// case 1
q = "SELECT * FROM T WHERE pt = 'p1'"
Assertions.assertFalse(checkFilterExists(q))
checkAnswer(spark.sql(q), Row(1, "a", "p1") :: Row(2, "b", "p1") :: Nil)

// case 2
q = "SELECT * FROM T WHERE id = '1' and pt = 'p1'"
Assertions.assertFalse(checkEqualToFilterExists(q, "pt", Literal("p1")))
checkAnswer(spark.sql(q), Row(1, "a", "p1") :: Nil)

// case 3
q = "SELECT * FROM T WHERE pt < 'p3'"
Assertions.assertFalse(checkFilterExists(q))
checkAnswer(spark.sql(q), Row(1, "a", "p1") :: Row(2, "b", "p1") :: Row(3, "c", "p2") :: Nil)
}

private def checkFilterExists(sql: String): Boolean = {
spark.sql(sql).queryExecution.optimizedPlan.exists {
case Filter(_: Expression, _) => true
case _ => false
}
}

private def checkEqualToFilterExists(sql: String, name: String, value: Literal): Boolean = {
spark.sql(sql).queryExecution.optimizedPlan.exists {
case Filter(c: Expression, _) =>
c.exists {
case EqualTo(a: AttributeReference, r: Literal) =>
a.name.equals(name) && r.equals(value)
case _ => false
}
case _ => false
}
}

}

0 comments on commit 686a6eb

Please sign in to comment.