From 9aa7e4b0f80165061bed9ae2d49d7dc7665270dc Mon Sep 17 00:00:00 2001 From: Maksim Kutakov Date: Sat, 16 Dec 2023 18:04:50 +0100 Subject: [PATCH] get_state test and code clean up --- .../memory_desc/cpu_blocked_memory_desc.cpp | 4 -- src/plugins/intel_cpu/src/memory_state.cpp | 46 ------------------- src/plugins/intel_cpu/src/memory_state.h | 22 --------- .../intel_cpu/src/nodes/scaled_attn.cpp | 26 +++++++---- .../src/concat_multiple_query_sdp.cpp | 8 ++++ .../src/concat_transpose_sdp_transpose.cpp | 8 ++++ 6 files changed, 33 insertions(+), 81 deletions(-) diff --git a/src/plugins/intel_cpu/src/memory_desc/cpu_blocked_memory_desc.cpp b/src/plugins/intel_cpu/src/memory_desc/cpu_blocked_memory_desc.cpp index 27ada57910b3bd..ab3df84a06f041 100644 --- a/src/plugins/intel_cpu/src/memory_desc/cpu_blocked_memory_desc.cpp +++ b/src/plugins/intel_cpu/src/memory_desc/cpu_blocked_memory_desc.cpp @@ -60,10 +60,6 @@ CpuBlockedMemoryDesc::CpuBlockedMemoryDesc(ov::element::Type prc, const Shape& s } } } else { - // TODO(BS): try to reuse allocated memory - // if (shape.hasZeroDims() && std::any_of(strides.begin(), strides.end(), [](size_t stride) { return stride != 0; } )) { - // OPENVINO_THROW("Can't create CpuBlockedMemoryDesc with zero dim, but with non zero strides"); - // } this->strides = strides; } diff --git a/src/plugins/intel_cpu/src/memory_state.cpp b/src/plugins/intel_cpu/src/memory_state.cpp index 62e5b17261bea0..525deb01070309 100644 --- a/src/plugins/intel_cpu/src/memory_state.cpp +++ b/src/plugins/intel_cpu/src/memory_state.cpp @@ -154,52 +154,6 @@ MemoryPtr VariableStateDoubleBuffer::internal_state_mem() const { return prime_mem(); } -VariableStateSingleBuffer::VariableStateSingleBuffer(const std::string& name, - const MemoryPtr& buffer, - const MemoryDescPtr& external_desc) : - VariableStateBase(name, external_desc) { - OPENVINO_ASSERT(buffer); - m_internal_mem = buffer; - m_internal_desc = m_internal_mem->getDescPtr(); - auto&& shape = m_internal_desc->getShape(); - //TODO what if by some reason we already have internal static state while the node is dynamic, is it even possible? - - if (shape.isStatic()) { - m_internal_mem->nullify(); - } else { - //in the case of the original desc has dynamic shape we create an empty tensor - auto new_desc = to_static(m_internal_desc); - m_internal_mem->redefineDesc(new_desc); - } -} - -void VariableStateSingleBuffer::reset_impl() { - auto new_desc = to_static(m_internal_desc); - m_internal_mem->redefineDesc(new_desc); - m_internal_mem->nullify(); -} - -MemoryPtr VariableStateSingleBuffer::input_mem() { - return m_internal_mem; -} - -MemoryPtr VariableStateSingleBuffer::output_mem() { - return m_internal_mem; -} - -MemoryDescPtr VariableStateSingleBuffer::internal_desc() const { - return m_internal_desc; -} - -MemoryPtr VariableStateSingleBuffer::internal_state_mem() const { - return m_internal_mem; -} - -void VariableStateSingleBuffer::commit_impl() { - //nothing to do -} - - VariableStateKVcache::VariableStateKVcache( const std::string& name, const MemoryDescPtr& external_desc, diff --git a/src/plugins/intel_cpu/src/memory_state.h b/src/plugins/intel_cpu/src/memory_state.h index 7bcae231f9fb0c..ef407bddaa802f 100644 --- a/src/plugins/intel_cpu/src/memory_state.h +++ b/src/plugins/intel_cpu/src/memory_state.h @@ -98,28 +98,6 @@ class VariableStateDoubleBuffer : public VariableStateBase { size_t buffer_num = 0; }; -class VariableStateSingleBuffer : public VariableStateBase { -public: - VariableStateSingleBuffer(const std::string& name, - const MemoryPtr& buffer, - const MemoryDescPtr& external_desc); - - MemoryPtr input_mem() override; - MemoryPtr output_mem() override; - MemoryDescPtr internal_desc() const override; - -private: - //ov::intel_cpu::VariableStateBase - void reset_impl() override; - void commit_impl() override; - - MemoryPtr internal_state_mem() const override; - -private: - MemoryDescPtr m_internal_desc; //mem desc required by the graph internal tensor - MemoryPtr m_internal_mem; -}; - class VariableStateKVcache : public VariableStateBase { public: VariableStateKVcache(const std::string& name, diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index a0f350ccab181d..bf30fd2e5c4878 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -916,8 +916,10 @@ void ScaledDotProductAttention::updateBeamTable(const MemoryPtr& mem_beam_idx, s } // second token itself for (size_t i = 0; i < B; i++) { - beam_table_k.at({i, L0}) = i; - beam_table_v.at({i, L0}) = i; + for (size_t j = 0; j < L1; j++) { + beam_table_k.at({i, L0 + j}) = i; + beam_table_v.at({i, L0 + j}) = i; + } } } @@ -962,7 +964,7 @@ void ScaledDotProductAttention::updatePastkv(const MemoryPtr& mem_cur_k, const M L0 = dims[order[2]]; OPENVINO_ASSERT(B == B_state, "pastkv batch: ", B, " is not equal to batch of state: ", B_state); } - + OPENVINO_ASSERT(B * (L0 + L1) > 0, "B or (L0+L1) is zero, B: ", B, ", L0: ", L0, ", L1: ", L1); // resize buffer if (B * H * (L0 + L1) * S > m_k_state->internal_state_max_size()) { auto new_shape = {B, H, (L0 + L1) * 2, S}; @@ -1013,13 +1015,19 @@ void ScaledDotProductAttention::updatePastkv(const MemoryPtr& mem_cur_k, const M past_v = past_v.permute(order); } if (L0 > 0 && is_reset) { - PlainTensor init_k, init_v; auto inputNumber = getOriginalInputsNumber(); - init_k.reset(getParentEdgeAt(inputNumber - 2)->getMemoryPtr()); - init_v.reset(getParentEdgeAt(inputNumber - 1)->getMemoryPtr()); - init_k = init_k.permute(order); - init_v = init_v.permute(order); - attn_memcpy(init_k, init_v, past_k, past_v); + auto k_mem = getParentEdgeAt(inputNumber - 2)->getMemoryPtr(); + auto v_mem = getParentEdgeAt(inputNumber - 1)->getMemoryPtr(); + auto&& k_shape = k_mem->getShape(); + auto&& v_shape = v_mem->getShape(); + if (!k_shape.hasZeroDims() && !v_shape.hasZeroDims()) { + PlainTensor init_k, init_v; + init_k.reset(k_mem); + init_v.reset(v_mem); + init_k = init_k.permute(order); + init_v = init_v.permute(order); + attn_memcpy(init_k, init_v, past_k, past_v); + } } attn_memcpy(cur_k, cur_v, past_k.slice(2, L0, L0 + L1), past_v.slice(2, L0, L0 + L1)); diff --git a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_multiple_query_sdp.cpp b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_multiple_query_sdp.cpp index 13d7c4123aaf6c..1c8ad07f8fd549 100644 --- a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_multiple_query_sdp.cpp +++ b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_multiple_query_sdp.cpp @@ -266,6 +266,14 @@ class ConcatMultiQuerySDPTest : public testing::WithParamInterface