Skip to content

Commit

Permalink
support cross partition insert
Browse files Browse the repository at this point in the history
  • Loading branch information
Zouxxyy committed Nov 2, 2023
1 parent 70959c9 commit f7b8af1
Show file tree
Hide file tree
Showing 10 changed files with 257 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
* limitations under the License.
*/

package org.apache.paimon.flink.sink.index;
package org.apache.paimon.crosspartition;

/** Type of record, key or full row. */
public enum KeyPartOrRow {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

/** A {@link PartitionKeyExtractor} to {@link InternalRow} with only key and partiton fields. */
/** A {@link PartitionKeyExtractor} to {@link InternalRow} with only key and partition fields. */
public class KeyPartPartitionKeyExtractor implements PartitionKeyExtractor<InternalRow> {

private final Projection partitionProjection;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.paimon.flink.sink.index;

import org.apache.paimon.crosspartition.IndexBootstrap;
import org.apache.paimon.crosspartition.KeyPartOrRow;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.flink.sink.Committable;
import org.apache.paimon.flink.sink.DynamicBucketRowWriteOperator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.paimon.flink.sink.index;

import org.apache.paimon.crosspartition.GlobalIndexAssigner;
import org.apache.paimon.crosspartition.KeyPartOrRow;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.disk.IOManager;
import org.apache.paimon.table.Table;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.paimon.flink.sink.index;

import org.apache.paimon.crosspartition.IndexBootstrap;
import org.apache.paimon.crosspartition.KeyPartOrRow;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.utils.SerializableFunction;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.apache.paimon.codegen.CodeGenUtils;
import org.apache.paimon.codegen.Projection;
import org.apache.paimon.crosspartition.KeyPartOrRow;
import org.apache.paimon.data.BinaryRow;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.flink.sink.ChannelComputer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.apache.paimon.flink.sink.index;

import org.apache.paimon.crosspartition.KeyPartOrRow;
import org.apache.paimon.flink.utils.InternalTypeSerializer;

import org.apache.flink.api.common.typeutils.TypeSerializer;
Expand All @@ -28,7 +29,7 @@
import java.io.IOException;
import java.util.Objects;

import static org.apache.paimon.flink.sink.index.KeyPartOrRow.KEY_PART;
import static org.apache.paimon.crosspartition.KeyPartOrRow.KEY_PART;

/** A {@link InternalTypeSerializer} to serialize KeyPartOrRow with T. */
public class KeyWithRowSerializer<T> extends InternalTypeSerializer<Tuple2<KeyPartOrRow, T>> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,25 @@
package org.apache.paimon.spark.commands

import org.apache.paimon.CoreOptions.DYNAMIC_PARTITION_OVERWRITE
import org.apache.paimon.data.BinaryRow
import org.apache.paimon.codegen.{CodeGenUtils, Projection}
import org.apache.paimon.crosspartition.{GlobalIndexAssigner, IndexBootstrap, KeyPartOrRow}
import org.apache.paimon.data.{BinaryRow, GenericRow, JoinedRow}
import org.apache.paimon.data.serializer.InternalSerializers
import org.apache.paimon.index.PartitionIndex
import org.apache.paimon.options.Options
import org.apache.paimon.spark.{DynamicOverWrite, InsertInto, Overwrite, SaveMode, SparkConnectorOptions, SparkRow}
import org.apache.paimon.spark._
import org.apache.paimon.spark.SparkUtils.createIOManager
import org.apache.paimon.spark.schema.SparkSystemColumns
import org.apache.paimon.spark.schema.SparkSystemColumns.{BUCKET_COL, ROW_KIND_COL}
import org.apache.paimon.spark.util.{EncoderUtils, SparkRowUtils}
import org.apache.paimon.table.{BucketMode, FileStoreTable}
import org.apache.paimon.table.sink.{BatchWriteBuilder, CommitMessageSerializer, DynamicBucketRow, RowPartitionKeyExtractor}
import org.apache.paimon.types.RowType
import org.apache.paimon.types.{RowKind, RowType}
import org.apache.paimon.utils.SerializationUtils

import org.apache.spark.TaskContext
import org.apache.spark.{HashPartitioner, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
Expand Down Expand Up @@ -78,51 +83,154 @@ case class WriteIntoPaimonTable(
val primaryKeyCols = tableSchema.trimmedPrimaryKeys().asScala.map(col)
val partitionCols = tableSchema.partitionKeys().asScala.map(col)

val dataEncoder = EncoderUtils.encode(dataSchema).resolveAndBind()
val originFromRow = dataEncoder.createDeserializer()
val (_, _, originFromRow) = EncoderUtils.getEncoderAndSerDe(dataSchema)

val rowkindColIdx = SparkRowUtils.getFieldIndex(data.schema, ROW_KIND_COL)
var newData = data

// append _bucket_ column as placeholder
val withBucketCol = data.withColumn(BUCKET_COL, lit(-1))
val bucketColIdx = withBucketCol.schema.size - 1
val withBucketDataEncoder = EncoderUtils.encode(withBucketCol.schema).resolveAndBind()
val toRow = withBucketDataEncoder.createSerializer()
val fromRow = withBucketDataEncoder.createDeserializer()
if (
bucketMode.equals(BucketMode.GLOBAL_DYNAMIC) && !newData.schema.fieldNames.contains(
ROW_KIND_COL)
) {
newData = data.withColumn(ROW_KIND_COL, lit(RowKind.INSERT.toByteValue))
}
val rowkindColIdx = SparkRowUtils.getFieldIndex(newData.schema, ROW_KIND_COL)

// append bucket column as placeholder
newData = newData.withColumn(BUCKET_COL, lit(-1))
val bucketColIdx = SparkRowUtils.getFieldIndex(newData.schema, BUCKET_COL)

val (newDataEncoder, toRow, fromRow) = EncoderUtils.getEncoderAndSerDe(newData.schema)

def repartitionByBucket(ds: Dataset[Row]) = {
ds.toDF().repartition(partitionCols ++ Seq(col(BUCKET_COL)): _*)
def repartitionByBucket(df: DataFrame) = {
df.repartition(partitionCols ++ Seq(col(BUCKET_COL)): _*)
}

val rowType = table.rowType()
val writeBuilder = table.newBatchWriteBuilder()

val df =
val df: Dataset[Row] =
bucketMode match {
case BucketMode.DYNAMIC =>
// Topology: input -> shuffle by key hash -> bucket-assigner -> shuffle by partition & bucket
val partitioned = if (primaryKeyCols.nonEmpty) {
// Make sure that the records with the same bucket values is within a task.
withBucketCol.repartition(primaryKeyCols: _*)
newData.repartition(primaryKeyCols: _*)
} else {
withBucketCol
newData
}
val numSparkPartitions = partitioned.rdd.getNumPartitions
val dynamicBucketProcessor =
DynamicBucketProcessor(table, rowType, bucketColIdx, numSparkPartitions, toRow, fromRow)

repartitionByBucket(
partitioned.mapPartitions(dynamicBucketProcessor.processPartition)(
withBucketDataEncoder))
partitioned
.mapPartitions(dynamicBucketProcessor.processPartition)(newDataEncoder)
.toDF())
case BucketMode.GLOBAL_DYNAMIC =>
// Topology: input -> bootstrap -> shuffle by key hash -> bucket-assigner -> shuffle by partition & bucket
val numSparkPartitions = newData.rdd.getNumPartitions
val primaryKeys: java.util.List[String] = table.schema().primaryKeys()
val bootstrapType: RowType = IndexBootstrap.bootstrapType(table.schema())
val rowType: RowType = SparkTypeUtils.toPaimonType(newData.schema).asInstanceOf[RowType]

// row: (keyHash, (kind, internalRow))
val bootstrapRow: RDD[(Int, (KeyPartOrRow, Array[Byte]))] = newData.rdd.mapPartitions {
iter =>
{
val sparkPartitionId = TaskContext.getPartitionId()

val keyPartProject: Projection =
CodeGenUtils.newProjection(bootstrapType, primaryKeys)
val rowProject: Projection = CodeGenUtils.newProjection(rowType, primaryKeys)
val bootstrapSer = InternalSerializers.create(bootstrapType)
val rowSer = InternalSerializers.create(rowType)

val lst = scala.collection.mutable.ListBuffer[(Int, (KeyPartOrRow, Array[Byte]))]()

val bootstrap = new IndexBootstrap(table)
bootstrap.bootstrap(
numSparkPartitions,
sparkPartitionId,
row => {
val bytes: Array[Byte] =
SerializationUtils.serializeBinaryRow(bootstrapSer.toBinaryRow(row))
lst.append((keyPartProject(row).hashCode(), (KeyPartOrRow.KEY_PART, bytes)))
}
)
lst.iterator ++ iter.map(
r => {
val sparkRow =
new SparkRow(rowType, r, SparkRowUtils.getRowKind(r, rowkindColIdx))
val bytes: Array[Byte] =
SerializationUtils.serializeBinaryRow(rowSer.toBinaryRow(sparkRow))
(rowProject(sparkRow).hashCode(), (KeyPartOrRow.ROW, bytes))
})
}
}

var assignerParallelism: Integer = table.coreOptions.dynamicBucketAssignerParallelism
if (assignerParallelism == null) {
assignerParallelism = numSparkPartitions
}

val rowRDD: RDD[Row] =
bootstrapRow.partitionBy(new HashPartitioner(assignerParallelism)).mapPartitions {
iter =>
{
val sparkPartitionId = TaskContext.getPartitionId()
val lst = scala.collection.mutable.ListBuffer[Row]()
val ioManager = createIOManager
val assigner = new GlobalIndexAssigner(table)
try {
assigner.open(
ioManager,
assignerParallelism,
sparkPartitionId,
(row, bucket) => {
val extraRow: GenericRow = new GenericRow(2)
extraRow.setField(0, row.getRowKind.toByteValue)
extraRow.setField(1, bucket)
lst.append(
fromRow(
SparkInternalRow.fromPaimon(new JoinedRow(row, extraRow), rowType)))
}
)
iter.foreach(
row => {
val tuple: (KeyPartOrRow, Array[Byte]) = row._2
val binaryRow = SerializationUtils.deserializeBinaryRow(tuple._2)
tuple._1 match {
case KeyPartOrRow.KEY_PART => assigner.bootstrapKey(binaryRow)
case KeyPartOrRow.ROW => assigner.processInput(binaryRow)
case _ =>
throw new UnsupportedOperationException(s"unknown kind ${tuple._1}")
}
})
assigner.endBoostrap(true)
lst.iterator
} finally {
assigner.close()
if (ioManager != null) {
ioManager.close()
}
}
}
}
repartitionByBucket(sparkSession.createDataFrame(rowRDD, newData.schema))
case BucketMode.UNAWARE =>
// Topology: input -> bucket-assigner
val unawareBucketProcessor = UnawareBucketProcessor(bucketColIdx, toRow, fromRow)
withBucketCol
.mapPartitions(unawareBucketProcessor.processPartition)(withBucketDataEncoder)
newData
.mapPartitions(unawareBucketProcessor.processPartition)(newDataEncoder)
.toDF()
case BucketMode.FIXED =>
// Topology: input -> bucket-assigner -> shuffle by partition & bucket
val commonBucketProcessor =
CommonBucketProcessor(writeBuilder, bucketColIdx, toRow, fromRow)
repartitionByBucket(
withBucketCol.mapPartitions(commonBucketProcessor.processPartition)(
withBucketDataEncoder))
newData.mapPartitions(commonBucketProcessor.processPartition)(newDataEncoder).toDF())
case _ =>
throw new UnsupportedOperationException(s"unsupported bucket mode $bucketMode")
}

val commitMessages = df
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,10 @@ object EncoderUtils {
.reflectMethod(method)(schema)
.asInstanceOf[ExpressionEncoder[Row]]
}

def getEncoderAndSerDe(schema: StructType)
: (ExpressionEncoder[Row], ExpressionEncoder.Serializer[Row], ExpressionEncoder.Deserializer[Row]) = {
val encoder = encode(schema).resolveAndBind()
(encoder, encoder.createSerializer(), encoder.createDeserializer())
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.paimon.spark

import org.assertj.core.api.Assertions.assertThat

class GlobalDynamicBucketTableTest extends PaimonSparkTestBase {

test(s"test global dynamic bucket") {
spark.sql(s"""
|CREATE TABLE T (
| pt INT,
| pk INT,
| v INT)
|TBLPROPERTIES (
| 'primary-key' = 'pk',
| 'bucket' = '-1',
| 'dynamic-bucket.target-row-num'='3',
| 'dynamic-bucket.assigner-parallelism'='1'
|)
|PARTITIONED BY (pt)
|""".stripMargin)

spark.sql("INSERT INTO T VALUES (1, 1, 1), (1, 2, 2), (1, 3, 3), (1, 4, 4), (1, 5, 5)")
val rows1 = spark.sql("SELECT * FROM T ORDER BY pk").collectAsList()
assertThat(rows1.toString).isEqualTo("[[1,1,1], [1,2,2], [1,3,3], [1,4,4], [1,5,5]]")

spark.sql("INSERT INTO T VALUES (1, 3, 33), (1, 1, 11)")
val rows2 = spark.sql("SELECT * FROM T ORDER BY pk").collectAsList()
assertThat(rows2.toString).isEqualTo("[[1,1,11], [1,2,2], [1,3,33], [1,4,4], [1,5,5]]")

val rows3 = spark.sql("SELECT DISTINCT bucket FROM `T$FILES`").collectAsList()
assertThat(rows3.toString).isEqualTo("[[0], [1]]")

// change partition
spark.sql("INSERT INTO T VALUES (2, 1, 2), (2, 2, 3)")
val rows4 = spark.sql("SELECT * FROM T ORDER BY pk").collectAsList()
assertThat(rows4.toString).isEqualTo("[[2,1,2], [2,2,3], [1,3,33], [1,4,4], [1,5,5]]")
}

test(s"test global dynamic bucket with delete") {
spark.sql(s"""
|CREATE TABLE T (
| pt INT,
| pk INT,
| v INT)
|TBLPROPERTIES (
| 'primary-key' = 'pk',
| 'bucket' = '-1',
| 'dynamic-bucket.target-row-num'='3'
|)
|PARTITIONED BY (pt)
|""".stripMargin)

spark.sql("INSERT INTO T VALUES (1, 1, 1), (1, 2, 2), (1, 3, 3), (1, 4, 4), (1, 5, 5)")
val rows1 = spark.sql("SELECT * FROM T ORDER BY pk").collectAsList()
assertThat(rows1.toString).isEqualTo("[[1,1,1], [1,2,2], [1,3,3], [1,4,4], [1,5,5]]")

spark.sql("DELETE FROM T WHERE pk = 1")
val rows2 = spark.sql("SELECT * FROM T ORDER BY pk").collectAsList()
assertThat(rows2.toString).isEqualTo("[[1,2,2], [1,3,3], [1,4,4], [1,5,5]]")

// change partition
spark.sql("INSERT INTO T VALUES (2, 1, 2), (2, 2, 3)")
val rows3 = spark.sql("SELECT * FROM T ORDER BY pk").collectAsList()
assertThat(rows3.toString).isEqualTo("[[2,1,2], [2,2,3], [1,3,3], [1,4,4], [1,5,5]]")

spark.sql("DELETE FROM T WHERE pk = 2")
val rows4 = spark.sql("SELECT * FROM T ORDER BY pk").collectAsList()
assertThat(rows4.toString).isEqualTo("[[2,1,2], [1,3,3], [1,4,4], [1,5,5]]")
}

test(s"test write with assigner parallelism") {
spark.sql(s"""
|CREATE TABLE T (
| pt INT,
| pk INT,
| v INT)
|TBLPROPERTIES (
| 'primary-key' = 'pk',
| 'bucket' = '-1',
| 'dynamic-bucket.target-row-num'='3',
| 'dynamic-bucket.assigner-parallelism'='3'
|)
|PARTITIONED BY (pt)
|""".stripMargin)

spark.sql("INSERT INTO T VALUES (1, 1, 1), (1, 2, 2), (1, 3, 3), (1, 4, 4), (1, 5, 5)")
val rows1 = spark.sql("SELECT * FROM T ORDER BY pk").collectAsList()
assertThat(rows1.toString).isEqualTo("[[1,1,1], [1,2,2], [1,3,3], [1,4,4], [1,5,5]]")

val rows3 = spark.sql("SELECT DISTINCT bucket FROM `T$FILES`").collectAsList()
assertThat(rows3.toString).isEqualTo("[[0], [1], [2]]")
}

}

0 comments on commit f7b8af1

Please sign in to comment.