Skip to content

Commit

Permalink
[fix](hash join) fix stack overflow caused by evaluate case expr on h…
Browse files Browse the repository at this point in the history
…uge build block (apache#28851)
  • Loading branch information
jacktengg authored Dec 22, 2023
1 parent cb61a07 commit d75300f
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 29 deletions.
18 changes: 12 additions & 6 deletions be/src/pipeline/exec/hashjoin_build_sink.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,6 @@ Status HashJoinBuildSinkLocalState::process_build_block(RuntimeState* state,
vectorized::ColumnRawPtrs raw_ptrs(_build_expr_ctxs.size());

vectorized::ColumnUInt8::MutablePtr null_map_val;
std::vector<int> res_col_ids(_build_expr_ctxs.size());
RETURN_IF_ERROR(_do_evaluate(block, _build_expr_ctxs, *_build_expr_call_timer, res_col_ids));
if (p._join_op == TJoinOp::LEFT_OUTER_JOIN || p._join_op == TJoinOp::FULL_OUTER_JOIN) {
_convert_block_to_null(block);
// first row is mocked
Expand All @@ -247,15 +245,15 @@ Status HashJoinBuildSinkLocalState::process_build_block(RuntimeState* state,
// so we have to initialize this flag by the first build block.
if (!_has_set_need_null_map_for_build) {
_has_set_need_null_map_for_build = true;
_set_build_ignore_flag(block, res_col_ids);
_set_build_ignore_flag(block, _build_col_ids);
}
if (p._short_circuit_for_null_in_build_side || _build_side_ignore_null) {
null_map_val = vectorized::ColumnUInt8::create();
null_map_val->get_data().assign(rows, (uint8_t)0);
}

// Get the key column that needs to be built
Status st = _extract_join_column(block, null_map_val, raw_ptrs, res_col_ids);
Status st = _extract_join_column(block, null_map_val, raw_ptrs, _build_col_ids);

st = std::visit(
Overload {[&](std::monostate& arg, auto join_op, auto has_null_value,
Expand Down Expand Up @@ -458,13 +456,21 @@ Status HashJoinBuildSinkOperatorX::sink(RuntimeState* state, vectorized::Block*
if (local_state._build_side_mutable_block.empty()) {
auto tmp_build_block = vectorized::VectorizedUtils::create_empty_columnswithtypename(
_child_x->row_desc());
tmp_build_block = *(tmp_build_block.create_same_struct_block(1, false));
local_state._build_col_ids.resize(_build_expr_ctxs.size());
RETURN_IF_ERROR(local_state._do_evaluate(tmp_build_block, local_state._build_expr_ctxs,
*local_state._build_expr_call_timer,
local_state._build_col_ids));
local_state._build_side_mutable_block =
vectorized::MutableBlock::build_mutable_block(&tmp_build_block);
RETURN_IF_ERROR(local_state._build_side_mutable_block.merge(
*(tmp_build_block.create_same_struct_block(1, false))));
}

if (in_block->rows() != 0) {
std::vector<int> res_col_ids(_build_expr_ctxs.size());
RETURN_IF_ERROR(local_state._do_evaluate(*in_block, local_state._build_expr_ctxs,
*local_state._build_expr_call_timer,
res_col_ids));

SCOPED_TIMER(local_state._build_side_merge_block_timer);
RETURN_IF_ERROR(local_state._build_side_mutable_block.merge(*in_block));
if (local_state._build_side_mutable_block.rows() >
Expand Down
1 change: 1 addition & 0 deletions be/src/pipeline/exec/hashjoin_build_sink.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class HashJoinBuildSinkLocalState final
bool _build_side_ignore_null = false;
std::unordered_set<const vectorized::Block*> _inserted_blocks;
std::shared_ptr<SharedHashTableDependency> _shared_hash_table_dependency;
std::vector<int> _build_col_ids;

RuntimeProfile::Counter* _build_table_timer = nullptr;
RuntimeProfile::Counter* _build_expr_call_timer = nullptr;
Expand Down
3 changes: 2 additions & 1 deletion be/src/vec/columns/column_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,8 @@ ColumnPtr ColumnVector<T>::replicate(const IColumn::Offsets& offsets) const {
res_data.reserve(offsets.back());

// vectorized this code to speed up
IColumn::Offset counts[size];
auto counts_uptr = std::unique_ptr<IColumn::Offset[]>(new IColumn::Offset[size]);
IColumn::Offset* counts = counts_uptr.get();
for (ssize_t i = 0; i < size; ++i) {
counts[i] = offsets[i] - offsets[i - 1];
}
Expand Down
16 changes: 10 additions & 6 deletions be/src/vec/exec/join/vhash_join_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,12 +725,18 @@ Status HashJoinNode::sink(doris::RuntimeState* state, vectorized::Block* in_bloc
if (_build_side_mutable_block.empty()) {
auto tmp_build_block =
VectorizedUtils::create_empty_columnswithtypename(child(1)->row_desc());
tmp_build_block = *(tmp_build_block.create_same_struct_block(1, false));
_build_col_ids.resize(_build_expr_ctxs.size());
RETURN_IF_ERROR(_do_evaluate(tmp_build_block, _build_expr_ctxs, *_build_expr_call_timer,
_build_col_ids));
_build_side_mutable_block = MutableBlock::build_mutable_block(&tmp_build_block);
RETURN_IF_ERROR(_build_side_mutable_block.merge(
*(tmp_build_block.create_same_struct_block(1, false))));
}

if (in_block->rows() != 0) {
std::vector<int> res_col_ids(_build_expr_ctxs.size());
RETURN_IF_ERROR(_do_evaluate(*in_block, _build_expr_ctxs, *_build_expr_call_timer,
res_col_ids));

SCOPED_TIMER(_build_side_merge_block_timer);
RETURN_IF_ERROR(_build_side_mutable_block.merge(*in_block));
if (_build_side_mutable_block.rows() > JOIN_BUILD_SIZE_LIMIT) {
Expand Down Expand Up @@ -952,8 +958,6 @@ Status HashJoinNode::_process_build_block(RuntimeState* state, Block& block) {
ColumnRawPtrs raw_ptrs(_build_expr_ctxs.size());

ColumnUInt8::MutablePtr null_map_val;
std::vector<int> res_col_ids(_build_expr_ctxs.size());
RETURN_IF_ERROR(_do_evaluate(block, _build_expr_ctxs, *_build_expr_call_timer, res_col_ids));
if (_join_op == TJoinOp::LEFT_OUTER_JOIN || _join_op == TJoinOp::FULL_OUTER_JOIN) {
_convert_block_to_null(block);
// first row is mocked
Expand All @@ -969,15 +973,15 @@ Status HashJoinNode::_process_build_block(RuntimeState* state, Block& block) {
// so we have to initialize this flag by the first build block.
if (!_has_set_need_null_map_for_build) {
_has_set_need_null_map_for_build = true;
_set_build_ignore_flag(block, res_col_ids);
_set_build_ignore_flag(block, _build_col_ids);
}
if (_short_circuit_for_null_in_build_side || _build_side_ignore_null) {
null_map_val = ColumnUInt8::create();
null_map_val->get_data().assign(rows, (uint8_t)0);
}

// Get the key column that needs to be built
Status st = _extract_join_column<true>(block, null_map_val, raw_ptrs, res_col_ids);
Status st = _extract_join_column<true>(block, null_map_val, raw_ptrs, _build_col_ids);

st = std::visit(
Overload {[&](std::monostate& arg, auto join_op, auto has_null_value,
Expand Down
1 change: 1 addition & 0 deletions be/src/vec/exec/join/vhash_join_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ class HashJoinNode final : public VJoinNodeBase {

std::vector<IRuntimeFilter*> _runtime_filters;
std::atomic_bool _probe_open_finish = false;
std::vector<int> _build_col_ids;
};
} // namespace vectorized
} // namespace doris
5 changes: 3 additions & 2 deletions be/src/vec/functions/function_binary_arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,8 @@ struct DecimalBinaryOperation {
make_bool_variant(need_adjust_scale && check_overflow));

if (OpTraits::is_multiply && need_adjust_scale && !check_overflow) {
int8_t sig[size];
auto sig_uptr = std::unique_ptr<int8_t[]>(new int8_t[size]);
int8_t* sig = sig_uptr.get();
for (size_t i = 0; i < size; i++) {
sig[i] = sgn(c[i].value);
}
Expand Down Expand Up @@ -917,7 +918,7 @@ class FunctionBinaryArithmetic : public IFunction {
if constexpr (!std::is_same_v<ResultDataType, InvalidType>) {
need_replace_null_data_to_default_ =
IsDataTypeDecimal<ResultDataType> ||
(name == "pow" &&
(get_name() == "pow" &&
std::is_floating_point_v<typename ResultDataType::FieldType>);
if constexpr (IsDataTypeDecimal<LeftDataType> &&
IsDataTypeDecimal<RightDataType>) {
Expand Down
21 changes: 11 additions & 10 deletions be/src/vec/functions/function_case.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ class FunctionCase : public IFunction {
int rows_count = column_holder.rows_count;

// `then` data index corresponding to each row of results, 0 represents `else`.
int then_idx[rows_count];
int* __restrict then_idx_ptr = then_idx;
memset(then_idx_ptr, 0, sizeof(then_idx));
auto then_idx_uptr = std::unique_ptr<int[]>(new int[rows_count]);
int* __restrict then_idx_ptr = then_idx_uptr.get();
memset(then_idx_ptr, 0, rows_count * sizeof(int));

for (int row_idx = 0; row_idx < column_holder.rows_count; row_idx++) {
for (int i = 1; i < column_holder.pair_count; i++) {
Expand Down Expand Up @@ -189,7 +189,7 @@ class FunctionCase : public IFunction {
}

auto result_column_ptr = data_type->create_column();
update_result_normal<int, ColumnType, then_null>(result_column_ptr, then_idx,
update_result_normal<int, ColumnType, then_null>(result_column_ptr, then_idx_ptr,
column_holder);
block.replace_by_position(result, std::move(result_column_ptr));
return Status::OK();
Expand All @@ -206,9 +206,9 @@ class FunctionCase : public IFunction {
int rows_count = column_holder.rows_count;

// `then` data index corresponding to each row of results, 0 represents `else`.
uint8_t then_idx[rows_count];
uint8_t* __restrict then_idx_ptr = then_idx;
memset(then_idx_ptr, 0, sizeof(then_idx));
auto then_idx_uptr = std::unique_ptr<uint8_t[]>(new uint8_t[rows_count]);
uint8_t* __restrict then_idx_ptr = then_idx_uptr.get();
memset(then_idx_ptr, 0, rows_count);

auto case_column_ptr = column_holder.when_ptrs[0].value_or(nullptr);

Expand Down Expand Up @@ -245,13 +245,13 @@ class FunctionCase : public IFunction {
}
}

return execute_update_result<ColumnType, then_null>(data_type, result, block, then_idx,
return execute_update_result<ColumnType, then_null>(data_type, result, block, then_idx_ptr,
column_holder);
}

template <typename ColumnType, bool then_null>
Status execute_update_result(const DataTypePtr& data_type, size_t result, Block& block,
uint8* then_idx, CaseWhenColumnHolder& column_holder) const {
const uint8* then_idx, CaseWhenColumnHolder& column_holder) const {
auto result_column_ptr = data_type->create_column();

if constexpr (std::is_same_v<ColumnType, ColumnString> ||
Expand Down Expand Up @@ -282,7 +282,8 @@ class FunctionCase : public IFunction {
}

template <typename IndexType, typename ColumnType, bool then_null>
void update_result_normal(MutableColumnPtr& result_column_ptr, IndexType* then_idx,
void update_result_normal(MutableColumnPtr& result_column_ptr,
const IndexType* __restrict then_idx,
CaseWhenColumnHolder& column_holder) const {
std::vector<uint8_t> is_consts(column_holder.then_ptrs.size());
std::vector<ColumnPtr> raw_columns(column_holder.then_ptrs.size());
Expand Down
30 changes: 27 additions & 3 deletions be/src/vec/functions/function_string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ class FunctionTrim : public IFunction {
}
};

static constexpr int MAX_STACK_CIPHER_LEN = 1024 * 64;
struct UnHexImpl {
static constexpr auto name = "unhex";
using ReturnType = DataTypeString;
Expand Down Expand Up @@ -654,8 +655,16 @@ struct UnHexImpl {
continue;
}

char dst_array[MAX_STACK_CIPHER_LEN];
char* dst = dst_array;

int cipher_len = srclen / 2;
char dst[cipher_len];
std::unique_ptr<char[]> dst_uptr;
if (cipher_len > MAX_STACK_CIPHER_LEN) {
dst_uptr.reset(new char[cipher_len]);
dst = dst_uptr.get();
}

int outlen = hex_decode(source, srclen, dst);

if (outlen < 0) {
Expand Down Expand Up @@ -725,8 +734,16 @@ struct ToBase64Impl {
continue;
}

char dst_array[MAX_STACK_CIPHER_LEN];
char* dst = dst_array;

int cipher_len = (int)(4.0 * ceil((double)srclen / 3.0));
char dst[cipher_len];
std::unique_ptr<char[]> dst_uptr;
if (cipher_len > MAX_STACK_CIPHER_LEN) {
dst_uptr.reset(new char[cipher_len]);
dst = dst_uptr.get();
}

int outlen = base64_encode((const unsigned char*)source, srclen, (unsigned char*)dst);

if (outlen < 0) {
Expand Down Expand Up @@ -765,8 +782,15 @@ struct FromBase64Impl {
continue;
}

char dst_array[MAX_STACK_CIPHER_LEN];
char* dst = dst_array;

int cipher_len = srclen;
char dst[cipher_len];
std::unique_ptr<char[]> dst_uptr;
if (cipher_len > MAX_STACK_CIPHER_LEN) {
dst_uptr.reset(new char[cipher_len]);
dst = dst_uptr.get();
}
int outlen = base64_decode(source, srclen, dst);

if (outlen < 0) {
Expand Down
3 changes: 2 additions & 1 deletion be/src/vec/functions/multiply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ struct MultiplyImpl {
static void vector_vector(const ColumnDecimal128::Container::value_type* __restrict a,
const ColumnDecimal128::Container::value_type* __restrict b,
ColumnDecimal128::Container::value_type* c, size_t size) {
int8 sgn[size];
auto sng_uptr = std::unique_ptr<int8[]>(new int8[size]);
int8* sgn = sng_uptr.get();
auto max = DecimalV2Value::get_max_decimal();
auto min = DecimalV2Value::get_min_decimal();

Expand Down

0 comments on commit d75300f

Please sign in to comment.