Skip to content

Commit

Permalink
apply review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
luo-cheng2021 committed Nov 28, 2023
1 parent 504b6e9 commit bceb7aa
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 42 deletions.
52 changes: 26 additions & 26 deletions src/plugins/intel_cpu/src/nodes/scaled_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
PlainTensor<T> k_input; // f32[B, H|1, L1, S] / [B, H|1, L0+L1, S]
PlainTensor<T> v_input; // f32[B, H|1, L1, S] / [B, H|1, L0+L1, S]
PlainTensor<int32_t> beam_table; // i32[B, max_kvLen]
PlainTensor<float> attn_mask; // f32[[B|1],[H|1], L1|1, L0+L1]
PlainTensor<float> attn_buf; // f32[[B|1],[H|1], L1|1, L0+L1]
float scale_input = 0.0f;

MHAKernel<KType, T> kernel;
Expand All @@ -587,13 +587,13 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
size_t B, H, L1, L0, S;

Config config;
AttentionExecutor(const Config& _config) : config(_config) {}
AttentionExecutor(const Config& _config) : attn_buf(true), config(_config) {}

void prepare_attn_mask(MemoryPtr attn_input) {
attn_mask.resize(attn_input->getStaticDims());
attn_buf.resize(attn_input->getStaticDims());
auto p = reinterpret_cast<uint8_t*>(attn_input->getData());
for (size_t i = 0; i < attn_input->getSize(); i++)
attn_mask.data()[i] = p[i] ? 0.0f : -FLT_MAX;
attn_buf.data()[i] = p[i] ? 0.0f : -FLT_MAX;
}

void concat_pastkv(const std::vector<MemoryPtr>& inputs,
Expand All @@ -610,11 +610,9 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
L0 = past_k_mem->getStaticDims()[2];
// k,v may support multiquery
auto Hk = past_k_mem->getStaticDims()[1];
// [S, B, L0, S]
past_k_output.resize({L0 + L1, B, Hk, S}, static_cast<T*>(outputs[1]->getData()));
past_v_output.resize({L0 + L1, B, Hk, S}, static_cast<T*>(outputs[2]->getData()));
past_k_output = past_k_output.permute({1, 2, 0, 3});
past_v_output = past_v_output.permute({1, 2, 0, 3});
// [B, H, L0, S]
past_k_output.reset(outputs[1]);
past_v_output.reset(outputs[2]);
parallel_for3d(B, Hk, L1, [&](size_t b, size_t h, size_t m) {
std::memcpy(&past_k_output.at({b, h, m + L0, 0}),
&k_input.at({b, h, m, 0}),
Expand All @@ -623,12 +621,10 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
&v_input.at({b, h, m, 0}),
S * sizeof(T));
});
if (!config.skipPastKVCopy) {
if (!config.is_concat_inplaced) {
PlainTensor<T> past_k_input, past_v_input;
past_k_input.resize({L0, B, Hk, S}, static_cast<T*>(past_k_mem->getData()));
past_v_input.resize({L0, B, Hk, S}, static_cast<T*>(inputs[past_k_idx + 1]->getData()));
past_k_input = past_k_input.permute({1, 2, 0, 3});
past_v_input = past_v_input.permute({1, 2, 0, 3});
past_k_input.reset(past_k_mem);
past_v_input.reset(inputs[past_k_idx + 1]);
parallel_for3d(B, Hk, L0, [&](size_t b, size_t h, size_t m) {
std::memcpy(&past_k_output.at({b, h, m, 0}),
&past_k_input.at({b, h, m, 0}),
Expand Down Expand Up @@ -658,11 +654,13 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
q_input.reset(inputs[0]);
k_input.reset(inputs[1]);
v_input.reset(inputs[2]);
PlainTensor<float> attn_mask;
if (input_num > 3) {
// attn_mask
if (inputs[3]->getDesc().getPrecision() == ov::element::u8) {
// bool->f32
prepare_attn_mask(inputs[3]);
attn_mask = attn_buf;
} else {
attn_mask.reset(inputs[3]);
}
Expand Down Expand Up @@ -749,17 +747,6 @@ void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() {
return;
auto rtPrecision = getOriginalInputPrecisionAtPort(0);

if (rtPrecision == ov::element::bf16) {
m_executor = std::make_shared<AttentionExecutor<KT_ONEDNN, ov::bfloat16>>(m_config);
} else {
// only support bf16/f32
rtPrecision = ov::element::f32;
#ifdef OV_CPU_WITH_MLAS
m_executor = std::make_shared<AttentionExecutor<KT_MLAS, float>>(m_config);
#else
m_executor = std::make_shared<AttentionExecutor<KT_ONEDNN, float>>(m_config);
#endif
}
NodeConfig config;
auto& creatorsMap = BlockedDescCreator::getCommonCreators();
auto orginSDPInputNumber = getOriginalInputsNumber() - (m_config.config.fuse_concat ? 2 : 0);
Expand Down Expand Up @@ -830,7 +817,20 @@ void ScaledDotProductAttention::createPrimitive() {
if (desc == nullptr)
OPENVINO_THROW("has unidentified preferable primitive descriptor");

m_config.skipPastKVCopy = desc->getConfig().outConfs[1].inPlace() >= 0;
m_config.is_concat_inplaced = desc->getConfig().outConfs[1].inPlace() >= 0;
}
auto rtPrecision = getOriginalInputPrecisionAtPort(0);

if (rtPrecision == ov::element::bf16) {
m_executor = std::make_shared<AttentionExecutor<KT_ONEDNN, ov::bfloat16>>(m_config);
} else {
// only support bf16/f32
rtPrecision = ov::element::f32;
#ifdef OV_CPU_WITH_MLAS
m_executor = std::make_shared<AttentionExecutor<KT_MLAS, float>>(m_config);
#else
m_executor = std::make_shared<AttentionExecutor<KT_ONEDNN, float>>(m_config);
#endif
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/nodes/scaled_attn.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class ScaledDotProductAttention : public Node {

struct Config {
ScaledDotProductAttentionStub::Config config;
bool skipPastKVCopy = false;
bool is_concat_inplaced = false;
};

struct Config m_config;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ bool ov::intel_cpu::ScaledDotProductAttentionStub::visit_attributes(ov::Attribut
visitor.on_attribute("output_BLHxS", m_config.output_BLHxS);
visitor.on_attribute("fuse_causal_attn", m_config.fuse_causal_attn);
visitor.on_attribute("is_causal", m_config.is_causal);
visitor.on_attribute("fuse_concat", m_config.fuse_concat);
visitor.finish_structure();
return true;
}
17 changes: 13 additions & 4 deletions src/plugins/intel_cpu/src/utils/plain_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,19 @@ struct PlainTensor {
}

void reset(MemoryPtr mem) {
assert_dt<DT>(mem->getDesc().getPrecision());
const auto& mem_desc = mem->getDesc();
assert_dt<DT>(mem_desc.getPrecision());
const auto* desc_ptr = mem_desc.as<BlockedMemoryDesc>();
// not support block layout
OPENVINO_ASSERT(desc_ptr && desc_ptr->getOrder().size() == mem->getStaticDims().size());
m_mem = mem;
VectorDims strides(desc_ptr->getStrides().size());
const auto& orders = desc_ptr->getOrder();
for (size_t i = 0; i < orders.size(); i++) {
strides[orders[i]] = desc_ptr->getStrides()[i];
}
// this reshape_to() can do reshape w/o additional cost
resize(mem->getStaticDims(), reinterpret_cast<DT*>(mem->getData()));
resize(mem->getStaticDims(), reinterpret_cast<DT*>(mem->getData()), &strides);
}

ov::element::Type get_precision(void) {
Expand Down Expand Up @@ -327,14 +336,14 @@ struct PlainTensor {
return new_tensor_view;
}

void resize(const VectorDims& new_dims, DT* data = nullptr) {
void resize(const VectorDims& new_dims, DT* data = nullptr, const VectorDims* strides = nullptr) {
// initialize strides for compact/dense tensor
m_rank = new_dims.size();
assert(m_rank <= PLAINTENSOR_RANK_MAX);
size_t stride = 1;
for (int i = m_rank - 1; i >= 0; i--) {
m_dims[i] = new_dims[i];
m_strides[i] = stride;
m_strides[i] = strides ? (*strides)[i] : stride;
stride *= new_dims[i];
}

Expand Down
33 changes: 22 additions & 11 deletions src/plugins/intel_cpu/tests/unit/graph/scaled_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
using namespace ov::intel_cpu;

TEST(ScaledAttnGraphTest, smoke_Check_Scaled_Concat_Noplace) {
auto build_graph = [](const ov::Shape& shape, float qkv_val, float past_kv_val) {
auto qkv = ov::op::v0::Constant::create(ov::element::f32, shape, {qkv_val});
auto build_graph = [](const ov::Shape& shape, float* qkv_val, float* past_kv_val) {
auto qkv = ov::op::v0::Constant::create(ov::element::f32, shape, qkv_val);
qkv->set_friendly_name("qkv_const");
auto pastkv = ov::op::v0::Constant::create(ov::element::f32, shape, {past_kv_val});
auto pastkv = ov::op::v0::Constant::create(ov::element::f32, shape, past_kv_val);
pastkv->set_friendly_name("pastkv_const");
// only need a dynamic parameter but its value will not be used
auto attn = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{-1});
Expand Down Expand Up @@ -101,7 +101,7 @@ TEST(ScaledAttnGraphTest, smoke_Check_Scaled_Concat_Noplace) {
graph.Infer();
};

auto check_graph = [] (Graph& graph, std::map<std::string, std::pair<float, ov::Shape>>& expected) {
auto check_graph = [] (Graph& graph, std::map<std::string, std::pair<float*, ov::Shape>>& expected) {
auto& outputNodesMap = graph.GetOutputNodesMap();
auto is_same = [] (float a, float b) {
return std::abs(a - b) < 0.0001f;
Expand All @@ -116,7 +116,9 @@ TEST(ScaledAttnGraphTest, smoke_Check_Scaled_Concat_Noplace) {
const auto& memory = parentEdge->getMemoryPtr();
auto size = memory->getSize() / sizeof(float);
auto p = reinterpret_cast<float*>(memory->getData());
ASSERT_EQ(std::all_of(p, p + size, [&](float v) { return is_same(v, expected.at(name).first); }), true);
for (size_t i = 0; i < size; i++) {
ASSERT_EQ(is_same(p[i], expected.at(name).first[i]), true);
}
ASSERT_EQ(memory->getShape(), ov::intel_cpu::Shape(expected.at(name).second));
}
};
Expand All @@ -133,16 +135,25 @@ TEST(ScaledAttnGraphTest, smoke_Check_Scaled_Concat_Noplace) {
return (*itr);
};

float qkv_val = 3.0f, past_kv_val = 3.0f;
ov::Shape shape{2, 2, 8, 8};
auto graph = build_graph(shape, qkv_val, past_kv_val);
auto strided_iota = [] (float* first, size_t n, float value, float stride) {
for (size_t i = 0; i < n; i++) {
*first++ = value;
value += stride;
}
};

ov::Shape shape{1, 1, 8, 8};
const size_t elements_count = std::accumulate(shape.begin(), shape.end(), size_t{1}, std::multiplies<size_t>());
std::vector<float> val(elements_count * 2);
strided_iota(val.data(), val.size(), -10.0f, 0.1f);
auto graph = build_graph(shape, val.data() + elements_count, val.data());
run_graph(graph);
// if no inplace, the pastk and pastv will concat, check shape and value
ov::Shape expectedShape(shape);
expectedShape[2] *= 2;
std::map<std::string, std::pair<float, ov::Shape>> expected{
{"pastk", std::make_pair(past_kv_val, expectedShape)},
{"pastv", std::make_pair(past_kv_val, expectedShape)}};
std::map<std::string, std::pair<float*, ov::Shape>> expected{
{"pastk", std::make_pair(val.data(), expectedShape)},
{"pastv", std::make_pair(val.data(), expectedShape)}};
check_graph(graph, expected);
auto spd = find_node_type(graph, Type::ScaledDotProductAttention)->getSelectedPrimitiveDescriptor();
ASSERT_EQ(spd->getConfig().outConfs[1].inPlace(), -1);
Expand Down

0 comments on commit bceb7aa

Please sign in to comment.