Skip to content

Commit

Permalink
fix: bugs in model_file and finetune script
Browse files Browse the repository at this point in the history
  • Loading branch information
duterscmy committed Dec 13, 2024
1 parent 5f19e8c commit 066d051
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 77 deletions.
34 changes: 13 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,14 @@ Mixture-of-Experts (MoE) has garnered significant attention for their ability t
<!-- ![main result3](./images/figure3.png)
CD-MoE on finetuning. Left: Average accuracy with varying SpeedUp. Right: Average accuracy with varying Memory Ratio. The Gray dotted line is the dense model result. CD-MoE and LM+SFT represent condensed and supervision fine-tuned models, respectively. E(2+0) represents 2 shared experts and no routing experts, and E(2+6) represents 2 shared with 6 routing experts. -->

## Installation

Installation instructions can be found in [INSTALL.md](./INSTALL.md).

## Setup

```bash
git clone https://github.com/duterscmy/CD-MoE.git
cd CD-MoE
pip install -e .
pip install -r requirements.txt
```

## Usage
Expand Down Expand Up @@ -84,25 +82,29 @@ python cd-moe/greedy_search/greedy_search_layer.py \
--model $model_path \
--dynamic-weight-file $expert_weight_file \
--greedy-expert-file $greedy_search_expert_result_file \
--output $greedy_search_layer_result_file
--output $greedy_search_layer_result_file \
--prune-num-expert 6 \
--prune-num-layer 15
```

### 3. Fine-tune (Optional)
The prune experts and layers in `cd-moe/exp_hyper.py` need to match the options in `cd-moe/finetune/finetune.py`, and replace the file paths in `cd-moe/modeling_deepseek.py` with real paths. You can use the `--no-c4` option to skip lm fine-tuning and directly fine-tune for downstream tasks.

```bash
cp cd-moe/modeling_deepseek.py $model_path
echo "num_route_experts=6;prune_layer_num=15" > cd-moe/exp_hyper.py
cp cd-moe/modeling_deepseek.py cd-moe/exp_hyper.py $model_path
python cd-moe/finetune/finetune.py \
--input $sft_data \
--c4-input $lm_data \
--model $model_path \
--dynamic-weight-file $expert_weight_file \
--greedy-expert-file $greedy_search_expert_result_file \
--greedy-expert-file $greedy_search_layer_result_file \
--output-dir $sft_model_path
--output-dir $sft_model_path \
--prune-num-expert 6 \
--prune-num-layer 15
```

You can use the `--no-c4` option to skip lm fine-tuning and directly fine-tune for downstream tasks.

For some intermediate variables, we provide some already generated results. The open-source model and C4 training data need to be downloaded locally:
- calibration_data_file: `cd-moe/data/calibration_data.json`
- expert_weight_file: `cd-moe/data/dynamic_weight.json`
Expand All @@ -112,25 +114,15 @@ For some intermediate variables, we provide some already generated results. The
## Evaluation

Install [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)
Evaluate the pruned model:
Evaluate the pruned model or the finetuned model:
```bash
cp $expert_weight_file $greedy_search_expert_result_file $greedy_search_layer_result_file cd-moe/modeling_deepseek.py $model_path
lm_eval --model hf \
--model_args $model_path \
--tasks arc-challenge,boolq,piqa,rte,obqa,winogrande,mmlu,hellaswag \
--model_args pretrained=$model_path,dtype="bfloat16",trust_remote_code=True \
--tasks arc_challenge,boolq,piqa,rte,obqa,winogrande,mmlu,hellaswag \
--device cuda:0 \
--batch_size 8
```
Evaluate the fine-tuned model:
```bash
lm_eval --model hf \
--model_args $sft_model_path \
--tasks arc-challenge,boolq,piqa,rte,obqa,winogrande,mmlu,hellaswag \
--device cuda:0 \
--batch_size 8 \
--ignore_mismatched_sizes
```
`--ignore_mismatched_sizes` option is necessary because, during fine-tuning, to save GPU memory, the unnecessary expert parameters in the model are set to empty, causing a mismatch between the parameter sizes saved in the model file and the default parameter sizes in the model config.

## Acknowledgement
This repository is build upon the [Transformers](https://github.com/huggingface/transformers) repositories.
Expand Down
2 changes: 1 addition & 1 deletion cd-moe/exp_hyper.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
num_route_experts = 0;prune_layer_num = 9
num_route_experts=6;prune_layer_num=15
42 changes: 14 additions & 28 deletions cd-moe/finetune/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,37 +165,23 @@
classify_remained_experts(name, prune_layer_idx_to_expert_idx)):
for param in module.parameters():
param.requires_grad = True
print("set {} requires_grad=True".format(name))
print_trainable_parameters(model)

# set //prune experts// of prune layer to empty to reduce memory
num_prune_module = 0
for name, module in model.named_modules():
if isinstance(module, (torch.nn.Linear)) and \
classify_pruned_experts(name, prune_layer_idx_to_expert_idx):
# print(name)
num_prune_module += 1
for param in module.parameters():
param.requires_grad = False
param.data = torch.tensor(
[[0.1]], dtype=param.dtype, device=param.device)
print("set {} modules to empty".format(num_prune_module))
print_trainable_parameters(model)


# for layer_idx, layer in enumerate(model.model.layers):
# if layer_idx == 0:
# continue
# moe_layer_idx = layer_idx - 1
# for expert_idx, param in enumerate(layer.mlp.expert_weights):
# static_weight = dynamic_weights[(moe_layer_idx, expert_idx)]
# if args.finetune_route_weight:
# param.requires_grad = True
# else:
# # set //prune experts// of prune layer to empty to reduce memory
# num_prune_module = 0
# for name, module in model.named_modules():
# if isinstance(module, (torch.nn.Linear)) and \
# classify_pruned_experts(name, prune_layer_idx_to_expert_idx):
# # print(name)
# num_prune_module += 1
# for param in module.parameters():
# param.requires_grad = False
# param.data = torch.tensor(
# [static_weight], dtype=param.dtype, device=param.device)
print("load static expert weight")
print_trainable_parameters(model)
# param.data = torch.tensor(
# [[0.1]], dtype=param.dtype, device=param.device)
# print("set {} modules to empty".format(num_prune_module))
# print_trainable_parameters(model)


# finetune
# 加载数据集
Expand Down
4 changes: 2 additions & 2 deletions cd-moe/greedy_search/greedy_search_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def get_total_js_divergence(origin_layer_outputs, prune_layer_outputs):
parser.add_argument("--num-expert", type=int,
default=64, help="默认为deepseek16B专家数")

parser.add_argument("--prune-layer", default=15, type=int,
parser.add_argument("--prune-num-layer", default=15, type=int,
help="剪枝层的数量")
parser.add_argument("--prune-expert", default=6, type=int,
parser.add_argument("--prune-num-expert", default=6, type=int,
help="剪枝专家的数量")
parser.add_argument("--prune-expert-strategy", default="greedy_jl", type=str,
help="剪枝专家的策略")
Expand Down
31 changes: 8 additions & 23 deletions cd-moe/modeling_deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
)
from transformers.utils.import_utils import is_torch_fx_available
from .configuration_deepseek import DeepseekConfig
# from .exp_hyper import num_route_experts, prune_layer_num
from .exp_hyper import *


# if is_flash_attn_2_available():
Expand Down Expand Up @@ -391,16 +391,17 @@ def backward(ctx, grad_output):
1, dtype=ctx.dtype, device=grad_output.device)
return grad_output, grad_loss

# hyper parameters

layer_num = 27
num_route_experts = 0
prune_layer_num = 9
num_route_experts = 6
prune_layer_num = 15

current_dir = os.path.dirname(os.path.abspath(__file__))
current_dir = "/home/work/mt_cmy/programs/CD-MoE/cd-moe/data" # change to your repo dir

# greedy search layer result
layer_result_file = "layer_idx_order.e6.json" if num_route_experts == 6 else "layer_idx_order.e0.json"
condense_layer_order_path = os.path.join(
current_dir, "layer_idx_order.e6.json")
current_dir, layer_result_file)
condense_layer_order = json.load(open(condense_layer_order_path, 'r'))
prune_layer_idxs = condense_layer_order[:prune_layer_num]
print("condense layer idx {}".format(prune_layer_idxs))
Expand Down Expand Up @@ -683,9 +684,8 @@ def forward(
if past_key_value is not None:
# print(self.layer_idx)
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
cache_idx = self.layer_idx
key_states, value_states = past_key_value.update(
key_states, value_states, cache_idx, cache_kwargs)
key_states, value_states, self.layer_idx, cache_kwargs)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
Expand Down Expand Up @@ -1278,12 +1278,6 @@ def __init__(self, config: DeepseekConfig):

self.gradient_checkpointing = False
# Initialize weights and apply final processing

global layer_num, prune_layer_idxs
self.layer_num = layer_num
self.prune_layer_idxs = prune_layer_idxs


self.post_init()

def get_input_embeddings(self):
Expand Down Expand Up @@ -1380,15 +1374,6 @@ def forward(
next_decoder_cache = None

for tmp_layer_idx, decoder_layer in enumerate(self.layers):
if tmp_layer_idx > 0:
global global_layer
relative_layer = global_layer % self.layer_num
if relative_layer in self.prune_layer_idxs:
# print("layer_num {} current_layer {}, BLOCK_TRIM layer".format(
# self.layer_num, relative_layer))
global_layer +=1
continue

if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
torch==2.0.0
peft==0.10.0
accelerate==1.2.0
5 changes: 3 additions & 2 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,9 @@ def update(

# Update the cache
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
for _ in range(layer_idx-len(self.key_cache)+1):
self.key_cache.append(key_states)
self.value_cache.append(value_states)
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
Expand Down

0 comments on commit 066d051

Please sign in to comment.