From 04688da1d016630e36cdca60751861fbfdb05b90 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Mon, 12 Aug 2024 16:42:46 +0800 Subject: [PATCH] try to use multi join on clause as possible --- .../execution/CHHashJoinExecTransformer.scala | 75 +++++ ...tenClickHouseColumnarShuffleAQESuite.scala | 33 ++ cpp-ch/local-engine/Common/GlutenConfig.h | 21 ++ .../Parser/AdvancedParametersParseUtil.cpp | 16 + .../Parser/AdvancedParametersParseUtil.h | 6 + cpp-ch/local-engine/Parser/JoinRelParser.cpp | 301 ++++++++++++++---- cpp-ch/local-engine/Parser/JoinRelParser.h | 19 ++ 7 files changed, 417 insertions(+), 54 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala index 7080e55dc1863..15e21681b1c63 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala @@ -16,6 +16,7 @@ */ package org.apache.gluten.execution +import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.extension.ValidationResult import org.apache.gluten.utils.{BroadcastHashJoinStrategy, CHJoinValidateUtil, ShuffleHashJoinStrategy} @@ -25,10 +26,13 @@ import org.apache.spark.rpc.GlutenDriverEndpoint import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.adaptive._ import org.apache.spark.sql.execution.joins.BuildSideRelation import org.apache.spark.sql.vectorized.ColumnarBatch +import com.google.protobuf.{Any, StringValue} import io.substrait.proto.JoinRel object JoinTypeTransform { @@ -66,6 +70,8 @@ object JoinTypeTransform { } } +case class ShuffleStageStaticstics(numPartitions: Int, numMappers: Int, rowCount: Option[BigInt]) + case class CHShuffledHashJoinExecTransformer( leftKeys: Seq[Expression], rightKeys: Seq[Expression], @@ -104,6 +110,75 @@ case class CHShuffledHashJoinExecTransformer( private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType) override protected lazy val substraitJoinType: JoinRel.JoinType = JoinTypeTransform.toSubstraitType(joinType, buildSide) + + override def genJoinParameters(): Any = { + val (isBHJ, isNullAwareAntiJoin, buildHashTableId): (Int, Int, String) = (0, 0, "") + + // Don't use lef/right directly, they may be reordered in `HashJoinLikeExecTransformer` + val leftStats = getShuffleStageStatistics(streamedPlan) + val rightStats = getShuffleStageStatistics(buildPlan) + // Start with "JoinParameters:" + val joinParametersStr = new StringBuffer("JoinParameters:") + // isBHJ: 0 for SHJ, 1 for BHJ + // isNullAwareAntiJoin: 0 for false, 1 for true + // buildHashTableId: the unique id for the hash table of build plan + joinParametersStr + .append("isBHJ=") + .append(isBHJ) + .append("\n") + .append("isNullAwareAntiJoin=") + .append(isNullAwareAntiJoin) + .append("\n") + .append("buildHashTableId=") + .append(buildHashTableId) + .append("\n") + .append("isExistenceJoin=") + .append(if (joinType.isInstanceOf[ExistenceJoin]) 1 else 0) + .append("\n") + .append("leftRowCount=") + .append(leftStats.rowCount.getOrElse(-1)) + .append("\n") + .append("leftNumPartitions=") + .append(leftStats.numPartitions) + .append("\n") + .append("leftNumMappers=") + .append(leftStats.numMappers) + .append("\n") + .append("rightRowCount=") + .append(rightStats.rowCount.getOrElse(-1)) + .append("\n") + .append("rightNumPartitions=") + .append(rightStats.numPartitions) + .append("\n") + .append("rightNumMappers=") + .append(rightStats.numMappers) + .append("\n") + val message = StringValue + .newBuilder() + .setValue(joinParametersStr.toString) + .build() + BackendsApiManager.getTransformerApiInstance.packPBMessage(message) + } + + private def getShuffleStageStatistics(plan: SparkPlan): ShuffleStageStaticstics = { + plan match { + case queryStage: ShuffleQueryStageExec => + ShuffleStageStaticstics( + queryStage.shuffle.numPartitions, + queryStage.shuffle.numMappers, + queryStage.getRuntimeStatistics.rowCount) + case shuffle: ColumnarShuffleExchangeExec => + // FIXEME: We cannot access shuffle.numPartitions and shuffle.numMappers here. + // Otherwise it will cause an exception `ProjectExecTransformer has column support mismatch` + ShuffleStageStaticstics(-1, -1, None) + case _ => + if (plan.children.length == 1) { + getShuffleStageStatistics(plan.children.head) + } else { + ShuffleStageStaticstics(-1, -1, None) + } + } + } } case class CHBroadcastBuildSideRDD( diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala index fc22add2d8802..044be6dfa2519 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala @@ -32,6 +32,7 @@ class GlutenClickHouseColumnarShuffleAQESuite override protected val tablesPath: String = basePath + "/tpch-data-ch" override protected val tpchQueries: String = rootPath + "queries/tpch-queries-ch" override protected val queriesResults: String = rootPath + "mergetree-queries-output" + private val backendConfigPrefix = "spark.gluten.sql.columnar.backend.ch" /** Run Gluten + ClickHouse Backend with ColumnarShuffleManager */ override protected def sparkConf: SparkConf = { @@ -261,4 +262,36 @@ class GlutenClickHouseColumnarShuffleAQESuite spark.sql("drop table t2") } } + + test("GLUTEN-6768 change mixed join condition into multi join on clauses") { + withSQLConf( + (backendConfigPrefix + "runtime_config.prefer_inequal_join_to_multi_join_on_clauses", "true"), + ( + backendConfigPrefix + "runtime_config.inequal_join_to_multi_join_on_clauses_row_limit", + "1000000") + ) { + + spark.sql("create table t1(a int, b int, c int, d int) using parquet") + spark.sql("create table t2(a int, b int, c int, d int) using parquet") + + spark.sql(""" + |insert into t1 + |select id % 100 as a, id as b, id + 1 as c, id + 2 as d from range(1000) + |""".stripMargin) + spark.sql(""" + |insert into t2 + |select id % 100 as a, id as b, id + 1 as c, id + 2 as d from range(1000) + |""".stripMargin) + + val sql = """ + |select * from t1 join t2 on + |t1.a = t2.a and (t1.b = t2.b or t1.c = t2.c or t1.d = t2.d) + |order by t1.a, t1.b, t1.c, t1.d + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + + spark.sql("drop table t1") + spark.sql("drop table t2") + } + } } diff --git a/cpp-ch/local-engine/Common/GlutenConfig.h b/cpp-ch/local-engine/Common/GlutenConfig.h index abb7295adc0d0..91a58d8585208 100644 --- a/cpp-ch/local-engine/Common/GlutenConfig.h +++ b/cpp-ch/local-engine/Common/GlutenConfig.h @@ -92,6 +92,27 @@ struct StreamingAggregateConfig } }; +struct JoinConfig +{ + /// If the join condition is like `t1.k = t2.k and (t1.id1 = t2.id2 or t1.id2 = t2.id2)`, try to join with multi + /// join on clauses `(t1.k = t2.k and t1.id1 = t2.id2) or (t1.k = t2.k or t1.id2 = t2.id2)` + inline static const String PREFER_INEQUAL_JOIN_TO_MULTI_JOIN_ON_CLAUSES = "prefer_inequal_join_to_multi_join_on_clauses"; + /// Only hash join supports multi join on clauses, the right table cannot be to large. If the row number of right + /// table is larger then this limit, this transform will not work. + inline static const String INEQUAL_JOIN_TO_MULTI_JOIN_ON_CLAUSES_ROWS_LIMIT = "inequal_join_to_multi_join_on_clauses_row_limit"; + + bool prefer_inequal_join_to_multi_join_on_clauses = true; + size_t inequal_join_to_multi_join_on_clauses_rows_limit = 10000000; + + static JoinConfig loadFromContext(DB::ContextPtr context) + { + JoinConfig config; + config.prefer_inequal_join_to_multi_join_on_clauses = context->getConfigRef().getBool(PREFER_INEQUAL_JOIN_TO_MULTI_JOIN_ON_CLAUSES, true); + config.inequal_join_to_multi_join_on_clauses_rows_limit = context->getConfigRef().getUInt64(INEQUAL_JOIN_TO_MULTI_JOIN_ON_CLAUSES_ROWS_LIMIT, 10000000); + return config; + } +}; + struct ExecutorConfig { inline static const String DUMP_PIPELINE = "dump_pipeline"; diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp index a7a07c0bf31ae..59de24de52757 100644 --- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp +++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp @@ -57,6 +57,16 @@ void tryAssign(const std::unordered_map & kvs, const Strin } } +template<> +void tryAssign(const std::unordered_map & kvs, const String & key, Int64 & v) +{ + auto it = kvs.find(key); + if (it != kvs.end()) + { + v = std::stol(it->second); + } +} + template void readStringUntilCharsInto(String & s, DB::ReadBuffer & buf) { @@ -121,6 +131,12 @@ JoinOptimizationInfo JoinOptimizationInfo::parse(const String & advance) tryAssign(kvs, "buildHashTableId", info.storage_join_key); tryAssign(kvs, "isNullAwareAntiJoin", info.is_null_aware_anti_join); tryAssign(kvs, "isExistenceJoin", info.is_existence_join); + tryAssign(kvs, "leftRowCount", info.left_table_rows); + tryAssign(kvs, "leftNumPartitions", info.left_table_partitions_num); + tryAssign(kvs, "leftNumMappers", info.left_table_mappers_num); + tryAssign(kvs, "rightRowCount", info.right_table_rows); + tryAssign(kvs, "rightNumPartitions", info.right_table_partitions_num); + tryAssign(kvs, "rightNumMappers", info.right_table_mappers_num); return info; } } diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h index 5a15a3ea8abc6..08bd520760d7d 100644 --- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h +++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h @@ -29,6 +29,12 @@ struct JoinOptimizationInfo bool is_smj = false; bool is_null_aware_anti_join = false; bool is_existence_join = false; + Int64 left_table_rows = -1; + Int64 left_table_partitions_num = -1; + Int64 left_table_mappers_num = -1; + Int64 right_table_rows = -1; + Int64 right_table_partitions_num = -1; + Int64 right_table_mappers_num = -1; String storage_join_key; static JoinOptimizationInfo parse(const String & advance); diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.cpp b/cpp-ch/local-engine/Parser/JoinRelParser.cpp index b217a9bd9da2c..ad7c88203b627 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.cpp +++ b/cpp-ch/local-engine/Parser/JoinRelParser.cpp @@ -16,6 +16,8 @@ */ #include "JoinRelParser.h" +#include +#include #include #include #include @@ -25,15 +27,15 @@ #include #include #include -#include #include +#include #include #include #include #include #include #include -#include +#include #include @@ -42,9 +44,9 @@ namespace DB { namespace ErrorCodes { - extern const int LOGICAL_ERROR; - extern const int UNKNOWN_TYPE; - extern const int BAD_ARGUMENTS; +extern const int LOGICAL_ERROR; +extern const int UNKNOWN_TYPE; +extern const int BAD_ARGUMENTS; } } using namespace DB; @@ -98,7 +100,8 @@ DB::QueryPlanPtr JoinRelParser::parseOp(const substrait::Rel & rel, std::list JoinRelParser::extractTableSidesFromExpression(const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header) +std::unordered_set JoinRelParser::extractTableSidesFromExpression( + const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header) { std::unordered_set table_sides; if (expr.has_scalar_function()) @@ -169,8 +172,7 @@ void JoinRelParser::renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & righ storage_join.getRightSampleBlock().getColumnsWithTypeAndName(), ActionsDAG::MatchColumnsMode::Position); - QueryPlanStepPtr right_project_step = - std::make_unique(right.getCurrentDataStream(), std::move(right_project)); + QueryPlanStepPtr right_project_step = std::make_unique(right.getCurrentDataStream(), std::move(right_project)); right_project_step->setStepDescription("Rename Broadcast Table Name"); steps.emplace_back(right_project_step.get()); right.addStep(std::move(right_project_step)); @@ -193,12 +195,9 @@ void JoinRelParser::renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & righ } } ActionsDAG left_project = ActionsDAG::makeConvertingActions( - left.getCurrentDataStream().header.getColumnsWithTypeAndName(), - new_left_cols, - ActionsDAG::MatchColumnsMode::Position); + left.getCurrentDataStream().header.getColumnsWithTypeAndName(), new_left_cols, ActionsDAG::MatchColumnsMode::Position); - QueryPlanStepPtr left_project_step = - std::make_unique(left.getCurrentDataStream(), std::move(left_project)); + QueryPlanStepPtr left_project_step = std::make_unique(left.getCurrentDataStream(), std::move(left_project)); left_project_step->setStepDescription("Rename Left Table Name for broadcast join"); steps.emplace_back(left_project_step.get()); left.addStep(std::move(left_project_step)); @@ -206,9 +205,11 @@ void JoinRelParser::renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & righ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::QueryPlanPtr left, DB::QueryPlanPtr right) { + auto join_config = JoinConfig::loadFromContext(getContext()); google::protobuf::StringValue optimization_info; optimization_info.ParseFromString(join.advanced_extension().optimization().value()); auto join_opt_info = JoinOptimizationInfo::parse(optimization_info.value()); + LOG_ERROR(getLogger("JoinRelParser"), "xxxx optimization info:{}", optimization_info.value()); auto storage_join = join_opt_info.is_broadcast ? BroadCastJoinBuilder::getJoin(join_opt_info.storage_join_key) : nullptr; if (storage_join) { @@ -239,7 +240,9 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q } if (is_col_names_changed) { - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "For broadcast join, we must not change the columns name in the right table.\nleft header:{},\nright header: {} -> {}", + throw DB::Exception( + DB::ErrorCodes::LOGICAL_ERROR, + "For broadcast join, we must not change the columns name in the right table.\nleft header:{},\nright header: {} -> {}", left->getCurrentDataStream().header.dumpStructure(), right_header_before_convert_step.dumpStructure(), right->getCurrentDataStream().header.dumpStructure()); @@ -266,7 +269,6 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q if (storage_join) { - applyJoinFilter(*table_join, join, *left, *right, true); auto broadcast_hash_join = storage_join->getJoinLocked(table_join, context); @@ -288,15 +290,13 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q /// TODO: make smj support mixed conditions if (need_post_filter && table_join->kind() != DB::JoinKind::Inner) { - throw DB::Exception( - DB::ErrorCodes::LOGICAL_ERROR, - "Sort merge join doesn't support mixed join conditions, except inner join."); + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Sort merge join doesn't support mixed join conditions, except inner join."); } JoinPtr smj_join = std::make_shared(table_join, right->getCurrentDataStream().header.cloneEmpty(), -1); MultiEnum join_algorithm = context->getSettingsRef().join_algorithm; QueryPlanStepPtr join_step - = std::make_unique(left->getCurrentDataStream(), right->getCurrentDataStream(), smj_join, 8192, 1, false); + = std::make_unique(left->getCurrentDataStream(), right->getCurrentDataStream(), smj_join, 8192, 1, false); join_step->setStepDescription("SORT_MERGE_JOIN"); steps.emplace_back(join_step.get()); @@ -311,41 +311,20 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q } else { - applyJoinFilter(*table_join, join, *left, *right, true); - - /// Following is some configurations for grace hash join. - /// - spark.gluten.sql.columnar.backend.ch.runtime_settings.join_algorithm=grace_hash. This will - /// enable grace hash join. - /// - spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_in_join=3145728. This setup - /// the memory limitation fro grace hash join. If the memory consumption exceeds the limitation, - /// data will be spilled to disk. Don't set the limitation too small, otherwise the buckets number - /// will be too large and the performance will be bad. - JoinPtr hash_join = nullptr; - MultiEnum join_algorithm = context->getSettingsRef().join_algorithm; - if (join_algorithm.isSet(DB::JoinAlgorithm::GRACE_HASH)) + std::vector join_on_clauses; + bool is_multi_join_on_clauses + = isJoinWithMultiJoinOnClauses(table_join->getOnlyClause(), join_on_clauses, join, left_header, right_header); + if (is_multi_join_on_clauses && join_config.prefer_inequal_join_to_multi_join_on_clauses && join_opt_info.right_table_rows > 0 + && join_opt_info.right_table_mappers_num > 0 + && join_opt_info.right_table_rows / join_opt_info.right_table_mappers_num + < join_config.inequal_join_to_multi_join_on_clauses_rows_limit) { - hash_join = std::make_shared( - context, - table_join, - left->getCurrentDataStream().header, - right->getCurrentDataStream().header, - context->getTempDataOnDisk()); + query_plan = buildMultiOnClauseHashJoin(table_join, std::move(left), std::move(right), join_on_clauses); } else { - hash_join = std::make_shared(table_join, right->getCurrentDataStream().header.cloneEmpty()); + query_plan = buildSingleOnClauseHashJoin(join, table_join, std::move(left), std::move(right)); } - QueryPlanStepPtr join_step - = std::make_unique(left->getCurrentDataStream(), right->getCurrentDataStream(), hash_join, 8192, 1, false); - - join_step->setStepDescription("HASH_JOIN"); - steps.emplace_back(join_step.get()); - std::vector plans; - plans.emplace_back(std::move(left)); - plans.emplace_back(std::move(right)); - - query_plan = std::make_unique(); - query_plan->unitePlans(std::move(join_step), {std::move(plans)}); } JoinUtil::reorderJoinOutput(*query_plan, after_join_names); @@ -508,7 +487,11 @@ void JoinRelParser::collectJoinKeys( } bool JoinRelParser::applyJoinFilter( - DB::TableJoin & table_join, const substrait::JoinRel & join_rel, DB::QueryPlan & left, DB::QueryPlan & right, bool allow_mixed_condition) + DB::TableJoin & table_join, + const substrait::JoinRel & join_rel, + DB::QueryPlan & left, + DB::QueryPlan & right, + bool allow_mixed_condition) { if (!join_rel.has_post_join_filter()) return true; @@ -594,12 +577,13 @@ bool JoinRelParser::applyJoinFilter( return false; auto mixed_join_expressions_actions = expressionsToActionsDAG({expr}, mixed_header); mixed_join_expressions_actions.removeUnusedActions(); - table_join.getMixedJoinExpression() - = std::make_shared(std::move(mixed_join_expressions_actions), ExpressionActionsSettings::fromContext(context)); + table_join.getMixedJoinExpression() = std::make_shared( + std::move(mixed_join_expressions_actions), ExpressionActionsSettings::fromContext(context)); } else { - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Not any table column is used in the join condition.\n{}", join_rel.DebugString()); + throw DB::Exception( + DB::ErrorCodes::LOGICAL_ERROR, "Not any table column is used in the join condition.\n{}", join_rel.DebugString()); } return true; } @@ -610,7 +594,7 @@ void JoinRelParser::addPostFilter(DB::QueryPlan & query_plan, const substrait::J ActionsDAG actions_dag{query_plan.getCurrentDataStream().header.getColumnsWithTypeAndName()}; if (!join.post_join_filter().has_scalar_function()) { - // It may be singular_or_list + // It may be singular_or_list auto * in_node = getPlanParser()->parseExpression(actions_dag, join.post_join_filter()); filter_name = in_node->result_name; } @@ -624,6 +608,215 @@ void JoinRelParser::addPostFilter(DB::QueryPlan & query_plan, const substrait::J query_plan.addStep(std::move(filter_step)); } +bool JoinRelParser::isJoinWithMultiJoinOnClauses( + const DB::TableJoin::JoinOnClause & prefix_clause, + std::vector clauses, + const substrait::JoinRel & join_rel, + const DB::Block & left_header, + const DB::Block & right_header) +{ + /// There is only on join clause + if (!join_rel.has_post_join_filter()) + return false; + + const auto & filter_expr = join_rel.post_join_filter(); + std::list expression_stack; + expression_stack.push_back(&filter_expr); + + auto check_function = [&](const String function_name_, const substrait::Expression & e) + { + if (!e.has_scalar_function()) + { + return false; + } + auto function_name = parseFunctionName(e.scalar_function().function_reference(), e.scalar_function()); + return function_name.has_value() && *function_name == function_name_; + }; + + auto get_field_ref = [](const substrait::Expression & e) -> std::optional + { + if (e.has_selection() && e.selection().has_direct_reference() && e.selection().direct_reference().has_struct_field()) + { + return std::optional(e.selection().direct_reference().struct_field().field()); + } + return {}; + }; + + auto parse_join_keys = [&](const substrait::Expression & e) -> std::optional> + { + const auto & args = e.scalar_function().arguments(); + auto l_field_ref = get_field_ref(args[0].value()); + auto r_field_ref = get_field_ref(args[1].value()); + if (!l_field_ref.has_value() || !r_field_ref.has_value()) + return {}; + size_t l_pos = static_cast(*l_field_ref); + size_t r_pos = static_cast(*r_field_ref); + size_t l_cols = left_header.columns(); + size_t total_cols = l_cols + right_header.columns(); + + if (l_pos < l_cols && r_pos >= l_cols && r_pos < total_cols) + return std::make_pair(left_header.getByPosition(l_pos).name, right_header.getByPosition(r_pos - l_cols).name); + else if (r_pos < l_cols && l_pos >= l_cols && l_pos < total_cols) + return std::make_pair(left_header.getByPosition(r_pos).name, right_header.getByPosition(l_pos - l_cols).name); + return {}; + }; + + auto parse_and_expression = [&](const substrait::Expression & e, DB::TableJoin::JoinOnClause & join_on_clause) + { + std::vector and_expression_stack; + and_expression_stack.push_back(&e); + while (!and_expression_stack.empty()) + { + const auto & current_expr = *(and_expression_stack.back()); + and_expression_stack.pop_back(); + if (check_function("and", current_expr)) + { + for (const auto & arg : e.scalar_function().arguments()) + and_expression_stack.push_back(&arg.value()); + } + else if (check_function("equals", current_expr)) + { + auto optional_keys = parse_join_keys(current_expr); + if (!optional_keys) + { + LOG_ERROR(getLogger("JoinRelParser"), "Not equal comparison for keys from both tables"); + return false; + } + join_on_clause.addKey(optional_keys->first, optional_keys->second, false); + } + else + { + LOG_ERROR(getLogger("JoinRelParser"), "And or equals function is expected"); + return false; + } + } + return true; + }; + + while (!expression_stack.empty()) + { + const auto & current_expr = *(expression_stack.back()); + expression_stack.pop_back(); + if (!check_function("or", current_expr)) + { + LOG_ERROR(getLogger("JoinRelParser"), "Not an or expression"); + } + + auto get_current_join_on_clause = [&]() + { + DB::TableJoin::JoinOnClause new_clause = prefix_clause; + clauses.push_back(new_clause); + return &clauses.back(); + }; + + const auto & args = current_expr.scalar_function().arguments(); + for (const auto & arg : args) + { + if (check_function("equals", arg.value())) + { + auto optional_keys = parse_join_keys(arg.value()); + if (!optional_keys) + { + LOG_ERROR(getLogger("JoinRelParser"), "Not equal comparison for keys from both tables"); + return false; + } + get_current_join_on_clause()->addKey(optional_keys->first, optional_keys->second, false); + } + else if (check_function("and", arg.value())) + { + if (!parse_and_expression(arg.value(), *get_current_join_on_clause())) + { + LOG_ERROR(getLogger("JoinRelParser"), "Parse and expression failed"); + return false; + } + } + else if (check_function("or", arg.value())) + { + expression_stack.push_back(&arg.value()); + } + else + { + LOG_ERROR(getLogger("JoinRelParser"), "Unknow function"); + return false; + } + } + } + return true; +} + + +DB::QueryPlanPtr JoinRelParser::buildMultiOnClauseHashJoin( + std::shared_ptr table_join, + DB::QueryPlanPtr left_plan, + DB::QueryPlanPtr right_plan, + const std::vector & join_on_clauses) +{ + DB::TableJoin::JoinOnClause base_join_on_clause = table_join->getOnlyClause(); + base_join_on_clause = join_on_clauses[0]; + for (size_t i = 1; i < join_on_clauses.size(); ++i) + { + table_join->addDisjunct(); + auto & join_on_clause = table_join->getClauses().back(); + join_on_clause = join_on_clauses[i]; + } + + LOG_ERROR(getLogger("JoinRelParser"), "xxx join on clauses:\n{}", DB::TableJoin::formatClauses(table_join->getClauses())); + + JoinPtr hash_join = std::make_shared(table_join, right_plan->getCurrentDataStream().header); + QueryPlanStepPtr join_step + = std::make_unique(left_plan->getCurrentDataStream(), right_plan->getCurrentDataStream(), hash_join, 8192, 1, false); + join_step->setStepDescription("Multi join on clause hash join"); + steps.emplace_back(join_step.get()); + std::vector plans; + plans.emplace_back(std::move(left_plan)); + plans.emplace_back(std::move(right_plan)); + auto query_plan = std::make_unique(); + query_plan->unitePlans(std::move(join_step), {std::move(plans)}); + return query_plan; +} + +DB::QueryPlanPtr JoinRelParser::buildSingleOnClauseHashJoin( + const substrait::JoinRel & join_rel, std::shared_ptr table_join, DB::QueryPlanPtr left_plan, DB::QueryPlanPtr right_plan) +{ + LOG_ERROR(getLogger("JoinRelParser"), "xxx buildSingleOnClauseHashJoin"); + applyJoinFilter(*table_join, join_rel, *left_plan, *right_plan, true); + /// Following is some configurations for grace hash join. + /// - spark.gluten.sql.columnar.backend.ch.runtime_settings.join_algorithm=grace_hash. This will + /// enable grace hash join. + /// - spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_in_join=3145728. This setup + /// the memory limitation fro grace hash join. If the memory consumption exceeds the limitation, + /// data will be spilled to disk. Don't set the limitation too small, otherwise the buckets number + /// will be too large and the performance will be bad. + JoinPtr hash_join = nullptr; + MultiEnum join_algorithm = context->getSettingsRef().join_algorithm; + if (join_algorithm.isSet(DB::JoinAlgorithm::GRACE_HASH)) + { + hash_join = std::make_shared( + context, + table_join, + left_plan->getCurrentDataStream().header, + right_plan->getCurrentDataStream().header, + context->getTempDataOnDisk()); + } + else + { + hash_join = std::make_shared(table_join, right_plan->getCurrentDataStream().header.cloneEmpty()); + } + QueryPlanStepPtr join_step + = std::make_unique(left_plan->getCurrentDataStream(), right_plan->getCurrentDataStream(), hash_join, 8192, 1, false); + + join_step->setStepDescription("HASH_JOIN"); + steps.emplace_back(join_step.get()); + std::vector plans; + plans.emplace_back(std::move(left_plan)); + plans.emplace_back(std::move(right_plan)); + + auto query_plan = std::make_unique(); + query_plan->unitePlans(std::move(join_step), {std::move(plans)}); + LOG_ERROR(getLogger("JoinRelParser"), "xxxx buildSingleOnClauseHashJoin, output:{}", query_plan->getCurrentDataStream().header.dumpStructure()); + return query_plan; +} + void registerJoinRelParser(RelParserFactory & factory) { auto builder = [](SerializedPlanParser * plan_paser) { return std::make_shared(plan_paser); }; diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.h b/cpp-ch/local-engine/Parser/JoinRelParser.h index ee1155cb47128..7ede6100d5fda 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.h +++ b/cpp-ch/local-engine/Parser/JoinRelParser.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -70,6 +71,24 @@ class JoinRelParser : public RelParser static std::unordered_set extractTableSidesFromExpression( const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header); + + bool isJoinWithMultiJoinOnClauses( + const DB::TableJoin::JoinOnClause & prefix_clause, + std::vector clauses, + const substrait::JoinRel & join_rel, + const DB::Block & left_header, + const DB::Block & right_header); + + DB::QueryPlanPtr buildMultiOnClauseHashJoin( + std::shared_ptr table_join, + DB::QueryPlanPtr left_plan, + DB::QueryPlanPtr right_plan, + const std::vector & join_on_clauses); + DB::QueryPlanPtr buildSingleOnClauseHashJoin( + const substrait::JoinRel & join_rel, + std::shared_ptr table_join, + DB::QueryPlanPtr left_plan, + DB::QueryPlanPtr right_plan); }; }