From 0c24f1a9e79ac88543f88d392342792bab7a7b49 Mon Sep 17 00:00:00 2001 From: zouxxyy Date: Sat, 14 Dec 2024 15:52:31 +0800 Subject: [PATCH] first version --- .../apache/paimon/table/source/DataSplit.java | 40 +++++++++++ .../table/source/DataTableBatchScan.java | 19 +----- .../table/source/SplitGeneratorTest.java | 4 +- .../apache/paimon/table/source/SplitTest.java | 67 +++++++++++++++++++ .../paimon/spark/PaimonScanBuilder.scala | 12 ++-- .../spark/aggregate/LocalAggregator.scala | 22 +++--- .../paimon/spark/sql/PaimonPushDownTest.scala | 52 ++++++++++---- .../spark/sql/PushDownAggregatesTest.scala | 66 +++++++++++++----- 8 files changed, 216 insertions(+), 66 deletions(-) diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java b/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java index 29405466b93f..b9460f28b4e7 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java @@ -44,6 +44,7 @@ import static org.apache.paimon.io.DataFilePathFactory.INDEX_PATH_SUFFIX; import static org.apache.paimon.utils.Preconditions.checkArgument; +import static org.apache.paimon.utils.Preconditions.checkState; /** Input splits. Needed by most batch computation engines. */ public class DataSplit implements Split { @@ -126,6 +127,45 @@ public long rowCount() { return rowCount; } + /** Whether it is possible to calculate the merged row count. */ + public boolean mergedRowCountAvailable() { + return rawConvertible + && (dataDeletionFiles == null + || dataDeletionFiles.stream() + .allMatch(f -> f == null || f.cardinality() != null)); + } + + public long mergedRowCount() { + checkState(mergedRowCountAvailable()); + return partialMergedRowCount(); + } + + /** + * Obtain merged row count as much as possible. There are two scenarios where accurate row count + * can be calculated: + * + *

1. raw file and no deletion file. + * + *

2. raw file + deletion file with cardinality. + */ + public long partialMergedRowCount() { + long sum = 0L; + if (rawConvertible) { + List rawFiles = convertToRawFiles().orElse(null); + if (rawFiles != null) { + for (int i = 0; i < rawFiles.size(); i++) { + RawFile rawFile = rawFiles.get(i); + if (dataDeletionFiles == null || dataDeletionFiles.get(i) == null) { + sum += rawFile.rowCount(); + } else if (dataDeletionFiles.get(i).cardinality() != null) { + sum += rawFile.rowCount() - dataDeletionFiles.get(i).cardinality(); + } + } + } + } + return sum; + } + @Override public Optional> convertToRawFiles() { if (rawConvertible) { diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableBatchScan.java b/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableBatchScan.java index 635802cc9dcb..a4fe6d73bba1 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableBatchScan.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableBatchScan.java @@ -28,7 +28,6 @@ import java.util.ArrayList; import java.util.List; -import java.util.Objects; import static org.apache.paimon.CoreOptions.MergeEngine.FIRST_ROW; @@ -103,9 +102,9 @@ private StartingScanner.Result applyPushDownLimit(StartingScanner.Result result) List limitedSplits = new ArrayList<>(); for (DataSplit dataSplit : splits) { if (dataSplit.rawConvertible()) { - long splitRowCount = getRowCountForSplit(dataSplit); + long partialMergedRowCount = dataSplit.partialMergedRowCount(); limitedSplits.add(dataSplit); - scannedRowCount += splitRowCount; + scannedRowCount += partialMergedRowCount; if (scannedRowCount >= pushDownLimit) { SnapshotReader.Plan newPlan = new PlanImpl(plan.watermark(), plan.snapshotId(), limitedSplits); @@ -117,20 +116,6 @@ private StartingScanner.Result applyPushDownLimit(StartingScanner.Result result) return result; } - /** - * 0 represents that we can't compute the row count of this split: 1. the split needs to be - * merged; 2. the table enabled deletion vector and there are some deletion files. - */ - private long getRowCountForSplit(DataSplit split) { - if (split.deletionFiles().isPresent() - && split.deletionFiles().get().stream().anyMatch(Objects::nonNull)) { - return 0L; - } - return split.convertToRawFiles() - .map(files -> files.stream().map(RawFile::rowCount).reduce(Long::sum).orElse(0L)) - .orElse(0L); - } - @Override public DataTableScan withShard(int indexOfThisSubtask, int numberOfParallelSubtasks) { snapshotReader.withShard(indexOfThisSubtask, numberOfParallelSubtasks); diff --git a/paimon-core/src/test/java/org/apache/paimon/table/source/SplitGeneratorTest.java b/paimon-core/src/test/java/org/apache/paimon/table/source/SplitGeneratorTest.java index a9e093dab124..a1f7d69e2877 100644 --- a/paimon-core/src/test/java/org/apache/paimon/table/source/SplitGeneratorTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/table/source/SplitGeneratorTest.java @@ -43,10 +43,10 @@ public class SplitGeneratorTest { public static DataFileMeta newFileFromSequence( - String name, int rowCount, long minSequence, long maxSequence) { + String name, int fileSize, long minSequence, long maxSequence) { return new DataFileMeta( name, - rowCount, + fileSize, 1, EMPTY_ROW, EMPTY_ROW, diff --git a/paimon-core/src/test/java/org/apache/paimon/table/source/SplitTest.java b/paimon-core/src/test/java/org/apache/paimon/table/source/SplitTest.java index 359d38c973db..0219941a0ac0 100644 --- a/paimon-core/src/test/java/org/apache/paimon/table/source/SplitTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/table/source/SplitTest.java @@ -49,6 +49,41 @@ /** Test for {@link DataSplit}. */ public class SplitTest { + @Test + public void testSplitMergedRowCount() { + // not rawConvertible + List dataFiles = + Arrays.asList(newDataFile(1000L), newDataFile(2000L), newDataFile(3000L)); + DataSplit split = newDataSplit(false, dataFiles, null); + assertThat(split.partialMergedRowCount()).isEqualTo(0L); + assertThat(split.mergedRowCountAvailable()).isEqualTo(false); + + // rawConvertible without deletion files + split = newDataSplit(true, dataFiles, null); + assertThat(split.partialMergedRowCount()).isEqualTo(6000L); + assertThat(split.mergedRowCountAvailable()).isEqualTo(true); + assertThat(split.mergedRowCount()).isEqualTo(6000L); + + // rawConvertible with deletion files without cardinality + ArrayList deletionFiles = new ArrayList<>(); + deletionFiles.add(null); + deletionFiles.add(new DeletionFile("p", 1, 2, null)); + deletionFiles.add(new DeletionFile("p", 1, 2, 100L)); + split = newDataSplit(true, dataFiles, deletionFiles); + assertThat(split.partialMergedRowCount()).isEqualTo(3900L); + assertThat(split.mergedRowCountAvailable()).isEqualTo(false); + + // rawConvertible with deletion files with cardinality + deletionFiles = new ArrayList<>(); + deletionFiles.add(null); + deletionFiles.add(new DeletionFile("p", 1, 2, 200L)); + deletionFiles.add(new DeletionFile("p", 1, 2, 100L)); + split = newDataSplit(true, dataFiles, deletionFiles); + assertThat(split.partialMergedRowCount()).isEqualTo(5700L); + assertThat(split.mergedRowCountAvailable()).isEqualTo(true); + assertThat(split.mergedRowCount()).isEqualTo(5700L); + } + @Test public void testSerializer() throws IOException { DataFileTestDataGenerator gen = DataFileTestDataGenerator.builder().build(); @@ -311,4 +346,36 @@ public void testSerializerCompatibleV3() throws Exception { InstantiationUtil.deserializeObject(v2Bytes, DataSplit.class.getClassLoader()); assertThat(actual).isEqualTo(split); } + + private DataFileMeta newDataFile(long rowCount) { + return DataFileMeta.forAppend( + "my_data_file.parquet", + 1024 * 1024, + rowCount, + null, + 0L, + rowCount, + 1, + Collections.emptyList(), + null, + null, + null); + } + + private DataSplit newDataSplit( + boolean rawConvertible, + List dataFiles, + List deletionFiles) { + DataSplit.Builder builder = DataSplit.builder(); + builder.withSnapshot(1) + .withPartition(BinaryRow.EMPTY_ROW) + .withBucket(1) + .withBucketPath("my path") + .rawConvertible(rawConvertible) + .withDataFiles(dataFiles); + if (deletionFiles != null) { + builder.withDataDeletionFiles(deletionFiles); + } + return builder.build(); + } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala index d8b66e1cd1e0..0393a1cd1578 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala @@ -21,6 +21,7 @@ package org.apache.paimon.spark import org.apache.paimon.predicate.PredicateBuilder import org.apache.paimon.spark.aggregate.LocalAggregator import org.apache.paimon.table.Table +import org.apache.paimon.table.source.DataSplit import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates, SupportsPushDownLimit} @@ -36,12 +37,12 @@ class PaimonScanBuilder(table: Table) override def pushLimit(limit: Int): Boolean = { // It is safe, since we will do nothing if it is the primary table and the split is not `rawConvertible` pushDownLimit = Some(limit) - // just make a best effort to push down limit + // just make the best effort to push down limit false } override def supportCompletePushDown(aggregation: Aggregation): Boolean = { - // for now we only support complete push down, so there is no difference with `pushAggregation` + // for now, we only support complete push down, so there is no difference with `pushAggregation` pushAggregation(aggregation) } @@ -66,8 +67,11 @@ class PaimonScanBuilder(table: Table) val pushedPartitionPredicate = PredicateBuilder.and(pushedPredicates.map(_._2): _*) readBuilder.withFilter(pushedPartitionPredicate) } - val scan = readBuilder.newScan() - scan.listPartitionEntries.asScala.foreach(aggregator.update) + val dataSplits = readBuilder.newScan().plan().splits().asScala.map(_.asInstanceOf[DataSplit]) + if (!dataSplits.forall(_.mergedRowCountAvailable())) { + return false + } + dataSplits.foreach(aggregator.update) localScan = Some( PaimonLocalScan( aggregator.result(), diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/LocalAggregator.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/LocalAggregator.scala index 41e7fd3c3ce9..8988e7218d1f 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/LocalAggregator.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/LocalAggregator.scala @@ -19,10 +19,10 @@ package org.apache.paimon.spark.aggregate import org.apache.paimon.data.BinaryRow -import org.apache.paimon.manifest.PartitionEntry import org.apache.paimon.spark.SparkTypeUtils import org.apache.paimon.spark.data.SparkInternalRow import org.apache.paimon.table.{DataTable, Table} +import org.apache.paimon.table.source.DataSplit import org.apache.paimon.utils.{InternalRowUtils, ProjectedRow} import org.apache.spark.sql.catalyst.InternalRow @@ -78,13 +78,7 @@ class LocalAggregator(table: Table) { } def pushAggregation(aggregation: Aggregation): Boolean = { - if ( - !table.isInstanceOf[DataTable] || - !table.primaryKeys.isEmpty - ) { - return false - } - if (table.asInstanceOf[DataTable].coreOptions.deletionVectorsEnabled) { + if (!table.isInstanceOf[DataTable]) { return false } @@ -108,12 +102,12 @@ class LocalAggregator(table: Table) { SparkInternalRow.create(partitionType).replace(genericRow) } - def update(partitionEntry: PartitionEntry): Unit = { + def update(dataSplit: DataSplit): Unit = { assert(isInitialized) - val groupByRow = requiredGroupByRow(partitionEntry.partition()) + val groupByRow = requiredGroupByRow(dataSplit.partition()) val aggFuncEvaluator = groupByEvaluatorMap.getOrElseUpdate(groupByRow, aggFuncEvaluatorGetter()) - aggFuncEvaluator.foreach(_.update(partitionEntry)) + aggFuncEvaluator.foreach(_.update(dataSplit)) } def result(): Array[InternalRow] = { @@ -147,7 +141,7 @@ class LocalAggregator(table: Table) { } trait AggFuncEvaluator[T] { - def update(partitionEntry: PartitionEntry): Unit + def update(dataSplit: DataSplit): Unit def result(): T def resultType: DataType def prettyName: String @@ -156,8 +150,8 @@ trait AggFuncEvaluator[T] { class CountStarEvaluator extends AggFuncEvaluator[Long] { private var _result: Long = 0L - override def update(partitionEntry: PartitionEntry): Unit = { - _result += partitionEntry.recordCount() + override def update(dataSplit: DataSplit): Unit = { + _result += dataSplit.mergedRowCount() } override def result(): Long = _result diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala index ba314e3afa81..503f1c8e3e9d 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala @@ -18,7 +18,7 @@ package org.apache.paimon.spark.sql -import org.apache.paimon.spark.{PaimonBatch, PaimonInputPartition, PaimonScan, PaimonSparkTestBase, SparkTable} +import org.apache.paimon.spark.{PaimonScan, PaimonSparkTestBase, SparkTable} import org.apache.paimon.table.source.DataSplit import org.apache.spark.sql.Row @@ -29,8 +29,6 @@ import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownLimit} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.junit.jupiter.api.Assertions -import scala.collection.JavaConverters._ - class PaimonPushDownTest extends PaimonSparkTestBase { import testImplicits._ @@ -64,7 +62,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase { checkAnswer(spark.sql(q), Row(1, "a", "p1") :: Nil) // case 2 - // filter "id = '1' or pt = 'p1'" can't push down completely, it still need to be evaluated after scanning + // filter "id = '1' or pt = 'p1'" can't push down completely, it still needs to be evaluated after scanning q = "SELECT * FROM T WHERE id = '1' or pt = 'p1'" Assertions.assertTrue(checkEqualToFilterExists(q, "pt", Literal("p1"))) checkAnswer(spark.sql(q), Row(1, "a", "p1") :: Row(2, "b", "p1") :: Nil) @@ -121,7 +119,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase { val dataSplitsWithoutLimit = scanBuilder.build().asInstanceOf[PaimonScan].getOriginSplits Assertions.assertTrue(dataSplitsWithoutLimit.length >= 2) - // It still return false even it can push down limit. + // It still returns false even it can push down limit. Assertions.assertFalse(scanBuilder.asInstanceOf[SupportsPushDownLimit].pushLimit(1)) val dataSplitsWithLimit = scanBuilder.build().asInstanceOf[PaimonScan].getOriginSplits Assertions.assertEquals(1, dataSplitsWithLimit.length) @@ -169,12 +167,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase { // Now, we have 4 dataSplits, and 2 dataSplit is nonRawConvertible, 2 dataSplit is rawConvertible. Assertions.assertEquals( 2, - dataSplitsWithoutLimit2 - .filter( - split => { - split.asInstanceOf[DataSplit].rawConvertible() - }) - .length) + dataSplitsWithoutLimit2.count(split => { split.asInstanceOf[DataSplit].rawConvertible() })) // Return 2 dataSplits. Assertions.assertFalse(scanBuilder2.asInstanceOf[SupportsPushDownLimit].pushLimit(2)) @@ -206,7 +199,40 @@ class PaimonPushDownTest extends PaimonSparkTestBase { // Need to scan all dataSplits. Assertions.assertEquals(4, dataSplitsWithLimit3.length) Assertions.assertEquals(1, spark.sql("SELECT * FROM T LIMIT 1").count()) + } + test("Paimon pushDown: limit for table with deletion vector") { + Seq(true, false).foreach( + deletionVectorsEnabled => { + Seq(true, false).foreach( + primaryKeyTable => { + withTable("T") { + sql(s""" + |CREATE TABLE T (id INT) + |TBLPROPERTIES ( + | 'deletion-vectors.enabled' = $deletionVectorsEnabled, + | '${if (primaryKeyTable) "primary-key" else "bucket-key"}' = 'id', + | 'bucket' = '10' + |) + |""".stripMargin) + + sql("INSERT INTO T SELECT id FROM range (1, 50000)") + sql("DELETE FROM T WHERE id % 13 = 0") + + val withoutLimit = getScanBuilder().build().asInstanceOf[PaimonScan].getOriginSplits + assert(withoutLimit.length == 10) + + val scanBuilder = getScanBuilder().asInstanceOf[SupportsPushDownLimit] + scanBuilder.pushLimit(1) + val withLimit = scanBuilder.build().asInstanceOf[PaimonScan].getOriginSplits + if (deletionVectorsEnabled || !primaryKeyTable) { + assert(withLimit.length == 1) + } else { + assert(withLimit.length == 10) + } + } + }) + }) } test("Paimon pushDown: runtime filter") { @@ -250,8 +276,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase { } private def getScanBuilder(tableName: String = "T"): ScanBuilder = { - new SparkTable(loadTable(tableName)) - .newScanBuilder(CaseInsensitiveStringMap.empty()) + SparkTable(loadTable(tableName)).newScanBuilder(CaseInsensitiveStringMap.empty()) } private def checkFilterExists(sql: String): Boolean = { @@ -272,5 +297,4 @@ class PaimonPushDownTest extends PaimonSparkTestBase { case _ => false } } - } diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala index 501e7bfb4a51..78c02644a7ce 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala @@ -117,22 +117,58 @@ class PushDownAggregatesTest extends PaimonSparkTestBase with AdaptiveSparkPlanH } } - test("Push down aggregate - primary table") { - withTable("T") { - spark.sql("CREATE TABLE T (c1 INT, c2 STRING) TBLPROPERTIES ('primary-key' = 'c1')") - runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(0) :: Nil, 2) - spark.sql("INSERT INTO T VALUES(1, 'x'), (2, 'x'), (3, 'x'), (3, 'x')") - runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(3) :: Nil, 2) - } + test("Push down aggregate - primary key table with deletion vector") { + Seq(true, false).foreach( + deletionVectorsEnabled => { + withTable("T") { + spark.sql(s""" + |CREATE TABLE T (c1 INT, c2 STRING) + |TBLPROPERTIES ( + |'primary-key' = 'c1', + |'deletion-vectors.enabled' = $deletionVectorsEnabled + |) + |""".stripMargin) + runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(0) :: Nil, 0) + + spark.sql("INSERT INTO T VALUES(1, 'x'), (2, 'x'), (3, 'x'), (3, 'x')") + runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(3) :: Nil, 0) + + spark.sql("INSERT INTO T VALUES(1, 'x_1')") + if (deletionVectorsEnabled) { + runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(3) :: Nil, 0) + } else { + runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(3) :: Nil, 2) + } + } + }) } - test("Push down aggregate - enable deletion vector") { - withTable("T") { - spark.sql( - "CREATE TABLE T (c1 INT, c2 STRING) TBLPROPERTIES('deletion-vectors.enabled' = 'true')") - runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(0) :: Nil, 2) - spark.sql("INSERT INTO T VALUES(1, 'x'), (2, 'x'), (3, 'x'), (3, 'x')") - runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(4) :: Nil, 2) - } + test("Push down aggregate - table with deletion vector") { + Seq(true, false).foreach( + deletionVectorsEnabled => { + Seq(true, false).foreach( + primaryKeyTable => { + withTable("T") { + sql(s""" + |CREATE TABLE T (id INT) + |TBLPROPERTIES ( + | 'deletion-vectors.enabled' = $deletionVectorsEnabled, + | '${if (primaryKeyTable) "primary-key" else "bucket-key"}' = 'id', + | 'bucket' = '1' + |) + |""".stripMargin) + + sql("INSERT INTO T SELECT id FROM range (0, 5000)") + runAndCheckAggregate("SELECT COUNT(*) FROM T", Seq(Row(5000)), 0) + + sql("DELETE FROM T WHERE id > 100 and id <= 400") + if (deletionVectorsEnabled || !primaryKeyTable) { + runAndCheckAggregate("SELECT COUNT(*) FROM T", Seq(Row(4700)), 0) + } else { + runAndCheckAggregate("SELECT COUNT(*) FROM T", Seq(Row(4700)), 2) + } + } + }) + }) } }