Skip to content

Commit

Permalink
[spark] Use Seq in the PaimonScan's constructor to ensure equals (#3185)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zouxxyy authored Apr 10, 2024
1 parent f067979 commit 8e77b1e
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.StructType
case class PaimonScan(
table: Table,
requiredSchema: StructType,
filters: Array[Predicate],
reservedFilters: Array[Filter],
filters: Seq[Predicate],
reservedFilters: Seq[Filter],
pushDownLimit: Option[Int])
extends PaimonBaseScan(table, requiredSchema, filters, reservedFilters, pushDownLimit)
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ import scala.collection.JavaConverters._
abstract class PaimonBaseScan(
table: Table,
requiredSchema: StructType,
filters: Array[Predicate],
reservedFilters: Array[Filter],
filters: Seq[Predicate],
reservedFilters: Seq[Filter],
pushDownLimit: Option[Int])
extends Scan
with SupportsReportStatistics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ import scala.collection.JavaConverters._
case class PaimonScan(
table: Table,
requiredSchema: StructType,
filters: Array[Predicate],
reservedFilters: Array[Filter],
filters: Seq[Predicate],
reservedFilters: Seq[Filter],
pushDownLimit: Option[Int])
extends PaimonBaseScan(table, requiredSchema, filters, reservedFilters, pushDownLimit)
with SupportsRuntimeFiltering {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ trait MergePaimonScalarSubqueriersBase extends Rule[LogicalPlan] with PredicateH
protected def mergePaimonScan(scan1: PaimonScan, scan2: PaimonScan): Option[PaimonScan] = {
if (
scan1.table == scan2.table &&
scan1.filters.sameElements(scan2.filters) &&
scan1.filters == scan2.filters &&
scan1.pushDownLimit == scan2.pushDownLimit
) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ trait StatisticsHelperBase extends SQLConfHelper {

val requiredSchema: StructType

def filterStatistics(v2Stats: Statistics, filters: Array[Filter]): Statistics = {
def filterStatistics(v2Stats: Statistics, filters: Seq[Filter]): Statistics = {
val attrs: Seq[AttributeReference] =
requiredSchema.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
val condition = filterToCondition(filters, attrs)
Expand All @@ -52,9 +52,7 @@ trait StatisticsHelperBase extends SQLConfHelper {
}
}

private def filterToCondition(
filters: Array[Filter],
attrs: Seq[Attribute]): Option[Expression] = {
private def filterToCondition(filters: Seq[Filter], attrs: Seq[Attribute]): Option[Expression] = {
StructFilters.filterToExpression(filters.reduce(And), toRef).map {
expression =>
expression.transform {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ import org.apache.paimon.spark.catalyst.optimizer.MergePaimonScalarSubqueriers
import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, LogicalPlan, OneRowRelation, WithCTE}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.functions._
import org.junit.jupiter.api.Assertions

import scala.collection.immutable

Expand Down Expand Up @@ -91,6 +93,25 @@ abstract class PaimonOptimizationTestBase extends PaimonSparkTestBase {
}
}

test("Paimon Optimization: paimon scan equals") {
withTable("T") {
spark.sql(s"CREATE TABLE T (id INT, name STRING, pt STRING) PARTITIONED BY (pt)")
spark.sql(s"INSERT INTO T VALUES (1, 'a', 'p1'), (2, 'b', 'p1'), (3, 'c', 'p2')")

val sqlText = "SELECT * FROM T WHERE id = 1 AND pt = 'p1' LIMIT 1"
def getPaimonScan(sqlText: String) = {
spark
.sql(sqlText)
.queryExecution
.optimizedPlan
.collectFirst { case relation: DataSourceV2ScanRelation => relation }
.get
.scan
}
Assertions.assertEquals(getPaimonScan(sqlText), getPaimonScan(sqlText))
}
}

private def definitionNode(plan: LogicalPlan, cteIndex: Int) = {
CTERelationDef(plan, cteIndex, underSubquery = true)
}
Expand Down

0 comments on commit 8e77b1e

Please sign in to comment.