Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CH] Shuffle writer connects to CH pipeline #6723

Merged
merged 11 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ private native long nativeCreate(
private native void nativeFlush(long instance);

public void write(ColumnarBatch cb) {
if (cb.numCols() == 0 || cb.numRows() == 0) return;
CHNativeBlock block = CHNativeBlock.fromColumnarBatch(cb);
dataSize.add(block.totalBytes());
nativeWrite(instance, block.blockAddress());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
*/
package org.apache.gluten.vectorized;

import org.apache.gluten.execution.ColumnarNativeIterator;

import java.io.IOException;

public class CHShuffleSplitterJniWrapper {
public CHShuffleSplitterJniWrapper() {}

public long make(
ColumnarNativeIterator records,
NativePartitioning part,
int shuffleId,
long mapId,
Expand All @@ -36,6 +39,7 @@ public long make(
long maxSortBufferSize,
boolean forceMemorySort) {
return nativeMake(
records,
part.getShortName(),
part.getNumPartitions(),
part.getExprList(),
Expand All @@ -55,6 +59,7 @@ public long make(
}

public long makeForRSS(
ColumnarNativeIterator records,
NativePartitioning part,
int shuffleId,
long mapId,
Expand All @@ -66,6 +71,7 @@ public long makeForRSS(
Object pusher,
boolean forceMemorySort) {
return nativeMakeForRSS(
records,
part.getShortName(),
part.getNumPartitions(),
part.getExprList(),
Expand All @@ -82,6 +88,7 @@ public long makeForRSS(
}

public native long nativeMake(
ColumnarNativeIterator records,
String shortName,
int numPartitions,
byte[] exprList,
Expand All @@ -100,6 +107,7 @@ public native long nativeMake(
boolean forceMemorySort);

public native long nativeMakeForRSS(
ColumnarNativeIterator records,
String shortName,
int numPartitions,
byte[] exprList,
Expand All @@ -114,8 +122,6 @@ public native long nativeMakeForRSS(
Object pusher,
boolean forceMemorySort);

public native void split(long splitterId, long block);

public native CHSplitResult stop(long splitterId) throws IOException;

public native void close(long splitterId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.{InterruptibleIterator, SparkConf, TaskContext}
import org.apache.spark.affinity.CHAffinity
import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.shuffle.CHColumnarShuffleWriter
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.execution.datasources.FilePartition
import org.apache.spark.sql.execution.metric.SQLMetric
Expand Down Expand Up @@ -322,8 +323,12 @@ class CollectMetricIterator(
private var outputRowCount = 0L
private var outputVectorCount = 0L
private var metricsUpdated = false
// Whether the stage is executed completely using ClickHouse pipeline.
private var wholeStagePipeline = true

override def hasNext: Boolean = {
// The hasNext call is triggered only when there is a fallback.
wholeStagePipeline = false
nativeIterator.hasNext
}

Expand All @@ -347,6 +352,11 @@ class CollectMetricIterator(
private def collectStageMetrics(): Unit = {
if (!metricsUpdated) {
val nativeMetrics = nativeIterator.getMetrics.asInstanceOf[NativeMetrics]
if (wholeStagePipeline) {
outputRowCount = Math.max(outputRowCount, CHColumnarShuffleWriter.getTotalOutputRows)
outputVectorCount =
Math.max(outputVectorCount, CHColumnarShuffleWriter.getTotalOutputBatches)
}
nativeMetrics.setFinalOutputMetrics(outputRowCount, outputVectorCount)
updateNativeMetrics(nativeMetrics)
updateInputMetrics.foreach(_(inputMetrics))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ private object CHRuleApi {
def injectLegacy(injector: LegacyInjector): Unit = {
// Gluten columnar: Transform rules.
injector.injectTransform(_ => RemoveTransitions)
injector.injectTransform(_ => PushDownInputFileExpression.PreOffload)
injector.injectTransform(c => FallbackOnANSIMode.apply(c.session))
injector.injectTransform(c => FallbackMultiCodegens.apply(c.session))
injector.injectTransform(_ => RewriteSubqueryBroadcast())
Expand All @@ -72,6 +73,7 @@ private object CHRuleApi {
injector.injectTransform(_ => TransformPreOverrides())
injector.injectTransform(_ => RemoveNativeWriteFilesSortAndProject())
injector.injectTransform(c => RewriteTransformer.apply(c.session))
injector.injectTransform(_ => PushDownInputFileExpression.PostOffload)
injector.injectTransform(_ => EnsureLocalSortRequirements)
injector.injectTransform(_ => EliminateLocalSort)
injector.injectTransform(_ => CollapseProjectExecTransformer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ private class CHColumnarBatchSerializerInstance(
compressionCodec,
GlutenConfig.getConf.columnarShuffleCodecBackend.orNull)

private val useColumnarShuffle: Boolean = GlutenConfig.getConf.isUseColumnarShuffleManager

override def deserializeStream(in: InputStream): DeserializationStream = {
// Don't use GlutenConfig in this method. It will execute in non task Thread.
new DeserializationStream {
private val reader: CHStreamReader = new CHStreamReader(
in,
GlutenConfig.getConf.isUseColumnarShuffleManager,
CHBackendSettings.useCustomizedShuffleCodec)
private val reader: CHStreamReader =
new CHStreamReader(in, useColumnarShuffle, CHBackendSettings.useCustomizedShuffleCodec)
private var cb: ColumnarBatch = _

private var numBatchesTotal: Long = _
Expand Down Expand Up @@ -97,7 +98,6 @@ private class CHColumnarBatchSerializerInstance(
var nativeBlock = reader.next()
while (nativeBlock.numRows() == 0) {
if (nativeBlock.numColumns() == 0) {
nativeBlock.close()
this.close()
throw new EOFException
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ public class CHSplitResult extends SplitResult {
private final long splitTime;
private final long diskWriteTime;
private final long serializationTime;
private final long totalRows;
private final long totalBatches;
private final long wallTime;

public CHSplitResult(long totalComputePidTime,
long totalWriteTime,
Expand All @@ -31,7 +34,10 @@ public CHSplitResult(long totalComputePidTime,
long[] rawPartitionLengths,
long splitTime,
long diskWriteTime,
long serializationTime) {
long serializationTime,
long totalRows,
long totalBatches,
long wallTime) {
super(totalComputePidTime,
totalWriteTime,
totalEvictTime,
Expand All @@ -43,6 +49,9 @@ public CHSplitResult(long totalComputePidTime,
this.splitTime = splitTime;
this.diskWriteTime = diskWriteTime;
this.serializationTime = serializationTime;
this.totalRows = totalRows;
this.totalBatches = totalBatches;
this.wallTime = wallTime;
}

public long getSplitTime() {
Expand All @@ -56,4 +65,16 @@ public long getDiskWriteTime() {
public long getSerializationTime() {
return serializationTime;
}

public long getTotalRows() {
return totalRows;
}

public long getTotalBatches() {
return totalBatches;
}

public long getWallTime() {
return wallTime;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.shuffle

import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings
import org.apache.gluten.execution.ColumnarNativeIterator
import org.apache.gluten.memory.CHThreadGroup
import org.apache.gluten.vectorized._

Expand Down Expand Up @@ -75,8 +76,6 @@ class CHColumnarShuffleWriter[K, V](

private var rawPartitionLengths: Array[Long] = _

private var firstRecordBatch: Boolean = true

@throws[IOException]
override def write(records: Iterator[Product2[K, V]]): Unit = {
CHThreadGroup.registerNewThreadGroup()
Expand All @@ -85,20 +84,23 @@ class CHColumnarShuffleWriter[K, V](

private def internalCHWrite(records: Iterator[Product2[K, V]]): Unit = {
val splitterJniWrapper: CHShuffleSplitterJniWrapper = jniWrapper
if (!records.hasNext) {
partitionLengths = new Array[Long](dep.partitioner.numPartitions)
shuffleBlockResolver.writeMetadataFileAndCommit(
dep.shuffleId,
mapId,
partitionLengths,
Array[Long](),
null)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
return
}

val dataTmp = Utils.tempFileWith(shuffleBlockResolver.getDataFile(dep.shuffleId, mapId))
// for fallback
val iter = new ColumnarNativeIterator(new java.util.Iterator[ColumnarBatch] {
override def hasNext: Boolean = {
val has_value = records.hasNext
has_value
}

override def next(): ColumnarBatch = {
val batch = records.next()._2.asInstanceOf[ColumnarBatch]
batch
}
})
if (nativeSplitter == 0) {
nativeSplitter = splitterJniWrapper.make(
iter,
dep.nativePartitioning,
dep.shuffleId,
mapId,
Expand All @@ -114,50 +116,49 @@ class CHColumnarShuffleWriter[K, V](
forceMemorySortShuffle
)
}
while (records.hasNext) {
val cb = records.next()._2.asInstanceOf[ColumnarBatch]
if (cb.numRows == 0 || cb.numCols == 0) {
logInfo(s"Skip ColumnarBatch of ${cb.numRows} rows, ${cb.numCols} cols")
} else {
firstRecordBatch = false
val col = cb.column(0).asInstanceOf[CHColumnVector]
val block = col.getBlockAddress
splitterJniWrapper
.split(nativeSplitter, block)
dep.metrics("numInputRows").add(cb.numRows)
dep.metrics("inputBatches").add(1)
writeMetrics.incRecordsWritten(cb.numRows)
}
}
splitResult = splitterJniWrapper.stop(nativeSplitter)

dep.metrics("splitTime").add(splitResult.getSplitTime)
dep.metrics("IOTime").add(splitResult.getDiskWriteTime)
dep.metrics("serializeTime").add(splitResult.getSerializationTime)
dep.metrics("spillTime").add(splitResult.getTotalSpillTime)
dep.metrics("compressTime").add(splitResult.getTotalCompressTime)
dep.metrics("computePidTime").add(splitResult.getTotalComputePidTime)
dep.metrics("bytesSpilled").add(splitResult.getTotalBytesSpilled)
dep.metrics("dataSize").add(splitResult.getTotalBytesWritten)
writeMetrics.incBytesWritten(splitResult.getTotalBytesWritten)
writeMetrics.incWriteTime(splitResult.getTotalWriteTime + splitResult.getTotalSpillTime)

partitionLengths = splitResult.getPartitionLengths
rawPartitionLengths = splitResult.getRawPartitionLengths
try {
if (splitResult.getTotalRows > 0) {
dep.metrics("numInputRows").add(splitResult.getTotalRows)
dep.metrics("inputBatches").add(splitResult.getTotalBatches)
dep.metrics("splitTime").add(splitResult.getSplitTime)
dep.metrics("IOTime").add(splitResult.getDiskWriteTime)
dep.metrics("serializeTime").add(splitResult.getSerializationTime)
dep.metrics("spillTime").add(splitResult.getTotalSpillTime)
dep.metrics("compressTime").add(splitResult.getTotalCompressTime)
dep.metrics("computePidTime").add(splitResult.getTotalComputePidTime)
dep.metrics("bytesSpilled").add(splitResult.getTotalBytesSpilled)
dep.metrics("dataSize").add(splitResult.getTotalBytesWritten)
dep.metrics("shuffleWallTime").add(splitResult.getWallTime)
writeMetrics.incRecordsWritten(splitResult.getTotalRows)
writeMetrics.incBytesWritten(splitResult.getTotalBytesWritten)
writeMetrics.incWriteTime(splitResult.getTotalWriteTime + splitResult.getTotalSpillTime)
partitionLengths = splitResult.getPartitionLengths
rawPartitionLengths = splitResult.getRawPartitionLengths
CHColumnarShuffleWriter.setOutputMetrics(splitResult)
try {
shuffleBlockResolver.writeMetadataFileAndCommit(
dep.shuffleId,
mapId,
partitionLengths,
Array[Long](),
dataTmp)
} finally {
if (dataTmp.exists() && !dataTmp.delete()) {
logError(s"Error while deleting temp file ${dataTmp.getAbsolutePath}")
}
}
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
} else {
partitionLengths = new Array[Long](dep.partitioner.numPartitions)
shuffleBlockResolver.writeMetadataFileAndCommit(
dep.shuffleId,
mapId,
partitionLengths,
Array[Long](),
dataTmp)
} finally {
if (dataTmp.exists() && !dataTmp.delete()) {
logError(s"Error while deleting temp file ${dataTmp.getAbsolutePath}")
}
null)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
}

mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
closeCHSplitter()
}

override def stop(success: Boolean): Option[MapStatus] = {
Expand All @@ -172,18 +173,46 @@ class CHColumnarShuffleWriter[K, V](
None
}
} finally {
if (nativeSplitter != 0) {
closeCHSplitter()
nativeSplitter = 0
}
closeCHSplitter()
}
}

private def closeCHSplitter(): Unit = {
jniWrapper.close(nativeSplitter)
if (nativeSplitter != 0) {
jniWrapper.close(nativeSplitter)
nativeSplitter = 0
}
}

// VisibleForTesting
def getPartitionLengths(): Array[Long] = partitionLengths

}

case class OutputMetrics(totalRows: Long, totalBatches: Long)

object CHColumnarShuffleWriter {

private var metric = new ThreadLocal[OutputMetrics]()

// Pass the statistics of the last operator before shuffle to CollectMetricIterator.
def setOutputMetrics(splitResult: CHSplitResult): Unit = {
metric.set(OutputMetrics(splitResult.getTotalRows, splitResult.getTotalBatches))
}

def getTotalOutputRows: Long = {
if (metric.get() == null) {
0
} else {
metric.get().totalRows
}
}

def getTotalOutputBatches: Long = {
if (metric.get() == null) {
0
} else {
metric.get().totalBatches
}
}
}
Loading
Loading