Skip to content

Commit

Permalink
[CPU] Fix SDPA pattern matching
Browse files Browse the repository at this point in the history
  • Loading branch information
maxnick committed Mar 20, 2024
1 parent 3d45a64 commit 980ac27
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/nodes/scaled_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,7 @@ void ScaledDotProductAttention::execute(dnnl::stream strm) {

PlainTensor k_scale_zp, v_scale_zp;
if (m_config.config.fuse_concat) {
CPU_NODE_ASSERT(m_k_state && m_v_state, "has null input states");
// initialization will be also completed in this func
gatherConcatPastkv(inputs[1], inputs[2], getSrcMemoryAtPort(orginSDPInputNumber));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,6 @@ StatefulSDPAFusion::StatefulSDPAFusion() {

auto find_assign = [&](const ov::Output<ov::Node>& out, opset6::Assign*& assign, opset1::Convert*& cvt) {
auto present_to = out.get_target_inputs();
if (present_to.size() < 2)
return false;
for (auto& to : present_to) {
auto to_node = to.get_node();
if (auto convert = dynamic_cast<opset1::Convert*>(to_node)) {
Expand Down Expand Up @@ -149,6 +147,28 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
const auto concat_k_node = ov::as_type_ptr<opset6::Concat>(pattern_map.at(concat_k).get_node_shared_ptr());
const auto concat_v_node = ov::as_type_ptr<opset6::Concat>(pattern_map.at(concat_v).get_node_shared_ptr());

for (auto&& item : {concat_k_node, concat_v_node}) {
auto&& children = item->get_output_target_inputs(0);
switch (children.size()) {
case 2:
// pass, as the existence of Assign will be checked later
break;
case 3:
// the fist one leads to SDPA, otherwise the matcher don't find the pattern
// the second one leads to Assign, and this is checked later
// the third child is allowed to be a ShapeOf op only, thus one of them must be ShapeOf
if (!std::any_of(children.begin(), children.end(), [](const ov::Input<ov::Node>& child) {
return ov::is_type<ov::op::v3::ShapeOf>(child.get_node()) ||
ov::is_type<ov::op::v0::ShapeOf>(child.get_node());
})) {
return false;
}
break;
default:
return false;
}
}

opset6::Assign *assign_k_node = nullptr, *assign_v_node = nullptr;
opset1::Convert *assign_cvt_k_node = nullptr, *assign_cvt_v_node = nullptr;
if (!find_assign(concat_k_node, assign_k_node, assign_cvt_k_node))
Expand Down

0 comments on commit 980ac27

Please sign in to comment.