diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index 38345396ef7db4..c2034036b811cc 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -505,6 +505,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt bool has_out_transpose = config.config.output_BLHxS; bool fuse_causal_attn = config.config.fuse_causal_attn; bool is_causal = config.config.is_causal; + bool fuse_concat = config.config.fuse_concat; auto input_num = inputs.size(); PlainTensor present_key, present_value; @@ -547,8 +548,13 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt L0 = present_key.size(2) - L1; auto Hk = k_input.size(1); - k_input.assert_dims({B, Hk, L1, S}); - v_input.assert_dims({B, Hk, L1, S}); + if (fuse_concat) { + k_input.assert_dims({B, Hk, L1, S}); + v_input.assert_dims({B, Hk, L1, S}); + } else { + k_input.assert_dims({B, Hk, L0 + L1, S}); + v_input.assert_dims({B, Hk, L0 + L1, S}); + } present_key.assert_dims({B, Hk, L0 + L1, S}); present_value.assert_dims({B, Hk, L0 + L1, S}); if (beam_table) diff --git a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_multiple_query_sdp.cpp b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_multiple_query_sdp.cpp index 924c5627079fbf..3b55b25c0e293d 100644 --- a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_multiple_query_sdp.cpp +++ b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_multiple_query_sdp.cpp @@ -54,10 +54,10 @@ class ConcatMultiQuerySDPTest : public testing::WithParamInterface& obj) { - ElementType inType; + ElementType qkvType; InputShapeAndTransposeOrder inputShapeAndOrders; bool hasShapeof; - std::tie(inType, inputShapeAndOrders, hasShapeof) = obj.param; + std::tie(qkvType, inputShapeAndOrders, hasShapeof) = obj.param; std::ostringstream result; std::vector& inputShapes = inputShapeAndOrders.first; std::vector& transposeOrder = inputShapeAndOrders.second; @@ -75,8 +75,8 @@ class ConcatMultiQuerySDPTest : public testing::WithParamInterfaceGetParam(); + ElementType qkvType; + std::tie(qkvType, inputShapeAndOrders, hasShapeOf) = this->GetParam(); std::vector& inputShapes = inputShapeAndOrders.first; std::vector& transposeOrder = inputShapeAndOrders.second; targetDevice = ov::test::utils::DEVICE_CPU; rel_threshold = 1e-2f; configuration[ov::hint::inference_precision.name()] = ov::element::f32; - if (inType == ElementType::bf16) { + if (qkvType == ElementType::bf16) { configuration[ov::hint::inference_precision.name()] = ov::element::bf16; rel_threshold = 0.01f; } init_input_shapes(inputShapes); ov::ParameterVector inputParams; // q,k,v - inputParams.push_back(std::make_shared(inType, inputDynamicShapes[0])); - inputParams.push_back(std::make_shared(inType, inputDynamicShapes[1])); - inputParams.push_back(std::make_shared(inType, inputDynamicShapes[1])); + inputParams.push_back(std::make_shared(qkvType, inputDynamicShapes[0])); + inputParams.push_back(std::make_shared(qkvType, inputDynamicShapes[1])); + inputParams.push_back(std::make_shared(qkvType, inputDynamicShapes[1])); inputParams[0]->set_friendly_name("q"); inputParams[1]->set_friendly_name("k"); inputParams[2]->set_friendly_name("v"); // pastkv init_cost - inputParams.push_back(std::make_shared(inType, inputDynamicShapes[2])); + inputParams.push_back(std::make_shared(qkvType, inputDynamicShapes[2])); auto var_k = std::make_shared( - ov::op::util::VariableInfo{inputDynamicShapes[2], inType, "pastk"}); + ov::op::util::VariableInfo{inputDynamicShapes[2], qkvType, "pastk"}); auto pastk = std::make_shared(inputParams[3], var_k); pastk->set_friendly_name("pastk_r"); auto var_v = std::make_shared( - ov::op::util::VariableInfo{inputDynamicShapes[2], inType, "pastv"}); + ov::op::util::VariableInfo{inputDynamicShapes[2], qkvType, "pastv"}); auto pastv = std::make_shared(inputParams[3], var_v); pastv->set_friendly_name("pastv_r"); std::shared_ptr pastk_shapeof, pastv_shapeof; @@ -143,7 +143,7 @@ class ConcatMultiQuerySDPTest : public testing::WithParamInterface(concatK, unsquezeAxis); auto unsqueezeV = std::make_shared(concatV, unsquezeAxis); - auto targetShape = op::v0::Constant::create(inType, {1, 1, 1, 4, 1}, {1}); + auto targetShape = op::v0::Constant::create(qkvType, {1, 1, 1, 4, 1}, {1}); auto broadcastK = std::make_shared(unsqueezeK, targetShape); auto broadcastV = std::make_shared(unsqueezeV, targetShape); @@ -175,7 +175,7 @@ class ConcatMultiQuerySDPTest : public testing::WithParamInterface(transposeSDP, constReshape, true); // BLHS -> B,L,HxS - auto add = std::make_shared(reshapeSDP, op::v0::Constant::create(inType, {1}, {1.0f})); + auto add = std::make_shared(reshapeSDP, op::v0::Constant::create(qkvType, {1}, {1.0f})); auto pastk_assign = std::make_shared(concatK, var_k); auto pastv_assign = std::make_shared(concatV, var_v); pastk_assign->set_friendly_name("pastk_w"); @@ -312,7 +312,6 @@ const std::vector inputShapeAndReorders = {{ }, // transposeOrder {1, 2, 0, 3}}, - }}; // TODO: BF16 test is disabled due to CI machine limitation