Skip to content

Commit

Permalink
sdpa support init data
Browse files Browse the repository at this point in the history
  • Loading branch information
luo-cheng2021 committed Dec 14, 2023
1 parent 6fc0545 commit e77d445
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 95 deletions.
22 changes: 12 additions & 10 deletions src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,18 @@ void attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input,
});
}

void attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input,
const ov::intel_cpu::PlainTensor& v_input,
const ov::intel_cpu::PlainTensor& past_k_output,
const ov::intel_cpu::PlainTensor& past_v_output) {
static void attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input,
const ov::intel_cpu::PlainTensor& v_input,
const ov::intel_cpu::PlainTensor& past_k_output,
const ov::intel_cpu::PlainTensor& past_v_output) {
size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3];
parallel_for3d(B, H, L1, [&](size_t b, size_t h, size_t m) {
memcpy(&past_k_output.at<char>({b, h, m, 0}),
&k_input.at<char>({b, h, m, 0}),
S * k_input.m_element_size);
memcpy(&past_v_output.at<char>({b, h, m, 0}),
&v_input.at<char>({b, h, m, 0}),
S * v_input.m_element_size);
std::memcpy(&past_k_output.at<char>({b, h, m, 0}),
&k_input.at<char>({b, h, m, 0}),
S * k_input.m_element_size);
std::memcpy(&past_v_output.at<char>({b, h, m, 0}),
&v_input.at<char>({b, h, m, 0}),
S * v_input.m_element_size);
});
}

Expand All @@ -84,6 +84,8 @@ void attn_memcpy(const ov::intel_cpu::PlainTensor& k_input,
attn_memcpy_kernel(k_input, v_input, past_k_output, past_v_output);
} else if (k_input.get_precision() == ov::element::f32 && past_k_output.get_precision() == ov::element::f16) {
attn_memcpy_kernel<float, ov::float16>(k_input, v_input, past_k_output, past_v_output);
} else if (k_input.get_precision() == ov::element::f32 && past_k_output.get_precision() == ov::element::bf16) {
attn_memcpy_kernel<float, ov::bfloat16>(k_input, v_input, past_k_output, past_v_output);
} else {
OPENVINO_THROW("unsupport src type: ", k_input.get_precision(), ", dst type: ", past_k_output.get_precision(), " in attn_memcpy");
}
Expand Down
137 changes: 58 additions & 79 deletions src/plugins/intel_cpu/src/nodes/scaled_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,14 +615,23 @@ ScaledDotProductAttention::ScaledDotProductAttention(const std::shared_ptr<ngrap
const auto node = std::dynamic_pointer_cast<const ScaledDotProductAttentionWithKVCache>(op);
m_config.config = node->get_config();
}
// BHLS->LBHS.. lookup table
std::vector<size_t> order;
if (m_config.config.permute_axes.empty()) {
order = {0, 1, 2, 3};
} else {
order = m_config.config.permute_axes;
}
m_config.reverse_order.resize(order.size());
for (size_t i = 0; i < order.size(); i++) {
m_config.reverse_order[order[i]] = i;
}
}

void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;
auto rtPrecision = getOriginalInputPrecisionAtPort(0);
if (rtPrecision != ov::element::bf16 && rtPrecision != ov::element::f32)
rtPrecision = ov::element::f32;
auto rtPrecision = getRuntimePrecision();
auto orginSDPInputNumber = getOriginalInputsNumber() - (m_config.config.fuse_concat ? 3 : 0);

NodeConfig config;
Expand Down Expand Up @@ -692,13 +701,11 @@ void ScaledDotProductAttention::createPrimitive() {
if (desc == nullptr)
OPENVINO_THROW("has unidentified preferable primitive descriptor");
}
auto rtPrecision = getOriginalInputPrecisionAtPort(0);
auto rtPrecision = getRuntimePrecision();

if (rtPrecision == ov::element::bf16) {
m_executor = std::make_shared<AttentionExecutor<KT_ONEDNN, ov::bfloat16>>(m_config);
} else {
// only support bf16/f32
rtPrecision = ov::element::f32;
#ifdef OV_CPU_WITH_MLAS
m_executor = std::make_shared<AttentionExecutor<KT_MLAS, float>>(m_config);
#else
Expand Down Expand Up @@ -798,21 +805,11 @@ void ScaledDotProductAttention::updateBeamTable(const MemoryPtr& mem_beam_idx, s

auto B = beam_idx.size(0);
size_t L0 = 0;

if (m_k_state->is_reset_state()) {
std::vector<size_t> order =
m_config.config.permute_axes.empty() ? std::vector<size_t>{0, 1, 2, 3} : m_config.config.permute_axes;

auto&& init_graph_k_dims = getParentEdgeAt(4)->getMemory().getStaticDims();
L0 = init_graph_k_dims.at(order.at(2));
B = init_graph_k_dims.at(order.at(0)); // TODO (BS): may be sanity check that init_graph B == cur_k B is just enough?
} else if (m_v_state->is_reset_state()) {
std::vector<size_t> order =
m_config.config.permute_axes.empty() ? std::vector<size_t>{0, 1, 2, 3} : m_config.config.permute_axes;

auto&& init_graph_v_dims = getParentEdgeAt(5)->getMemory().getStaticDims();
L0 = init_graph_v_dims.at(order.at(2));
B = init_graph_v_dims.at(order.at(0));
auto is_reset = m_k_state->is_reset_state();
if (is_reset) {
auto inputNumber = getOriginalInputsNumber();
auto&& init_graph_v_dims = getParentEdgeAt(inputNumber - 1)->getMemory().getStaticDims();
L0 = init_graph_v_dims.at(m_config.reverse_order[2]);
} else if (hidden_state_k) {
auto block_desc = hidden_state_k->getDescWithType<BlockedMemoryDesc>();
L0 = block_desc->getShape().getStaticDims()[1];
Expand All @@ -826,7 +823,7 @@ void ScaledDotProductAttention::updateBeamTable(const MemoryPtr& mem_beam_idx, s
PlainTensor new_beam_table_k, new_beam_table_v;
new_beam_table_k.reset(new_hidden_state_k);
new_beam_table_v.reset(new_hidden_state_v);
if (L0 > 0 && !(m_k_state->is_reset_state() || m_v_state->is_reset_state())) {
if (L0 > 0 && !is_reset) {
beam_table_k.reset(hidden_state_k);
beam_table_v.reset(hidden_state_v);
for (size_t b = 0; b < B; b++) {
Expand All @@ -850,7 +847,7 @@ void ScaledDotProductAttention::updateBeamTable(const MemoryPtr& mem_beam_idx, s
VectorDims{0, 1},
0,
VectorDims{},
hidden_state_k->getDescPtr()->as<CpuBlockedMemoryDesc>()->getStrides());
hidden_state_k->getDescWithType<BlockedMemoryDesc>()->getStrides());
hidden_state_k->redefineDesc(mem_desc);
hidden_state_v->redefineDesc(mem_desc);

Expand All @@ -859,22 +856,10 @@ void ScaledDotProductAttention::updateBeamTable(const MemoryPtr& mem_beam_idx, s
beam_table_v.reset(hidden_state_v);
}

if (m_k_state->is_reset_state()) {
auto init_graph_k_mem = getParentEdgeAt(4)->getMemoryPtr();
auto&& shape = init_graph_k_mem->getShape();
// rebuild beam table according to the shape
}

if (m_v_state->is_reset_state()) {
auto init_graph_v_mem = getParentEdgeAt(5)->getMemoryPtr();
auto&& shape = init_graph_v_mem->getShape();
// rebuild beam table according to the shape
}

// first token
if (L0 == 0) {
if (L0 == 0 || is_reset) {
for (size_t b = 0; b < B; b++) {
for (size_t l = 0; l < L1; l++) {
for (size_t l = 0; l < L0 + L1; l++) {
beam_table_k.at<int32_t>({b, l}) = b;
beam_table_v.at<int32_t>({b, l}) = b;
}
Expand All @@ -895,19 +880,19 @@ void ScaledDotProductAttention::updateBeamTable(const MemoryPtr& mem_beam_idx, s
if (!no_reorder) {
m_tmp_reorder.resize<int32_t>({B, L0});
for (size_t i = 0; i < B; i++) {
memcpy(&m_tmp_reorder.at<int32_t>({i}),
&beam_table_k.at<int32_t>({i}),
sizeof(int32_t) * L0);
std::memcpy(&m_tmp_reorder.at<int32_t>({i}),
&beam_table_k.at<int32_t>({i}),
sizeof(int32_t) * L0);
}
auto* table = beam_idx.data<int32_t>();
// beam table is same for both k,v state
for (size_t i = 0; i < B; i++) {
memcpy(&beam_table_k.at<int32_t>({i}),
&m_tmp_reorder.at<int32_t>({static_cast<size_t>(table[i])}),
sizeof(int32_t) * L0);
memcpy(&beam_table_v.at<int32_t>({i}),
&m_tmp_reorder.at<int32_t>({static_cast<size_t>(table[i])}),
sizeof(int32_t) * L0);
std::memcpy(&beam_table_k.at<int32_t>({i}),
&m_tmp_reorder.at<int32_t>({static_cast<size_t>(table[i])}),
sizeof(int32_t) * L0);
std::memcpy(&beam_table_v.at<int32_t>({i}),
&m_tmp_reorder.at<int32_t>({static_cast<size_t>(table[i])}),
sizeof(int32_t) * L0);
}
}
// second token itself
Expand All @@ -933,10 +918,7 @@ void ScaledDotProductAttention::updatePastkv(const MemoryPtr& mem_cur_k, const M
auto L1 = cur_k.size(2);
auto S = cur_k.size(3);
size_t L0 = 0;
std::vector<size_t> reverse_order(order.size());
for (size_t i = 0; i < order.size(); i++) {
reverse_order[order[i]] = i;
}
auto& reverse_order = m_config.reverse_order;
auto reverse = [&reverse_order] (const std::vector<size_t>& cur) {
std::vector<size_t> result(cur.size());
for (size_t i = 0; i < cur.size(); i++) {
Expand All @@ -947,14 +929,11 @@ void ScaledDotProductAttention::updatePastkv(const MemoryPtr& mem_cur_k, const M
auto internal_mem_k = m_k_state->internal_state_mem();
auto internal_mem_v = m_v_state->internal_state_mem();

if (m_k_state->is_reset_state()) {
auto&& init_graph_k_dims = getParentEdgeAt(4)->getMemory().getStaticDims();
L0 = init_graph_k_dims.at(order.at(2));
B = init_graph_k_dims.at(order.at(0)); // TODO (BS): may be sanity check that init_graph B == cur_k B is just enough?
} else if (m_v_state->is_reset_state()) {
auto&& init_graph_v_dims = getParentEdgeAt(5)->getMemory().getStaticDims();
L0 = init_graph_v_dims.at(order.at(2));
B = init_graph_v_dims.at(order.at(0));
auto is_reset = m_k_state->is_reset_state();
if (is_reset) {
auto inputNumber = getOriginalInputsNumber();
auto&& init_graph_v_dims = getParentEdgeAt(inputNumber - 1)->getMemory().getStaticDims();
L0 = init_graph_v_dims.at(m_config.reverse_order[2]);
} else if (internal_mem_k) {
auto block_desc = internal_mem_k->getDescWithType<BlockedMemoryDesc>();
L0 = block_desc->getShape().getStaticDims()[reverse_order[2]];
Expand All @@ -976,7 +955,7 @@ void ScaledDotProductAttention::updatePastkv(const MemoryPtr& mem_cur_k, const M
new_pastv.reset(new_internal_mem_v);
new_pastk = new_pastk.permute(order);
new_pastv = new_pastv.permute(order);
if (L0 > 0 && !(m_k_state->is_reset_state() || m_v_state->is_reset_state())) {
if (L0 > 0 && !is_reset) {
past_k.reset(internal_mem_k);
past_v.reset(internal_mem_v);
past_k = past_k.permute(order);
Expand All @@ -999,47 +978,47 @@ void ScaledDotProductAttention::updatePastkv(const MemoryPtr& mem_cur_k, const M
order,
0,
VectorDims{},
internal_mem_k->getDescPtr()->as<CpuBlockedMemoryDesc>()->getStrides());
internal_mem_k->getDescWithType<BlockedMemoryDesc>()->getStrides());
internal_mem_k->redefineDesc(mem_desc);
internal_mem_v->redefineDesc(mem_desc);

if (m_k_state->is_reset_state()) {
auto init_graph_k_mem = getParentEdgeAt(4)->getMemoryPtr();
auto&& shape = init_graph_k_mem->getShape();
if (!shape.hasZeroDims()) {
//copy from init_graph_k_mem to the state
}
}

if (m_v_state->is_reset_state()) {
auto init_graph_v_mem = getParentEdgeAt(5)->getMemoryPtr();
auto&& shape = init_graph_v_mem->getShape();
if (!shape.hasZeroDims()) {
//copy from init_graph_v_mem to the state
}
}

if (!past_k) {
past_k.reset(internal_mem_k);
past_v.reset(internal_mem_v);
past_k = past_k.permute(order);
past_v = past_v.permute(order);
}
if (L0 > 0 && is_reset) {
PlainTensor init_k, init_v;
auto inputNumber = getOriginalInputsNumber();
init_k.reset(getParentEdgeAt(inputNumber - 2)->getMemoryPtr());
init_v.reset(getParentEdgeAt(inputNumber - 1)->getMemoryPtr());
init_k = init_k.permute(order);
init_v = init_v.permute(order);
attn_memcpy(init_k, init_v, past_k, past_v);
}

attn_memcpy(cur_k, cur_v, past_k.slice(2, L0, L0 + L1), past_v.slice(2, L0, L0 + L1));
}

ov::element::Type ScaledDotProductAttention::getKVCachePrecision() {
if (m_kvcache_precision != ov::element::undefined)
return m_kvcache_precision;
auto rtPrecision = getOriginalInputPrecisionAtPort(0);
if (rtPrecision != ov::element::bf16 && rtPrecision != ov::element::f32)
rtPrecision = ov::element::f32;
auto rtPrecision = getRuntimePrecision();
bool enableKVCacheFP16 = m_config.config.fuse_concat && mayiuse(cpu_isa_t::avx2) && rtPrecision != ov::element::bf16;
m_kvcache_precision = enableKVCacheFP16 ? ov::element::f16 : rtPrecision;

return m_kvcache_precision;
}

ov::element::Type ScaledDotProductAttention::getRuntimePrecision() const {
auto rtPrecision = getOriginalInputPrecisionAtPort(0);
// only support bf16 and f32
if (rtPrecision != ov::element::bf16 && rtPrecision != ov::element::f32)
rtPrecision = ov::element::f32;
return rtPrecision;
}

} // namespace node
} // namespace intel_cpu
} // namespace ov
2 changes: 2 additions & 0 deletions src/plugins/intel_cpu/src/nodes/scaled_attn.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class ScaledDotProductAttention : public Node {
void gatherConcatPastkv(const MemoryPtr& mem_cur_k, const MemoryPtr& mem_cur_v, const MemoryPtr& mem_beam_idx);
void updateBeamTable(const MemoryPtr& mem_beam_idx, size_t new_q_len);
void updatePastkv(const MemoryPtr& mem_cur_k, const MemoryPtr& mem_cur_v);
ov::element::Type getRuntimePrecision() const override;

struct Executor {
virtual void execute(dnnl::stream strm, const std::vector<MemoryPtr>& inputs, const MemoryPtr output, const MemoryPtr presentk_input,
Expand All @@ -63,6 +64,7 @@ class ScaledDotProductAttention : public Node {

struct Config {
ScaledDotProductAttentionWithKVCache::Config config;
std::vector<size_t> reverse_order;
};

Config m_config;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
const auto past_k_node = ov::as_type_ptr<opset6::ReadValue>(pattern_map.at(past_k).get_node_shared_ptr());
const auto past_v_node = ov::as_type_ptr<opset6::ReadValue>(pattern_map.at(past_v).get_node_shared_ptr());
if (!check_valid_children_type(past_k_node) || !check_valid_children_type(past_v_node)) {
// TODO: remove
std::cout << "StatefulSDPAFusion unexpected children of readvalue\n";
return false;
}
const auto concat_k_node = ov::as_type_ptr<opset6::Concat>(pattern_map.at(concat_k).get_node_shared_ptr());
Expand All @@ -111,8 +109,6 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
const auto gather_k_node = ov::as_type_ptr<opset8::Gather>(pattern_map.at(gather_input_k).get_node_shared_ptr());
const auto gather_v_node = ov::as_type_ptr<opset8::Gather>(pattern_map.at(gather_input_v).get_node_shared_ptr());
if (gather_k_node->input_value(1) != gather_v_node->input_value(1)) {
// TODO: remove
std::cout << "StatefulSDPAFusion beam_idx is not same for gather\n";
return false;
}
auto args = sdp_node->input_values();
Expand Down Expand Up @@ -140,8 +136,6 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
else
assign_v_node->set_arguments({new_node->output(2)});

// TODO: remove
std::cout << "StatefulSDPAFusion hits pattern\n";
return true;
};

Expand Down

0 comments on commit e77d445

Please sign in to comment.