Skip to content

Commit

Permalink
Fix SDPA tests
Browse files Browse the repository at this point in the history
  • Loading branch information
maxnick committed Apr 26, 2024
1 parent 672a10d commit cc2383e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,12 @@ class ConcatMultiQuerySDPTest : public testing::WithParamInterface<ConcatMultiQu
outputs.push_back(copy);
}
auto states = inferRequest.query_state();
for (auto&& state : states) {
for (std::string name : {"pastk", "pastv"}) {
auto itr = std::find_if(states.begin(), states.end(), [&](const ov::VariableState& state) {
return name == state.get_name();
});
OPENVINO_ASSERT(itr != states.end(), "Failed to find ", name, " state");
const auto& state = *itr;
auto state_tensor = state.get_state();
ov::Tensor copy{state_tensor.get_element_type(), state_tensor.get_shape()};
state_tensor.copy_to(copy);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,12 @@ class ConcatSDPTransposeTest : public ConcatSDPTransposeTestBase {
outputs.push_back(copy);
}
auto states = inferRequest.query_state();
for (auto&& state : states) {
for (std::string name : {"pastk", "pastv"}) {
auto itr = std::find_if(states.begin(), states.end(), [&](const ov::VariableState& state) {
return name == state.get_name();
});
OPENVINO_ASSERT(itr != states.end(), "Failed to find ", name, " state");
const auto& state = *itr;
auto state_tensor = state.get_state();
ov::Tensor copy{state_tensor.get_element_type(), state_tensor.get_shape()};
state_tensor.copy_to(copy);
Expand Down

0 comments on commit cc2383e

Please sign in to comment.