Skip to content

Commit

Permalink
[Bug](function) fix wrong result on group_concat with distinct+order_…
Browse files Browse the repository at this point in the history
…by+nullable (#45313)

fix wrong result on group_concat with distinct+order_by+nullable
  • Loading branch information
BiteTheDDDDt authored and Your Name committed Dec 17, 2024
1 parent 820b300 commit f2d31aa
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 39 deletions.
6 changes: 4 additions & 2 deletions be/src/vec/aggregate_functions/aggregate_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

#pragma once

#include <utility>

#include "common/exception.h"
#include "common/status.h"
#include "util/defer_op.h"
Expand Down Expand Up @@ -80,7 +82,7 @@ using ConstAggregateDataPtr = const char*;
*/
class IAggregateFunction {
public:
IAggregateFunction(const DataTypes& argument_types_) : argument_types(argument_types_) {}
IAggregateFunction(DataTypes argument_types_) : argument_types(std::move(argument_types_)) {}

/// Get main function name.
virtual String get_name() const = 0;
Expand Down Expand Up @@ -224,7 +226,7 @@ class IAggregateFunction {

virtual void set_version(const int version_) { version = version_; }

virtual AggregateFunctionPtr transmit_to_stable() { return nullptr; }
virtual IAggregateFunction* transmit_to_stable() { return nullptr; }

/// Verify function signature
virtual Status verify_result_type(const bool without_key, const DataTypes& argument_types,
Expand Down
16 changes: 13 additions & 3 deletions be/src/vec/aggregate_functions/aggregate_function_distinct.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,20 @@ class AggregateFunctionDistinct

DataTypePtr get_return_type() const override { return nested_func->get_return_type(); }

AggregateFunctionPtr transmit_to_stable() override {
return AggregateFunctionPtr(new AggregateFunctionDistinct<Data, true>(
nested_func, IAggregateFunction::argument_types));
IAggregateFunction* transmit_to_stable() override {
return new AggregateFunctionDistinct<Data, true>(nested_func,
IAggregateFunction::argument_types);
}
};

template <typename T>
struct FunctionStableTransfer {
using FunctionStable = T;
};

template <template <bool stable> typename Data>
struct FunctionStableTransfer<AggregateFunctionDistinct<Data, false>> {
using FunctionStable = AggregateFunctionDistinct<Data, true>;
};

} // namespace doris::vectorized
39 changes: 33 additions & 6 deletions be/src/vec/aggregate_functions/aggregate_function_null.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "common/logging.h"
#include "common/status.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/aggregate_functions/aggregate_function_distinct.h"
#include "vec/columns/column_nullable.h"
#include "vec/common/assert_cast.h"
#include "vec/data_types/data_type_nullable.h"
Expand Down Expand Up @@ -165,7 +166,7 @@ class AggregateFunctionNullBaseInline : public IAggregateFunctionHelper<Derived>

void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
if constexpr (result_is_nullable) {
ColumnNullable& to_concrete = assert_cast<ColumnNullable&>(to);
auto& to_concrete = assert_cast<ColumnNullable&>(to);
if (get_flag(place)) {
nested_function->insert_result_into(nested_place(place),
to_concrete.get_nested_column());
Expand Down Expand Up @@ -197,7 +198,7 @@ class AggregateFunctionNullUnaryInline final

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena* arena) const override {
const ColumnNullable* column =
const auto* column =
assert_cast<const ColumnNullable*, TypeCheckOnRelease::DISABLE>(columns[0]);
if (!column->is_null_at(row_num)) {
this->set_flag(place);
Expand All @@ -206,6 +207,19 @@ class AggregateFunctionNullUnaryInline final
}
}

IAggregateFunction* transmit_to_stable() override {
auto f = AggregateFunctionNullBaseInline<
NestFuction, result_is_nullable,
AggregateFunctionNullUnaryInline<NestFuction, result_is_nullable>>::
nested_function->transmit_to_stable();
if (!f) {
return nullptr;
}
return new AggregateFunctionNullUnaryInline<
typename FunctionStableTransfer<NestFuction>::FunctionStable, result_is_nullable>(
f, IAggregateFunction::argument_types);
}

void add_batch(size_t batch_size, AggregateDataPtr* __restrict places, size_t place_offset,
const IColumn** columns, Arena* arena, bool agg_many) const override {
const auto* column = assert_cast<const ColumnNullable*>(columns[0]);
Expand Down Expand Up @@ -235,7 +249,7 @@ class AggregateFunctionNullUnaryInline final

void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
Arena* arena) const override {
const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]);
const auto* column = assert_cast<const ColumnNullable*>(columns[0]);
bool has_null = column->has_null();

if (has_null) {
Expand All @@ -252,7 +266,7 @@ class AggregateFunctionNullUnaryInline final

void add_batch_range(size_t batch_begin, size_t batch_end, AggregateDataPtr place,
const IColumn** columns, Arena* arena, bool has_null) override {
const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]);
const auto* column = assert_cast<const ColumnNullable*>(columns[0]);

if (has_null) {
for (size_t i = batch_begin; i <= batch_end; ++i) {
Expand Down Expand Up @@ -282,13 +296,13 @@ class AggregateFunctionNullVariadicInline final
nested_function_, arguments),
number_of_arguments(arguments.size()) {
if (number_of_arguments == 1) {
throw doris::Exception(
throw Exception(
ErrorCode::INTERNAL_ERROR,
"Logical error: single argument is passed to AggregateFunctionNullVariadic");
}

if (number_of_arguments > MAX_ARGS) {
throw doris::Exception(
throw Exception(
ErrorCode::INTERNAL_ERROR,
"Maximum number of arguments for aggregate function with Nullable types is {}",
size_t(MAX_ARGS));
Expand All @@ -299,6 +313,19 @@ class AggregateFunctionNullVariadicInline final
}
}

IAggregateFunction* transmit_to_stable() override {
auto f = AggregateFunctionNullBaseInline<
NestFuction, result_is_nullable,
AggregateFunctionNullVariadicInline<NestFuction, result_is_nullable>>::
nested_function->transmit_to_stable();
if (!f) {
return nullptr;
}
return new AggregateFunctionNullVariadicInline<
typename FunctionStableTransfer<NestFuction>::FunctionStable, result_is_nullable>(
f, IAggregateFunction::argument_types);
}

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena* arena) const override {
/// This container stores the columns we really pass to the nested function.
Expand Down
4 changes: 2 additions & 2 deletions be/src/vec/aggregate_functions/aggregate_function_sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ class AggregateFunctionSort
_arguments(arguments),
_sort_desc(sort_desc),
_state(state) {
if (auto f = _nested_func->transmit_to_stable(); f) {
_nested_func = f;
if (auto* f = _nested_func->transmit_to_stable(); f) {
_nested_func = AggregateFunctionPtr(f);
}
for (const auto& type : _arguments) {
_block.insert({type, ""});
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !test --
abc,abcd,eee

61 changes: 35 additions & 26 deletions regression-test/suites/nereids_p0/aggregate/agg_group_concat.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -75,30 +75,39 @@ suite("agg_group_concat") {
exception "doesn't support order by expression"
}

sql """select multi_distinct_sum(kint) from agg_group_concat_table;"""

sql """select group_concat(distinct kstr order by kint), group_concat(distinct kstr2 order by kbint) from agg_group_concat_table;"""
sql """select multi_distinct_group_concat(kstr order by kint), multi_distinct_group_concat(kstr2 order by kbint) from agg_group_concat_table;"""
sql """select group_concat(distinct kstr), group_concat(distinct kstr2) from agg_group_concat_table;"""
sql """select multi_distinct_group_concat(kstr), multi_distinct_group_concat(kstr2) from agg_group_concat_table;"""

sql """select group_concat(distinct kstr order by kint), group_concat(distinct kstr2 order by kbint) from agg_group_concat_table group by kbint;"""
sql """select multi_distinct_group_concat(kstr order by kint), multi_distinct_group_concat(kstr2 order by kbint) from agg_group_concat_table group by kbint;"""
sql """select group_concat(distinct kstr), group_concat(distinct kstr2) from agg_group_concat_table group by kbint;"""
sql """select multi_distinct_group_concat(kstr), multi_distinct_group_concat(kstr2) from agg_group_concat_table group by kbint;"""

sql """select group_concat(distinct kstr order by kbint), group_concat(distinct kstr2 order by kint) from agg_group_concat_table group by kint;"""
sql """select multi_distinct_group_concat(kstr order by kbint), multi_distinct_group_concat(kstr2 order by kint) from agg_group_concat_table group by kint;"""
sql """select group_concat(distinct kstr), group_concat(distinct kstr2) from agg_group_concat_table group by kint;"""
sql """select multi_distinct_group_concat(kstr), multi_distinct_group_concat(kstr2) from agg_group_concat_table group by kint;"""

sql """select group_concat(distinct kstr order by kint), group_concat(kstr2 order by kbint) from agg_group_concat_table;"""
sql """select multi_distinct_group_concat(kstr order by kint), group_concat(kstr2 order by kbint) from agg_group_concat_table;"""
sql """select group_concat(distinct kstr), group_concat(kstr2) from agg_group_concat_table;"""
sql """select multi_distinct_group_concat(kstr), group_concat(kstr2) from agg_group_concat_table;"""

sql """select group_concat(distinct kstr order by kint), group_concat(kstr2 order by kbint) from agg_group_concat_table group by kbint;"""
sql """select multi_distinct_group_concat(kstr order by kint), group_concat(kstr2 order by kbint) from agg_group_concat_table group by kbint;"""
sql """select group_concat(distinct kstr), group_concat(kstr2) from agg_group_concat_table group by kbint;"""
sql """select multi_distinct_group_concat(kstr), group_concat(kstr2) from agg_group_concat_table group by kbint;"""
sql """select multi_distinct_sum(kint) from agg_group_concat_table order by 1;"""

sql """select group_concat(distinct kstr order by kint), group_concat(distinct kstr2 order by kbint) from agg_group_concat_table order by 1,2;"""
sql """select multi_distinct_group_concat(kstr order by kint), multi_distinct_group_concat(kstr2 order by kbint) from agg_group_concat_table order by 1,2;"""
sql """select group_concat(distinct kstr), group_concat(distinct kstr2) from agg_group_concat_table order by 1,2;"""
sql """select multi_distinct_group_concat(kstr), multi_distinct_group_concat(kstr2) from agg_group_concat_table order by 1,2;"""

sql """select group_concat(distinct kstr order by kint), group_concat(distinct kstr2 order by kbint) from agg_group_concat_table group by kbint order by 1,2;"""
sql """select multi_distinct_group_concat(kstr order by kint), multi_distinct_group_concat(kstr2 order by kbint) from agg_group_concat_table group by kbint order by 1,2;"""
sql """select group_concat(distinct kstr), group_concat(distinct kstr2) from agg_group_concat_table group by kbint order by 1,2;"""
sql """select multi_distinct_group_concat(kstr), multi_distinct_group_concat(kstr2) from agg_group_concat_table group by kbint order by 1,2;"""

sql """select group_concat(distinct kstr order by kbint), group_concat(distinct kstr2 order by kint) from agg_group_concat_table group by kint order by 1,2;"""
sql """select multi_distinct_group_concat(kstr order by kbint), multi_distinct_group_concat(kstr2 order by kint) from agg_group_concat_table group by kint order by 1,2;"""
sql """select group_concat(distinct kstr), group_concat(distinct kstr2) from agg_group_concat_table group by kint order by 1,2;"""
sql """select multi_distinct_group_concat(kstr), multi_distinct_group_concat(kstr2) from agg_group_concat_table group by kint order by 1,2;"""

sql """select group_concat(distinct kstr order by kint), group_concat(kstr2 order by kbint) from agg_group_concat_table order by 1,2;"""
sql """select multi_distinct_group_concat(kstr order by kint), group_concat(kstr2 order by kbint) from agg_group_concat_table order by 1,2;"""
sql """select group_concat(distinct kstr), group_concat(kstr2) from agg_group_concat_table order by 1,2;"""
sql """select multi_distinct_group_concat(kstr), group_concat(kstr2) from agg_group_concat_table order by 1,2;"""

sql """select group_concat(distinct kstr order by kint), group_concat(kstr2 order by kbint) from agg_group_concat_table group by kbint order by 1,2;"""
sql """select multi_distinct_group_concat(kstr order by kint), group_concat(kstr2 order by kbint) from agg_group_concat_table group by kbint order by 1,2;"""
sql """select group_concat(distinct kstr), group_concat(kstr2) from agg_group_concat_table group by kbint order by 1,2;"""
sql """select multi_distinct_group_concat(kstr), group_concat(kstr2) from agg_group_concat_table group by kbint order by 1,2;"""

sql "drop table if exists test_distinct_multi"
sql """
create table test_distinct_multi(a int, b int, c int, d varchar(10), e date) distributed by hash(a) properties('replication_num'='1');
"""
sql """
insert into test_distinct_multi values(1,2,3,'abc','2024-01-02'),(1,2,4,'abc','2024-01-03'),(2,2,4,'abcd','2024-01-02'),(1,2,3,'abcd','2024-01-04'),(1,2,4,'eee','2024-02-02'),(2,2,4,'abc','2024-01-02');
"""
qt_test "select group_concat( distinct d order by d) from test_distinct_multi order by 1; "
}

0 comments on commit f2d31aa

Please sign in to comment.