diff --git a/README.md b/README.md index 4349a53279..75cdda3c35 100644 --- a/README.md +++ b/README.md @@ -35,9 +35,6 @@ Mixture-of-Experts (MoE) has garnered significant attention for their ability t -## Installation - -Installation instructions can be found in [INSTALL.md](./INSTALL.md). ## Setup @@ -45,6 +42,7 @@ Installation instructions can be found in [INSTALL.md](./INSTALL.md). git clone https://github.com/duterscmy/CD-MoE.git cd CD-MoE pip install -e . +pip install -r requirements.txt ``` ## Usage @@ -84,13 +82,17 @@ python cd-moe/greedy_search/greedy_search_layer.py \ --model $model_path \ --dynamic-weight-file $expert_weight_file \ --greedy-expert-file $greedy_search_expert_result_file \ - --output $greedy_search_layer_result_file + --output $greedy_search_layer_result_file \ + --prune-num-expert 6 \ + --prune-num-layer 15 ``` ### 3. Fine-tune (Optional) +The prune experts and layers in `cd-moe/exp_hyper.py` need to match the options in `cd-moe/finetune/finetune.py`, and replace the file paths in `cd-moe/modeling_deepseek.py` with real paths. You can use the `--no-c4` option to skip lm fine-tuning and directly fine-tune for downstream tasks. ```bash -cp cd-moe/modeling_deepseek.py $model_path +echo "num_route_experts=6;prune_layer_num=15" > cd-moe/exp_hyper.py +cp cd-moe/modeling_deepseek.py cd-moe/exp_hyper.py $model_path python cd-moe/finetune/finetune.py \ --input $sft_data \ --c4-input $lm_data \ @@ -98,11 +100,11 @@ python cd-moe/finetune/finetune.py \ --dynamic-weight-file $expert_weight_file \ --greedy-expert-file $greedy_search_expert_result_file \ --greedy-expert-file $greedy_search_layer_result_file \ - --output-dir $sft_model_path + --output-dir $sft_model_path \ + --prune-num-expert 6 \ + --prune-num-layer 15 ``` -You can use the `--no-c4` option to skip lm fine-tuning and directly fine-tune for downstream tasks. - For some intermediate variables, we provide some already generated results. The open-source model and C4 training data need to be downloaded locally: - calibration_data_file: `cd-moe/data/calibration_data.json` - expert_weight_file: `cd-moe/data/dynamic_weight.json` @@ -112,25 +114,15 @@ For some intermediate variables, we provide some already generated results. The ## Evaluation Install [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) -Evaluate the pruned model: +Evaluate the pruned model or the finetuned model: ```bash cp $expert_weight_file $greedy_search_expert_result_file $greedy_search_layer_result_file cd-moe/modeling_deepseek.py $model_path lm_eval --model hf \ - --model_args $model_path \ - --tasks arc-challenge,boolq,piqa,rte,obqa,winogrande,mmlu,hellaswag \ + --model_args pretrained=$model_path,dtype="bfloat16",trust_remote_code=True \ + --tasks arc_challenge,boolq,piqa,rte,obqa,winogrande,mmlu,hellaswag \ --device cuda:0 \ --batch_size 8 ``` -Evaluate the fine-tuned model: -```bash -lm_eval --model hf \ - --model_args $sft_model_path \ - --tasks arc-challenge,boolq,piqa,rte,obqa,winogrande,mmlu,hellaswag \ - --device cuda:0 \ - --batch_size 8 \ - --ignore_mismatched_sizes -``` -`--ignore_mismatched_sizes` option is necessary because, during fine-tuning, to save GPU memory, the unnecessary expert parameters in the model are set to empty, causing a mismatch between the parameter sizes saved in the model file and the default parameter sizes in the model config. ## Acknowledgement This repository is build upon the [Transformers](https://github.com/huggingface/transformers) repositories. diff --git a/cd-moe/exp_hyper.py b/cd-moe/exp_hyper.py index cea91b16fd..69462b20f8 100644 --- a/cd-moe/exp_hyper.py +++ b/cd-moe/exp_hyper.py @@ -1 +1 @@ -num_route_experts = 0;prune_layer_num = 9 \ No newline at end of file +num_route_experts=6;prune_layer_num=15 \ No newline at end of file diff --git a/cd-moe/finetune/finetune.py b/cd-moe/finetune/finetune.py index ed7b6e8a43..93771c68e5 100644 --- a/cd-moe/finetune/finetune.py +++ b/cd-moe/finetune/finetune.py @@ -165,37 +165,23 @@ classify_remained_experts(name, prune_layer_idx_to_expert_idx)): for param in module.parameters(): param.requires_grad = True + print("set {} requires_grad=True".format(name)) print_trainable_parameters(model) -# set //prune experts// of prune layer to empty to reduce memory -num_prune_module = 0 -for name, module in model.named_modules(): - if isinstance(module, (torch.nn.Linear)) and \ - classify_pruned_experts(name, prune_layer_idx_to_expert_idx): - # print(name) - num_prune_module += 1 - for param in module.parameters(): - param.requires_grad = False - param.data = torch.tensor( - [[0.1]], dtype=param.dtype, device=param.device) -print("set {} modules to empty".format(num_prune_module)) -print_trainable_parameters(model) - - -# for layer_idx, layer in enumerate(model.model.layers): -# if layer_idx == 0: -# continue -# moe_layer_idx = layer_idx - 1 -# for expert_idx, param in enumerate(layer.mlp.expert_weights): -# static_weight = dynamic_weights[(moe_layer_idx, expert_idx)] -# if args.finetune_route_weight: -# param.requires_grad = True -# else: +# # set //prune experts// of prune layer to empty to reduce memory +# num_prune_module = 0 +# for name, module in model.named_modules(): +# if isinstance(module, (torch.nn.Linear)) and \ +# classify_pruned_experts(name, prune_layer_idx_to_expert_idx): +# # print(name) +# num_prune_module += 1 +# for param in module.parameters(): # param.requires_grad = False -# param.data = torch.tensor( -# [static_weight], dtype=param.dtype, device=param.device) -print("load static expert weight") -print_trainable_parameters(model) +# param.data = torch.tensor( +# [[0.1]], dtype=param.dtype, device=param.device) +# print("set {} modules to empty".format(num_prune_module)) +# print_trainable_parameters(model) + # finetune # 加载数据集 diff --git a/cd-moe/greedy_search/greedy_search_layer.py b/cd-moe/greedy_search/greedy_search_layer.py index d34fd3de80..4d68c60d80 100644 --- a/cd-moe/greedy_search/greedy_search_layer.py +++ b/cd-moe/greedy_search/greedy_search_layer.py @@ -125,9 +125,9 @@ def get_total_js_divergence(origin_layer_outputs, prune_layer_outputs): parser.add_argument("--num-expert", type=int, default=64, help="默认为deepseek16B专家数") -parser.add_argument("--prune-layer", default=15, type=int, +parser.add_argument("--prune-num-layer", default=15, type=int, help="剪枝层的数量") -parser.add_argument("--prune-expert", default=6, type=int, +parser.add_argument("--prune-num-expert", default=6, type=int, help="剪枝专家的数量") parser.add_argument("--prune-expert-strategy", default="greedy_jl", type=str, help="剪枝专家的策略") diff --git a/cd-moe/modeling_deepseek.py b/cd-moe/modeling_deepseek.py index 3a99dc7de6..30a9b11cfb 100644 --- a/cd-moe/modeling_deepseek.py +++ b/cd-moe/modeling_deepseek.py @@ -53,7 +53,7 @@ ) from transformers.utils.import_utils import is_torch_fx_available from .configuration_deepseek import DeepseekConfig -# from .exp_hyper import num_route_experts, prune_layer_num +from .exp_hyper import * # if is_flash_attn_2_available(): @@ -391,16 +391,17 @@ def backward(ctx, grad_output): 1, dtype=ctx.dtype, device=grad_output.device) return grad_output, grad_loss -# hyper parameters + layer_num = 27 -num_route_experts = 0 -prune_layer_num = 9 +num_route_experts = 6 +prune_layer_num = 15 -current_dir = os.path.dirname(os.path.abspath(__file__)) +current_dir = "/home/work/mt_cmy/programs/CD-MoE/cd-moe/data" # change to your repo dir # greedy search layer result +layer_result_file = "layer_idx_order.e6.json" if num_route_experts == 6 else "layer_idx_order.e0.json" condense_layer_order_path = os.path.join( - current_dir, "layer_idx_order.e6.json") + current_dir, layer_result_file) condense_layer_order = json.load(open(condense_layer_order_path, 'r')) prune_layer_idxs = condense_layer_order[:prune_layer_num] print("condense layer idx {}".format(prune_layer_idxs)) @@ -683,9 +684,8 @@ def forward( if past_key_value is not None: # print(self.layer_idx) cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - cache_idx = self.layer_idx key_states, value_states = past_key_value.update( - key_states, value_states, cache_idx, cache_kwargs) + key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -1278,12 +1278,6 @@ def __init__(self, config: DeepseekConfig): self.gradient_checkpointing = False # Initialize weights and apply final processing - - global layer_num, prune_layer_idxs - self.layer_num = layer_num - self.prune_layer_idxs = prune_layer_idxs - - self.post_init() def get_input_embeddings(self): @@ -1380,15 +1374,6 @@ def forward( next_decoder_cache = None for tmp_layer_idx, decoder_layer in enumerate(self.layers): - if tmp_layer_idx > 0: - global global_layer - relative_layer = global_layer % self.layer_num - if relative_layer in self.prune_layer_idxs: - # print("layer_num {} current_layer {}, BLOCK_TRIM layer".format( - # self.layer_num, relative_layer)) - global_layer +=1 - continue - if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000..2d645da399 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +torch==2.0.0 +peft==0.10.0 +accelerate==1.2.0 \ No newline at end of file diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 04ba337ef4..ccba77b31e 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -357,8 +357,9 @@ def update( # Update the cache if len(self.key_cache) <= layer_idx: - self.key_cache.append(key_states) - self.value_cache.append(value_states) + for _ in range(layer_idx-len(self.key_cache)+1): + self.key_cache.append(key_states) + self.value_cache.append(value_states) else: self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)