Skip to content

Commit

Permalink
fix sdpa fusion testcase
Browse files Browse the repository at this point in the history
  • Loading branch information
luo-cheng2021 committed Dec 16, 2023
1 parent b64fb40 commit 546b555
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 18 deletions.
10 changes: 8 additions & 2 deletions src/plugins/intel_cpu/src/nodes/scaled_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
bool has_out_transpose = config.config.output_BLHxS;
bool fuse_causal_attn = config.config.fuse_causal_attn;
bool is_causal = config.config.is_causal;
bool fuse_concat = config.config.fuse_concat;
auto input_num = inputs.size();
PlainTensor present_key, present_value;

Expand Down Expand Up @@ -547,8 +548,13 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
L0 = present_key.size(2) - L1;
auto Hk = k_input.size(1);

k_input.assert_dims({B, Hk, L1, S});
v_input.assert_dims({B, Hk, L1, S});
if (fuse_concat) {
k_input.assert_dims({B, Hk, L1, S});
v_input.assert_dims({B, Hk, L1, S});
} else {
k_input.assert_dims({B, Hk, L0 + L1, S});
v_input.assert_dims({B, Hk, L0 + L1, S});
}
present_key.assert_dims({B, Hk, L0 + L1, S});
present_value.assert_dims({B, Hk, L0 + L1, S});
if (beam_table)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ class ConcatMultiQuerySDPTest : public testing::WithParamInterface<ConcatMultiQu
public CPUTestsBase {
public:
static std::string getTestCaseName(const testing::TestParamInfo<ConcatMultiQuerySDPParams>& obj) {
ElementType inType;
ElementType qkvType;
InputShapeAndTransposeOrder inputShapeAndOrders;
bool hasShapeof;
std::tie(inType, inputShapeAndOrders, hasShapeof) = obj.param;
std::tie(qkvType, inputShapeAndOrders, hasShapeof) = obj.param;
std::ostringstream result;
std::vector<InputShape>& inputShapes = inputShapeAndOrders.first;
std::vector<size_t>& transposeOrder = inputShapeAndOrders.second;
Expand All @@ -75,8 +75,8 @@ class ConcatMultiQuerySDPTest : public testing::WithParamInterface<ConcatMultiQu
}
result << ")_";
}
result << "Prc=" << inType << "_";
result << "HasShapeOf=" << hasShapeof;
result << "Prc=" << qkvType << "_";
result << "HasShapeOf=" << hasShapeof << "_";
result << "TransposeOrder=";
result << "(";
for (const auto& itr : transposeOrder) {
Expand All @@ -90,34 +90,34 @@ class ConcatMultiQuerySDPTest : public testing::WithParamInterface<ConcatMultiQu
void SetUp() override {
InputShapeAndTransposeOrder inputShapeAndOrders;
bool hasShapeOf;
ElementType inType;
std::tie(inType, inputShapeAndOrders, hasShapeOf) = this->GetParam();
ElementType qkvType;
std::tie(qkvType, inputShapeAndOrders, hasShapeOf) = this->GetParam();
std::vector<InputShape>& inputShapes = inputShapeAndOrders.first;
std::vector<size_t>& transposeOrder = inputShapeAndOrders.second;
targetDevice = ov::test::utils::DEVICE_CPU;
rel_threshold = 1e-2f;
configuration[ov::hint::inference_precision.name()] = ov::element::f32;
if (inType == ElementType::bf16) {
if (qkvType == ElementType::bf16) {
configuration[ov::hint::inference_precision.name()] = ov::element::bf16;
rel_threshold = 0.01f;
}
init_input_shapes(inputShapes);
ov::ParameterVector inputParams;
// q,k,v
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[1]));
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[1]));
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(qkvType, inputDynamicShapes[0]));
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(qkvType, inputDynamicShapes[1]));
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(qkvType, inputDynamicShapes[1]));
inputParams[0]->set_friendly_name("q");
inputParams[1]->set_friendly_name("k");
inputParams[2]->set_friendly_name("v");
// pastkv init_cost
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[2]));
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(qkvType, inputDynamicShapes[2]));
auto var_k = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{inputDynamicShapes[2], inType, "pastk"});
ov::op::util::VariableInfo{inputDynamicShapes[2], qkvType, "pastk"});
auto pastk = std::make_shared<ov::op::v6::ReadValue>(inputParams[3], var_k);
pastk->set_friendly_name("pastk_r");
auto var_v = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{inputDynamicShapes[2], inType, "pastv"});
ov::op::util::VariableInfo{inputDynamicShapes[2], qkvType, "pastv"});
auto pastv = std::make_shared<ov::op::v6::ReadValue>(inputParams[3], var_v);
pastv->set_friendly_name("pastv_r");
std::shared_ptr<Node> pastk_shapeof, pastv_shapeof;
Expand All @@ -143,7 +143,7 @@ class ConcatMultiQuerySDPTest : public testing::WithParamInterface<ConcatMultiQu
auto unsqueezeK = std::make_shared<ov::op::v0::Unsqueeze>(concatK, unsquezeAxis);
auto unsqueezeV = std::make_shared<ov::op::v0::Unsqueeze>(concatV, unsquezeAxis);

auto targetShape = op::v0::Constant::create(inType, {1, 1, 1, 4, 1}, {1});
auto targetShape = op::v0::Constant::create(qkvType, {1, 1, 1, 4, 1}, {1});
auto broadcastK = std::make_shared<ov::op::v1::Multiply>(unsqueezeK, targetShape);
auto broadcastV = std::make_shared<ov::op::v1::Multiply>(unsqueezeV, targetShape);

Expand Down Expand Up @@ -175,7 +175,7 @@ class ConcatMultiQuerySDPTest : public testing::WithParamInterface<ConcatMultiQu
auto constReshape = ov::op::v0::Constant::create(ov::element::i32, {3}, reshapeOrder);
auto reshapeSDP = std::make_shared<ov::op::v1::Reshape>(transposeSDP, constReshape, true); // BLHS -> B,L,HxS

auto add = std::make_shared<ov::op::v1::Add>(reshapeSDP, op::v0::Constant::create(inType, {1}, {1.0f}));
auto add = std::make_shared<ov::op::v1::Add>(reshapeSDP, op::v0::Constant::create(qkvType, {1}, {1.0f}));
auto pastk_assign = std::make_shared<ov::op::v6::Assign>(concatK, var_k);
auto pastv_assign = std::make_shared<ov::op::v6::Assign>(concatV, var_v);
pastk_assign->set_friendly_name("pastk_w");
Expand Down Expand Up @@ -312,7 +312,6 @@ const std::vector<InputShapeAndTransposeOrder> inputShapeAndReorders = {{
},
// transposeOrder
{1, 2, 0, 3}},

}};

// TODO: BF16 test is disabled due to CI machine limitation
Expand Down

0 comments on commit 546b555

Please sign in to comment.