Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaoc committed Oct 24, 2023
1 parent f572737 commit 0aec1b6
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 219 deletions.
188 changes: 80 additions & 108 deletions include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -3,157 +3,82 @@

namespace FlexFlow {

template <typename T, int Dh>
struct Qk_vec_k_ {};
////////////////basic datatype//////////////////////
struct half4 {
half2 a, b;
};

////////////////data type///////////////
template <typename DT, int VECPSIZE>
struct VEC_K {};
template <>
struct Qk_vec_k_<float, 32> {
struct VEC_K<float, 1> {
using Type = float;
};
template <>
struct Qk_vec_k_<float, 64> {
struct VEC_K<float, 2> {
using Type = float2;
};
template <>
struct Qk_vec_k_<float, 128> {
using Type = float4;
};
template <>
struct Qk_vec_k_<float, 256> {
struct VEC_K<float, 4> {
using Type = float4;
};

template <>
struct Qk_vec_k_<half, 32> {
using Type = half;
};
template <>
struct Qk_vec_k_<half, 64> {
using Type = half;
};
template <>
struct Qk_vec_k_<half, 128> {
struct VEC_K<half, 1> {
using Type = half;
};
template <>
struct Qk_vec_k_<half, 256> {
using Type = half;
};

template <typename T, int THREADS_PER_KEY>
struct K_vec_k_ {};

template <>
struct K_vec_k_<float, 4> {
using Type = float4;
};
template <>
struct K_vec_k_<float, 2> {
using Type = float2;
};
template <>
struct K_vec_k_<float, 1> {
using Type = float4;
struct VEC_K<half, 2> {
using Type = half2;
};

template <>
struct K_vec_k_<half, 4> {
using Type = float;
};
template <>
struct K_vec_k_<half, 2> {
using Type = float2;
};
template <>
struct K_vec_k_<half, 1> {
using Type = float4;
struct VEC_K<half, 4> {
using Type = half4;
};

// data type for QK production
template <typename T>
struct K_vec_acum_fp32_ {};
struct Vec_fp32_ {};

template <>
struct K_vec_acum_fp32_<float> {
struct Vec_fp32_<float> {
using Type = float;
};
template <>
struct K_vec_acum_fp32_<float2> {
struct Vec_fp32_<float2> {
using Type = float2;
};
template <>
struct K_vec_acum_fp32_<float4> {
struct Vec_fp32_<float4> {
using Type = float4;
};

template <>
struct K_vec_acum_fp32_<half> {
using Type = half;
};

template <typename T, int V_VEC_SIZE>
struct V_vec_k_ {};

template <>
struct V_vec_k_<float, 1> {
struct Vec_fp32_<half> {
using Type = float;
};
template <>
struct V_vec_k_<float, 2> {
struct Vec_fp32_<half2> {
using Type = float2;
};
template <>
struct V_vec_k_<float, 4> {
struct Vec_fp32_<half4> {
using Type = float4;
};

template <>
struct V_vec_k_<half, 1> {
using Type = half;
};
template <>
struct V_vec_k_<half, 2> {
using Type = half;
};
template <>
struct V_vec_k_<half, 4> {
using Type = half;
};

template <typename T>
struct V_vec_acum_fp32_ {};

template <>
struct V_vec_acum_fp32_<float> {
using Type = float;
};
template <>
struct V_vec_acum_fp32_<float2> {
using Type = float2;
};
template <>
struct V_vec_acum_fp32_<float4> {
using Type = float4;
};
////////////////data structures half///////////////

template <>
struct V_vec_acum_fp32_<half> {
using Type = float;
};
////////////////////////////////////floating point
/// operations///////////////////////////////////////////

template <typename Acc, typename A, typename B>
inline __device__ Acc mul(A a, B b) {
return Acc{}; // for compile
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template <>
inline __device__ float mul<float, float>(float a, float b) {
return a * b;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template <>
inline __device__ float2 mul(float2 a, float2 b) {
float2 c;
Expand Down Expand Up @@ -297,6 +222,56 @@ inline __device__ float4 cast_to_float(float4 u) {
return u;
}

inline __device__ float cast_to_float(half u) {
return __half2float(u);
}

////////////////////////////////////////////////////////////////////////////////////////////////////

inline __device__ float2 cast_to_float(half2 u) {
float2 tmp;
tmp.x = __half2float(u.x);
tmp.y = __half2float(u.y);
return tmp;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

inline __device__ float4 cast_to_float(half4 u) {
float4 tmp;
tmp.x = __half2float(u.a.x);
tmp.y = __half2float(u.a.y);
tmp.z = __half2float(u.b.x);
tmp.w = __half2float(u.b.y);
return tmp;
}

inline __device__ void convert_from_float(float4 &dst, float4 src) {
dst = src;
}
inline __device__ void convert_from_float(float &dst, float src) {
dst = src;
}
inline __device__ void convert_from_float(float2 &dst, float2 src) {
dst = src;
}

inline __device__ void convert_from_float(half4 &dst, float4 src) {
dst.a.x = __float2half(src.x);
dst.a.y = __float2half(src.y);
dst.b.x = __float2half(src.z);
dst.b.y = __float2half(src.w);
}
inline __device__ void convert_from_float(half2 &dst, float2 src) {
dst.x = __float2half(src.x);
dst.y = __float2half(src.y);
}
inline __device__ void convert_from_float(half &dst, float src) {
dst = __float2half(src);
}

//////////////////////////////////////utils///////////////////////////////////////////////

template <typename T>
inline __device__ void zero(T &dst) {
constexpr int WORDS = sizeof(T) / 4;
Expand All @@ -313,16 +288,13 @@ inline __device__ void zero(T &dst) {

template <int THREADS_PER_KEY, typename K_vec, int N>
inline __device__ float qk_dot_(K_vec const (&q)[N], K_vec const (&k)[N]) {
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
using K_vec_acum = typename K_vec_acum_fp32_<K_vec>::Type;
#else
using K_vec_acum = K_vec;
#endif
// use float32 to get better accuracy
using Vec_sum = typename Vec_fp32_<K_vec>::Type;
// Compute the parallel products for Q*K^T (treat vector lanes separately).
K_vec_acum qk_vec = mul<K_vec_acum, K_vec, K_vec>(q[0], k[0]);
Vec_sum qk_vec = mul<Vec_sum, K_vec, K_vec>(q[0], k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
qk_vec = FlexFlow::fma(cast_to_float(q[ii]), cast_to_float(k[ii]), qk_vec);
}

// Finalize the reduction across lanes.
Expand Down
Loading

0 comments on commit 0aec1b6

Please sign in to comment.