Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][spark] Enable limit pushdown and count optimization for dv table #4709

Merged
merged 1 commit into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

import static org.apache.paimon.io.DataFilePathFactory.INDEX_PATH_SUFFIX;
import static org.apache.paimon.utils.Preconditions.checkArgument;
import static org.apache.paimon.utils.Preconditions.checkState;

/** Input splits. Needed by most batch computation engines. */
public class DataSplit implements Split {
Expand Down Expand Up @@ -126,6 +127,45 @@ public long rowCount() {
return rowCount;
}

/** Whether it is possible to calculate the merged row count. */
public boolean mergedRowCountAvailable() {
return rawConvertible
&& (dataDeletionFiles == null
|| dataDeletionFiles.stream()
.allMatch(f -> f == null || f.cardinality() != null));
}

public long mergedRowCount() {
checkState(mergedRowCountAvailable());
return partialMergedRowCount();
}

/**
* Obtain merged row count as much as possible. There are two scenarios where accurate row count
* can be calculated:
*
* <p>1. raw file and no deletion file.
*
* <p>2. raw file + deletion file with cardinality.
*/
public long partialMergedRowCount() {
long sum = 0L;
if (rawConvertible) {
List<RawFile> rawFiles = convertToRawFiles().orElse(null);
if (rawFiles != null) {
for (int i = 0; i < rawFiles.size(); i++) {
RawFile rawFile = rawFiles.get(i);
if (dataDeletionFiles == null || dataDeletionFiles.get(i) == null) {
sum += rawFile.rowCount();
} else if (dataDeletionFiles.get(i).cardinality() != null) {
sum += rawFile.rowCount() - dataDeletionFiles.get(i).cardinality();
}
}
}
}
return sum;
}

@Override
public Optional<List<RawFile>> convertToRawFiles() {
if (rawConvertible) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

import static org.apache.paimon.CoreOptions.MergeEngine.FIRST_ROW;

Expand Down Expand Up @@ -103,9 +102,9 @@ private StartingScanner.Result applyPushDownLimit(StartingScanner.Result result)
List<Split> limitedSplits = new ArrayList<>();
for (DataSplit dataSplit : splits) {
if (dataSplit.rawConvertible()) {
long splitRowCount = getRowCountForSplit(dataSplit);
long partialMergedRowCount = dataSplit.partialMergedRowCount();
limitedSplits.add(dataSplit);
scannedRowCount += splitRowCount;
scannedRowCount += partialMergedRowCount;
if (scannedRowCount >= pushDownLimit) {
SnapshotReader.Plan newPlan =
new PlanImpl(plan.watermark(), plan.snapshotId(), limitedSplits);
Expand All @@ -117,20 +116,6 @@ private StartingScanner.Result applyPushDownLimit(StartingScanner.Result result)
return result;
}

/**
* 0 represents that we can't compute the row count of this split: 1. the split needs to be
* merged; 2. the table enabled deletion vector and there are some deletion files.
*/
private long getRowCountForSplit(DataSplit split) {
if (split.deletionFiles().isPresent()
&& split.deletionFiles().get().stream().anyMatch(Objects::nonNull)) {
return 0L;
}
return split.convertToRawFiles()
.map(files -> files.stream().map(RawFile::rowCount).reduce(Long::sum).orElse(0L))
.orElse(0L);
}

@Override
public DataTableScan withShard(int indexOfThisSubtask, int numberOfParallelSubtasks) {
snapshotReader.withShard(indexOfThisSubtask, numberOfParallelSubtasks);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@
public class SplitGeneratorTest {

public static DataFileMeta newFileFromSequence(
String name, int rowCount, long minSequence, long maxSequence) {
String name, int fileSize, long minSequence, long maxSequence) {
return new DataFileMeta(
name,
rowCount,
fileSize,
1,
EMPTY_ROW,
EMPTY_ROW,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,41 @@
/** Test for {@link DataSplit}. */
public class SplitTest {

@Test
public void testSplitMergedRowCount() {
// not rawConvertible
List<DataFileMeta> dataFiles =
Arrays.asList(newDataFile(1000L), newDataFile(2000L), newDataFile(3000L));
DataSplit split = newDataSplit(false, dataFiles, null);
assertThat(split.partialMergedRowCount()).isEqualTo(0L);
assertThat(split.mergedRowCountAvailable()).isEqualTo(false);

// rawConvertible without deletion files
split = newDataSplit(true, dataFiles, null);
assertThat(split.partialMergedRowCount()).isEqualTo(6000L);
assertThat(split.mergedRowCountAvailable()).isEqualTo(true);
assertThat(split.mergedRowCount()).isEqualTo(6000L);

// rawConvertible with deletion files without cardinality
ArrayList<DeletionFile> deletionFiles = new ArrayList<>();
deletionFiles.add(null);
deletionFiles.add(new DeletionFile("p", 1, 2, null));
deletionFiles.add(new DeletionFile("p", 1, 2, 100L));
split = newDataSplit(true, dataFiles, deletionFiles);
assertThat(split.partialMergedRowCount()).isEqualTo(3900L);
assertThat(split.mergedRowCountAvailable()).isEqualTo(false);

// rawConvertible with deletion files with cardinality
deletionFiles = new ArrayList<>();
deletionFiles.add(null);
deletionFiles.add(new DeletionFile("p", 1, 2, 200L));
deletionFiles.add(new DeletionFile("p", 1, 2, 100L));
split = newDataSplit(true, dataFiles, deletionFiles);
assertThat(split.partialMergedRowCount()).isEqualTo(5700L);
assertThat(split.mergedRowCountAvailable()).isEqualTo(true);
assertThat(split.mergedRowCount()).isEqualTo(5700L);
}

@Test
public void testSerializer() throws IOException {
DataFileTestDataGenerator gen = DataFileTestDataGenerator.builder().build();
Expand Down Expand Up @@ -311,4 +346,36 @@ public void testSerializerCompatibleV3() throws Exception {
InstantiationUtil.deserializeObject(v2Bytes, DataSplit.class.getClassLoader());
assertThat(actual).isEqualTo(split);
}

private DataFileMeta newDataFile(long rowCount) {
return DataFileMeta.forAppend(
"my_data_file.parquet",
1024 * 1024,
rowCount,
null,
0L,
rowCount,
1,
Collections.emptyList(),
null,
null,
null);
}

private DataSplit newDataSplit(
boolean rawConvertible,
List<DataFileMeta> dataFiles,
List<DeletionFile> deletionFiles) {
DataSplit.Builder builder = DataSplit.builder();
builder.withSnapshot(1)
.withPartition(BinaryRow.EMPTY_ROW)
.withBucket(1)
.withBucketPath("my path")
.rawConvertible(rawConvertible)
.withDataFiles(dataFiles);
if (deletionFiles != null) {
builder.withDataDeletionFiles(deletionFiles);
}
return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package org.apache.paimon.spark
import org.apache.paimon.predicate.PredicateBuilder
import org.apache.paimon.spark.aggregate.LocalAggregator
import org.apache.paimon.table.Table
import org.apache.paimon.table.source.DataSplit

import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates, SupportsPushDownLimit}
Expand All @@ -36,12 +37,12 @@ class PaimonScanBuilder(table: Table)
override def pushLimit(limit: Int): Boolean = {
// It is safe, since we will do nothing if it is the primary table and the split is not `rawConvertible`
pushDownLimit = Some(limit)
// just make a best effort to push down limit
// just make the best effort to push down limit
false
}

override def supportCompletePushDown(aggregation: Aggregation): Boolean = {
// for now we only support complete push down, so there is no difference with `pushAggregation`
// for now, we only support complete push down, so there is no difference with `pushAggregation`
pushAggregation(aggregation)
}

Expand All @@ -66,8 +67,11 @@ class PaimonScanBuilder(table: Table)
val pushedPartitionPredicate = PredicateBuilder.and(pushedPredicates.map(_._2): _*)
readBuilder.withFilter(pushedPartitionPredicate)
}
val scan = readBuilder.newScan()
scan.listPartitionEntries.asScala.foreach(aggregator.update)
val dataSplits = readBuilder.newScan().plan().splits().asScala.map(_.asInstanceOf[DataSplit])
if (!dataSplits.forall(_.mergedRowCountAvailable())) {
return false
}
dataSplits.foreach(aggregator.update)
localScan = Some(
PaimonLocalScan(
aggregator.result(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
package org.apache.paimon.spark.aggregate

import org.apache.paimon.data.BinaryRow
import org.apache.paimon.manifest.PartitionEntry
import org.apache.paimon.spark.SparkTypeUtils
import org.apache.paimon.spark.data.SparkInternalRow
import org.apache.paimon.table.{DataTable, Table}
import org.apache.paimon.table.source.DataSplit
import org.apache.paimon.utils.{InternalRowUtils, ProjectedRow}

import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -78,13 +78,7 @@ class LocalAggregator(table: Table) {
}

def pushAggregation(aggregation: Aggregation): Boolean = {
if (
!table.isInstanceOf[DataTable] ||
!table.primaryKeys.isEmpty
) {
return false
}
if (table.asInstanceOf[DataTable].coreOptions.deletionVectorsEnabled) {
if (!table.isInstanceOf[DataTable]) {
return false
}

Expand All @@ -108,12 +102,12 @@ class LocalAggregator(table: Table) {
SparkInternalRow.create(partitionType).replace(genericRow)
}

def update(partitionEntry: PartitionEntry): Unit = {
def update(dataSplit: DataSplit): Unit = {
assert(isInitialized)
val groupByRow = requiredGroupByRow(partitionEntry.partition())
val groupByRow = requiredGroupByRow(dataSplit.partition())
val aggFuncEvaluator =
groupByEvaluatorMap.getOrElseUpdate(groupByRow, aggFuncEvaluatorGetter())
aggFuncEvaluator.foreach(_.update(partitionEntry))
aggFuncEvaluator.foreach(_.update(dataSplit))
}

def result(): Array[InternalRow] = {
Expand Down Expand Up @@ -147,7 +141,7 @@ class LocalAggregator(table: Table) {
}

trait AggFuncEvaluator[T] {
def update(partitionEntry: PartitionEntry): Unit
def update(dataSplit: DataSplit): Unit
def result(): T
def resultType: DataType
def prettyName: String
Expand All @@ -156,8 +150,8 @@ trait AggFuncEvaluator[T] {
class CountStarEvaluator extends AggFuncEvaluator[Long] {
private var _result: Long = 0L

override def update(partitionEntry: PartitionEntry): Unit = {
_result += partitionEntry.recordCount()
override def update(dataSplit: DataSplit): Unit = {
_result += dataSplit.mergedRowCount()
}

override def result(): Long = _result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

package org.apache.paimon.spark.sql

import org.apache.paimon.spark.{PaimonBatch, PaimonInputPartition, PaimonScan, PaimonSparkTestBase, SparkTable}
import org.apache.paimon.spark.{PaimonScan, PaimonSparkTestBase, SparkTable}
import org.apache.paimon.table.source.DataSplit

import org.apache.spark.sql.Row
Expand All @@ -29,8 +29,6 @@ import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownLimit}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.junit.jupiter.api.Assertions

import scala.collection.JavaConverters._

class PaimonPushDownTest extends PaimonSparkTestBase {

import testImplicits._
Expand Down Expand Up @@ -64,7 +62,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase {
checkAnswer(spark.sql(q), Row(1, "a", "p1") :: Nil)

// case 2
// filter "id = '1' or pt = 'p1'" can't push down completely, it still need to be evaluated after scanning
// filter "id = '1' or pt = 'p1'" can't push down completely, it still needs to be evaluated after scanning
q = "SELECT * FROM T WHERE id = '1' or pt = 'p1'"
Assertions.assertTrue(checkEqualToFilterExists(q, "pt", Literal("p1")))
checkAnswer(spark.sql(q), Row(1, "a", "p1") :: Row(2, "b", "p1") :: Nil)
Expand Down Expand Up @@ -121,7 +119,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase {
val dataSplitsWithoutLimit = scanBuilder.build().asInstanceOf[PaimonScan].getOriginSplits
Assertions.assertTrue(dataSplitsWithoutLimit.length >= 2)

// It still return false even it can push down limit.
// It still returns false even it can push down limit.
Assertions.assertFalse(scanBuilder.asInstanceOf[SupportsPushDownLimit].pushLimit(1))
val dataSplitsWithLimit = scanBuilder.build().asInstanceOf[PaimonScan].getOriginSplits
Assertions.assertEquals(1, dataSplitsWithLimit.length)
Expand Down Expand Up @@ -169,12 +167,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase {
// Now, we have 4 dataSplits, and 2 dataSplit is nonRawConvertible, 2 dataSplit is rawConvertible.
Assertions.assertEquals(
2,
dataSplitsWithoutLimit2
.filter(
split => {
split.asInstanceOf[DataSplit].rawConvertible()
})
.length)
dataSplitsWithoutLimit2.count(split => { split.asInstanceOf[DataSplit].rawConvertible() }))

// Return 2 dataSplits.
Assertions.assertFalse(scanBuilder2.asInstanceOf[SupportsPushDownLimit].pushLimit(2))
Expand Down Expand Up @@ -206,7 +199,40 @@ class PaimonPushDownTest extends PaimonSparkTestBase {
// Need to scan all dataSplits.
Assertions.assertEquals(4, dataSplitsWithLimit3.length)
Assertions.assertEquals(1, spark.sql("SELECT * FROM T LIMIT 1").count())
}

test("Paimon pushDown: limit for table with deletion vector") {
Seq(true, false).foreach(
deletionVectorsEnabled => {
Seq(true, false).foreach(
primaryKeyTable => {
withTable("T") {
sql(s"""
|CREATE TABLE T (id INT)
|TBLPROPERTIES (
| 'deletion-vectors.enabled' = $deletionVectorsEnabled,
| '${if (primaryKeyTable) "primary-key" else "bucket-key"}' = 'id',
| 'bucket' = '10'
|)
|""".stripMargin)

sql("INSERT INTO T SELECT id FROM range (1, 50000)")
sql("DELETE FROM T WHERE id % 13 = 0")

val withoutLimit = getScanBuilder().build().asInstanceOf[PaimonScan].getOriginSplits
assert(withoutLimit.length == 10)

val scanBuilder = getScanBuilder().asInstanceOf[SupportsPushDownLimit]
scanBuilder.pushLimit(1)
val withLimit = scanBuilder.build().asInstanceOf[PaimonScan].getOriginSplits
if (deletionVectorsEnabled || !primaryKeyTable) {
assert(withLimit.length == 1)
} else {
assert(withLimit.length == 10)
}
}
})
})
}

test("Paimon pushDown: runtime filter") {
Expand Down Expand Up @@ -250,8 +276,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase {
}

private def getScanBuilder(tableName: String = "T"): ScanBuilder = {
new SparkTable(loadTable(tableName))
.newScanBuilder(CaseInsensitiveStringMap.empty())
SparkTable(loadTable(tableName)).newScanBuilder(CaseInsensitiveStringMap.empty())
}

private def checkFilterExists(sql: String): Boolean = {
Expand All @@ -272,5 +297,4 @@ class PaimonPushDownTest extends PaimonSparkTestBase {
case _ => false
}
}

}
Loading
Loading