From db7d4be2290ab3ff3994d0deebd9924e4fc4c2d5 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Tue, 20 Aug 2024 16:44:34 +0800 Subject: [PATCH 1/3] fix: total_bytes_written is not updated in celeborn partition writer --- cpp-ch/local-engine/Shuffle/PartitionWriter.cpp | 2 ++ cpp-ch/local-engine/local_engine_jni.cpp | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp b/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp index 2f22d0e24139..79d640d3b2bc 100644 --- a/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp +++ b/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp @@ -469,6 +469,7 @@ size_t MemorySortCelebornPartitionWriter::evictPartitions() celeborn_client->pushPartitionData(cur_partition_id, data.data(), data.size()); shuffle_writer->split_result.total_io_time += push_time_watch.elapsedNanoseconds(); shuffle_writer->split_result.partition_lengths[cur_partition_id] += data.size(); + shuffle_writer->split_result.total_bytes_written += data.size(); } output.restart(); }; @@ -586,6 +587,7 @@ size_t CelebornPartitionWriter::evictSinglePartition(size_t partition_id) shuffle_writer->split_result.total_write_time += push_time_watch.elapsedNanoseconds(); shuffle_writer->split_result.total_io_time += push_time_watch.elapsedNanoseconds(); shuffle_writer->split_result.total_serialize_time += serialization_time_watch.elapsedNanoseconds(); + shuffle_writer->split_result.total_bytes_written += written_bytes; }; Stopwatch spill_time_watch; diff --git a/cpp-ch/local-engine/local_engine_jni.cpp b/cpp-ch/local-engine/local_engine_jni.cpp index 828556b4abf6..ce536799d94a 100644 --- a/cpp-ch/local-engine/local_engine_jni.cpp +++ b/cpp-ch/local-engine/local_engine_jni.cpp @@ -677,6 +677,11 @@ JNIEXPORT jobject Java_org_apache_gluten_vectorized_CHShuffleSplitterJniWrapper_ const auto * raw_src = reinterpret_cast(raw_partition_lengths.data()); env->SetLongArrayRegion(raw_partition_length_arr, 0, raw_partition_lengths.size(), raw_src); + // AQE has dependency on total_bytes_written, if the data is wrong, it will generate inappropriate plan + // add a log here for remining this. + if (!result.total_bytes_written) + LOG_WARNING(getLogger("_CHShuffleSplitterJniWrapper"), "total_bytes_written is 0, something may be wrong"); + jobject split_result = env->NewObject( split_result_class, split_result_constructor, From e7bbd400b6d4faa81091c3856c5eb5010cc69834 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Tue, 20 Aug 2024 18:10:59 +0800 Subject: [PATCH 2/3] add checks --- .../backendsapi/clickhouse/CHBackend.scala | 7 +++++++ .../clickhouse/CHSparkPlanExecApi.scala | 17 +++++++++++++++-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index 9884a0c6ef39..ffd9068b166a 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -407,4 +407,11 @@ object CHBackendSettings extends BackendSettingsApi with Logging { } } + def getBroadcastThreshold: Long = { + val conf = SQLConf.get + conf + .getConf(SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD) + .getOrElse(conf.autoBroadcastJoinThreshold) + } + } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 8fdc2645a5fb..cb261e6e416c 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -18,6 +18,7 @@ package org.apache.gluten.backendsapi.clickhouse import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.{BackendsApiManager, SparkPlanExecApi} +import org.apache.gluten.exception.GlutenException import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.execution._ import org.apache.gluten.expression._ @@ -31,7 +32,7 @@ import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode import org.apache.gluten.utils.{CHJoinValidateUtil, UnknownJoinStrategy} import org.apache.gluten.vectorized.CHColumnarBatchSerializer -import org.apache.spark.{ShuffleDependency, SparkException} +import org.apache.spark.ShuffleDependency import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{GenShuffleWriterParameters, GlutenShuffleWriterWrapper, HashPartitioningWrapper} @@ -539,9 +540,21 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { CHExecUtil.buildSideRDD(dataSize, newChild).collect val batches = countsAndBytes.map(_._2) + val totalBatchesBytes = batches.map(_.length).sum + // totalBatchesBytes could be larger than the shuffle written bytes, so we double the threshold + // here. + if ( + totalBatchesBytes < 0 || + totalBatchesBytes.toLong > CHBackendSettings.getBroadcastThreshold * 2 + ) { + throw new GlutenException( + s"Cannot broadcast the table ($totalBatchesBytes) that is larger than threshold:" + + s" ${CHBackendSettings.getBroadcastThreshold}. Ensure the shuffle written" + + s"bytes is collected properly.") + } val rawSize = dataSize.value if (rawSize >= BroadcastExchangeExec.MAX_BROADCAST_TABLE_BYTES) { - throw new SparkException( + throw new GlutenException( s"Cannot broadcast the table that is larger than 8GB: ${rawSize >> 30} GB") } val rowCount = countsAndBytes.map(_._1).sum From aec2764ea857bef5a2e46146b5386a27f4ba51b8 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Tue, 20 Aug 2024 20:32:24 +0800 Subject: [PATCH 3/3] remove check it's not safet --- .../backendsapi/clickhouse/CHBackend.scala | 8 ------- .../clickhouse/CHSparkPlanExecApi.scala | 23 ++++++++----------- cpp-ch/local-engine/local_engine_jni.cpp | 2 +- 3 files changed, 10 insertions(+), 23 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index ffd9068b166a..d0dbd98a88c2 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -406,12 +406,4 @@ object CHBackendSettings extends BackendSettingsApi with Logging { } } } - - def getBroadcastThreshold: Long = { - val conf = SQLConf.get - conf - .getConf(SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD) - .getOrElse(conf.autoBroadcastJoinThreshold) - } - } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index cb261e6e416c..99a417a5f9cd 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -33,6 +33,7 @@ import org.apache.gluten.utils.{CHJoinValidateUtil, UnknownJoinStrategy} import org.apache.gluten.vectorized.CHColumnarBatchSerializer import org.apache.spark.ShuffleDependency +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{GenShuffleWriterParameters, GlutenShuffleWriterWrapper, HashPartitioningWrapper} @@ -70,7 +71,7 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -class CHSparkPlanExecApi extends SparkPlanExecApi { +class CHSparkPlanExecApi extends SparkPlanExecApi with Logging { /** The columnar-batch type this backend is using. */ override def batchType: Convention.BatchType = CHBatch @@ -540,22 +541,16 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { CHExecUtil.buildSideRDD(dataSize, newChild).collect val batches = countsAndBytes.map(_._2) - val totalBatchesBytes = batches.map(_.length).sum - // totalBatchesBytes could be larger than the shuffle written bytes, so we double the threshold - // here. - if ( - totalBatchesBytes < 0 || - totalBatchesBytes.toLong > CHBackendSettings.getBroadcastThreshold * 2 - ) { - throw new GlutenException( - s"Cannot broadcast the table ($totalBatchesBytes) that is larger than threshold:" + - s" ${CHBackendSettings.getBroadcastThreshold}. Ensure the shuffle written" + - s"bytes is collected properly.") - } + val totalBatchesSize = batches.map(_.length).sum val rawSize = dataSize.value if (rawSize >= BroadcastExchangeExec.MAX_BROADCAST_TABLE_BYTES) { throw new GlutenException( - s"Cannot broadcast the table that is larger than 8GB: ${rawSize >> 30} GB") + s"Cannot broadcast the table that is larger than 8GB: $rawSize bytes") + } + if ((rawSize == 0 && totalBatchesSize != 0) || totalBatchesSize < 0) { + throw new GlutenException( + s"Invalid rawSize($rawSize) or totalBatchesSize ($totalBatchesSize). Ensure the shuffle" + + s" written bytes is correct.") } val rowCount = countsAndBytes.map(_._1).sum numOutputRows += rowCount diff --git a/cpp-ch/local-engine/local_engine_jni.cpp b/cpp-ch/local-engine/local_engine_jni.cpp index ce536799d94a..9727fca1937d 100644 --- a/cpp-ch/local-engine/local_engine_jni.cpp +++ b/cpp-ch/local-engine/local_engine_jni.cpp @@ -680,7 +680,7 @@ JNIEXPORT jobject Java_org_apache_gluten_vectorized_CHShuffleSplitterJniWrapper_ // AQE has dependency on total_bytes_written, if the data is wrong, it will generate inappropriate plan // add a log here for remining this. if (!result.total_bytes_written) - LOG_WARNING(getLogger("_CHShuffleSplitterJniWrapper"), "total_bytes_written is 0, something may be wrong"); + LOG_WARNING(getLogger("CHShuffleSplitterJniWrapper"), "total_bytes_written is 0, something may be wrong"); jobject split_result = env->NewObject( split_result_class,