Skip to content

Commit

Permalink
choose fp16 for kvcaches in SDPA node.
Browse files Browse the repository at this point in the history
OV_ENABLE_SDPA_KVCACHE_FP16 (default false)

enable avx512 cpu_convert

OV_ENABLE_SDPA_KVCACHE_FP32 to diable fp16 kvcache

review fix

apply fp16 kvcache conditionally.
  • Loading branch information
ceciliapeng2011 committed Dec 1, 2023
1 parent 504b6e9 commit 15fb01c
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 66 deletions.
4 changes: 2 additions & 2 deletions cmake/developer_package/compile_flags/os_flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ macro(ov_avx2_optimization_flags flags)
set(${flags} -xCORE-AVX2)
endif()
elseif(OV_COMPILER_IS_CLANG OR CMAKE_COMPILER_IS_GNUCXX)
set(${flags} -mavx2 -mfma)
set(${flags} -mavx2 -mfma -mf16c)
else()
message(WARNING "Unsupported CXX compiler ${CMAKE_CXX_COMPILER_ID}")
endif()
Expand All @@ -147,7 +147,7 @@ macro(ov_avx512_optimization_flags flags)
set(${flags} -xCOMMON-AVX512)
endif()
elseif(OV_COMPILER_IS_CLANG OR CMAKE_COMPILER_IS_GNUCXX)
set(${flags} -mavx512f -mfma)
set(${flags} -mavx512f -mfma -mf16c)
else()
message(WARNING "Unsupported CXX compiler ${CMAKE_CXX_COMPILER_ID}")
endif()
Expand Down
53 changes: 42 additions & 11 deletions src/plugins/intel_cpu/src/nodes/common/cpu_convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ using namespace dnnl::impl::utils;
using namespace dnnl::impl::cpu::x64;
using namespace Xbyak;

template <typename src_t, typename dst_t>

template <typename src_t, typename dst_t, int isa>
void convert_vec(jit_generator & gen,
const RegExp & src,
const RegExp & dst);

template <>
void convert_vec<ov::float16, float>(jit_generator & gen,
void convert_vec<ov::float16, float, cpu_isa_t::avx2>(jit_generator & gen,
const RegExp & src,
const RegExp & dst) {
auto const & f16vec = gen.xmm3;
Expand All @@ -50,7 +51,7 @@ void convert_vec<ov::float16, float>(jit_generator & gen,
}

template <>
void convert_vec<float, ov::float16>(jit_generator & gen,
void convert_vec<float, ov::float16, cpu_isa_t::avx2>(jit_generator & gen,
const RegExp & src,
const RegExp & dst) {
auto const & f16vec = gen.xmm3;
Expand All @@ -61,12 +62,36 @@ void convert_vec<float, ov::float16>(jit_generator & gen,
gen.movdqu(gen.xword[dst], f16vec);
}

template <>
void convert_vec<ov::float16, float, cpu_isa_t::avx512_core>(jit_generator & gen,
const RegExp & src,
const RegExp & dst) {
auto const & f16vec = gen.ymm3;
auto const & f32vec = gen.zmm4;

gen.movdqu(f16vec, gen.yword[src]);
gen.vcvtph2ps(f32vec, f16vec);
gen.vmovups(gen.zword[dst], f32vec);
}

template <>
void convert_vec<float, ov::float16, cpu_isa_t::avx512_core>(jit_generator & gen,
const RegExp & src,
const RegExp & dst) {
auto const & f16vec = gen.ymm3;
auto const & f32vec = gen.zmm4;

gen.vmovups(f32vec, gen.zword[src]);
gen.vcvtps2ph(f16vec, f32vec, 0);
gen.movdqu(gen.yword[dst], f16vec);
}

class jit_convert_array : public jit_kernel {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_convert_array)

void generate() override {
constexpr size_t vlen = 8u;
constexpr size_t vlen_log2 = 3;
const size_t vlen = mayiuse(cpu_isa_t::avx512_core) ? 16u : 8u;
const size_t vlen_log2 = mayiuse(cpu_isa_t::avx512_core) ? 4 : 3;

preamble();

Expand Down Expand Up @@ -131,12 +156,18 @@ class jit_convert_array : public jit_kernel {

template<typename src_t, typename dst_t>
static fn_t get() {
if (mayiuse(cpu_isa_t::avx2)
&& dnnl::impl::cpu::x64::cpu().has(Xbyak::util::Cpu::tF16C)) {
static jit_convert_array converter(convert_vec<src_t, dst_t>, sizeof(src_t), sizeof(dst_t));
auto & generator = static_cast<jit_generator&>(converter);
generator.create_kernel();
return (fn_t)generator.jit_ker();
if (dnnl::impl::cpu::x64::cpu().has(Xbyak::util::Cpu::tF16C)) {
if (mayiuse(cpu_isa_t::avx512_core)) {
static jit_convert_array converter(convert_vec<src_t, dst_t, cpu_isa_t::avx512_core>, sizeof(src_t), sizeof(dst_t));
auto & generator = static_cast<jit_generator&>(converter);
generator.create_kernel();
return (fn_t)generator.jit_ker();
} else if (mayiuse(cpu_isa_t::avx2)) {
static jit_convert_array converter(convert_vec<src_t, dst_t, cpu_isa_t::avx2>, sizeof(src_t), sizeof(dst_t));
auto & generator = static_cast<jit_generator&>(converter);
generator.create_kernel();
return (fn_t)generator.jit_ker();
}
}
return nullptr;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,19 @@ void attn_acc_values(float** outs, float* weights, void** vs, size_t vec_num, si
auto v_ptr = static_cast<float*>(vs[i]);
attn_acc_value_inner(out_ptr, weights[i], v_ptr, vec_len);
}
} else {
assert(input_precision == ov::element::bf16);
} else if (input_precision == ov::element::bf16) {
for (size_t i = 0; i < vec_num; i++) {
auto out_ptr = outs[i];
auto v_ptr = static_cast<ov::bfloat16*>(vs[i]);
attn_acc_value_inner(out_ptr, weights[i], v_ptr, vec_len);
}
} else {
assert(input_precision == ov::element::f16);
for (size_t i = 0; i < vec_num; i++) {
auto out_ptr = outs[i];
auto v_ptr = static_cast<ov::float16*>(vs[i]);
attn_acc_value_inner(out_ptr, weights[i], v_ptr, vec_len);
}
}
}

Expand Down
22 changes: 22 additions & 0 deletions src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
#include <cstdint>
#include <vector>

#include "openvino/core/type/bfloat16.hpp"
#include "openvino/core/type/float16.hpp"

namespace InferenceEngine {
namespace Extensions {
namespace Cpu {
Expand Down Expand Up @@ -55,6 +58,15 @@ static constexpr size_t vec_len_f32_avx2 = vec_len_avx2 / sizeof(float);
x = _mm512_mask_blend_epi32(mask, nan, x); // Check NaN before converting back to bf16
_mm512_mask_cvtepi32_storeu_epi16(addr, mask_addr, x);
}

inline __m512 mm512_uni_loadu_ps(ov::float16* a) {
auto vec_f16 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a));
return _mm512_cvtph_ps(vec_f16);
}
inline void mm512_uni_storeu_ps(ov::float16* addr, __m512 v) {
__m256i vec_f16 = _mm512_cvtps_ph(v, 0);
_mm256_storeu_si256(reinterpret_cast<__m256i *>(addr), vec_f16);
}
#endif

#ifdef HAVE_AVX2
Expand Down Expand Up @@ -87,6 +99,16 @@ static constexpr size_t vec_len_f32_avx2 = vec_len_avx2 / sizeof(float);
_mm_storeu_si128(reinterpret_cast<__m128i *>(addr), bf16_o);
}

inline __m256 mm256_uni_loadu_ps(ov::float16* a) {
auto vec_f16 = _mm_loadu_si128(reinterpret_cast<__m128i*>(a));
auto o = _mm256_cvtph_ps(vec_f16);
return o;
}
inline void mm256_uni_storeu_ps(ov::float16* a, __m256 v) {
__m128i vec_f16 = _mm256_cvtps_ph(v, 0); // FIXME: rounding
_mm_storeu_si128(reinterpret_cast<__m128i *>(a), vec_f16);
}

inline void hsum(__m256& x) {
__m256 y; // x: 0 1 2 3 4 5 6 7
y = _mm256_permute_ps(x, 0x39); // y: 1 2 3 0 5 6 7 4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ namespace Extensions {
namespace Cpu {
namespace XARCH {

template<typename T>
float dot_product_inner(T* a, T* b, size_t n) {
template<typename TA, typename TB>
float dot_product_inner(TA* a, TB* b, size_t n) {
size_t i = 0;
float sum = 0.0f;
#if defined(HAVE_AVX512F)
Expand All @@ -51,15 +51,28 @@ float dot_product_inner(T* a, T* b, size_t n) {
return sum;
}

void attn_dot_products(void** a, void** b, void**c, size_t vec_num, size_t vec_len, ov::element::Type input_precision) {
if (input_precision == ov::element::f32) {
for (size_t i = 0; i < vec_num; i++) {
auto a_ptr = static_cast<float*>(a[i]);
auto b_ptr = static_cast<float*>(b[i]);
auto c_ptr = static_cast<float*>(c[i]);
c_ptr[0] = dot_product_inner(a_ptr, b_ptr, vec_len);
void attn_dot_products(void** a, void** b, void**c, size_t vec_num, size_t vec_len,
ov::element::Type a_precision, ov::element::Type b_precision) {
if (a_precision == ov::element::f32) {
if (b_precision == ov::element::f32) {
for (size_t i = 0; i < vec_num; i++) {
auto a_ptr = static_cast<float*>(a[i]);
auto b_ptr = static_cast<float*>(b[i]);
auto c_ptr = static_cast<float*>(c[i]);
c_ptr[0] = dot_product_inner(a_ptr, b_ptr, vec_len);
}
} else {
assert(b_precision == ov::element::f16);
for (size_t i = 0; i < vec_num; i++) {
auto a_ptr = static_cast<float*>(a[i]);
auto b_ptr = static_cast<ov::float16*>(b[i]);
auto c_ptr = static_cast<float*>(c[i]);
c_ptr[0] = dot_product_inner(a_ptr, b_ptr, vec_len);
}
}
} else {
assert(a_precision == ov::element::bf16);
assert(b_precision == ov::element::bf16);
for (size_t i = 0; i < vec_num; i++) {
auto a_ptr = static_cast<ov::bfloat16*>(a[i]);
auto b_ptr = static_cast<ov::bfloat16*>(b[i]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace Extensions {
namespace Cpu {
namespace XARCH {

void attn_dot_products(void** a, void** b, void**c, size_t vec_num, size_t vec_len, ov::element::Type input_precision);
void attn_dot_products(void** a, void** b, void**c, size_t vec_num, size_t vec_len, ov::element::Type precision_a, ov::element::Type precision_b);

} // namespace XARCH
} // namespace Cpu
Expand Down
Loading

0 comments on commit 15fb01c

Please sign in to comment.