Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
liuneng1994 committed Sep 10, 2024
1 parent 0f7982b commit e00a93c
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,9 @@ class CollectMetricIterator(
if (!metricsUpdated) {
val nativeMetrics = nativeIterator.getMetrics.asInstanceOf[NativeMetrics]
if (wholeStagePipeline) {
outputRowCount = Math.max(outputRowCount, CHColumnarShuffleWriter.getTotalOutputRows())
outputRowCount = Math.max(outputRowCount, CHColumnarShuffleWriter.getTotalOutputRows)
outputVectorCount =
Math.max(outputVectorCount, CHColumnarShuffleWriter.getTotalOutputBatches())
Math.max(outputVectorCount, CHColumnarShuffleWriter.getTotalOutputBatches)
}
nativeMetrics.setFinalOutputMetrics(outputRowCount, outputVectorCount)
updateNativeMetrics(nativeMetrics)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.gluten.execution.ColumnarNativeIterator
import org.apache.gluten.memory.CHThreadGroup
import org.apache.gluten.vectorized._

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand Down Expand Up @@ -189,34 +189,30 @@ class CHColumnarShuffleWriter[K, V](

}

case class OutputMetrics(totalRows: Long, totalBatches: Long)

object CHColumnarShuffleWriter {
private val TOTAL_OUTPUT_ROWS = "total_output_rows"

private val TOTAL_OUTPUT_BATCHES = "total_output_batches"
private var metric = new ThreadLocal[OutputMetrics]()

// Pass the statistics of the last operator before shuffle to CollectMetricIterator.
def setOutputMetrics(splitResult: CHSplitResult): Unit = {
TaskContext
.get()
.getLocalProperties
.setProperty(TOTAL_OUTPUT_ROWS, splitResult.getTotalRows.toString)
TaskContext
.get()
.getLocalProperties
.setProperty(TOTAL_OUTPUT_BATCHES, splitResult.getTotalBatches.toString)
metric.set(OutputMetrics(splitResult.getTotalRows, splitResult.getTotalBatches))
}

def getTotalOutputRows(): Long = {
val output_rows = TaskContext.get().getLocalProperty(TOTAL_OUTPUT_ROWS)
var output_rows_value = 0L
if (output_rows != null && output_rows.nonEmpty) output_rows_value = output_rows.toLong
output_rows_value
def getTotalOutputRows: Long = {
if (metric.get() == null) {
0
} else {
metric.get().totalRows
}
}

def getTotalOutputBatches(): Long = {
val output_batches = TaskContext.get().getLocalProperty(TOTAL_OUTPUT_BATCHES)
var output_batches_value = 0L
if (output_batches != null) output_batches_value = output_batches.toLong
output_batches_value
def getTotalOutputBatches: Long = {
if (metric.get() == null) {
0
} else {
metric.get().totalBatches
}
}
}
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/LocalExecutor.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct SparkBuffer
class LocalExecutor : public BlockIterator
{
public:
static LocalExecutor * getCurrentExecutor() { return current_executor; }
static std::optional<LocalExecutor *> getCurrentExecutor() { return current_executor; }
static void resetCurrentExecutor() { current_executor = nullptr; }
LocalExecutor(DB::QueryPlanPtr query_plan, DB::QueryPipelineBuilderPtr pipeline, bool dump_pipeline_ = false);
~LocalExecutor();
Expand Down
13 changes: 7 additions & 6 deletions cpp-ch/local-engine/local_engine_jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -544,11 +544,12 @@ local_engine::SplitterHolder * buildAndExecuteShuffle(JNIEnv * env,
jobject rss_pusher = nullptr
)
{
auto * current_executor = local_engine::LocalExecutor::getCurrentExecutor();
chassert(current_executor);
auto current_executor = local_engine::LocalExecutor::getCurrentExecutor();
local_engine::SplitterHolder * splitter = nullptr;
// handle fallback, whole stage fallback or partial fallback
if (!current_executor || current_executor->fallbackMode())
// There are two modes of fallback, one is full fallback but uses columnar shuffle,
// and the other is partial fallback that creates one or more LocalExecutor.
// In full fallback, the current executor does not exist.
if (!current_executor.has_value() || current_executor.value()->fallbackMode())
{
auto first_block = local_engine::SourceFromJavaIter::peekBlock(env, iter);
if (first_block.has_value())
Expand All @@ -574,9 +575,9 @@ local_engine::SplitterHolder * buildAndExecuteShuffle(JNIEnv * env,
splitter = new local_engine::SplitterHolder{.exchange_manager = std::make_unique<local_engine::SparkExchangeManager>(current_executor->getHeader().cloneEmpty(), name, options, rss_pusher)};
// TODO support multiple sinks
splitter->exchange_manager->initSinks(1);
current_executor->setSinks([&](auto & pipeline_builder) { splitter->exchange_manager->setSinksToPipeline(pipeline_builder);});
current_executor.value()->setSinks([&](auto & pipeline_builder) { splitter->exchange_manager->setSinksToPipeline(pipeline_builder);});
// execute pipeline
current_executor->execute();
current_executor.value()->execute();
}
return splitter;
}
Expand Down

0 comments on commit e00a93c

Please sign in to comment.