Skip to content

Commit

Permalink
Fix SDPA decomposition
Browse files Browse the repository at this point in the history
  • Loading branch information
maxnick committed Nov 22, 2024
1 parent db44ef0 commit 037c2dd
Showing 1 changed file with 15 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,21 +67,22 @@ std::shared_ptr<ov::Node> ov::pass::ScaledDotProductAttentionDecomposition::deco
auto one_f = register_new_node<v1::ConvertLike>(one_i, query);
auto zero_f = register_new_node<v1::ConvertLike>(zero_i, query);

auto last_dim = [&](const ov::Output<ov::Node>& output) -> ov::Output<ov::Node> {
auto& inp_shape = output.get_partial_shape();
if (inp_shape.rank().is_static()) {
auto& last_dim = *(inp_shape.rbegin());
Output<Node> scale;
if (node->get_input_size() < 5) {
auto&& query_shape = query.get_partial_shape();
//often the embeddings space size is known, so the dimension may be extracted into a constant
if (query_shape.rank().is_static()) {
auto&& last_dim = *(query_shape.rbegin());
if (last_dim.is_static()) {
return register_new_node(v0::Constant::create(element::i32, Shape{}, {last_dim.get_length()}));
scale = register_new_node(v0::Constant::create(element::i32, Shape{}, {last_dim.get_length()}));
}
}
auto shape = register_new_node<v3::ShapeOf>(output, element::i32);
return register_new_node<v8::Gather>(shape, minus_one, zero_i);
};

Output<Node> scale;
if (node->get_input_size() < 5) {
scale = last_dim(query);
if (!scale.get_node()) {
auto shape = register_new_node<v3::ShapeOf>(query, element::i32);
scale = register_new_node<v8::Gather>(shape, minus_one, zero_i);
}

scale = register_new_node<v1::ConvertLike>(scale, query);
auto sqrt_scale = register_new_node<v0::Sqrt>(scale);
scale = register_new_node<v1::Divide>(one_f, sqrt_scale);
Expand Down Expand Up @@ -124,8 +125,9 @@ std::shared_ptr<ov::Node> ov::pass::ScaledDotProductAttentionDecomposition::deco
atten_mask = mask;
}
} else {
auto target_s_len = last_dim(query);
auto source_s_len = last_dim(key);
auto q_shape = register_new_node<v3::ShapeOf>(query, element::i32);
auto target_s_len = register_new_node<v8::Gather>(q_shape, minus_two, zero_i);
auto source_s_len = register_new_node<v8::Gather>(k_shape, minus_two, zero_i);
auto ssl = register_new_node<v0::Unsqueeze>(source_s_len, zero_i);
auto tsl = register_new_node<v0::Unsqueeze>(target_s_len, zero_i);
auto mask_shape = register_new_node<v0::Concat>(OutputVector{tsl, ssl}, 0);
Expand Down

0 comments on commit 037c2dd

Please sign in to comment.