diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala index 2dd45281e4169..9b6b2958ccc71 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala @@ -34,15 +34,12 @@ import com.google.protobuf.{Any, StringValue} import io.substrait.proto.JoinRel object JoinTypeTransform { - def toNativeJoinType(joinType: JoinType): JoinType = { - joinType match { - case ExistenceJoin(_) => - LeftSemi - case _ => - joinType - } - } + // ExistenceJoin is introduced in #SPARK-14781. It returns all rows from the left table with + // a new column to indecate whether the row is matched in the right table. + // Indeed, the ExistenceJoin is transformed into left any join in CH. + // We don't have left any join in substrait, so use left semi join instead. + // and isExistenceJoin is set to true to indicate that it is an existence join. def toSubstraitJoinType(sparkJoin: JoinType, buildRight: Boolean): JoinRel.JoinType = sparkJoin match { case _: InnerLike => @@ -104,7 +101,7 @@ case class CHShuffledHashJoinExecTransformer( override protected def doValidateInternal(): ValidationResult = { val shouldFallback = CHJoinValidateUtil.shouldFallback( - ShuffleHashJoinStrategy(finalJoinType), + ShuffleHashJoinStrategy(joinType), left.outputSet, right.outputSet, condition) @@ -113,7 +110,6 @@ case class CHShuffledHashJoinExecTransformer( } super.doValidateInternal() } - private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType) override def genJoinParameters(): Any = { val (isBHJ, isNullAwareAntiJoin, buildHashTableId): (Int, Int, String) = (0, 0, "") @@ -226,7 +222,7 @@ case class CHBroadcastHashJoinExecTransformer( override protected def doValidateInternal(): ValidationResult = { val shouldFallback = CHJoinValidateUtil.shouldFallback( - BroadcastHashJoinStrategy(finalJoinType), + BroadcastHashJoinStrategy(joinType), left.outputSet, right.outputSet, condition) @@ -255,7 +251,7 @@ case class CHBroadcastHashJoinExecTransformer( val context = BroadCastHashJoinContext( buildKeyExprs, - finalJoinType, + joinType, buildSide == BuildRight, isMixedCondition(condition), joinType.isInstanceOf[ExistenceJoin], @@ -278,12 +274,6 @@ case class CHBroadcastHashJoinExecTransformer( res } - // ExistenceJoin is introduced in #SPARK-14781. It returns all rows from the left table with - // a new column to indecate whether the row is matched in the right table. - // Indeed, the ExistenceJoin is transformed into left any join in CH. - // We don't have left any join in substrait, so use left semi join instead. - // and isExistenceJoin is set to true to indicate that it is an existence join. - private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType) override protected lazy val substraitJoinType: JoinRel.JoinType = { JoinTypeTransform.toSubstraitJoinType(joinType, buildSide == BuildRight) } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala index 0f5b5e2c4fd5a..2256ed4a40fbb 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala @@ -29,13 +29,10 @@ case class BroadcastHashJoinStrategy(joinType: JoinType) extends JoinStrategy {} case class SortMergeJoinStrategy(joinType: JoinType) extends JoinStrategy {} /** - * The logic here is that if it is not an equi-join spark will create BNLJ, which will fallback, if - * it is an equi-join, spark will create BroadcastHashJoin or ShuffleHashJoin, for these join types, - * we need to filter For cases that cannot be handled by the backend, 1 there are at least two - * different tables column and Literal in the condition Or condition for comparison, for example: (a - * join b on a.a1 = b.b1 and (a.a2 > 1 or b.b2 < 2) ) 2 tow join key for inequality comparison (!= , - * > , <), for example: (a join b on a.a1 > b.b1) There will be a fallback for Nullaware Jion For - * Existence Join which is just an optimization of exist subquery, it will also fallback + * BroadcastHashJoinStrategy and ShuffleHashJoinStrategy are relatively complete, They support + * left/right/inner full/anti/semi join, existence Join, and also support join contiditions with + * columns from both sides. e.g. (a join b on a.a1 = b.b1 and a.a2 > 1 and b.b2 < 2) + * SortMergeJoinStrategy is not fully supported for all cases in CH. */ object CHJoinValidateUtil extends Logging { @@ -52,33 +49,20 @@ object CHJoinValidateUtil extends Logging { leftOutputSet: AttributeSet, rightOutputSet: AttributeSet, condition: Option[Expression]): Boolean = { - var shouldFallback = false - val joinType = joinStrategy.joinType - if (!joinType.isInstanceOf[ExistenceJoin] && joinType.sql.contains("INNER")) { - shouldFallback = false; - } else if ( + val hasMixedFiltCondition = condition.isDefined && hasTwoTableColumn(leftOutputSet, rightOutputSet, condition.get) - ) { - shouldFallback = joinStrategy match { - case BroadcastHashJoinStrategy(joinTy) => - joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI") - case SortMergeJoinStrategy(_) => true - case ShuffleHashJoinStrategy(joinTy) => - joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI") - case UnknownJoinStrategy(joinTy) => - joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI") - } - } else { - shouldFallback = joinStrategy match { - case SortMergeJoinStrategy(joinTy) => - joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI") || joinTy.toString.contains( - "ExistenceJoin") - case _ => false - } + val shouldFallback = joinStrategy match { + case SortMergeJoinStrategy(joinType) => + joinType.sql.contains("SEMI") || joinType.sql.contains("ANTI") || joinType.toString + .contains("ExistenceJoin") || hasMixedFiltCondition + case UnknownJoinStrategy(joinType) => + throw new IllegalArgumentException(s"Unknown join type $joinStrategy") + case _ => false } + if (shouldFallback) { - logError(s"Fallback for join type $joinType") + logError(s"Fallback for join type $joinStrategy") } shouldFallback }