Skip to content

Commit

Permalink
inc_mha
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaoc committed Sep 30, 2023
1 parent 01330ee commit 5ed09ed
Show file tree
Hide file tree
Showing 5 changed files with 690 additions and 278 deletions.
9 changes: 9 additions & 0 deletions include/flexflow/ops/inc_multihead_self_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,15 @@ class IncMultiHeadSelfAttentionMeta : public OpMeta {
// typedef hipFloatComplex attFloatComplex;
hipFloatComplex *complex_input;
#endif
// metadata for pad
void *query, *key, *value, *padded_output,
*padded_input; // temporary storage fot key, value, output

//if positive, it indicates the real token idx in the origin request
// if negative, it indicates it is a padding token
int *real_token_idx, *total_tokens_per_req;
int *real_token_idx_gpu, *total_tokens_per_req_gpu;
int *max_req_length, *max_total_tokens;
};

}; // namespace FlexFlow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@ __global__ void apply_proj_bias_w(DT *input_ptr,
int qkv_weight_size,
int oProjSize);

template <typename DT>
__global__ void copy_output(DT const *padded_output,
DT *output,
int num_total_tokens,
int oProjSize,
int *real_token_idx);

template <typename DT>
__global__ void pad_input_ptr(DT *input_ptr,
DT *padded_input,
BatchConfig const *bc,
int num_padded_tokens,
int hidden_size,
int max_req_length);

template <typename DT>
__global__ void apply_proj_bias_qkv(DT *input_ptr,
DT const *bias_ptr,
Expand Down
Loading

0 comments on commit 5ed09ed

Please sign in to comment.