diff --git a/paimon-spark/paimon-spark-3.1/src/main/scala/org/apache/paimon/spark/PaimonScan.scala b/paimon-spark/paimon-spark-3.1/src/main/scala/org/apache/paimon/spark/PaimonScan.scala index f9546c961e91e..bc2c2bc8e9fcb 100644 --- a/paimon-spark/paimon-spark-3.1/src/main/scala/org/apache/paimon/spark/PaimonScan.scala +++ b/paimon-spark/paimon-spark-3.1/src/main/scala/org/apache/paimon/spark/PaimonScan.scala @@ -27,5 +27,6 @@ case class PaimonScan( table: Table, requiredSchema: StructType, filters: Array[Predicate], + reservedFilters: Array[Filter], pushDownLimit: Option[Int]) - extends PaimonBaseScan(table, requiredSchema, filters, pushDownLimit) + extends PaimonBaseScan(table, requiredSchema, filters, reservedFilters, pushDownLimit) diff --git a/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/statistics/StatisticsHelper.scala b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/statistics/StatisticsHelper.scala new file mode 100644 index 0000000000000..66d1265b0e937 --- /dev/null +++ b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/statistics/StatisticsHelper.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.paimon.spark.statistics + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.connector.read.Statistics +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation + +trait StatisticsHelper extends StatisticsHelperBase { + protected def toV1Stats(v2Stats: Statistics, attrs: Seq[Attribute]): logical.Statistics = { + DataSourceV2Relation.transformV2Stats(v2Stats, None, conf.defaultSizeInBytes) + } +} diff --git a/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/AnalyzeTableTest.scala b/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/AnalyzeTableTest.scala index c6f788eab13e4..2abc0c46ba0bf 100644 --- a/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/AnalyzeTableTest.scala +++ b/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/AnalyzeTableTest.scala @@ -17,4 +17,44 @@ */ package org.apache.paimon.spark.sql -class AnalyzeTableTest extends AnalyzeTableTestBase {} +import org.junit.jupiter.api.Assertions + +class AnalyzeTableTest extends AnalyzeTableTestBase { + test("Paimon analyze: spark use col stats") { + spark.sql(s""" + |CREATE TABLE T (id STRING, name STRING, i INT, l LONG) + |USING PAIMON + |TBLPROPERTIES ('primary-key'='id') + |""".stripMargin) + + spark.sql(s"INSERT INTO T VALUES ('1', 'a', 1, 1)") + spark.sql(s"INSERT INTO T VALUES ('2', 'aaa', 1, 2)") + spark.sql(s"ANALYZE TABLE T COMPUTE STATISTICS FOR ALL COLUMNS") + + val stats = getScanStatistic("SELECT * FROM T") + Assertions.assertEquals(2L, stats.rowCount.get.longValue()) + Assertions.assertEquals(4, stats.attributeStats.size) + } + + test("Paimon analyze: partition filter push down hit") { + spark.sql(s""" + |CREATE TABLE T (id INT, name STRING, pt INT) + |TBLPROPERTIES ('primary-key'='id, pt', 'bucket'='2') + |PARTITIONED BY (pt) + |""".stripMargin) + + spark.sql("INSERT INTO T VALUES (1, 'a', 1), (2, 'b', 1), (3, 'c', 2), (4, 'd', 3)") + spark.sql(s"ANALYZE TABLE T COMPUTE STATISTICS FOR ALL COLUMNS") + + // paimon will reserve partition filter and not return it to spark, we need to ensure stats are filtered correctly. + // partition push down hit + var sql = "SELECT * FROM T WHERE pt < 1" + Assertions.assertEquals(0L, getScanStatistic(sql).rowCount.get.longValue()) + checkAnswer(spark.sql(sql), Nil) + + // partition push down not hit + sql = "SELECT * FROM T WHERE id < 1" + Assertions.assertEquals(4L, getScanStatistic(sql).rowCount.get.longValue()) + checkAnswer(spark.sql(sql), Nil) + } +} diff --git a/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/AnalyzeTableTest.scala b/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/AnalyzeTableTest.scala index c6f788eab13e4..2abc0c46ba0bf 100644 --- a/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/AnalyzeTableTest.scala +++ b/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/AnalyzeTableTest.scala @@ -17,4 +17,44 @@ */ package org.apache.paimon.spark.sql -class AnalyzeTableTest extends AnalyzeTableTestBase {} +import org.junit.jupiter.api.Assertions + +class AnalyzeTableTest extends AnalyzeTableTestBase { + test("Paimon analyze: spark use col stats") { + spark.sql(s""" + |CREATE TABLE T (id STRING, name STRING, i INT, l LONG) + |USING PAIMON + |TBLPROPERTIES ('primary-key'='id') + |""".stripMargin) + + spark.sql(s"INSERT INTO T VALUES ('1', 'a', 1, 1)") + spark.sql(s"INSERT INTO T VALUES ('2', 'aaa', 1, 2)") + spark.sql(s"ANALYZE TABLE T COMPUTE STATISTICS FOR ALL COLUMNS") + + val stats = getScanStatistic("SELECT * FROM T") + Assertions.assertEquals(2L, stats.rowCount.get.longValue()) + Assertions.assertEquals(4, stats.attributeStats.size) + } + + test("Paimon analyze: partition filter push down hit") { + spark.sql(s""" + |CREATE TABLE T (id INT, name STRING, pt INT) + |TBLPROPERTIES ('primary-key'='id, pt', 'bucket'='2') + |PARTITIONED BY (pt) + |""".stripMargin) + + spark.sql("INSERT INTO T VALUES (1, 'a', 1), (2, 'b', 1), (3, 'c', 2), (4, 'd', 3)") + spark.sql(s"ANALYZE TABLE T COMPUTE STATISTICS FOR ALL COLUMNS") + + // paimon will reserve partition filter and not return it to spark, we need to ensure stats are filtered correctly. + // partition push down hit + var sql = "SELECT * FROM T WHERE pt < 1" + Assertions.assertEquals(0L, getScanStatistic(sql).rowCount.get.longValue()) + checkAnswer(spark.sql(sql), Nil) + + // partition push down not hit + sql = "SELECT * FROM T WHERE id < 1" + Assertions.assertEquals(4L, getScanStatistic(sql).rowCount.get.longValue()) + checkAnswer(spark.sql(sql), Nil) + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala index 5ede8883dfdc0..d234092a5365b 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala @@ -17,11 +17,13 @@ */ package org.apache.paimon.spark -import org.apache.paimon.CoreOptions +import org.apache.paimon.{stats, CoreOptions} import org.apache.paimon.predicate.{Predicate, PredicateBuilder} import org.apache.paimon.spark.sources.PaimonMicroBatchStream +import org.apache.paimon.spark.statistics.StatisticsHelper import org.apache.paimon.table.{DataTable, FileStoreTable, Table} import org.apache.paimon.table.source.{ReadBuilder, Split} +import org.apache.paimon.types.RowType import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.read.{Batch, Scan, Statistics, SupportsReportStatistics} @@ -29,18 +31,22 @@ import org.apache.spark.sql.connector.read.streaming.MicroBatchStream import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType +import java.util.Optional + import scala.collection.JavaConverters._ abstract class PaimonBaseScan( table: Table, requiredSchema: StructType, filters: Array[Predicate], + reservedFilters: Array[Filter], pushDownLimit: Option[Int]) extends Scan with SupportsReportStatistics - with ScanHelper { + with ScanHelper + with StatisticsHelper { - private val tableRowType = table.rowType + val tableRowType: RowType = table.rowType private lazy val tableSchema = SparkTypeUtils.fromPaimonRowType(tableRowType) @@ -50,6 +56,8 @@ abstract class PaimonBaseScan( override val coreOptions: CoreOptions = CoreOptions.fromMap(table.options()) + lazy val statistics: Optional[stats.Statistics] = table.statistics() + lazy val readBuilder: ReadBuilder = { val _readBuilder = table.newReadBuilder() @@ -89,7 +97,12 @@ abstract class PaimonBaseScan( } override def estimateStatistics(): Statistics = { - PaimonStatistics(this) + val stats = PaimonStatistics(this) + if (reservedFilters.nonEmpty) { + filterStatistics(stats, reservedFilters) + } else { + stats + } } override def supportedCustomMetrics: Array[CustomMetric] = { diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala index e6ec877074c59..420f39637cdb7 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala @@ -37,10 +37,12 @@ abstract class PaimonBaseScanBuilder(table: Table) protected var pushed: Array[(Filter, Predicate)] = Array.empty + protected var reservedFilters: Array[Filter] = Array.empty + protected var pushDownLimit: Option[Int] = None override def build(): Scan = { - PaimonScan(table, requiredSchema, pushed.map(_._2), pushDownLimit) + PaimonScan(table, requiredSchema, pushed.map(_._2), reservedFilters, pushDownLimit) } /** @@ -51,6 +53,7 @@ abstract class PaimonBaseScanBuilder(table: Table) override def pushFilters(filters: Array[Filter]): Array[Filter] = { val pushable = mutable.ArrayBuffer.empty[(Filter, Predicate)] val postScan = mutable.ArrayBuffer.empty[Filter] + val reserved = mutable.ArrayBuffer.empty[Filter] val converter = new SparkFilterConverter(table.rowType) val visitor = new PartitionPredicateVisitor(table.partitionKeys()) @@ -59,7 +62,11 @@ abstract class PaimonBaseScanBuilder(table: Table) try { val predicate = converter.convert(filter) pushable.append((filter, predicate)) - if (!predicate.visit(visitor)) postScan.append(filter) + if (!predicate.visit(visitor)) { + postScan.append(filter) + } else { + reserved.append(filter) + } } catch { case e: UnsupportedOperationException => logWarning(e.getMessage) @@ -70,6 +77,9 @@ abstract class PaimonBaseScanBuilder(table: Table) if (pushable.nonEmpty) { this.pushed = pushable.toArray } + if (reserved.nonEmpty) { + this.reservedFilters = reserved.toArray + } postScan.toArray } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonImplicits.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonImplicits.scala index 8b794259681d7..2aa4431ca431c 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonImplicits.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonImplicits.scala @@ -24,10 +24,16 @@ import scala.language.implicitConversions object PaimonImplicits { implicit def toScalaOption[T](o: Optional[T]): Option[T] = { if (o.isPresent) { - Option.apply(o.get()) + Some(o.get) } else { None } } + implicit def toJavaOptional[T, U](o: Option[T]): Optional[U] = { + o match { + case Some(t) => Optional.ofNullable(t.asInstanceOf[U]) + case _ => Optional.empty[U] + } + } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala index 2bb3ff0c9a65f..d4bda95ec95ad 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala @@ -32,8 +32,9 @@ case class PaimonScan( table: Table, requiredSchema: StructType, filters: Array[Predicate], + reservedFilters: Array[Filter], pushDownLimit: Option[Int]) - extends PaimonBaseScan(table, requiredSchema, filters, pushDownLimit) + extends PaimonBaseScan(table, requiredSchema, filters, reservedFilters, pushDownLimit) with SupportsRuntimeFiltering { override def filterAttributes(): Array[NamedReference] = { diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala index 208c5816c98ee..e5129ae3cb015 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala @@ -17,9 +17,19 @@ */ package org.apache.paimon.spark +import org.apache.paimon.stats +import org.apache.paimon.stats.ColStats +import org.apache.paimon.types.DataType + +import org.apache.spark.sql.Utils +import org.apache.spark.sql.catalyst.plans.logical.ColumnStat +import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.connector.read.Statistics +import org.apache.spark.sql.connector.read.colstats.ColumnStatistics + +import java.util.{Optional, OptionalLong} -import java.util.OptionalLong +import scala.collection.JavaConverters._ case class PaimonStatistics[T <: PaimonBaseScan](scan: T) extends Statistics { @@ -27,9 +37,69 @@ case class PaimonStatistics[T <: PaimonBaseScan](scan: T) extends Statistics { private lazy val scannedTotalSize: Long = rowCount * scan.readSchema().defaultSize - override def sizeInBytes(): OptionalLong = OptionalLong.of(scannedTotalSize) + private lazy val paimonStats: Optional[stats.Statistics] = scan.statistics + + override def sizeInBytes(): OptionalLong = if (paimonStats.isPresent) + paimonStats.get().mergedRecordSize() + else OptionalLong.of(scannedTotalSize) + + override def numRows(): OptionalLong = + if (paimonStats.isPresent) paimonStats.get().mergedRecordCount() else OptionalLong.of(rowCount) + + override def columnStats(): java.util.Map[NamedReference, ColumnStatistics] = { + val requiredFields = scan.readSchema().fieldNames.toList.asJava + val resultMap = new java.util.HashMap[NamedReference, ColumnStatistics]() + if (paimonStats.isPresent) { + val paimonColStats = paimonStats.get().colStats() + scan.tableRowType.getFields + .stream() + .filter( + field => requiredFields.contains(field.name) && paimonColStats.containsKey(field.name())) + .forEach( + f => + resultMap.put( + Utils.fieldReference(f.name()), + PaimonColumnStats(f.`type`(), paimonColStats.get(f.name())))) + } + resultMap + } +} + +case class PaimonColumnStats( + override val nullCount: OptionalLong, + override val min: Optional[Object], + override val max: Optional[Object], + override val distinctCount: OptionalLong, + override val avgLen: OptionalLong, + override val maxLen: OptionalLong) + extends ColumnStatistics - override def numRows(): OptionalLong = OptionalLong.of(rowCount) +object PaimonColumnStats { + def apply(dateType: DataType, paimonColStats: ColStats[_]): PaimonColumnStats = { + PaimonColumnStats( + paimonColStats.nullCount, + Optional.ofNullable(SparkInternalRow.fromPaimon(paimonColStats.min().orElse(null), dateType)), + Optional.ofNullable(SparkInternalRow.fromPaimon(paimonColStats.max().orElse(null), dateType)), + paimonColStats.distinctCount, + paimonColStats.avgLen, + paimonColStats.maxLen + ) + } - // TODO: extend columnStats for CBO + def apply(v1ColStats: ColumnStat): PaimonColumnStats = { + import PaimonImplicits._ + PaimonColumnStats( + if (v1ColStats.nullCount.isDefined) OptionalLong.of(v1ColStats.nullCount.get.longValue()) + else OptionalLong.empty(), + v1ColStats.min, + v1ColStats.max, + if (v1ColStats.distinctCount.isDefined) + OptionalLong.of(v1ColStats.distinctCount.get.longValue()) + else OptionalLong.empty(), + if (v1ColStats.avgLen.isDefined) OptionalLong.of(v1ColStats.avgLen.get.longValue()) + else OptionalLong.empty(), + if (v1ColStats.maxLen.isDefined) OptionalLong.of(v1ColStats.maxLen.get.longValue()) + else OptionalLong.empty() + ) + } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/statistics/StatisticsHelper.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/statistics/StatisticsHelper.scala new file mode 100644 index 0000000000000..e0ff2507733d7 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/statistics/StatisticsHelper.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.paimon.spark.statistics + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.connector.read.Statistics +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation + +trait StatisticsHelper extends StatisticsHelperBase { + protected def toV1Stats(v2Stats: Statistics, attrs: Seq[Attribute]): logical.Statistics = { + DataSourceV2Relation.transformV2Stats(v2Stats, None, conf.defaultSizeInBytes, attrs) + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/statistics/StatisticsHelperBase.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/statistics/StatisticsHelperBase.scala new file mode 100644 index 0000000000000..26d442c640749 --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/statistics/StatisticsHelperBase.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.paimon.spark.statistics + +import org.apache.paimon.spark.PaimonColumnStats + +import org.apache.spark.sql.Utils +import org.apache.spark.sql.catalyst.{SQLConfHelper, StructFilters} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, Expression} +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.FilterEstimation +import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.read.Statistics +import org.apache.spark.sql.connector.read.colstats.ColumnStatistics +import org.apache.spark.sql.sources.{And, Filter} +import org.apache.spark.sql.types.{StructField, StructType} + +import java.util.OptionalLong + +trait StatisticsHelperBase extends SQLConfHelper { + + val requiredSchema: StructType + + def filterStatistics(v2Stats: Statistics, filters: Array[Filter]): Statistics = { + val attrs: Seq[AttributeReference] = + requiredSchema.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) + val condition = filterToCondition(filters, attrs) + + if (condition.isDefined && v2Stats.numRows().isPresent) { + val filteredStats = FilterEstimation( + logical.Filter(condition.get, FakePlanWithStats(toV1Stats(v2Stats, attrs)))).estimate.get + toV2Stats(filteredStats) + } else { + v2Stats + } + } + + private def filterToCondition( + filters: Array[Filter], + attrs: Seq[Attribute]): Option[Expression] = { + StructFilters.filterToExpression(filters.reduce(And), toRef).map { + expression => + expression.transform { + case ref: BoundReference => attrs.find(_.name == requiredSchema(ref.ordinal).name).get + } + } + } + + private def toRef(attr: String): Option[BoundReference] = { + val index = requiredSchema.fieldIndex(attr) + val field = requiredSchema(index) + Option.apply(BoundReference(index, field.dataType, field.nullable)) + } + + protected def toV1Stats(v2Stats: Statistics, attrs: Seq[Attribute]): logical.Statistics + + private def toV2Stats(v1Stats: logical.Statistics): Statistics = { + new Statistics() { + override def sizeInBytes(): OptionalLong = if (v1Stats.sizeInBytes != null) + OptionalLong.of(v1Stats.sizeInBytes.longValue()) + else OptionalLong.empty() + + override def numRows(): OptionalLong = if (v1Stats.rowCount.isDefined) + OptionalLong.of(v1Stats.rowCount.get.longValue()) + else OptionalLong.empty() + + override def columnStats(): java.util.Map[NamedReference, ColumnStatistics] = { + val columnStatsMap = new java.util.HashMap[NamedReference, ColumnStatistics]() + v1Stats.attributeStats.foreach { + case (attr, v1ColStats) => + columnStatsMap.put( + Utils.fieldReference(attr.name), + PaimonColumnStats(v1ColStats) + ) + } + columnStatsMap + } + } + } +} + +case class FakePlanWithStats(v1Stats: logical.Statistics) extends LogicalPlan { + override def output: Seq[Attribute] = Seq.empty + override def children: Seq[LogicalPlan] = Seq.empty + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = throw new UnsupportedOperationException + override def stats: logical.Statistics = v1Stats +} diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/AnalyzeTableTestBase.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/AnalyzeTableTestBase.scala index d8f121f6e1339..5656fa60e64ec 100644 --- a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/AnalyzeTableTestBase.scala +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/AnalyzeTableTestBase.scala @@ -24,11 +24,11 @@ import org.apache.paimon.stats.ColStats import org.apache.paimon.utils.DateTimeUtils import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.jupiter.api.Assertions -import java.util.concurrent.TimeUnit - abstract class AnalyzeTableTestBase extends PaimonSparkTestBase { test("Paimon analyze: analyze table only") { @@ -283,7 +283,32 @@ abstract class AnalyzeTableTestBase extends PaimonSparkTestBase { Assertions.assertEquals(1, statsFileCount(tableLocation, fileIO)) } - private def statsFileCount(tableLocation: Path, fileIO: FileIO): Int = { + test("Paimon analyze: spark use table stats") { + spark.sql(s""" + |CREATE TABLE T (id STRING, name STRING, i INT, l LONG) + |USING PAIMON + |TBLPROPERTIES ('primary-key'='id') + |""".stripMargin) + + spark.sql(s"INSERT INTO T VALUES ('1', 'a', 1, 1)") + spark.sql(s"INSERT INTO T VALUES ('2', 'aaa', 1, 2)") + spark.sql(s"ANALYZE TABLE T COMPUTE STATISTICS") + + val stats = getScanStatistic("SELECT * FROM T") + Assertions.assertEquals(2L, stats.rowCount.get.longValue()) + } + + protected def statsFileCount(tableLocation: Path, fileIO: FileIO): Int = { fileIO.listStatus(new Path(tableLocation, "statistics")).length } + + protected def getScanStatistic(sql: String): Statistics = { + val relation = spark + .sql(sql) + .queryExecution + .optimizedPlan + .collectFirst { case relation: DataSourceV2ScanRelation => relation } + .get + relation.computeStats() + } }