From ed407907a29484abe59319b937fa92ac7524fb63 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Mon, 12 Aug 2024 16:42:46 +0800 Subject: [PATCH 1/4] try to use multi join on clause as possible --- .../execution/CHHashJoinExecTransformer.scala | 75 +++++ ...tenClickHouseColumnarShuffleAQESuite.scala | 47 +++ cpp-ch/local-engine/Common/GlutenConfig.h | 21 ++ .../Parser/AdvancedParametersParseUtil.cpp | 16 + .../Parser/AdvancedParametersParseUtil.h | 6 + cpp-ch/local-engine/Parser/JoinRelParser.cpp | 302 ++++++++++++++---- cpp-ch/local-engine/Parser/JoinRelParser.h | 19 ++ 7 files changed, 432 insertions(+), 54 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala index 7080e55dc186..15e21681b1c6 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala @@ -16,6 +16,7 @@ */ package org.apache.gluten.execution +import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.extension.ValidationResult import org.apache.gluten.utils.{BroadcastHashJoinStrategy, CHJoinValidateUtil, ShuffleHashJoinStrategy} @@ -25,10 +26,13 @@ import org.apache.spark.rpc.GlutenDriverEndpoint import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.adaptive._ import org.apache.spark.sql.execution.joins.BuildSideRelation import org.apache.spark.sql.vectorized.ColumnarBatch +import com.google.protobuf.{Any, StringValue} import io.substrait.proto.JoinRel object JoinTypeTransform { @@ -66,6 +70,8 @@ object JoinTypeTransform { } } +case class ShuffleStageStaticstics(numPartitions: Int, numMappers: Int, rowCount: Option[BigInt]) + case class CHShuffledHashJoinExecTransformer( leftKeys: Seq[Expression], rightKeys: Seq[Expression], @@ -104,6 +110,75 @@ case class CHShuffledHashJoinExecTransformer( private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType) override protected lazy val substraitJoinType: JoinRel.JoinType = JoinTypeTransform.toSubstraitType(joinType, buildSide) + + override def genJoinParameters(): Any = { + val (isBHJ, isNullAwareAntiJoin, buildHashTableId): (Int, Int, String) = (0, 0, "") + + // Don't use lef/right directly, they may be reordered in `HashJoinLikeExecTransformer` + val leftStats = getShuffleStageStatistics(streamedPlan) + val rightStats = getShuffleStageStatistics(buildPlan) + // Start with "JoinParameters:" + val joinParametersStr = new StringBuffer("JoinParameters:") + // isBHJ: 0 for SHJ, 1 for BHJ + // isNullAwareAntiJoin: 0 for false, 1 for true + // buildHashTableId: the unique id for the hash table of build plan + joinParametersStr + .append("isBHJ=") + .append(isBHJ) + .append("\n") + .append("isNullAwareAntiJoin=") + .append(isNullAwareAntiJoin) + .append("\n") + .append("buildHashTableId=") + .append(buildHashTableId) + .append("\n") + .append("isExistenceJoin=") + .append(if (joinType.isInstanceOf[ExistenceJoin]) 1 else 0) + .append("\n") + .append("leftRowCount=") + .append(leftStats.rowCount.getOrElse(-1)) + .append("\n") + .append("leftNumPartitions=") + .append(leftStats.numPartitions) + .append("\n") + .append("leftNumMappers=") + .append(leftStats.numMappers) + .append("\n") + .append("rightRowCount=") + .append(rightStats.rowCount.getOrElse(-1)) + .append("\n") + .append("rightNumPartitions=") + .append(rightStats.numPartitions) + .append("\n") + .append("rightNumMappers=") + .append(rightStats.numMappers) + .append("\n") + val message = StringValue + .newBuilder() + .setValue(joinParametersStr.toString) + .build() + BackendsApiManager.getTransformerApiInstance.packPBMessage(message) + } + + private def getShuffleStageStatistics(plan: SparkPlan): ShuffleStageStaticstics = { + plan match { + case queryStage: ShuffleQueryStageExec => + ShuffleStageStaticstics( + queryStage.shuffle.numPartitions, + queryStage.shuffle.numMappers, + queryStage.getRuntimeStatistics.rowCount) + case shuffle: ColumnarShuffleExchangeExec => + // FIXEME: We cannot access shuffle.numPartitions and shuffle.numMappers here. + // Otherwise it will cause an exception `ProjectExecTransformer has column support mismatch` + ShuffleStageStaticstics(-1, -1, None) + case _ => + if (plan.children.length == 1) { + getShuffleStageStatistics(plan.children.head) + } else { + ShuffleStageStaticstics(-1, -1, None) + } + } + } } case class CHBroadcastBuildSideRDD( diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala index fc22add2d880..ebeb69c16350 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala @@ -32,6 +32,7 @@ class GlutenClickHouseColumnarShuffleAQESuite override protected val tablesPath: String = basePath + "/tpch-data-ch" override protected val tpchQueries: String = rootPath + "queries/tpch-queries-ch" override protected val queriesResults: String = rootPath + "mergetree-queries-output" + private val backendConfigPrefix = "spark.gluten.sql.columnar.backend.ch." /** Run Gluten + ClickHouse Backend with ColumnarShuffleManager */ override protected def sparkConf: SparkConf = { @@ -261,4 +262,50 @@ class GlutenClickHouseColumnarShuffleAQESuite spark.sql("drop table t2") } } + + test("GLUTEN-6768 change mixed join condition into multi join on clauses") { + withSQLConf( + (backendConfigPrefix + "runtime_config.prefer_inequal_join_to_multi_join_on_clauses", "true"), + ( + backendConfigPrefix + "runtime_config.inequal_join_to_multi_join_on_clauses_row_limit", + "1000000") + ) { + + spark.sql("create table t1(a int, b int, c int, d int) using parquet") + spark.sql("create table t2(a int, b int, c int, d int) using parquet") + + spark.sql(""" + |insert into t1 + |select id % 2 as a, id as b, id + 1 as c, id + 2 as d from range(1000) + |""".stripMargin) + spark.sql(""" + |insert into t2 + |select id % 2 as a, id as b, id + 1 as c, id + 2 as d from range(1000) + |""".stripMargin) + + var sql = """ + |select * from t1 join t2 on + |t1.a = t2.a and (t1.b = t2.b or t1.c = t2.c or t1.d = t2.d) + |order by t1.a, t1.b, t1.c, t1.d + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + + sql = """ + |select * from t1 join t2 on + |t1.a = t2.a and (t1.b = t2.b or t1.c = t2.c or (t1.c = t2.c and t1.d = t2.d)) + |order by t1.a, t1.b, t1.c, t1.d + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + + sql = """ + |select * from t1 join t2 on + |t1.a = t2.a and (t1.b = t2.b or t1.c = t2.c or (t1.d = t2.d and t1.c >= t2.c)) + |order by t1.a, t1.b, t1.c, t1.d + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + + spark.sql("drop table t1") + spark.sql("drop table t2") + } + } } diff --git a/cpp-ch/local-engine/Common/GlutenConfig.h b/cpp-ch/local-engine/Common/GlutenConfig.h index abb7295adc0d..f8f0f41fe21a 100644 --- a/cpp-ch/local-engine/Common/GlutenConfig.h +++ b/cpp-ch/local-engine/Common/GlutenConfig.h @@ -92,6 +92,27 @@ struct StreamingAggregateConfig } }; +struct JoinConfig +{ + /// If the join condition is like `t1.k = t2.k and (t1.id1 = t2.id2 or t1.id2 = t2.id2)`, try to join with multi + /// join on clauses `(t1.k = t2.k and t1.id1 = t2.id2) or (t1.k = t2.k or t1.id2 = t2.id2)` + inline static const String PREFER_INEQUAL_JOIN_TO_MULTI_JOIN_ON_CLAUSES = "prefer_inequal_join_to_multi_join_on_clauses"; + /// Only hash join supports multi join on clauses, the right table cannot be too large. If the row number of right + /// table is larger then this limit, this transform will not work. + inline static const String INEQUAL_JOIN_TO_MULTI_JOIN_ON_CLAUSES_ROWS_LIMIT = "inequal_join_to_multi_join_on_clauses_row_limit"; + + bool prefer_inequal_join_to_multi_join_on_clauses = true; + size_t inequal_join_to_multi_join_on_clauses_rows_limit = 10000000; + + static JoinConfig loadFromContext(DB::ContextPtr context) + { + JoinConfig config; + config.prefer_inequal_join_to_multi_join_on_clauses = context->getConfigRef().getBool(PREFER_INEQUAL_JOIN_TO_MULTI_JOIN_ON_CLAUSES, true); + config.inequal_join_to_multi_join_on_clauses_rows_limit = context->getConfigRef().getUInt64(INEQUAL_JOIN_TO_MULTI_JOIN_ON_CLAUSES_ROWS_LIMIT, 10000000); + return config; + } +}; + struct ExecutorConfig { inline static const String DUMP_PIPELINE = "dump_pipeline"; diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp index a7a07c0bf31a..59de24de5275 100644 --- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp +++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp @@ -57,6 +57,16 @@ void tryAssign(const std::unordered_map & kvs, const Strin } } +template<> +void tryAssign(const std::unordered_map & kvs, const String & key, Int64 & v) +{ + auto it = kvs.find(key); + if (it != kvs.end()) + { + v = std::stol(it->second); + } +} + template void readStringUntilCharsInto(String & s, DB::ReadBuffer & buf) { @@ -121,6 +131,12 @@ JoinOptimizationInfo JoinOptimizationInfo::parse(const String & advance) tryAssign(kvs, "buildHashTableId", info.storage_join_key); tryAssign(kvs, "isNullAwareAntiJoin", info.is_null_aware_anti_join); tryAssign(kvs, "isExistenceJoin", info.is_existence_join); + tryAssign(kvs, "leftRowCount", info.left_table_rows); + tryAssign(kvs, "leftNumPartitions", info.left_table_partitions_num); + tryAssign(kvs, "leftNumMappers", info.left_table_mappers_num); + tryAssign(kvs, "rightRowCount", info.right_table_rows); + tryAssign(kvs, "rightNumPartitions", info.right_table_partitions_num); + tryAssign(kvs, "rightNumMappers", info.right_table_mappers_num); return info; } } diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h index 5a15a3ea8abc..08bd520760d7 100644 --- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h +++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h @@ -29,6 +29,12 @@ struct JoinOptimizationInfo bool is_smj = false; bool is_null_aware_anti_join = false; bool is_existence_join = false; + Int64 left_table_rows = -1; + Int64 left_table_partitions_num = -1; + Int64 left_table_mappers_num = -1; + Int64 right_table_rows = -1; + Int64 right_table_partitions_num = -1; + Int64 right_table_mappers_num = -1; String storage_join_key; static JoinOptimizationInfo parse(const String & advance); diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.cpp b/cpp-ch/local-engine/Parser/JoinRelParser.cpp index b217a9bd9da2..02691fceb7ba 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.cpp +++ b/cpp-ch/local-engine/Parser/JoinRelParser.cpp @@ -16,6 +16,8 @@ */ #include "JoinRelParser.h" +#include +#include #include #include #include @@ -25,15 +27,15 @@ #include #include #include -#include #include +#include #include #include #include #include #include #include -#include +#include #include @@ -42,9 +44,9 @@ namespace DB { namespace ErrorCodes { - extern const int LOGICAL_ERROR; - extern const int UNKNOWN_TYPE; - extern const int BAD_ARGUMENTS; +extern const int LOGICAL_ERROR; +extern const int UNKNOWN_TYPE; +extern const int BAD_ARGUMENTS; } } using namespace DB; @@ -98,7 +100,8 @@ DB::QueryPlanPtr JoinRelParser::parseOp(const substrait::Rel & rel, std::list JoinRelParser::extractTableSidesFromExpression(const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header) +std::unordered_set JoinRelParser::extractTableSidesFromExpression( + const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header) { std::unordered_set table_sides; if (expr.has_scalar_function()) @@ -169,8 +172,7 @@ void JoinRelParser::renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & righ storage_join.getRightSampleBlock().getColumnsWithTypeAndName(), ActionsDAG::MatchColumnsMode::Position); - QueryPlanStepPtr right_project_step = - std::make_unique(right.getCurrentDataStream(), std::move(right_project)); + QueryPlanStepPtr right_project_step = std::make_unique(right.getCurrentDataStream(), std::move(right_project)); right_project_step->setStepDescription("Rename Broadcast Table Name"); steps.emplace_back(right_project_step.get()); right.addStep(std::move(right_project_step)); @@ -193,12 +195,9 @@ void JoinRelParser::renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & righ } } ActionsDAG left_project = ActionsDAG::makeConvertingActions( - left.getCurrentDataStream().header.getColumnsWithTypeAndName(), - new_left_cols, - ActionsDAG::MatchColumnsMode::Position); + left.getCurrentDataStream().header.getColumnsWithTypeAndName(), new_left_cols, ActionsDAG::MatchColumnsMode::Position); - QueryPlanStepPtr left_project_step = - std::make_unique(left.getCurrentDataStream(), std::move(left_project)); + QueryPlanStepPtr left_project_step = std::make_unique(left.getCurrentDataStream(), std::move(left_project)); left_project_step->setStepDescription("Rename Left Table Name for broadcast join"); steps.emplace_back(left_project_step.get()); left.addStep(std::move(left_project_step)); @@ -206,9 +205,11 @@ void JoinRelParser::renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & righ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::QueryPlanPtr left, DB::QueryPlanPtr right) { + auto join_config = JoinConfig::loadFromContext(getContext()); google::protobuf::StringValue optimization_info; optimization_info.ParseFromString(join.advanced_extension().optimization().value()); auto join_opt_info = JoinOptimizationInfo::parse(optimization_info.value()); + LOG_ERROR(getLogger("JoinRelParser"), "{}", optimization_info.value()); auto storage_join = join_opt_info.is_broadcast ? BroadCastJoinBuilder::getJoin(join_opt_info.storage_join_key) : nullptr; if (storage_join) { @@ -239,7 +240,9 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q } if (is_col_names_changed) { - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "For broadcast join, we must not change the columns name in the right table.\nleft header:{},\nright header: {} -> {}", + throw DB::Exception( + DB::ErrorCodes::LOGICAL_ERROR, + "For broadcast join, we must not change the columns name in the right table.\nleft header:{},\nright header: {} -> {}", left->getCurrentDataStream().header.dumpStructure(), right_header_before_convert_step.dumpStructure(), right->getCurrentDataStream().header.dumpStructure()); @@ -266,7 +269,6 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q if (storage_join) { - applyJoinFilter(*table_join, join, *left, *right, true); auto broadcast_hash_join = storage_join->getJoinLocked(table_join, context); @@ -288,15 +290,13 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q /// TODO: make smj support mixed conditions if (need_post_filter && table_join->kind() != DB::JoinKind::Inner) { - throw DB::Exception( - DB::ErrorCodes::LOGICAL_ERROR, - "Sort merge join doesn't support mixed join conditions, except inner join."); + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Sort merge join doesn't support mixed join conditions, except inner join."); } JoinPtr smj_join = std::make_shared(table_join, right->getCurrentDataStream().header.cloneEmpty(), -1); MultiEnum join_algorithm = context->getSettingsRef().join_algorithm; QueryPlanStepPtr join_step - = std::make_unique(left->getCurrentDataStream(), right->getCurrentDataStream(), smj_join, 8192, 1, false); + = std::make_unique(left->getCurrentDataStream(), right->getCurrentDataStream(), smj_join, 8192, 1, false); join_step->setStepDescription("SORT_MERGE_JOIN"); steps.emplace_back(join_step.get()); @@ -311,41 +311,22 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q } else { - applyJoinFilter(*table_join, join, *left, *right, true); - - /// Following is some configurations for grace hash join. - /// - spark.gluten.sql.columnar.backend.ch.runtime_settings.join_algorithm=grace_hash. This will - /// enable grace hash join. - /// - spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_in_join=3145728. This setup - /// the memory limitation fro grace hash join. If the memory consumption exceeds the limitation, - /// data will be spilled to disk. Don't set the limitation too small, otherwise the buckets number - /// will be too large and the performance will be bad. - JoinPtr hash_join = nullptr; - MultiEnum join_algorithm = context->getSettingsRef().join_algorithm; - if (join_algorithm.isSet(DB::JoinAlgorithm::GRACE_HASH)) + std::vector join_on_clauses; + if (table_join->getClauses().empty()) + table_join->addDisjunct(); + bool is_multi_join_on_clauses + = isJoinWithMultiJoinOnClauses(table_join->getOnlyClause(), join_on_clauses, join, left_header, right_header); + if (is_multi_join_on_clauses && join_config.prefer_inequal_join_to_multi_join_on_clauses && join_opt_info.right_table_rows > 0 + && join_opt_info.right_table_mappers_num > 0 + && join_opt_info.right_table_rows / join_opt_info.right_table_mappers_num + < join_config.inequal_join_to_multi_join_on_clauses_rows_limit) { - hash_join = std::make_shared( - context, - table_join, - left->getCurrentDataStream().header, - right->getCurrentDataStream().header, - context->getTempDataOnDisk()); + query_plan = buildMultiOnClauseHashJoin(table_join, std::move(left), std::move(right), join_on_clauses); } else { - hash_join = std::make_shared(table_join, right->getCurrentDataStream().header.cloneEmpty()); + query_plan = buildSingleOnClauseHashJoin(join, table_join, std::move(left), std::move(right)); } - QueryPlanStepPtr join_step - = std::make_unique(left->getCurrentDataStream(), right->getCurrentDataStream(), hash_join, 8192, 1, false); - - join_step->setStepDescription("HASH_JOIN"); - steps.emplace_back(join_step.get()); - std::vector plans; - plans.emplace_back(std::move(left)); - plans.emplace_back(std::move(right)); - - query_plan = std::make_unique(); - query_plan->unitePlans(std::move(join_step), {std::move(plans)}); } JoinUtil::reorderJoinOutput(*query_plan, after_join_names); @@ -508,7 +489,11 @@ void JoinRelParser::collectJoinKeys( } bool JoinRelParser::applyJoinFilter( - DB::TableJoin & table_join, const substrait::JoinRel & join_rel, DB::QueryPlan & left, DB::QueryPlan & right, bool allow_mixed_condition) + DB::TableJoin & table_join, + const substrait::JoinRel & join_rel, + DB::QueryPlan & left, + DB::QueryPlan & right, + bool allow_mixed_condition) { if (!join_rel.has_post_join_filter()) return true; @@ -594,12 +579,13 @@ bool JoinRelParser::applyJoinFilter( return false; auto mixed_join_expressions_actions = expressionsToActionsDAG({expr}, mixed_header); mixed_join_expressions_actions.removeUnusedActions(); - table_join.getMixedJoinExpression() - = std::make_shared(std::move(mixed_join_expressions_actions), ExpressionActionsSettings::fromContext(context)); + table_join.getMixedJoinExpression() = std::make_shared( + std::move(mixed_join_expressions_actions), ExpressionActionsSettings::fromContext(context)); } else { - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Not any table column is used in the join condition.\n{}", join_rel.DebugString()); + throw DB::Exception( + DB::ErrorCodes::LOGICAL_ERROR, "Not any table column is used in the join condition.\n{}", join_rel.DebugString()); } return true; } @@ -610,7 +596,7 @@ void JoinRelParser::addPostFilter(DB::QueryPlan & query_plan, const substrait::J ActionsDAG actions_dag{query_plan.getCurrentDataStream().header.getColumnsWithTypeAndName()}; if (!join.post_join_filter().has_scalar_function()) { - // It may be singular_or_list + // It may be singular_or_list auto * in_node = getPlanParser()->parseExpression(actions_dag, join.post_join_filter()); filter_name = in_node->result_name; } @@ -624,6 +610,214 @@ void JoinRelParser::addPostFilter(DB::QueryPlan & query_plan, const substrait::J query_plan.addStep(std::move(filter_step)); } +/// Only support following pattern: a1 = b1 or a2 = b2 or (a3 = b3 and a4 = b4) +bool JoinRelParser::isJoinWithMultiJoinOnClauses( + const DB::TableJoin::JoinOnClause & prefix_clause, + std::vector & clauses, + const substrait::JoinRel & join_rel, + const DB::Block & left_header, + const DB::Block & right_header) +{ + /// There is only on join clause + if (!join_rel.has_post_join_filter()) + return false; + + const auto & filter_expr = join_rel.post_join_filter(); + std::list expression_stack; + expression_stack.push_back(&filter_expr); + + auto check_function = [&](const String function_name_, const substrait::Expression & e) + { + if (!e.has_scalar_function()) + { + return false; + } + auto function_name = parseFunctionName(e.scalar_function().function_reference(), e.scalar_function()); + return function_name.has_value() && *function_name == function_name_; + }; + + auto get_field_ref = [](const substrait::Expression & e) -> std::optional + { + if (e.has_selection() && e.selection().has_direct_reference() && e.selection().direct_reference().has_struct_field()) + { + return std::optional(e.selection().direct_reference().struct_field().field()); + } + return {}; + }; + + auto parse_join_keys = [&](const substrait::Expression & e) -> std::optional> + { + const auto & args = e.scalar_function().arguments(); + auto l_field_ref = get_field_ref(args[0].value()); + auto r_field_ref = get_field_ref(args[1].value()); + if (!l_field_ref.has_value() || !r_field_ref.has_value()) + return {}; + size_t l_pos = static_cast(*l_field_ref); + size_t r_pos = static_cast(*r_field_ref); + size_t l_cols = left_header.columns(); + size_t total_cols = l_cols + right_header.columns(); + + if (l_pos < l_cols && r_pos >= l_cols && r_pos < total_cols) + return std::make_pair(left_header.getByPosition(l_pos).name, right_header.getByPosition(r_pos - l_cols).name); + else if (r_pos < l_cols && l_pos >= l_cols && l_pos < total_cols) + return std::make_pair(left_header.getByPosition(r_pos).name, right_header.getByPosition(l_pos - l_cols).name); + return {}; + }; + + auto parse_and_expression = [&](const substrait::Expression & e, DB::TableJoin::JoinOnClause & join_on_clause) + { + std::vector and_expression_stack; + and_expression_stack.push_back(&e); + while (!and_expression_stack.empty()) + { + const auto & current_expr = *(and_expression_stack.back()); + and_expression_stack.pop_back(); + if (check_function("and", current_expr)) + { + for (const auto & arg : e.scalar_function().arguments()) + and_expression_stack.push_back(&arg.value()); + } + else if (check_function("equals", current_expr)) + { + auto optional_keys = parse_join_keys(current_expr); + if (!optional_keys) + { + LOG_ERROR(getLogger("JoinRelParser"), "Not equal comparison for keys from both tables"); + return false; + } + join_on_clause.addKey(optional_keys->first, optional_keys->second, false); + } + else + { + LOG_ERROR(getLogger("JoinRelParser"), "And or equals function is expected"); + return false; + } + } + return true; + }; + + while (!expression_stack.empty()) + { + const auto & current_expr = *(expression_stack.back()); + expression_stack.pop_back(); + if (!check_function("or", current_expr)) + { + LOG_ERROR(getLogger("JoinRelParser"), "Not an or expression"); + } + + auto get_current_join_on_clause = [&]() + { + DB::TableJoin::JoinOnClause new_clause = prefix_clause; + clauses.push_back(new_clause); + return &clauses.back(); + }; + + const auto & args = current_expr.scalar_function().arguments(); + for (const auto & arg : args) + { + if (check_function("equals", arg.value())) + { + auto optional_keys = parse_join_keys(arg.value()); + if (!optional_keys) + { + LOG_ERROR(getLogger("JoinRelParser"), "Not equal comparison for keys from both tables"); + return false; + } + get_current_join_on_clause()->addKey(optional_keys->first, optional_keys->second, false); + } + else if (check_function("and", arg.value())) + { + if (!parse_and_expression(arg.value(), *get_current_join_on_clause())) + { + LOG_ERROR(getLogger("JoinRelParser"), "Parse and expression failed"); + return false; + } + } + else if (check_function("or", arg.value())) + { + expression_stack.push_back(&arg.value()); + } + else + { + LOG_ERROR(getLogger("JoinRelParser"), "Unknow function"); + return false; + } + } + } + return true; +} + + +DB::QueryPlanPtr JoinRelParser::buildMultiOnClauseHashJoin( + std::shared_ptr table_join, + DB::QueryPlanPtr left_plan, + DB::QueryPlanPtr right_plan, + const std::vector & join_on_clauses) +{ + DB::TableJoin::JoinOnClause & base_join_on_clause = table_join->getOnlyClause(); + base_join_on_clause = join_on_clauses[0]; + for (size_t i = 1; i < join_on_clauses.size(); ++i) + { + table_join->addDisjunct(); + auto & join_on_clause = table_join->getClauses().back(); + join_on_clause = join_on_clauses[i]; + } + + LOG_INFO(getLogger("JoinRelParser"), "multi join on clauses:\n{}", DB::TableJoin::formatClauses(table_join->getClauses())); + + JoinPtr hash_join = std::make_shared(table_join, right_plan->getCurrentDataStream().header); + QueryPlanStepPtr join_step + = std::make_unique(left_plan->getCurrentDataStream(), right_plan->getCurrentDataStream(), hash_join, 8192, 1, false); + join_step->setStepDescription("Multi join on clause hash join"); + steps.emplace_back(join_step.get()); + std::vector plans; + plans.emplace_back(std::move(left_plan)); + plans.emplace_back(std::move(right_plan)); + auto query_plan = std::make_unique(); + query_plan->unitePlans(std::move(join_step), {std::move(plans)}); + return query_plan; +} + +DB::QueryPlanPtr JoinRelParser::buildSingleOnClauseHashJoin( + const substrait::JoinRel & join_rel, std::shared_ptr table_join, DB::QueryPlanPtr left_plan, DB::QueryPlanPtr right_plan) +{ + applyJoinFilter(*table_join, join_rel, *left_plan, *right_plan, true); + /// Following is some configurations for grace hash join. + /// - spark.gluten.sql.columnar.backend.ch.runtime_settings.join_algorithm=grace_hash. This will + /// enable grace hash join. + /// - spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_in_join=3145728. This setup + /// the memory limitation fro grace hash join. If the memory consumption exceeds the limitation, + /// data will be spilled to disk. Don't set the limitation too small, otherwise the buckets number + /// will be too large and the performance will be bad. + JoinPtr hash_join = nullptr; + MultiEnum join_algorithm = context->getSettingsRef().join_algorithm; + if (join_algorithm.isSet(DB::JoinAlgorithm::GRACE_HASH)) + { + hash_join = std::make_shared( + context, + table_join, + left_plan->getCurrentDataStream().header, + right_plan->getCurrentDataStream().header, + context->getTempDataOnDisk()); + } + else + { + hash_join = std::make_shared(table_join, right_plan->getCurrentDataStream().header.cloneEmpty()); + } + QueryPlanStepPtr join_step + = std::make_unique(left_plan->getCurrentDataStream(), right_plan->getCurrentDataStream(), hash_join, 8192, 1, false); + + join_step->setStepDescription("HASH_JOIN"); + steps.emplace_back(join_step.get()); + std::vector plans; + plans.emplace_back(std::move(left_plan)); + plans.emplace_back(std::move(right_plan)); + + auto query_plan = std::make_unique(); + query_plan->unitePlans(std::move(join_step), {std::move(plans)}); + return query_plan; +} + void registerJoinRelParser(RelParserFactory & factory) { auto builder = [](SerializedPlanParser * plan_paser) { return std::make_shared(plan_paser); }; diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.h b/cpp-ch/local-engine/Parser/JoinRelParser.h index ee1155cb4712..0c0d07d6fdd2 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.h +++ b/cpp-ch/local-engine/Parser/JoinRelParser.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -70,6 +71,24 @@ class JoinRelParser : public RelParser static std::unordered_set extractTableSidesFromExpression( const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header); + + bool isJoinWithMultiJoinOnClauses( + const DB::TableJoin::JoinOnClause & prefix_clause, + std::vector & clauses, + const substrait::JoinRel & join_rel, + const DB::Block & left_header, + const DB::Block & right_header); + + DB::QueryPlanPtr buildMultiOnClauseHashJoin( + std::shared_ptr table_join, + DB::QueryPlanPtr left_plan, + DB::QueryPlanPtr right_plan, + const std::vector & join_on_clauses); + DB::QueryPlanPtr buildSingleOnClauseHashJoin( + const substrait::JoinRel & join_rel, + std::shared_ptr table_join, + DB::QueryPlanPtr left_plan, + DB::QueryPlanPtr right_plan); }; } From 8081b622e2feb5d8f44377fb7f76d66a1899adfb Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Thu, 15 Aug 2024 13:21:12 +0800 Subject: [PATCH 2/4] update --- .../clickhouse/CHSparkPlanExecApi.scala | 8 +++- .../execution/CHHashJoinExecTransformer.scala | 48 ++++++++++++------- .../velox/VeloxSparkPlanExecApi.scala | 3 +- cpp-ch/local-engine/Common/GlutenConfig.h | 12 ++--- .../Parser/AdvancedParametersParseUtil.cpp | 7 ++- .../Parser/AdvancedParametersParseUtil.h | 7 ++- cpp-ch/local-engine/Parser/JoinRelParser.cpp | 8 ++-- .../gluten/backendsapi/SparkPlanExecApi.scala | 3 +- .../extension/columnar/FallbackRules.scala | 3 +- .../columnar/OffloadSingleNode.scala | 3 +- 10 files changed, 60 insertions(+), 42 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 03e5aaa538a9..cea30ae284d7 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -309,8 +309,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { condition: Option[Expression], left: SparkPlan, right: SparkPlan, - isSkewJoin: Boolean): ShuffledHashJoinExecTransformerBase = - CHShuffledHashJoinExecTransformer( + isSkewJoin: Boolean, + logicalLink: Option[LogicalPlan]): ShuffledHashJoinExecTransformerBase = { + val res = CHShuffledHashJoinExecTransformer( leftKeys, rightKeys, joinType, @@ -319,6 +320,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { left, right, isSkewJoin) + res.setLogicalLink(logicalLink.getOrElse(null)) + res + } /** Generate BroadcastHashJoinExecTransformer. */ def genBroadcastHashJoinExecTransformer( diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala index 15e21681b1c6..252b9bc03fd6 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala @@ -26,6 +26,7 @@ import org.apache.spark.rpc.GlutenDriverEndpoint import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive._ @@ -135,24 +136,35 @@ case class CHShuffledHashJoinExecTransformer( .append("isExistenceJoin=") .append(if (joinType.isInstanceOf[ExistenceJoin]) 1 else 0) .append("\n") - .append("leftRowCount=") - .append(leftStats.rowCount.getOrElse(-1)) - .append("\n") - .append("leftNumPartitions=") - .append(leftStats.numPartitions) - .append("\n") - .append("leftNumMappers=") - .append(leftStats.numMappers) - .append("\n") - .append("rightRowCount=") - .append(rightStats.rowCount.getOrElse(-1)) - .append("\n") - .append("rightNumPartitions=") - .append(rightStats.numPartitions) - .append("\n") - .append("rightNumMappers=") - .append(rightStats.numMappers) - .append("\n") + logicalLink match { + case Some(join: Join) => + val leftRowCount = + if (needSwitchChildren) join.left.stats.rowCount else join.right.stats.rowCount + val rightRowCount = + if (needSwitchChildren) join.right.stats.rowCount else join.left.stats.rowCount + val leftSizeInBytes = + if (needSwitchChildren) join.left.stats.sizeInBytes else join.right.stats.sizeInBytes + val rightSizeInBytes = + if (needSwitchChildren) join.right.stats.sizeInBytes else join.left.stats.sizeInBytes + val numPartitions = outputPartitioning.numPartitions + joinParametersStr + .append("leftRowCount=") + .append(leftRowCount.getOrElse(-1)) + .append("\n") + .append("leftSizeInBytes=") + .append(leftSizeInBytes) + .append("\n") + .append("rightRowCount=") + .append(rightRowCount.getOrElse(-1)) + .append("\n") + .append("rightSizeInBytes=") + .append(rightSizeInBytes) + .append("\n") + .append("numPartitions=") + .append(numPartitions) + .append("\n") + case _ => + } val message = StringValue .newBuilder() .setValue(joinParametersStr.toString) diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index 89d29781c079..cc813daf1ff6 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -437,7 +437,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { condition: Option[Expression], left: SparkPlan, right: SparkPlan, - isSkewJoin: Boolean): ShuffledHashJoinExecTransformerBase = + isSkewJoin: Boolean, + logicalLink: Option[LogicalPlan]): ShuffledHashJoinExecTransformerBase = ShuffledHashJoinExecTransformer( leftKeys, rightKeys, diff --git a/cpp-ch/local-engine/Common/GlutenConfig.h b/cpp-ch/local-engine/Common/GlutenConfig.h index f8f0f41fe21a..84744dab21b8 100644 --- a/cpp-ch/local-engine/Common/GlutenConfig.h +++ b/cpp-ch/local-engine/Common/GlutenConfig.h @@ -96,19 +96,19 @@ struct JoinConfig { /// If the join condition is like `t1.k = t2.k and (t1.id1 = t2.id2 or t1.id2 = t2.id2)`, try to join with multi /// join on clauses `(t1.k = t2.k and t1.id1 = t2.id2) or (t1.k = t2.k or t1.id2 = t2.id2)` - inline static const String PREFER_INEQUAL_JOIN_TO_MULTI_JOIN_ON_CLAUSES = "prefer_inequal_join_to_multi_join_on_clauses"; + inline static const String PREFER_MULTI_JOIN_ON_CLAUSES = "prefer_multi_join_on_clauses"; /// Only hash join supports multi join on clauses, the right table cannot be too large. If the row number of right /// table is larger then this limit, this transform will not work. - inline static const String INEQUAL_JOIN_TO_MULTI_JOIN_ON_CLAUSES_ROWS_LIMIT = "inequal_join_to_multi_join_on_clauses_row_limit"; + inline static const String MULTI_JOIN_ON_CLAUSES_BUILD_SIDE_ROWS_LIMIT = "multi_join_on_clauses_build_side_row_limit"; - bool prefer_inequal_join_to_multi_join_on_clauses = true; - size_t inequal_join_to_multi_join_on_clauses_rows_limit = 10000000; + bool prefer_multi_join_on_clauses = true; + size_t multi_join_on_clauses_build_side_rows_limit = 10000000; static JoinConfig loadFromContext(DB::ContextPtr context) { JoinConfig config; - config.prefer_inequal_join_to_multi_join_on_clauses = context->getConfigRef().getBool(PREFER_INEQUAL_JOIN_TO_MULTI_JOIN_ON_CLAUSES, true); - config.inequal_join_to_multi_join_on_clauses_rows_limit = context->getConfigRef().getUInt64(INEQUAL_JOIN_TO_MULTI_JOIN_ON_CLAUSES_ROWS_LIMIT, 10000000); + config.prefer_multi_join_on_clauses = context->getConfigRef().getBool(PREFER_MULTI_JOIN_ON_CLAUSES, true); + config.multi_join_on_clauses_build_side_rows_limit = context->getConfigRef().getUInt64(MULTI_JOIN_ON_CLAUSES_BUILD_SIDE_ROWS_LIMIT, 10000000); return config; } }; diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp index 59de24de5275..344bf939fe09 100644 --- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp +++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp @@ -132,11 +132,10 @@ JoinOptimizationInfo JoinOptimizationInfo::parse(const String & advance) tryAssign(kvs, "isNullAwareAntiJoin", info.is_null_aware_anti_join); tryAssign(kvs, "isExistenceJoin", info.is_existence_join); tryAssign(kvs, "leftRowCount", info.left_table_rows); - tryAssign(kvs, "leftNumPartitions", info.left_table_partitions_num); - tryAssign(kvs, "leftNumMappers", info.left_table_mappers_num); + tryAssign(kvs, "leftSizeInBytes", info.left_table_bytes); tryAssign(kvs, "rightRowCount", info.right_table_rows); - tryAssign(kvs, "rightNumPartitions", info.right_table_partitions_num); - tryAssign(kvs, "rightNumMappers", info.right_table_mappers_num); + tryAssign(kvs, "rightSizeInBytes", info.right_table_bytes); + tryAssign(kvs, "numPartitions", info.partitions_num); return info; } } diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h index 08bd520760d7..5f6fe6d256e3 100644 --- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h +++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h @@ -30,11 +30,10 @@ struct JoinOptimizationInfo bool is_null_aware_anti_join = false; bool is_existence_join = false; Int64 left_table_rows = -1; - Int64 left_table_partitions_num = -1; - Int64 left_table_mappers_num = -1; + Int64 left_table_bytes = -1; Int64 right_table_rows = -1; - Int64 right_table_partitions_num = -1; - Int64 right_table_mappers_num = -1; + Int64 right_table_bytes = -1; + Int64 partitions_num = -1; String storage_join_key; static JoinOptimizationInfo parse(const String & advance); diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.cpp b/cpp-ch/local-engine/Parser/JoinRelParser.cpp index 02691fceb7ba..fbab2609412a 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.cpp +++ b/cpp-ch/local-engine/Parser/JoinRelParser.cpp @@ -316,10 +316,10 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q table_join->addDisjunct(); bool is_multi_join_on_clauses = isJoinWithMultiJoinOnClauses(table_join->getOnlyClause(), join_on_clauses, join, left_header, right_header); - if (is_multi_join_on_clauses && join_config.prefer_inequal_join_to_multi_join_on_clauses && join_opt_info.right_table_rows > 0 - && join_opt_info.right_table_mappers_num > 0 - && join_opt_info.right_table_rows / join_opt_info.right_table_mappers_num - < join_config.inequal_join_to_multi_join_on_clauses_rows_limit) + if (is_multi_join_on_clauses && join_config.prefer_multi_join_on_clauses && join_opt_info.right_table_rows > 0 + && join_opt_info.partitions_num > 0 + && join_opt_info.right_table_rows / join_opt_info.partitions_num + < join_config.multi_join_on_clauses_build_side_rows_limit) { query_plan = buildMultiOnClauseHashJoin(table_join, std::move(left), std::move(right), join_on_clauses); } diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index 69392a353d7f..b11102aea907 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -114,7 +114,8 @@ trait SparkPlanExecApi { condition: Option[Expression], left: SparkPlan, right: SparkPlan, - isSkewJoin: Boolean): ShuffledHashJoinExecTransformerBase + isSkewJoin: Boolean, + logicalLink: Option[LogicalPlan]): ShuffledHashJoinExecTransformerBase /** Generate BroadcastHashJoinExecTransformer. */ def genBroadcastHashJoinExecTransformer( diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala index f9eaa4179c67..c6392d13c70d 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala @@ -392,7 +392,8 @@ case class AddFallbackTagRule() extends Rule[SparkPlan] { plan.condition, plan.left, plan.right, - plan.isSkewJoin) + plan.isSkewJoin, + plan.logicalLink) transformer.doValidate().tagOnFallback(plan) case plan: BroadcastExchangeExec => val transformer = ColumnarBroadcastExchangeExec(plan.mode, plan.child) diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala index a8cc791286b2..a75f8f64de09 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala @@ -142,7 +142,8 @@ case class OffloadJoin() extends OffloadSingleNode with LogLevelUtil { plan.condition, left, right, - plan.isSkewJoin) + plan.isSkewJoin, + plan.logicalLink) case plan: SortMergeJoinExec => val left = plan.left val right = plan.right From 1376873592a1b3ced138f36ca08b3d6c29236357 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Thu, 15 Aug 2024 14:41:29 +0800 Subject: [PATCH 3/4] update --- .../execution/CHHashJoinExecTransformer.scala | 39 +++---------------- ...tenClickHouseColumnarShuffleAQESuite.scala | 6 +-- .../Parser/AdvancedParametersParseUtil.cpp | 9 ++++- cpp-ch/local-engine/Parser/JoinRelParser.cpp | 8 ++-- cpp-ch/local-engine/Parser/JoinRelParser.h | 2 +- 5 files changed, 21 insertions(+), 43 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala index 252b9bc03fd6..990a0163ee82 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala @@ -27,9 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} -import org.apache.spark.sql.execution.adaptive._ import org.apache.spark.sql.execution.joins.BuildSideRelation import org.apache.spark.sql.vectorized.ColumnarBatch @@ -115,9 +113,6 @@ case class CHShuffledHashJoinExecTransformer( override def genJoinParameters(): Any = { val (isBHJ, isNullAwareAntiJoin, buildHashTableId): (Int, Int, String) = (0, 0, "") - // Don't use lef/right directly, they may be reordered in `HashJoinLikeExecTransformer` - val leftStats = getShuffleStageStatistics(streamedPlan) - val rightStats = getShuffleStageStatistics(buildPlan) // Start with "JoinParameters:" val joinParametersStr = new StringBuffer("JoinParameters:") // isBHJ: 0 for SHJ, 1 for BHJ @@ -138,14 +133,12 @@ case class CHShuffledHashJoinExecTransformer( .append("\n") logicalLink match { case Some(join: Join) => - val leftRowCount = - if (needSwitchChildren) join.left.stats.rowCount else join.right.stats.rowCount - val rightRowCount = - if (needSwitchChildren) join.right.stats.rowCount else join.left.stats.rowCount - val leftSizeInBytes = - if (needSwitchChildren) join.left.stats.sizeInBytes else join.right.stats.sizeInBytes - val rightSizeInBytes = - if (needSwitchChildren) join.right.stats.sizeInBytes else join.left.stats.sizeInBytes + val left = if (!needSwitchChildren) join.left else join.right + val right = if (!needSwitchChildren) join.right else join.left + val leftRowCount = left.stats.rowCount + val rightRowCount = right.stats.rowCount + val leftSizeInBytes = left.stats.sizeInBytes + val rightSizeInBytes = right.stats.sizeInBytes val numPartitions = outputPartitioning.numPartitions joinParametersStr .append("leftRowCount=") @@ -171,26 +164,6 @@ case class CHShuffledHashJoinExecTransformer( .build() BackendsApiManager.getTransformerApiInstance.packPBMessage(message) } - - private def getShuffleStageStatistics(plan: SparkPlan): ShuffleStageStaticstics = { - plan match { - case queryStage: ShuffleQueryStageExec => - ShuffleStageStaticstics( - queryStage.shuffle.numPartitions, - queryStage.shuffle.numMappers, - queryStage.getRuntimeStatistics.rowCount) - case shuffle: ColumnarShuffleExchangeExec => - // FIXEME: We cannot access shuffle.numPartitions and shuffle.numMappers here. - // Otherwise it will cause an exception `ProjectExecTransformer has column support mismatch` - ShuffleStageStaticstics(-1, -1, None) - case _ => - if (plan.children.length == 1) { - getShuffleStageStatistics(plan.children.head) - } else { - ShuffleStageStaticstics(-1, -1, None) - } - } - } } case class CHBroadcastBuildSideRDD( diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala index ebeb69c16350..10e5c7534d35 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala @@ -265,10 +265,8 @@ class GlutenClickHouseColumnarShuffleAQESuite test("GLUTEN-6768 change mixed join condition into multi join on clauses") { withSQLConf( - (backendConfigPrefix + "runtime_config.prefer_inequal_join_to_multi_join_on_clauses", "true"), - ( - backendConfigPrefix + "runtime_config.inequal_join_to_multi_join_on_clauses_row_limit", - "1000000") + (backendConfigPrefix + "runtime_config.prefer_multi_join_on_clauses", "true"), + (backendConfigPrefix + "runtime_config.multi_join_on_clauses_build_side_row_limit", "1000000") ) { spark.sql("create table t1(a int, b int, c int, d int) using parquet") diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp index 344bf939fe09..e642d28d7f2e 100644 --- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp +++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp @@ -63,7 +63,14 @@ void tryAssign(const std::unordered_map & kvs, const Stri auto it = kvs.find(key); if (it != kvs.end()) { - v = std::stol(it->second); + try + { + v = std::stol(it->second); + } + catch (...) + { + LOG_ERROR(getLogger("tryAssign"), "Invalid number: {}", it->second); + } } } diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.cpp b/cpp-ch/local-engine/Parser/JoinRelParser.cpp index fbab2609412a..ef19e007d439 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.cpp +++ b/cpp-ch/local-engine/Parser/JoinRelParser.cpp @@ -209,7 +209,7 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q google::protobuf::StringValue optimization_info; optimization_info.ParseFromString(join.advanced_extension().optimization().value()); auto join_opt_info = JoinOptimizationInfo::parse(optimization_info.value()); - LOG_ERROR(getLogger("JoinRelParser"), "{}", optimization_info.value()); + LOG_ERROR(getLogger("JoinRelParser"), "optimizaiton info:{}", optimization_info.value()); auto storage_join = join_opt_info.is_broadcast ? BroadCastJoinBuilder::getJoin(join_opt_info.storage_join_key) : nullptr; if (storage_join) { @@ -315,7 +315,7 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q if (table_join->getClauses().empty()) table_join->addDisjunct(); bool is_multi_join_on_clauses - = isJoinWithMultiJoinOnClauses(table_join->getOnlyClause(), join_on_clauses, join, left_header, right_header); + = couldRewriteToMultiJoinOnClauses(table_join->getOnlyClause(), join_on_clauses, join, left_header, right_header); if (is_multi_join_on_clauses && join_config.prefer_multi_join_on_clauses && join_opt_info.right_table_rows > 0 && join_opt_info.partitions_num > 0 && join_opt_info.right_table_rows / join_opt_info.partitions_num @@ -611,14 +611,14 @@ void JoinRelParser::addPostFilter(DB::QueryPlan & query_plan, const substrait::J } /// Only support following pattern: a1 = b1 or a2 = b2 or (a3 = b3 and a4 = b4) -bool JoinRelParser::isJoinWithMultiJoinOnClauses( +bool JoinRelParser::couldRewriteToMultiJoinOnClauses( const DB::TableJoin::JoinOnClause & prefix_clause, std::vector & clauses, const substrait::JoinRel & join_rel, const DB::Block & left_header, const DB::Block & right_header) { - /// There is only on join clause + /// There is only one join clause if (!join_rel.has_post_join_filter()) return false; diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.h b/cpp-ch/local-engine/Parser/JoinRelParser.h index 0c0d07d6fdd2..7e43187be308 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.h +++ b/cpp-ch/local-engine/Parser/JoinRelParser.h @@ -72,7 +72,7 @@ class JoinRelParser : public RelParser static std::unordered_set extractTableSidesFromExpression( const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header); - bool isJoinWithMultiJoinOnClauses( + bool couldRewriteToMultiJoinOnClauses( const DB::TableJoin::JoinOnClause & prefix_clause, std::vector & clauses, const substrait::JoinRel & join_rel, From 3b923af654bcb705de595cabba6d2cbae5c6ce1a Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Fri, 16 Aug 2024 10:02:21 +0800 Subject: [PATCH 4/4] update --- .../clickhouse/CHSparkPlanExecApi.scala | 7 +--- .../execution/CHHashJoinExecTransformer.scala | 36 ++++++++--------- .../org/apache/gluten/utils/CHAQEUtil.scala | 39 +++++++++++++++++++ .../velox/VeloxSparkPlanExecApi.scala | 3 +- .../Parser/AdvancedParametersParseUtil.cpp | 1 + .../gluten/backendsapi/SparkPlanExecApi.scala | 3 +- .../extension/columnar/FallbackRules.scala | 3 +- .../columnar/OffloadSingleNode.scala | 3 +- 8 files changed, 63 insertions(+), 32 deletions(-) create mode 100644 backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHAQEUtil.scala diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index cea30ae284d7..5a49d6ea3d66 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -309,9 +309,8 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { condition: Option[Expression], left: SparkPlan, right: SparkPlan, - isSkewJoin: Boolean, - logicalLink: Option[LogicalPlan]): ShuffledHashJoinExecTransformerBase = { - val res = CHShuffledHashJoinExecTransformer( + isSkewJoin: Boolean): ShuffledHashJoinExecTransformerBase = { + CHShuffledHashJoinExecTransformer( leftKeys, rightKeys, joinType, @@ -320,8 +319,6 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { left, right, isSkewJoin) - res.setLogicalLink(logicalLink.getOrElse(null)) - res } /** Generate BroadcastHashJoinExecTransformer. */ diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala index 990a0163ee82..adb824804718 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala @@ -26,7 +26,6 @@ import org.apache.spark.rpc.GlutenDriverEndpoint import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.joins.BuildSideRelation import org.apache.spark.sql.vectorized.ColumnarBatch @@ -69,8 +68,6 @@ object JoinTypeTransform { } } -case class ShuffleStageStaticstics(numPartitions: Int, numMappers: Int, rowCount: Option[BigInt]) - case class CHShuffledHashJoinExecTransformer( leftKeys: Seq[Expression], rightKeys: Seq[Expression], @@ -131,33 +128,34 @@ case class CHShuffledHashJoinExecTransformer( .append("isExistenceJoin=") .append(if (joinType.isInstanceOf[ExistenceJoin]) 1 else 0) .append("\n") - logicalLink match { - case Some(join: Join) => - val left = if (!needSwitchChildren) join.left else join.right - val right = if (!needSwitchChildren) join.right else join.left - val leftRowCount = left.stats.rowCount - val rightRowCount = right.stats.rowCount - val leftSizeInBytes = left.stats.sizeInBytes - val rightSizeInBytes = right.stats.sizeInBytes - val numPartitions = outputPartitioning.numPartitions + + CHAQEUtil.getShuffleQueryStageStats(streamedPlan) match { + case Some(stats) => joinParametersStr .append("leftRowCount=") - .append(leftRowCount.getOrElse(-1)) + .append(stats.rowCount.getOrElse(-1)) .append("\n") .append("leftSizeInBytes=") - .append(leftSizeInBytes) + .append(stats.sizeInBytes) .append("\n") + case _ => + } + CHAQEUtil.getShuffleQueryStageStats(buildPlan) match { + case Some(stats) => + joinParametersStr .append("rightRowCount=") - .append(rightRowCount.getOrElse(-1)) + .append(stats.rowCount.getOrElse(-1)) .append("\n") .append("rightSizeInBytes=") - .append(rightSizeInBytes) - .append("\n") - .append("numPartitions=") - .append(numPartitions) + .append(stats.sizeInBytes) .append("\n") case _ => } + joinParametersStr + .append("numPartitions=") + .append(outputPartitioning.numPartitions) + .append("\n") + val message = StringValue .newBuilder() .setValue(joinParametersStr.toString) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHAQEUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHAQEUtil.scala new file mode 100644 index 000000000000..9a35517f54fc --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHAQEUtil.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution + +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive._ + +object CHAQEUtil { + + // All TransformSupports have lost the logicalLink. So we need iterate the plan to find the + // first ShuffleQueryStageExec and get the runtime stats. + def getShuffleQueryStageStats(plan: SparkPlan): Option[Statistics] = { + plan match { + case stage: ShuffleQueryStageExec => + Some(stage.getRuntimeStatistics) + case _ => + if (plan.children.length == 1) { + getShuffleQueryStageStats(plan.children.head) + } else { + None + } + } + } +} diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index cc813daf1ff6..89d29781c079 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -437,8 +437,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { condition: Option[Expression], left: SparkPlan, right: SparkPlan, - isSkewJoin: Boolean, - logicalLink: Option[LogicalPlan]): ShuffledHashJoinExecTransformerBase = + isSkewJoin: Boolean): ShuffledHashJoinExecTransformerBase = ShuffledHashJoinExecTransformer( leftKeys, rightKeys, diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp index e642d28d7f2e..42d4f4d4d8cd 100644 --- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp +++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp @@ -70,6 +70,7 @@ void tryAssign(const std::unordered_map & kvs, const Stri catch (...) { LOG_ERROR(getLogger("tryAssign"), "Invalid number: {}", it->second); + throw; } } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index b11102aea907..69392a353d7f 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -114,8 +114,7 @@ trait SparkPlanExecApi { condition: Option[Expression], left: SparkPlan, right: SparkPlan, - isSkewJoin: Boolean, - logicalLink: Option[LogicalPlan]): ShuffledHashJoinExecTransformerBase + isSkewJoin: Boolean): ShuffledHashJoinExecTransformerBase /** Generate BroadcastHashJoinExecTransformer. */ def genBroadcastHashJoinExecTransformer( diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala index c6392d13c70d..f9eaa4179c67 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala @@ -392,8 +392,7 @@ case class AddFallbackTagRule() extends Rule[SparkPlan] { plan.condition, plan.left, plan.right, - plan.isSkewJoin, - plan.logicalLink) + plan.isSkewJoin) transformer.doValidate().tagOnFallback(plan) case plan: BroadcastExchangeExec => val transformer = ColumnarBroadcastExchangeExec(plan.mode, plan.child) diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala index a75f8f64de09..a8cc791286b2 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala @@ -142,8 +142,7 @@ case class OffloadJoin() extends OffloadSingleNode with LogLevelUtil { plan.condition, left, right, - plan.isSkewJoin, - plan.logicalLink) + plan.isSkewJoin) case plan: SortMergeJoinExec => val left = plan.left val right = plan.right