Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
liuneng1994 committed Sep 10, 2024
1 parent 0f7982b commit aa46eea
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,9 @@ class CollectMetricIterator(
if (!metricsUpdated) {
val nativeMetrics = nativeIterator.getMetrics.asInstanceOf[NativeMetrics]
if (wholeStagePipeline) {
outputRowCount = Math.max(outputRowCount, CHColumnarShuffleWriter.getTotalOutputRows())
outputRowCount = Math.max(outputRowCount, CHColumnarShuffleWriter.getTotalOutputRows)
outputVectorCount =
Math.max(outputVectorCount, CHColumnarShuffleWriter.getTotalOutputBatches())
Math.max(outputVectorCount, CHColumnarShuffleWriter.getTotalOutputBatches)
}
nativeMetrics.setFinalOutputMetrics(outputRowCount, outputVectorCount)
updateNativeMetrics(nativeMetrics)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings
import org.apache.gluten.execution.ColumnarNativeIterator
import org.apache.gluten.memory.CHThreadGroup
import org.apache.gluten.vectorized._

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand Down Expand Up @@ -189,34 +188,30 @@ class CHColumnarShuffleWriter[K, V](

}

case class OutputMetrics(totalRows: Long, totalBatches: Long)

object CHColumnarShuffleWriter {
private val TOTAL_OUTPUT_ROWS = "total_output_rows"

private val TOTAL_OUTPUT_BATCHES = "total_output_batches"
private var metric = new ThreadLocal[OutputMetrics]()

// Pass the statistics of the last operator before shuffle to CollectMetricIterator.
def setOutputMetrics(splitResult: CHSplitResult): Unit = {
TaskContext
.get()
.getLocalProperties
.setProperty(TOTAL_OUTPUT_ROWS, splitResult.getTotalRows.toString)
TaskContext
.get()
.getLocalProperties
.setProperty(TOTAL_OUTPUT_BATCHES, splitResult.getTotalBatches.toString)
metric.set(OutputMetrics(splitResult.getTotalRows, splitResult.getTotalBatches))
}

def getTotalOutputRows(): Long = {
val output_rows = TaskContext.get().getLocalProperty(TOTAL_OUTPUT_ROWS)
var output_rows_value = 0L
if (output_rows != null && output_rows.nonEmpty) output_rows_value = output_rows.toLong
output_rows_value
def getTotalOutputRows: Long = {
if (metric.get() == null) {
0
} else {
metric.get().totalRows
}
}

def getTotalOutputBatches(): Long = {
val output_batches = TaskContext.get().getLocalProperty(TOTAL_OUTPUT_BATCHES)
var output_batches_value = 0L
if (output_batches != null) output_batches_value = output_batches.toLong
output_batches_value
def getTotalOutputBatches: Long = {
if (metric.get() == null) {
0
} else {
metric.get().totalBatches
}
}
}

0 comments on commit aa46eea

Please sign in to comment.