Skip to content

Commit

Permalink
try to use multi join on clause as possible
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Aug 14, 2024
1 parent fc7f9cd commit fe33925
Show file tree
Hide file tree
Showing 7 changed files with 417 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.utils.{BroadcastHashJoinStrategy, CHJoinValidateUtil, ShuffleHashJoinStrategy}

Expand All @@ -25,10 +26,13 @@ import org.apache.spark.rpc.GlutenDriverEndpoint
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive._
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.vectorized.ColumnarBatch

import com.google.protobuf.{Any, StringValue}
import io.substrait.proto.JoinRel

object JoinTypeTransform {
Expand Down Expand Up @@ -66,6 +70,8 @@ object JoinTypeTransform {
}
}

case class ShuffleStageStaticstics(numPartitions: Int, numMappers: Int, rowCount: Option[BigInt])

case class CHShuffledHashJoinExecTransformer(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
Expand Down Expand Up @@ -104,6 +110,75 @@ case class CHShuffledHashJoinExecTransformer(
private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType)
override protected lazy val substraitJoinType: JoinRel.JoinType =
JoinTypeTransform.toSubstraitType(joinType, buildSide)

override def genJoinParameters(): Any = {
val (isBHJ, isNullAwareAntiJoin, buildHashTableId): (Int, Int, String) = (0, 0, "")

// Don't use lef/right directly, they may be reordered in `HashJoinLikeExecTransformer`
val leftStats = getShuffleStageStatistics(streamedPlan)
val rightStats = getShuffleStageStatistics(buildPlan)
// Start with "JoinParameters:"
val joinParametersStr = new StringBuffer("JoinParameters:")
// isBHJ: 0 for SHJ, 1 for BHJ
// isNullAwareAntiJoin: 0 for false, 1 for true
// buildHashTableId: the unique id for the hash table of build plan
joinParametersStr
.append("isBHJ=")
.append(isBHJ)
.append("\n")
.append("isNullAwareAntiJoin=")
.append(isNullAwareAntiJoin)
.append("\n")
.append("buildHashTableId=")
.append(buildHashTableId)
.append("\n")
.append("isExistenceJoin=")
.append(if (joinType.isInstanceOf[ExistenceJoin]) 1 else 0)
.append("\n")
.append("leftRowCount=")
.append(leftStats.rowCount.getOrElse(-1))
.append("\n")
.append("leftNumPartitions=")
.append(leftStats.numPartitions)
.append("\n")
.append("leftNumMappers=")
.append(leftStats.numMappers)
.append("\n")
.append("rightRowCount=")
.append(rightStats.rowCount.getOrElse(-1))
.append("\n")
.append("rightNumPartitions=")
.append(rightStats.numPartitions)
.append("\n")
.append("rightNumMappers=")
.append(rightStats.numMappers)
.append("\n")
val message = StringValue
.newBuilder()
.setValue(joinParametersStr.toString)
.build()
BackendsApiManager.getTransformerApiInstance.packPBMessage(message)
}

private def getShuffleStageStatistics(plan: SparkPlan): ShuffleStageStaticstics = {
plan match {
case queryStage: ShuffleQueryStageExec =>
ShuffleStageStaticstics(
queryStage.shuffle.numPartitions,
queryStage.shuffle.numMappers,
queryStage.getRuntimeStatistics.rowCount)
case shuffle: ColumnarShuffleExchangeExec =>
// FIXEME: We cannot access shuffle.numPartitions and shuffle.numMappers here.
// Otherwise it will cause an exception `ProjectExecTransformer has column support mismatch`
ShuffleStageStaticstics(-1, -1, None)
case _ =>
if (plan.children.length == 1) {
getShuffleStageStatistics(plan.children.head)
} else {
ShuffleStageStaticstics(-1, -1, None)
}
}
}
}

case class CHBroadcastBuildSideRDD(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class GlutenClickHouseColumnarShuffleAQESuite
override protected val tablesPath: String = basePath + "/tpch-data-ch"
override protected val tpchQueries: String = rootPath + "queries/tpch-queries-ch"
override protected val queriesResults: String = rootPath + "mergetree-queries-output"
private val backendConfigPrefix = "spark.gluten.sql.columnar.backend.ch"

/** Run Gluten + ClickHouse Backend with ColumnarShuffleManager */
override protected def sparkConf: SparkConf = {
Expand Down Expand Up @@ -261,4 +262,36 @@ class GlutenClickHouseColumnarShuffleAQESuite
spark.sql("drop table t2")
}
}

test("GLUTEN-6768 change mixed join condition into multi join on clauses") {
withSQLConf(
(backendConfigPrefix + "runtime_config.prefer_inequal_join_to_multi_join_on_clauses", "true"),
(
backendConfigPrefix + "runtime_config.inequal_join_to_multi_join_on_clauses_row_limit",
"1000000")
) {

spark.sql("create table t1(a int, b int, c int, d int) using parquet")
spark.sql("create table t2(a int, b int, c int, d int) using parquet")

spark.sql("""
|insert into t1
|select id % 100 as a, id as b, id + 1 as c, id + 2 as d from range(1000)
|""".stripMargin)
spark.sql("""
|insert into t2
|select id % 100 as a, id as b, id + 1 as c, id + 2 as d from range(1000)
|""".stripMargin)

val sql = """
|select * from t1 join t2 on
|t1.a = t2.a and (t1.b = t2.b or t1.c = t2.c or t1.d = t2.d)
|order by t1.a, t1.b, t1.c, t1.d
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })

spark.sql("drop table t1")
spark.sql("drop table t2")
}
}
}
21 changes: 21 additions & 0 deletions cpp-ch/local-engine/Common/GlutenConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,27 @@ struct StreamingAggregateConfig
}
};

struct JoinConfig
{
/// If the join condition is like `t1.k = t2.k and (t1.id1 = t2.id2 or t1.id2 = t2.id2)`, try to join with multi
/// join on clauses `(t1.k = t2.k and t1.id1 = t2.id2) or (t1.k = t2.k or t1.id2 = t2.id2)`
inline static const String PREFER_INEQUAL_JOIN_TO_MULTI_JOIN_ON_CLAUSES = "prefer_inequal_join_to_multi_join_on_clauses";
/// Only hash join supports multi join on clauses, the right table cannot be too large. If the row number of right
/// table is larger then this limit, this transform will not work.
inline static const String INEQUAL_JOIN_TO_MULTI_JOIN_ON_CLAUSES_ROWS_LIMIT = "inequal_join_to_multi_join_on_clauses_row_limit";

bool prefer_inequal_join_to_multi_join_on_clauses = true;
size_t inequal_join_to_multi_join_on_clauses_rows_limit = 10000000;

static JoinConfig loadFromContext(DB::ContextPtr context)
{
JoinConfig config;
config.prefer_inequal_join_to_multi_join_on_clauses = context->getConfigRef().getBool(PREFER_INEQUAL_JOIN_TO_MULTI_JOIN_ON_CLAUSES, true);
config.inequal_join_to_multi_join_on_clauses_rows_limit = context->getConfigRef().getUInt64(INEQUAL_JOIN_TO_MULTI_JOIN_ON_CLAUSES_ROWS_LIMIT, 10000000);
return config;
}
};

struct ExecutorConfig
{
inline static const String DUMP_PIPELINE = "dump_pipeline";
Expand Down
16 changes: 16 additions & 0 deletions cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ void tryAssign<bool>(const std::unordered_map<String, String> & kvs, const Strin
}
}

template<>
void tryAssign<Int64>(const std::unordered_map<String, String> & kvs, const String & key, Int64 & v)
{
auto it = kvs.find(key);
if (it != kvs.end())
{
v = std::stol(it->second);
}
}

template <char... chars>
void readStringUntilCharsInto(String & s, DB::ReadBuffer & buf)
{
Expand Down Expand Up @@ -121,6 +131,12 @@ JoinOptimizationInfo JoinOptimizationInfo::parse(const String & advance)
tryAssign(kvs, "buildHashTableId", info.storage_join_key);
tryAssign(kvs, "isNullAwareAntiJoin", info.is_null_aware_anti_join);
tryAssign(kvs, "isExistenceJoin", info.is_existence_join);
tryAssign(kvs, "leftRowCount", info.left_table_rows);
tryAssign(kvs, "leftNumPartitions", info.left_table_partitions_num);
tryAssign(kvs, "leftNumMappers", info.left_table_mappers_num);
tryAssign(kvs, "rightRowCount", info.right_table_rows);
tryAssign(kvs, "rightNumPartitions", info.right_table_partitions_num);
tryAssign(kvs, "rightNumMappers", info.right_table_mappers_num);
return info;
}
}
Expand Down
6 changes: 6 additions & 0 deletions cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ struct JoinOptimizationInfo
bool is_smj = false;
bool is_null_aware_anti_join = false;
bool is_existence_join = false;
Int64 left_table_rows = -1;
Int64 left_table_partitions_num = -1;
Int64 left_table_mappers_num = -1;
Int64 right_table_rows = -1;
Int64 right_table_partitions_num = -1;
Int64 right_table_mappers_num = -1;
String storage_join_key;

static JoinOptimizationInfo parse(const String & advance);
Expand Down
Loading

0 comments on commit fe33925

Please sign in to comment.