From 65945a0c262b0037aec80ed4e57ae08ddd0e06ba Mon Sep 17 00:00:00 2001 From: morrySnow <101034200+morrySnow@users.noreply.github.com> Date: Thu, 14 Mar 2024 11:44:29 +0800 Subject: [PATCH] [opt](Nereids) support cast agg state type as legacy planner (#32198) --- .../expression/ExpressionNormalization.java | 2 + .../expression/rules/ConvertAggStateCast.java | 81 +++++++++++++++++++ .../doris/nereids/types/AggStateType.java | 12 +++ .../nereids/test_agg_state_nereids.groovy | 3 +- 4 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ConvertAggStateCast.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java index 23fca6b77b0af5..9886cb1787e9ed 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules.expression; import org.apache.doris.nereids.rules.expression.check.CheckCast; +import org.apache.doris.nereids.rules.expression.rules.ConvertAggStateCast; import org.apache.doris.nereids.rules.expression.rules.DigitalMaskingConvert; import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule; import org.apache.doris.nereids.rules.expression.rules.InPredicateDedup; @@ -53,6 +54,7 @@ public class ExpressionNormalization extends ExpressionRewrite { DigitalMaskingConvert.INSTANCE, SimplifyArithmeticComparisonRule.INSTANCE, SupportJavaDateFormatter.INSTANCE, + ConvertAggStateCast.INSTANCE, CheckCast.INSTANCE ); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ConvertAggStateCast.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ConvertAggStateCast.java new file mode 100644 index 00000000000000..e5748eb1d59e2c --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ConvertAggStateCast.java @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator; +import org.apache.doris.nereids.trees.expressions.functions.scalar.NonNullable; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Nullable; +import org.apache.doris.nereids.types.AggStateType; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.util.TypeCoercionUtils; + +import com.google.common.collect.ImmutableList; + +/** + * Follow legacy planner cast agg_state combinator's children if we need cast it to another agg_state type when insert + */ +public class ConvertAggStateCast extends AbstractExpressionRewriteRule { + + public static ConvertAggStateCast INSTANCE = new ConvertAggStateCast(); + + @Override + public Expression visitCast(Cast cast, ExpressionRewriteContext context) { + Expression child = cast.child(); + DataType originalType = child.getDataType(); + DataType targetType = cast.getDataType(); + if (originalType instanceof AggStateType + && targetType instanceof AggStateType + && child instanceof StateCombinator) { + AggStateType original = (AggStateType) originalType; + AggStateType target = (AggStateType) targetType; + if (original.getSubTypes().size() != target.getSubTypes().size()) { + return processCastChild(cast, context); + } + if (!original.getFunctionName().equalsIgnoreCase(target.getFunctionName())) { + return processCastChild(cast, context); + } + ImmutableList.Builder newChildren = ImmutableList.builderWithExpectedSize(child.arity()); + for (int i = 0; i < child.arity(); i++) { + Expression newChild = TypeCoercionUtils.castIfNotSameType(child.child(i), target.getSubTypes().get(i)); + if (newChild.nullable() != target.getSubTypeNullables().get(i)) { + if (newChild.nullable()) { + newChild = new NonNullable(newChild); + } else { + newChild = new Nullable(newChild); + } + } + newChildren.add(newChild); + } + child = child.withChildren(newChildren.build()); + return processCastChild(cast.withChildren(ImmutableList.of(child)), context); + } + return processCastChild(cast, context); + } + + private Expression processCastChild(Cast cast, ExpressionRewriteContext context) { + Expression child = visit(cast.child(), context); + if (child != cast.child()) { + cast = cast.withChildren(ImmutableList.of(child)); + } + return cast; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/AggStateType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/AggStateType.java index 6acde37b74f400..6680a6ccee0827 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/AggStateType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/AggStateType.java @@ -66,6 +66,18 @@ public List getMockedExpressions() { return result; } + public List getSubTypes() { + return subTypes; + } + + public List getSubTypeNullables() { + return subTypeNullables; + } + + public String getFunctionName() { + return functionName; + } + @Override public Type toCatalogDataType() { List types = subTypes.stream().map(t -> t.toCatalogDataType()).collect(Collectors.toList()); diff --git a/regression-test/suites/datatype_p0/agg_state/nereids/test_agg_state_nereids.groovy b/regression-test/suites/datatype_p0/agg_state/nereids/test_agg_state_nereids.groovy index c7a0a6d748dab5..3adbfe9e43e968 100644 --- a/regression-test/suites/datatype_p0/agg_state/nereids/test_agg_state_nereids.groovy +++ b/regression-test/suites/datatype_p0/agg_state/nereids/test_agg_state_nereids.groovy @@ -55,7 +55,8 @@ suite("test_agg_state_nereids") { properties("replication_num" = "1"); """ - sql 'set enable_fallback_to_original_planner=true' + sql "explain insert into a_table select 1,max_by_state(1,3);" + sql "insert into a_table select 1,max_by_state(1,3);" sql "insert into a_table select 1,max_by_state(2,2);" sql "insert into a_table select 1,max_by_state(3,1);"