diff --git a/src/plugins/intel_cpu/src/memory_state.cpp b/src/plugins/intel_cpu/src/memory_state.cpp index 14bf36fbded16a..8b73d3947cd2e6 100644 --- a/src/plugins/intel_cpu/src/memory_state.cpp +++ b/src/plugins/intel_cpu/src/memory_state.cpp @@ -313,5 +313,15 @@ void VariableStateKVcache::assign_hidden_state(const MemoryPtr& mem) { m_hidden_state = mem; } +void VariableStateKVcache::init_from(const MemoryPtr& init_val) { + // TODO + OPENVINO_ASSERT(false, "Implement VariableStateKVcache::init_from"); +} + +void VariableStateKVcache::gather_concat_pastkv(MemoryPtr cur_kv, MemoryPtr beam_idx) { + // TODO + OPENVINO_ASSERT(false, "Implement VariableStateKVcache::gather_concat_pastkv"); +} + } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/memory_state.h b/src/plugins/intel_cpu/src/memory_state.h index 0dc32e084f75eb..234e91c6acaa40 100644 --- a/src/plugins/intel_cpu/src/memory_state.h +++ b/src/plugins/intel_cpu/src/memory_state.h @@ -140,6 +140,9 @@ class VariableStateKVcache : public VariableStateBase { MemoryPtr hidden_state_mem() const; void assign_hidden_state(const MemoryPtr& mem); + void init_from(const MemoryPtr& init_val); + void gather_concat_pastkv(MemoryPtr cur_kv, MemoryPtr beam_idx); + private: //ov::intel_cpu::VariableStateBase void set_state_impl(const ov::SoPtr& state) override; diff --git a/src/plugins/intel_cpu/src/nodes/memory.cpp b/src/plugins/intel_cpu/src/nodes/memory.cpp index 27fc9d058a344c..03f98f18c44e9b 100644 --- a/src/plugins/intel_cpu/src/nodes/memory.cpp +++ b/src/plugins/intel_cpu/src/nodes/memory.cpp @@ -13,6 +13,7 @@ #include "memory_desc/dnnl_blocked_memory_desc.h" #include "utils/ngraph_utils.hpp" #include "shape_inference/shape_inference_pass_through.hpp" +#include "shape_inference/shape_inference_internal_dyn.hpp" #include "common/arbitrary_order_desc_creator.h" using namespace dnnl; @@ -668,20 +669,73 @@ MemoryInputSDPA::MemoryInputSDPA(const std::string id, const ov::optional& input_shape, const ov::optional& input_prc, const std::shared_ptr& sdpaNode) : - MemoryInputBase(id, name, type, output_shape, output_prc, context, input_shape, input_prc), m_sdpaNode(sdpaNode) {} + MemoryInputBase(id, name, type, output_shape, output_prc, context, input_shape, input_prc), m_sdpaNode(sdpaNode) { + if (isDynamic) { + // 2 scenarios: + // 1, after reset(first token) + // a, if there is init-subgraph, the shape should be got from init subgraph + // b, if there is no init-subgraph, the shape will be computed from state + // 2, second token: the shape will be computed from state + // since the source is determined by the condition, can use InternalDynShapeInferFactory to + // dynamicly get the shape + shapeInference = InternalDynShapeInferFactory().makeShapeInfer(); + } +} void MemoryInputSDPA::createPrimitive() { MemoryInputBase::createPrimitive(); -// determine the output node idx -// child_port_idx = + // determine the output node idx + auto memDesc = getBaseMemDescAtOutputPort(0); + auto sdpaNode = m_sdpaNode.lock(); + for (auto&& edge : getChildEdgesAtPort(0)) { // always only one child port + auto node = edge->getChild(); + if (node == sdpaNode) { + child_port_idx = edge->getOutputNum(); + break; + } + } + OPENVINO_ASSERT(child_port_idx != -1, getName(), " should connect to SDPA node."); +} + +void MemoryInputSDPA::initSupportedPrimitiveDescriptors() { + if (!supportedPrimitiveDescriptors.empty()) + return; + + auto&& shape = getOutputShapeAtPort(0); + auto precision = getOriginalOutputPrecisionAtPort(0); + auto&& descCreators = ov::intel_cpu::BlockedDescCreator::getCommonCreators(); + NodeConfig config; + if (!getParentEdges().empty()) { + PortConfig inPortConfig; + inPortConfig.inPlace(-1); + inPortConfig.constant(false); + inPortConfig.setMemDesc(descCreators.at(LayoutType::ncsp)->createSharedDesc(precision, shape)); + config.inConfs.push_back(std::move(inPortConfig)); + } + + auto node = m_sdpaNode.lock(); + // retrieve the internal precision and axis order from the SDPA node + auto kv_precision = node->getKVCachePrecision(); + VectorDims order = {0, 1, 2, 3}; + if (!node->getKVCacheOrder().empty()) + order = node->getKVCacheOrder(); + ArbitraryOrderDescCreator cabdDescCreator(order); + + PortConfig outPortConfig; + // output edge will be a fake memory obj, real memory is stored in state + outPortConfig.inPlace(-1); + outPortConfig.constant(false); + outPortConfig.setMemDesc(cabdDescCreator.createSharedDesc(kv_precision, shape)); + config.outConfs.push_back(std::move(outPortConfig)); + supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown); } void MemoryInputSDPA::assignState(MemStatePtr newState) { auto sdpaNode = m_sdpaNode.lock(); OPENVINO_ASSERT(sdpaNode); - auto sdpaState = std::dynamic_pointer_cast(newState); - OPENVINO_ASSERT(sdpaState); - sdpaNode->assignState(sdpaState, child_port_idx); + m_sdpaState = std::dynamic_pointer_cast(newState); + OPENVINO_ASSERT(m_sdpaState); + sdpaNode->assignState(m_sdpaState, child_port_idx); } MemStatePtr MemoryInputSDPA::makeState() const { @@ -699,25 +753,33 @@ MemStatePtr MemoryInputSDPA::makeState() const { state_name = state_name.substr(0, suffix_idx); } - // somehow retrieve the internal precision and axis order from the SDPA node - // m_sdpa->get_kv_precision(); - // m_sdpa->get_kv_axis_order(); - - auto kv_precision = element::bf16; + auto node = m_sdpaNode.lock(); + // retrieve the internal precision and axis order from the SDPA node + auto kv_precision = node->getKVCachePrecision(); VectorDims order = {0, 1, 2, 3}; + if (!node->getKVCacheOrder().empty()) + order = node->getKVCacheOrder(); auto internal_desc = ArbitraryOrderDescCreator(order).createSharedDesc(kv_precision, outputShapes.at(0)); return std::make_shared(state_name, original_desc, internal_desc); } -bool MemoryInputSDPA::isExecutable() const { - // this node is mostly a proxy to transfer the state to the SDPA - // so the SDPA itself should handle the reset state as it handles the memory manipulation - return false; -} - void MemoryInputSDPA::execute(dnnl::stream strm) { + if (m_sdpaState->is_reset_state()) { + // has init subgraph + if (!getParentEdges().empty()) { + auto input = getParentEdgeAt(0)->getMemoryPtr(); + m_sdpaState->init_from(input); + } + } + // 1, if in reset: + // a, if has init subgraph, the state shape will be defined after init_from + // b, if no init subgraph, the state shape will be defined in VariableStateKVcache::reset + // 2, if not in reset, the state will be updated in sdpa::infer call + // Update the shape to to fake memoryobj, shapeof can get the correct shape then + this->redefineOutputMemory(0, m_sdpaState->internal_desc()->getShape().getStaticDims()); + return; } diff --git a/src/plugins/intel_cpu/src/nodes/memory.hpp b/src/plugins/intel_cpu/src/nodes/memory.hpp index 96a0195c90eaa7..ea0aa7a9396f66 100644 --- a/src/plugins/intel_cpu/src/nodes/memory.hpp +++ b/src/plugins/intel_cpu/src/nodes/memory.hpp @@ -211,18 +211,20 @@ class MemoryInputSDPA : public MemoryInputBase { static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; void createPrimitive() override; + void initSupportedPrimitiveDescriptors() override; - bool isExecutable() const override; void execute(dnnl::stream strm) override; void resolveInPlaceEdges(Edge::LOOK look) override; void assignState(MemStatePtr newState) override; MemStatePtr makeState() const override; + bool needShapeInfer() const override { return false; } private: std::weak_ptr m_sdpaNode; - int child_port_idx; + std::shared_ptr m_sdpaState; + int child_port_idx = -1; }; } // namespace node } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index 6dcbedbf1473f6..6f0f815f2c43b0 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -517,12 +517,6 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt past_k_output.reset(outputs[1]); past_v_output.reset(outputs[2]); attn_memcpy(k_input, v_input, past_k_output.slice(2, L0, L0 + L1), past_v_output.slice(2, L0, L0 + L1)); - if (!config.is_concat_inplaced) { - PlainTensor past_k_input, past_v_input; - past_k_input.reset(past_k_mem); - past_v_input.reset(inputs[past_k_idx + 1]); - attn_memcpy(past_k_input, past_v_input, past_k_output, past_v_output); - } } else { // k,v inputs are already concatenated L0 = k_input.size(2) - L1; @@ -533,16 +527,21 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt } } - void execute(dnnl::stream strm, const std::vector& inputs, const std::vector& outputs) override { + void execute(dnnl::stream strm, const std::vector& inputs, const MemoryPtr output, const MemoryPtr presentk_input, + const MemoryPtr presentv_input, const MemoryPtr beam_input) override { bool has_out_transpose = config.config.output_BLHxS; bool fuse_causal_attn = config.config.fuse_causal_attn; bool is_causal = config.config.is_causal; const bool fuse_concat = config.config.fuse_concat; - auto input_num = inputs.size() - (fuse_concat ? 2 : 0); + auto input_num = inputs.size(); + PlainTensor present_key, present_value; q_input.reset(inputs[0]); k_input.reset(inputs[1]); v_input.reset(inputs[2]); + present_key.reset(presentk_input); + present_value.reset(presentv_input); + beam_table.reset(beam_input); PlainTensor attn_mask; if (input_num > 3) { // attn_mask @@ -560,15 +559,28 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt } // q: [B, H, L1, S] + const auto & permute_axes = config.config.permute_axes; + if (!permute_axes.empty()) { + q_input = q_input.permute(permute_axes); + k_input = k_input.permute(permute_axes); + v_input = v_input.permute(permute_axes); + present_key = present_key.permute(permute_axes); + present_value = present_value.permute(permute_axes); + } B = q_input.size(0); H = q_input.size(1); L1 = q_input.size(2); - S = q_input.size(-1); + S = q_input.size(3); + L0 = present_key.size(2) - L1; + auto Hk = k_input.size(1); - PlainTensor present_key, present_value; - concat_pastkv(inputs, outputs, k_input, v_input, present_key, present_value); + k_input.assert_dims({B, Hk, L1, S}); + v_input.assert_dims({B, Hk, L1, S}); + present_key.assert_dims({B, Hk, L0 + L1, S}); + present_value.assert_dims({B, Hk, L0 + L1, S}); + beam_table.assert_dims({B, L0 + L1}); - ov::intel_cpu::PlainTensor output_emb(outputs[0]); + ov::intel_cpu::PlainTensor output_emb(output); bool auto_causal; bool use_attn_mask; @@ -635,11 +647,11 @@ void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() { if (!supportedPrimitiveDescriptors.empty()) return; auto rtPrecision = getOriginalInputPrecisionAtPort(0); - auto orginSDPInputNumber = getOriginalInputsNumber() - (m_config.config.fuse_concat ? 2 : 0); + auto orginSDPInputNumber = getOriginalInputsNumber() - (m_config.config.fuse_concat ? 3 : 0); bool enableKVCacheFP16 = m_config.config.fuse_concat && mayiuse(cpu_isa_t::avx2) && rtPrecision != ov::element::bf16; - auto kvCachePrecision = enableKVCacheFP16 ? ov::element::f16 : rtPrecision; + m_kvcache_precision = enableKVCacheFP16 ? ov::element::f16 : rtPrecision; NodeConfig config; auto& creatorsMap = BlockedDescCreator::getCommonCreators(); @@ -669,39 +681,33 @@ void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() { } if (m_config.config.fuse_concat) { - ArbitraryOrderDescCreator cabdDescCreator({2, 0, 1, 3}); - - config.inConfs[orginSDPInputNumber + 0].setMemDesc(cabdDescCreator.createSharedDesc( - kvCachePrecision, getInputShapeAtPort(orginSDPInputNumber + 0))); - config.inConfs[orginSDPInputNumber + 1].setMemDesc(cabdDescCreator.createSharedDesc( - kvCachePrecision, getInputShapeAtPort(orginSDPInputNumber + 1))); - - config.outConfs[1].setMemDesc(cabdDescCreator.createSharedDesc( - kvCachePrecision, getOutputShapeAtPort(1))); - config.outConfs[1].inPlace(orginSDPInputNumber + 0); - config.outConfs[2].setMemDesc(cabdDescCreator.createSharedDesc( - kvCachePrecision, getOutputShapeAtPort(2))); - config.outConfs[2].inPlace(orginSDPInputNumber + 1); + VectorDims order = {0, 1, 2, 3}; + if (!m_config.config.permute_axes.empty()) + order = m_config.config.permute_axes; + ArbitraryOrderDescCreator descCreator(order); + + // beam_idx + config.inConfs[orginSDPInputNumber + 0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::i32, getInputShapeAtPort(orginSDPInputNumber + 0))); + // pastk + config.inConfs[orginSDPInputNumber + 1].setMemDesc(descCreator.createSharedDesc( + m_kvcache_precision, getInputShapeAtPort(orginSDPInputNumber + 1))); + // pastv + config.inConfs[orginSDPInputNumber + 2].setMemDesc(descCreator.createSharedDesc( + m_kvcache_precision, getInputShapeAtPort(orginSDPInputNumber + 2))); + + config.outConfs[1].setMemDesc(descCreator.createSharedDesc( + m_kvcache_precision, getOutputShapeAtPort(1))); + config.outConfs[1].inPlace(-1); + config.outConfs[2].setMemDesc(descCreator.createSharedDesc( + m_kvcache_precision, getOutputShapeAtPort(2))); + config.outConfs[2].inPlace(-1); } config.outConfs[0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( rtPrecision, getOutputShapeAtPort(0))); supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::ref_any); - // may fallback to abcd without inplace - if (m_config.config.fuse_concat) { - config.inConfs[orginSDPInputNumber + 0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - kvCachePrecision, getInputShapeAtPort(orginSDPInputNumber + 0))); - config.inConfs[orginSDPInputNumber + 1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - kvCachePrecision, getInputShapeAtPort(orginSDPInputNumber + 1))); - config.outConfs[1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - kvCachePrecision, getOutputShapeAtPort(1))); - config.outConfs[1].inPlace(-1); - config.outConfs[2].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - kvCachePrecision, getOutputShapeAtPort(2))); - config.outConfs[2].inPlace(-1); - supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::ref_any); - } } void ScaledDotProductAttention::createPrimitive() { @@ -709,8 +715,6 @@ void ScaledDotProductAttention::createPrimitive() { auto desc = getSelectedPrimitiveDescriptor(); if (desc == nullptr) OPENVINO_THROW("has unidentified preferable primitive descriptor"); - - m_config.is_concat_inplaced = desc->getConfig().outConfs[1].inPlace() >= 0; } auto rtPrecision = getOriginalInputPrecisionAtPort(0); @@ -728,36 +732,25 @@ void ScaledDotProductAttention::createPrimitive() { } void ScaledDotProductAttention::execute(dnnl::stream strm) { - if (k_state && k_state->is_reset_state()) { - constexpr int k_idx = 3; - //The memory from the initialization graph is bypassed using inplace memory usage - //so it may be captured from the input edge - auto k_input = getParentEdgeAt(k_idx)->getMemoryPtr(); - // perform the K tensor and the corresponding beam_table reinitialization - // using k_input data - // this is an example, most likely some kind of a strided tensor would be used - k_state->assign_internal_state(k_input); - } - - if (v_state && v_state->is_reset_state()) { - constexpr int v_idx = 3; - //The memory from the initialization graph is bypassed using inplace memory usage - //so it may be captured from the input edge - auto v_input = getParentEdgeAt(v_idx)->getMemoryPtr(); - // perform the V tensor and the corresponding beam_table reinitialization - // using v_input data - // this is an example, most likely some kind of a strided tensor would be used - v_state->assign_internal_state(v_input); - } - - std::vector inputs(getParentEdges().size()), outputs(getChildEdges().size()); - for (size_t i = 0; i < inputs.size(); i++) { + auto orginSDPInputNumber = getOriginalInputsNumber() - (m_config.config.fuse_concat ? 3 : 0); + std::vector inputs(orginSDPInputNumber); + auto output = getChildEdgeAt(0)->getMemoryPtr(); + MemoryPtr presentk_input, presentv_input, beam_input; + for (size_t i = 0; i < orginSDPInputNumber; i++) { inputs[i] = getParentEdgeAt(i)->getMemoryPtr(); } - for (size_t i = 0; i < outputs.size(); i++) { - outputs[i] = getChildEdgeAt(i)->getMemoryPtr(); + + if (m_config.config.fuse_concat) { + k_state->gather_concat_pastkv(inputs[1], getParentEdgeAt(orginSDPInputNumber)->getMemoryPtr()); + v_state->gather_concat_pastkv(inputs[2], getParentEdgeAt(orginSDPInputNumber)->getMemoryPtr()); + presentk_input = k_state->internal_state_mem(); + presentv_input = v_state->internal_state_mem(); + beam_input = k_state->hidden_state_mem(); + } else { + presentk_input = inputs[1]; + presentv_input = inputs[2]; } - m_executor->execute(strm, inputs, outputs); + m_executor->execute(strm, inputs, output, presentk_input, presentv_input, beam_input); } bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { @@ -777,7 +770,7 @@ bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr(op); if (node) { if (node->get_config().fuse_concat) { - orgSDPAInput -= 2; + orgSDPAInput -= 3; } } if (orgSDPAInput > 3) { @@ -799,21 +792,15 @@ bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr& state, int idx) { - constexpr int k_idx = 3; - constexpr int v_idx = 4; - - if (k_idx == idx) { + auto inputNumber = getOriginalInputsNumber(); + if (inputNumber - 2 == idx) { k_state = state; - } else if (v_idx == idx) { + } else if (inputNumber - 1 == idx) { v_state = state; } else { OPENVINO_THROW( "Unexpected idx ", idx , " for a state in a node with type: ", getTypeStr(), " and name ", getName()); } - - if (state->is_reset_state()) { - //do some preliminary state modifications when necessary - } } } // namespace node diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.h b/src/plugins/intel_cpu/src/nodes/scaled_attn.h index 644a5d3cab3587..38cdb336dc6c47 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.h +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.h @@ -44,14 +44,22 @@ class ScaledDotProductAttention : public Node { void assignState(const std::shared_ptr& state, int idx); + const std::vector& getKVCacheOrder() const { + return m_config.config.permute_axes; + } + + ov::element::Type getKVCachePrecision() const { + return m_kvcache_precision; + } + private: struct Executor { - virtual void execute(dnnl::stream strm, const std::vector& inputs, const std::vector& outputs) = 0; + virtual void execute(dnnl::stream strm, const std::vector& inputs, const MemoryPtr output, const MemoryPtr presentk_input, + const MemoryPtr presentv_input, const MemoryPtr beam_input) = 0; }; struct Config { ScaledDotProductAttentionWithKVCache::Config config; - bool is_concat_inplaced = false; }; Config m_config; @@ -60,6 +68,8 @@ class ScaledDotProductAttention : public Node { std::shared_ptr k_state; std::shared_ptr v_state; + + ov::element::Type m_kvcache_precision; }; } // namespace node diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.cpp index 4dc5ba799dd4eb..f163cdc9d8217b 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.cpp @@ -29,18 +29,43 @@ void ov::intel_cpu::ScaledDotProductAttentionWithKVCache::validate_and_infer_typ // [B, H, L0, S] auto past_kv_ps = get_input_partial_shape(input_num - 1); + auto output_logits = q_ps; NODE_VALIDATION_CHECK(this, m_config.output_BLHxS == false); NODE_VALIDATION_CHECK(this, q_ps.size() >= 3); + // permute_axes from original to [B, H, L, S] + const auto& permute_axes = this->m_config.permute_axes; if (past_kv_ps.rank().is_static()) { + const size_t length_index = permute_axes.empty() ? q_ps.size() - 2 : permute_axes[permute_axes.size() - 2]; + const size_t head_num_index = permute_axes.empty() ? q_ps.size() - 3 : permute_axes[permute_axes.size() - 3]; NODE_VALIDATION_CHECK(this, q_ps.size() == past_kv_ps.size()); for (size_t i = 0; i < q_ps.size(); i++) { - if (i == q_ps.size() - 2) - continue; - NODE_VALIDATION_CHECK(this, q_ps[i].compatible(past_kv_ps[i])); + if (i == head_num_index) { + if (q_ps[i].is_static() && past_kv_ps[i].is_static()) { + NODE_VALIDATION_CHECK(this, + q_ps[i].get_length() % past_kv_ps[i].get_length() == 0, + "shape not compatiable at index ", + i); + } else if (i == length_index) { + continue; + } else { + NODE_VALIDATION_CHECK(this, + q_ps[i].compatible(past_kv_ps[i]), + "shape not compatiable at index ", + i); + } + } } - past_kv_ps[q_ps.size() - 2] += q_ps[q_ps.size() - 2]; + past_kv_ps[length_index] += q_ps[length_index]; } - set_output_type(0, get_input_element_type(0), q_ps); + if (!permute_axes.empty()) { + if (q_ps.rank().is_static()) { + // q_ps needs permute to BHLS + for (size_t i = 0; i < q_ps.size(); i++) { + output_logits[i] = q_ps[permute_axes[i]]; + } + } + } + set_output_type(0, get_input_element_type(0), output_logits); set_output_type(1, get_input_element_type(input_num - 1), past_kv_ps); set_output_type(2, get_input_element_type(input_num - 1), past_kv_ps); } @@ -52,6 +77,7 @@ bool ov::intel_cpu::ScaledDotProductAttentionWithKVCache::visit_attributes(ov::A visitor.on_attribute("fuse_causal_attn", m_config.fuse_causal_attn); visitor.on_attribute("is_causal", m_config.is_causal); visitor.on_attribute("fuse_concat", m_config.fuse_concat); + visitor.on_attribute("permute_axes", m_config.permute_axes); visitor.finish_structure(); return true; -} +} \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.hpp index 94406caeab016e..753de527dc73f3 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.hpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.hpp @@ -21,11 +21,13 @@ class ScaledDotProductAttentionWithKVCache : public ov::op::Op { ScaledDotProductAttentionWithKVCache() = default; struct Config { - bool output_BLHxS = false; // true implies that output is [B,L,H*S] + bool output_BLHxS = false; // true implies that output is [B,L,H*S] - bool fuse_causal_attn = false; // fuse causal mask and attn mask into attn_mask - bool is_causal = false; // apply causal mask internally - bool fuse_concat = false; // fuse (concat->sdp) ==> sdp + bool fuse_causal_attn = false; // fuse causal mask and attn mask into attn_mask + bool is_causal = false; // apply causal mask internally + bool fuse_concat = false; // fuse (concat->sdp) ==> sdp + std::vector permute_axes; // not empty means input has transpose. output of permutation is [B,H,L,S] + // e.g. [L,B,H,S] -> permute[1, 2, 0, 3] ->[B, H, L, S] }; ScaledDotProductAttentionWithKVCache(const OutputVector& args, const Config& cfg); @@ -47,4 +49,4 @@ class ScaledDotProductAttentionWithKVCache : public ov::op::Op { }; } // namespace intel_cpu -} // namespace ov +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp index 683609e968c900..dea6753e7b90c8 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp @@ -16,6 +16,7 @@ #include #include "itt.hpp" +#include #include "ov_ops/type_relaxed.hpp" #include "transformations/cpu_opset/common/op/sdpa.hpp" @@ -30,10 +31,12 @@ StatefulSDPAFusion::StatefulSDPAFusion() { auto past_v = wrap_type(); auto convert_past_k = wrap_type({past_k}); auto convert_past_v = wrap_type({past_v}); - auto concat_input_k = std::make_shared(OutputVector{past_k, convert_past_k}); - auto concat_input_v = std::make_shared(OutputVector{past_v, convert_past_v}); - auto concat_k = wrap_type({concat_input_k, any_input()}); - auto concat_v = wrap_type({concat_input_v, any_input()}); + auto select_input_k = std::make_shared(OutputVector{past_k, convert_past_k}); + auto select_input_v = std::make_shared(OutputVector{past_v, convert_past_v}); + auto gather_input_k = wrap_type({select_input_k, any_input(), any_input()}); + auto gather_input_v = wrap_type({select_input_v, any_input(), any_input()}); + auto concat_k = wrap_type({gather_input_k, any_input()}); + auto concat_v = wrap_type({gather_input_v, any_input()}); auto sdp0 = wrap_type({any_input(), concat_k, concat_v}); auto sdp1 = wrap_type({any_input(), concat_k, concat_v, any_input()}); auto sdp2 = wrap_type({any_input(), concat_k, concat_v, any_input(), any_input()}); @@ -61,11 +64,23 @@ StatefulSDPAFusion::StatefulSDPAFusion() { return; } }; + auto check_valid_children_type = [] (const ov::Output& out) { + auto children = out.get_target_inputs(); + for (auto& child : children) { + auto node = child.get_node(); + if (!one_of(node->get_type_info(), ov::op::v13::ScaledDotProductAttention::get_type_info_static(), + ov::op::v3::ShapeOf::get_type_info_static(), ov::op::v0::Convert::get_type_info_static())) + return false; + } + return true; + }; std::shared_ptr read_cvt_k_node, read_cvt_v_node; const auto sdp_node = ov::as_type_ptr(root); const auto past_k_node = ov::as_type_ptr(pattern_map.at(past_k).get_node_shared_ptr()); const auto past_v_node = ov::as_type_ptr(pattern_map.at(past_v).get_node_shared_ptr()); + if (!check_valid_children_type(past_k_node) || check_valid_children_type(past_v_node)) + return false; const auto concat_k_node = ov::as_type_ptr(pattern_map.at(concat_k).get_node_shared_ptr()); const auto concat_v_node = ov::as_type_ptr(pattern_map.at(concat_v).get_node_shared_ptr()); if (pattern_map.count(convert_past_k)) { @@ -86,9 +101,16 @@ StatefulSDPAFusion::StatefulSDPAFusion() { if (past_v_node->get_variable_id() != assign_v_node->get_variable_id()) return false; + const auto gather_k_node = ov::as_type_ptr(pattern_map.at(gather_input_k).get_node_shared_ptr()); + const auto gather_v_node = ov::as_type_ptr(pattern_map.at(gather_input_v).get_node_shared_ptr()); + if (gather_k_node->input_value(1) != gather_v_node->input_value(1)) { + std::cout << "StatefulSDPAFusion beam_idx is not same for gather\n"; + return false; + } auto args = sdp_node->input_values(); args[1] = concat_k_node->input_value(1); args[2] = concat_v_node->input_value(1); + args.push_back(gather_k_node->input_value(1)); args.push_back(read_cvt_k_node ? read_cvt_k_node->output(0) : past_k_node->output(0)); args.push_back(read_cvt_v_node ? read_cvt_v_node->output(0) : past_v_node->output(0)); ov::intel_cpu::ScaledDotProductAttentionWithKVCache::Config config; diff --git a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_sdp.cpp b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_sdp.cpp index 7f7ea0f30f9997..2115634070d0e9 100644 --- a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_sdp.cpp +++ b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_sdp.cpp @@ -28,7 +28,8 @@ using ConcatSDPTestParams = std::tuple, v pastk_shapeof = std::make_shared(pastk); pastv_shapeof = std::make_shared(pastv); } - auto concatK = std::make_shared(OutputVector{pastk, inputParams[1]}, 2); - auto concatV = std::make_shared(OutputVector{pastv, inputParams[2]}, 2); + auto beam_idx = std::make_shared(ElementType::i32, ov::PartialShape{-1, -1}); + beam_idx->set_friendly_name("beam_idx"); + inputParams.push_back(beam_idx); + auto gatherK = std::make_shared(pastk, beam_idx, op::v0::Constant::create(ElementType::i32, {1}, {0})); + auto gatherV = std::make_shared(pastv, beam_idx, op::v0::Constant::create(ElementType::i32, {1}, {0})); + auto concatK = std::make_shared(OutputVector{gatherK, inputParams[1]}, 2); + auto concatV = std::make_shared(OutputVector{gatherV, inputParams[2]}, 2); auto sdp = std::make_shared(inputParams[0], concatK, concatV, false); sdp->set_friendly_name("mha"); auto add = std::make_shared(sdp, op::v0::Constant::create(inType, {1}, {1.0f})); @@ -145,7 +151,11 @@ class ConcatSDPTest : public testing::WithParamInterface, v void generate(int idx, const std::vector& targetInputStaticShapes) { inputs.clear(); auto create_input = [this] (std::shared_ptr param, ov::Shape shape, float val) { - if (param->get_element_type() == element::f32) { + if (param->get_element_type() == element::i32) { + ov::Tensor t{ov::element::i32, shape}; + std::iota(static_cast(t.data()), static_cast(t.data()) + t.get_size(), 0); + inputs.insert({param, t}); + } else if (param->get_element_type() == element::f32) { ov::Tensor t{ov::element::f32, shape}; strided_iota(static_cast(t.data()), t.get_size(), val, 0.1f); inputs.insert({param, t}); @@ -155,11 +165,12 @@ class ConcatSDPTest : public testing::WithParamInterface, v inputs.insert({param, t}); } }; - // q, k, v + // q, k, v, pastkv create_input(function->get_parameters()[0], targetInputStaticShapes[0], idx + 1.0f); create_input(function->get_parameters()[1], targetInputStaticShapes[0], idx + 2.0f); create_input(function->get_parameters()[2], targetInputStaticShapes[0], idx + 3.0f); create_input(function->get_parameters()[3], targetInputStaticShapes[1], idx + 4.0f); + create_input(function->get_parameters()[4], ov::Shape{targetInputStaticShapes[0][0], targetInputStaticShapes[0][2]}, idx + 0.0f); } void prepare() { compile_model(); @@ -198,6 +209,7 @@ TEST_P(ConcatSDPTest, CompareWithRefs) { CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 1); CheckNumberOfNodesWithType(compiledModel, "Concatenation", 0); CheckNumberOfNodesWithType(compiledModel, "Reorder", 0); + CheckNumberOfNodesWithType(compiledModel, "Gather", 0); auto expectedOutputs = run_test(functionRefs); CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 0); for (size_t i = 0; i < actualOutputs.size(); i++) {