Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-6768][CH] Try to use multi join on clauses instead of inequal join condition #6787

Merged
merged 4 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan,
isSkewJoin: Boolean): ShuffledHashJoinExecTransformerBase =
isSkewJoin: Boolean): ShuffledHashJoinExecTransformerBase = {
CHShuffledHashJoinExecTransformer(
leftKeys,
rightKeys,
Expand All @@ -319,6 +319,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
left,
right,
isSkewJoin)
}

/** Generate BroadcastHashJoinExecTransformer. */
def genBroadcastHashJoinExecTransformer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.utils.{BroadcastHashJoinStrategy, CHJoinValidateUtil, ShuffleHashJoinStrategy}

Expand All @@ -29,6 +30,7 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.vectorized.ColumnarBatch

import com.google.protobuf.{Any, StringValue}
import io.substrait.proto.JoinRel

object JoinTypeTransform {
Expand Down Expand Up @@ -104,6 +106,62 @@ case class CHShuffledHashJoinExecTransformer(
private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType)
override protected lazy val substraitJoinType: JoinRel.JoinType =
JoinTypeTransform.toSubstraitType(joinType, buildSide)

override def genJoinParameters(): Any = {
val (isBHJ, isNullAwareAntiJoin, buildHashTableId): (Int, Int, String) = (0, 0, "")

// Start with "JoinParameters:"
val joinParametersStr = new StringBuffer("JoinParameters:")
// isBHJ: 0 for SHJ, 1 for BHJ
// isNullAwareAntiJoin: 0 for false, 1 for true
// buildHashTableId: the unique id for the hash table of build plan
joinParametersStr
.append("isBHJ=")
.append(isBHJ)
.append("\n")
.append("isNullAwareAntiJoin=")
.append(isNullAwareAntiJoin)
.append("\n")
.append("buildHashTableId=")
.append(buildHashTableId)
.append("\n")
.append("isExistenceJoin=")
.append(if (joinType.isInstanceOf[ExistenceJoin]) 1 else 0)
.append("\n")

CHAQEUtil.getShuffleQueryStageStats(streamedPlan) match {
case Some(stats) =>
joinParametersStr
.append("leftRowCount=")
.append(stats.rowCount.getOrElse(-1))
.append("\n")
.append("leftSizeInBytes=")
.append(stats.sizeInBytes)
.append("\n")
case _ =>
}
CHAQEUtil.getShuffleQueryStageStats(buildPlan) match {
case Some(stats) =>
joinParametersStr
.append("rightRowCount=")
.append(stats.rowCount.getOrElse(-1))
.append("\n")
.append("rightSizeInBytes=")
.append(stats.sizeInBytes)
.append("\n")
case _ =>
}
joinParametersStr
.append("numPartitions=")
.append(outputPartitioning.numPartitions)
.append("\n")

val message = StringValue
.newBuilder()
.setValue(joinParametersStr.toString)
.build()
BackendsApiManager.getTransformerApiInstance.packPBMessage(message)
}
}

case class CHBroadcastBuildSideRDD(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.execution

import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive._

object CHAQEUtil {

// All TransformSupports have lost the logicalLink. So we need iterate the plan to find the
// first ShuffleQueryStageExec and get the runtime stats.
def getShuffleQueryStageStats(plan: SparkPlan): Option[Statistics] = {
plan match {
case stage: ShuffleQueryStageExec =>
Some(stage.getRuntimeStatistics)
case _ =>
if (plan.children.length == 1) {
getShuffleQueryStageStats(plan.children.head)
} else {
None
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class GlutenClickHouseColumnarShuffleAQESuite
override protected val tablesPath: String = basePath + "/tpch-data-ch"
override protected val tpchQueries: String = rootPath + "queries/tpch-queries-ch"
override protected val queriesResults: String = rootPath + "mergetree-queries-output"
private val backendConfigPrefix = "spark.gluten.sql.columnar.backend.ch."

/** Run Gluten + ClickHouse Backend with ColumnarShuffleManager */
override protected def sparkConf: SparkConf = {
Expand Down Expand Up @@ -261,4 +262,48 @@ class GlutenClickHouseColumnarShuffleAQESuite
spark.sql("drop table t2")
}
}

test("GLUTEN-6768 change mixed join condition into multi join on clauses") {
withSQLConf(
(backendConfigPrefix + "runtime_config.prefer_multi_join_on_clauses", "true"),
(backendConfigPrefix + "runtime_config.multi_join_on_clauses_build_side_row_limit", "1000000")
) {

spark.sql("create table t1(a int, b int, c int, d int) using parquet")
spark.sql("create table t2(a int, b int, c int, d int) using parquet")

spark.sql("""
|insert into t1
|select id % 2 as a, id as b, id + 1 as c, id + 2 as d from range(1000)
|""".stripMargin)
spark.sql("""
|insert into t2
|select id % 2 as a, id as b, id + 1 as c, id + 2 as d from range(1000)
|""".stripMargin)

var sql = """
|select * from t1 join t2 on
|t1.a = t2.a and (t1.b = t2.b or t1.c = t2.c or t1.d = t2.d)
|order by t1.a, t1.b, t1.c, t1.d
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })

sql = """
|select * from t1 join t2 on
|t1.a = t2.a and (t1.b = t2.b or t1.c = t2.c or (t1.c = t2.c and t1.d = t2.d))
|order by t1.a, t1.b, t1.c, t1.d
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })

sql = """
|select * from t1 join t2 on
|t1.a = t2.a and (t1.b = t2.b or t1.c = t2.c or (t1.d = t2.d and t1.c >= t2.c))
|order by t1.a, t1.b, t1.c, t1.d
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })

spark.sql("drop table t1")
spark.sql("drop table t2")
}
}
}
21 changes: 21 additions & 0 deletions cpp-ch/local-engine/Common/GlutenConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,27 @@ struct StreamingAggregateConfig
}
};

struct JoinConfig
{
/// If the join condition is like `t1.k = t2.k and (t1.id1 = t2.id2 or t1.id2 = t2.id2)`, try to join with multi
/// join on clauses `(t1.k = t2.k and t1.id1 = t2.id2) or (t1.k = t2.k or t1.id2 = t2.id2)`
inline static const String PREFER_MULTI_JOIN_ON_CLAUSES = "prefer_multi_join_on_clauses";
/// Only hash join supports multi join on clauses, the right table cannot be too large. If the row number of right
/// table is larger then this limit, this transform will not work.
inline static const String MULTI_JOIN_ON_CLAUSES_BUILD_SIDE_ROWS_LIMIT = "multi_join_on_clauses_build_side_row_limit";

bool prefer_multi_join_on_clauses = true;
size_t multi_join_on_clauses_build_side_rows_limit = 10000000;

static JoinConfig loadFromContext(DB::ContextPtr context)
{
JoinConfig config;
config.prefer_multi_join_on_clauses = context->getConfigRef().getBool(PREFER_MULTI_JOIN_ON_CLAUSES, true);
config.multi_join_on_clauses_build_side_rows_limit = context->getConfigRef().getUInt64(MULTI_JOIN_ON_CLAUSES_BUILD_SIDE_ROWS_LIMIT, 10000000);
return config;
}
};

struct ExecutorConfig
{
inline static const String DUMP_PIPELINE = "dump_pipeline";
Expand Down
23 changes: 23 additions & 0 deletions cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,24 @@ void tryAssign<bool>(const std::unordered_map<String, String> & kvs, const Strin
}
}

template<>
void tryAssign<Int64>(const std::unordered_map<String, String> & kvs, const String & key, Int64 & v)
{
auto it = kvs.find(key);
if (it != kvs.end())
{
try
{
v = std::stol(it->second);
}
catch (...)
{
LOG_ERROR(getLogger("tryAssign"), "Invalid number: {}", it->second);
throw;
}
}
}

template <char... chars>
void readStringUntilCharsInto(String & s, DB::ReadBuffer & buf)
{
Expand Down Expand Up @@ -121,6 +139,11 @@ JoinOptimizationInfo JoinOptimizationInfo::parse(const String & advance)
tryAssign(kvs, "buildHashTableId", info.storage_join_key);
tryAssign(kvs, "isNullAwareAntiJoin", info.is_null_aware_anti_join);
tryAssign(kvs, "isExistenceJoin", info.is_existence_join);
tryAssign(kvs, "leftRowCount", info.left_table_rows);
tryAssign(kvs, "leftSizeInBytes", info.left_table_bytes);
tryAssign(kvs, "rightRowCount", info.right_table_rows);
tryAssign(kvs, "rightSizeInBytes", info.right_table_bytes);
tryAssign(kvs, "numPartitions", info.partitions_num);
return info;
}
}
Expand Down
5 changes: 5 additions & 0 deletions cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ struct JoinOptimizationInfo
bool is_smj = false;
bool is_null_aware_anti_join = false;
bool is_existence_join = false;
Int64 left_table_rows = -1;
Int64 left_table_bytes = -1;
Int64 right_table_rows = -1;
Int64 right_table_bytes = -1;
Int64 partitions_num = -1;
String storage_join_key;

static JoinOptimizationInfo parse(const String & advance);
Expand Down
Loading
Loading