From aa46eeafcdf96e04c9316b9207e48f2b4261e5df Mon Sep 17 00:00:00 2001 From: liuneng1994 Date: Tue, 10 Sep 2024 10:10:13 +0800 Subject: [PATCH] test --- .../clickhouse/CHIteratorApi.scala | 4 +- .../shuffle/CHColumnarShuffleWriter.scala | 39 ++++++++----------- 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala index e90a3821a41ba..f33e767e13e0f 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala @@ -353,9 +353,9 @@ class CollectMetricIterator( if (!metricsUpdated) { val nativeMetrics = nativeIterator.getMetrics.asInstanceOf[NativeMetrics] if (wholeStagePipeline) { - outputRowCount = Math.max(outputRowCount, CHColumnarShuffleWriter.getTotalOutputRows()) + outputRowCount = Math.max(outputRowCount, CHColumnarShuffleWriter.getTotalOutputRows) outputVectorCount = - Math.max(outputVectorCount, CHColumnarShuffleWriter.getTotalOutputBatches()) + Math.max(outputVectorCount, CHColumnarShuffleWriter.getTotalOutputBatches) } nativeMetrics.setFinalOutputMetrics(outputRowCount, outputVectorCount) updateNativeMetrics(nativeMetrics) diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala b/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala index 53f85d84672b9..a883e97847bc4 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala @@ -21,8 +21,7 @@ import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings import org.apache.gluten.execution.ColumnarNativeIterator import org.apache.gluten.memory.CHThreadGroup import org.apache.gluten.vectorized._ - -import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.scheduler.MapStatus import org.apache.spark.sql.vectorized.ColumnarBatch @@ -189,34 +188,30 @@ class CHColumnarShuffleWriter[K, V]( } +case class OutputMetrics(totalRows: Long, totalBatches: Long) + object CHColumnarShuffleWriter { - private val TOTAL_OUTPUT_ROWS = "total_output_rows" - private val TOTAL_OUTPUT_BATCHES = "total_output_batches" + private var metric = new ThreadLocal[OutputMetrics]() // Pass the statistics of the last operator before shuffle to CollectMetricIterator. def setOutputMetrics(splitResult: CHSplitResult): Unit = { - TaskContext - .get() - .getLocalProperties - .setProperty(TOTAL_OUTPUT_ROWS, splitResult.getTotalRows.toString) - TaskContext - .get() - .getLocalProperties - .setProperty(TOTAL_OUTPUT_BATCHES, splitResult.getTotalBatches.toString) + metric.set(OutputMetrics(splitResult.getTotalRows, splitResult.getTotalBatches)) } - def getTotalOutputRows(): Long = { - val output_rows = TaskContext.get().getLocalProperty(TOTAL_OUTPUT_ROWS) - var output_rows_value = 0L - if (output_rows != null && output_rows.nonEmpty) output_rows_value = output_rows.toLong - output_rows_value + def getTotalOutputRows: Long = { + if (metric.get() == null) { + 0 + } else { + metric.get().totalRows + } } - def getTotalOutputBatches(): Long = { - val output_batches = TaskContext.get().getLocalProperty(TOTAL_OUTPUT_BATCHES) - var output_batches_value = 0L - if (output_batches != null) output_batches_value = output_batches.toLong - output_batches_value + def getTotalOutputBatches: Long = { + if (metric.get() == null) { + 0 + } else { + metric.get().totalBatches + } } }