Skip to content

Commit

Permalink
choose fp16 for kvcaches in SDPA node.
Browse files Browse the repository at this point in the history
OV_ENABLE_SDPA_KVCACHE_FP16 (default false)
  • Loading branch information
ceciliapeng2011 committed Nov 27, 2023
1 parent ad723b5 commit a81f36b
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 51 deletions.
4 changes: 2 additions & 2 deletions cmake/developer_package/compile_flags/os_flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ macro(ov_avx2_optimization_flags flags)
set(${flags} -xCORE-AVX2)
endif()
elseif(OV_COMPILER_IS_CLANG OR CMAKE_COMPILER_IS_GNUCXX)
set(${flags} -mavx2 -mfma)
set(${flags} -mavx2 -mfma -mf16c)
else()
message(WARNING "Unsupported CXX compiler ${CMAKE_CXX_COMPILER_ID}")
endif()
Expand All @@ -147,7 +147,7 @@ macro(ov_avx512_optimization_flags flags)
set(${flags} -xCOMMON-AVX512)
endif()
elseif(OV_COMPILER_IS_CLANG OR CMAKE_COMPILER_IS_GNUCXX)
set(${flags} -mavx512f -mfma)
set(${flags} -mavx512f -mfma -mf16c)
else()
message(WARNING "Unsupported CXX compiler ${CMAKE_CXX_COMPILER_ID}")
endif()
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ cross_compiled_file(${TARGET_NAME}
NAME attn_reduce
NAMESPACE InferenceEngine::Extensions::Cpu::XARCH
)

# system dependencies must go last
target_link_libraries(${TARGET_NAME} PRIVATE openvino::pugixml)
ov_set_threading_interface_for(${TARGET_NAME})
Expand Down
15 changes: 15 additions & 0 deletions src/plugins/intel_cpu/src/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,21 @@ void Graph::InitDescriptors() {
OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, node->profiling.selectOptimalPrimitiveDescriptor);
DEBUG_LOG("Select optimal primitive descriptors for node: ", node->getName());
node->selectOptimalPrimitiveDescriptor();
#ifdef CPU_DEBUG_CAPS
const auto& SPDs = node->getSupportedPrimitiveDescriptors();
for (size_t i = 0; i < SPDs.size(); i++) {
DEBUG_LOG("#",
node->getExecIndex(),
" ",
node->getName(),
" SupportedPrimitiveDescriptors [",
i,
"/",
SPDs.size(),
"]: \n",
SPDs[i]);
}
#endif
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,19 @@ void attn_acc_values(float** outs, float* weights, void** vs, size_t vec_num, si
auto v_ptr = static_cast<float*>(vs[i]);
attn_acc_value_inner(out_ptr, weights[i], v_ptr, vec_len);
}
} else {
assert(input_precision == ov::element::bf16);
} else if (input_precision == ov::element::bf16) {
for (size_t i = 0; i < vec_num; i++) {
auto out_ptr = outs[i];
auto v_ptr = static_cast<ov::bfloat16*>(vs[i]);
attn_acc_value_inner(out_ptr, weights[i], v_ptr, vec_len);
}
} else {
assert(input_precision == ov::element::f16);
for (size_t i = 0; i < vec_num; i++) {
auto out_ptr = outs[i];
auto v_ptr = static_cast<ov::float16*>(vs[i]);
attn_acc_value_inner(out_ptr, weights[i], v_ptr, vec_len);
}
}
}

Expand Down
22 changes: 22 additions & 0 deletions src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
#include <cstdint>
#include <vector>

#include "openvino/core/type/bfloat16.hpp"
#include "openvino/core/type/float16.hpp"

namespace InferenceEngine {
namespace Extensions {
namespace Cpu {
Expand Down Expand Up @@ -55,6 +58,15 @@ static constexpr size_t vec_len_f32_avx2 = vec_len_avx2 / sizeof(float);
x = _mm512_mask_blend_epi32(mask, nan, x); // Check NaN before converting back to bf16
_mm512_mask_cvtepi32_storeu_epi16(addr, mask_addr, x);
}

inline __m512 mm512_uni_loadu_ps(ov::float16* a) {
auto vec_f16 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a));
return _mm512_cvtph_ps(vec_f16);
}
inline void mm512_uni_storeu_ps(ov::float16* addr, __m512 v) {
__m256i vec_f16 = _mm512_cvtps_ph(v, 0);
_mm256_storeu_si256(reinterpret_cast<__m256i *>(addr), vec_f16);
}
#endif

#ifdef HAVE_AVX2
Expand Down Expand Up @@ -87,6 +99,16 @@ static constexpr size_t vec_len_f32_avx2 = vec_len_avx2 / sizeof(float);
_mm_storeu_si128(reinterpret_cast<__m128i *>(addr), bf16_o);
}

inline __m256 mm256_uni_loadu_ps(ov::float16* a) {
auto vec_f16 = _mm_loadu_si128(reinterpret_cast<__m128i*>(a));
auto o = _mm256_cvtph_ps(vec_f16);
return o;
}
inline void mm256_uni_storeu_ps(ov::float16* a, __m256 v) {
__m128i vec_f16 = _mm256_cvtps_ph(v, 0); // FIXME: rounding
_mm_storeu_si128(reinterpret_cast<__m128i *>(a), vec_f16);
}

inline void hsum(__m256& x) {
__m256 y; // x: 0 1 2 3 4 5 6 7
y = _mm256_permute_ps(x, 0x39); // y: 1 2 3 0 5 6 7 4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ namespace Extensions {
namespace Cpu {
namespace XARCH {

template<typename T>
float dot_product_inner(T* a, T* b, size_t n) {
template<typename TA, typename TB>
float dot_product_inner(TA* a, TB* b, size_t n) {
size_t i = 0;
float sum = 0.0f;
#if defined(HAVE_AVX512F)
Expand All @@ -51,15 +51,28 @@ float dot_product_inner(T* a, T* b, size_t n) {
return sum;
}

void attn_dot_products(void** a, void** b, void**c, size_t vec_num, size_t vec_len, ov::element::Type input_precision) {
if (input_precision == ov::element::f32) {
for (size_t i = 0; i < vec_num; i++) {
auto a_ptr = static_cast<float*>(a[i]);
auto b_ptr = static_cast<float*>(b[i]);
auto c_ptr = static_cast<float*>(c[i]);
c_ptr[0] = dot_product_inner(a_ptr, b_ptr, vec_len);
void attn_dot_products(void** a, void** b, void**c, size_t vec_num, size_t vec_len,
ov::element::Type a_precision, ov::element::Type b_precision) {
if (a_precision == ov::element::f32) {
if (b_precision == ov::element::f32) {
for (size_t i = 0; i < vec_num; i++) {
auto a_ptr = static_cast<float*>(a[i]);
auto b_ptr = static_cast<float*>(b[i]);
auto c_ptr = static_cast<float*>(c[i]);
c_ptr[0] = dot_product_inner(a_ptr, b_ptr, vec_len);
}
} else {
assert(b_precision == ov::element::f16);
for (size_t i = 0; i < vec_num; i++) {
auto a_ptr = static_cast<float*>(a[i]);
auto b_ptr = static_cast<ov::float16*>(b[i]);
auto c_ptr = static_cast<float*>(c[i]);
c_ptr[0] = dot_product_inner(a_ptr, b_ptr, vec_len);
}
}
} else {
assert(a_precision == ov::element::bf16);
assert(b_precision == ov::element::bf16);
for (size_t i = 0; i < vec_num; i++) {
auto a_ptr = static_cast<ov::bfloat16*>(a[i]);
auto b_ptr = static_cast<ov::bfloat16*>(b[i]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace Extensions {
namespace Cpu {
namespace XARCH {

void attn_dot_products(void** a, void** b, void**c, size_t vec_num, size_t vec_len, ov::element::Type input_precision);
void attn_dot_products(void** a, void** b, void**c, size_t vec_num, size_t vec_len, ov::element::Type precision_a, ov::element::Type precision_b);

} // namespace XARCH
} // namespace Cpu
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/nodes/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,4 +611,4 @@ MemStatePtr MemoryInputSDPA::makeState() const {

} // namespace node
} // namespace intel_cpu
} // namespace ov
} // namespace ov
Loading

0 comments on commit a81f36b

Please sign in to comment.