Skip to content

Commit

Permalink
fix different thread per key case
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaoc committed Oct 24, 2023
1 parent 0aec1b6 commit 5d2dbbd
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ struct Vec_fp32_<half4> {
using Type = float4;
};

template <typename DT>
struct VEC_V {};
template <>
struct VEC_V<float> {
using Type = float4;
};
template <>
struct VEC_V<half> {
using Type = half4;
};

////////////////data structures half///////////////

////////////////////////////////////floating point
Expand Down
25 changes: 5 additions & 20 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ __global__ void compute_attention_kernel_generation_kernel(
// q, k
using Q_vec = typename VEC_K<DT, THREADS_PER_KEY>::Type;
using K_vec = typename VEC_K<DT, THREADS_PER_KEY>::Type;
using V_vec = typename VEC_K<DT, THREADS_PER_KEY>::Type;
using V_vec = typename VEC_V<DT>::Type;
using Out_sum = typename Vec_fp32_<V_vec>::Type;

constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE;
Expand All @@ -66,13 +66,12 @@ __global__ void compute_attention_kernel_generation_kernel(
// then K_VEC_SIZE = 1, QK_VEC_SIZE = 4
// K_ELTS_PER_THREAD = 128 / 4 = 32
// K_VECS_PER_THREAD = 32 / 1 = 32
// todo fix
constexpr int K_VEC_SIZE = 16 / sizeof(DT);
constexpr int QK_VEC_SIZE = 16 / sizeof(DT);
// constexpr int QK_VEC_SIZE = sizeof(Qk_vec_k) / sizeof(DT);
constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(DT);
// constexpr int QK_VEC_SIZE = 16 / sizeof(DT);
// // constexpr int QK_VEC_SIZE = sizeof(Qk_vec_k) / sizeof(DT);
constexpr int K_ELTS_PER_THREAD = Dh / THREADS_PER_KEY;
constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE;
constexpr int QK_ELTS_IN_16B = 16 / sizeof(DT);
// constexpr int QK_ELTS_IN_16B = 16 / sizeof(DT);

// thread id
int const tidx = threadIdx.x;
Expand Down Expand Up @@ -152,20 +151,6 @@ __global__ void compute_attention_kernel_generation_kernel(
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
}
// if (blockIdx.x == 0 && blockIdx.y == 0 && tidx == 0) {
// printf("query and key %.10f, %.10f, %.10f, %.10f\n",
// q_vec[0].x,
// q_vec[1].x,
// k[0].x,
// k[1].x);
// }
// if (blockIdx.x == 0 && blockIdx.y == 10 && tidx == 0) {
// printf("query and key second thread %.10f, %.10f, %.10f, %.10f\n",
// q_vec[0].x,
// q_vec[1].x,
// k[0].x,
// k[1].x);
// }
float qk = scale * Qk_dot<DT, THREADS_PER_KEY>::dot(q_vecs[ki_o], k);
// // todo add positional embedding to the qk production
// // Store the product to shared memory. There's one qk value per
Expand Down

0 comments on commit 5d2dbbd

Please sign in to comment.