Skip to content

Commit

Permalink
add set_state test
Browse files Browse the repository at this point in the history
  • Loading branch information
luo-cheng2021 committed Dec 17, 2023
1 parent 9aa7e4b commit e129722
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 18 deletions.
4 changes: 3 additions & 1 deletion src/plugins/intel_cpu/src/memory_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,11 @@ void VariableStateKVcache::set_state_impl(const ov::SoPtr<ov::ITensor>& state) {
auto buff = reinterpret_cast<int*>(m_hidden_state->getData());
for (size_t i = 0; i < size_B; ++i) {
for (size_t j = 0; j < size_L; ++j) {
buff[i * size_B + j] = i;
buff[i * size_L + j] = i;
}
}
m_internal_mem_max_size = dense_internal_desc->getCurrentMemSize() / dense_internal_desc->getPrecision().size();
m_hidden_state_max_size = mem_desc->getCurrentMemSize() / mem_desc->getPrecision().size();
}

void VariableStateKVcache::reset_impl() {
Expand Down
26 changes: 14 additions & 12 deletions src/plugins/intel_cpu/src/nodes/scaled_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,9 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
}
}

if (L1 > 1) {
// second token, or first token with pastkv fusing
bool use_one_token = L1 == 1 || (fuse_concat && L0 > 0);
if (!use_one_token) {
// multi-token version
kernel(strm, q_input, k_input, v_input, {}, use_attn_mask ? attn_mask : PlainTensor(),
output_emb, has_out_transpose, auto_causal, scale_input);
Expand Down Expand Up @@ -620,17 +622,17 @@ ScaledDotProductAttention::ScaledDotProductAttention(const std::shared_ptr<ngrap
} else {
const auto node = std::dynamic_pointer_cast<const ScaledDotProductAttentionWithKVCache>(op);
m_config.config = node->get_config();
}
// BHLS->LBHS.. lookup table
std::vector<size_t> order;
if (m_config.config.permute_axes.empty()) {
order = {0, 1, 2, 3};
} else {
order = m_config.config.permute_axes;
}
m_config.reverse_order.resize(order.size());
for (size_t i = 0; i < order.size(); i++) {
m_config.reverse_order[order[i]] = i;
// BHLS->LBHS.. lookup table
std::vector<size_t> order;
if (m_config.config.permute_axes.empty()) {
order = {0, 1, 2, 3};
} else {
order = m_config.config.permute_axes;
}
m_config.reverse_order.resize(order.size());
for (size_t i = 0; i < order.size(); i++) {
m_config.reverse_order[order[i]] = i;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ using ConcatSDPTransposeTestParams = std::tuple<ElementType,
* Result
*/

class ConcatSDPTransposeTest : public testing::WithParamInterface<ConcatSDPTransposeTestParams>,
virtual public ov::test::SubgraphBaseTest,
public CPUTestsBase {
class ConcatSDPTransposeTestBase : public testing::WithParamInterface<ConcatSDPTransposeTestParams>,
virtual public ov::test::SubgraphBaseTest,
public CPUTestsBase {
public:
static std::string getTestCaseName(const testing::TestParamInfo<ConcatSDPTransposeTestParams>& obj) {
ElementType inType;
Expand Down Expand Up @@ -91,7 +91,7 @@ class ConcatSDPTransposeTest : public testing::WithParamInterface<ConcatSDPTrans
bool hasShapeOf;
std::tie(inType, inputShapeAndOrders, hasShapeOf) = this->GetParam();
std::vector<InputShape>& inputShapes = inputShapeAndOrders.first;
std::vector<size_t>& transposeOrder = inputShapeAndOrders.second;
transposeOrder = inputShapeAndOrders.second;
targetDevice = ov::test::utils::DEVICE_CPU;
rel_threshold = 1e-2f;
configuration[ov::hint::inference_precision.name()] = ov::element::f32;
Expand Down Expand Up @@ -189,7 +189,7 @@ class ConcatSDPTransposeTest : public testing::WithParamInterface<ConcatSDPTrans
SubgraphBaseTest::generate_inputs(shapes);
}
template <typename IT, typename T>
void strided_iota(IT first, size_t n, T value, T stride) {
static void strided_iota(IT first, size_t n, T value, T stride) {
for (size_t i = 0; i < n; i++) {
*first++ = value;
value += stride;
Expand All @@ -212,6 +212,7 @@ class ConcatSDPTransposeTest : public testing::WithParamInterface<ConcatSDPTrans
strided_iota(static_cast<float*>(t.data()), t.get_size(), val, 0.1f);
inputs.insert({param, t});
} else {
ASSERT_TRUE(param->get_element_type() == element::bf16);
ov::Tensor t{ov::element::bf16, shape};
strided_iota(static_cast<ov::bfloat16*>(t.data()), t.get_size(), val, 0.1f);
inputs.insert({param, t});
Expand All @@ -234,6 +235,11 @@ class ConcatSDPTransposeTest : public testing::WithParamInterface<ConcatSDPTrans
state.reset();
}
}
std::vector<size_t> transposeOrder;
};

class ConcatSDPTransposeTest : public ConcatSDPTransposeTestBase {
public:
std::vector<ov::Tensor> run_test(std::shared_ptr<ov::Model> model) {
function = model;
prepare();
Expand Down Expand Up @@ -310,6 +316,125 @@ INSTANTIATE_TEST_SUITE_P(smoke_ConcatSDPTransposeTest,
::testing::ValuesIn(inputShapeAndReorders),
::testing::Values(true, false)),
ConcatSDPTransposeTest::getTestCaseName);
} // namespace

class ConcatSDPTransposeTestSetState : public ConcatSDPTransposeTestBase {
public:
void reduce_state() {
auto states = inferRequest.query_state();
for (auto&& state : states) {
auto state_tensor = state.get_state();
ov::Tensor copy{state_tensor.get_element_type(), state_tensor.get_shape()};
state_tensor.copy_to(copy);
auto new_shape = state_tensor.get_shape();
ASSERT_TRUE(new_shape[transposeOrder[2]] >= 1);
new_shape[transposeOrder[2]] -= 1;
ov::Tensor new_state{state_tensor.get_element_type(), new_shape, copy.data()};
state.set_state(new_state);
}
}
void new_state() {
auto fill = [] (ov::Tensor& t, float val) {
auto shape = t.get_shape();
if (t.get_element_type() == element::f32) {
strided_iota(static_cast<float*>(t.data()), t.get_size(), val, 0.1f);
} else if (t.get_element_type() == element::f16) {
strided_iota(static_cast<ov::float16*>(t.data()), t.get_size(), val, 0.1f);
} else {
ASSERT_TRUE(t.get_element_type() == element::bf16);
strided_iota(static_cast<ov::bfloat16*>(t.data()), t.get_size(), val, 0.1f);
}
};
float val = 0;
auto states = inferRequest.query_state();
for (auto&& state : states) {
auto state_tensor = state.get_state();
auto new_shape = state_tensor.get_shape();
new_shape[transposeOrder[2]] = 3;
ov::Tensor new_state{state_tensor.get_element_type(), new_shape};
fill(new_state, val);
val += 0.13f;

state.set_state(new_state);
}
}
std::vector<ov::Tensor> run_test(std::shared_ptr<ov::Model> model) {
function = model;
prepare();
std::vector<ov::Tensor> outputs;
// case 1: initialization + pastkv reaches limitation, remove some state
int idx = 0;
for (auto&& shapes : targetStaticShapes) {
generate(idx++, shapes);
for (const auto& input : inputs) {
inferRequest.set_tensor(input.first, input.second);
}
inferRequest.infer();
auto outputTensor = inferRequest.get_output_tensor(0);
ov::Tensor copy{outputTensor.get_element_type(), outputTensor.get_shape()};
outputTensor.copy_to(copy);
outputs.push_back(copy);
if (idx > 1) {
reduce_state();
}
}

// case 2: after reset, set_state at once
reset();
new_state();
idx = 0;
for (auto&& shapes : targetStaticShapes) {
generate(idx++, shapes);
for (const auto& input : inputs) {
inferRequest.set_tensor(input.first, input.second);
}
inferRequest.infer();
auto outputTensor = inferRequest.get_output_tensor(0);
ov::Tensor copy{outputTensor.get_element_type(), outputTensor.get_shape()};
outputTensor.copy_to(copy);
outputs.push_back(copy);
}

return outputs;
}
};

TEST_P(ConcatSDPTransposeTestSetState, CompareWithRefs) {
auto actualOutputs = run_test(function);
CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 1);
CheckNumberOfNodesWithType(compiledModel, "Concatenation", 0);
CheckNumberOfNodesWithType(compiledModel, "Reorder", 0);
CheckNumberOfNodesWithType(compiledModel, "Transpose", 1);
CheckNumberOfNodesWithType(compiledModel, "Gather", 0);
auto expectedOutputs = run_test(functionRefs);
CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 0);
for (size_t i = 0; i < actualOutputs.size(); i++) {
ov::test::utils::compare(expectedOutputs[i], actualOutputs[i], abs_threshold, rel_threshold);
}
}

namespace {
const std::vector<InputShapeAndTransposeOrder> inputShapeAndReordersSetState = {
{
// beam search
{{
// B, L1, H, S
{{-1, -1, 8, 64}, {{4, 10, 8, 64}, {4, 1, 8, 64}, {4, 1, 8, 64}, {4, 1, 8, 64}}},
// B, L0, H, S and init tensor
{{-1, -1, 8, 64}, {{4, 2, 8, 64}, {4, 12, 8, 64}, {4, 13, 8, 64}, {4, 14, 8, 64}}},
},
// transposeOrder
{0, 2, 1, 3}
}
}
};

INSTANTIATE_TEST_SUITE_P(smoke_ConcatSDPTransposeTestSetState,
ConcatSDPTransposeTestSetState,
::testing::Combine(::testing::Values(ElementType::f32),
::testing::ValuesIn(inputShapeAndReordersSetState),
::testing::Values(false)),
ConcatSDPTransposeTest::getTestCaseName);

} // namespace
} // namespace SubgraphTestsDefinitions

0 comments on commit e129722

Please sign in to comment.