Skip to content

Commit

Permalink
[improve](function) add error msg if exceeded maximum default value i…
Browse files Browse the repository at this point in the history
…n repeat function
  • Loading branch information
zhangstar333 committed Mar 14, 2024
1 parent 3e07897 commit c85115d
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 25 deletions.
75 changes: 51 additions & 24 deletions be/src/vec/functions/function_string.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <ostream>
#include <random>
#include <sstream>
#include <stdexcept>
#include <tuple>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -1439,6 +1440,14 @@ class FunctionStringRepeat : public IFunction {
static FunctionPtr create() { return std::make_shared<FunctionStringRepeat>(); }
String get_name() const override { return name; }
size_t get_number_of_arguments() const override { return 2; }
std::string error_msg(int default_value, int repeat_value) const {
auto error_msg = fmt::format(
"The second parameter of repeat function exceeded maximum default value, "
"default_value is {}, and now input is {} . you could try change default value "
"greater than value eg: set repeat_max_num = {}.",
default_value, repeat_value, repeat_value + 10);
return error_msg;
}

DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
return std::make_shared<DataTypeString>();
Expand All @@ -1455,15 +1464,18 @@ class FunctionStringRepeat : public IFunction {

if (auto* col1 = check_and_get_column<ColumnString>(*argument_ptr[0])) {
if (auto* col2 = check_and_get_column<ColumnInt32>(*argument_ptr[1])) {
vector_vector(col1->get_chars(), col1->get_offsets(), col2->get_data(),
res->get_chars(), res->get_offsets(),
context->state()->repeat_max_num());
RETURN_IF_ERROR(vector_vector(
col1->get_chars(), col1->get_offsets(), col2->get_data(), res->get_chars(),
res->get_offsets(), context->state()->repeat_max_num()));
block.replace_by_position(result, std::move(res));
return Status::OK();
} else if (auto* col2_const = check_and_get_column<ColumnConst>(*argument_ptr[1])) {
DCHECK(check_and_get_column<ColumnInt32>(col2_const->get_data_column()));
int repeat = 0;
repeat = std::min<int>(col2_const->get_int(0), context->state()->repeat_max_num());
int repeat = col2_const->get_int(0);
if (repeat > context->state()->repeat_max_num()) {
return Status::InvalidArgument(
error_msg(context->state()->repeat_max_num(), repeat));
}

if (repeat <= 0) {
res->insert_many_defaults(input_rows_count);
Expand All @@ -1480,9 +1492,9 @@ class FunctionStringRepeat : public IFunction {
argument_ptr[0]->get_name(), argument_ptr[1]->get_name());
}

void vector_vector(const ColumnString::Chars& data, const ColumnString::Offsets& offsets,
const ColumnInt32::Container& repeats, ColumnString::Chars& res_data,
ColumnString::Offsets& res_offsets, const int repeat_max_num) const {
Status vector_vector(const ColumnString::Chars& data, const ColumnString::Offsets& offsets,
const ColumnInt32::Container& repeats, ColumnString::Chars& res_data,
ColumnString::Offsets& res_offsets, const int repeat_max_num) const {
size_t input_row_size = offsets.size();

fmt::memory_buffer buffer;
Expand All @@ -1491,9 +1503,10 @@ class FunctionStringRepeat : public IFunction {
buffer.clear();
const char* raw_str = reinterpret_cast<const char*>(&data[offsets[i - 1]]);
size_t size = offsets[i] - offsets[i - 1];
int repeat = 0;
repeat = std::min<int>(repeats[i], repeat_max_num);

int repeat = repeats[i];
if (repeat > repeat_max_num) {
return Status::InvalidArgument(error_msg(repeat_max_num, repeat));
}
if (repeat <= 0) {
StringOP::push_empty_string(i, res_data, res_offsets);
} else {
Expand All @@ -1504,6 +1517,7 @@ class FunctionStringRepeat : public IFunction {
res_data, res_offsets);
}
}
return Status::OK();
}

// TODO: 1. use pmr::vector<char> replace fmt_buffer may speed up the code
Expand Down Expand Up @@ -1536,7 +1550,14 @@ class FunctionStringRepeatOld : public IFunction {
static FunctionPtr create() { return std::make_shared<FunctionStringRepeatOld>(); }
String get_name() const override { return name; }
size_t get_number_of_arguments() const override { return 2; }

std::string error_msg(int default_value, int repeat_value) const {
auto error_msg = fmt::format(
"The second parameter of repeat function exceeded maximum default value, "
"default_value is {}, and now input is {} . you could try change default value "
"greater than value eg: set repeat_max_num = {}.",
default_value, repeat_value, repeat_value + 10);
return error_msg;
}
DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
return make_nullable(std::make_shared<DataTypeString>());
}
Expand All @@ -1553,17 +1574,20 @@ class FunctionStringRepeatOld : public IFunction {

if (auto* col1 = check_and_get_column<ColumnString>(*argument_ptr[0])) {
if (auto* col2 = check_and_get_column<ColumnInt32>(*argument_ptr[1])) {
vector_vector(col1->get_chars(), col1->get_offsets(), col2->get_data(),
res->get_chars(), res->get_offsets(), null_map->get_data(),
context->state()->repeat_max_num());
RETURN_IF_ERROR(vector_vector(col1->get_chars(), col1->get_offsets(),
col2->get_data(), res->get_chars(),
res->get_offsets(), null_map->get_data(),
context->state()->repeat_max_num()));
block.replace_by_position(
result, ColumnNullable::create(std::move(res), std::move(null_map)));
return Status::OK();
} else if (auto* col2_const = check_and_get_column<ColumnConst>(*argument_ptr[1])) {
DCHECK(check_and_get_column<ColumnInt32>(col2_const->get_data_column()));
int repeat = 0;
repeat = std::min<int>(col2_const->get_int(0), context->state()->repeat_max_num());

int repeat = col2_const->get_int(0);
if (repeat > context->state()->repeat_max_num()) {
return Status::InvalidArgument(
error_msg(context->state()->repeat_max_num(), repeat));
}
if (repeat <= 0) {
null_map->get_data().resize_fill(input_rows_count, 0);
res->insert_many_defaults(input_rows_count);
Expand All @@ -1581,10 +1605,10 @@ class FunctionStringRepeatOld : public IFunction {
argument_ptr[0]->get_name(), argument_ptr[1]->get_name());
}

void vector_vector(const ColumnString::Chars& data, const ColumnString::Offsets& offsets,
const ColumnInt32::Container& repeats, ColumnString::Chars& res_data,
ColumnString::Offsets& res_offsets, ColumnUInt8::Container& null_map,
const int repeat_max_num) const {
Status vector_vector(const ColumnString::Chars& data, const ColumnString::Offsets& offsets,
const ColumnInt32::Container& repeats, ColumnString::Chars& res_data,
ColumnString::Offsets& res_offsets, ColumnUInt8::Container& null_map,
const int repeat_max_num) const {
size_t input_row_size = offsets.size();

fmt::memory_buffer buffer;
Expand All @@ -1594,8 +1618,10 @@ class FunctionStringRepeatOld : public IFunction {
buffer.clear();
const char* raw_str = reinterpret_cast<const char*>(&data[offsets[i - 1]]);
size_t size = offsets[i] - offsets[i - 1];
int repeat = 0;
repeat = std::min<int>(repeats[i], repeat_max_num);
int repeat = repeats[i];
if (repeat > repeat_max_num) {
return Status::InvalidArgument(error_msg(repeat_max_num, repeat));
}

if (repeat <= 0) {
StringOP::push_empty_string(i, res_data, res_offsets);
Expand All @@ -1609,6 +1635,7 @@ class FunctionStringRepeatOld : public IFunction {
res_data, res_offsets);
}
}
return Status::OK();
}

// TODO: 1. use pmr::vector<char> replace fmt_buffer may speed up the code
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,12 @@ suite("test_string_basic") {
(2, repeat("test1111", 131072))
"""
order_qt_select_str_tb "select k1, md5(v1), length(v1) from ${tbName}"

try {
sql """ SELECT repeat("test1111", 131073 + 100); """
} catch (Exception ex) {
log.info(ex.getMessage());
assertTrue(ex.getMessage().contains("repeat function exceeded maximum default value"))
}
sql """drop table if exists test_string_cmp;"""

sql """
Expand Down

0 comments on commit c85115d

Please sign in to comment.