Skip to content

Commit

Permalink
fix error when fallback happened
Browse files Browse the repository at this point in the history
  • Loading branch information
liuneng1994 committed Aug 12, 2024
1 parent 5d89e56 commit bf78670
Show file tree
Hide file tree
Showing 11 changed files with 154 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
*/
package org.apache.gluten.vectorized;

import org.apache.gluten.execution.ColumnarNativeIterator;

import java.io.IOException;

public class CHShuffleSplitterJniWrapper {
public CHShuffleSplitterJniWrapper() {}

public long make(
ColumnarNativeIterator records,
NativePartitioning part,
int shuffleId,
long mapId,
Expand All @@ -36,6 +39,7 @@ public long make(
long maxSortBufferSize,
boolean forceMemorySort) {
return nativeMake(
records,
part.getShortName(),
part.getNumPartitions(),
part.getExprList(),
Expand All @@ -55,6 +59,7 @@ public long make(
}

public long makeForRSS(
ColumnarNativeIterator records,
NativePartitioning part,
int shuffleId,
long mapId,
Expand All @@ -66,6 +71,7 @@ public long makeForRSS(
Object pusher,
boolean forceMemorySort) {
return nativeMakeForRSS(
records,
part.getShortName(),
part.getNumPartitions(),
part.getExprList(),
Expand All @@ -82,6 +88,7 @@ public long makeForRSS(
}

public native long nativeMake(
ColumnarNativeIterator records,
String shortName,
int numPartitions,
byte[] exprList,
Expand All @@ -100,6 +107,7 @@ public native long nativeMake(
boolean forceMemorySort);

public native long nativeMakeForRSS(
ColumnarNativeIterator records,
String shortName,
int numPartitions,
byte[] exprList,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ package org.apache.spark.shuffle

import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings
import org.apache.gluten.execution.ColumnarNativeIterator
import org.apache.gluten.memory.CHThreadGroup
import org.apache.gluten.vectorized._

import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.{SparkDirectoryUtil, Utils}

import java.io.IOException
Expand Down Expand Up @@ -84,8 +85,21 @@ class CHColumnarShuffleWriter[K, V](
val splitterJniWrapper: CHShuffleSplitterJniWrapper = jniWrapper

val dataTmp = Utils.tempFileWith(shuffleBlockResolver.getDataFile(dep.shuffleId, mapId))
// for fallback
val iter = new ColumnarNativeIterator(new java.util.Iterator[ColumnarBatch] {
override def hasNext: Boolean = {
val has_value = records.hasNext
has_value
}

override def next(): ColumnarBatch = {
val batch = records.next()._2.asInstanceOf[ColumnarBatch]
batch
}
})
if (nativeSplitter == 0) {
nativeSplitter = splitterJniWrapper.make(
iter,
dep.nativePartitioning,
dep.shuffleId,
mapId,
Expand All @@ -101,7 +115,6 @@ class CHColumnarShuffleWriter[K, V](
forceMemorySortShuffle
)
}

splitResult = splitterJniWrapper.stop(nativeSplitter)
if (splitResult.getTotalRows > 0) {
dep.metrics("numInputRows").add(splitResult.getTotalRows)
Expand Down
9 changes: 6 additions & 3 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1667,12 +1667,15 @@ void LocalExecutor::cancel()

void LocalExecutor::execute()
{
chassert(query_pipeline_builder);
push_executor = query_pipeline_builder->execute();
chassert(query_pipeline_builder || external_pipeline_builder);
if (external_pipeline_builder)
push_executor = external_pipeline_builder->execute();
else
push_executor = query_pipeline_builder->execute();
push_executor->execute(local_engine::QueryContextManager::instance().currentQueryContext()->getSettingsRef().max_threads, false);
}

Block & LocalExecutor::getHeader()
Block LocalExecutor::getHeader()
{
return header;
}
Expand Down
16 changes: 14 additions & 2 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ struct SparkBuffer
class LocalExecutor : public BlockIterator
{
public:
static thread_local LocalExecutor * current_executor;
static LocalExecutor * getCurrentExecutor() { return current_executor; }
static void resetCurrentExecutor() { current_executor = nullptr; }
LocalExecutor(QueryPlanPtr query_plan, QueryPipelineBuilderPtr pipeline, bool dump_pipeline_ = false);
Expand All @@ -220,26 +219,39 @@ class LocalExecutor : public BlockIterator
SparkRowInfoPtr next();
Block * nextColumnar();
bool hasNext();
// When a fallback occurs, hasNext will be called to trigger the initialization of the pulling executor
bool initByPulling()
{
return executor.get();
}

/// Stop execution, used when task receives shutdown command or executor receives SIGTERM signal
void cancel();
void setSinks(std::function<void(QueryPipelineBuilder &)> setter)
{
setter(*query_pipeline_builder);
}
// set shuffle write pipeline for fallback
void setExternalPipelineBuilder(QueryPipelineBuilderPtr builder)
{
external_pipeline_builder = std::move(builder);
}
void execute();
Block & getHeader();
Block getHeader();
RelMetricPtr getMetric() const { return metric; }
void setMetric(const RelMetricPtr & metric_) { metric = metric_; }
void setExtraPlanHolder(std::vector<QueryPlanPtr> & extra_plan_holder_) { extra_plan_holder = std::move(extra_plan_holder_); }

private:
static thread_local LocalExecutor * current_executor;
std::unique_ptr<SparkRowInfo> writeBlockToSparkRow(const DB::Block & block) const;
void initPullingPipelineExecutor();
/// Dump processor runtime information to log
std::string dumpPipeline() const;

QueryPipelineBuilderPtr query_pipeline_builder;
// final shuffle write pipeline for fallback
QueryPipelineBuilderPtr external_pipeline_builder = nullptr;
QueryPipeline query_pipeline;
std::unique_ptr<DB::PullingAsyncPipelineExecutor> executor = nullptr;
PipelineExecutorPtr push_executor = nullptr;
Expand Down
8 changes: 4 additions & 4 deletions cpp-ch/local-engine/Shuffle/PartitionWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ size_t LocalPartitionWriter::evictPartitions()
{
auto file = getNextSpillFile();
WriteBufferFromFile output(file, options.io_buffer_size);
auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), shuffle_writer->options.compress_level);
auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), options.compress_level);
CompressedWriteBuffer compressed_output(output, codec, options.io_buffer_size);
NativeWriter writer(compressed_output, output_header);

Expand Down Expand Up @@ -273,7 +273,7 @@ size_t MemorySortLocalPartitionWriter::evictPartitions()
return;
auto file = getNextSpillFile();
WriteBufferFromFile output(file, options.io_buffer_size);
auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), shuffle_writer->options.compress_level);
auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), options.compress_level);
CompressedWriteBuffer compressed_output(output, codec, options.io_buffer_size);
NativeWriter writer(compressed_output, output_header);

Expand Down Expand Up @@ -368,7 +368,7 @@ size_t MemorySortCelebornPartitionWriter::evictPartitions()
return;

WriteBufferFromOwnString output;
auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), shuffle_writer->options.compress_level);
auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), options.compress_level);
CompressedWriteBuffer compressed_output(output, codec, options.io_buffer_size);
NativeWriter writer(compressed_output, output_header);

Expand Down Expand Up @@ -475,7 +475,7 @@ size_t CelebornPartitionWriter::evictSinglePartition(size_t partition_id)
return;

WriteBufferFromOwnString output;
auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), shuffle_writer->options.compress_level);
auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), options.compress_level);
CompressedWriteBuffer compressed_output(output, codec, options.io_buffer_size);
NativeWriter writer(compressed_output, output_header);

Expand Down
15 changes: 9 additions & 6 deletions cpp-ch/local-engine/Shuffle/PartitionWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,15 @@ friend class Spillable;

void initialize(SplitResult * split_result_, const Block & output_header_)
{
chassert(split_result);
split_result = split_result_;
split_result->partition_lengths.resize(options.partition_num);
split_result->raw_partition_lengths.resize(options.partition_num);
output_header = output_header_;
init = true;
if (!init)
{
split_result = split_result_;
chassert(split_result != nullptr);
split_result->partition_lengths.resize(options.partition_num);
split_result->raw_partition_lengths.resize(options.partition_num);
output_header = output_header_;
init = true;
}
}
virtual String getName() const = 0;

Expand Down
4 changes: 2 additions & 2 deletions cpp-ch/local-engine/Shuffle/ShuffleCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

namespace local_engine
{
class SparkExechangeManager;
class SparkExchangeManager;
}

namespace local_engine
Expand Down Expand Up @@ -117,7 +117,7 @@ struct SplitResult

struct SplitterHolder
{
std::unique_ptr<SparkExechangeManager> exechange_manager;
std::unique_ptr<SparkExchangeManager> exchange_manager;
};


Expand Down
27 changes: 14 additions & 13 deletions cpp-ch/local-engine/Shuffle/SparkExchangeSink.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ void SparkExchangeSink::initOutputHeader(const Block & block)
}
}

SparkExechangeManager::SparkExechangeManager(const Block& header, const String & short_name, const SplitOptions & options_, jobject rss_pusher): input_header(header), options(options_)
SparkExchangeManager::SparkExchangeManager(const Block& header, const String & short_name, const SplitOptions & options_, jobject rss_pusher): input_header(header), options(options_)
{
if (rss_pusher)
{
Expand Down Expand Up @@ -143,7 +143,7 @@ std::shared_ptr<PartitionWriter> createPartitionWriter(const SplitOptions& optio
return std::make_shared<LocalPartitionWriter>(options);
}

void SparkExechangeManager::initSinks(size_t num)
void SparkExchangeManager::initSinks(size_t num)
{
if (num > 1 && celeborn_client)
{
Expand All @@ -158,7 +158,7 @@ void SparkExechangeManager::initSinks(size_t num)
}
}

void SparkExechangeManager::setSinksToPipeline(DB::QueryPipelineBuilder & pipeline) const
void SparkExchangeManager::setSinksToPipeline(DB::QueryPipelineBuilder & pipeline) const
{
size_t count = 0;
Pipe::ProcessorGetterWithStreamKind getter = [&](const Block & header, Pipe::StreamType stream_type) -> ProcessorPtr
Expand All @@ -173,12 +173,12 @@ void SparkExechangeManager::setSinksToPipeline(DB::QueryPipelineBuilder & pipeli
pipeline.setSinks(getter);
}

SelectBuilderPtr SparkExechangeManager::createRoundRobinSelectorBuilder(const SplitOptions & options_)
SelectBuilderPtr SparkExchangeManager::createRoundRobinSelectorBuilder(const SplitOptions & options_)
{
return std::make_unique<RoundRobinSelectorBuilder>(options_.partition_num);
}

SelectBuilderPtr SparkExechangeManager::createHashSelectorBuilder(const SplitOptions & options_)
SelectBuilderPtr SparkExchangeManager::createHashSelectorBuilder(const SplitOptions & options_)
{
Poco::StringTokenizer expr_list(options_.hash_exprs, ",");
std::vector<size_t> hash_fields;
Expand All @@ -189,18 +189,18 @@ SelectBuilderPtr SparkExechangeManager::createHashSelectorBuilder(const SplitOpt
return std::make_unique<HashSelectorBuilder>(options_.partition_num, hash_fields, options_.hash_algorithm);
}

SelectBuilderPtr SparkExechangeManager::createSingleSelectorBuilder(const SplitOptions & options_)
SelectBuilderPtr SparkExchangeManager::createSingleSelectorBuilder(const SplitOptions & options_)
{
chassert(options_.partition_num == 1);
return std::make_unique<RoundRobinSelectorBuilder>(options_.partition_num);
}

SelectBuilderPtr SparkExechangeManager::createRangeSelectorBuilder(const SplitOptions & options_)
SelectBuilderPtr SparkExchangeManager::createRangeSelectorBuilder(const SplitOptions & options_)
{
return std::make_unique<RangeSelectorBuilder>(options_.hash_exprs, options_.partition_num);
}

void SparkExechangeManager::finish()
void SparkExchangeManager::finish()
{
Stopwatch wall_time;
mergeSplitResult();
Expand All @@ -222,7 +222,7 @@ void SparkExechangeManager::finish()
split_result.wall_time += wall_time.elapsedNanoseconds();
}

void SparkExechangeManager::mergeSplitResult()
void SparkExchangeManager::mergeSplitResult()
{
for (const auto & sink : sinks)
{
Expand All @@ -242,7 +242,7 @@ void SparkExechangeManager::mergeSplitResult()
}
}

std::vector<SpillInfo> SparkExechangeManager::gatherAllSpillInfo()
std::vector<SpillInfo> SparkExchangeManager::gatherAllSpillInfo()
{
std::vector<SpillInfo> res;
for (const auto& writer : partition_writers)
Expand All @@ -256,9 +256,10 @@ std::vector<SpillInfo> SparkExechangeManager::gatherAllSpillInfo()
return res;
}

std::vector<UInt64> SparkExechangeManager::mergeSpills(DB::WriteBuffer & data_file, const std::vector<SpillInfo>& spill_infos, const std::vector<Spillable::ExtraData> & extra_datas)
std::vector<UInt64> SparkExchangeManager::mergeSpills(DB::WriteBuffer & data_file, const std::vector<SpillInfo>& spill_infos, const std::vector<Spillable::ExtraData> & extra_datas)
{
auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), {});
if (sinks.empty()) return {};
auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), options.compress_level);

CompressedWriteBuffer compressed_output(data_file, codec, options.io_buffer_size);
NativeWriter writer(compressed_output, sinks.front()->getOutputHeaderCopy());
Expand Down Expand Up @@ -340,7 +341,7 @@ std::vector<UInt64> SparkExechangeManager::mergeSpills(DB::WriteBuffer & data_fi
return partition_length;
}

std::unordered_map<String, SelectBuilderCreator> SparkExechangeManager::partitioner_creators = {
std::unordered_map<String, SelectBuilderCreator> SparkExchangeManager::partitioner_creators = {
{"rr", createRoundRobinSelectorBuilder},
{"hash", createHashSelectorBuilder},
{"single", createSingleSelectorBuilder},
Expand Down
4 changes: 2 additions & 2 deletions cpp-ch/local-engine/Shuffle/SparkExchangeSink.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ class SparkExchangeSink : public DB::ISink
using SelectBuilderPtr = std::unique_ptr<SelectorBuilder>;
using SelectBuilderCreator = std::function<SelectBuilderPtr(const SplitOptions &)>;

class SparkExechangeManager
class SparkExchangeManager
{
public:
SparkExechangeManager(const DB::Block& header, const String & short_name, const SplitOptions & options_, jobject rss_pusher = nullptr);
SparkExchangeManager(const DB::Block& header, const String & short_name, const SplitOptions & options_, jobject rss_pusher = nullptr);
void initSinks(size_t num);
void setSinksToPipeline(DB::QueryPipelineBuilder & pipeline) const;
void finish();
Expand Down
Loading

0 comments on commit bf78670

Please sign in to comment.