Skip to content

Commit

Permalink
build fix
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Nov 25, 2024
1 parent 9c0d827 commit 784c8d9
Show file tree
Hide file tree
Showing 11 changed files with 453 additions and 286 deletions.
58 changes: 38 additions & 20 deletions include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@ class InferenceManager {
static InferenceManager *get_inference_manager();
void compile_model_and_allocate_buffer(FFModel *model);
void init_operators_inference(FFModel *model);
InferenceResultFuture inference(FFModel *model, int index, BatchConfig const &bc);
InferenceResultFuture inference(FFModel *model, int index, BatchConfigFuture const &bc);
FinetuningBwdFuture peft_bwd(FFModel *model, int index, BatchConfigFuture const &bc);
InferenceResultFuture
inference(FFModel *model, int index, BatchConfig const &bc);
InferenceResultFuture
inference(FFModel *model, int index, BatchConfigFuture const &bc);
FinetuningBwdFuture
peft_bwd(FFModel *model, int index, BatchConfigFuture const &bc);
void load_input_tokens_from_batch_config(FFModel *model,
BatchConfigFuture const &bc,
ParallelTensor const input,
Expand Down Expand Up @@ -67,7 +70,7 @@ struct Request {
};
enum FinetuningStatus {
FORWARD_PHASE = 201,
BACKWARD_PHASE = 202,
BACKWARD_PHASE = 202,
};
struct PeftFinetuningInfo {
FinetuningStatus status = FORWARD_PHASE;
Expand All @@ -80,8 +83,8 @@ struct Request {
std::vector<float> finetuning_losses;
// bwd state
int last_processed_layer = INT_MAX;
// how many gradient accumulation steps to do before updating the weights. if
// left as -1, it will be set to the number of entries in the dataset
// how many gradient accumulation steps to do before updating the weights.
// if left as -1, it will be set to the number of entries in the dataset
int gradient_accumulation_steps = -1;
// std::vector<int> finetuning_tokens_per_batch;
};
Expand All @@ -96,20 +99,20 @@ struct Request {
// inference fields
std::string prompt;
std::vector<BatchConfig::TokenId> tokens;

// peft fields
PEFTModelID peft_model_id = PEFTModelID::NO_ID;
PeftFinetuningInfo peft_finetuning_info;
std::vector<std::vector<BatchConfig::TokenId>> dataset;

// speculation fields
int initial_len = 0;
int ssm_cache_size = 0;
int llm_cache_size = 0;
std::vector<struct BeamTree> beam_trees;

Request() = default;
Request(const Request& other);
Request(Request const &other);
void load_token_ids();

friend std::ostream &operator<<(std::ostream &os, Request const &req);
Expand Down Expand Up @@ -214,25 +217,40 @@ class RequestManager {
void add_peft_config_to_request_info(BatchConfig &bc,
int req_idx,
LoraLinearConfig const &peft_config);

// helpers for prepare_next_batch
void process_inf_req_progress(BatchConfig const &old_fwd_bc, InferenceResult const &result);
void process_inf_req_progress(BatchConfig const &old_fwd_bc,
InferenceResult const &result);
void handle_completed_inf_req(BatchConfig const &old_bc, int i);
void add_continuing_inf_req_to_new_batch(BatchConfig &new_bc, BatchConfig const &old_bc, int &num_active_req, int &num_concurrent_inf_adapters, int i);
void add_new_inf_req(BatchConfig &new_bc, int &num_active_req, int &num_concurrent_inf_adapters, int i);
void add_continuing_inf_req_to_new_batch(BatchConfig &new_bc,
BatchConfig const &old_bc,
int &num_active_req,
int &num_concurrent_inf_adapters,
int i);
void add_new_inf_req(BatchConfig &new_bc,
int &num_active_req,
int &num_concurrent_inf_adapters,
int i);
void handle_completed_finetuning_req(BatchConfig const &old_finetuning_bc);
void add_finetuning_req_fwd_batch(BatchConfig &new_bc);
void add_finetuning_req_bwd_batch(BatchConfig &new_bc);
bool finetuning_fwd_work_available();
bool finetuning_bwd_work_available();
void process_finetuning_req_fwd_progress(BatchConfig const &old_fwd_bc, InferenceResult const &result);
void process_finetuning_req_fwd_progress(BatchConfig const &old_fwd_bc,
InferenceResult const &result);
void process_finetuning_req_bwd_progress(BatchConfig const &old_bwd_bc);
void process_work_from_old_batches(BatchConfig const &old_fwd_bc, BatchConfig const &old_bwd_bc, InferenceResult const &result);
void process_work_from_old_batches(BatchConfig const &old_fwd_bc,
BatchConfig const &old_bwd_bc,
InferenceResult const &result);
BatchConfig prepare_next_bwd_batch();
BatchConfig prepare_next_fwd_batch(BatchConfig const &old_fwd_bc, InferenceResult const &result);
BatchConfigPairFuture prepare_next_batch(std::tuple<BatchConfigFuture, BatchConfigFuture, InferenceResultFuture, FinetuningBwdFuture> &batch_pipeline_entry,
Context ctx,
Runtime *runtime);
BatchConfig prepare_next_fwd_batch(BatchConfig const &old_fwd_bc,
InferenceResult const &result);
BatchConfigPairFuture
prepare_next_batch(std::tuple<BatchConfigPairFuture,
InferenceResultFuture,
FinetuningBwdFuture> &batch_pipeline_entry,
Context ctx,
Runtime *runtime);
// BatchConfig prepare_next_batch(BatchConfig const &bc,
// InferenceResult const &result);
// BatchConfigFuture prepare_next_batch(BatchConfigFuture const &bc,
Expand Down Expand Up @@ -311,7 +329,7 @@ class RequestManager {
Legion::Context ctx,
Legion::Runtime *runtime);
static std::pair<BatchConfig, BatchConfig> prepare_next_batch_task(
Legion::Task const *task,
Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
Expand Down
19 changes: 10 additions & 9 deletions inference/incr_decoding/incr_decoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,11 @@ void FlexFlow::top_level_task(Task const *task,
using json = nlohmann::json;
std::ifstream file_handle(file_paths.prompt_file_path);
assert(file_handle.good() && "Prompt file does not exist.");
nlohmann::ordered_json prompt_json = nlohmann::ordered_json::parse(file_handle,
/*parser_callback_t */ nullptr,
/*allow_exceptions */ true,
/*ignore_comments */ true);
nlohmann::ordered_json prompt_json =
nlohmann::ordered_json::parse(file_handle,
/*parser_callback_t */ nullptr,
/*allow_exceptions */ true,
/*ignore_comments */ true);
file_handle.close();
auto &metadata = prompt_json["metadata"];
int num_warmup_requests = metadata["num_warmup_requests"];
Expand All @@ -289,7 +290,7 @@ void FlexFlow::top_level_task(Task const *task,
int response_length = entry["response_length"];
std::string text = entry["prompt"];
bool is_warmup_request = total_requests < num_warmup_requests;

Request inference_req;
inference_req.prompt = text;
inference_req.add_special_tokens = false;
Expand All @@ -302,18 +303,19 @@ void FlexFlow::top_level_task(Task const *task,
requests.push_back(inference_req);
num_regular_requests++;
}

total_requests++;
}
std::vector<GenerationResult> warmup_result = model.generate(warmup_requests);
std::vector<GenerationResult> warmup_result =
model.generate(warmup_requests);
std::vector<GenerationResult> result = model.generate(requests);

assert(warmup_result.size() == warmup_requests.size());
assert(result.size() == requests.size());
assert(result.size() + warmup_result.size() == total_requests);
int i = 0;
for (auto &entry : prompt_json["entries"]) {
if (i<num_warmup_requests) {
if (i < num_warmup_requests) {
i++;
continue;
}
Expand All @@ -337,7 +339,6 @@ void FlexFlow::top_level_task(Task const *task,
} else {
std::cerr << "Unable to open file for writing." << std::endl;
}

}

// terminate the request manager by stopping the background thread
Expand Down
3 changes: 2 additions & 1 deletion inference/peft/peft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@ void FlexFlow::top_level_task(Task const *task,
fine_tuning_req.peft_model_id = (peft_model_id_finetuning != nullptr)
? *peft_model_id_finetuning
: PEFTModelID::NO_ID;
fine_tuning_req.peft_finetuning_info.dataset_filepath = file_paths.dataset_file_path;
fine_tuning_req.peft_finetuning_info.dataset_filepath =
file_paths.dataset_file_path;
fine_tuning_req.peft_finetuning_info.max_training_steps = 2;
requests.push_back(fine_tuning_req);
}
Expand Down
4 changes: 2 additions & 2 deletions inference/peft/peft_bwd_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ void FlexFlow::top_level_task(Task const *task,
fine_tuning_req.warmup = true;
fine_tuning_req.peft_model_id =
(peft_model_id != nullptr) ? *peft_model_id : PEFTModelID::NO_ID;
fine_tuning_req.max_training_steps = 1;
fine_tuning_req.peft_finetuning_info.max_training_steps = 1;
requests.push_back(fine_tuning_req);
std::vector<GenerationResult> result = model.generate(requests);
}
Expand Down Expand Up @@ -377,7 +377,7 @@ void FlexFlow::top_level_task(Task const *task,
fine_tuning_req.max_length = lengths[i];
fine_tuning_req.peft_model_id =
(peft_model_id != nullptr) ? *peft_model_id : PEFTModelID::NO_ID;
fine_tuning_req.max_training_steps = 1;
fine_tuning_req.peft_finetuning_info.max_training_steps = 1;
requests.push_back(fine_tuning_req);
}
std::vector<GenerationResult> result = model.generate(requests);
Expand Down
20 changes: 10 additions & 10 deletions inference/peft/req_rate_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ Legion::Logger log_app("llama");

class ConcurrentQueue {
public:
std::queue<RequestManager::RequestGuid> inf_queue;
std::queue<RequestManager::RequestGuid> peft_queue;
std::queue<BatchConfig::RequestGuid> inf_queue;
std::queue<BatchConfig::RequestGuid> peft_queue;
std::mutex request_queue_mutex;
bool producer_finished = false;
};
Expand All @@ -58,7 +58,7 @@ void consume() {
bool queue_is_empty = false;
// int i=0;
while (!producer_is_finished || !queue_is_empty) {
RequestManager::RequestGuid guid = RequestManager::INVALID_GUID;
BatchConfig::RequestGuid guid = BatchConfig::INVALID_GUID;
{
const std::lock_guard<std::mutex> lock(guids->request_queue_mutex);
queue_is_empty = guids->inf_queue.empty();
Expand All @@ -68,7 +68,7 @@ void consume() {
guids->inf_queue.pop();
}
}
if (guid != RequestManager::INVALID_GUID) {
if (guid != BatchConfig::INVALID_GUID) {
GenerationResult result = rm->get_generation_result(guid);
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(nb_millisecs));
Expand Down Expand Up @@ -396,7 +396,7 @@ void FlexFlow::top_level_task(Task const *task,
fine_tuning_req.warmup = true;
fine_tuning_req.peft_model_id =
(peft_model_id != nullptr) ? *peft_model_id : PEFTModelID::NO_ID;
fine_tuning_req.max_training_steps = 1;
fine_tuning_req.peft_finetuning_info.max_training_steps = 1;
requests.push_back(fine_tuning_req);
std::vector<GenerationResult> result = model.generate(requests);
}
Expand Down Expand Up @@ -459,10 +459,10 @@ void FlexFlow::top_level_task(Task const *task,
fine_tuning_req.max_length = 1024;
fine_tuning_req.peft_model_id =
(peft_model_id != nullptr) ? *peft_model_id : PEFTModelID::NO_ID;
fine_tuning_req.max_training_steps = 1000000000;
RequestManager::RequestGuid ft_guid =
fine_tuning_req.peft_finetuning_info.max_training_steps = 1000000000;
BatchConfig::RequestGuid ft_guid =
rm->register_new_peft_request(fine_tuning_req);
if (ft_guid != RequestManager::INVALID_GUID) {
if (ft_guid != BatchConfig::INVALID_GUID) {
const std::lock_guard<std::mutex> lock(guids->request_queue_mutex);
guids->peft_queue.push(ft_guid);
}
Expand Down Expand Up @@ -495,9 +495,9 @@ void FlexFlow::top_level_task(Task const *task,
{
const std::lock_guard<std::mutex> lock(guids->request_queue_mutex);
for (int i = 0; i < requests.size(); i++) {
RequestManager::RequestGuid guid =
BatchConfig::RequestGuid guid =
rm->register_new_request(requests.at(i));
if (guid != RequestManager::INVALID_GUID) {
if (guid != BatchConfig::INVALID_GUID) {
guids->inf_queue.push(guid);
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/c/flexflow_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1747,7 +1747,8 @@ void flexflow_model_generate(flexflow_model_t handle_,
}
std::string const dataset_fp(dataset_filepaths[i]);
fine_tuning_req.peft_finetuning_info.dataset_filepath = dataset_fp;
fine_tuning_req.peft_finetuning_info.max_training_steps = training_steps[i];
fine_tuning_req.peft_finetuning_info.max_training_steps =
training_steps[i];
requests.push_back(fine_tuning_req);
DEBUG_PRINT("[Model] finetune[%d] %p %s %i %i %i %i",
i,
Expand Down
26 changes: 13 additions & 13 deletions src/ops/inc_multihead_self_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -698,19 +698,19 @@ __global__ void scaling_query_kernel(DT *input_ptr,
template <typename DT>
__global__ void
apply_rotary_embedding_fwd(DT *input_ptr,
hipFloatComplex *complex_input,
BatchConfig::PerTokenInfo const *tokenInfos,
float rope_theta,
bool llama3_rope,
float factor,
float low_freq_factor,
float high_freq_factor,
int original_max_position_embeddings,
int qProjSize,
int kProjSize,
int num_tokens,
size_t q_array_size,
int hidden_size) {
hipFloatComplex *complex_input,
BatchConfig::PerTokenInfo const *tokenInfos,
float rope_theta,
bool llama3_rope,
float factor,
float low_freq_factor,
float high_freq_factor,
int original_max_position_embeddings,
int qProjSize,
int kProjSize,
int num_tokens,
size_t q_array_size,
int hidden_size) {
CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) {
// create complex number
bool q_tensor = i < (q_array_size / 2);
Expand Down
32 changes: 16 additions & 16 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -654,19 +654,19 @@ __global__ void scaling_query_kernel(DT *input_ptr,
template <typename DT>
__global__ void
apply_rotary_embedding_fwd(DT *input_ptr,
cuFloatComplex *complex_input,
BatchConfig::PerTokenInfo const *tokenInfos,
float rope_theta,
bool llama3_rope,
float factor,
float low_freq_factor,
float high_freq_factor,
int original_max_position_embeddings,
int qProjSize,
int kProjSize,
int num_tokens,
size_t q_array_size,
int hidden_size) {
cuFloatComplex *complex_input,
BatchConfig::PerTokenInfo const *tokenInfos,
float rope_theta,
bool llama3_rope,
float factor,
float low_freq_factor,
float high_freq_factor,
int original_max_position_embeddings,
int qProjSize,
int kProjSize,
int num_tokens,
size_t q_array_size,
int hidden_size) {
CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) {
// create complex number
bool q_tensor = i < (q_array_size / 2);
Expand Down Expand Up @@ -826,9 +826,9 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m,
/*q&k*/
parallelism = num_tokens * m->hidden_size;
apply_rotary_embedding_fwd<<<GET_BLOCKS(parallelism),
min(CUDA_NUM_THREADS, parallelism),
0,
stream>>>(
min(CUDA_NUM_THREADS, parallelism),
0,
stream>>>(
output_ptr,
m->complex_input,
m->token_infos,
Expand Down
12 changes: 9 additions & 3 deletions src/runtime/inference_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,9 @@ void InferenceManager::init_operators_inference(FFModel *model) {
}
}

InferenceResultFuture InferenceManager::inference(FFModel *model, int index, BatchConfig const &bc) {
InferenceResultFuture InferenceManager::inference(FFModel *model,
int index,
BatchConfig const &bc) {
if (bc.get_mode() == INC_DECODING_MODE) {
BatchConfigFuture bcf = Future::from_value<BatchConfig>(bc);
return inference(model, index, bcf);
Expand All @@ -403,7 +405,9 @@ InferenceResultFuture InferenceManager::inference(FFModel *model, int index, Bat
}
}

InferenceResultFuture InferenceManager::inference(FFModel *model, int index, BatchConfigFuture const &bc) {
InferenceResultFuture InferenceManager::inference(FFModel *model,
int index,
BatchConfigFuture const &bc) {
// log_inf_mgr.print("mode(%d) num_active_infr_tokens(%d)
// num_active_requests(%d)",
// bc.get_mode(),
Expand Down Expand Up @@ -465,7 +469,9 @@ InferenceResultFuture InferenceManager::inference(FFModel *model, int index, Bat
return irf;
};

FinetuningBwdFuture InferenceManager::peft_bwd(FFModel *model, int index, BatchConfigFuture const &bc) {
FinetuningBwdFuture InferenceManager::peft_bwd(FFModel *model,
int index,
BatchConfigFuture const &bc) {
int batch_index = index % model->config.data_parallelism_degree;
FutureMap fm;
bool found_input_operator = false;
Expand Down
Loading

0 comments on commit 784c8d9

Please sign in to comment.