Skip to content

Commit

Permalink
get_state test and code clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
maxnick committed Dec 16, 2023
1 parent b8de7bf commit 9aa7e4b
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,6 @@ CpuBlockedMemoryDesc::CpuBlockedMemoryDesc(ov::element::Type prc, const Shape& s
}
}
} else {
// TODO(BS): try to reuse allocated memory
// if (shape.hasZeroDims() && std::any_of(strides.begin(), strides.end(), [](size_t stride) { return stride != 0; } )) {
// OPENVINO_THROW("Can't create CpuBlockedMemoryDesc with zero dim, but with non zero strides");
// }
this->strides = strides;
}

Expand Down
46 changes: 0 additions & 46 deletions src/plugins/intel_cpu/src/memory_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,52 +154,6 @@ MemoryPtr VariableStateDoubleBuffer::internal_state_mem() const {
return prime_mem();
}

VariableStateSingleBuffer::VariableStateSingleBuffer(const std::string& name,
const MemoryPtr& buffer,
const MemoryDescPtr& external_desc) :
VariableStateBase(name, external_desc) {
OPENVINO_ASSERT(buffer);
m_internal_mem = buffer;
m_internal_desc = m_internal_mem->getDescPtr();
auto&& shape = m_internal_desc->getShape();
//TODO what if by some reason we already have internal static state while the node is dynamic, is it even possible?

if (shape.isStatic()) {
m_internal_mem->nullify();
} else {
//in the case of the original desc has dynamic shape we create an empty tensor
auto new_desc = to_static(m_internal_desc);
m_internal_mem->redefineDesc(new_desc);
}
}

void VariableStateSingleBuffer::reset_impl() {
auto new_desc = to_static(m_internal_desc);
m_internal_mem->redefineDesc(new_desc);
m_internal_mem->nullify();
}

MemoryPtr VariableStateSingleBuffer::input_mem() {
return m_internal_mem;
}

MemoryPtr VariableStateSingleBuffer::output_mem() {
return m_internal_mem;
}

MemoryDescPtr VariableStateSingleBuffer::internal_desc() const {
return m_internal_desc;
}

MemoryPtr VariableStateSingleBuffer::internal_state_mem() const {
return m_internal_mem;
}

void VariableStateSingleBuffer::commit_impl() {
//nothing to do
}


VariableStateKVcache::VariableStateKVcache(
const std::string& name,
const MemoryDescPtr& external_desc,
Expand Down
22 changes: 0 additions & 22 deletions src/plugins/intel_cpu/src/memory_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,28 +98,6 @@ class VariableStateDoubleBuffer : public VariableStateBase {
size_t buffer_num = 0;
};

class VariableStateSingleBuffer : public VariableStateBase {
public:
VariableStateSingleBuffer(const std::string& name,
const MemoryPtr& buffer,
const MemoryDescPtr& external_desc);

MemoryPtr input_mem() override;
MemoryPtr output_mem() override;
MemoryDescPtr internal_desc() const override;

private:
//ov::intel_cpu::VariableStateBase
void reset_impl() override;
void commit_impl() override;

MemoryPtr internal_state_mem() const override;

private:
MemoryDescPtr m_internal_desc; //mem desc required by the graph internal tensor
MemoryPtr m_internal_mem;
};

class VariableStateKVcache : public VariableStateBase {
public:
VariableStateKVcache(const std::string& name,
Expand Down
26 changes: 17 additions & 9 deletions src/plugins/intel_cpu/src/nodes/scaled_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -916,8 +916,10 @@ void ScaledDotProductAttention::updateBeamTable(const MemoryPtr& mem_beam_idx, s
}
// second token itself
for (size_t i = 0; i < B; i++) {
beam_table_k.at<int32_t>({i, L0}) = i;
beam_table_v.at<int32_t>({i, L0}) = i;
for (size_t j = 0; j < L1; j++) {
beam_table_k.at<int32_t>({i, L0 + j}) = i;
beam_table_v.at<int32_t>({i, L0 + j}) = i;
}
}
}

Expand Down Expand Up @@ -962,7 +964,7 @@ void ScaledDotProductAttention::updatePastkv(const MemoryPtr& mem_cur_k, const M
L0 = dims[order[2]];
OPENVINO_ASSERT(B == B_state, "pastkv batch: ", B, " is not equal to batch of state: ", B_state);
}

OPENVINO_ASSERT(B * (L0 + L1) > 0, "B or (L0+L1) is zero, B: ", B, ", L0: ", L0, ", L1: ", L1);
// resize buffer
if (B * H * (L0 + L1) * S > m_k_state->internal_state_max_size()) {
auto new_shape = {B, H, (L0 + L1) * 2, S};
Expand Down Expand Up @@ -1013,13 +1015,19 @@ void ScaledDotProductAttention::updatePastkv(const MemoryPtr& mem_cur_k, const M
past_v = past_v.permute(order);
}
if (L0 > 0 && is_reset) {
PlainTensor init_k, init_v;
auto inputNumber = getOriginalInputsNumber();
init_k.reset(getParentEdgeAt(inputNumber - 2)->getMemoryPtr());
init_v.reset(getParentEdgeAt(inputNumber - 1)->getMemoryPtr());
init_k = init_k.permute(order);
init_v = init_v.permute(order);
attn_memcpy(init_k, init_v, past_k, past_v);
auto k_mem = getParentEdgeAt(inputNumber - 2)->getMemoryPtr();
auto v_mem = getParentEdgeAt(inputNumber - 1)->getMemoryPtr();
auto&& k_shape = k_mem->getShape();
auto&& v_shape = v_mem->getShape();
if (!k_shape.hasZeroDims() && !v_shape.hasZeroDims()) {
PlainTensor init_k, init_v;
init_k.reset(k_mem);
init_v.reset(v_mem);
init_k = init_k.permute(order);
init_v = init_v.permute(order);
attn_memcpy(init_k, init_v, past_k, past_v);
}
}

attn_memcpy(cur_k, cur_v, past_k.slice(2, L0, L0 + L1), past_v.slice(2, L0, L0 + L1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,14 @@ class ConcatMultiQuerySDPTest : public testing::WithParamInterface<ConcatMultiQu
outputTensor.copy_to(copy);
outputs.push_back(copy);
}
auto states = inferRequest.query_state();
for (auto&& state : states) {
auto state_tensor = state.get_state();
ov::Tensor copy{state_tensor.get_element_type(), state_tensor.get_shape()};
state_tensor.copy_to(copy);
outputs.push_back(copy);
}

reset();

return outputs;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,14 @@ class ConcatSDPTransposeTest : public testing::WithParamInterface<ConcatSDPTrans
outputTensor.copy_to(copy);
outputs.push_back(copy);
}
auto states = inferRequest.query_state();
for (auto&& state : states) {
auto state_tensor = state.get_state();
ov::Tensor copy{state_tensor.get_element_type(), state_tensor.get_shape()};
state_tensor.copy_to(copy);
outputs.push_back(copy);
}

reset();

return outputs;
Expand Down

0 comments on commit 9aa7e4b

Please sign in to comment.