diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java index a18ff819d36676..1d4c73e2a60d99 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java @@ -35,6 +35,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue; import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinction; +import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsBigInt; import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsInt; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.plans.Plan; @@ -48,6 +49,7 @@ import org.apache.doris.nereids.util.PlanUtils.CollectNonWindowedAggFuncs; import org.apache.doris.nereids.util.Utils; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.ImmutableSet; @@ -135,38 +137,74 @@ private boolean canCompress(Expression expression) { return false; } - private LogicalAggregate encode(LogicalAggregate aggregate, Optional> having) { - List newGroupByExpressions = Lists.newArrayList(); - List encodedExpressions = Lists.newArrayList(); - Map encodeMap = Maps.newHashMap(); - for (Expression gp : aggregate.getGroupByExpressions()) { - if (gp instanceof SlotReference && canCompress(gp)) { - Alias alias = new Alias(new EncodeAsInt(gp), ((SlotReference) gp).getName()); - newGroupByExpressions.add(alias); - encodedExpressions.add(gp); - encodeMap.put(gp, alias); - } else { - newGroupByExpressions.add(gp); + /* + example: + [support] select sum(v) from t group by substring(k, 1,2) + [not support] select substring(k, 1,2), sum(v) from t group by substring(k, 1,2) + [support] select k, sum(v) from t group by k + [not support] select substring(k, 1,2), sum(v) from t group by k + [support] select A as B from T group by A + */ + private Set getEncodableGroupByExpressions(LogicalAggregate aggregate) { + Set encodableGroupbyExpressions = Sets.newHashSet(); + Set slotShouldNotEncode = Sets.newHashSet(); + for (NamedExpression ne : aggregate.getOutputExpressions()) { + if (ne instanceof Alias) { + Expression child = ((Alias) ne).child(); + //support: select A as B from T group by A + if (!(child instanceof SlotReference)) { + slotShouldNotEncode.addAll(child.getInputSlots()); + } + } + } + for (Expression gb : aggregate.getGroupByExpressions()) { + if (canCompress(gb)) { + boolean encodable = true; + for (Slot gbs : gb.getInputSlots()) { + if (slotShouldNotEncode.contains(gbs)) { + encodable = false; + break; + } + } + if (encodable) { + encodableGroupbyExpressions.add(gb); + } } } - if (!encodedExpressions.isEmpty()) { - // aggregate = aggregate.withGroupByExpressions(newGroupByExpressions); - // boolean hasNewOutput = false; + return encodableGroupbyExpressions; + } + + private LogicalAggregate encode(LogicalAggregate aggregate, Optional> having) { + List encodedExpressions = Lists.newArrayList(); + Set encodableGroupByExpressions = getEncodableGroupByExpressions(aggregate); + if (!encodableGroupByExpressions.isEmpty()) { + List newGroupByExpressions = Lists.newArrayList(); + List encodedGroupByExpressions = Lists.newArrayList(); + for (Expression gp : aggregate.getGroupByExpressions()) { + if (encodableGroupByExpressions.contains(gp)) { + Alias alias = new Alias(new EncodeAsBigInt(gp)); + newGroupByExpressions.add(alias); + encodedExpressions.add(alias); + } else { + newGroupByExpressions.add(gp); + } + } List newOutput = Lists.newArrayList(); - List output = aggregate.getOutputExpressions(); - for (NamedExpression ne : output) { - if (ne instanceof SlotReference && encodedExpressions.contains(ne)) { + for (NamedExpression ne : aggregate.getOutputExpressions()) { + if (ne instanceof SlotReference && encodableGroupByExpressions.contains(ne)) { newOutput.add(new Alias(ne.getExprId(), new AnyValue(ne), ne.getName())); - newOutput.add(encodeMap.get(ne)); - // hasNewOutput = true; + } else if (ne instanceof Alias && encodableGroupByExpressions.contains(((Alias) ne).child())) { + Expression child = ((Alias) ne).child(); + Preconditions.checkArgument(child instanceof SlotReference, + "encode %s failed, not a slot", child); + newOutput.add(new Alias(((SlotReference) child).getExprId(), new AnyValue(child), + "any_value(" + child + ")")); } else { newOutput.add(ne); } } + newOutput.addAll(encodedExpressions); aggregate = aggregate.withGroupByAndOutput(newGroupByExpressions, newOutput); - // if (hasNewOutput) { - // aggregate = aggregate.withAggOutput(newOutput); - // } } return aggregate; } diff --git a/regression-test/data/nereids_p0/compress_materialize.out b/regression-test/data/nereids_p0/compress_materialize.out new file mode 100644 index 00000000000000..54b2d6f4c29913 --- /dev/null +++ b/regression-test/data/nereids_p0/compress_materialize.out @@ -0,0 +1,28 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !agg_exec -- +aaaaa +bbbbb + +-- !not_support -- +aaa +bbb + +-- !not_support -- +aaa +bbb + +-- !encodeexpr -- +12 +3 + +-- !join -- +PhysicalResultSink +--hashJoin[INNER_JOIN broadcast] hashCondition=((T.k = cmt2.k2)) otherCondition=() build RFs:RF0 k2->[k] +----PhysicalProject +------hashAgg[GLOBAL] +--------PhysicalDistribute[DistributionSpecHash] +----------hashAgg[LOCAL] +------------PhysicalProject +--------------PhysicalOlapScan[compress] apply RFs: RF0 +----PhysicalOlapScan[cmt2] + diff --git a/regression-test/suites/nereids_p0/compress_materialize.groovy b/regression-test/suites/nereids_p0/compress_materialize.groovy new file mode 100644 index 00000000000000..ce2f03cfeb5a44 --- /dev/null +++ b/regression-test/suites/nereids_p0/compress_materialize.groovy @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("compress_materialize") { + sql """ + drop table if exists compress; + CREATE TABLE `compress` ( + `k` varchar(5) NOT NULL, + `v` int NOT NULL + ) ENGINE=OLAP + duplicate KEY(`k`) + DISTRIBUTED BY HASH(`k`) BUCKETS AUTO + PROPERTIES ( + "replication_num" = "1" + ); + + + insert into compress values ("aaaaaa", 1), ("aaaaaa", 2), ("bbbbb", 3), ("bbbbb", 4), ("bbbbb", 5); + + + drop table if exists cmt2; + CREATE TABLE `cmt2` ( + `k2` varchar(5) NOT NULL, + `v2` int NOT NULL + ) ENGINE=OLAP + duplicate KEY(`k2`) + DISTRIBUTED BY random + PROPERTIES ( + "replication_num" = "1" + ); + + insert into cmt2 values ("aaaa", 1), ("b", 3); +insert into cmt2 values("123456", 123456); + """ + +// expected explain contains partial_any_value(k) +// 3:VAGGREGATE (merge finalize)(167) +// | output: any_value(partial_any_value(k)[#5])[#7] +// | group by: k[#4] +// | sortByGroupKey:false +// | cardinality=1 +// | final projections: k[#7] +// | final project output tuple id: 4 +// | distribute expr lists: k[#4] + explain{ + sql (""" + select k from compress group by k; + """) + contains("any_value(partial_any_value(k)") + } + order_qt_agg_exec "select k from compress group by k;" + order_qt_not_support """ select substring(k,1,3) from compress group by substring(k,1,3);""" + order_qt_not_support """ select substring(k,1,3) from compress group by k;""" + + explain { + sql("select sum(v) from compress group by substring(k, 1, 3);") + contains("group by: encode_as_bigint(substring(k, 1, 3))") + } + order_qt_encodeexpr "select sum(v) from compress group by substring(k, 1, 3);" + + + // verify that compressed materialization do not block runtime filter generation + sql """ + set disable_join_reorder=true; + set runtime_filter_mode = GLOBAL; + set runtime_filter_type=2; + set enable_runtime_filter_prune=false; + """ + + qt_join """ + explain shape plan + select * + from ( + select k from compress group by k + ) T join cmt2 on T.k = cmt2.k2; + """ +} +