forked from thunlp/Few-NERD
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_demo.py
173 lines (152 loc) · 6.78 KB
/
train_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from transformers import BertTokenizer
from util.data_loader import get_loader
from util.framework import FewShotNERFramework
from util.word_encoder import BERTWordEncoder
from model.proto import Proto
from model.nnshot import NNShot
import sys
import torch
from torch import optim, nn
import numpy as np
import json
import argparse
import os
import torch
import random
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--mode', default='inter',
help='training mode, must be in [inter, intra, supervised]')
parser.add_argument('--trainN', default=2, type=int,
help='N in train')
parser.add_argument('--N', default=2, type=int,
help='N way')
parser.add_argument('--K', default=2, type=int,
help='K shot')
parser.add_argument('--Q', default=3, type=int,
help='Num of query per class')
parser.add_argument('--batch_size', default=4, type=int,
help='batch size')
parser.add_argument('--train_iter', default=600, type=int,
help='num of iters in training')
parser.add_argument('--val_iter', default=100, type=int,
help='num of iters in validation')
parser.add_argument('--test_iter', default=500, type=int,
help='num of iters in testing')
parser.add_argument('--val_step', default=20, type=int,
help='val after training how many iters')
parser.add_argument('--model', default='proto',
help='model name, must be basic-bert, proto, nnshot, or structshot')
parser.add_argument('--max_length', default=100, type=int,
help='max length')
parser.add_argument('--lr', default=1e-4, type=float,
help='learning rate')
parser.add_argument('--grad_iter', default=1, type=int,
help='accumulate gradient every x iterations')
parser.add_argument('--load_ckpt', default=None,
help='load ckpt')
parser.add_argument('--save_ckpt', default=None,
help='save ckpt')
parser.add_argument('--fp16', action='store_true',
help='use nvidia apex fp16')
parser.add_argument('--only_test', action='store_true',
help='only test')
parser.add_argument('--ckpt_name', type=str, default='',
help='checkpoint name.')
parser.add_argument('--seed', type=int, default=0,
help='random seed')
parser.add_argument('--ignore_index', type=int, default=-1,
help='label index to ignore when calculating loss and metrics')
# only for bert / roberta
parser.add_argument('--pretrain_ckpt', default=None,
help='bert / roberta pre-trained checkpoint')
# only for prototypical networks
parser.add_argument('--dot', action='store_true',
help='use dot instead of L2 distance for proto')
# only for structshot
parser.add_argument('--tau', default=0.05, type=float,
help='StructShot parameter to re-normalizes the transition probabilities')
# experiment
parser.add_argument('--use_sgd_for_bert', action='store_true',
help='use SGD instead of AdamW for BERT.')
opt = parser.parse_args()
trainN = opt.trainN
N = opt.N
K = opt.K
Q = opt.Q
batch_size = opt.batch_size
model_name = opt.model
max_length = opt.max_length
print("{}-way-{}-shot Few-Shot NER".format(N, K))
print("model: {}".format(model_name))
print("max_length: {}".format(max_length))
print('mode: {}'.format(opt.mode))
set_seed(opt.seed)
print('loading model and tokenizer...')
pretrain_ckpt = opt.pretrain_ckpt or 'bert-base-uncased'
word_encoder = BERTWordEncoder(
pretrain_ckpt)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
print('loading data...')
opt.train = f'data/{opt.mode}/train.txt'
opt.test = f'data/{opt.mode}/test.txt'
opt.dev = f'data/{opt.mode}/dev.txt'
if not (os.path.exists(opt.train) and os.path.exists(opt.dev) and os.path.exists(opt.test)):
os.system(f'bash data/download.sh {opt.mode}')
train_data_loader = get_loader(opt.train, tokenizer,
N=trainN, K=K, Q=Q, batch_size=batch_size, max_length=max_length, ignore_index=opt.ignore_index)
val_data_loader = get_loader(opt.dev, tokenizer,
N=N, K=K, Q=Q, batch_size=batch_size, max_length=max_length, ignore_index=opt.ignore_index)
test_data_loader = get_loader(opt.test, tokenizer,
N=N, K=K, Q=Q, batch_size=batch_size, max_length=max_length, ignore_index=opt.ignore_index)
prefix = '-'.join([model_name, opt.mode, str(N), str(K), 'seed'+str(opt.seed)])
if opt.dot:
prefix += '-dot'
if len(opt.ckpt_name) > 0:
prefix += '-' + opt.ckpt_name
if model_name == 'proto':
print('use proto')
model = Proto(word_encoder, dot=opt.dot, ignore_index=opt.ignore_index)
framework = FewShotNERFramework(train_data_loader, val_data_loader, test_data_loader)
elif model_name == 'nnshot':
print('use nnshot')
model = NNShot(word_encoder, dot=opt.dot, ignore_index=opt.ignore_index)
framework = FewShotNERFramework(train_data_loader, val_data_loader, test_data_loader)
elif model_name == 'structshot':
print('use structshot')
model = NNShot(word_encoder, dot=opt.dot, ignore_index=opt.ignore_index)
framework = FewShotNERFramework(train_data_loader, val_data_loader, test_data_loader, N=opt.N, tau=opt.tau, train_fname=opt.train, viterbi=True)
else:
raise NotImplementedError
if not os.path.exists('checkpoint'):
os.mkdir('checkpoint')
ckpt = 'checkpoint/{}.pth.tar'.format(prefix)
if opt.save_ckpt:
ckpt = opt.save_ckpt
print('model-save-path:', ckpt)
if torch.cuda.is_available():
model.cuda()
if not opt.only_test:
if opt.lr == -1:
opt.lr = 2e-5
framework.train(model, prefix,
load_ckpt=opt.load_ckpt, save_ckpt=ckpt,
val_step=opt.val_step, fp16=opt.fp16,
train_iter=opt.train_iter, warmup_step=int(opt.train_iter * 0.1), val_iter=opt.val_iter, learning_rate=opt.lr, use_sgd_for_bert=opt.use_sgd_for_bert)
else:
ckpt = opt.load_ckpt
if ckpt is None:
print("Warning: --load_ckpt is not specified. Will load Hugginface pre-trained checkpoint.")
ckpt = 'none'
# test
precision, recall, f1, fp, fn, within, outer = framework.eval(model, opt.test_iter, ckpt=ckpt)
print("RESULT: precision: %.4f, recall: %.4f, f1:%.4f" % (precision, recall, f1))
print('ERROR ANALYSIS: fp: %.4f, fn: %.4f, within:%.4f, outer: %.4f'%(fp, fn, within, outer))
if __name__ == "__main__":
main()