From 7cc978287f190a9e86b557bf1d08506e4e85957b Mon Sep 17 00:00:00 2001 From: zouxxyy Date: Thu, 16 Jan 2025 00:51:45 +0800 Subject: [PATCH] use v1 --- .../paimon/spark/PaimonScanBuilder.scala | 57 ++++- .../paimon/spark/sql/PaimonPushDownTest.scala | 54 ++++ .../internal/connector/PredicateUtils.scala | 147 +++++++++++ .../paimon/spark/sql/PaimonPushDownTest.scala | 21 ++ .../paimon/spark/sql/PaimonPushDownTest.scala | 21 ++ .../paimon/spark/sql/PaimonPushDownTest.scala | 21 ++ .../paimon/spark/sql/PaimonPushDownTest.scala | 21 ++ .../paimon/spark/PaimonBaseScanBuilder.scala | 58 +---- .../apache/paimon/spark/PaimonLocalScan.scala | 4 +- .../org/apache/paimon/spark/PaimonScan.scala | 2 +- .../paimon/spark/PaimonScanBuilder.scala | 63 ++++- .../apache/paimon/spark/PaimonSplitScan.scala | 2 +- .../paimon/spark/SparkV2FilterConverter.scala | 19 +- .../expressions/ExpressionHelper.scala | 1 + .../paimon/spark/commands/PaimonCommand.scala | 1 + .../org/apache/spark/sql/PaimonUtils.scala | 5 + ...est.scala => PaimonPushDownTestBase.scala} | 9 +- .../sql/SparkV2FilterConverterTestBase.scala | 233 ++++++++++++++---- 18 files changed, 629 insertions(+), 110 deletions(-) create mode 100644 paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala create mode 100644 paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala create mode 100644 paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala create mode 100644 paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala create mode 100644 paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala create mode 100644 paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala rename paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/{PaimonPushDownTest.scala => PaimonPushDownTestBase.scala} (97%) diff --git a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala index 10b83ccf08b1..395f8707ab9d 100644 --- a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala +++ b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala @@ -18,6 +18,61 @@ package org.apache.paimon.spark +import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate} import org.apache.paimon.table.Table -class PaimonScanBuilder(table: Table) extends PaimonBaseScanBuilder(table) +import org.apache.spark.sql.connector.read.SupportsPushDownFilters +import org.apache.spark.sql.sources.Filter + +import scala.collection.mutable + +class PaimonScanBuilder(table: Table) + extends PaimonBaseScanBuilder(table) + with SupportsPushDownFilters { + + private var pushedSparkFilters = Array.empty[Filter] + + /** + * Pushes down filters, and returns filters that need to be evaluated after scanning.

Rows + * should be returned from the data source if and only if all the filters match. That is, filters + * must be interpreted as ANDed together. + */ + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + val pushable = mutable.ArrayBuffer.empty[(Filter, Predicate)] + val postScan = mutable.ArrayBuffer.empty[Filter] + val reserved = mutable.ArrayBuffer.empty[Filter] + + val converter = new SparkFilterConverter(table.rowType) + val visitor = new PartitionPredicateVisitor(table.partitionKeys()) + filters.foreach { + filter => + val predicate = converter.convertIgnoreFailure(filter) + if (predicate == null) { + postScan.append(filter) + } else { + pushable.append((filter, predicate)) + if (predicate.visit(visitor)) { + reserved.append(filter) + } else { + postScan.append(filter) + } + } + } + + if (pushable.nonEmpty) { + this.pushedSparkFilters = pushable.map(_._1).toArray + this.pushedPaimonPredicates = pushable.map(_._2).toArray + } + if (reserved.nonEmpty) { + this.reservedFilters = reserved.toArray + } + if (postScan.nonEmpty) { + this.hasPostScanPredicates = true + } + postScan.toArray + } + + override def pushedFilters(): Array[Filter] = { + pushedSparkFilters + } +} diff --git a/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala b/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala new file mode 100644 index 000000000000..e0705b761ab9 --- /dev/null +++ b/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal} +import org.apache.spark.sql.catalyst.plans.logical.Filter + +class PaimonPushDownTest extends PaimonPushDownTestBase { + + override def checkFilterExists(sql: String): Boolean = { + spark + .sql(sql) + .queryExecution + .optimizedPlan + .find { + case Filter(_: Expression, _) => true + case _ => false + } + .isDefined + } + + override def checkEqualToFilterExists(sql: String, name: String, value: Literal): Boolean = { + spark + .sql(sql) + .queryExecution + .optimizedPlan + .find { + case Filter(c: Expression, _) => + c.find { + case EqualTo(a: AttributeReference, r: Literal) => + a.name.equals(name) && r.equals(value) + case _ => false + }.isDefined + case _ => false + } + .isDefined + } +} diff --git a/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala new file mode 100644 index 000000000000..43459648095d --- /dev/null +++ b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal.connector + +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.connector.expressions.{LiteralValue, NamedReference} +import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => V2Not, Or => V2Or, Predicate} +import org.apache.spark.sql.sources.{AlwaysFalse, AlwaysTrue, And, EqualNullSafe, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not, Or, StringContains, StringEndsWith, StringStartsWith} +import org.apache.spark.sql.types.StringType + +// Copy from Spark 3.4+ +private[sql] object PredicateUtils { + + def toV1(predicate: Predicate): Option[Filter] = { + + def isValidBinaryPredicate(): Boolean = { + if ( + predicate.children().length == 2 && + predicate.children()(0).isInstanceOf[NamedReference] && + predicate.children()(1).isInstanceOf[LiteralValue[_]] + ) { + true + } else { + false + } + } + + predicate.name() match { + case "IN" if predicate.children()(0).isInstanceOf[NamedReference] => + val attribute = predicate.children()(0).toString + val values = predicate.children().drop(1) + if (values.length > 0) { + if (!values.forall(_.isInstanceOf[LiteralValue[_]])) return None + val dataType = values(0).asInstanceOf[LiteralValue[_]].dataType + if (!values.forall(_.asInstanceOf[LiteralValue[_]].dataType.sameType(dataType))) { + return None + } + val inValues = values.map( + v => + CatalystTypeConverters.convertToScala( + v.asInstanceOf[LiteralValue[_]].value, + dataType)) + Some(In(attribute, inValues)) + } else { + Some(In(attribute, Array.empty[Any])) + } + + case "=" | "<=>" | ">" | "<" | ">=" | "<=" if isValidBinaryPredicate() => + val attribute = predicate.children()(0).toString + val value = predicate.children()(1).asInstanceOf[LiteralValue[_]] + val v1Value = CatalystTypeConverters.convertToScala(value.value, value.dataType) + val v1Filter = predicate.name() match { + case "=" => EqualTo(attribute, v1Value) + case "<=>" => EqualNullSafe(attribute, v1Value) + case ">" => GreaterThan(attribute, v1Value) + case ">=" => GreaterThanOrEqual(attribute, v1Value) + case "<" => LessThan(attribute, v1Value) + case "<=" => LessThanOrEqual(attribute, v1Value) + } + Some(v1Filter) + + case "IS_NULL" | "IS_NOT_NULL" + if predicate.children().length == 1 && + predicate.children()(0).isInstanceOf[NamedReference] => + val attribute = predicate.children()(0).toString + val v1Filter = predicate.name() match { + case "IS_NULL" => IsNull(attribute) + case "IS_NOT_NULL" => IsNotNull(attribute) + } + Some(v1Filter) + + case "STARTS_WITH" | "ENDS_WITH" | "CONTAINS" if isValidBinaryPredicate() => + val attribute = predicate.children()(0).toString + val value = predicate.children()(1).asInstanceOf[LiteralValue[_]] + if (!value.dataType.sameType(StringType)) return None + val v1Value = value.value.toString + val v1Filter = predicate.name() match { + case "STARTS_WITH" => + StringStartsWith(attribute, v1Value) + case "ENDS_WITH" => + StringEndsWith(attribute, v1Value) + case "CONTAINS" => + StringContains(attribute, v1Value) + } + Some(v1Filter) + + case "ALWAYS_TRUE" | "ALWAYS_FALSE" if predicate.children().isEmpty => + val v1Filter = predicate.name() match { + case "ALWAYS_TRUE" => AlwaysTrue() + case "ALWAYS_FALSE" => AlwaysFalse() + } + Some(v1Filter) + + case "AND" => + val and = predicate.asInstanceOf[V2And] + val left = toV1(and.left()) + val right = toV1(and.right()) + if (left.nonEmpty && right.nonEmpty) { + Some(And(left.get, right.get)) + } else { + None + } + + case "OR" => + val or = predicate.asInstanceOf[V2Or] + val left = toV1(or.left()) + val right = toV1(or.right()) + if (left.nonEmpty && right.nonEmpty) { + Some(Or(left.get, right.get)) + } else if (left.nonEmpty) { + left + } else { + right + } + + case "NOT" => + val child = toV1(predicate.asInstanceOf[V2Not].child()) + if (child.nonEmpty) { + Some(Not(child.get)) + } else { + None + } + + case _ => None + } + } + + def toV1(predicates: Array[Predicate]): Array[Filter] = { + predicates.flatMap(toV1(_)) + } +} diff --git a/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala b/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala new file mode 100644 index 000000000000..26677d85c71a --- /dev/null +++ b/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +class PaimonPushDownTest extends PaimonPushDownTestBase {} diff --git a/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala b/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala new file mode 100644 index 000000000000..26677d85c71a --- /dev/null +++ b/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +class PaimonPushDownTest extends PaimonPushDownTestBase {} diff --git a/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala b/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala new file mode 100644 index 000000000000..26677d85c71a --- /dev/null +++ b/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +class PaimonPushDownTest extends PaimonPushDownTestBase {} diff --git a/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala b/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala new file mode 100644 index 000000000000..26677d85c71a --- /dev/null +++ b/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +class PaimonPushDownTest extends PaimonPushDownTestBase {} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala index a265ee78f5b9..1db178448413 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala @@ -18,77 +18,31 @@ package org.apache.paimon.spark -import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate} +import org.apache.paimon.predicate.Predicate import org.apache.paimon.table.Table import org.apache.spark.internal.Logging -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownRequiredColumns} import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType -import scala.collection.mutable - abstract class PaimonBaseScanBuilder(table: Table) extends ScanBuilder - with SupportsPushDownFilters with SupportsPushDownRequiredColumns with Logging { protected var requiredSchema: StructType = SparkTypeUtils.fromPaimonRowType(table.rowType()) - protected var pushedPredicates: Array[(Filter, Predicate)] = Array.empty + protected var pushedPaimonPredicates: Array[Predicate] = Array.empty - protected var partitionFilters: Array[Filter] = Array.empty + protected var reservedFilters: Array[Filter] = Array.empty - protected var postScanFilters: Array[Filter] = Array.empty + protected var hasPostScanPredicates = false protected var pushDownLimit: Option[Int] = None override def build(): Scan = { - PaimonScan(table, requiredSchema, pushedPredicates.map(_._2), partitionFilters, pushDownLimit) - } - - /** - * Pushes down filters, and returns filters that need to be evaluated after scanning.

Rows - * should be returned from the data source if and only if all of the filters match. That is, - * filters must be interpreted as ANDed together. - */ - override def pushFilters(filters: Array[Filter]): Array[Filter] = { - val pushable = mutable.ArrayBuffer.empty[(Filter, Predicate)] - val postScan = mutable.ArrayBuffer.empty[Filter] - val partitionFilter = mutable.ArrayBuffer.empty[Filter] - - val converter = new SparkFilterConverter(table.rowType) - val visitor = new PartitionPredicateVisitor(table.partitionKeys()) - filters.foreach { - filter => - val predicate = converter.convertIgnoreFailure(filter) - if (predicate == null) { - postScan.append(filter) - } else { - pushable.append((filter, predicate)) - if (predicate.visit(visitor)) { - partitionFilter.append(filter) - } else { - postScan.append(filter) - } - } - } - - if (pushable.nonEmpty) { - this.pushedPredicates = pushable.toArray - } - if (partitionFilter.nonEmpty) { - this.partitionFilters = partitionFilter.toArray - } - if (postScan.nonEmpty) { - this.postScanFilters = postScan.toArray - } - postScan.toArray - } - - override def pushedFilters(): Array[Filter] = { - pushedPredicates.map(_._1) + PaimonScan(table, requiredSchema, pushedPaimonPredicates, reservedFilters, pushDownLimit) } override def pruneColumns(requiredSchema: StructType): Unit = { diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonLocalScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonLocalScan.scala index 490a1b133f6f..1f4e88e8d160 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonLocalScan.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonLocalScan.scala @@ -18,11 +18,11 @@ package org.apache.paimon.spark +import org.apache.paimon.predicate.Predicate import org.apache.paimon.table.Table import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.read.LocalScan -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType /** A scan does not require [[RDD]] to execute */ @@ -30,7 +30,7 @@ case class PaimonLocalScan( rows: Array[InternalRow], readSchema: StructType, table: Table, - filters: Array[Filter]) + filters: Array[Predicate]) extends LocalScan { override def description(): String = { diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala index 2f1e6c53ab0a..d02c6edd84d2 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala @@ -156,6 +156,7 @@ case class PaimonScan( .map(fieldReference) } + // todo: replace it with SupportsRuntimeV2Filtering override def filter(filters: Array[Filter]): Unit = { val converter = new SparkFilterConverter(table.rowType()) val partitionFilter = filters.flatMap { @@ -170,5 +171,4 @@ case class PaimonScan( inputPartitions = null } } - } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala index 0393a1cd1578..7b6c65c37f76 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala @@ -18,22 +18,71 @@ package org.apache.paimon.spark -import org.apache.paimon.predicate.PredicateBuilder +import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate, PredicateBuilder} import org.apache.paimon.spark.aggregate.LocalAggregator import org.apache.paimon.table.Table import org.apache.paimon.table.source.DataSplit +import org.apache.spark.sql.PaimonUtils import org.apache.spark.sql.connector.expressions.aggregate.Aggregation -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates, SupportsPushDownLimit} +import org.apache.spark.sql.connector.expressions.filter.{Predicate => SparkPredicate} +import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates, SupportsPushDownLimit, SupportsPushDownV2Filters} +import org.apache.spark.sql.sources.Filter import scala.collection.JavaConverters._ +import scala.collection.mutable class PaimonScanBuilder(table: Table) extends PaimonBaseScanBuilder(table) + with SupportsPushDownV2Filters with SupportsPushDownLimit with SupportsPushDownAggregates { + private var localScan: Option[Scan] = None + private var pushedSparkPredicates = Array.empty[SparkPredicate] + + /** Pushes down filters, and returns filters that need to be evaluated after scanning. */ + override def pushPredicates(predicates: Array[SparkPredicate]): Array[SparkPredicate] = { + val pushable = mutable.ArrayBuffer.empty[(SparkPredicate, Predicate)] + val postScan = mutable.ArrayBuffer.empty[SparkPredicate] + val reserved = mutable.ArrayBuffer.empty[Filter] + + val converter = SparkV2FilterConverter(table.rowType) + val visitor = new PartitionPredicateVisitor(table.partitionKeys()) + predicates.foreach { + predicate => + converter.convert(predicate, ignoreFailure = true) match { + case Some(paimonPredicate) => + pushable.append((predicate, paimonPredicate)) + if (paimonPredicate.visit(visitor)) { + // We need to filter the stats using filter instead of predicate. + reserved.append(PaimonUtils.filterV2ToV1(predicate).get) + } else { + postScan.append(predicate) + } + case None => + postScan.append(predicate) + } + } + + if (pushable.nonEmpty) { + this.pushedSparkPredicates = pushable.map(_._1).toArray + this.pushedPaimonPredicates = pushable.map(_._2).toArray + } + if (reserved.nonEmpty) { + this.reservedFilters = reserved.toArray + } + if (postScan.nonEmpty) { + this.hasPostScanPredicates = true + } + postScan.toArray + } + + override def pushedPredicates: Array[SparkPredicate] = { + pushedSparkPredicates + } + override def pushLimit(limit: Int): Boolean = { // It is safe, since we will do nothing if it is the primary table and the split is not `rawConvertible` pushDownLimit = Some(limit) @@ -52,8 +101,8 @@ class PaimonScanBuilder(table: Table) return true } - // Only support with push down partition filter - if (postScanFilters.nonEmpty) { + // Only support when there is no post scan predicates. + if (hasPostScanPredicates) { return false } @@ -63,8 +112,8 @@ class PaimonScanBuilder(table: Table) } val readBuilder = table.newReadBuilder - if (pushedPredicates.nonEmpty) { - val pushedPartitionPredicate = PredicateBuilder.and(pushedPredicates.map(_._2): _*) + if (pushedPaimonPredicates.nonEmpty) { + val pushedPartitionPredicate = PredicateBuilder.and(pushedPaimonPredicates.toList.asJava) readBuilder.withFilter(pushedPartitionPredicate) } val dataSplits = readBuilder.newScan().plan().splits().asScala.map(_.asInstanceOf[DataSplit]) @@ -77,7 +126,7 @@ class PaimonScanBuilder(table: Table) aggregator.result(), aggregator.resultSchema(), table, - pushedPredicates.map(_._1))) + pushedPaimonPredicates)) true } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSplitScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSplitScan.scala index 8d9e643f9485..be67cc5e28b3 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSplitScan.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSplitScan.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.StructType class PaimonSplitScanBuilder(table: KnownSplitsTable) extends PaimonBaseScanBuilder(table) { override def build(): Scan = { - PaimonSplitScan(table, table.splits(), requiredSchema, pushedPredicates.map(_._2)) + PaimonSplitScan(table, table.splits(), requiredSchema, pushedPaimonPredicates) } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala index 11ef302672e1..a3e1077b507f 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala @@ -20,9 +20,11 @@ package org.apache.paimon.spark import org.apache.paimon.data.{BinaryString, Decimal, Timestamp} import org.apache.paimon.predicate.{Predicate, PredicateBuilder} +import org.apache.paimon.spark.util.shim.TypeUtils.treatPaimonTimestampTypeAsSparkTimestampType import org.apache.paimon.types.{DataTypeRoot, DecimalType, RowType} import org.apache.paimon.types.DataTypeRoot._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.connector.expressions.{Literal, NamedReference} import org.apache.spark.sql.connector.expressions.filter.{And, Not, Or, Predicate => SparkPredicate} @@ -35,6 +37,15 @@ case class SparkV2FilterConverter(rowType: RowType) { val builder = new PredicateBuilder(rowType) + def convert(sparkPredicate: SparkPredicate, ignoreFailure: Boolean): Option[Predicate] = { + try { + Some(convert(sparkPredicate)) + } catch { + case _ if ignoreFailure => None + case e: Exception => throw e + } + } + def convert(sparkPredicate: SparkPredicate): Predicate = { sparkPredicate.name() match { case EQUAL_TO => @@ -205,8 +216,14 @@ case class SparkV2FilterConverter(rowType: RowType) { value.asInstanceOf[org.apache.spark.sql.types.Decimal].toJavaBigDecimal, precision, scale) - case DataTypeRoot.TIMESTAMP_WITH_LOCAL_TIME_ZONE | DataTypeRoot.TIMESTAMP_WITHOUT_TIME_ZONE => + case DataTypeRoot.TIMESTAMP_WITH_LOCAL_TIME_ZONE => Timestamp.fromMicros(value.asInstanceOf[Long]) + case DataTypeRoot.TIMESTAMP_WITHOUT_TIME_ZONE => + if (treatPaimonTimestampTypeAsSparkTimestampType()) { + Timestamp.fromSQLTimestamp(DateTimeUtils.toJavaTimestamp(value.asInstanceOf[Long])) + } else { + Timestamp.fromMicros(value.asInstanceOf[Long]) + } case _ => throw new UnsupportedOperationException( s"Convert value: $value to datatype: $dataType is unsupported.") diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/expressions/ExpressionHelper.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/expressions/ExpressionHelper.scala index 2eef2c41aebe..ba660a7a8e98 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/expressions/ExpressionHelper.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/expressions/ExpressionHelper.scala @@ -168,6 +168,7 @@ trait ExpressionHelper extends PredicateHelper { output: Seq[Attribute], rowType: RowType, ignorePartialFailure: Boolean = false): Option[Predicate] = { + // todo: replace it with SparkV2FilterConverter when we drop Spark3.2 val converter = new SparkFilterConverter(rowType) val filters = normalizeExprs(Seq(condition), output) .flatMap(splitConjunctivePredicates(_).flatMap { diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala index 28ac1623fb59..d41fd7d4d287 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala @@ -63,6 +63,7 @@ trait PaimonCommand extends WithFileStoreTable with ExpressionHelper with SQLCon def convertPartitionFilterToMap( filter: Filter, partitionRowType: RowType): Map[String, String] = { + // todo: replace it with SparkV2FilterConverter when we drop Spark3.2 val converter = new SparkFilterConverter(partitionRowType) splitConjunctiveFilters(filter).map { case EqualNullSafe(attribute, value) => diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/PaimonUtils.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/PaimonUtils.scala index d01a840f8ece..a1ce25137436 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/PaimonUtils.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/PaimonUtils.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy.translateFilterV2WithMapping +import org.apache.spark.sql.internal.connector.PredicateUtils import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.PartitioningUtils @@ -74,6 +75,10 @@ object PaimonUtils { translateFilterV2WithMapping(predicate, None) } + def filterV2ToV1(predicate: Predicate): Option[Filter] = { + PredicateUtils.toV1(predicate) + } + def fieldReference(name: String): FieldReference = { fieldReference(Seq(name)) } diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala similarity index 97% rename from paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala rename to paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala index 503f1c8e3e9d..15c021babb75 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownLimit} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.junit.jupiter.api.Assertions -class PaimonPushDownTest extends PaimonSparkTestBase { +abstract class PaimonPushDownTestBase extends PaimonSparkTestBase { import testImplicits._ @@ -101,6 +101,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase { } test("Paimon pushDown: limit for append-only tables") { + assume(gteqSpark3_3) spark.sql(s""" |CREATE TABLE T (a INT, b STRING, c STRING) |PARTITIONED BY (c) @@ -128,6 +129,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase { } test("Paimon pushDown: limit for primary key table") { + assume(gteqSpark3_3) spark.sql(s""" |CREATE TABLE T (a INT, b STRING, c STRING) |TBLPROPERTIES ('primary-key'='a') @@ -202,6 +204,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase { } test("Paimon pushDown: limit for table with deletion vector") { + assume(gteqSpark3_3) Seq(true, false).foreach( deletionVectorsEnabled => { Seq(true, false).foreach( @@ -279,14 +282,14 @@ class PaimonPushDownTest extends PaimonSparkTestBase { SparkTable(loadTable(tableName)).newScanBuilder(CaseInsensitiveStringMap.empty()) } - private def checkFilterExists(sql: String): Boolean = { + def checkFilterExists(sql: String): Boolean = { spark.sql(sql).queryExecution.optimizedPlan.exists { case Filter(_: Expression, _) => true case _ => false } } - private def checkEqualToFilterExists(sql: String, name: String, value: Literal): Boolean = { + def checkEqualToFilterExists(sql: String, name: String, value: Literal): Boolean = { spark.sql(sql).queryExecution.optimizedPlan.exists { case Filter(c: Expression, _) => c.exists { diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala index b9cbc29b3aa3..e51a8f7a0114 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala @@ -21,10 +21,13 @@ package org.apache.paimon.spark.sql import org.apache.paimon.data.{BinaryString, Decimal, Timestamp} import org.apache.paimon.predicate.PredicateBuilder import org.apache.paimon.spark.{PaimonSparkTestBase, SparkV2FilterConverter} +import org.apache.paimon.spark.util.shim.TypeUtils.treatPaimonTimestampTypeAsSparkTimestampType +import org.apache.paimon.table.source.DataSplit import org.apache.paimon.types.RowType import org.apache.spark.SparkConf import org.apache.spark.sql.PaimonUtils.translateFilterV2 +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.logical.Filter import org.apache.spark.sql.connector.expressions.filter.Predicate @@ -54,9 +57,26 @@ abstract class SparkV2FilterConverterTestBase extends PaimonSparkTestBase { | decimal_col DECIMAL(10, 5), | boolean_col BOOLEAN, | date_col DATE, - | binary BINARY + | binary_col BINARY |) USING paimon |""".stripMargin) + + sql(""" + |INSERT INTO test_tbl VALUES + |('hello', 1, 1, 1, 1, 1.0, 1.0, 12.12345, true, date('2025-01-15'), binary('b1')) + |""".stripMargin) + sql(""" + |INSERT INTO test_tbl VALUES + |('world', 2, 2, 2, 2, 2.0, 2.0, 22.12345, false, date('2025-01-16'), binary('b2')) + |""".stripMargin) + sql(""" + |INSERT INTO test_tbl VALUES + |('hi', 3, 3, 3, 3, 3.0, 3.0, 32.12345, false, date('2025-01-17'), binary('b3')) + |""".stripMargin) + sql(""" + |INSERT INTO test_tbl VALUES + |('paimon', 4, 4, null, 4, 4.0, 4.0, 42.12345, true, date('2025-01-18'), binary('b4')) + |""".stripMargin) } override protected def afterAll(): Unit = { @@ -71,147 +91,269 @@ abstract class SparkV2FilterConverterTestBase extends PaimonSparkTestBase { lazy val converter: SparkV2FilterConverter = SparkV2FilterConverter(rowType) test("V2Filter: all types") { - var actual = converter.convert(v2Filter("string_col = 'hello'")) + var filter = "string_col = 'hello'" + var actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.equal(0, BinaryString.fromString("hello")))) + checkAnswer(sql(s"SELECT string_col from test_tbl WHERE $filter"), Seq(Row("hello"))) + assert(scanFilesCount(filter) == 1) - actual = converter.convert(v2Filter("byte_col = 1")) + filter = "byte_col = 1" + actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.equal(1, 1.toByte))) + checkAnswer(sql(s"SELECT byte_col from test_tbl WHERE $filter"), Seq(Row(1.toByte))) + assert(scanFilesCount(filter) == 1) - actual = converter.convert(v2Filter("short_col = 1")) + filter = "short_col = 1" + actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.equal(2, 1.toShort))) + checkAnswer(sql(s"SELECT short_col from test_tbl WHERE $filter"), Seq(Row(1.toShort))) + assert(scanFilesCount(filter) == 1) - actual = converter.convert(v2Filter("int_col = 1")) + filter = "int_col = 1" + actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.equal(3, 1))) + checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter"), Seq(Row(1))) + assert(scanFilesCount(filter) == 1) - actual = converter.convert(v2Filter("long_col = 1")) + filter = "long_col = 1" + actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.equal(4, 1L))) + checkAnswer(sql(s"SELECT long_col from test_tbl WHERE $filter"), Seq(Row(1L))) + assert(scanFilesCount(filter) == 1) - actual = converter.convert(v2Filter("float_col = 1.0")) + filter = "float_col = 1.0" + actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.equal(5, 1.0f))) + checkAnswer(sql(s"SELECT float_col from test_tbl WHERE $filter"), Seq(Row(1.0f))) + assert(scanFilesCount(filter) == 1) - actual = converter.convert(v2Filter("double_col = 1.0")) + filter = "double_col = 1.0" + actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.equal(6, 1.0d))) + checkAnswer(sql(s"SELECT double_col from test_tbl WHERE $filter"), Seq(Row(1.0d))) + assert(scanFilesCount(filter) == 1) - actual = converter.convert(v2Filter("decimal_col = 12.12345")) + filter = "decimal_col = 12.12345" + actual = converter.convert(v2Filter(filter)) assert( actual.equals( builder.equal(7, Decimal.fromBigDecimal(new java.math.BigDecimal("12.12345"), 10, 5)))) + checkAnswer( + sql(s"SELECT decimal_col from test_tbl WHERE $filter"), + Seq(Row(new java.math.BigDecimal("12.12345")))) + assert(scanFilesCount(filter) == 1) - actual = converter.convert(v2Filter("boolean_col = true")) + filter = "boolean_col = true" + actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.equal(8, true))) + checkAnswer(sql(s"SELECT boolean_col from test_tbl WHERE $filter"), Seq(Row(true), Row(true))) + assert(scanFilesCount(filter) == 2) - actual = converter.convert(v2Filter("date_col = cast('2025-01-15' as date)")) + filter = "date_col = date('2025-01-15')" + actual = converter.convert(v2Filter(filter)) val localDate = LocalDate.parse("2025-01-15") val epochDay = localDate.toEpochDay.toInt assert(actual.equals(builder.equal(9, epochDay))) + checkAnswer( + sql(s"SELECT date_col from test_tbl WHERE $filter"), + sql("SELECT date('2025-01-15')")) + assert(scanFilesCount(filter) == 1) + filter = "binary_col = binary('b1')" intercept[UnsupportedOperationException] { - actual = converter.convert(v2Filter("binary = binary('b1')")) + actual = converter.convert(v2Filter(filter)) } + checkAnswer(sql(s"SELECT binary_col from test_tbl WHERE $filter"), sql("SELECT binary('b1')")) + assert(scanFilesCount(filter) == 4) } test("V2Filter: timestamp and timestamp_ntz") { withTimeZone("Asia/Shanghai") { withTable("ts_tbl", "ts_ntz_tbl") { sql("CREATE TABLE ts_tbl (ts_col TIMESTAMP) USING paimon") + sql("INSERT INTO ts_tbl VALUES (timestamp'2025-01-15 00:00:00.123')") + sql("INSERT INTO ts_tbl VALUES (timestamp'2025-01-16 00:00:00.123')") + + val filter1 = "ts_col = timestamp'2025-01-15 00:00:00.123'" val rowType1 = loadTable("ts_tbl").rowType() val converter1 = SparkV2FilterConverter(rowType1) - val actual1 = - converter1.convert(v2Filter("ts_col = timestamp'2025-01-15 00:00:00.123'", "ts_tbl")) - assert( - actual1.equals(new PredicateBuilder(rowType1) + val actual1 = converter1.convert(v2Filter(filter1, "ts_tbl")) + if (treatPaimonTimestampTypeAsSparkTimestampType()) { + assert(actual1.equals(new PredicateBuilder(rowType1) + .equal(0, Timestamp.fromLocalDateTime(LocalDateTime.parse("2025-01-15T00:00:00.123"))))) + } else { + assert(actual1.equals(new PredicateBuilder(rowType1) .equal(0, Timestamp.fromLocalDateTime(LocalDateTime.parse("2025-01-14T16:00:00.123"))))) + } + checkAnswer( + sql(s"SELECT ts_col from ts_tbl WHERE $filter1"), + sql("SELECT timestamp'2025-01-15 00:00:00.123'")) + assert(scanFilesCount(filter1, "ts_tbl") == 1) // Spark support TIMESTAMP_NTZ since Spark 3.4 if (gteqSpark3_4) { sql("CREATE TABLE ts_ntz_tbl (ts_ntz_col TIMESTAMP_NTZ) USING paimon") + sql("INSERT INTO ts_ntz_tbl VALUES (timestamp_ntz'2025-01-15 00:00:00.123')") + sql("INSERT INTO ts_ntz_tbl VALUES (timestamp_ntz'2025-01-16 00:00:00.123')") + val filter2 = "ts_ntz_col = timestamp_ntz'2025-01-15 00:00:00.123'" val rowType2 = loadTable("ts_ntz_tbl").rowType() val converter2 = SparkV2FilterConverter(rowType2) - val actual2 = converter2.convert( - v2Filter("ts_ntz_col = timestamp_ntz'2025-01-15 00:00:00.123'", "ts_ntz_tbl")) + val actual2 = converter2.convert(v2Filter(filter2, "ts_ntz_tbl")) assert(actual2.equals(new PredicateBuilder(rowType2) .equal(0, Timestamp.fromLocalDateTime(LocalDateTime.parse("2025-01-15T00:00:00.123"))))) + checkAnswer( + sql(s"SELECT ts_ntz_col from ts_ntz_tbl WHERE $filter2"), + sql("SELECT timestamp_ntz'2025-01-15 00:00:00.123'")) + assert(scanFilesCount(filter2, "ts_ntz_tbl") == 1) } } } } test("V2Filter: EqualTo") { - val actual = converter.convert(v2Filter("int_col = 1")) + val filter = "int_col = 1" + val actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.equal(3, 1))) + checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), Seq(Row(1))) + assert(scanFilesCount(filter) == 1) } test("V2Filter: EqualNullSafe") { - var actual = converter.convert(v2Filter("int_col <=> 1")) + var filter = "int_col <=> 1" + var actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.equal(3, 1))) + checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), Seq(Row(1))) + assert(scanFilesCount(filter) == 1) - actual = converter.convert(v2Filter("int_col <=> null")) + filter = "int_col <=> null" + actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.isNull(3))) + checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), Seq(Row(null))) + assert(scanFilesCount(filter) == 1) } test("V2Filter: GreaterThan") { - val actual = converter.convert(v2Filter("int_col > 1")) - assert(actual.equals(builder.greaterThan(3, 1))) + val filter = "int_col > 2" + val actual = converter.convert(v2Filter(filter)) + assert(actual.equals(builder.greaterThan(3, 2))) + checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), Seq(Row(3))) + assert(scanFilesCount(filter) == 1) } test("V2Filter: GreaterThanOrEqual") { - val actual = converter.convert(v2Filter("int_col >= 1")) - assert(actual.equals(builder.greaterOrEqual(3, 1))) + val filter = "int_col >= 2" + val actual = converter.convert(v2Filter(filter)) + assert(actual.equals(builder.greaterOrEqual(3, 2))) + checkAnswer( + sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), + Seq(Row(2), Row(3))) + assert(scanFilesCount(filter) == 2) } test("V2Filter: LessThan") { - val actual = converter.convert(v2Filter("int_col < 1")) - assert(actual.equals(builder.lessThan(3, 1))) + val filter = "int_col < 2" + val actual = converter.convert(v2Filter("int_col < 2")) + assert(actual.equals(builder.lessThan(3, 2))) + checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), Seq(Row(1))) + assert(scanFilesCount(filter) == 1) } test("V2Filter: LessThanOrEqual") { - val actual = converter.convert(v2Filter("int_col <= 1")) - assert(actual.equals(builder.lessOrEqual(3, 1))) + val filter = "int_col <= 2" + val actual = converter.convert(v2Filter("int_col <= 2")) + assert(actual.equals(builder.lessOrEqual(3, 2))) + checkAnswer( + sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col "), + Seq(Row(1), Row(2))) + assert(scanFilesCount(filter) == 2) } test("V2Filter: In") { - val actual = converter.convert(v2Filter("int_col IN (1, 2, 3)")) - assert(actual.equals(builder.in(3, List(1, 2, 3).map(_.asInstanceOf[AnyRef]).asJava))) + val filter = "int_col IN (1, 2)" + val actual = converter.convert(v2Filter("int_col IN (1, 2)")) + assert(actual.equals(builder.in(3, List(1, 2).map(_.asInstanceOf[AnyRef]).asJava))) + checkAnswer( + sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), + Seq(Row(1), Row(2))) + assert(scanFilesCount(filter) == 2) } test("V2Filter: IsNull") { - val actual = converter.convert(v2Filter("int_col IS NULL")) + val filter = "int_col IS NULL" + val actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.isNull(3))) + checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), Seq(Row(null))) + assert(scanFilesCount(filter) == 1) } test("V2Filter: IsNotNull") { - val actual = converter.convert(v2Filter("int_col IS NOT NULL")) + val filter = "int_col IS NOT NULL" + val actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.isNotNull(3))) + checkAnswer( + sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), + Seq(Row(1), Row(2), Row(3))) + assert(scanFilesCount(filter) == 3) } test("V2Filter: And") { - val actual = converter.convert(v2Filter("int_col > 1 AND int_col < 10")) - assert(actual.equals(PredicateBuilder.and(builder.greaterThan(3, 1), builder.lessThan(3, 10)))) + val filter = "int_col > 1 AND int_col < 3" + val actual = converter.convert(v2Filter(filter)) + assert(actual.equals(PredicateBuilder.and(builder.greaterThan(3, 1), builder.lessThan(3, 3)))) + checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), Seq(Row(2))) + assert(scanFilesCount(filter) == 1) } test("V2Filter: Or") { - val actual = converter.convert(v2Filter("int_col > 1 OR int_col < 10")) - assert(actual.equals(PredicateBuilder.or(builder.greaterThan(3, 1), builder.lessThan(3, 10)))) + val filter = "int_col > 2 OR int_col IS NULL" + val actual = converter.convert(v2Filter(filter)) + assert(actual.equals(PredicateBuilder.or(builder.greaterThan(3, 2), builder.isNull(3)))) + checkAnswer( + sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), + Seq(Row(null), Row(3))) + assert(scanFilesCount(filter) == 2) } test("V2Filter: Not") { - val actual = converter.convert(v2Filter("NOT (int_col > 1)")) - assert(actual.equals(builder.greaterThan(3, 1).negate().get())) + val filter = "NOT (int_col > 2)" + val actual = converter.convert(v2Filter(filter)) + assert(actual.equals(builder.greaterThan(3, 2).negate().get())) + checkAnswer( + sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), + Seq(Row(1), Row(2))) + assert(scanFilesCount(filter) == 2) } test("V2Filter: StartWith") { - val actual = converter.convert(v2Filter("string_col LIKE 'h%'")) + val filter = "string_col LIKE 'h%'" + val actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.startsWith(0, BinaryString.fromString("h")))) + checkAnswer( + sql(s"SELECT string_col from test_tbl WHERE $filter ORDER BY string_col"), + Seq(Row("hello"), Row("hi"))) + assert(scanFilesCount(filter) == 2) } test("V2Filter: EndWith") { - val actual = converter.convert(v2Filter("string_col LIKE '%o'")) + val filter = "string_col LIKE '%o'" + val actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.endsWith(0, BinaryString.fromString("o")))) + checkAnswer( + sql(s"SELECT string_col from test_tbl WHERE $filter ORDER BY string_col"), + Seq(Row("hello"))) + // EndWith does not have file skipping effect now. + assert(scanFilesCount(filter) == 4) } test("V2Filter: Contains") { - val actual = converter.convert(v2Filter("string_col LIKE '%e%'")) + val filter = "string_col LIKE '%e%'" + val actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.contains(0, BinaryString.fromString("e")))) + checkAnswer( + sql(s"SELECT string_col from test_tbl WHERE $filter ORDER BY string_col"), + Seq(Row("hello"))) + // Contains does not have file skipping effect now. + assert(scanFilesCount(filter) == 4) } private def v2Filter(str: String, tableName: String = "test_tbl"): Predicate = { @@ -221,4 +363,11 @@ abstract class SparkV2FilterConverterTestBase extends PaimonSparkTestBase { .condition translateFilterV2(condition).get } + + private def scanFilesCount(str: String, tableName: String = "test_tbl"): Int = { + getPaimonScan(s"SELECT * FROM $tableName WHERE $str").lazyInputPartitions + .flatMap(_.splits) + .map(_.asInstanceOf[DataSplit].dataFiles().size()) + .sum + } }