Skip to content

Commit

Permalink
[CPU] Optimize ScaledDotProductAttention performance (openvinotoolkit…
Browse files Browse the repository at this point in the history
  • Loading branch information
luo-cheng2021 authored Dec 6, 2023
1 parent 55d7765 commit 0a7d1d7
Show file tree
Hide file tree
Showing 29 changed files with 1,291 additions and 1,174 deletions.
4 changes: 2 additions & 2 deletions cmake/developer_package/compile_flags/os_flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ macro(ov_avx2_optimization_flags flags)
set(${flags} -xCORE-AVX2)
endif()
elseif(OV_COMPILER_IS_CLANG OR CMAKE_COMPILER_IS_GNUCXX)
set(${flags} -mavx2 -mfma)
set(${flags} -mavx2 -mfma -mf16c)
else()
message(WARNING "Unsupported CXX compiler ${CMAKE_CXX_COMPILER_ID}")
endif()
Expand All @@ -147,7 +147,7 @@ macro(ov_avx512_optimization_flags flags)
set(${flags} -xCOMMON-AVX512)
endif()
elseif(OV_COMPILER_IS_CLANG OR CMAKE_COMPILER_IS_GNUCXX)
set(${flags} -mavx512f -mfma)
set(${flags} -mavx512f -mfma -mf16c)
else()
message(WARNING "Unsupported CXX compiler ${CMAKE_CXX_COMPILER_ID}")
endif()
Expand Down
19 changes: 6 additions & 13 deletions src/plugins/intel_cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -161,23 +161,16 @@ cross_compiled_file(${TARGET_NAME}
)
cross_compiled_file(${TARGET_NAME}
ARCH AVX512F AVX2 ANY
src/nodes/kernels/scaled_attn/dot_product.cpp
API src/nodes/kernels/scaled_attn/dot_product.hpp
NAME attn_dot_products
src/nodes/kernels/scaled_attn/mha_single_token.cpp
API src/nodes/kernels/scaled_attn/mha_single_token.hpp
NAME mha_single_token
NAMESPACE InferenceEngine::Extensions::Cpu::XARCH
)
cross_compiled_file(${TARGET_NAME}
ARCH AVX512F AVX2 ANY
src/nodes/kernels/scaled_attn/acc_value.cpp
API src/nodes/kernels/scaled_attn/acc_value.hpp
NAME attn_acc_values
NAMESPACE InferenceEngine::Extensions::Cpu::XARCH
)
cross_compiled_file(${TARGET_NAME}
ARCH AVX512F AVX2 ANY
src/nodes/kernels/scaled_attn/reduce.cpp
API src/nodes/kernels/scaled_attn/reduce.hpp
NAME attn_reduce
src/nodes/kernels/scaled_attn/attn_memcpy.cpp
API src/nodes/kernels/scaled_attn/attn_memcpy.hpp
NAME attn_memcpy
NAMESPACE InferenceEngine::Extensions::Cpu::XARCH
)
# system dependencies must go last
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/cpu_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ static const TypeToNameMap& get_type_to_name_tbl() {
{ "Unique", Type::Unique},
{ "Ngram", Type::Ngram},
{ "ScaledDotProductAttention", Type::ScaledDotProductAttention},
{ "ScaledDotProductAttentionStub", Type::ScaledDotProductAttention},
{ "ScaledDotProductAttentionWithKVCache", Type::ScaledDotProductAttention},
{ "RoPE", Type::RoPE},
};
return type_to_name_tbl;
Expand Down
4 changes: 2 additions & 2 deletions src/plugins/intel_cpu/src/extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "transformations/cpu_opset/common/op/fully_connected.hpp"
#include "transformations/cpu_opset/common/op/leaky_relu.hpp"
#include "transformations/cpu_opset/common/op/power_static.hpp"
#include "transformations/cpu_opset/common/op/sdp.hpp"
#include "transformations/cpu_opset/common/op/sdpa.hpp"
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"
#include "transformations/cpu_opset/common/op/ngram.hpp"
#include "transformations/cpu_opset/x64/op/mha.hpp"
Expand Down Expand Up @@ -61,7 +61,7 @@ std::map<std::string, ngraph::OpSet> Extension::getOpSets() {
NGRAPH_OP(NgramNode, ov::intel_cpu)
NGRAPH_OP_X64(MHANode, ov::intel_cpu)
NGRAPH_OP_X64(InteractionNode, ov::intel_cpu)
NGRAPH_OP_X64(ScaledDotProductAttentionStub, ov::intel_cpu)
NGRAPH_OP_X64(ScaledDotProductAttentionWithKVCache, ov::intel_cpu)
#undef NGRAPH_OP

return opset;
Expand Down
71 changes: 0 additions & 71 deletions src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/acc_value.cpp

This file was deleted.

100 changes: 100 additions & 0 deletions src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <float.h>

#include <cmath>
#include <cstring>
#include <iostream>
#include <limits>
#include <type_traits>

#if defined(HAVE_AVX2) || defined(HAVE_AVX512F)
# include <immintrin.h>
#endif

#include "openvino/core/type/bfloat16.hpp"
#include "openvino/core/parallel.hpp"
#include "common.hpp"
#include "attn_memcpy.hpp"

namespace InferenceEngine {
namespace Extensions {
namespace Cpu {
namespace XARCH {

using namespace ov;

// float16 <- float
template<typename TA, typename TB>
void attn_copy(TA* a, TB* b, size_t n) {
size_t i = 0;
#if defined(HAVE_AVX512F)
for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) {
auto vb = mm512_uni_loadu_ps(b + i);
mm512_uni_storeu_ps(a + i, vb);
}
#elif defined(HAVE_AVX2)
for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) {
auto vb = mm256_uni_loadu_ps(b + i);
mm256_uni_storeu_ps(a + i, vb);
}
#endif
for (; i < n; i++) {
a[i] = b[i];
}
}

template <typename T, typename T2>
void attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input,
const ov::intel_cpu::PlainTensor& v_input,
const ov::intel_cpu::PlainTensor& past_k_output,
const ov::intel_cpu::PlainTensor& past_v_output) {
size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3];
parallel_for3d(B, H, L1, [&](size_t b, size_t h, size_t m) {
attn_copy(&past_k_output.at<T2>({b, h, m, 0}),
&k_input.at<T>({b, h, m, 0}),
S);
attn_copy(&past_v_output.at<T2>({b, h, m, 0}),
&v_input.at<T>({b, h, m, 0}),
S);
});
}

template <typename T>
void attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input,
const ov::intel_cpu::PlainTensor& v_input,
const ov::intel_cpu::PlainTensor& past_k_output,
const ov::intel_cpu::PlainTensor& past_v_output) {
size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3];
parallel_for3d(B, H, L1, [&](size_t b, size_t h, size_t m) {
memcpy(&past_k_output.at<T>({b, h, m, 0}),
&k_input.at<T>({b, h, m, 0}),
S * sizeof(T));
memcpy(&past_v_output.at<T>({b, h, m, 0}),
&v_input.at<T>({b, h, m, 0}),
S * sizeof(T));
});
}

void attn_memcpy(const ov::intel_cpu::PlainTensor& k_input,
const ov::intel_cpu::PlainTensor& v_input,
const ov::intel_cpu::PlainTensor& past_k_output,
const ov::intel_cpu::PlainTensor& past_v_output) {
if (past_k_output.get_precision() == k_input.get_precision()) {
if (past_k_output.get_precision() == ov::element::bf16) {
attn_memcpy_kernel<ov::bfloat16>(k_input, v_input, past_k_output, past_v_output);
} else {
assert(past_k_output.get_precision() == ov::element::f16);
attn_memcpy_kernel<ov::float16>(k_input, v_input, past_k_output, past_v_output);
}
} else if (past_k_output.get_precision() == ov::element::f16) {
attn_memcpy_kernel<float, ov::float16>(k_input, v_input, past_k_output, past_v_output);
} else {
attn_memcpy_kernel<float, float>(k_input, v_input, past_k_output, past_v_output);
}
}
} // namespace XARCH
} // namespace Cpu
} // namespace Extensions
} // namespace InferenceEngine
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once

#include <array>
#include <cstddef>
#include <cstdint>
#include <vector>
#include <openvino/core/type/element_type.hpp>
#include "utils/plain_tensor.hpp"

namespace InferenceEngine {
namespace Extensions {
namespace Cpu {
namespace XARCH {

void attn_acc_values(float** outs, float* weights, void** vs, size_t vec_num, size_t vec_len, ov::element::Type input_precision);
void attn_memcpy(const ov::intel_cpu::PlainTensor& k_input,
const ov::intel_cpu::PlainTensor& v_input,
const ov::intel_cpu::PlainTensor& past_k_output,
const ov::intel_cpu::PlainTensor& past_v_output);

} // namespace XARCH
} // namespace Cpu
Expand Down
23 changes: 23 additions & 0 deletions src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once

#include <array>
#include <cstddef>
#include <cstdint>
#include <vector>

#include "openvino/core/type/bfloat16.hpp"
#include "openvino/core/type/float16.hpp"

namespace InferenceEngine {
namespace Extensions {
namespace Cpu {
Expand Down Expand Up @@ -55,6 +59,15 @@ static constexpr size_t vec_len_f32_avx2 = vec_len_avx2 / sizeof(float);
x = _mm512_mask_blend_epi32(mask, nan, x); // Check NaN before converting back to bf16
_mm512_mask_cvtepi32_storeu_epi16(addr, mask_addr, x);
}

inline __m512 mm512_uni_loadu_ps(ov::float16* a) {
auto vec_f16 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a));
return _mm512_cvtph_ps(vec_f16);
}
inline void mm512_uni_storeu_ps(ov::float16* addr, __m512 v) {
__m256i vec_f16 = _mm512_cvtps_ph(v, 0);
_mm256_storeu_si256(reinterpret_cast<__m256i *>(addr), vec_f16);
}
#endif

#ifdef HAVE_AVX2
Expand Down Expand Up @@ -87,6 +100,16 @@ static constexpr size_t vec_len_f32_avx2 = vec_len_avx2 / sizeof(float);
_mm_storeu_si128(reinterpret_cast<__m128i *>(addr), bf16_o);
}

inline __m256 mm256_uni_loadu_ps(ov::float16* a) {
auto vec_f16 = _mm_loadu_si128(reinterpret_cast<__m128i*>(a));
auto o = _mm256_cvtph_ps(vec_f16);
return o;
}
inline void mm256_uni_storeu_ps(ov::float16* a, __m256 v) {
__m128i vec_f16 = _mm256_cvtps_ph(v, 0);
_mm_storeu_si128(reinterpret_cast<__m128i *>(a), vec_f16);
}

inline void hsum(__m256& x) {
__m256 y; // x: 0 1 2 3 4 5 6 7
y = _mm256_permute_ps(x, 0x39); // y: 1 2 3 0 5 6 7 4
Expand Down
Loading

0 comments on commit 0a7d1d7

Please sign in to comment.