diff --git a/include/flexflow/ops/kernels/lora_linear_kernels.h b/include/flexflow/ops/kernels/lora_linear_kernels.h index 7138f62e90..b17868fb96 100644 --- a/include/flexflow/ops/kernels/lora_linear_kernels.h +++ b/include/flexflow/ops/kernels/lora_linear_kernels.h @@ -52,6 +52,7 @@ void peft_bwd_kernel_wrapper(Context ctx, Runtime *runtime, LoraLinearMeta *m, BatchConfig const *bc, + int shard_id, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &output_grad); @@ -71,6 +72,7 @@ void peft_bwd_kernel(Context ctx, Runtime *runtime, LoraLinearMeta *m, BatchConfig const *bc, + int shard_id, DT *input_grad_ptr, DT const *output_grad_ptr, int in_dim, diff --git a/include/flexflow/utils/peft_weight_allocator.h b/include/flexflow/utils/peft_weight_allocator.h index bd8ddb1dce..21ac9bf426 100644 --- a/include/flexflow/utils/peft_weight_allocator.h +++ b/include/flexflow/utils/peft_weight_allocator.h @@ -23,74 +23,6 @@ namespace FlexFlow { -#ifdef DEADCODE -class PEFTWeightAllocator { -public: - PEFTWeightAllocator(void *_base_ptr, size_t _total_size) - : base_ptr(_base_ptr), total_size(_total_size), sync_offset(0), - local_offset(_total_size) {} - - inline void *allocate_sync_weights_untyped(PEFTModelID const &peft_model_id, - size_t datalen) { - const std::lock_guard lock(peft_weight_allocator_mutex); - void *ptr = static_cast(base_ptr) + sync_offset; - off_t model_sync_weights_offset = sync_offset; - size_t model_sync_weights_size = datalen; - if (sync_weights.find(peft_model_id) != sync_weights.end()) { - // Assert that sync weights for each PEFT model is consecutive - std::pair offset_and_size = sync_weights[peft_model_id]; - assert(sync_offset == offset_and_size.first + offset_and_size.second); - model_sync_weights_offset = offset_and_size.first; - model_sync_weights_size = offset_and_size.second + datalen; - } - sync_offset += datalen; - assert(sync_offset < local_offset); - sync_weights[peft_model_id] = - std::make_pair(model_sync_weights_offset, model_sync_weights_size); - return ptr; - } - - std::pair - get_sync_weights_ptr_and_size(PEFTModelID const &peft_model_id) { - const std::lock_guard lock(peft_weight_allocator_mutex); - assert(sync_weights.find(peft_model_id) != sync_weights.end()); - std::pair offset_and_size = sync_weights[peft_model_id]; - return std::make_pair(static_cast(base_ptr) + offset_and_size.first, - offset_and_size.second); - } - - inline void *allocate_local_weights_untyped(PEFTModelID const &peft_model_id, - size_t datalen) { - const std::lock_guard lock(peft_weight_allocator_mutex); - local_offset -= datalen; - assert(sync_offset < local_offset); - void *ptr = static_cast(base_ptr) + local_offset; - return ptr; - } - - template - inline DT *allocate_sync_weights(PEFTModelID const &peft_model_id, - size_t count) { - return static_cast
( - allocate_sync_weights_untyped(peft_model_id, sizeof(DT) * count)); - } - - template - inline DT *allocate_local_weights(PEFTModelID const &peft_model_id, - size_t count) { - return static_cast
( - allocate_local_weights_untyped(peft_model_id, sizeof(DT) * count)); - } - -public: - void *base_ptr; - size_t total_size; - off_t sync_offset, local_offset; - std::unordered_map> sync_weights; - std::mutex peft_weight_allocator_mutex; -}; -#endif - struct LoraLinearWeight { // weights void *w0_ptr, *w1_ptr; diff --git a/src/ops/fused.cu b/src/ops/fused.cu index 62845c0f8e..c615a104d2 100644 --- a/src/ops/fused.cu +++ b/src/ops/fused.cu @@ -889,11 +889,13 @@ __host__ void FusedOp::peft_bwd_task(Task const *task, // Assert that the output and the second input are at the same place // since we ``inplace'' the output for LoRA assert(my_input_grad_accessor[1].ptr == my_output_grad_accessor[0].ptr); + int shard_id = task->index_point.point_data[0]; Kernels::LoraLinear::peft_bwd_kernel_wrapper( ctx, runtime, m, bc, + shard_id, my_input_grad_accessor[0], my_output_grad_accessor[0]); break; diff --git a/src/ops/kernels/lora_linear_kernels.cu b/src/ops/kernels/lora_linear_kernels.cu index 09d79809a7..dabe40c501 100644 --- a/src/ops/kernels/lora_linear_kernels.cu +++ b/src/ops/kernels/lora_linear_kernels.cu @@ -24,31 +24,34 @@ namespace FlexFlow { LoraLinearMeta::LoraLinearMeta(FFHandler handler, LoraLinear const *li) : OpMeta(handler, li) { -#ifdef DEADCODE - allocated_peft_buffer_size1 = 0; - allocated_peft_buffer_size2 = 0; -#endif } LoraLinearMeta::~LoraLinearMeta(void) {} +std::string get_peft_dbg_folder(LoraLinearMeta const *m, + int shard_id, + bool is_fwd) { + std::string op_name_without_uid = LoraLinear::get_op_name_without_uid(m); + fs::path dst_filepath; + if (is_fwd) { + dst_filepath = get_dst_folder("fwd", m->decoding_step, shard_id); + } else { + dst_filepath = get_dst_folder("bwd", m->bwd_step, shard_id); + } + if (m->layer_guid.model_id > 0) { + assert(false && "Model ID > 0 not supported yet"); + } + std::string layername = "layers." + + std::to_string(m->layer_guid.transformer_layer_id) + + "." + op_name_without_uid; + dst_filepath /= layername; + return dst_filepath.string(); +} + namespace Kernels { namespace LoraLinear { -#ifdef DEADCODE -void init_kernel_wrapper(LoraLinearMeta *m, int seed) { - cudaStream_t stream; - checkCUDA(get_legion_stream(&stream)); - if (m->input_type[0] == DT_FLOAT) { - Internal::init_kernel(m, seed, stream); - } else if (m->input_type[0] == DT_HALF) { - Internal::init_kernel(m, seed, stream); - } else { - assert(false && "Unsupported data type"); - } -} -#endif void inference_kernel_wrapper(LoraLinearMeta *m, BatchConfig const *bc, @@ -104,6 +107,7 @@ void peft_bwd_kernel_wrapper(Context ctx, Runtime *runtime, LoraLinearMeta *m, BatchConfig const *bc, + int shard_id, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &output_grad) { cudaStream_t stream; @@ -121,6 +125,7 @@ void peft_bwd_kernel_wrapper(Context ctx, runtime, m, bc, + shard_id, input_grad.get_float_ptr(), output_grad.get_float_ptr(), in_dim, @@ -131,6 +136,7 @@ void peft_bwd_kernel_wrapper(Context ctx, runtime, m, bc, + shard_id, input_grad.get_half_ptr(), output_grad.get_half_ptr(), in_dim, @@ -168,146 +174,6 @@ bool lora_applies_to_this_layer(LoraLinearMeta *m, namespace Internal { -#ifdef DEADCODE -template -void inference_kernel(LoraLinearMeta *m, - BatchConfig const *bc, - DT const *input_ptr, - DT *output_ptr, - int in_dim, - int out_dim, - ffStream_t stream) { - checkCUDA(cublasSetStream(m->handle.blas, stream)); - checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); - DT alpha = 1.0f, beta = 0.0f; - cudaDataType_t input_type = ff_to_cuda_datatype(m->input_type[0]); - cudaDataType_t output_type = ff_to_cuda_datatype(m->input_type[1]); - cudaDataType_t lr_actv_type = output_type; - assert(input_type == output_type); - cudaDataType_t weight_type = output_type; - cudaDataType_t compute_type = output_type; - // #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) - // cudaDataType_t compute_type = output_type; - // #else - // // For best performance, set the default cublas compute type to - // // CUBLAS_COMPUTE_16F for half precision and to - // // CUBLAS_COMPUTE_32F_FAST_16F for full precision - // cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; - // if (m->input_type[0] == DT_FLOAT) { - // compute_type = CUBLAS_COMPUTE_32F_FAST_16F; - // } - // #endif - int num_peft_requests = 0; - for (int i = 0; i < bc->max_requests_per_batch(); i++) { - if (bc->request_completed[i]) { - continue; - } - if (bc->requestsInfo[i].peft_model_id == PEFTModelID::NO_ID) { - continue; - } - if (bc->requestsInfo[i].peft_bwd) { - num_peft_requests++; - } - } - // Assert that we have at most one request that requires peft_bwd - assert(num_peft_requests <= 1); - for (int i = 0; i < bc->max_requests_per_batch(); i++) { - if (bc->request_completed[i]) { - continue; - } - // Skip non-PEFT requests - if (bc->requestsInfo[i].peft_model_id == PEFTModelID::NO_ID) { - continue; - } - int num_peft_tokens = bc->requestsInfo[i].num_tokens_in_batch; - int max_peft_tokens = bc->requestsInfo[i].max_length; - int first_token_offset = bc->requestsInfo[i].first_token_offset_in_batch; - assert(m->model_state.find(bc->requestsInfo[i].peft_model_id) != - m->model_state.end()); - LoraLinearWeight weight = - m->model_state[bc->requestsInfo[i].peft_model_id].weights; - int rank = weight.rank; - void *intermediate_result_ptr = nullptr; - if (bc->requestsInfo[i].peft_bwd) { - size_t activation_size_needed1 = - data_type_size(m->input_type[0]) * max_peft_tokens * in_dim; - size_t activation_size_needed2 = - data_type_size(m->input_type[1]) * max_peft_tokens * rank; - MemoryAllocator *allocator = m->handle.peft_activation_allocator; - if (activation_size_needed1 > m->allocated_peft_buffer_size1) { - m->input_activation = - allocator->allocate_instance_untyped(activation_size_needed1); - m->allocated_peft_buffer_size1 = activation_size_needed1; - } - if (activation_size_needed2 > m->allocated_peft_buffer_size2) { - m->low_rank_activation = - allocator->allocate_instance_untyped(activation_size_needed2); - m->allocated_peft_buffer_size2 = activation_size_needed2; - } - // copy input activation - checkCUDA(cudaMemcpyAsync(m->input_activation, - input_ptr + first_token_offset * in_dim, - data_type_size(m->input_type[0]) * - num_peft_tokens * in_dim, - cudaMemcpyDeviceToDevice, - stream)); - intermediate_result_ptr = m->low_rank_activation; - } else { - // use workspace to save intermediate result - assert(m->handle.workSpaceSize >= - data_type_size(m->input_type[1]) * num_peft_tokens * rank); - intermediate_result_ptr = m->handle.workSpace; - } - // buffer = weight_first * input - // [rank, num_peft_tokens] = [in_dim, rank].T * [in_dim, num_peft_tokens] - checkCUDA(cublasGemmEx(m->handle.blas, - CUBLAS_OP_T, - CUBLAS_OP_N, - rank, - num_peft_tokens, - in_dim, - &alpha, - weight.w0_ptr, - weight_type, - in_dim, - input_ptr + first_token_offset * in_dim, - input_type, - in_dim, - &beta, - intermediate_result_ptr, - lr_actv_type, - rank, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - // output = weight_second * buffer - // [out_dim, num_peft_tokens] = [rank, out_dim].T * [rank, num_peft_tokens] - // Note that we use alpha in both places since we do - // an in-place update for LoraLinear - float lora_alpha = - m->model_state[bc->requestsInfo[i].peft_model_id].lora_alpha; - DT scaling_constant = (DT)(lora_alpha / rank); - checkCUDA(cublasGemmEx(m->handle.blas, - CUBLAS_OP_T, - CUBLAS_OP_N, - out_dim, - num_peft_tokens, - rank, - &scaling_constant, - weight.w1_ptr, - weight_type, - rank, - intermediate_result_ptr, - lr_actv_type, - rank, - &alpha, - output_ptr + first_token_offset * out_dim, - output_type, - out_dim, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - } -} -#endif template void inference_kernel(LoraLinearMeta *m, @@ -342,6 +208,8 @@ void inference_kernel(LoraLinearMeta *m, if (!lora_applies_to_this_layer(m, lora_config)) { continue; } + std::cout << "Lora layer activated!" << std::endl; + std::cout << "Lora Config: " << peft_model_config_str << std::endl; assert(lora_config.trainable == bc->requestsInfo[i].peft_bwd && "Trainable flag mismatch"); int num_peft_tokens = bc->requestsInfo[i].num_tokens_in_batch; @@ -443,6 +311,7 @@ void peft_bwd_kernel(Context ctx, Runtime *runtime, LoraLinearMeta *m, BatchConfig const *bc, + int shard_id, DT *input_grad_ptr, DT const *output_grad_ptr, int in_dim, @@ -471,6 +340,8 @@ void peft_bwd_kernel(Context ctx, if (!lora_applies_to_this_layer(m, lora_config)) { continue; } + std::cout << "Lora layer activated!" << std::endl; + std::cout << "Lora Config: " << peft_model_config_str << std::endl; assert(lora_config.trainable == bc->requestsInfo[i].peft_bwd && "Trainable flag mismatch"); m->peft_memory_manager->check_ft_model_id( @@ -488,6 +359,13 @@ void peft_bwd_kernel(Context ctx, DT beta = (bc->requestsInfo[i].optimizer_tasks.reset_gradients_to_zero) ? 0.0f : 1.0f; + std::cout << "Lora B gradient computation, beta = " << (float) beta << std::endl; + if (m->inference_debugging) { + // save result to file for checking + std::string filename = get_peft_dbg_folder(m, shard_id, false) + ".low_rank_activation"; + std::cout << "Save low_rank_activation (" << lora_config.rank << ", " << num_peft_tokens << ") to " << filename << std::endl; + save_tensor(static_cast(weight.low_rank_activation), lora_config.rank*num_peft_tokens, filename.c_str()); + } checkCUDA(cublasGemmEx(m->handle.blas, CUBLAS_OP_N, CUBLAS_OP_T, diff --git a/src/ops/lora_linear.cc b/src/ops/lora_linear.cc index f17f69a7c9..5f67709358 100644 --- a/src/ops/lora_linear.cc +++ b/src/ops/lora_linear.cc @@ -136,133 +136,6 @@ void FFModel::add_lora_layers(std::vector target_modules) { } } -#ifdef DEADCODE -PEFTModelID *FFModel::add_lora_layer(LoraLinearConfig const peft_config) { - assert(config.enable_peft && - "Cannot add a LoRA layer if PEFT mode is not enabled"); - if (peft_config.target_modules.size() == 0) { - printf("PEFT config does not contain any target module\n"); - std::cout << peft_config << std::endl; - assert(false); - } - PEFTModelID *peft_model_id = new PEFTModelID(peft_model_global_guid++); - peft_configs[*peft_model_id] = peft_config; - - for (std::string target_module_name : peft_config.target_modules) { - assert(target_module_name.length() > 0 && - "LoRA target module name is empty"); - // find target layer - for (auto it = layers.begin(); it != layers.end(); ++it) { - Layer *target_module = *it; - bool match = check_lora_layer_match(target_module, target_module_name); - if (!match) { - continue; - } - - if (base_layer_to_peft_layer.find(target_module) != - base_layer_to_peft_layer.end()) { - // lora linear layer already added, no need to add again - Layer *peft_layer = base_layer_to_peft_layer[target_module]; - peft_layer_to_peft_id[peft_layer].push_back(*peft_model_id); - } else { - Tensor const input = target_module->inputs[0]; - Tensor const output = target_module->outputs[0]; - assert(input->data_type == output->data_type); - std::string name_ = target_module->name - ? std::string(target_module->name) - : std::string(""); - size_t last_underscore = name_.length() - 1; - for (int i = name_.length() - 1; i > 0; i--) { - if (!(std::isdigit(target_module->name[i]) || - target_module->name[i] == '_')) { - break; - } else if (target_module->name[i] == '_') { - last_underscore = i; - } - } - name_.erase(last_underscore); - - name_ += ".lora"; - std::cout << "Adding layer " << name_ << std::endl; - Layer *peft_layer = new Layer(this, - OP_LORA, - output->data_type, - name_.c_str(), - 2 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - input, - output); - // fix LoRA layer's transformer layer ID and model ID - peft_layer->layer_guid.transformer_layer_id = - target_module->layer_guid.transformer_layer_id; - peft_layer->layer_guid.model_id = target_module->layer_guid.model_id; - { - int numdims = output->num_dims; - int dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdims; i++) { - dims[i] = output->dims[i]; - } - peft_layer->outputs[0] = - create_tensor_legion_ordering(numdims, - dims, - output->data_type, - peft_layer, - 0, - true /*create_grad*/); - } - it = layers.insert(it + 1, peft_layer); - ++it; - base_layer_to_peft_layer[target_module] = peft_layer; - peft_layer_to_peft_id[peft_layer] = std::vector(); - peft_layer_to_peft_id[peft_layer].push_back(*peft_model_id); - } - } - } - - // save finetuned lora model configs to file - if (peft_config.trainable) { - std::string finetuned_model_folder = join_path({ - peft_config.cache_folder, - "finetuned_models", - peft_config.peft_model_id, - }); - fs::remove_all(finetuned_model_folder); - std::string finetuned_model_config_folder = join_path({ - finetuned_model_folder, - "config", - }); - fs::create_directories(finetuned_model_config_folder); - std::string lora_linear_config_filepath = join_path({ - finetuned_model_config_folder, - "ff_config.json", - }); - serialize_to_json_file(peft_config, lora_linear_config_filepath); - std::string optimizer_config_filepath = join_path({ - finetuned_model_config_folder, - "ff_optimizer_config.json", - }); - if (typeid(*peft_config.optimizer_config) == - typeid(LoraSGDOptimizerConfig)) { - LoraSGDOptimizerConfig const *sgd_config = - static_cast( - peft_config.optimizer_config); - serialize_to_json_file(*sgd_config, optimizer_config_filepath); - } else if (typeid(*peft_config.optimizer_config) == - typeid(LoraAdamOptimizerConfig)) { - LoraAdamOptimizerConfig const *adam_config = - static_cast( - peft_config.optimizer_config); - serialize_to_json_file(*adam_config, optimizer_config_filepath); - } else { - assert(false && "Optimizer not supported"); - } - } - - return peft_model_id; -} -#endif - Op *LoraLinear::create_operator_from_layer( FFModel &model, Layer const *layer, @@ -272,15 +145,6 @@ Op *LoraLinear::create_operator_from_layer( int max_rank = value; layer->get_int_property("max_concurrent_adapters", value); int max_concurrent_adapters = value; -#ifdef DEADCODE - std::unordered_map _peft_configs; - std::vector const &peft_ids = - model.peft_layer_to_peft_id[(Layer *)layer]; - for (int i = 0; i < peft_ids.size(); i++) { - _peft_configs.emplace( - std::make_pair(peft_ids[i], model.peft_configs[peft_ids[i]])); - } -#endif return new LoraLinear(model, layer->layer_guid, inputs[0], @@ -982,7 +846,7 @@ void LoraLinear::peft_bwd_task(Task const *task, int out_dim = output_grad.domain.hi()[0] - output_grad.domain.lo()[0] + 1; // int num_infr_tokens = bc->num_active_infr_tokens(); // int num_peft_tokens = bc->num_active_peft_tokens(); - peft_bwd_kernel_wrapper(ctx, runtime, m, bc, input_grad, output_grad); + peft_bwd_kernel_wrapper(ctx, runtime, m, bc, shard_id, input_grad, output_grad); save_peft_weights_if_needed(m, bc, in_dim, out_dim, shard_id); @@ -1018,14 +882,6 @@ bool operator==(LoraLinearParams const &lhs, LoraLinearParams const &rhs) { if (lhs.layer_guid == rhs.layer_guid && lhs.max_rank == rhs.max_rank && lhs.max_concurrent_adapters == rhs.max_concurrent_adapters && strcmp(lhs.name, rhs.name) == 0) { -#ifdef DEADCODE - for (auto const &kv : lhs.peft_configs) { - auto it = rhs.peft_configs.find(kv.first); - if (it == rhs.peft_configs.end() || !(it->second == kv.second)) { - return false; - } - } -#endif return true; } return false; @@ -1066,50 +922,6 @@ void LoraLinear::serialize(Legion::Serializer &sez) const { sez.serialize(this->layer_guid.model_id); sez.serialize(this->max_rank); sez.serialize(this->max_concurrent_adapters); -#ifdef DEADCODE - sez.serialize(this->op_type); - sez.serialize(this->peft_configs.size()); - for (auto const &kv : this->peft_configs) { - // Serialize PEFTModelID - sez.serialize(kv.first.id); - - // Serialize LoraLinearConfig and OptimizerConfig to tmp folder - // 1. Create tmp dir and serialize it - fs::path unique_temp_dir = create_unique_temp_directory(); - serialize_string(sez, unique_temp_dir.string()); - // 2. Dump LoraLinearConfig to json file in tmp dir - std::string lora_config_filename = std::string("lora_linear_config_") + - std::to_string(kv.first.id) + - std::string(".json"); - fs::path lora_config_json_filepath = unique_temp_dir / lora_config_filename; - serialize_to_json_file(kv.second, lora_config_json_filepath); - // 3. Dump optimizer to json file in tmp dir, and serialize optimizer type - std::string optimizer_filename = std::string("optimizer_config_") + - std::to_string(kv.first.id) + - std::string(".json"); - fs::path optim_config_filepath = unique_temp_dir / optimizer_filename; - assert((kv.second.trainable) == (kv.second.optimizer_config != nullptr)); - if (kv.second.trainable) { - if (typeid(*kv.second.optimizer_config) == - typeid(LoraSGDOptimizerConfig)) { - sez.serialize(OPTIMIZER_TYPE_SGD); - LoraSGDOptimizerConfig const *sgd_config = - static_cast( - kv.second.optimizer_config); - serialize_to_json_file(*sgd_config, optim_config_filepath); - } else if (typeid(*kv.second.optimizer_config) == - typeid(LoraAdamOptimizerConfig)) { - sez.serialize(OPTIMIZER_TYPE_ADAM); - LoraAdamOptimizerConfig const *adam_config = - static_cast( - kv.second.optimizer_config); - serialize_to_json_file(*adam_config, optim_config_filepath); - } else { - assert(false && "Optimizer type not yet supported"); - } - } - } -#endif sez.serialize(strlen(this->name)); sez.serialize(this->name, strlen(this->name)); } @@ -1135,58 +947,6 @@ Node LoraLinear::deserialize(FFModel &ff, dez.deserialize(deserialized_model_id); dez.deserialize(max_rank); dez.deserialize(max_concurrent_adapters); -#ifdef DEADCODE - dez.deserialize(op_type); - dez.deserialize(num_pefts); - for (int i = 0; i < num_pefts; i++) { - // Deserialize PEFTModelID - size_t pid; - dez.deserialize(pid); - PEFTModelID peft_model_id(pid); - // Deserialize tmp folder containing LoraLinearConfig and optimizer config - fs::path unique_temp_dir = fs::path(deserialize_string(dez)); - // 1. Deserialize LoraLinearConfig - std::string lora_config_filename = std::string("lora_linear_config_") + - std::to_string(pid) + - std::string(".json"); - fs::path lora_config_json_filepath = unique_temp_dir / lora_config_filename; - std::unique_ptr lora_linear_config = - deserialize_from_json_file(lora_config_json_filepath); - // 2. Deserialize optimizer if needed - if (lora_linear_config->trainable) { - std::string optimizer_filename = std::string("optimizer_config_") + - std::to_string(pid) + - std::string(".json"); - fs::path optim_config_filepath = unique_temp_dir / optimizer_filename; - OptimizerType type_; - dez.deserialize(type_); - if (type_ == OPTIMIZER_TYPE_SGD) { - std::unique_ptr sgd_optimizer_config = - deserialize_from_json_file( - optim_config_filepath); - lora_linear_config->optimizer_config = - dynamic_cast(sgd_optimizer_config.release()); - } else if (type_ == OPTIMIZER_TYPE_ADAM) { - std::unique_ptr adam_optimizer_config = - deserialize_from_json_file( - optim_config_filepath); - lora_linear_config->optimizer_config = - dynamic_cast( - adam_optimizer_config.release()); - } else { - printf("Optimizer type: %d\n", type_); - assert(false && "Optimizer type not yet supported"); - } - } - try { - fs::remove_all(unique_temp_dir); - } catch (fs::filesystem_error const &e) { - std::cerr << "Error removing tmp directory: " << e.what() << std::endl; - } - params.peft_configs.emplace( - std::make_pair(peft_model_id, *lora_linear_config)); - } -#endif dez.deserialize(name_len); dez.deserialize(name, name_len); LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); @@ -1236,19 +996,6 @@ size_t hash::operator()( hash_combine(key, params.layer_guid.model_id); hash_combine(key, params.max_rank); hash_combine(key, params.max_concurrent_adapters); -#ifdef DEADCODE - for (auto const &kv : params.peft_configs) { - hash_combine(key, kv.first.id); - hash_combine(key, kv.second.rank); - hash_combine(key, kv.second.trainable); - hash_combine(key, kv.second.cache_folder); - hash_combine(key, kv.second.peft_model_id); - hash_combine(key, kv.second.lora_alpha); - hash_combine(key, kv.second.lora_dropout); - hash_combine(key, kv.second.target_modules); - hash_combine(key, kv.second.init_lora_weights); - } -#endif return key; } }; // namespace std diff --git a/src/ops/lora_linear_params.cc b/src/ops/lora_linear_params.cc index 4eb59bc53f..4bc75d17e4 100644 --- a/src/ops/lora_linear_params.cc +++ b/src/ops/lora_linear_params.cc @@ -282,18 +282,18 @@ LoraLinearConfig LoraLinearConfig::deserialize_from_json_string( if (!j["optimizer_config"].is_null()) { optimizer_config_ = LoraOptimizerConfig::fromJson(j["optimizer_config"]); } - LoraLinearConfig config( - j["cache_folder"].get(), - j["peft_model_id"].get(), - j["trainable"].get(), - optimizer_config_, // optimizer_config will be set later if present - j["init_lora_weights"].get(), - j["base_model_name_or_path"].get(), - j["precision"].get(), - j["rank"].get(), - j["lora_alpha"].get(), - j["lora_dropout"].get(), - j["target_modules"].get>()); + LoraLinearConfig config = LoraLinearConfig::EmptyConfig; + config.cache_folder = j["cache_folder"].get(); + config.peft_model_id = j["peft_model_id"].get(); + config.rank = j["rank"].get(); + config.lora_alpha = j["lora_alpha"].get(); + config.lora_dropout = j["lora_dropout"].get(); + config.target_modules = j["target_modules"].get>(); + config.trainable = j["trainable"].get(); + config.init_lora_weights = j["init_lora_weights"].get(); + config.base_model_name_or_path = j["base_model_name_or_path"].get(); + config.precision = j["precision"].get(); + config.optimizer_config = optimizer_config_; return config; } diff --git a/src/runtime/peft_weight_allocator.cc b/src/runtime/peft_weight_allocator.cc index 2dd9a4711b..bd33076309 100644 --- a/src/runtime/peft_weight_allocator.cc +++ b/src/runtime/peft_weight_allocator.cc @@ -23,7 +23,7 @@ using Legion::TaskLauncher; void PEFTMemoryManager::allocate_inference_memory() { // allocate chunk of memory for all the PEFT adapters Realm::Rect<1, coord_t> bounds(Realm::Point<1, coord_t>(0), - Realm::Point<1, coord_t>(max_lora_size - 1)); + Realm::Point<1, coord_t>(max_lora_size*max_concurrent_adapters - 1)); std::vector field_sizes; field_sizes.push_back(sizeof(char)); Realm::RegionInstance::create_instance(peftLegionInst, @@ -39,7 +39,7 @@ void PEFTMemoryManager::allocate_inference_memory() { void PEFTMemoryManager::allocate_finetuning_memory() { size_t ft_size = max_lora_size * 3; // weights, gradients, momentum values ft_size += - max_peft_tokens * (in_dim + max_rank); // input, low-rank activations + max_peft_tokens * (in_dim + max_rank) * data_type_size(dt); // input, low-rank activations // allocate chunk of memory for PEFT adapter Realm::Rect<1, coord_t> bounds(Realm::Point<1, coord_t>(0), Realm::Point<1, coord_t>(ft_size - 1)); diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index a25677b22e..7d1e338d8f 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -268,8 +268,9 @@ void RequestManager::set_peft_config(PEFTModelID const &peft_model_id, // check that peft_model_id is not already in use assert(peft_configs.find(peft_model_id) == peft_configs.end() && "PEFT model ID already in use"); - peft_configs[peft_model_id] = LoraLinearConfig::deserialize_from_json_string( - peft_config.serialize_to_json_string()); + // LoraLinearConfig new_config = LoraLinearConfig::deserialize_from_json_string( + // peft_config.serialize_to_json_string()); + peft_configs[peft_model_id] = peft_config; } LoraLinearConfig const & @@ -304,6 +305,7 @@ PEFTModelID * std::cout << peft_config << std::endl; assert(false); } + std::cout << "Registering PEFT adapter" << peft_config.serialize_to_json_string() << std::endl; // go over base_layer_to_peft_layer and check that you can find at least one // match for (int i = 0; i < peft_config.target_modules.size(); i++) { @@ -699,6 +701,8 @@ void RequestManager::add_peft_config_to_request_info( std::string peft_config_str = peft_config.serialize_to_json_string(); std::strcpy(bc.requestsInfo[req_idx].peft_model_config_str, peft_config_str.c_str()); + // std::cout << "Added PEFT config to request info: " + // << bc.requestsInfo[req_idx].peft_model_config_str << std::endl; } BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, diff --git a/tests/peft/hf_finetune.py b/tests/peft/hf_finetune.py index a2fc5548ab..8a53ef8c9c 100644 --- a/tests/peft/hf_finetune.py +++ b/tests/peft/hf_finetune.py @@ -77,7 +77,7 @@ def main(): if args.save_peft_tensors: make_debug_dirs() register_peft_hooks(model) - save_model_weights(model, target_modules=["lora", "lm_head", "down_proj"]) + save_model_weights(model, target_modules=["lora", "lm_head", "down_proj", "up_proj"]) # Load fine-tuning dataset data = load_dataset("Abirate/english_quotes") diff --git a/tests/peft/peft_alignment_test.py b/tests/peft/peft_alignment_test.py index cc677cd51a..bc9d8d9d24 100644 --- a/tests/peft/peft_alignment_test.py +++ b/tests/peft/peft_alignment_test.py @@ -17,7 +17,7 @@ def check_bwd_pass(self): def check_step(self, step_idx, learning_rate=0.001): raise NotImplementedError() -class LllamaAlignmentTest(AlignmentTest): +class LlamaAlignmentTest(AlignmentTest): def __init__(self, model_name, tp_degree=1): self.model_name = model_name self.peft_config = PeftConfig.from_pretrained(model_name) @@ -538,11 +538,47 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance output_comparison = TensorComparisonIdxs(hf_tensor_type="output_gradient", ff_tensor_type="output_gradient", hf_tensor_idx=0, ff_tensor_idx=0) hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + print(f"w3 {i} grad output") + print("flexflow tensor shape:", ff_tensor.squeeze().shape) + print(ff_tensor.squeeze()) + print("huggingface tensor shape:", hf_tensor.squeeze().T.shape) + print(hf_tensor.squeeze().T) compare(hf_tensor, ff_tensor, label=f"W3 {i} gradient output") + # print(f"W3 {i} output matches!") + # print(f"FF shape: {ff_tensor.shape}") + # print(f"HF shape: {hf_tensor.shape}") + + # hf_w3_output = hf_tensor.clone() + + # W3 (up_proj) input input_comparison = TensorComparisonIdxs(hf_tensor_type="input_gradient", ff_tensor_type="input_gradient", hf_tensor_idx=0, ff_tensor_idx=0) hf_tensor = get_hf_tensor(hf_tensor_name, input_comparison) ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape, tp_type=TPType.TO_REDUCE) + + # w3_input_torch = torch.matmul(hf_tensor, torch.transpose(ff_tensor, 0, 1)) + # ff_up_proj_weight_path="/usr/.cache/flexflow/debug/flexflow/weights/step_0/shard_0/layers.11.layers.11.mlp.up_proj.weight_0" + # hf_up_proj_weight_path="/usr/.cache/flexflow/debug/huggingface/weights/step_0/layers.11.mlp.up_proj.weight" + # hf_up_proj_weight = torch.load(hf_up_proj_weight_path, map_location='cpu') + # print(hf_up_proj_weight.shape) + # ff_up_proj_weight = load_ff_tensor(ff_up_proj_weight_path, hf_up_proj_weight.shape[::-1]) + # print(ff_up_proj_weight.shape) + # ff_up_proj_weight = torch.from_numpy(ff_up_proj_weight).to(hf_up_proj_weight.dtype) + # assert torch.allclose(hf_up_proj_weight.T, ff_up_proj_weight, atol=1e-5) + + # print("HF W3 output shape:", hf_w3_output.shape) + # print("HF W3 weight shape:", hf_up_proj_weight.shape) + # print("HF W3 input shape:", hf_tensor.shape) + + # simulated_w3_input = torch.matmul(hf_w3_output.squeeze(), hf_up_proj_weight) + # print("simulated W3 input shape:", simulated_w3_input.T.shape) + # print(simulated_w3_input.T) + print(f"w3 {i} grad input") + print("flexflow tensor shape:", ff_tensor.squeeze().shape) + print(ff_tensor.squeeze()) + print("huggingface tensor shape:", hf_tensor.squeeze().T.shape) + print(hf_tensor.squeeze().T) + compare(hf_tensor, ff_tensor, label=f"W3 {i} gradient input") # Attn O-proj @@ -695,7 +731,24 @@ def compare(hf_tensor, ff_tensor, label="", tolerance=1e-4): torch.testing.assert_close(hf_gradient, (hf_original_weight-hf_finetuned_weight)/learning_rate, rtol=1.3e-6, atol=1e-5) ff_gradient_name = convert_hf_filename_to_ff(hf_gradient_name) ff_gradient = get_ff_tensor(ff_gradient_name, hf_gradient.shape, tp_type=TPType.REPLICATE) + + lora_low_rank_activation_fwd_path = f"/usr/.cache/flexflow/debug/flexflow/fwd/step_{step_idx}/shard_0/layers.{i}.layers.{i}.mlp.down_proj.lora.low_rank_activation" + lora_low_rank_activation_bwd_path = f"/usr/.cache/flexflow/debug/flexflow/bwd/step_{step_idx}/shard_0/layers.{i}.layers.{i}.mlp.down_proj.lora.low_rank_activation" + lora_low_rank_activation_fwd = load_ff_tensor(lora_low_rank_activation_fwd_path, [16, 128])[:,:self.num_tokens] + lora_low_rank_activation_fwd = torch.from_numpy(lora_low_rank_activation_fwd) + lora_low_rank_activation_bwd = load_ff_tensor(lora_low_rank_activation_bwd_path, [16, 24]) + lora_low_rank_activation_bwd = torch.from_numpy(lora_low_rank_activation_bwd) + torch.testing.assert_close(lora_low_rank_activation_fwd, lora_low_rank_activation_bwd, rtol=1.3e-6, atol=1e-5) + + print(f"LoRA_B {i} gradient") + print("FlexFlow shape: ", ff_gradient.shape) + print(ff_gradient) + print("HuggingFace shape: ", hf_gradient.shape) + print(hf_gradient.squeeze().T) compare(hf_gradient, ff_gradient, label=f"LoRA_B {i} gradient") + + + # ff_out_gradient_name = f"layers.{i}.layers.{i}.mlp.down_proj.lora.output_gradient_0" # ff_fwd_folder = os.path.join(ff_path, "fwd", f"step_{step_idx}", "shard_0") # ff_bwd_folder = os.path.join(ff_path, "bwd", f"step_{step_idx}", "shard_0") @@ -737,7 +790,7 @@ def compare(hf_tensor, ff_tensor, label="", tolerance=1e-4): args = parser.parse_args() if __name__ == "__main__": - llama_alignment = LllamaAlignmentTest(args.model_name, tp_degree=args.tensor_parallelism_degree) + llama_alignment = LlamaAlignmentTest(args.model_name, tp_degree=args.tensor_parallelism_degree) # llama_alignment.check_weights_alignment() for i in range(args.num_steps): llama_alignment.check_fwd_pass(i) diff --git a/tests/peft_test.sh b/tests/peft_test.sh index 173fb37fd9..b7adce8028 100755 --- a/tests/peft_test.sh +++ b/tests/peft_test.sh @@ -34,7 +34,7 @@ export LEGION_BACKTRACE=1 python ./inference/utils/download_peft_model.py goliaro/llama-160m-lora --base_model_name JackFram/llama-160m # Run PEFT in Huggingface to get ground truth tensors -python ./tests/peft/hf_finetune.py --peft-model-id goliaro/llama-160m-lora --save-peft-tensors --use-full-precision +python ./tests/peft/hf_finetune.py --peft-model-id goliaro/llama-160m-lora --save-peft-tensors --use-full-precision -lr 1.0 # Python test echo "Python test" @@ -45,8 +45,8 @@ echo "Python test" # C++ test echo "C++ test" ./build/inference/peft/peft \ - -ll:gpu 2 -ll:cpu 4 -ll:util 4 \ - -tensor-parallelism-degree 2 \ + -ll:gpu 1 -ll:cpu 4 -ll:util 4 \ + -tensor-parallelism-degree 1 \ -ll:fsize 8192 -ll:zsize 12000 \ -llm-model JackFram/llama-160m \ -finetuning-dataset ./inference/prompt/peft_dataset.json \ @@ -55,7 +55,7 @@ echo "C++ test" --use-full-precision \ --inference-debugging # Check alignment -python ./tests/peft/peft_alignment_test.py -tp 2 +python ./tests/peft/peft_alignment_test.py -tp 1 -lr 1.0 # Print succeess message echo ""