From 048bc998db0bee5cd968ce6bd324e8c711439097 Mon Sep 17 00:00:00 2001 From: Gabriel Date: Thu, 18 Apr 2024 14:27:40 +0800 Subject: [PATCH] [refactor](partitioner) refine get channel id logics (#33765) --- be/src/pipeline/exec/exchange_sink_operator.cpp | 4 ++-- be/src/pipeline/exec/exchange_sink_operator.h | 3 +-- .../partitioned_hash_join_probe_operator.cpp | 2 +- .../partitioned_hash_join_sink_operator.cpp | 4 ++-- .../local_exchange/local_exchanger.cpp | 2 +- be/src/vec/runtime/partitioner.h | 17 +++++++++++++++-- be/src/vec/sink/vdata_stream_sender.cpp | 4 ++-- 7 files changed, 24 insertions(+), 12 deletions(-) diff --git a/be/src/pipeline/exec/exchange_sink_operator.cpp b/be/src/pipeline/exec/exchange_sink_operator.cpp index 2c37f24eac4237..580e8e525d65ca 100644 --- a/be/src/pipeline/exec/exchange_sink_operator.cpp +++ b/be/src/pipeline/exec/exchange_sink_operator.cpp @@ -493,11 +493,11 @@ Status ExchangeSinkOperatorX::sink(RuntimeState* state, vectorized::Block* block if (_part_type == TPartitionType::HASH_PARTITIONED) { RETURN_IF_ERROR(channel_add_rows( state, local_state.channels, local_state._partition_count, - (uint32_t*)local_state._partitioner->get_channel_ids(), rows, block, eos)); + local_state._partitioner->get_channel_ids().get(), rows, block, eos)); } else { RETURN_IF_ERROR(channel_add_rows( state, local_state.channel_shared_ptrs, local_state._partition_count, - (uint32_t*)local_state._partitioner->get_channel_ids(), rows, block, eos)); + local_state._partitioner->get_channel_ids().get(), rows, block, eos)); } } else if (_part_type == TPartitionType::TABLET_SINK_SHUFFLE_PARTITIONED) { // check out of limit diff --git a/be/src/pipeline/exec/exchange_sink_operator.h b/be/src/pipeline/exec/exchange_sink_operator.h index 9c40242cd030be..f275365c0e85a3 100644 --- a/be/src/pipeline/exec/exchange_sink_operator.h +++ b/be/src/pipeline/exec/exchange_sink_operator.h @@ -76,8 +76,7 @@ class ExchangeSinkLocalState final : public PipelineXSinkLocalState<> { : _partitioner(partitioner) {} int get_partition(vectorized::Block* block, int position) { - uint32_t* partition_ids = (uint32_t*)_partitioner->get_channel_ids(); - return partition_ids[position]; + return _partitioner->get_channel_ids().get()[position]; } private: diff --git a/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp b/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp index 78dcaf1e6c5385..0f57a03fc64507 100644 --- a/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp +++ b/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp @@ -535,7 +535,7 @@ Status PartitionedHashJoinProbeOperatorX::push(RuntimeState* state, vectorized:: } std::vector partition_indexes[_partition_count]; - auto* channel_ids = reinterpret_cast(local_state._partitioner->get_channel_ids()); + auto* channel_ids = local_state._partitioner->get_channel_ids().get(); for (uint32_t i = 0; i != rows; ++i) { partition_indexes[channel_ids[i]].emplace_back(i); } diff --git a/be/src/pipeline/exec/partitioned_hash_join_sink_operator.cpp b/be/src/pipeline/exec/partitioned_hash_join_sink_operator.cpp index c9d61757461c5b..d0ca832630e5fa 100644 --- a/be/src/pipeline/exec/partitioned_hash_join_sink_operator.cpp +++ b/be/src/pipeline/exec/partitioned_hash_join_sink_operator.cpp @@ -145,7 +145,7 @@ Status PartitionedHashJoinSinkLocalState::_revoke_unpartitioned_block(RuntimeSta } auto& p = _parent->cast(); SCOPED_TIMER(_partition_shuffle_timer); - auto* channel_ids = reinterpret_cast(_partitioner->get_channel_ids()); + auto* channel_ids = _partitioner->get_channel_ids().get(); auto& partitioned_blocks = _shared_state->partitioned_build_blocks; std::vector partition_indices; @@ -293,7 +293,7 @@ Status PartitionedHashJoinSinkLocalState::_partition_block(RuntimeState* state, auto& p = _parent->cast(); SCOPED_TIMER(_partition_shuffle_timer); - auto* channel_ids = reinterpret_cast(_partitioner->get_channel_ids()); + auto* channel_ids = _partitioner->get_channel_ids().get(); std::vector partition_indexes[p._partition_count]; DCHECK_LT(begin, end); for (size_t i = begin; i != end; ++i) { diff --git a/be/src/pipeline/pipeline_x/local_exchange/local_exchanger.cpp b/be/src/pipeline/pipeline_x/local_exchange/local_exchanger.cpp index da395fefdd5eb2..0837a1212b98f6 100644 --- a/be/src/pipeline/pipeline_x/local_exchange/local_exchanger.cpp +++ b/be/src/pipeline/pipeline_x/local_exchange/local_exchanger.cpp @@ -32,7 +32,7 @@ Status ShuffleExchanger::sink(RuntimeState* state, vectorized::Block* in_block, { SCOPED_TIMER(local_state._distribute_timer); RETURN_IF_ERROR(_split_rows(state, - (const uint32_t*)local_state._partitioner->get_channel_ids(), + local_state._partitioner->get_channel_ids().get(), in_block, eos, local_state)); } diff --git a/be/src/vec/runtime/partitioner.h b/be/src/vec/runtime/partitioner.h index 66ed8809d7ce7c..8d715a41285800 100644 --- a/be/src/vec/runtime/partitioner.h +++ b/be/src/vec/runtime/partitioner.h @@ -26,6 +26,17 @@ class MemTracker; namespace vectorized { +struct ChannelField { + const void* channel_id; + const uint32_t len; + + template + const T* get() const { + CHECK_EQ(sizeof(T), len) << " sizeof(T): " << sizeof(T) << " len: " << len; + return reinterpret_cast(channel_id); + } +}; + class PartitionerBase { public: PartitionerBase(size_t partition_count) : _partition_count(partition_count) {} @@ -40,7 +51,7 @@ class PartitionerBase { virtual Status do_partitioning(RuntimeState* state, Block* block, MemTracker* mem_tracker) const = 0; - virtual void* get_channel_ids() const = 0; + virtual ChannelField get_channel_ids() const = 0; virtual Status clone(RuntimeState* state, std::unique_ptr& partitioner) = 0; @@ -67,7 +78,9 @@ class Partitioner : public PartitionerBase { Status do_partitioning(RuntimeState* state, Block* block, MemTracker* mem_tracker) const override; - void* get_channel_ids() const override { return _hash_vals.data(); } + ChannelField get_channel_ids() const override { + return {_hash_vals.data(), sizeof(HashValueType)}; + } protected: Status _get_partition_column_result(Block* block, std::vector& result) const { diff --git a/be/src/vec/sink/vdata_stream_sender.cpp b/be/src/vec/sink/vdata_stream_sender.cpp index ce6a5317fd4be6..69b7054f5005c8 100644 --- a/be/src/vec/sink/vdata_stream_sender.cpp +++ b/be/src/vec/sink/vdata_stream_sender.cpp @@ -739,11 +739,11 @@ Status VDataStreamSender::send(RuntimeState* state, Block* block, bool eos) { } if (_part_type == TPartitionType::HASH_PARTITIONED) { RETURN_IF_ERROR(channel_add_rows(state, _channels, _partition_count, - (uint64_t*)_partitioner->get_channel_ids(), rows, + _partitioner->get_channel_ids().get(), rows, block, _enable_pipeline_exec ? eos : false)); } else { RETURN_IF_ERROR(channel_add_rows(state, _channel_shared_ptrs, _partition_count, - (uint32_t*)_partitioner->get_channel_ids(), rows, + _partitioner->get_channel_ids().get(), rows, block, _enable_pipeline_exec ? eos : false)); } } else if (_part_type == TPartitionType::TABLET_SINK_SHUFFLE_PARTITIONED) {