Skip to content

Commit

Permalink
[core][spark] Enable limit pushdown and count optimization for dv tab…
Browse files Browse the repository at this point in the history
…le (#4709)
  • Loading branch information
Zouxxyy authored Dec 15, 2024
1 parent 72e7150 commit 9179d65
Show file tree
Hide file tree
Showing 8 changed files with 216 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

import static org.apache.paimon.io.DataFilePathFactory.INDEX_PATH_SUFFIX;
import static org.apache.paimon.utils.Preconditions.checkArgument;
import static org.apache.paimon.utils.Preconditions.checkState;

/** Input splits. Needed by most batch computation engines. */
public class DataSplit implements Split {
Expand Down Expand Up @@ -126,6 +127,45 @@ public long rowCount() {
return rowCount;
}

/** Whether it is possible to calculate the merged row count. */
public boolean mergedRowCountAvailable() {
return rawConvertible
&& (dataDeletionFiles == null
|| dataDeletionFiles.stream()
.allMatch(f -> f == null || f.cardinality() != null));
}

public long mergedRowCount() {
checkState(mergedRowCountAvailable());
return partialMergedRowCount();
}

/**
* Obtain merged row count as much as possible. There are two scenarios where accurate row count
* can be calculated:
*
* <p>1. raw file and no deletion file.
*
* <p>2. raw file + deletion file with cardinality.
*/
public long partialMergedRowCount() {
long sum = 0L;
if (rawConvertible) {
List<RawFile> rawFiles = convertToRawFiles().orElse(null);
if (rawFiles != null) {
for (int i = 0; i < rawFiles.size(); i++) {
RawFile rawFile = rawFiles.get(i);
if (dataDeletionFiles == null || dataDeletionFiles.get(i) == null) {
sum += rawFile.rowCount();
} else if (dataDeletionFiles.get(i).cardinality() != null) {
sum += rawFile.rowCount() - dataDeletionFiles.get(i).cardinality();
}
}
}
}
return sum;
}

@Override
public Optional<List<RawFile>> convertToRawFiles() {
if (rawConvertible) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

import static org.apache.paimon.CoreOptions.MergeEngine.FIRST_ROW;

Expand Down Expand Up @@ -103,9 +102,9 @@ private StartingScanner.Result applyPushDownLimit(StartingScanner.Result result)
List<Split> limitedSplits = new ArrayList<>();
for (DataSplit dataSplit : splits) {
if (dataSplit.rawConvertible()) {
long splitRowCount = getRowCountForSplit(dataSplit);
long partialMergedRowCount = dataSplit.partialMergedRowCount();
limitedSplits.add(dataSplit);
scannedRowCount += splitRowCount;
scannedRowCount += partialMergedRowCount;
if (scannedRowCount >= pushDownLimit) {
SnapshotReader.Plan newPlan =
new PlanImpl(plan.watermark(), plan.snapshotId(), limitedSplits);
Expand All @@ -117,20 +116,6 @@ private StartingScanner.Result applyPushDownLimit(StartingScanner.Result result)
return result;
}

/**
* 0 represents that we can't compute the row count of this split: 1. the split needs to be
* merged; 2. the table enabled deletion vector and there are some deletion files.
*/
private long getRowCountForSplit(DataSplit split) {
if (split.deletionFiles().isPresent()
&& split.deletionFiles().get().stream().anyMatch(Objects::nonNull)) {
return 0L;
}
return split.convertToRawFiles()
.map(files -> files.stream().map(RawFile::rowCount).reduce(Long::sum).orElse(0L))
.orElse(0L);
}

@Override
public DataTableScan withShard(int indexOfThisSubtask, int numberOfParallelSubtasks) {
snapshotReader.withShard(indexOfThisSubtask, numberOfParallelSubtasks);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@
public class SplitGeneratorTest {

public static DataFileMeta newFileFromSequence(
String name, int rowCount, long minSequence, long maxSequence) {
String name, int fileSize, long minSequence, long maxSequence) {
return new DataFileMeta(
name,
rowCount,
fileSize,
1,
EMPTY_ROW,
EMPTY_ROW,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,41 @@
/** Test for {@link DataSplit}. */
public class SplitTest {

@Test
public void testSplitMergedRowCount() {
// not rawConvertible
List<DataFileMeta> dataFiles =
Arrays.asList(newDataFile(1000L), newDataFile(2000L), newDataFile(3000L));
DataSplit split = newDataSplit(false, dataFiles, null);
assertThat(split.partialMergedRowCount()).isEqualTo(0L);
assertThat(split.mergedRowCountAvailable()).isEqualTo(false);

// rawConvertible without deletion files
split = newDataSplit(true, dataFiles, null);
assertThat(split.partialMergedRowCount()).isEqualTo(6000L);
assertThat(split.mergedRowCountAvailable()).isEqualTo(true);
assertThat(split.mergedRowCount()).isEqualTo(6000L);

// rawConvertible with deletion files without cardinality
ArrayList<DeletionFile> deletionFiles = new ArrayList<>();
deletionFiles.add(null);
deletionFiles.add(new DeletionFile("p", 1, 2, null));
deletionFiles.add(new DeletionFile("p", 1, 2, 100L));
split = newDataSplit(true, dataFiles, deletionFiles);
assertThat(split.partialMergedRowCount()).isEqualTo(3900L);
assertThat(split.mergedRowCountAvailable()).isEqualTo(false);

// rawConvertible with deletion files with cardinality
deletionFiles = new ArrayList<>();
deletionFiles.add(null);
deletionFiles.add(new DeletionFile("p", 1, 2, 200L));
deletionFiles.add(new DeletionFile("p", 1, 2, 100L));
split = newDataSplit(true, dataFiles, deletionFiles);
assertThat(split.partialMergedRowCount()).isEqualTo(5700L);
assertThat(split.mergedRowCountAvailable()).isEqualTo(true);
assertThat(split.mergedRowCount()).isEqualTo(5700L);
}

@Test
public void testSerializer() throws IOException {
DataFileTestDataGenerator gen = DataFileTestDataGenerator.builder().build();
Expand Down Expand Up @@ -311,4 +346,36 @@ public void testSerializerCompatibleV3() throws Exception {
InstantiationUtil.deserializeObject(v2Bytes, DataSplit.class.getClassLoader());
assertThat(actual).isEqualTo(split);
}

private DataFileMeta newDataFile(long rowCount) {
return DataFileMeta.forAppend(
"my_data_file.parquet",
1024 * 1024,
rowCount,
null,
0L,
rowCount,
1,
Collections.emptyList(),
null,
null,
null);
}

private DataSplit newDataSplit(
boolean rawConvertible,
List<DataFileMeta> dataFiles,
List<DeletionFile> deletionFiles) {
DataSplit.Builder builder = DataSplit.builder();
builder.withSnapshot(1)
.withPartition(BinaryRow.EMPTY_ROW)
.withBucket(1)
.withBucketPath("my path")
.rawConvertible(rawConvertible)
.withDataFiles(dataFiles);
if (deletionFiles != null) {
builder.withDataDeletionFiles(deletionFiles);
}
return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package org.apache.paimon.spark
import org.apache.paimon.predicate.PredicateBuilder
import org.apache.paimon.spark.aggregate.LocalAggregator
import org.apache.paimon.table.Table
import org.apache.paimon.table.source.DataSplit

import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates, SupportsPushDownLimit}
Expand All @@ -36,12 +37,12 @@ class PaimonScanBuilder(table: Table)
override def pushLimit(limit: Int): Boolean = {
// It is safe, since we will do nothing if it is the primary table and the split is not `rawConvertible`
pushDownLimit = Some(limit)
// just make a best effort to push down limit
// just make the best effort to push down limit
false
}

override def supportCompletePushDown(aggregation: Aggregation): Boolean = {
// for now we only support complete push down, so there is no difference with `pushAggregation`
// for now, we only support complete push down, so there is no difference with `pushAggregation`
pushAggregation(aggregation)
}

Expand All @@ -66,8 +67,11 @@ class PaimonScanBuilder(table: Table)
val pushedPartitionPredicate = PredicateBuilder.and(pushedPredicates.map(_._2): _*)
readBuilder.withFilter(pushedPartitionPredicate)
}
val scan = readBuilder.newScan()
scan.listPartitionEntries.asScala.foreach(aggregator.update)
val dataSplits = readBuilder.newScan().plan().splits().asScala.map(_.asInstanceOf[DataSplit])
if (!dataSplits.forall(_.mergedRowCountAvailable())) {
return false
}
dataSplits.foreach(aggregator.update)
localScan = Some(
PaimonLocalScan(
aggregator.result(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
package org.apache.paimon.spark.aggregate

import org.apache.paimon.data.BinaryRow
import org.apache.paimon.manifest.PartitionEntry
import org.apache.paimon.spark.SparkTypeUtils
import org.apache.paimon.spark.data.SparkInternalRow
import org.apache.paimon.table.{DataTable, Table}
import org.apache.paimon.table.source.DataSplit
import org.apache.paimon.utils.{InternalRowUtils, ProjectedRow}

import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -78,13 +78,7 @@ class LocalAggregator(table: Table) {
}

def pushAggregation(aggregation: Aggregation): Boolean = {
if (
!table.isInstanceOf[DataTable] ||
!table.primaryKeys.isEmpty
) {
return false
}
if (table.asInstanceOf[DataTable].coreOptions.deletionVectorsEnabled) {
if (!table.isInstanceOf[DataTable]) {
return false
}

Expand All @@ -108,12 +102,12 @@ class LocalAggregator(table: Table) {
SparkInternalRow.create(partitionType).replace(genericRow)
}

def update(partitionEntry: PartitionEntry): Unit = {
def update(dataSplit: DataSplit): Unit = {
assert(isInitialized)
val groupByRow = requiredGroupByRow(partitionEntry.partition())
val groupByRow = requiredGroupByRow(dataSplit.partition())
val aggFuncEvaluator =
groupByEvaluatorMap.getOrElseUpdate(groupByRow, aggFuncEvaluatorGetter())
aggFuncEvaluator.foreach(_.update(partitionEntry))
aggFuncEvaluator.foreach(_.update(dataSplit))
}

def result(): Array[InternalRow] = {
Expand Down Expand Up @@ -147,7 +141,7 @@ class LocalAggregator(table: Table) {
}

trait AggFuncEvaluator[T] {
def update(partitionEntry: PartitionEntry): Unit
def update(dataSplit: DataSplit): Unit
def result(): T
def resultType: DataType
def prettyName: String
Expand All @@ -156,8 +150,8 @@ trait AggFuncEvaluator[T] {
class CountStarEvaluator extends AggFuncEvaluator[Long] {
private var _result: Long = 0L

override def update(partitionEntry: PartitionEntry): Unit = {
_result += partitionEntry.recordCount()
override def update(dataSplit: DataSplit): Unit = {
_result += dataSplit.mergedRowCount()
}

override def result(): Long = _result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

package org.apache.paimon.spark.sql

import org.apache.paimon.spark.{PaimonBatch, PaimonInputPartition, PaimonScan, PaimonSparkTestBase, SparkTable}
import org.apache.paimon.spark.{PaimonScan, PaimonSparkTestBase, SparkTable}
import org.apache.paimon.table.source.DataSplit

import org.apache.spark.sql.Row
Expand All @@ -29,8 +29,6 @@ import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownLimit}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.junit.jupiter.api.Assertions

import scala.collection.JavaConverters._

class PaimonPushDownTest extends PaimonSparkTestBase {

import testImplicits._
Expand Down Expand Up @@ -64,7 +62,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase {
checkAnswer(spark.sql(q), Row(1, "a", "p1") :: Nil)

// case 2
// filter "id = '1' or pt = 'p1'" can't push down completely, it still need to be evaluated after scanning
// filter "id = '1' or pt = 'p1'" can't push down completely, it still needs to be evaluated after scanning
q = "SELECT * FROM T WHERE id = '1' or pt = 'p1'"
Assertions.assertTrue(checkEqualToFilterExists(q, "pt", Literal("p1")))
checkAnswer(spark.sql(q), Row(1, "a", "p1") :: Row(2, "b", "p1") :: Nil)
Expand Down Expand Up @@ -121,7 +119,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase {
val dataSplitsWithoutLimit = scanBuilder.build().asInstanceOf[PaimonScan].getOriginSplits
Assertions.assertTrue(dataSplitsWithoutLimit.length >= 2)

// It still return false even it can push down limit.
// It still returns false even it can push down limit.
Assertions.assertFalse(scanBuilder.asInstanceOf[SupportsPushDownLimit].pushLimit(1))
val dataSplitsWithLimit = scanBuilder.build().asInstanceOf[PaimonScan].getOriginSplits
Assertions.assertEquals(1, dataSplitsWithLimit.length)
Expand Down Expand Up @@ -169,12 +167,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase {
// Now, we have 4 dataSplits, and 2 dataSplit is nonRawConvertible, 2 dataSplit is rawConvertible.
Assertions.assertEquals(
2,
dataSplitsWithoutLimit2
.filter(
split => {
split.asInstanceOf[DataSplit].rawConvertible()
})
.length)
dataSplitsWithoutLimit2.count(split => { split.asInstanceOf[DataSplit].rawConvertible() }))

// Return 2 dataSplits.
Assertions.assertFalse(scanBuilder2.asInstanceOf[SupportsPushDownLimit].pushLimit(2))
Expand Down Expand Up @@ -206,7 +199,40 @@ class PaimonPushDownTest extends PaimonSparkTestBase {
// Need to scan all dataSplits.
Assertions.assertEquals(4, dataSplitsWithLimit3.length)
Assertions.assertEquals(1, spark.sql("SELECT * FROM T LIMIT 1").count())
}

test("Paimon pushDown: limit for table with deletion vector") {
Seq(true, false).foreach(
deletionVectorsEnabled => {
Seq(true, false).foreach(
primaryKeyTable => {
withTable("T") {
sql(s"""
|CREATE TABLE T (id INT)
|TBLPROPERTIES (
| 'deletion-vectors.enabled' = $deletionVectorsEnabled,
| '${if (primaryKeyTable) "primary-key" else "bucket-key"}' = 'id',
| 'bucket' = '10'
|)
|""".stripMargin)

sql("INSERT INTO T SELECT id FROM range (1, 50000)")
sql("DELETE FROM T WHERE id % 13 = 0")

val withoutLimit = getScanBuilder().build().asInstanceOf[PaimonScan].getOriginSplits
assert(withoutLimit.length == 10)

val scanBuilder = getScanBuilder().asInstanceOf[SupportsPushDownLimit]
scanBuilder.pushLimit(1)
val withLimit = scanBuilder.build().asInstanceOf[PaimonScan].getOriginSplits
if (deletionVectorsEnabled || !primaryKeyTable) {
assert(withLimit.length == 1)
} else {
assert(withLimit.length == 10)
}
}
})
})
}

test("Paimon pushDown: runtime filter") {
Expand Down Expand Up @@ -250,8 +276,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase {
}

private def getScanBuilder(tableName: String = "T"): ScanBuilder = {
new SparkTable(loadTable(tableName))
.newScanBuilder(CaseInsensitiveStringMap.empty())
SparkTable(loadTable(tableName)).newScanBuilder(CaseInsensitiveStringMap.empty())
}

private def checkFilterExists(sql: String): Boolean = {
Expand All @@ -272,5 +297,4 @@ class PaimonPushDownTest extends PaimonSparkTestBase {
case _ => false
}
}

}
Loading

0 comments on commit 9179d65

Please sign in to comment.