diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/celeborn/BlazeCelebornShuffleManager.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/celeborn/BlazeCelebornShuffleManager.scala index a8cb1149..8e15bece 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/celeborn/BlazeCelebornShuffleManager.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/celeborn/BlazeCelebornShuffleManager.scala @@ -134,7 +134,10 @@ class BlazeCelebornShuffleManager(conf: SparkConf, isDriver: Boolean) .asInstanceOf[ShuffleClient] val celebornHandle = handle.asInstanceOf[CelebornShuffleHandle[_, _, _]] - val writer = new BlazeCelebornShuffleWriter(shuffleClient, context, celebornHandle, metrics) + val shuffleIdTracker = FieldUtils + .readField(celebornShuffleManager, "shuffleIdTracker", true) + .asInstanceOf[ExecutorShuffleIdTracker] + val writer = new BlazeCelebornShuffleWriter(shuffleClient, context, celebornHandle, metrics, shuffleIdTracker) writer.asInstanceOf[BlazeRssShuffleWriterBase[K, V]] } diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/celeborn/BlazeCelebornShuffleWriter.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/celeborn/BlazeCelebornShuffleWriter.scala index 02c24a52..323f959c 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/celeborn/BlazeCelebornShuffleWriter.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/celeborn/BlazeCelebornShuffleWriter.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.blaze.shuffle.celeborn import org.apache.celeborn.client.ShuffleClient import org.apache.spark.shuffle.ShuffleHandle import org.apache.spark.shuffle.ShuffleWriteMetricsReporter -import org.apache.spark.shuffle.celeborn.CelebornShuffleHandle +import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ExecutorShuffleIdTracker, SparkUtils} import org.apache.spark.sql.execution.blaze.shuffle.BlazeRssShuffleWriterBase import org.apache.spark.sql.execution.blaze.shuffle.RssPartitionWriterBase import org.apache.spark.TaskContext @@ -29,7 +29,8 @@ class BlazeCelebornShuffleWriter[K, C]( shuffleClient: ShuffleClient, taskContext: TaskContext, handle: CelebornShuffleHandle[K, _, C], - metrics: ShuffleWriteMetricsReporter) + metrics: ShuffleWriteMetricsReporter, + shuffleIdTracker: ExecutorShuffleIdTracker) extends BlazeRssShuffleWriterBase[K, C](metrics) { private val numMappers = handle.numMappers @@ -41,9 +42,11 @@ class BlazeCelebornShuffleWriter[K, C]( metrics: ShuffleWriteMetricsReporter, numPartitions: Int): RssPartitionWriterBase = { + val shuffleId = SparkUtils.celebornShuffleId(shuffleClient, handle, taskContext, true) + shuffleIdTracker.track(handle.shuffleId, shuffleId) new CelebornPartitionWriter( shuffleClient, - handle.shuffleId, + shuffleId, encodedAttemptId, numMappers, numPartitions,