Skip to content

Commit

Permalink
[BLAZE-287][FOLLOWUP] BlazeCelebornShuffleWriter should use mapped sh…
Browse files Browse the repository at this point in the history
…uffle id for rerunning stage of fetch failure (#712)
  • Loading branch information
SteNicholas authored Dec 21, 2024
1 parent 092c9a0 commit cbeda0a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ class BlazeCelebornShuffleManager(conf: SparkConf, isDriver: Boolean)
.asInstanceOf[ShuffleClient]

val celebornHandle = handle.asInstanceOf[CelebornShuffleHandle[_, _, _]]
val writer = new BlazeCelebornShuffleWriter(shuffleClient, context, celebornHandle, metrics)
val shuffleIdTracker = FieldUtils
.readField(celebornShuffleManager, "shuffleIdTracker", true)
.asInstanceOf[ExecutorShuffleIdTracker]
val writer = new BlazeCelebornShuffleWriter(shuffleClient, context, celebornHandle, metrics, shuffleIdTracker)
writer.asInstanceOf[BlazeRssShuffleWriterBase[K, V]]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.blaze.shuffle.celeborn
import org.apache.celeborn.client.ShuffleClient
import org.apache.spark.shuffle.ShuffleHandle
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
import org.apache.spark.shuffle.celeborn.CelebornShuffleHandle
import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ExecutorShuffleIdTracker, SparkUtils}
import org.apache.spark.sql.execution.blaze.shuffle.BlazeRssShuffleWriterBase
import org.apache.spark.sql.execution.blaze.shuffle.RssPartitionWriterBase
import org.apache.spark.TaskContext
Expand All @@ -29,7 +29,8 @@ class BlazeCelebornShuffleWriter[K, C](
shuffleClient: ShuffleClient,
taskContext: TaskContext,
handle: CelebornShuffleHandle[K, _, C],
metrics: ShuffleWriteMetricsReporter)
metrics: ShuffleWriteMetricsReporter,
shuffleIdTracker: ExecutorShuffleIdTracker)
extends BlazeRssShuffleWriterBase[K, C](metrics) {

private val numMappers = handle.numMappers
Expand All @@ -41,9 +42,11 @@ class BlazeCelebornShuffleWriter[K, C](
metrics: ShuffleWriteMetricsReporter,
numPartitions: Int): RssPartitionWriterBase = {

val shuffleId = SparkUtils.celebornShuffleId(shuffleClient, handle, taskContext, true)
shuffleIdTracker.track(handle.shuffleId, shuffleId)
new CelebornPartitionWriter(
shuffleClient,
handle.shuffleId,
shuffleId,
encodedAttemptId,
numMappers,
numPartitions,
Expand Down

0 comments on commit cbeda0a

Please sign in to comment.