From a1f43e7a2544d94b526c2dd643bd9db6ece995f5 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Tue, 15 Oct 2024 20:19:45 +0000 Subject: [PATCH] fix --- include/flexflow/request_manager.h | 2 +- inference/peft/peft.cc | 2 +- src/runtime/request_manager.cc | 21 +++++++++++++-------- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index ca83acb4f4..94bfc74244 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -68,7 +68,7 @@ struct Request { BatchConfig::RequestGuid guid; PEFTModelID peft_model_id = PEFTModelID::NO_ID; int max_length = -1; - int max_new_tokens = 128; + int max_new_tokens = -1; int initial_len; int ssm_cache_size = 0; int llm_cache_size = 0; diff --git a/inference/peft/peft.cc b/inference/peft/peft.cc index ee5bd1b460..14fc653eba 100644 --- a/inference/peft/peft.cc +++ b/inference/peft/peft.cc @@ -340,7 +340,7 @@ void FlexFlow::top_level_task(Task const *task, printf("Inference prompt[%d]: %s\n", total_num_requests, text.c_str()); Request inference_req; inference_req.prompt = text; - inference_req.max_length = 128; + inference_req.max_new_tokens = 128; inference_req.peft_model_id = (peft_model_id != nullptr) ? *peft_model_id : PEFTModelID::NO_ID; requests.push_back(inference_req); diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 151ccee19b..612bbfb152 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -271,7 +271,13 @@ RequestManager::RequestGuid request.guid = next_available_guid++; request.max_length = request_.max_length; request.max_new_tokens = request_.max_new_tokens; + // both unset + if (request.max_length == -1 && request.max_new_tokens == -1) { + request.max_length = get_max_sequence_length(); + } + // both set if (request.max_length != -1 && request.max_new_tokens != -1) { + request.max_length = -1; std::cout << "Both `max_new_tokens` (=" << request.max_new_tokens << ") and `max_length`(=" << request.max_length @@ -372,15 +378,14 @@ RequestManager::RequestGuid request.initial_len = 0; request.max_length = request_.max_length; request.max_new_tokens = request_.max_new_tokens; - if (request.max_length != -1) { - std::cout << "Warning: max_length is set for PEFT finetuning, but it will " - "be ignored." - << std::endl; - } if (request.max_new_tokens != -1) { - std::cout << "Warning: max_new_tokens is set for PEFT finetuning, but " - "it will be ignored." - << std::endl; + std::cerr + << "Error: max_new_tokens is not allowed for PEFT finetuning requests" + << std::endl; + assert(false); + } + if (request.max_length == -1) { + request.max_length = get_max_sequence_length(); } request.peft_model_id = request_.peft_model_id; request.req_type = RequestType::REQ_FINETUNING;