-
Notifications
You must be signed in to change notification settings - Fork 1
/
finetune.py
311 lines (262 loc) · 9.56 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
# This code is based on the revised code from fastchat based on tatsu-lab/stanford_alpaca.
import json
import random
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Sequence
import torch
import transformers
from accelerate.utils import DistributedType
from data_mix import Mix_dataset
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from peft import LoraConfig, get_peft_model
from transformers import Trainer, deepspeed
from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default='')
@dataclass
class DataArguments:
data_path: str = field(
default='data.txt', metadata={'help': 'Path to the training data.'})
given_num: bool = False
img_size: int = 224
batch_size: int = 4
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default='adamw_torch')
max_length: int = field(
default=4096,
metadata={
'help':
'Maximum sequence length. Sequences will be right padded (and possibly truncated).'
},
)
use_lora: bool = False
fix_vit: bool = True
fix_sampler: bool = False
label_names: List[str] = field(default_factory=lambda: ['samples'])
@dataclass
class LoraArguments:
lora_r: int = 64
lora_alpha: int = 64
lora_dropout: float = 0.05
lora_target_modules: List[str] = field(default_factory=lambda: [
'attention.wqkv',
'attention.wo',
'feed_forward.w1',
'feed_forward.w2',
'feed_forward.w3',
])
lora_weight_path: str = ''
lora_bias: str = 'none'
def maybe_zero_3(param):
if hasattr(param, 'ds_id'):
assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == 'none':
to_return = {k: t for k, t in named_params if 'lora_' in k}
elif bias == 'all':
to_return = {
k: t
for k, t in named_params if 'lora_' in k or 'bias' in k
}
elif bias == 'lora_only':
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if 'lora_' in k:
to_return[k] = t
bias_name = k.split('lora_')[0] + 'bias'
lora_bias_names.add(bias_name)
elif 'bias' in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
return to_return
local_rank = None
def rank0_print(*args):
if local_rank == 0:
print(*args)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
output_dir: str,
bias='none'):
"""Collects the state dict and dump to disk."""
# check if zero3 mode enabled
if deepspeed.is_deepspeed_zero3_enabled():
state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict(
)
else:
if trainer.args.use_lora:
state_dict = get_peft_state_maybe_zero_3(
trainer.model.named_parameters(), bias)
else:
state_dict = trainer.model.state_dict()
if trainer.args.should_save and trainer.args.local_rank == 0:
trainer._save(output_dir, state_dict=state_dict)
@dataclass
class DataCollatorForSupervisedDataset:
"""Collate examples for supervised fine-tuning."""
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
instances = [instance['samples'] for instance in instances]
text_input, data_type = tuple(
[instance[key] for instance in instances]
for key in ('text_input', 'data_type'))
if 'image' not in instances[0]:
text_input = [instance['text_input'][0] for instance in instances]
batch = dict(
text_input=text_input,
data_type=data_type,
)
if 'image' in instances[0]:
images = [instance['image'] for instance in instances]
batch['image'] = images
return dict(samples=batch)
def make_supervised_data_module(
tokenizer: transformers.PreTrainedTokenizer,
data_args,
) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
rank0_print('Loading data...')
if data_args.data_path.endswith('json'):
train_json = json.load(open(data_args.data_path))
elif data_args.data_path.endswith('txt'):
train_json = {}
with open(data_args.data_path) as f:
lines = f.readlines()
for line in lines:
line = line.strip()
line = line.split(' ')
with open(line[0]) as f:
temp = json.load(f)
if data_args.given_num:
assert len(line) == 2
num = int(float(line[1]) * 1000)
if len(temp) > num:
temp = random.sample(temp, num)
else:
ex_temp = []
for i in range(num - len(temp)):
ex_temp.append(random.choice(temp))
temp.extend(ex_temp)
else:
if len(line) == 2:
ratio = float(line[1])
new_len = int(len(temp) * ratio)
if ratio < 1:
temp = random.sample(temp, new_len)
elif ratio > 1:
ex_temp = []
for i in range(new_len - len(temp)):
ex_temp.append(random.choice(temp))
temp.extend(ex_temp)
rank0_print(f'Load {len(temp)} samples from {line}')
train_json[line[0]] = temp
train_dataset = Mix_dataset(
train_json,
data_args.batch_size,
img_size=data_args.img_size,
local_rank=local_rank)
print(str(len(train_dataset)) + 'samples is loaded')
eval_dataset = None
data_collator = DataCollatorForSupervisedDataset()
return dict(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
)
def train():
global local_rank
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments, LoraArguments))
(
model_args,
data_args,
training_args,
lora_args,
) = parser.parse_args_into_dataclasses()
if getattr(training_args, 'deepspeed', None):
training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
local_rank = training_args.local_rank
device_map = None
# Set RoPE scaling factor
config = transformers.AutoConfig.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
trust_remote_code=True,
)
config.use_cache = False
config.max_length = training_args.max_length
# Load model and tokenizer
print(f'Load model from: {model_args.model_name_or_path}')
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
cache_dir=training_args.cache_dir,
device_map=device_map,
trust_remote_code=True,
)
if data_args.img_size != 336:
model.vit.resize_pos()
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
padding_side='right',
use_fast=False,
trust_remote_code=True,
)
model.tokenizer = tokenizer
if training_args.fix_vit:
model.vit.requires_grad_(False)
else:
model.vit.requires_grad_(True)
model.vit.vision_tower.vision_model.post_layernorm = torch.nn.Identity(
)
if training_args.fix_sampler:
model.vision_proj.requires_grad_(False)
else:
model.vision_proj.requires_grad_(True)
if training_args.use_lora:
for name, param in model.model.named_parameters():
param.requires_grad = False
lora_config = LoraConfig(
r=lora_args.lora_r,
lora_alpha=lora_args.lora_alpha,
target_modules=lora_args.lora_target_modules,
lora_dropout=lora_args.lora_dropout,
bias=lora_args.lora_bias,
task_type='CAUSAL_LM',
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
if training_args.gradient_checkpointing:
model.enable_input_require_grads()
# Load data
data_module = make_supervised_data_module(
tokenizer=tokenizer, data_args=data_args)
print(transformers.processing_utils.logging.is_progress_bar_enabled())
transformers.processing_utils.logging.enable_progress_bar()
# Start trainner
trainer = Trainer(
model=model, tokenizer=tokenizer, args=training_args, **data_module)
trainer.train()
trainer.save_state()
safe_save_model_for_hf_trainer(
trainer=trainer,
output_dir=training_args.output_dir,
bias=lora_args.lora_bias)
if __name__ == '__main__':
train()