diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala index 252b9bc03fd68..3321ae8801799 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} -import org.apache.spark.sql.execution.adaptive._ import org.apache.spark.sql.execution.joins.BuildSideRelation import org.apache.spark.sql.vectorized.ColumnarBatch @@ -115,9 +114,6 @@ case class CHShuffledHashJoinExecTransformer( override def genJoinParameters(): Any = { val (isBHJ, isNullAwareAntiJoin, buildHashTableId): (Int, Int, String) = (0, 0, "") - // Don't use lef/right directly, they may be reordered in `HashJoinLikeExecTransformer` - val leftStats = getShuffleStageStatistics(streamedPlan) - val rightStats = getShuffleStageStatistics(buildPlan) // Start with "JoinParameters:" val joinParametersStr = new StringBuffer("JoinParameters:") // isBHJ: 0 for SHJ, 1 for BHJ @@ -138,14 +134,12 @@ case class CHShuffledHashJoinExecTransformer( .append("\n") logicalLink match { case Some(join: Join) => - val leftRowCount = - if (needSwitchChildren) join.left.stats.rowCount else join.right.stats.rowCount - val rightRowCount = - if (needSwitchChildren) join.right.stats.rowCount else join.left.stats.rowCount - val leftSizeInBytes = - if (needSwitchChildren) join.left.stats.sizeInBytes else join.right.stats.sizeInBytes - val rightSizeInBytes = - if (needSwitchChildren) join.right.stats.sizeInBytes else join.left.stats.sizeInBytes + val left = if (!needSwitchChildren) join.left else join.right + val right = if (!needSwitchChildren) join.right else join.left + val leftRowCount = left.stats.rowCount + val rightRowCount = right.stats.rowCount + val leftSizeInBytes = left.stats.sizeInBytes + val rightSizeInBytes = right.stats.sizeInBytes val numPartitions = outputPartitioning.numPartitions joinParametersStr .append("leftRowCount=") @@ -171,26 +165,6 @@ case class CHShuffledHashJoinExecTransformer( .build() BackendsApiManager.getTransformerApiInstance.packPBMessage(message) } - - private def getShuffleStageStatistics(plan: SparkPlan): ShuffleStageStaticstics = { - plan match { - case queryStage: ShuffleQueryStageExec => - ShuffleStageStaticstics( - queryStage.shuffle.numPartitions, - queryStage.shuffle.numMappers, - queryStage.getRuntimeStatistics.rowCount) - case shuffle: ColumnarShuffleExchangeExec => - // FIXEME: We cannot access shuffle.numPartitions and shuffle.numMappers here. - // Otherwise it will cause an exception `ProjectExecTransformer has column support mismatch` - ShuffleStageStaticstics(-1, -1, None) - case _ => - if (plan.children.length == 1) { - getShuffleStageStatistics(plan.children.head) - } else { - ShuffleStageStaticstics(-1, -1, None) - } - } - } } case class CHBroadcastBuildSideRDD( diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala index ebeb69c16350a..94d9b6f015ffc 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala @@ -265,9 +265,9 @@ class GlutenClickHouseColumnarShuffleAQESuite test("GLUTEN-6768 change mixed join condition into multi join on clauses") { withSQLConf( - (backendConfigPrefix + "runtime_config.prefer_inequal_join_to_multi_join_on_clauses", "true"), + (backendConfigPrefix + "runtime_config.prefer_multi_join_on_clauses", "true"), ( - backendConfigPrefix + "runtime_config.inequal_join_to_multi_join_on_clauses_row_limit", + backendConfigPrefix + "runtime_config.multi_join_on_clauses_build_side_row_limit", "1000000") ) { diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.cpp b/cpp-ch/local-engine/Parser/JoinRelParser.cpp index fbab2609412ab..70e051b02cadc 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.cpp +++ b/cpp-ch/local-engine/Parser/JoinRelParser.cpp @@ -315,7 +315,7 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q if (table_join->getClauses().empty()) table_join->addDisjunct(); bool is_multi_join_on_clauses - = isJoinWithMultiJoinOnClauses(table_join->getOnlyClause(), join_on_clauses, join, left_header, right_header); + = couldRewriteToMultiJoinOnClauses(table_join->getOnlyClause(), join_on_clauses, join, left_header, right_header); if (is_multi_join_on_clauses && join_config.prefer_multi_join_on_clauses && join_opt_info.right_table_rows > 0 && join_opt_info.partitions_num > 0 && join_opt_info.right_table_rows / join_opt_info.partitions_num @@ -611,14 +611,14 @@ void JoinRelParser::addPostFilter(DB::QueryPlan & query_plan, const substrait::J } /// Only support following pattern: a1 = b1 or a2 = b2 or (a3 = b3 and a4 = b4) -bool JoinRelParser::isJoinWithMultiJoinOnClauses( +bool JoinRelParser::couldRewriteToMultiJoinOnClauses( const DB::TableJoin::JoinOnClause & prefix_clause, std::vector & clauses, const substrait::JoinRel & join_rel, const DB::Block & left_header, const DB::Block & right_header) { - /// There is only on join clause + /// There is only one join clause if (!join_rel.has_post_join_filter()) return false; diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.h b/cpp-ch/local-engine/Parser/JoinRelParser.h index 0c0d07d6fdd27..7e43187be308b 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.h +++ b/cpp-ch/local-engine/Parser/JoinRelParser.h @@ -72,7 +72,7 @@ class JoinRelParser : public RelParser static std::unordered_set extractTableSidesFromExpression( const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header); - bool isJoinWithMultiJoinOnClauses( + bool couldRewriteToMultiJoinOnClauses( const DB::TableJoin::JoinOnClause & prefix_clause, std::vector & clauses, const substrait::JoinRel & join_rel,