From 02c3893e7a85694a3a6a7ea37b9515a9f8197b15 Mon Sep 17 00:00:00 2001 From: liuneng1994 Date: Mon, 12 Aug 2024 17:48:25 +0800 Subject: [PATCH] fix error when fallback happened --- .../CHShuffleSplitterJniWrapper.java | 8 ++ .../shuffle/CHColumnarShuffleWriter.scala | 16 +++- .../Parser/SerializedPlanParser.cpp | 9 +- .../Parser/SerializedPlanParser.h | 16 +++- .../local-engine/Shuffle/PartitionWriter.cpp | 8 +- cpp-ch/local-engine/Shuffle/PartitionWriter.h | 15 ++-- cpp-ch/local-engine/Shuffle/ShuffleCommon.h | 4 +- .../Shuffle/SparkExchangeSink.cpp | 27 +++--- .../local-engine/Shuffle/SparkExchangeSink.h | 4 +- cpp-ch/local-engine/local_engine_jni.cpp | 83 +++++++++++++++---- .../CHCelebornColumnarShuffleWriter.scala | 18 +++- 11 files changed, 154 insertions(+), 54 deletions(-) diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/CHShuffleSplitterJniWrapper.java b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/CHShuffleSplitterJniWrapper.java index 66cefa62e2393..64d41f306a661 100644 --- a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/CHShuffleSplitterJniWrapper.java +++ b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/CHShuffleSplitterJniWrapper.java @@ -16,12 +16,15 @@ */ package org.apache.gluten.vectorized; +import org.apache.gluten.execution.ColumnarNativeIterator; + import java.io.IOException; public class CHShuffleSplitterJniWrapper { public CHShuffleSplitterJniWrapper() {} public long make( + ColumnarNativeIterator records, NativePartitioning part, int shuffleId, long mapId, @@ -36,6 +39,7 @@ public long make( long maxSortBufferSize, boolean forceMemorySort) { return nativeMake( + records, part.getShortName(), part.getNumPartitions(), part.getExprList(), @@ -55,6 +59,7 @@ public long make( } public long makeForRSS( + ColumnarNativeIterator records, NativePartitioning part, int shuffleId, long mapId, @@ -66,6 +71,7 @@ public long makeForRSS( Object pusher, boolean forceMemorySort) { return nativeMakeForRSS( + records, part.getShortName(), part.getNumPartitions(), part.getExprList(), @@ -82,6 +88,7 @@ public long makeForRSS( } public native long nativeMake( + ColumnarNativeIterator records, String shortName, int numPartitions, byte[] exprList, @@ -100,6 +107,7 @@ public native long nativeMake( boolean forceMemorySort); public native long nativeMakeForRSS( + ColumnarNativeIterator records, String shortName, int numPartitions, byte[] exprList, diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala b/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala index 3f25762b0947a..41fc114509141 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala @@ -18,12 +18,14 @@ package org.apache.spark.shuffle import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings +import org.apache.gluten.execution.ColumnarNativeIterator import org.apache.gluten.memory.CHThreadGroup import org.apache.gluten.vectorized._ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.scheduler.MapStatus +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.{SparkDirectoryUtil, Utils} import java.io.IOException @@ -84,8 +86,21 @@ class CHColumnarShuffleWriter[K, V]( val splitterJniWrapper: CHShuffleSplitterJniWrapper = jniWrapper val dataTmp = Utils.tempFileWith(shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)) + // for fallback + val iter = new ColumnarNativeIterator(new java.util.Iterator[ColumnarBatch] { + override def hasNext: Boolean = { + val has_value = records.hasNext + has_value + } + + override def next(): ColumnarBatch = { + val batch = records.next()._2.asInstanceOf[ColumnarBatch] + batch + } + }) if (nativeSplitter == 0) { nativeSplitter = splitterJniWrapper.make( + iter, dep.nativePartitioning, dep.shuffleId, mapId, @@ -101,7 +116,6 @@ class CHColumnarShuffleWriter[K, V]( forceMemorySortShuffle ) } - splitResult = splitterJniWrapper.stop(nativeSplitter) if (splitResult.getTotalRows > 0) { dep.metrics("numInputRows").add(splitResult.getTotalRows) diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index 3537083d9342e..0e07b3ccbddd0 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -1667,12 +1667,15 @@ void LocalExecutor::cancel() void LocalExecutor::execute() { - chassert(query_pipeline_builder); - push_executor = query_pipeline_builder->execute(); + chassert(query_pipeline_builder || external_pipeline_builder); + if (external_pipeline_builder) + push_executor = external_pipeline_builder->execute(); + else + push_executor = query_pipeline_builder->execute(); push_executor->execute(local_engine::QueryContextManager::instance().currentQueryContext()->getSettingsRef().max_threads, false); } -Block & LocalExecutor::getHeader() +Block LocalExecutor::getHeader() { return header; } diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index 2f3c365ebd3ea..04573b7581c52 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -211,7 +211,6 @@ struct SparkBuffer class LocalExecutor : public BlockIterator { public: - static thread_local LocalExecutor * current_executor; static LocalExecutor * getCurrentExecutor() { return current_executor; } static void resetCurrentExecutor() { current_executor = nullptr; } LocalExecutor(QueryPlanPtr query_plan, QueryPipelineBuilderPtr pipeline, bool dump_pipeline_ = false); @@ -220,6 +219,11 @@ class LocalExecutor : public BlockIterator SparkRowInfoPtr next(); Block * nextColumnar(); bool hasNext(); + // When a fallback occurs, hasNext will be called to trigger the initialization of the pulling executor + bool initByPulling() + { + return executor.get(); + } /// Stop execution, used when task receives shutdown command or executor receives SIGTERM signal void cancel(); @@ -227,19 +231,27 @@ class LocalExecutor : public BlockIterator { setter(*query_pipeline_builder); } + // set shuffle write pipeline for fallback + void setExternalPipelineBuilder(QueryPipelineBuilderPtr builder) + { + external_pipeline_builder = std::move(builder); + } void execute(); - Block & getHeader(); + Block getHeader(); RelMetricPtr getMetric() const { return metric; } void setMetric(const RelMetricPtr & metric_) { metric = metric_; } void setExtraPlanHolder(std::vector & extra_plan_holder_) { extra_plan_holder = std::move(extra_plan_holder_); } private: + static thread_local LocalExecutor * current_executor; std::unique_ptr writeBlockToSparkRow(const DB::Block & block) const; void initPullingPipelineExecutor(); /// Dump processor runtime information to log std::string dumpPipeline() const; QueryPipelineBuilderPtr query_pipeline_builder; + // final shuffle write pipeline for fallback + QueryPipelineBuilderPtr external_pipeline_builder = nullptr; QueryPipeline query_pipeline; std::unique_ptr executor = nullptr; PipelineExecutorPtr push_executor = nullptr; diff --git a/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp b/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp index c76dd6fff9c4c..0110c7380abaa 100644 --- a/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp +++ b/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp @@ -140,7 +140,7 @@ size_t LocalPartitionWriter::evictPartitions() { auto file = getNextSpillFile(); WriteBufferFromFile output(file, options.io_buffer_size); - auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), shuffle_writer->options.compress_level); + auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), options.compress_level); CompressedWriteBuffer compressed_output(output, codec, options.io_buffer_size); NativeWriter writer(compressed_output, output_header); @@ -273,7 +273,7 @@ size_t MemorySortLocalPartitionWriter::evictPartitions() return; auto file = getNextSpillFile(); WriteBufferFromFile output(file, options.io_buffer_size); - auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), shuffle_writer->options.compress_level); + auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), options.compress_level); CompressedWriteBuffer compressed_output(output, codec, options.io_buffer_size); NativeWriter writer(compressed_output, output_header); @@ -368,7 +368,7 @@ size_t MemorySortCelebornPartitionWriter::evictPartitions() return; WriteBufferFromOwnString output; - auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), shuffle_writer->options.compress_level); + auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), options.compress_level); CompressedWriteBuffer compressed_output(output, codec, options.io_buffer_size); NativeWriter writer(compressed_output, output_header); @@ -475,7 +475,7 @@ size_t CelebornPartitionWriter::evictSinglePartition(size_t partition_id) return; WriteBufferFromOwnString output; - auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), shuffle_writer->options.compress_level); + auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), options.compress_level); CompressedWriteBuffer compressed_output(output, codec, options.io_buffer_size); NativeWriter writer(compressed_output, output_header); diff --git a/cpp-ch/local-engine/Shuffle/PartitionWriter.h b/cpp-ch/local-engine/Shuffle/PartitionWriter.h index 3ca251abea957..407998121287c 100644 --- a/cpp-ch/local-engine/Shuffle/PartitionWriter.h +++ b/cpp-ch/local-engine/Shuffle/PartitionWriter.h @@ -70,12 +70,15 @@ friend class Spillable; void initialize(SplitResult * split_result_, const Block & output_header_) { - chassert(split_result); - split_result = split_result_; - split_result->partition_lengths.resize(options.partition_num); - split_result->raw_partition_lengths.resize(options.partition_num); - output_header = output_header_; - init = true; + if (!init) + { + split_result = split_result_; + chassert(split_result != nullptr); + split_result->partition_lengths.resize(options.partition_num); + split_result->raw_partition_lengths.resize(options.partition_num); + output_header = output_header_; + init = true; + } } virtual String getName() const = 0; diff --git a/cpp-ch/local-engine/Shuffle/ShuffleCommon.h b/cpp-ch/local-engine/Shuffle/ShuffleCommon.h index ab8e582ebf2d9..a2aa447f50ced 100644 --- a/cpp-ch/local-engine/Shuffle/ShuffleCommon.h +++ b/cpp-ch/local-engine/Shuffle/ShuffleCommon.h @@ -26,7 +26,7 @@ namespace local_engine { - class SparkExechangeManager; + class SparkExchangeManager; } namespace local_engine @@ -117,7 +117,7 @@ struct SplitResult struct SplitterHolder { - std::unique_ptr exechange_manager; + std::unique_ptr exchange_manager; }; diff --git a/cpp-ch/local-engine/Shuffle/SparkExchangeSink.cpp b/cpp-ch/local-engine/Shuffle/SparkExchangeSink.cpp index 150fbc740f120..f5af7ca7680b0 100644 --- a/cpp-ch/local-engine/Shuffle/SparkExchangeSink.cpp +++ b/cpp-ch/local-engine/Shuffle/SparkExchangeSink.cpp @@ -102,7 +102,7 @@ void SparkExchangeSink::initOutputHeader(const Block & block) } } -SparkExechangeManager::SparkExechangeManager(const Block& header, const String & short_name, const SplitOptions & options_, jobject rss_pusher): input_header(header), options(options_) +SparkExchangeManager::SparkExchangeManager(const Block& header, const String & short_name, const SplitOptions & options_, jobject rss_pusher): input_header(header), options(options_) { if (rss_pusher) { @@ -143,7 +143,7 @@ std::shared_ptr createPartitionWriter(const SplitOptions& optio return std::make_shared(options); } -void SparkExechangeManager::initSinks(size_t num) +void SparkExchangeManager::initSinks(size_t num) { if (num > 1 && celeborn_client) { @@ -158,7 +158,7 @@ void SparkExechangeManager::initSinks(size_t num) } } -void SparkExechangeManager::setSinksToPipeline(DB::QueryPipelineBuilder & pipeline) const +void SparkExchangeManager::setSinksToPipeline(DB::QueryPipelineBuilder & pipeline) const { size_t count = 0; Pipe::ProcessorGetterWithStreamKind getter = [&](const Block & header, Pipe::StreamType stream_type) -> ProcessorPtr @@ -173,12 +173,12 @@ void SparkExechangeManager::setSinksToPipeline(DB::QueryPipelineBuilder & pipeli pipeline.setSinks(getter); } -SelectBuilderPtr SparkExechangeManager::createRoundRobinSelectorBuilder(const SplitOptions & options_) +SelectBuilderPtr SparkExchangeManager::createRoundRobinSelectorBuilder(const SplitOptions & options_) { return std::make_unique(options_.partition_num); } -SelectBuilderPtr SparkExechangeManager::createHashSelectorBuilder(const SplitOptions & options_) +SelectBuilderPtr SparkExchangeManager::createHashSelectorBuilder(const SplitOptions & options_) { Poco::StringTokenizer expr_list(options_.hash_exprs, ","); std::vector hash_fields; @@ -189,18 +189,18 @@ SelectBuilderPtr SparkExechangeManager::createHashSelectorBuilder(const SplitOpt return std::make_unique(options_.partition_num, hash_fields, options_.hash_algorithm); } -SelectBuilderPtr SparkExechangeManager::createSingleSelectorBuilder(const SplitOptions & options_) +SelectBuilderPtr SparkExchangeManager::createSingleSelectorBuilder(const SplitOptions & options_) { chassert(options_.partition_num == 1); return std::make_unique(options_.partition_num); } -SelectBuilderPtr SparkExechangeManager::createRangeSelectorBuilder(const SplitOptions & options_) +SelectBuilderPtr SparkExchangeManager::createRangeSelectorBuilder(const SplitOptions & options_) { return std::make_unique(options_.hash_exprs, options_.partition_num); } -void SparkExechangeManager::finish() +void SparkExchangeManager::finish() { Stopwatch wall_time; mergeSplitResult(); @@ -222,7 +222,7 @@ void SparkExechangeManager::finish() split_result.wall_time += wall_time.elapsedNanoseconds(); } -void SparkExechangeManager::mergeSplitResult() +void SparkExchangeManager::mergeSplitResult() { for (const auto & sink : sinks) { @@ -242,7 +242,7 @@ void SparkExechangeManager::mergeSplitResult() } } -std::vector SparkExechangeManager::gatherAllSpillInfo() +std::vector SparkExchangeManager::gatherAllSpillInfo() { std::vector res; for (const auto& writer : partition_writers) @@ -256,9 +256,10 @@ std::vector SparkExechangeManager::gatherAllSpillInfo() return res; } -std::vector SparkExechangeManager::mergeSpills(DB::WriteBuffer & data_file, const std::vector& spill_infos, const std::vector & extra_datas) +std::vector SparkExchangeManager::mergeSpills(DB::WriteBuffer & data_file, const std::vector& spill_infos, const std::vector & extra_datas) { - auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), {}); + if (sinks.empty()) return {}; + auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), options.compress_level); CompressedWriteBuffer compressed_output(data_file, codec, options.io_buffer_size); NativeWriter writer(compressed_output, sinks.front()->getOutputHeaderCopy()); @@ -340,7 +341,7 @@ std::vector SparkExechangeManager::mergeSpills(DB::WriteBuffer & data_fi return partition_length; } -std::unordered_map SparkExechangeManager::partitioner_creators = { +std::unordered_map SparkExchangeManager::partitioner_creators = { {"rr", createRoundRobinSelectorBuilder}, {"hash", createHashSelectorBuilder}, {"single", createSingleSelectorBuilder}, diff --git a/cpp-ch/local-engine/Shuffle/SparkExchangeSink.h b/cpp-ch/local-engine/Shuffle/SparkExchangeSink.h index 69215fac09592..409a021d96082 100644 --- a/cpp-ch/local-engine/Shuffle/SparkExchangeSink.h +++ b/cpp-ch/local-engine/Shuffle/SparkExchangeSink.h @@ -81,10 +81,10 @@ class SparkExchangeSink : public DB::ISink using SelectBuilderPtr = std::unique_ptr; using SelectBuilderCreator = std::function; -class SparkExechangeManager +class SparkExchangeManager { public: - SparkExechangeManager(const DB::Block& header, const String & short_name, const SplitOptions & options_, jobject rss_pusher = nullptr); + SparkExchangeManager(const DB::Block& header, const String & short_name, const SplitOptions & options_, jobject rss_pusher = nullptr); void initSinks(size_t num); void setSinksToPipeline(DB::QueryPipelineBuilder & pipeline) const; void finish(); diff --git a/cpp-ch/local-engine/local_engine_jni.cpp b/cpp-ch/local-engine/local_engine_jni.cpp index 57493f5bfd6bb..087a37b1b4578 100644 --- a/cpp-ch/local-engine/local_engine_jni.cpp +++ b/cpp-ch/local-engine/local_engine_jni.cpp @@ -49,11 +49,11 @@ #include #include #include -#include #include #include #include #include +#include #include @@ -532,10 +532,68 @@ JNIEXPORT void Java_org_apache_gluten_vectorized_CHStreamReader_nativeClose(JNIE LOCAL_ENGINE_JNI_METHOD_END(env, ) } +local_engine::SplitterHolder * buildAndExecuteShuffle(JNIEnv * env, + jobject iter, + const String & name, + const local_engine::SplitOptions& options, + jobject rss_pusher = nullptr + ) +{ + auto * current_executor = local_engine::LocalExecutor::getCurrentExecutor(); + chassert(current_executor); + local_engine::SplitterHolder * splitter = nullptr; + // handle fallback, whole stage fallback or partial fallback + if (!current_executor || current_executor->initByPulling()) + { + auto * first_block = local_engine::SourceFromJavaIter::peekBlock(env, iter); + if (first_block) + { + /// Try to decide header from the first block read from Java iterator. + auto header = first_block->cloneEmpty(); + auto context = local_engine::QueryContextManager::instance().currentQueryContext(); + auto global_iter = env->NewGlobalRef(iter); + auto source = std::make_shared(context, first_block->cloneEmpty(), global_iter, true, first_block); + splitter = new local_engine::SplitterHolder{.exchange_manager = std::make_unique(first_block->cloneEmpty(), name, options, rss_pusher)}; + + DB::QueryPipelineBuilderPtr builder = std::make_unique(); + builder->init(DB::Pipe(source)); + // fallback only support one sink + splitter->exchange_manager->initSinks(1); + splitter->exchange_manager->setSinksToPipeline(*builder); + if (current_executor) + { + // partial fallback, can't build whole stage pipeline + current_executor->setExternalPipelineBuilder(std::move(builder)); + current_executor->execute(); + } + else + { + // whole stage fallback which no LocalExecutor created but use columnar shuffle + auto executor = builder->execute(); + executor->execute(1, false); + } + } + else + // empty iterator + splitter = new local_engine::SplitterHolder{.exchange_manager = std::make_unique(DB::Block(), name, options, rss_pusher)}; + } + else + { + splitter = new local_engine::SplitterHolder{.exchange_manager = std::make_unique(current_executor->getHeader().cloneEmpty(), name, options, rss_pusher)}; + // TODO support multiple sinks + splitter->exchange_manager->initSinks(1); + current_executor->setSinks([&](auto & pipeline_builder) { splitter->exchange_manager->setSinksToPipeline(pipeline_builder);}); + // execute pipeline + current_executor->execute(); + } + return splitter; +} + // Splitter Jni Wrapper JNIEXPORT jlong Java_org_apache_gluten_vectorized_CHShuffleSplitterJniWrapper_nativeMake( JNIEnv * env, jobject, + jobject iter, jstring short_name, jint num_partitions, jbyteArray expr_list, @@ -592,20 +650,15 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_CHShuffleSplitterJniWrapper_na .max_sort_buffer_size = static_cast(max_sort_buffer_size), .force_memory_sort = static_cast(force_memory_sort)}; auto name = jstring2string(env, short_name); - auto * current_executor = local_engine::LocalExecutor::getCurrentExecutor(); - chassert(current_executor); - local_engine::SplitterHolder * splitter = new local_engine::SplitterHolder{.exechange_manager = std::make_unique(current_executor->getHeader(), name, options)}; - splitter->exechange_manager->initSinks(1); - current_executor->setSinks([&](auto & pipeline_builder) { splitter->exechange_manager->setSinksToPipeline(pipeline_builder);}); - // execute pipeline - current_executor->execute(); - return reinterpret_cast(splitter); + + return reinterpret_cast(buildAndExecuteShuffle(env, iter, name, options)); LOCAL_ENGINE_JNI_METHOD_END(env, -1) } JNIEXPORT jlong Java_org_apache_gluten_vectorized_CHShuffleSplitterJniWrapper_nativeMakeForRSS( JNIEnv * env, jobject, + jobject iter, jstring short_name, jint num_partitions, jbyteArray expr_list, @@ -651,13 +704,7 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_CHShuffleSplitterJniWrapper_na .hash_algorithm = jstring2string(env, hash_algorithm), .force_memory_sort = static_cast(force_memory_sort)}; auto name = jstring2string(env, short_name); - auto * current_executor = local_engine::LocalExecutor::getCurrentExecutor(); - chassert(current_executor); - local_engine::SplitterHolder * splitter = new local_engine::SplitterHolder{.exechange_manager = std::make_unique(current_executor->getHeader(), name, options, pusher)}; - splitter->exechange_manager->initSinks(local_engine::QueryContextManager::instance().currentQueryContext()->getSettingsRef().max_threads); - current_executor->setSinks([&](auto & pipeline_builder) { splitter->exechange_manager->setSinksToPipeline(pipeline_builder);}); - current_executor->execute(); - return reinterpret_cast(splitter); + return reinterpret_cast(buildAndExecuteShuffle(env, iter, name, options, pusher)); LOCAL_ENGINE_JNI_METHOD_END(env, -1) } @@ -665,8 +712,8 @@ JNIEXPORT jobject Java_org_apache_gluten_vectorized_CHShuffleSplitterJniWrapper_ { LOCAL_ENGINE_JNI_METHOD_START local_engine::SplitterHolder * splitter = reinterpret_cast(splitterId); - splitter->exechange_manager->finish(); - auto result = splitter->exechange_manager->getSplitResult(); + splitter->exchange_manager->finish(); + auto result = splitter->exchange_manager->getSplitResult(); const auto & partition_lengths = result.partition_lengths; auto * partition_length_arr = env->NewLongArray(partition_lengths.size()); diff --git a/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala b/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala index 13a453f064338..6c97dc0ccf1f0 100644 --- a/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala +++ b/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala @@ -20,14 +20,14 @@ import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings import org.apache.gluten.memory.CHThreadGroup import org.apache.gluten.vectorized._ - import org.apache.spark._ import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.celeborn.CelebornShuffleHandle - import org.apache.celeborn.client.ShuffleClient import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.protocol.ShuffleMode +import org.apache.gluten.execution.ColumnarNativeIterator +import org.apache.spark.sql.vectorized.ColumnarBatch import java.io.IOException import java.util.Locale @@ -57,7 +57,20 @@ class CHCelebornColumnarShuffleWriter[K, V]( @throws[IOException] override def internalWrite(records: Iterator[Product2[K, V]]): Unit = { CHThreadGroup.registerNewThreadGroup() + // for fallback + val iter = new ColumnarNativeIterator(new java.util.Iterator[ColumnarBatch] { + override def hasNext: Boolean = { + val has_value = records.hasNext + has_value + } + + override def next(): ColumnarBatch = { + val batch = records.next()._2.asInstanceOf[ColumnarBatch] + batch + } + }) nativeShuffleWriter = jniWrapper.makeForRSS( + iter, dep.nativePartitioning, shuffleId, mapId, @@ -71,7 +84,6 @@ class CHCelebornColumnarShuffleWriter[K, V]( || ShuffleMode.SORT.name.equalsIgnoreCase(shuffleWriterType) ) - val startTime = System.nanoTime() splitResult = jniWrapper.stop(nativeShuffleWriter) // If all of the ColumnarBatch have empty rows, the nativeShuffleWriter still equals -1 if (splitResult.getTotalRows == 0) {