-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
92 lines (85 loc) · 3.31 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from component.ofa.modeling_ofa import OFAModelForCaption
from component.ofa.tokenization_ofa import OFATokenizer
from transformers import (
HfArgumentParser,
TrainingArguments,
set_seed,
Trainer,
BertTokenizerFast
)
from loguru import logger
from component.dataset import CaptionDataset
from component.argument import CaptionArguments
import argparse
import os
import json
from os.path import join
from component.datacollator import CaptionCollator
from component.scst import ScstTrainer
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--train_args_file", type=str, default='train_args/train_ofa.json', help="")
args = parser.parse_args()
train_args_file = args.train_args_file
# 读取参数配置
parser = HfArgumentParser((CaptionArguments, TrainingArguments))
args, training_args = parser.parse_json_file(json_file=train_args_file)
# 创建输出目录
if not os.path.exists(training_args.output_dir):
os.makedirs(training_args.output_dir)
# 记录训练参数
with open(train_args_file, 'r', encoding='utf8') as f:
train_args = json.load(f)
with open(join(training_args.output_dir, 'train_args.json'), 'w', encoding='utf8') as f:
json.dump(train_args, f, indent=2)
# 设置随机种子
set_seed(training_args.seed)
# 初始化模型
tokenizer = OFATokenizer.from_pretrained('./vocab')
model = OFAModelForCaption.from_pretrained(args.model_name_or_path)
# 是否将encoder的权重冻结,仅对decoder进行finetune
if args.freeze_encoder:
for name, param in model.encoder.named_parameters():
# encoder,decoder는 단어벡터 공유하기때문에 freeze 예외
if 'embed_tokens' in name and not args.freeze_word_embed:
param.requires_grad = True
# 冻结权重
else:
param.requires_grad = False
total = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info("Total training params: %.2fM" % (total / 1e6))
# 加载数据集
train_dataset = CaptionDataset(args.train_caption_file, args.train_image_file)
# 初始化collator
data_collator = CaptionCollator(tokenizer=tokenizer, max_seq_length=args.max_seq_length)
# 初始化训练器
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=data_collator,
tokenizer=tokenizer
)
# trainer = ScstTrainer(
# model=model,
# args=training_args,
# train_dataset=train_dataset,
# data_collator=data_collator,
# tokenizer=tokenizer
# )
# 开始训练
train_result = trainer.train()
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
trainer.save_model(join(training_args.output_dir, 'checkpoint-final'))
# 评测验证集的指标
if args.test_caption_file is not None and args.test_image_file is not None:
logger.info("*** start test ***")
test_dataset = CaptionDataset(args.test_caption_file, args.test_image_file)
metrics = trainer.evaluate(test_dataset)
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
if __name__ == '__main__':
main()