Skip to content

Commit

Permalink
[pipelineX](feature) support assert rows num operator (apache#23857)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gabriel39 authored Sep 4, 2023
1 parent d694f4a commit 21aea76
Show file tree
Hide file tree
Showing 12 changed files with 279 additions and 160 deletions.
87 changes: 87 additions & 0 deletions be/src/pipeline/exec/assert_num_rows_operator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "assert_num_rows_operator.h"

namespace doris::pipeline {

OperatorPtr AssertNumRowsOperatorBuilder::build_operator() {
return std::make_shared<AssertNumRowsOperator>(this, _node);
}

AssertNumRowsOperatorX::AssertNumRowsOperatorX(ObjectPool* pool, const TPlanNode& tnode,
const DescriptorTbl& descs)
: StreamingOperatorX<AssertNumRowsLocalState>(pool, tnode, descs),
_desired_num_rows(tnode.assert_num_rows_node.desired_num_rows),
_subquery_string(tnode.assert_num_rows_node.subquery_string) {
if (tnode.assert_num_rows_node.__isset.assertion) {
_assertion = tnode.assert_num_rows_node.assertion;
} else {
_assertion = TAssertion::LE; // just compatible for the previous code
}
}

Status AssertNumRowsOperatorX::pull(doris::RuntimeState* state, vectorized::Block* block,
SourceState& source_state) {
auto& local_state = state->get_local_state(id())->cast<AssertNumRowsLocalState>();
local_state.add_num_rows_returned(block->rows());
int64_t num_rows_returned = local_state.num_rows_returned();
bool assert_res = false;
switch (_assertion) {
case TAssertion::EQ:
assert_res = num_rows_returned == _desired_num_rows;
break;
case TAssertion::NE:
assert_res = num_rows_returned != _desired_num_rows;
break;
case TAssertion::LT:
assert_res = num_rows_returned < _desired_num_rows;
break;
case TAssertion::LE:
assert_res = num_rows_returned <= _desired_num_rows;
break;
case TAssertion::GT:
assert_res = num_rows_returned > _desired_num_rows;
break;
case TAssertion::GE:
assert_res = num_rows_returned >= _desired_num_rows;
break;
default:
break;
}

if (!assert_res) {
auto to_string_lambda = [](TAssertion::type assertion) {
std::map<int, const char*>::const_iterator it =
_TAssertion_VALUES_TO_NAMES.find(assertion);

if (it == _TAggregationOp_VALUES_TO_NAMES.end()) {
return "NULL";
} else {
return it->second;
}
};
LOG(INFO) << "Expected " << to_string_lambda(_assertion) << " " << _desired_num_rows
<< " to be returned by expression " << _subquery_string;
return Status::Cancelled("Expected {} {} to be returned by expression {}",
to_string_lambda(_assertion), _desired_num_rows, _subquery_string);
}
COUNTER_SET(local_state.rows_returned_counter(), local_state.num_rows_returned());
return Status::OK();
}

} // namespace doris::pipeline
30 changes: 26 additions & 4 deletions be/src/pipeline/exec/assert_num_rows_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#pragma once

#include "operator.h"
#include "pipeline/pipeline_x/operator.h"
#include "vec/exec/vassert_num_rows_node.h"

namespace doris {
Expand All @@ -38,9 +39,30 @@ class AssertNumRowsOperator final : public StreamingOperator<AssertNumRowsOperat
: StreamingOperator(operator_builder, node) {}
};

OperatorPtr AssertNumRowsOperatorBuilder::build_operator() {
return std::make_shared<AssertNumRowsOperator>(this, _node);
}
class AssertNumRowsLocalState final : public PipelineXLocalState<FakeDependency> {
public:
ENABLE_FACTORY_CREATOR(AssertNumRowsLocalState);

AssertNumRowsLocalState(RuntimeState* state, OperatorXBase* parent)
: PipelineXLocalState<FakeDependency>(state, parent) {}
~AssertNumRowsLocalState() = default;
};

class AssertNumRowsOperatorX final : public StreamingOperatorX<AssertNumRowsLocalState> {
public:
AssertNumRowsOperatorX(ObjectPool* pool, const TPlanNode& tnode, const DescriptorTbl& descs);

Status pull(RuntimeState* state, vectorized::Block* block, SourceState& source_state) override;

[[nodiscard]] bool is_source() const override { return false; }

private:
friend class AssertNumRowsLocalState;

int64_t _desired_num_rows;
const std::string _subquery_string;
TAssertion::type _assertion;
};

} // namespace pipeline
} // namespace doris
} // namespace doris
65 changes: 20 additions & 45 deletions be/src/pipeline/exec/hashjoin_probe_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,21 @@ Status HashJoinProbeLocalState::close(RuntimeState* state) {
if (_closed) {
return Status::OK();
}
std::visit(vectorized::Overload {[&](std::monostate&) {},
[&](auto&& process_hashtable_ctx) {
if (process_hashtable_ctx._arena) {
process_hashtable_ctx._arena.reset();
}

if (process_hashtable_ctx._serialize_key_arena) {
process_hashtable_ctx._serialize_key_arena.reset();
process_hashtable_ctx._serialized_key_buffer_size = 0;
}
}},
*_process_hashtable_ctx_variants);
if (_process_hashtable_ctx_variants) {
std::visit(vectorized::Overload {[&](std::monostate&) {},
[&](auto&& process_hashtable_ctx) {
if (process_hashtable_ctx._arena) {
process_hashtable_ctx._arena.reset();
}

if (process_hashtable_ctx._serialize_key_arena) {
process_hashtable_ctx._serialize_key_arena.reset();
process_hashtable_ctx._serialized_key_buffer_size =
0;
}
}},
*_process_hashtable_ctx_variants);
}
_shared_state->arena = nullptr;
_shared_state->hash_table_variants.reset();
_process_hashtable_ctx_variants = nullptr;
Expand Down Expand Up @@ -180,39 +183,10 @@ HashJoinProbeOperatorX::HashJoinProbeOperatorX(ObjectPool* pool, const TPlanNode
? tnode.hash_join_node.hash_output_slot_ids
: std::vector<SlotId> {}) {}

Status HashJoinProbeOperatorX::get_block(RuntimeState* state, vectorized::Block* block,
SourceState& source_state) {
auto& local_state = state->get_local_state(id())->cast<HashJoinProbeLocalState>();
local_state.init_for_probe(state);
if (need_more_input_data(state)) {
local_state._child_block->clear_column_data();
RETURN_IF_ERROR(_child_x->get_next_after_projects(state, local_state._child_block.get(),
local_state._child_source_state));
source_state = local_state._child_source_state;
if (local_state._child_block->rows() == 0 &&
local_state._child_source_state != SourceState::FINISHED) {
return Status::OK();
}
local_state.prepare_for_next();
RETURN_IF_ERROR(
push(state, local_state._child_block.get(), local_state._child_source_state));
}

if (!need_more_input_data(state)) {
RETURN_IF_ERROR(pull(state, block, source_state));
if (source_state != SourceState::FINISHED && !need_more_input_data(state)) {
source_state = SourceState::MORE_DATA;
} else if (source_state != SourceState::FINISHED &&
source_state == SourceState::MORE_DATA) {
source_state = local_state._child_source_state;
}
}
return Status::OK();
}

Status HashJoinProbeOperatorX::pull(doris::RuntimeState* state, vectorized::Block* output_block,
SourceState& source_state) {
SourceState& source_state) const {
auto& local_state = state->get_local_state(id())->cast<HashJoinProbeLocalState>();
local_state.init_for_probe(state);
SCOPED_TIMER(local_state._probe_timer);
if (local_state._shared_state->short_circuit_for_probe) {
// If we use a short-circuit strategy, should return empty block directly.
Expand Down Expand Up @@ -331,7 +305,7 @@ bool HashJoinProbeOperatorX::need_more_input_data(RuntimeState* state) const {
Status HashJoinProbeOperatorX::_do_evaluate(vectorized::Block& block,
vectorized::VExprContextSPtrs& exprs,
RuntimeProfile::Counter& expr_call_timer,
std::vector<int>& res_col_ids) {
std::vector<int>& res_col_ids) const {
for (size_t i = 0; i < exprs.size(); ++i) {
int result_col_id = -1;
// execute build column
Expand All @@ -349,8 +323,9 @@ Status HashJoinProbeOperatorX::_do_evaluate(vectorized::Block& block,
}

Status HashJoinProbeOperatorX::push(RuntimeState* state, vectorized::Block* input_block,
SourceState source_state) {
SourceState source_state) const {
auto& local_state = state->get_local_state(id())->cast<HashJoinProbeLocalState>();
local_state.prepare_for_next();
local_state._probe_eos = source_state == SourceState::FINISHED;
if (input_block->rows() > 0) {
COUNTER_UPDATE(local_state._probe_rows_counter, input_block->rows());
Expand Down
13 changes: 6 additions & 7 deletions be/src/pipeline/exec/hashjoin_probe_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,18 +111,17 @@ class HashJoinProbeOperatorX final : public JoinProbeOperatorX<HashJoinProbeLoca
Status open(RuntimeState* state) override;
bool can_read(RuntimeState* state) override;

Status get_block(RuntimeState* state, vectorized::Block* block,
SourceState& source_state) override;

Status push(RuntimeState* state, vectorized::Block* input_block, SourceState source_state);
Status push(RuntimeState* state, vectorized::Block* input_block,
SourceState source_state) const override;
Status pull(doris::RuntimeState* state, vectorized::Block* output_block,
SourceState& source_state);
SourceState& source_state) const override;

bool need_more_input_data(RuntimeState* state) const;
bool need_more_input_data(RuntimeState* state) const override;

private:
Status _do_evaluate(vectorized::Block& block, vectorized::VExprContextSPtrs& exprs,
RuntimeProfile::Counter& expr_call_timer, std::vector<int>& res_col_ids);
RuntimeProfile::Counter& expr_call_timer,
std::vector<int>& res_col_ids) const;
friend class HashJoinProbeLocalState;
friend struct vectorized::HashJoinProbeContext;

Expand Down
6 changes: 4 additions & 2 deletions be/src/pipeline/exec/join_probe_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class JoinProbeLocalState : public PipelineXLocalState<DependencyType> {
virtual void add_tuple_is_null_column(vectorized::Block* block) = 0;

protected:
template <typename LocalStateType>
friend class StatefulOperatorX;
JoinProbeLocalState(RuntimeState* state, OperatorXBase* parent)
: Base(state, parent),
_child_block(vectorized::Block::create_unique()),
Expand All @@ -62,9 +64,9 @@ class JoinProbeLocalState : public PipelineXLocalState<DependencyType> {
};

template <typename LocalStateType>
class JoinProbeOperatorX : public OperatorX<LocalStateType> {
class JoinProbeOperatorX : public StatefulOperatorX<LocalStateType> {
public:
using Base = OperatorX<LocalStateType>;
using Base = StatefulOperatorX<LocalStateType>;
JoinProbeOperatorX(ObjectPool* pool, const TPlanNode& tnode, const DescriptorTbl& descs);
virtual Status init(const TPlanNode& tnode, RuntimeState* state) override;

Expand Down
28 changes: 0 additions & 28 deletions be/src/pipeline/exec/nested_loop_join_probe_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,34 +509,6 @@ Status NestedLoopJoinProbeOperatorX::push(doris::RuntimeState* state, vectorized
return Status::OK();
}

Status NestedLoopJoinProbeOperatorX::get_block(RuntimeState* state, vectorized::Block* block,
SourceState& source_state) {
auto& local_state = state->get_local_state(id())->cast<NestedLoopJoinProbeLocalState>();
if (need_more_input_data(state)) {
local_state._child_block->clear_column_data();
RETURN_IF_ERROR(_child_x->get_next_after_projects(state, local_state._child_block.get(),
local_state._child_source_state));
source_state = local_state._child_source_state;
if (local_state._child_block->rows() == 0 &&
local_state._child_source_state != SourceState::FINISHED) {
return Status::OK();
}
RETURN_IF_ERROR(
push(state, local_state._child_block.get(), local_state._child_source_state));
}

if (!need_more_input_data(state)) {
RETURN_IF_ERROR(pull(state, block, source_state));
if (source_state != SourceState::FINISHED && !need_more_input_data(state)) {
source_state = SourceState::MORE_DATA;
} else if (source_state != SourceState::FINISHED &&
source_state == SourceState::MORE_DATA) {
source_state = local_state._child_source_state;
}
}
return Status::OK();
}

Status NestedLoopJoinProbeOperatorX::pull(RuntimeState* state, vectorized::Block* block,
SourceState& source_state) const {
auto& local_state = state->get_local_state(id())->cast<NestedLoopJoinProbeLocalState>();
Expand Down
9 changes: 3 additions & 6 deletions be/src/pipeline/exec/nested_loop_join_probe_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,10 @@ class NestedLoopJoinProbeOperatorX final
Status open(RuntimeState* state) override;
bool can_read(RuntimeState* state) override;

Status get_block(RuntimeState* state, vectorized::Block* block,
SourceState& source_state) override;

Status push(RuntimeState* state, vectorized::Block* input_block,
SourceState source_state) const;
SourceState source_state) const override;
Status pull(doris::RuntimeState* state, vectorized::Block* output_block,
SourceState& source_state) const;
SourceState& source_state) const override;
const RowDescriptor& intermediate_row_desc() const override {
return _old_version_flag ? _row_descriptor : *_intermediate_row_desc;
}
Expand All @@ -227,7 +224,7 @@ class NestedLoopJoinProbeOperatorX final
: *_output_row_desc;
}

bool need_more_input_data(RuntimeState* state) const;
bool need_more_input_data(RuntimeState* state) const override;

private:
friend class NestedLoopJoinProbeLocalState;
Expand Down
Loading

0 comments on commit 21aea76

Please sign in to comment.