-
Notifications
You must be signed in to change notification settings - Fork 4
/
train_adv.py
216 lines (174 loc) · 10.9 KB
/
train_adv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch import nn
from pytorch_transformers import AdamW, WEIGHTS_NAME, WarmupLinearSchedule
import csv
import numpy as np
import os
import logging
from fp16 import FP16_Module, FP16_Optimizer
from parallel import DataParallelModel, DataParallelCriterion
from collections import OrderedDict
from utils import *
from settings import args, TASK_DICT, init_logging, MODEL_CONFIG, MODEL_CLASS, SPECIAL_TOKENS, CONFIG_CLASS
from settings import TOKENIZER, SPECIAL_TOKEN_IDS, FILL_VAL, SAVE_NAME, FINAL_SAVE_NAME, TOKENS_WEIGHT, CONFIG_NAME
from scheduler import AnnealingLR
from regularizers import REG_TYPES, REG_TYPE_KEYS, Weight_Regularized_AdamW, Weight_Regularized_SGD
from torch.nn import CrossEntropyLoss
logger = logging.getLogger(__name__)
import OpenAttack
import datasets
import numpy as np
torch.set_printoptions(edgeitems=1024,linewidth=160)
global toadd_adv_embedding
global save_curr_embedding
toadd_adv_embedding = None
save_curr_embedding = None
def my_forward_hook(module, input_, output_):
global toadd_adv_embedding
if toadd_adv_embedding is not None:
output_ = output_ + toadd_adv_embedding
return output_
def my_backward_hook(module, input_, output_):
global save_curr_embedding
save_curr_embedding = output_[0]
def train(task_ids, model):
tasks = [args.tasks[task_id] for task_id in task_ids]
logger.info("start to train { task: %s, seq train type: %s }" % (tasks, args.seq_train_type))
model_dir = get_model_dir(tasks)
make_dir(model_dir)
train_dataset = [TASK_DICT[t]["train"] for t in tasks]
train_extra_data = []
logger.info('extra training data size: {}'.format(len(train_extra_data)))
if not model:
model = MODEL_CLASS.from_pretrained('../gpt2-medium-pretrained/').cuda()
model.resize_token_embeddings(len(TOKENIZER))
model.train()
for p in model.parameters(): p.requires_grad = False
print(model.transformer.wte.weight[0,:].shape)
print(model.transformer.wte.weight[0,:])
prefix_tokens = torch.arange(args.preseqlen).long()
prefix_weight = nn.Embedding(args.preseqlen, MODEL_CONFIG.n_embd).requires_grad_(True).to(args.device_ids[0])
prefix_weight.from_pretrained(model.transformer.wte.weight[:args.preseqlen,:])
control_trans = nn.Sequential(
nn.Linear(MODEL_CONFIG.n_embd, args.mid_dim), #1024 * 512
nn.Tanh(),
nn.Linear(args.mid_dim, MODEL_CONFIG.n_layer * 2 * MODEL_CONFIG.n_embd)).requires_grad_(True).to(args.device_ids[0])
print(prefix_weight.weight.shape)
print(control_trans[0].weight.shape, control_trans[2].weight.shape)
gen_token = get_gen_token(tasks[0])
TOKENIZER.add_tokens([gen_token])
TOKENIZER.save_pretrained(model_dir)
SPECIAL_TOKENS[tasks[0]] = gen_token
SPECIAL_TOKEN_IDS[tasks[0]] = TOKENIZER.convert_tokens_to_ids(gen_token)
logger.info('gen token = {} , gen token id = {}'.format(gen_token, SPECIAL_TOKEN_IDS[tasks[0]]))
MODEL_CONFIG.vocab_size = len(TOKENIZER)
MODEL_CONFIG.to_json_file(os.path.join(model_dir,CONFIG_NAME))
global TOKENS_WEIGHT
if len(TOKENIZER) != TOKENS_WEIGHT.shape[0]:
TOKENS_WEIGHT = torch.cat((TOKENS_WEIGHT, torch.ones([1]).cuda()))
model.resize_token_embeddings(len(TOKENIZER))
for p in model.parameters(): p.requires_grad = False
model = WrapModel(model)
# model = DataParallelModel(WrapModel(model), args.device_ids)
train_qadata = QADataset(train_dataset, "train", SPECIAL_TOKEN_IDS[tasks[0]], train_extra_data)
max_train_batch_size = max(len(train_qadata) // args.min_n_steps, args.min_batch_size)
train_dataloader = create_dataloader(train_qadata, "train", max_train_batch_size)
n_train_epochs = args.n_train_epochs[tasks[0]]
n_train_optimization_steps = len(train_qadata) * n_train_epochs
logger.info('len of train dataset: {} , max train batch size {} , num of opt steps: {}'.format(
len(train_qadata), max_train_batch_size, n_train_optimization_steps))
param_optimizer = [prefix_weight.weight, control_trans[0].weight, control_trans[2].weight]#list(filter(lambda p: p.requires_grad, model.parameters()))#list(model.named_parameters())
print([param_optimizer[i].shape for i in range(len(param_optimizer))])
optimizer = AdamW(param_optimizer, lr=args.learning_rate, eps=args.adam_epsilon)
if not args.fp32:
optimizer = FP16_Optimizer(optimizer, static_loss_scale=None, dynamic_loss_scale=True,
dynamic_loss_args={'scale_window': 100, 'min_scale': 1, 'delayed_shift': 2})
scheduler = AnnealingLR(optimizer, start_lr=args.learning_rate, warmup_iter=int(args.n_warmup_ratio*len(train_qadata)),
num_iters=int(n_train_optimization_steps), decay_style=args.decay_style)
train_loss_fct = DataParallelCriterion(CrossEntropyLoss(ignore_index=FILL_VAL, weight=TOKENS_WEIGHT), args.device_ids)
tot_n_steps = 0
train_once = TrainStep(model, optimizer, scheduler)
# Since pytorch_transformers does not support embedding input, we implement adv. training with hooks to inject the
# perturbation to the word embedding layer. It should be easier to implement when using the Transformers package.
model.model.transformer.wte.weight.requires_grad = True
model.model.transformer.wte.register_forward_hook(my_forward_hook)
model.model.transformer.wte.register_backward_hook(my_backward_hook)
suffix_length = 0
if tasks[0] == "sst": suffix_length = 8
if tasks[0] == "ag": suffix_length = 15
if tasks[0] == "snli": suffix_length = 11
for ep in range(n_train_epochs):
cum_loss, cum_qa_loss, cum_lm_loss, cur_n_inputs = 0, 0, 0, 0
for n_steps, (cq, len_cq, cqa, _, Y, gen_X, gen_Y) in enumerate(train_dataloader):
n_inputs = cqa[0].shape[0]
prefix_tokensi = prefix_tokens.unsqueeze(0).expand(cqa[0].shape[0], -1).to(args.device_ids[0])
temp_control = prefix_weight(prefix_tokensi)#.to(args.device_ids[0]) # preseqlen, emb (20*1024)
past = control_trans(temp_control)#.to(args.device_ids[0]) #bsz, preseqlen, layer*emb
bsz, seqlen, _ = past.shape
past = past.view(bsz, seqlen, MODEL_CONFIG.n_layer * 2,
MODEL_CONFIG.n_head, MODEL_CONFIG.n_embd // MODEL_CONFIG.n_head)
past0 = past[0,:,:,:,:].clone().detach()
past = past.permute([2, 0, 3, 1, 4]).split(2)
cqa_ = cqa[0].to(args.device_ids[0])
Y_ = Y[0].to(args.device_ids[0])
gen_X_ = gen_X[0].to(args.device_ids[0])
gen_Y_ = gen_Y[0].to(args.device_ids[0])
with torch.enable_grad():
embd_ori = model.model.transformer.wte(cqa_)
toadd_adv_embedding = 0.001*torch.randn((cqa_.shape[0], cqa_.shape[1], MODEL_CONFIG.n_embd)).to(args.device_ids[0])
embd_adv = embd_ori + toadd_adv_embedding
for _ in range(10): # PGD-10
loss = get_losses(model, cqa_, Y_, gen_X_, gen_Y_, train_loss_fct, past, toadd_adv_embedding=toadd_adv_embedding)[0] #parallel_model
model.zero_grad()
loss.backward(retain_graph=True)
global save_curr_embedding
# only perturb the content part during adv. training
now_mask = torch.tensor([[1 for _ in range(len_cq[0][ii]-suffix_length)] + [0 for _ in range(cqa_.shape[1]-len_cq[0][ii]+suffix_length)] for ii in range(n_inputs)]).cuda()
save_curr_embedding = (save_curr_embedding.permute(2,0,1)*now_mask).permute(1,2,0)
if args.pgd_ball == "word":
torch_small_constant = 1e-12*torch.ones(cqa_.shape[0], cqa_.shape[1], 1).to(save_curr_embedding.dtype).to(save_curr_embedding.device)
grad_norm = torch.sqrt(torch.sum(save_curr_embedding * save_curr_embedding, dim=-1, keepdim=True))
else:
torch_small_constant = 1e-12*torch.ones(cqa_.shape[0], 1, 1).to(save_curr_embedding.dtype).to(save_curr_embedding.device)
grad_norm = torch.sqrt(torch.sum(save_curr_embedding * save_curr_embedding, dim=(1,2), keepdim=True))
grad_norm = torch.max(torch_small_constant, grad_norm)
save_curr_embedding = save_curr_embedding / grad_norm
embd_adv = embd_adv.detach() + 1.25 * save_curr_embedding.detach() #alpha=1.25
if args.pgd_ball == "word":
pert_norm = torch.sqrt(torch.sum((embd_adv - embd_ori) * (embd_adv - embd_ori), dim=-1, keepdim=True))
else:
pert_norm = torch.sqrt(torch.sum((embd_adv - embd_ori) * (embd_adv - embd_ori), dim=(1,2), keepdim=True))
pert_norm = torch.max(torch_small_constant, pert_norm)
ratio = 5.0 / pert_norm # eps=5.0
ratio = torch.min(torch.ones_like(embd_ori), ratio)
toadd_adv_embedding = ratio * (embd_adv - embd_ori)
embd_adv = embd_ori + toadd_adv_embedding
loss = get_losses(model, cqa_, Y_, gen_X_, gen_Y_, train_loss_fct, past, toadd_adv_embedding=toadd_adv_embedding)[0]
train_once(loss, n_inputs)
qa_loss = loss.item() * n_inputs
cum_qa_loss += qa_loss
cur_n_inputs += n_inputs
if (n_steps + 1) % args.logging_steps == 0:
logger.info('progress {:.3f} , lr {:.1E} , loss {:.3f} , avg batch size {:.1f}'.format(
ep + cur_n_inputs/len(train_qadata), scheduler.get_lr(), cum_qa_loss/cur_n_inputs,
cur_n_inputs/(n_steps + 1)
))
torch.save(control_trans(prefix_weight(prefix_tokens.to(prefix_weight.weight.device))).cpu(), os.path.join(model_dir, "p"+str(args.preseqlen)+"lr"+str(args.learning_rate)+SAVE_NAME+"stokens"+str(ep+1)))
tot_n_steps += (n_steps + 1)
logger.info('epoch {}/{} done , tot steps {} , lr {:.1E} , loss {:.2f} , avg batch size {:.1f}'.format(
ep+1, n_train_epochs, tot_n_steps, scheduler.get_lr(), cum_qa_loss/cur_n_inputs, cur_n_inputs/(n_steps+1)
))
return model
if __name__ == '__main__':
if not args.debug:
logging.getLogger("pytorch_transformers").setLevel(logging.WARNING)
logging.getLogger("pytorch_transformers.tokenization_utils").setLevel(logging.CRITICAL)
make_dir(args.model_dir_root)
init_logging(os.path.join(args.model_dir_root, 'log_train_adv_p{}_lr{}.txt'.format(args.preseqlen, args.learning_rate)))
logger.info('args = {}'.format(str(args)))
model = None
for task_id in range(len(args.tasks)):
model = train([task_id], model)