Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Nov 27, 2024
1 parent 98e025c commit d963933
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 29 deletions.
42 changes: 22 additions & 20 deletions benchmarking/debug.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,30 @@ make -j install
export LEGION_BACKTRACE=1
export FF_DEBG_NO_WEIGHTS=1

gdb -ex run --args ./inference/incr_decoding/incr_decoding \
-ll:cpu $NCPUS -ll:gpu $NGPUS -ll:util $NCPUS \
-ll:fsize 20000 -ll:zsize 10000 \
-llm-model $MODEL_NAME --verbose \
-prompt $PROMPT \
-tensor-parallelism-degree $NGPUS \
-log-file ../inference/output/test.out \
-output-file ../inference/output/test.json \
--max-requests-per-batch 1 --max-tokens-per-batch 3000 --max-sequence-length 3000
export CUDA_VISIBLE_DEVICES=1

#--verbose -lg:prof 1 -lg:prof_logfile prof_%.gz \

# ./inference/peft/peft \
# -ll:cpu 4 -ll:gpu $NGPUS -ll:util 2 \
# -ll:fsize 10000 -ll:zsize 10000 \
# --fusion \
# -llm-model $MODEL_NAME \
# -enable-peft -peft-model $PEFT_MODEL_NAME \
# -prompt /usr/FlexFlow/inference/prompt/peft.json \
# -finetuning-dataset /usr/FlexFlow/inference/prompt/peft_dataset.json \
# gdb -ex run --args ./inference/incr_decoding/incr_decoding \
# -ll:cpu $NCPUS -ll:gpu $NGPUS -ll:util $NCPUS \
# -ll:fsize 20000 -ll:zsize 10000 \
# -llm-model $MODEL_NAME --verbose \
# -prompt $PROMPT \
# -tensor-parallelism-degree $NGPUS \
# -log-file ../inference/output/test.out \
# -output-file ../inference/output/test.json \
# --max-requests-per-batch 1 --max-tokens-per-batch 3000 --max-sequence-length 3000

# -lg:prof 1 -lg:prof_logfile prof_%.gz --verbose --inference-debugging \
#--verbose -lg:prof 1 -lg:prof_logfile prof_%.gz \

./inference/peft/peft \
-ll:cpu 4 -ll:gpu $NGPUS -ll:util 2 \
-ll:fsize 20000 -ll:zsize 10000 \
--fusion \
-llm-model $MODEL_NAME \
-enable-peft -peft-model $PEFT_MODEL_NAME \
-prompt /usr/FlexFlow/inference/prompt/peft.json \
-finetuning-dataset /usr/FlexFlow/inference/prompt/peft_dataset.json \
-tensor-parallelism-degree $NGPUS \
-output-file ../inference/output/test.json \
--max-requests-per-batch 1 --max-tokens-per-batch 3000 --max-sequence-length 3000

# -lg:prof 1 -lg:prof_logfile prof_%.gz --verbose --inference-debugging \
2 changes: 2 additions & 0 deletions inference/peft/peft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ void FlexFlow::top_level_task(Task const *task,
assert(ffconfig.data_parallelism_degree * ffconfig.tensor_parallelism_degree *
ffconfig.pipeline_parallelism_degree ==
ffconfig.numNodes * ffconfig.workersPerNode);

ffconfig.enable_peft_finetuning = enable_peft_finetuning;

std::string config_filepath = join_path(
{file_paths.cache_folder_path, "configs", llm_model_name, "config.json"});
Expand Down
30 changes: 22 additions & 8 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,23 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m,
int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch;
int total_tokens = bc->requestsInfo[i].first_token_depth_in_request +
bc->requestsInfo[i].num_tokens_in_batch;
int max_peft_tokens = bc->requestsInfo[i].max_length;
int max_peft_tokens = BatchConfig::max_sequence_length();
// Copy query to m->query_activation_buffer if we need to compute
// PEFT backward
if (bc->requestsInfo[i].finetuning_request) {
size_t activation_size_needed =
sizeof(DT) * max_peft_tokens * m->num_q_heads * m->qProjSize;
if (activation_size_needed != m->allocated_peft_buffer_size1) {
std::cout << "activation_size_needed: " << activation_size_needed
<< std::endl;
std::cout << "m->allocated_peft_buffer_size1: " << m->allocated_peft_buffer_size1
<< std::endl;
std::cout << "max_peft_tokens: " << max_peft_tokens << std::endl;
std::cout << "m->num_q_heads: " << m->num_q_heads << std::endl;
std::cout << "m->qProjSize: " << m->qProjSize << std::endl;
std::cout << "BatchConfig::max_sequence_length()" << BatchConfig::max_sequence_length() << std::endl;
std::cout << "sizeof(DT)" << sizeof(DT) << std::endl;
}
assert(activation_size_needed == m->allocated_peft_buffer_size1);
int parallelism = m->hidden_size * num_tokens;
store_query_cache<<<GET_BLOCKS(parallelism),
Expand Down Expand Up @@ -1697,11 +1708,16 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta(
size_t complex_size = (max_tokens_per_batch * (qProjSize * num_q_heads +
kProjSize * num_q_heads)) /
2;
allocated_peft_buffer_size1 = BatchConfig::max_sequence_length() *
num_q_heads * qProjSize * size_of_dt;
allocated_peft_buffer_size2 = BatchConfig::max_sequence_length() *
BatchConfig::max_sequence_length() *
num_q_heads * size_of_dt;
if (enable_peft_finetuning) {
allocated_peft_buffer_size1 = BatchConfig::max_sequence_length() *
num_q_heads * qProjSize * size_of_dt;
allocated_peft_buffer_size2 = BatchConfig::max_sequence_length() *
BatchConfig::max_sequence_length() *
num_q_heads * size_of_dt;
} else {
allocated_peft_buffer_size1 = 0;
allocated_peft_buffer_size2 = 0;
}
size_t totalSize =
(qkv_max_proj_size + key_cache_size + value_cache_size +
2 * qk_prod_size + attn_heads_size) *
Expand Down Expand Up @@ -1791,8 +1807,6 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta(
gpu_mem_allocator.reserved_allocated_size);
}
}
allocated_peft_buffer_size1 = 0;
allocated_peft_buffer_size2 = 0;
cudaStreamSynchronize(stream);
}
Expand Down
13 changes: 13 additions & 0 deletions src/ops/kernels/rms_norm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ RMSNormMeta::RMSNormMeta(FFHandler handler,

in_dim = rms->data_dim;
batch_size = rms->effective_batch_size;
enable_peft_finetuning = rms->enable_peft_finetuning;
num_elements = in_dim * batch_size;

DataType data_type = rms->weights[0]->data_type;
Expand Down Expand Up @@ -218,6 +219,18 @@ void inference_kernel_wrapper(RMSNormMeta *m,
assert(bc->requestsInfo[i].peft_model_id != PEFTModelID::NO_ID);
assert(!bc->requestsInfo[i].finetuning_backward_phase);
int in_dim = input.domain.hi()[0] - input.domain.lo()[0] + 1;
if (m->allocated_peft_buffer_size != data_type_size(m->input_type[0]) *
BatchConfig::max_sequence_length() * in_dim) {
std::cout << "allocated_peft_buffer_size = " << m->allocated_peft_buffer_size
<< ", expected = " << data_type_size(m->input_type[0]) *
BatchConfig::max_sequence_length() * in_dim
<< std::endl;
std::cout << "in_dim = " << in_dim << std::endl;
std::cout << "max_sequence_length = " << BatchConfig::max_sequence_length()
<< std::endl;
std::cout << "data_type_size = " << data_type_size(m->input_type[0])
<< std::endl;
}
assert(m->allocated_peft_buffer_size ==
data_type_size(m->input_type[0]) *
BatchConfig::max_sequence_length() * in_dim);
Expand Down
3 changes: 2 additions & 1 deletion src/ops/sigmoid_silu_multi.cu
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,11 @@ void SigmoidSiluMulti::inference_kernel_wrapper(
int in_dim = input1.domain.hi()[0] - input1.domain.lo()[0] + 1;
int num_peft_tokens = bc->requestsInfo[i].num_tokens_in_batch;
assert(num_peft_tokens == bc->num_finetuning_tokens());
int max_peft_tokens = BatchConfig::max_sequence_length();
int first_token_offset = bc->requestsInfo[i].first_token_offset_in_batch;
size_t input_tensor_size =
data_type_size(m->input_type[0]) * num_peft_tokens * in_dim;
assert(m->allocated_peft_buffer_size == 2 * input_tensor_size);
assert(m->allocated_peft_buffer_size == 2 * (data_type_size(m->input_type[0]) * max_peft_tokens * in_dim));
// copy input activation
if (m->input_type[0] == DT_FLOAT) {
checkCUDA(
Expand Down

0 comments on commit d963933

Please sign in to comment.