Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaoc committed Oct 29, 2023
1 parent a53ff87 commit 02e4fad
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 79 deletions.
183 changes: 114 additions & 69 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ __global__ void compute_attention_kernel_generation_kernel(
int max_seq_length,
int per_head_size,
int hidden_size,
BatchConfig::PerRequestInfo *request_infos) {
BatchConfig::PerRequestInfo *request_infos,
bool is_beam,
int max_beam_width) {

// q, k
using Q_vec = typename VEC_K<DT, THREADS_PER_KEY>::Type;
Expand Down Expand Up @@ -80,10 +82,15 @@ __global__ void compute_attention_kernel_generation_kernel(
// request idx
int const request_idx = blockIdx.y;

int const beam_request_idx =
is_beam ? request_idx / max_beam_width : request_idx;
int const beam_sub_request_idx = is_beam ? request_idx % max_beam_width : 0;

int const first_step = 0;

int const tlength = request_infos[request_idx].first_token_depth_in_request +
request_infos[request_idx].num_tokens_in_batch;
int const tlength =
request_infos[beam_request_idx].first_token_depth_in_request +
request_infos[beam_request_idx].num_tokens_in_batch;

// shared memory objects
extern __shared__ char smem_[];
Expand All @@ -96,7 +103,7 @@ __global__ void compute_attention_kernel_generation_kernel(
// first WARPS_PER_BLOCK for store qk_max, second WARPS_PER_BLOCK for sum
__shared__ float red_smem[WARPS_PER_BLOCK * 2];

const DT *q_ptr = query + request_idx * hidden_size * QKV_WEIGHT_NUM +
const DT *q_ptr = query + beam_request_idx * hidden_size * QKV_WEIGHT_NUM +
head_idx * per_head_size;
__shared__ Q_vec q_vecs[THREADS_PER_KEY][K_VECS_PER_THREAD];
// DT const *q_ptr =
Expand Down Expand Up @@ -131,7 +138,10 @@ __global__ void compute_attention_kernel_generation_kernel(
constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;

DT const *k_cache_batch =
key_cache + request_idx * max_seq_length * hidden_size + ki;
key_cache +
(beam_request_idx * max_beam_width + beam_sub_request_idx) *
max_seq_length * hidden_size +
ki;

int ti_end =
div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step;
Expand Down Expand Up @@ -269,7 +279,10 @@ __global__ void compute_attention_kernel_generation_kernel(

// The base pointer for the value in the cache buffer.
DT const *v_cache_batch =
value_cache + request_idx * max_seq_length * hidden_size + vi;
value_cache +
(beam_request_idx * max_beam_width + beam_sub_request_idx) *
max_seq_length * hidden_size +
vi;

if (Dh == Dh_MAX || vi < Dh) {
for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) {
Expand Down Expand Up @@ -651,8 +664,6 @@ void compute_o_prod_bias(IncMultiHeadSelfAttentionMeta const *m,
DT const *weight_ptr,
DT const *bias_ptr,
cudaStream_t stream) {
checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));
cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]);
cudnnDataType_t cudnn_data_type = ff_to_cudnn_datatype(m->output_type[0]);
assert(data_type_size(m->output_type[0]) == sizeof(DT));
Expand Down Expand Up @@ -729,7 +740,10 @@ void compute_o_prod_bias(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig::max_sequence_length(), \
m->qProjSize, \
m->hidden_size, \
m->request_infos)
m->request_infos, \
false, \
0)
template <typename DT>
void compute_attention_kernel_generation(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
Expand All @@ -739,24 +753,26 @@ void compute_attention_kernel_generation(IncMultiHeadSelfAttentionMeta const *m,
int const per_head_size = m->qProjSize;
float scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->kProjSize) : 1.0f;
size_t smem_sz;
// std::cout << "bc->num_active_requests(): " << bc->num_active_requests() <<
// "\n";
switch (per_head_size) {
case 64:
LAUNCH_ATTENTION_SCORE_KERNEL(DT, 64, 64, 4, 16, 128, stream);
break;
case 128:
LAUNCH_ATTENTION_SCORE_KERNEL(DT, 128, 128, 4, 32, 64, stream);
LAUNCH_ATTENTION_SCORE_KERNEL(DT, 128, 128, 4, 32, 128, stream);
break;
default:
assert(false);
}
// check for errors
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) {
// // check for errors
// cudaError_t error = cudaGetLastError();
// if (error != cudaSuccess) {
fprintf(stderr, "ERROR: %s \n", cudaGetErrorString(error));
assert(false);
}
// fprintf(stderr, "ERROR: %s \n", cudaGetErrorString(error));
// assert(false);
// }
}
template <typename DT>
Expand Down Expand Up @@ -830,6 +846,16 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta const *m,
cudaStream_t stream) {
// here because we need position info in inference 1
cudaEvent_t t_start, t_end1, t_end2, t_end3, t_end4, t_end5, t_end6;
cudaEventCreate(&t_start);
cudaEventCreate(&t_end1);
cudaEventCreate(&t_end2);
cudaEventCreate(&t_end3);
cudaEventCreate(&t_end4);
cudaEventCreate(&t_end5);
cudaEventCreate(&t_end6);
cudaEventRecord(t_start, stream);
if (m->offload && m->biasSize > 0) {
cudaMemcpyAsync(
m->bias_ptr, bias_ptr, m->biasSize, cudaMemcpyHostToDevice, stream);
Expand All @@ -848,6 +874,12 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta const *m,
sizeof(BatchConfig::PerRequestInfo),
cudaMemcpyHostToDevice,
stream);
float elapsed4 = 0;
cudaEventRecord(t_end4, stream);
checkCUDA(cudaEventSynchronize(t_end4));
checkCUDA(cudaEventElapsedTime(&elapsed4, t_start, t_end4));
printf("IncMultiHeadSelfAttention copy element kernel time = %.9fms\n",
elapsed4);
// phase 1: Implement kernel to compute KQV for input tokens
compute_qkv_kernel(m,
Expand All @@ -858,31 +890,60 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta const *m,
static_cast<DT *>(m->devQKVProjArray),
bias_ptr,
stream);
float elapsed1 = 0;
cudaEventRecord(t_end1, stream);
checkCUDA(cudaEventSynchronize(t_end1));
checkCUDA(cudaEventElapsedTime(&elapsed1, t_start, t_end1));
printf("IncMultiHeadSelfAttention qkv kernel time = %.9fms\n", elapsed1);
// phase 2: Update key/val cache
update_kv_cache_kernel<DT>(m, bc, stream);
// cudaMemsetAsync((float *)m->attn_heads, 0.0f, 12 * 64 * 4, stream);
// checkCUDA(cudaStreamSynchronize(stream));
// save_tensor<float>((float *)m->attn_heads,
// 1 * 12 * 64 * 1,
// "/home/ubuntu/FlexFlow/inference/vprojbefore.txt");
float elapsed2 = 0;
cudaEventRecord(t_end2, stream);
checkCUDA(cudaEventSynchronize(t_end2));
checkCUDA(cudaEventElapsedTime(&elapsed2, t_start, t_end2));
printf("IncMultiHeadSelfAttention update kv cache time = %.9fms\n", elapsed2);
if (bc->num_generation_tokens > 0) {
// phase 3: Compute attention score for generation tokens
compute_attention_kernel_generation<DT>(
m, bc, static_cast<DT *>(m->attn_heads), stream);
}
float elapsed3 = 0;
cudaEventRecord(t_end3, stream);
checkCUDA(cudaEventSynchronize(t_end3));
checkCUDA(cudaEventElapsedTime(&elapsed3, t_start, t_end3));
printf("IncMultiHeadSelfAttention attention score time = %.9fms\n", elapsed3);
if (bc->num_tokens > bc->num_generation_tokens) {
// phase 4: Compute attention score for prompt tokens;
compute_attention_kernel_prompt(
m, bc, shard_id, bias_ptr, weight_ptr, stream);
}
float elapsed5 = 0;
cudaEventRecord(t_end5, stream);
checkCUDA(cudaEventSynchronize(t_end5));
checkCUDA(cudaEventElapsedTime(&elapsed5, t_start, t_end5));
printf("IncMultiHeadSelfAttention is there a thing? time = %.9fms\n",
elapsed5);
// compute output production and bias together for all tokens
compute_o_prod_bias(
m, bc, shard_id, output_ptr, weight_ptr, bias_ptr, stream);
float elapsed6 = 0;
cudaEventRecord(t_end6, stream);
checkCUDA(cudaEventSynchronize(t_end6));
checkCUDA(cudaEventElapsedTime(&elapsed6, t_start, t_end6));
printf("IncMultiHeadSelfAttention final projection time = %.9fms\n",
elapsed6);
cudaEventDestroy(t_start);
cudaEventDestroy(t_end1);
cudaEventDestroy(t_end2);
cudaEventDestroy(t_end3);
cudaEventDestroy(t_end4);
cudaEventDestroy(t_end5);
cudaEventDestroy(t_end6);
}
} // namespace IncMultiHeadAttention
Expand Down Expand Up @@ -1150,13 +1211,6 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper(
checkCUDA(get_legion_stream(&stream));
bool use_bias = *m->qkv_bias || *m->final_bias;
cudaEvent_t t_start, t_end;
if (m->profiling) {
cudaEventCreate(&t_start);
cudaEventCreate(&t_end);
cudaEventRecord(t_start, stream);
}
// assert(input.data_type == weight.data_type);
assert(input.data_type == output.data_type);
if (use_bias) {
Expand Down Expand Up @@ -1197,47 +1251,38 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper(
} else {
assert(false && "Unspported data type");
}
if (m->profiling) {
cudaEventRecord(t_end, stream);
checkCUDA(cudaEventSynchronize(t_end));
float elapsed = 0;
checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end));
cudaEventDestroy(t_start);
cudaEventDestroy(t_end);
printf("IncMultiHeadSelfAttention forward time = %.9fms\n", elapsed);
// if (input.data_type == DT_HALF) {
// print_tensor<half>(input.get_half_ptr(),
// 32,
// "[IncMultiHeadSelfAttention:forward:input]");
// print_tensor<half>(weight.get_half_ptr(),
// 32,
// "[IncMultiHeadSelfAttention:forward:weight]");
// print_tensor<half>(output.get_half_ptr(),
// 32,
// "[IncMultiHeadSelfAttention:forward:output]");
// print_tensor<half>(
// bias.get_half_ptr(), 32,
// "[IncMultiHeadSelfAttention:forward:bias]");
// } else {
// print_tensor<float>(input.get_float_ptr(),
// 32,
// "[IncMultiHeadSelfAttention:forward:input]");
// print_tensor<float>(weight.get_float_ptr(),
// 32,
// "[IncMultiHeadSelfAttention:forward:weight]");
// print_tensor<float>(output.get_float_ptr(),
// 32,
// "[IncMultiHeadSelfAttention:forward:output]");
// print_tensor<float>(
// bias.get_float_ptr(), 32,
// "[IncMultiHeadSelfAttention:forward:bias]");
// }
// print_tensor<3, float>(acc_query.ptr, acc_query.rect,
// "[Attention:forward:query]"); print_tensor<3, float>(acc_output.ptr,
// acc_output.rect, "[Attention:forward:output]");
}
// if (input.data_type == DT_HALF) {
// print_tensor<half>(input.get_half_ptr(),
// 32,
// "[IncMultiHeadSelfAttention:forward:input]");
// print_tensor<half>(weight.get_half_ptr(),
// 32,
// "[IncMultiHeadSelfAttention:forward:weight]");
// print_tensor<half>(output.get_half_ptr(),
// 32,
// "[IncMultiHeadSelfAttention:forward:output]");
// print_tensor<half>(
// bias.get_half_ptr(), 32,
// "[IncMultiHeadSelfAttention:forward:bias]");
// } else {
// print_tensor<float>(input.get_float_ptr(),
// 32,
// "[IncMultiHeadSelfAttention:forward:input]");
// print_tensor<float>(weight.get_float_ptr(),
// 32,
// "[IncMultiHeadSelfAttention:forward:weight]");
// print_tensor<float>(output.get_float_ptr(),
// 32,
// "[IncMultiHeadSelfAttention:forward:output]");
// print_tensor<float>(
// bias.get_float_ptr(), 32,
// "[IncMultiHeadSelfAttention:forward:bias]");
// }
// print_tensor<3, float>(acc_query.ptr, acc_query.rect,
// "[Attention:forward:query]"); print_tensor<3, float>(acc_output.ptr,
// acc_output.rect, "[Attention:forward:output]");
// print_tensor<float>(input.get_float_ptr(), 32, "ip");
// print_tensor<float>(output.get_float_ptr(), 32, "op");
}
Expand Down
Loading

0 comments on commit 02e4fad

Please sign in to comment.