diff --git a/benchmarking/debug.sh b/benchmarking/debug.sh index 300bf80df2..86c7d2d902 100755 --- a/benchmarking/debug.sh +++ b/benchmarking/debug.sh @@ -6,21 +6,44 @@ set -e cd "${BASH_SOURCE[0]%/*}/../build" # MODEL_NAME="meta-llama/Llama-3.1-8B-Instruct" +# PROMPT="../benchmarking/test.json" +PROMPT="/usr/FlexFlow/inference/prompt/peft.json" MODEL_NAME="JackFram/llama-160m" +PEFT_MODEL_NAME="goliaro/llama-160m-lora" NGPUS=1 +NCPUS=4 + +reset +make -j install + +# python ../inference/utils/download_hf_model.py $MODEL_NAME +# python ../inference/utils/download_peft_model.py $PEFT_MODEL_NAME -python ../inference/utils/download_hf_model.py $MODEL_NAME export LEGION_BACKTRACE=1 +export FF_DEBG_NO_WEIGHTS=1 -./inference/incr_decoding/incr_decoding \ - -ll:cpu 16 -ll:gpu $NGPUS -ll:util 16 \ +gdb -ex run --args ./inference/incr_decoding/incr_decoding \ + -ll:cpu $NCPUS -ll:gpu $NGPUS -ll:util $NCPUS \ -ll:fsize 20000 -ll:zsize 10000 \ - --fusion \ + --verbose -lg:prof 1 -lg:prof_logfile prof_%.gz \ -llm-model $MODEL_NAME \ - -prompt ../benchmarking/test.json \ + -prompt $PROMPT \ -tensor-parallelism-degree $NGPUS \ -log-file ../inference/output/test.out \ -output-file ../inference/output/test.json \ --max-requests-per-batch 1 --max-tokens-per-batch 3000 --max-sequence-length 3000 +# ./inference/peft/peft \ +# -ll:cpu 4 -ll:gpu $NGPUS -ll:util 2 \ +# -ll:fsize 10000 -ll:zsize 10000 \ +# --fusion \ +# -llm-model $MODEL_NAME \ +# -enable-peft -peft-model $PEFT_MODEL_NAME \ +# -prompt /usr/FlexFlow/inference/prompt/peft.json \ +# -finetuning-dataset /usr/FlexFlow/inference/prompt/peft_dataset.json \ +# -tensor-parallelism-degree $NGPUS \ +# -output-file ../inference/output/test.json \ +# --max-requests-per-batch 1 --max-tokens-per-batch 3000 --max-sequence-length 3000 + + # -lg:prof 1 -lg:prof_logfile prof_%.gz --verbose --inference-debugging \ \ No newline at end of file diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index e3c0e50396..d4be0631f6 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -177,6 +177,7 @@ struct InferenceResult { static int const MAX_NUM_TOKENS = BatchConfig::MAX_NUM_TOKENS; BatchConfig::TokenId token_ids[MAX_NUM_TOKENS]; float finetuning_loss; + friend std::ostream &operator<<(std::ostream &os, InferenceResult const &result); }; class BeamSearchBatchConfig : public BatchConfig { diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index f2d97af89f..a920436d54 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -112,8 +112,7 @@ struct Request { std::vector beam_trees; Request() = default; - Request(Request const &other); - void load_token_ids(); + static Request from_other(Request const &other); friend std::ostream &operator<<(std::ostream &os, Request const &req); }; @@ -152,6 +151,7 @@ class RequestManager { bool load_request_token_ids(Request &request); + void set_verbose(bool verbose); void set_max_requests_per_batch(int max_num_requests); int get_max_requests_per_batch(); void set_max_tokens_per_batch(int max_num_tokens); diff --git a/include/flexflow/utils/file_loader.h b/include/flexflow/utils/file_loader.h index 8735f23571..c265083973 100644 --- a/include/flexflow/utils/file_loader.h +++ b/include/flexflow/utils/file_loader.h @@ -21,6 +21,8 @@ using namespace std; using namespace FlexFlow; +using namespace Legion; + class FileDataLoader { public: @@ -36,16 +38,31 @@ class FileDataLoader { BatchConfig::TokenId *generate_requests(int num, int length); template - void load_single_weight_tensor(FFModel *ff, Layer *l, int weight_idx); + void load_single_weight_tensor(FFModel *ff, + Layer *l, + int weight_idx, + size_t volume, + size_t num_replicas, + DT *weight, + Domain weight_domain); - void load_quantization_weight(FFModel *ff, Layer *l, int weight_idx); + void load_quantization_weight(FFModel *ff, + Layer *l, + int weight_idx, + size_t volume, + size_t num_replicas, + char *weight, + DataType data_type, + Domain weight_domain); static void load_weight_task(Legion::Task const *task, std::vector const ®ions, Legion::Context ctx, Legion::Runtime *runtime); - void load_weights_parallel(FFModel *ff, Context ctx, Runtime *runtime); + void load_weights_parallel(FFModel *ff, + Legion::Context ctx, + Legion::Runtime *runtime); void load_positions(FFModel *ff, Tensor pt, @@ -66,12 +83,15 @@ struct WeightLoadTaskArgs { FileDataLoader *loader; Layer *layer; int weight_idx; + size_t volume, num_replicas; DataType data_type; WeightLoadTaskArgs(FFModel *_ff, FileDataLoader *_loader, Layer *_l, int _idx, + size_t _volume, + size_t _num_replicas, DataType _data_type) - : ff(_ff), loader(_loader), layer(_l), weight_idx(_idx), - data_type(_data_type) {} + : ff(_ff), loader(_loader), layer(_l), weight_idx(_idx), volume(_volume), + num_replicas(_num_replicas), data_type(_data_type) {} }; diff --git a/inference/incr_decoding/incr_decoding.cc b/inference/incr_decoding/incr_decoding.cc index 4dfc2df474..23d302cae8 100644 --- a/inference/incr_decoding/incr_decoding.cc +++ b/inference/incr_decoding/incr_decoding.cc @@ -225,6 +225,7 @@ void FlexFlow::top_level_task(Task const *task, GenerationConfig generationConfig(do_sample, temperature, topp); RequestManager *rm = RequestManager::get_request_manager(); + rm->set_verbose(verbose); rm->set_max_requests_per_batch(max_requests_per_batch); rm->set_max_tokens_per_batch(max_tokens_per_batch); rm->set_max_sequence_length(max_sequence_length); @@ -271,74 +272,96 @@ void FlexFlow::top_level_task(Task const *task, } rm->start_background_server(&model); + // { + // using json = nlohmann::json; + // std::ifstream file_handle(file_paths.prompt_file_path); + // assert(file_handle.good() && "Prompt file does not exist."); + // nlohmann::ordered_json prompt_json = + // nlohmann::ordered_json::parse(file_handle, + // /*parser_callback_t */ nullptr, + // /*allow_exceptions */ true, + // /*ignore_comments */ true); + // file_handle.close(); + // auto &metadata = prompt_json["metadata"]; + // int num_warmup_requests = metadata["num_warmup_requests"]; + // int num_regular_requests = 0, total_requests = 0; + // std::vector warmup_requests, requests; + // for (auto &entry : prompt_json["entries"]) { + // int prompt_length = entry["prompt_length"]; + // int response_length = entry["response_length"]; + // std::string text = entry["prompt"]; + // bool is_warmup_request = total_requests < num_warmup_requests; + + // Request inference_req; + // inference_req.prompt = text; + // inference_req.add_special_tokens = false; + // inference_req.max_new_tokens = response_length; + + // if (is_warmup_request) { + // warmup_requests.push_back(inference_req); + // } else { + // printf("Prompt[%d]: %s\n", total_requests, text.c_str()); + // requests.push_back(inference_req); + // num_regular_requests++; + // } + + // total_requests++; + // } + // std::vector warmup_result = + // model.generate(warmup_requests); + // std::vector result = model.generate(requests); + + // assert(warmup_result.size() == warmup_requests.size()); + // assert(result.size() == requests.size()); + // assert(result.size() + warmup_result.size() == total_requests); + // int i = 0; + // for (auto &entry : prompt_json["entries"]) { + // if (i < num_warmup_requests) { + // i++; + // continue; + // } + // int index = i - num_warmup_requests; + // entry["original_response"] = entry["response"]; + // entry["original_response_length"] = entry["response_length"]; + // std::string ff_out = result[index].output_text; + // int tot_length = result[index].output_text.length(); + // entry["response"] = ff_out; + // entry["response_length"] = result[index].output_tokens.size(); + // i++; + // } + + // // Write the modified JSON to a file + // std::ofstream output_file(file_paths.output_file_path); + // if (output_file.is_open()) { + // output_file << prompt_json.dump(2); + // output_file.close(); + // std::cout << "Modified JSON has been saved to " + // << file_paths.output_file_path << std::endl; + // } else { + // std::cerr << "Unable to open file for writing." << std::endl; + // } + // } + int total_num_requests = 0; { using json = nlohmann::json; std::ifstream file_handle(file_paths.prompt_file_path); assert(file_handle.good() && "Prompt file does not exist."); - nlohmann::ordered_json prompt_json = - nlohmann::ordered_json::parse(file_handle, - /*parser_callback_t */ nullptr, - /*allow_exceptions */ true, - /*ignore_comments */ true); - file_handle.close(); - auto &metadata = prompt_json["metadata"]; - int num_warmup_requests = metadata["num_warmup_requests"]; - int num_regular_requests = 0, total_requests = 0; - std::vector warmup_requests, requests; - for (auto &entry : prompt_json["entries"]) { - int prompt_length = entry["prompt_length"]; - int response_length = entry["response_length"]; - std::string text = entry["prompt"]; - bool is_warmup_request = total_requests < num_warmup_requests; + json prompt_json = json::parse(file_handle, + /*parser_callback_t */ nullptr, + /*allow_exceptions */ true, + /*ignore_comments */ true); + std::vector requests; + for (auto &prompt : prompt_json) { + std::string text = prompt.get(); + printf("Prompt[%d]: %s\n", total_num_requests, text.c_str()); Request inference_req; inference_req.prompt = text; - inference_req.add_special_tokens = false; - inference_req.max_new_tokens = response_length; - - if (is_warmup_request) { - warmup_requests.push_back(inference_req); - } else { - printf("Prompt[%d]: %s\n", total_requests, text.c_str()); - requests.push_back(inference_req); - num_regular_requests++; - } - - total_requests++; + inference_req.max_length = 128; + requests.push_back(inference_req); + total_num_requests++; } - std::vector warmup_result = - model.generate(warmup_requests); std::vector result = model.generate(requests); - - assert(warmup_result.size() == warmup_requests.size()); - assert(result.size() == requests.size()); - assert(result.size() + warmup_result.size() == total_requests); - int i = 0; - for (auto &entry : prompt_json["entries"]) { - if (i < num_warmup_requests) { - i++; - continue; - } - int index = i - num_warmup_requests; - entry["original_response"] = entry["response"]; - entry["original_response_length"] = entry["response_length"]; - std::string ff_out = result[index].output_text; - int tot_length = result[index].output_text.length(); - entry["response"] = ff_out; - entry["response_length"] = result[index].output_tokens.size(); - i++; - } - - // Write the modified JSON to a file - std::ofstream output_file(file_paths.output_file_path); - if (output_file.is_open()) { - output_file << prompt_json.dump(2); - output_file.close(); - std::cout << "Modified JSON has been saved to " - << file_paths.output_file_path << std::endl; - } else { - std::cerr << "Unable to open file for writing." << std::endl; - } } // terminate the request manager by stopping the background thread diff --git a/inference/peft/peft.cc b/inference/peft/peft.cc index 484488ac65..f4348fd743 100644 --- a/inference/peft/peft.cc +++ b/inference/peft/peft.cc @@ -194,6 +194,10 @@ void FlexFlow::top_level_task(Task const *task, << std::endl; assert(false); } + if (!enable_peft) { + std::cerr << "Running PEFT script with PEFT not enabled" << std::endl; + assert(false); + } if (enable_peft && peft_model_name.empty()) { std::cout << "PEFT enabled, but no PEFT model id passed" << std::endl; assert(false); @@ -272,6 +276,7 @@ void FlexFlow::top_level_task(Task const *task, GenerationConfig generationConfig(do_sample, temperature, topp); RequestManager *rm = RequestManager::get_request_manager(); + rm->set_verbose(verbose); rm->set_max_requests_per_batch( max_requests_per_batch + (int)enable_peft_finetuning); // add one slot for finetuning if needed diff --git a/python/flexflow/serve/serve.py b/python/flexflow/serve/serve.py index c2804b6966..d6aa2b3f03 100644 --- a/python/flexflow/serve/serve.py +++ b/python/flexflow/serve/serve.py @@ -395,6 +395,7 @@ def download_and_convert_peft_model(hf_peft_model_id: str): weights_path = self.__get_resource_path( hf_peft_model_id.lower(), CachedResourceType.WEIGHTS ) + print(f"Opening {adapter_path}...") with safe_open(adapter_path, framework="pt", device="cpu") as f: for tensor_name in f.keys(): tensor = f.get_tensor(tensor_name) diff --git a/src/runtime/batch_config.cc b/src/runtime/batch_config.cc index 221a215afc..7f2a53b804 100644 --- a/src/runtime/batch_config.cc +++ b/src/runtime/batch_config.cc @@ -167,6 +167,20 @@ int BatchConfig::max_spec_tree_token_num() { return RequestManager::get_request_manager()->get_max_spec_tree_token_num(); } +// print InferenceResult +std::ostream &operator<<(std::ostream &os, InferenceResult const &result) { + os << "InferenceResult {"; + os << "MAX_NUM_TOKENS: " << InferenceResult::MAX_NUM_TOKENS << ", "; + os << "token_ids: ["; + for (int i = 0; i < 16; i++) { + os << result.token_ids[i] << ", "; + } + os << "], "; + os << "finetuning_loss: " << result.finetuning_loss; + os << "}"; + return os; +} + std::ostream &operator<<(std::ostream &os, BatchConfig const &bc) { os << "@@@@@@@@@@@@@@ Batch Config (mode " << bc.get_mode() << ") @@@@@@@@@@@@@@" << std::endl; diff --git a/src/runtime/file_loader.cc b/src/runtime/file_loader.cc index 3ebe6cf095..8ed4d75546 100644 --- a/src/runtime/file_loader.cc +++ b/src/runtime/file_loader.cc @@ -670,14 +670,20 @@ void load_from_quantized_file(char *ptr, void FileDataLoader::load_quantization_weight(FFModel *ff, Layer *l, - int weight_idx) { - Tensor weight = l->weights[weight_idx]; - size_t volume = 1; + int weight_idx, + size_t volume, + size_t num_replicas, + char *weight, + DataType data_type, + Domain weight_domain) { + size_t volume_ = 1; std::vector dims_vec; - for (int i = 0; i < weight->num_dims; i++) { - dims_vec.push_back(weight->dims[i]); - volume *= weight->dims[i]; + for (int i = 0; i < weight_domain.get_dim(); i++) { + int dim_i = weight_domain.hi()[i] - weight_domain.lo()[i] + 1; + dims_vec.push_back(dim_i); + volume_ *= dim_i; } + assert(volume_ == volume * num_replicas); char *data = (char *)malloc(sizeof(char) * volume); std::string weight_filename = removeGuidOperatorName(std::string(l->name)); @@ -692,7 +698,7 @@ void FileDataLoader::load_quantization_weight(FFModel *ff, qkv_inner_dim, weight_filename, weights_folder, - weight->data_type, + data_type, use_full_precision); } // else { @@ -714,31 +720,38 @@ void FileDataLoader::load_quantization_weight(FFModel *ff, load_from_quantized_file(data, volume, join_path({weights_folder, weight_filename}), - weight->data_type, + data_type, use_full_precision); } - ParallelTensor weight_pt; - ff->get_parallel_tensor_from_tensor(weight, weight_pt); - weight_pt->set_tensor(ff, dims_vec, data); + char *ptr = weight; + for (size_t i = 0; i < num_replicas; i++) { + memcpy(ptr, data, volume * sizeof(char)); + ptr += volume; + } - delete data; + free(data); } template void FileDataLoader::load_single_weight_tensor(FFModel *ff, Layer *l, - int weight_idx) { - Tensor weight = l->weights[weight_idx]; + int weight_idx, + size_t volume, + size_t num_replicas, + DT *weight, + Domain weight_domain) { // Create a buffer to store weight data from the file - size_t volume = 1; + size_t volume_ = 1; std::vector dims_vec; - for (int i = 0; i < weight->num_dims; i++) { - dims_vec.push_back(weight->dims[i]); - volume *= weight->dims[i]; + for (int i = 0; i < weight_domain.get_dim(); i++) { + int dim_i = weight_domain.hi()[i] - weight_domain.lo()[i] + 1; + dims_vec.push_back(dim_i); + volume_ *= dim_i; } - assert(data_type_size(weight->data_type) == sizeof(DT)); + assert(volume_ == volume * num_replicas); + // assert(data_type_size(weight->data_type) == sizeof(DT)); DT *data = (DT *)malloc(sizeof(DT) * volume); std::string weight_filename = removeGuidOperatorName(std::string(l->name)); @@ -843,13 +856,15 @@ void FileDataLoader::load_single_weight_tensor(FFModel *ff, } } - // Copy the weight data from the buffer to the weight's ParallelTensor - ParallelTensor weight_pt; - ff->get_parallel_tensor_from_tensor(weight, weight_pt); - weight_pt->set_tensor
(ff, dims_vec, data); + // Copy the weight data from the buffer to the weight + DT *ptr = weight; + for (size_t i = 0; i < num_replicas; i++) { + memcpy(ptr, data, volume * sizeof(DT)); + ptr += volume; + } // Free buffer memory - delete data; + free(data); } void FileDataLoader::load_weight_task( @@ -859,21 +874,44 @@ void FileDataLoader::load_weight_task( Legion::Runtime *runtime) { WeightLoadTaskArgs const *args = (WeightLoadTaskArgs const *)task->args; + assert(task->regions.size() == regions.size()); + assert(regions.size() == 1); // one weight only + GenericTensorAccessorW weight = helperGetGenericTensorAccessorWO( + args->data_type, regions[0], task->regions[0], FID_DATA, ctx, runtime); + Domain weight_domain = runtime->get_index_space_domain( + ctx, task->regions[0].region.get_index_space()); + switch (args->data_type) { case DT_HALF: { - args->loader->load_single_weight_tensor( - args->ff, args->layer, args->weight_idx); + args->loader->load_single_weight_tensor(args->ff, + args->layer, + args->weight_idx, + args->volume, + args->num_replicas, + weight.get_half_ptr(), + weight_domain); break; } case DT_FLOAT: { - args->loader->load_single_weight_tensor( - args->ff, args->layer, args->weight_idx); + args->loader->load_single_weight_tensor(args->ff, + args->layer, + args->weight_idx, + args->volume, + args->num_replicas, + weight.get_float_ptr(), + weight_domain); break; } case DT_INT4: case DT_INT8: { - args->loader->load_quantization_weight( - args->ff, args->layer, args->weight_idx); + args->loader->load_quantization_weight(args->ff, + args->layer, + args->weight_idx, + args->volume, + args->num_replicas, + weight.get_byte_ptr(), + args->data_type, + weight_domain); break; } default: @@ -897,19 +935,38 @@ void FileDataLoader::load_weights_parallel(FFModel *ff, continue; } - if (l->op_type == OP_LORA) { - continue; - } - if (weight->data_type != DT_FLOAT && weight->data_type != DT_HALF && weight->data_type != DT_INT4 && weight->data_type != DT_INT8) { assert(false && "Unsupported data type"); } + ParallelTensor weight_pt; + ff->get_parallel_tensor_from_tensor(weight, weight_pt); + // Create task arguments - WeightLoadTaskArgs args(ff, this, l, i, weight->data_type); + size_t volume = 1, num_replicas = 1; + if (weight_pt->sync_type == ParameterSyncType::NCCL) { + for (int i = 0; i < weight_pt->num_dims; i++) { + if (weight_pt->dims[i].is_replica_dim) { + num_replicas *= weight_pt->dims[i].size; + } + } + } else if (weight_pt->sync_type == ParameterSyncType::PS) { + num_replicas = 1; + } else { + num_replicas = 1; + } + for (int i = 0; i < weight->num_dims; i++) { + volume *= weight->dims[i]; + } + WeightLoadTaskArgs args( + ff, this, l, i, volume, num_replicas, weight->data_type); + // launch task asynchronously TaskLauncher launcher(LOAD_WEIGHT_TASK_ID, TaskArgument(&args, sizeof(WeightLoadTaskArgs))); + launcher.add_region_requirement(RegionRequirement( + weight_pt->region, WRITE_ONLY, EXCLUSIVE, weight_pt->region)); + launcher.add_field(0, FID_DATA); futures.push_back(runtime->execute_task(ctx, launcher)); } } diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index c33c0c6e5e..6ad77652ea 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -54,44 +54,36 @@ RequestGuid RequestManager::assign_next_guid() { return next_available_guid++; } -Request::Request(Request const &other) - : req_type(other.req_type), max_length(other.max_length), - max_new_tokens(other.max_new_tokens), - benchmarking_tokens(other.benchmarking_tokens), - add_special_tokens(other.add_special_tokens), warmup(other.warmup), - status(Request::PENDING), prompt(other.prompt), tokens(other.tokens), - peft_model_id(other.peft_model_id), - peft_finetuning_info(other.peft_finetuning_info), - initial_len(other.initial_len), ssm_cache_size(other.ssm_cache_size), - llm_cache_size(other.llm_cache_size), beam_trees(other.beam_trees) { - +Request Request::from_other(Request const &other) { RequestManager *rm = RequestManager::get_request_manager(); - guid = rm->assign_next_guid(); + Request req = other; + req.guid = rm->assign_next_guid(); int max_seq_len = rm->get_max_sequence_length(); - if (req_type == RequestType::REQ_INFERENCE) { + if (req.req_type == RequestType::REQ_INFERENCE) { // both unset - if (max_length == -1 && max_new_tokens == -1) { - max_length = max_seq_len - 1; + if (req.max_length == -1 && req.max_new_tokens == -1) { + req.max_length = max_seq_len - 1; } // both set - if (max_length != -1 && max_new_tokens != -1) { - max_length = -1; + if (req.max_length != -1 && req.max_new_tokens != -1) { + req.max_length = -1; std::cout - << "Both `max_new_tokens` (=" << max_new_tokens - << ") and `max_length`(=" << max_length + << "Both `max_new_tokens` (=" << req.max_new_tokens + << ") and `max_length`(=" << req.max_length << ") seem to have been set. `max_new_tokens` will take precedence."; } } else { - if (max_new_tokens != -1) { + if (req.max_new_tokens != -1) { std::cerr << "Error: max_new_tokens is not allowed for PEFT finetuning requests" << std::endl; assert(false); } - if (max_length == -1) { - max_length = max_seq_len - 1; + if (req.max_length == -1) { + req.max_length = max_seq_len - 1; } } + return req; } bool RequestManager::load_request_token_ids(Request &request) { @@ -163,6 +155,7 @@ bool RequestManager::load_request_token_ids(Request &request) { request.benchmarking_tokens - (int)bos_added, 15); // insert random number request.dataset.push_back(input_tokens); + std::cout << "Creating dataset with benchmarking tokens. Size of dataset: " << request.dataset.size() << std::endl; } else { using json = nlohmann::json; std::ifstream file_handle(request.peft_finetuning_info.dataset_filepath); @@ -190,6 +183,7 @@ bool RequestManager::load_request_token_ids(Request &request) { request.dataset.push_back(input_tokens); } } + std::cout << "Creating dataset from json file: " << request.peft_finetuning_info.dataset_filepath << ". Size of dataset: " << request.dataset.size() << std::endl; } if (request.peft_finetuning_info.gradient_accumulation_steps == -1) { request.peft_finetuning_info.gradient_accumulation_steps = @@ -275,6 +269,8 @@ RequestManager::RequestManager() max_sequence_length = -1; } +void RequestManager::set_verbose(bool verbose_) { verbose = verbose_; } + void RequestManager::set_max_requests_per_batch(int max_num_requests) { assert(max_requests_per_batch == -1 || max_requests_per_batch == max_num_requests); @@ -438,6 +434,15 @@ void RequestManager::set_peft_config(PEFTModelID const &peft_model_id, LoraLinearConfig const & RequestManager::get_peft_config(PEFTModelID const &peft_model_id) { + if (peft_configs.find(peft_model_id) == peft_configs.end()) { + std::cout << "PEFT model ID not found" << std::endl; + std::cout << peft_model_id << std::endl; + // print all registerd peft model ids + std::cout << "Registered PEFT model IDs:" << std::endl; + for (auto const &pair : peft_configs) { + std::cout << pair.first << std::endl; + } + } assert(peft_configs.find(peft_model_id) != peft_configs.end() && "PEFT model ID not found"); return peft_configs[peft_model_id]; @@ -506,13 +511,15 @@ PEFTModelID * PEFTModelID *peft_model_id = new PEFTModelID(peft_model_global_guid++); RequestManager *rm = RequestManager::get_request_manager(); rm->set_peft_config(*peft_model_id, peft_config); + std::cout << "Registered PEFT adapter with id: " << *peft_model_id + << std::endl; return peft_model_id; } RequestGuid RequestManager::register_new_request(Request const &request_) { const std::lock_guard lock(request_queue_mutex); // Add a new request - Request request(request_); + Request request = Request::from_other(request_); if (!load_request_token_ids(request)) { return BatchConfig::INVALID_GUID; } @@ -524,13 +531,18 @@ RequestGuid RequestManager::register_new_request(Request const &request_) { request_to_promise[request.guid] = new std::promise(); } - { - std::string output = "New request tokens:"; - output = "[" + std::to_string(request.guid) + "]" + output; - for (int i = 0; i < request.tokens.size(); i++) { - output = output + " " + std::to_string(request.tokens[i]); - } - log_req_mgr.print("%s", output.c_str()); + // { + // std::string output = "New request tokens:"; + // output = "[" + std::to_string(request.guid) + "]" + output; + // for (int i = 0; i < request.tokens.size(); i++) { + // output = output + " " + std::to_string(request.tokens[i]); + // } + // log_req_mgr.print("%s", output.c_str()); + // } + if (verbose) { + std::cout << "Registered new request with guid: " << request.guid + << std::endl; + std::cout << request << std::endl; } GenerationResult gr; @@ -552,7 +564,7 @@ RequestGuid RequestManager::register_new_peft_request(Request const &request_) { assert(enable_peft_finetuning && "PEFT finetuning is not enabled"); const std::lock_guard lock(request_queue_mutex); // Add a new request - Request request(request_); + Request request = Request::from_other(request_); if (!load_request_token_ids(request)) { return BatchConfig::INVALID_GUID; } @@ -564,12 +576,17 @@ RequestGuid RequestManager::register_new_peft_request(Request const &request_) { request_to_promise[request.guid] = new std::promise(); } - for (size_t r = 0; r < request.dataset.size(); r++) { - std::string input = "[" + std::to_string(r) + "] entry:"; - for (size_t i = 0; i < request.dataset[r].size(); i++) { - input = input + " " + std::to_string(request.dataset[r][i]); - } - log_req_mgr.print("%s", input.c_str()); + // for (size_t r = 0; r < request.dataset.size(); r++) { + // std::string input = "[" + std::to_string(r) + "] entry:"; + // for (size_t i = 0; i < request.dataset[r].size(); i++) { + // input = input + " " + std::to_string(request.dataset[r][i]); + // } + // log_req_mgr.print("%s", input.c_str()); + // } + if (verbose) { + std::cout << "Registered new request with guid: " << request.guid + << std::endl; + std::cout << request << std::endl; } GenerationResult gr; @@ -1349,6 +1366,13 @@ void RequestManager::process_work_from_old_batches( InferenceResult const &result) { const std::lock_guard lock(request_queue_mutex); + if (verbose) { + std::cout << "\n############### process_work_from_old_batches ###############\n"; + std::cout << "old_fwd_bc: " << old_fwd_bc << std::endl; + std::cout << "old_bwd_bc: " << old_bwd_bc << std::endl; + std::cout << "result: " << result << std::endl; + } + // Step 1: Inference. Process work from previous fwd iteration: save generated // inference tokens and update records of finetuning fwd progress process_inf_req_progress(old_fwd_bc, result); @@ -1380,6 +1404,12 @@ BatchConfig InferenceResult const &result) { const std::lock_guard lock(request_queue_mutex); + if (verbose) { + std::cout << "\n############### prepare_next_fwd_batch ###############\n"; + std::cout << "old_fwd_bc: " << old_fwd_bc << std::endl; + std::cout << "result: " << result << std::endl; + } + // Step 1: Create new batch config BatchConfig new_bc; // params @@ -3376,7 +3406,7 @@ void RequestManager::serve_incr_decoding(FFModel *llm) { BatchConfigFuture bcf_fwd = next_batches.first; BatchConfigFuture bcf_bwd = next_batches.second; InferenceResultFuture irf = im->inference(llm, 0, bcf_fwd); - FinetuningBwdFuture bwd_f = im->peft_bwd(llm, 0, bcf_bwd); + FinetuningBwdFuture bwd_f = (llm->config.enable_peft) ? im->peft_bwd(llm, 0, bcf_bwd) : Future::from_value(true); batch_pipeline.push(std::make_tuple(bcf_fwd, bcf_bwd, irf, bwd_f)); last_bcf_fwd = bcf_fwd; last_bcf_bwd = bcf_bwd;