-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathbert.py
624 lines (495 loc) · 23.7 KB
/
bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
import torch
import math
import torch.nn.functional as F
import numpy as np
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn.parameter import Parameter
from transformers import BertPreTrainedModel, BertModel, BertForMaskedLM, AutoConfig
from transformers.modeling_outputs import SequenceClassifierOutput
from .utils import ConvexSampler
activation_map = {'relu': nn.ReLU(), 'tanh': nn.Tanh()}
class BERT_DOC(BertPreTrainedModel):
def __init__(self, config, args):
super(BERT_DOC, self).__init__(config)
self.num_labels = args.num_labels
self.bert = BertModel(config)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = activation_map[args.activation]
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, args.num_labels)
self.init_weights()
def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None,
feature_ext=False, mode=None, loss_fct=None, centroids = None):
outputs = self.bert(
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True)
encoded_layer_12 = outputs.hidden_states
pooled_output = self.dense(encoded_layer_12[-1].mean(dim=1))
pooled_output = self.dropout(pooled_output)
pooled_output = self.activation(pooled_output)
logits = self.classifier(pooled_output)
logits = self.dropout(logits)
sigmoid = nn.Sigmoid()
logits = sigmoid(logits)
if feature_ext:
return pooled_output
else:
if mode == 'train':
target = F.one_hot(labels, num_classes = self.num_labels)
loss_bce = loss_fct(logits, target.float())
return loss_bce
else:
return pooled_output, logits
class BERT(BertPreTrainedModel):
def __init__(self, config, args):
super(BERT, self).__init__(config)
self.num_labels = args.num_labels
self.bert = BertModel(config)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = activation_map[args.activation]
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, args.num_labels)
self.init_weights()
def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None,
feature_ext=False, mode=None, loss_fct=None, centroids = None):
outputs = self.bert(
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True)
encoded_layer_12 = outputs.hidden_states
pooled_output = self.dense(encoded_layer_12[-1].mean(dim=1))
pooled_output = self.activation(pooled_output)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
if feature_ext:
return pooled_output
else:
if mode == 'train':
loss_ce = loss_fct(logits, labels)
return loss_ce
else:
return pooled_output, logits
class BERT_Norm(BertPreTrainedModel):
def __init__(self, config, args):
super(BERT_Norm, self).__init__(config)
self.num_labels = args.num_labels
self.bert = BertModel(config)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.init_weights()
self.weight = Parameter(torch.FloatTensor(args.num_labels, args.feat_dim).to(args.device))
nn.init.xavier_uniform_(self.weight)
def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None,
feature_ext=False, mode=None, loss_fct=None, device = None, head = None):
outputs = self.bert(
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True)
encoded_layer_12 = outputs.hidden_states
pooled_output = encoded_layer_12[-1].mean(dim=1)
pooled_output = self.dropout(pooled_output)
pooled_output = F.normalize(pooled_output)
logits = F.linear(pooled_output, F.normalize(self.weight))
logits = F.softmax(logits, dim = 1)
if feature_ext:
return pooled_output
else:
if mode == 'train':
loss = loss_fct(logits, labels)
return loss
else:
return pooled_output, logits
class BERT_K_1_way(BertPreTrainedModel):
def __init__(self, config, args):
super(BERT_K_1_way, self).__init__(config)
self.num_labels = args.num_labels
self.bert = BertModel(config)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = activation_map[args.activation]
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.sampler = ConvexSampler(args)
self.classifier = nn.Linear(config.hidden_size, self.num_labels + 1)
self.t = args.temp
self.init_weights()
def forward(self, input_ids = None, token_type_ids = None, attention_mask=None , labels = None,
feature_ext = False, mode = None, loss_fct = None):
outputs = self.bert(
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True)
encoded_layer_12 = outputs.hidden_states
pooled_output = self.dense(encoded_layer_12[-1].mean(dim=1))
if mode is not 'test':
pooled_output, labels = self.sampler(pooled_output, labels, mode=mode)
pooled_output = self.activation(pooled_output)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
if feature_ext:
return pooled_output
else:
if mode == 'train':
loss = loss_fct(torch.div(logits, self.t), labels)
return loss
else:
return pooled_output, logits, labels
class BERT_SEG(BertPreTrainedModel):
def __init__(self, config, args):
super(BERT_SEG, self).__init__(config)
self.num_labels = args.num_labels
self.bert = BertModel(config)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = activation_map[args.activation]
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.init_weights()
self.alpha = args.alpha
self.lambda_ = args.lambda_
self.means = nn.Parameter(torch.randn(self.num_labels, args.feat_dim).cuda())
nn.init.xavier_uniform_(self.means, gain=math.sqrt(2.0))
def forward(self, input_ids = None, token_type_ids = None, attention_mask=None , labels = None,
feature_ext = False, mode = None, device=None, p_y = None, class_emb=None, loss_fct=None):
outputs = self.bert(
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True)
encoded_layer_12 = outputs.hidden_states
pooled_output = self.dense(encoded_layer_12[-1].mean(dim=1))
pooled_output = self.activation(pooled_output)
pooled_output = self.dropout(pooled_output)
if feature_ext:
return pooled_output
else:
batch_size = pooled_output.shape[0]
XY = torch.matmul(pooled_output, torch.transpose(self.means, 0, 1))
XX = torch.sum(pooled_output ** 2, dim=1, keepdim=True)
YY = torch.sum(torch.transpose(self.means, 0, 1)**2, dim=0, keepdim=True)
neg_sqr_dist = - 0.5 * (XX - 2.0 * XY + YY)
# with p_y
########################################
p_y = p_y.expand_as(neg_sqr_dist).to(device) # [bsz, n_c_seen]
dist_exp = torch.exp(neg_sqr_dist)
dist_exp_py = p_y.mul(dist_exp)
dist_exp_sum = torch.sum(dist_exp_py, dim=1, keepdim=True) # [bsz, n_c_seen] -> [bsz, 1]
logits = dist_exp_py / dist_exp_sum # [bsz, n_c, seen]
if mode == 'train':
labels_reshped = labels.view(labels.size()[0], -1) # [bsz] -> [bsz, 1]
ALPHA = torch.zeros(batch_size, self.num_labels).to(device).scatter_(1, labels_reshped, self.alpha) # margin
K = ALPHA + torch.ones([batch_size, self.num_labels]).to(device)
#######################################
dist_margin = torch.mul(neg_sqr_dist, K)
dist_margin_exp = torch.exp(dist_margin)
dist_margin_exp_py = p_y.mul(dist_margin_exp)
dist_exp_sum_margin = torch.sum(dist_margin_exp_py, dim=1, keepdim=True)
likelihood = dist_margin_exp_py / dist_exp_sum_margin
loss_ce = - likelihood.log().sum() / batch_size
#######################################
means = self.means if class_emb is None else class_emb
means_batch = torch.index_select(means, dim=0, index=labels)
loss_gen = (torch.sum((pooled_output - means_batch)**2) / 2) * (1. / batch_size)
########################################
loss = loss_ce + self.lambda_ * loss_gen
return loss
else:
return pooled_output, logits
class CosNorm_Classifier(nn.Module):
def __init__(self, in_dims, out_dims, scale=64, device = None):
super(CosNorm_Classifier, self).__init__()
self.in_dims = in_dims
self.out_dims = out_dims
self.scale = scale
self.weight = Parameter(torch.Tensor(out_dims, in_dims).to(device))
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
def forward(self, input, *args):
norm_x = torch.norm(input, 2, 1, keepdim=True)
ex = (norm_x / (1 + norm_x)) * (input / norm_x)
ew = self.weight / torch.norm(self.weight, 2, 1, keepdim=True)
return torch.mm(self.scale * ex, ew.t())
class BERT_Disaware(BertPreTrainedModel):
def __init__(self, config, args):
super(BERT_Disaware, self).__init__(config)
self.num_labels = args.num_labels
self.bert = BertModel(config)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.ReLU()
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.init_weights()
self.cosnorm_classifier = CosNorm_Classifier(
config.hidden_size, args.num_labels, args.scale, args.device)
def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None,
feature_ext=False, mode=None, loss_fct=None, centroids=None, dist_infos = None):
outputs = self.bert(
input_ids, token_type_ids, attention_mask, output_hidden_states=True)
encoded_layer_12 = outputs.hidden_states
pooled_output = self.dense(encoded_layer_12[-1].mean(dim=1))
pooled_output = self.activation(pooled_output)
pooled_output = self.dropout(pooled_output)
x = pooled_output
if feature_ext:
return pooled_output
else:
feat_size = x.shape[1]
batch_size = x.shape[0]
f_expand = x.unsqueeze(1).expand(-1, self.num_labels, -1)
centroids_expand = centroids.unsqueeze(0).expand(batch_size, -1, -1)
dist_cur = torch.norm(f_expand - centroids_expand, 2, 2)
values_nn, labels_nn = torch.sort(dist_cur, 1)
nearest_centers = centroids[labels_nn[:, 0]]
dist_denominator = torch.norm(x - nearest_centers, 2, 1)
second_nearest_centers = centroids[labels_nn[:, 1]]
dist_numerator = torch.norm(x - second_nearest_centers, 2, 1)
dist_info = dist_numerator - dist_denominator
dist_info = torch.exp(dist_info)
scalar = dist_info
reachability = scalar.unsqueeze(1).expand(-1, feat_size)
x = reachability * pooled_output
logits = self.cosnorm_classifier(x)
if mode == 'train':
loss = loss_fct(logits, labels)
return loss
elif mode == 'eval':
return pooled_output, logits
class BERT_MDF_Pretrain(nn.Module):
def __init__(self, args):
super(BERT_MDF_Pretrain, self).__init__()
self.num_labels = args.num_labels
self.bert = BertForMaskedLM.from_pretrained(args.pretrained_bert_model)
self.dropout = nn.Dropout(0.1) #0.1
self.classifier = nn.Linear(args.feat_dim, args.num_labels)
def forward(self, X):
outputs = self.bert(**X, output_hidden_states=True)
CLSEmbedding = outputs.hidden_states[-1][:,0]
CLSEmbedding = self.dropout(CLSEmbedding)
logits = self.classifier(CLSEmbedding)
output_dir = {"logits": logits}
output_dir["hidden_states"] = outputs.hidden_states[-1][:, 0]
return output_dir
def mlmForward(self, X, Y = None):
outputs = self.bert(**X, labels = Y)
return outputs.loss
def loss_ce(self, logits, Y):
loss = nn.CrossEntropyLoss()
output = loss(logits, Y)
return output
class BERT_MDF(BertPreTrainedModel):
def __init__(self, config, args):
super(BERT_MDF, self).__init__(config)
self.num_labels = args.num_labels
self.bert = BertModel(config)
self.dropout = nn.Dropout(0.1) #0.1
self.classifier = nn.Linear(args.feat_dim, 2)
self.init_weights()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
):
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
output_hidden_states=True
)
# Complains if input_embeds is kept
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
outputs = (logits,) + outputs[
2:
] # add hidden states and attention if they are here
return outputs # (loss), logits, (hidden_states), (attentions)
class BertClassificationHead(nn.Module):
def __init__(self, config):
super(BertClassificationHead, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels-1)
def forward(self, feature):
x = self.dropout(feature)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class BertContrastiveHead(nn.Module):
def __init__(self, config):
super(BertContrastiveHead, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, feature):
x = self.dropout(feature)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class BERT_KNNCL(nn.Module):
def __init__(self, args):
super(BERT_KNNCL, self).__init__()
self.number_labels = args.anum_labels
config = AutoConfig.from_pretrained(
args.bert_model ,
num_labels=self.number_labels,
)
self.encoder_q = BertModel.from_pretrained(args.bert_model, config=config)
self.encoder_k = BertModel.from_pretrained(args.bert_model, config=config)
self.classifier_liner = BertClassificationHead(config)
self.contrastive_liner_q = BertContrastiveHead(config)
self.contrastive_liner_k = BertContrastiveHead(config)
self.m = 0.999
self.T = args.temperature
self.init_weights() # Exec
self.contrastive_rate_in_training = args.contrastive_rate_in_training
# create the label_queue and feature_queue
self.K = args.queue_size # 7500
self.register_buffer("label_queue", torch.randint(0, self.number_labels, [self.K])) # Tensor:(7500,)
self.register_buffer("feature_queue", torch.randn(self.K, config.hidden_size)) # Tensor:(7500, 768)
self.feature_queue = torch.nn.functional.normalize(self.feature_queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) # Tensor(1,)
self.top_k = args.top_k # 25
self.update_num = args.positive_num # 3
# optional and delete can improve the performance indicated
# by some experiment
# params_to_train = ["layer." + str(i) for i in range(0, 12)]
# for name, param in self.encoder_q.named_parameters():
# param.requires_grad_(False)
# for term in params_to_train:
# if term in name:
# param.requires_grad_(True)
def _dequeue_and_enqueue(self, keys, label):
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
if ptr + batch_size > self.K:
batch_size = self.K - ptr
keys = keys[: batch_size]
label = label[: batch_size]
# replace the keys at ptr (dequeue ans enqueue)
self.feature_queue[ptr: ptr + batch_size, :] = keys
self.label_queue[ptr: ptr + batch_size] = label
ptr = (ptr + batch_size) % self.K
self.queue_ptr[0] = ptr
def select_pos_neg_sample(self, liner_q, label_q):
label_queue = self.label_queue.clone().detach() # K
feature_queue = self.feature_queue.clone().detach() # K * hidden_size
# 1. expand label_queue and feature_queue to batch_size * K
batch_size = label_q.shape[0]
tmp_label_queue = label_queue.repeat([batch_size, 1])
tmp_feature_queue = feature_queue.unsqueeze(0)
tmp_feature_queue = tmp_feature_queue.repeat([batch_size, 1, 1]) # batch_size * K * hidden_size
# 2.caluate sim
cos_sim = torch.einsum('nc,nkc->nk', [liner_q, tmp_feature_queue])
# 3. get index of postive and neigative
tmp_label = label_q.unsqueeze(1)
tmp_label = tmp_label.repeat([1, self.K])
pos_mask_index = torch.eq(tmp_label_queue, tmp_label)
neg_mask_index = ~ pos_mask_index
# 4.another option
feature_value = cos_sim.masked_select(neg_mask_index)
neg_sample = torch.full_like(cos_sim, -np.inf).cuda()
neg_sample = neg_sample.masked_scatter(neg_mask_index, feature_value)
# 5.topk
pos_mask_index = pos_mask_index.int()
pos_number = pos_mask_index.sum(dim=-1)
pos_min = pos_number.min()
if pos_min == 0:
return None
pos_sample, _ = cos_sim.topk(pos_min, dim=-1)
pos_sample_top_k = pos_sample[:, 0:self.top_k] # self.topk = 25
pos_sample = pos_sample_top_k
pos_sample = pos_sample.contiguous().view([-1, 1])
neg_mask_index = neg_mask_index.int()
neg_number = neg_mask_index.sum(dim=-1)
neg_min = neg_number.min()
if neg_min == 0:
return None
neg_sample, _ = neg_sample.topk(neg_min, dim=-1)
neg_topk = min(pos_min, self.top_k)
neg_sample = neg_sample.repeat([1, neg_topk])
neg_sample = neg_sample.view([-1, neg_min])
logits_con = torch.cat([pos_sample, neg_sample], dim=-1)
logits_con /= self.T
return logits_con
def init_weights(self):
for param_q, param_k in zip(self.contrastive_liner_q.parameters(), self.contrastive_liner_k.parameters()):
param_k.data = param_q.data
def update_encoder_k(self):
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
for param_q, param_k in zip(self.contrastive_liner_q.parameters(), self.contrastive_liner_k.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
def reshape_dict(self, batch):
for k, v in batch.items():
shape = v.shape
batch[k] = v.view([-1, shape[-1]])
return batch
def l2norm(self, x: torch.Tensor):
norm = torch.pow(x, 2).sum(dim=-1, keepdim=True).sqrt()
x = torch.div(x, norm)
return x
def forward_no_multi_v2(self,
query,
positive_sample=None,
negative_sample=None,
):
labels = query["labels"]
labels = labels.view(-1)
with torch.no_grad():
self.update_encoder_k()
update_sample = self.reshape_dict(positive_sample)
bert_output_p = self.encoder_k(**update_sample)
update_keys = bert_output_p[1]
update_keys = self.contrastive_liner_k(update_keys)
update_keys = self.l2norm(update_keys)
tmp_labels = labels.unsqueeze(-1)
tmp_labels = tmp_labels.repeat([1, self.update_num])
tmp_labels = tmp_labels.view(-1)
self._dequeue_and_enqueue(update_keys, tmp_labels)
query.pop('labels')
bert_output_q = self.encoder_q(**query)
q = bert_output_q[1]
liner_q = self.contrastive_liner_q(q)
liner_q = self.l2norm(liner_q)
logits_cls = self.classifier_liner(q)
if self.number_labels == 1:
loss_fct = MSELoss()
loss_cls = loss_fct(logits_cls.view(-1), labels)
else:
loss_fct = CrossEntropyLoss()
loss_cls = loss_fct(logits_cls.view(-1, self.number_labels - 1), labels)
logits_con = self.select_pos_neg_sample(liner_q, labels)
if logits_con is not None:
labels_con = torch.zeros(logits_con.shape[0], dtype=torch.long).cuda()
loss_fct = CrossEntropyLoss()
loss_con = loss_fct(logits_con, labels_con)
loss = loss_con * self.contrastive_rate_in_training + \
loss_cls * (1 - self.contrastive_rate_in_training)
else:
loss = loss_cls
return SequenceClassifierOutput(
loss=loss,
)
def forward(self,
query, # batch_size * max_length
mode,
positive_sample=None, # batch_size * max_length
negative_sample=None, # batch_size * sample_num * max_length
):
if mode == 'train':
return self.forward_no_multi_v2(query=query, positive_sample=positive_sample,
negative_sample=negative_sample)
elif mode == 'validation':
labels = query['labels']
query.pop('labels')
seq_embed = self.encoder_q(**query)[1]
logits_cls = self.classifier_liner(seq_embed)
probs = torch.softmax(logits_cls, dim=1)
return torch.argmax(probs, dim=1).tolist(), labels.cpu().numpy().tolist()
elif mode == 'test':
query.pop('labels')
seq_embed = self.encoder_q(**query)[1]
logits_cls = self.classifier_liner(seq_embed)
probs = torch.softmax(logits_cls, dim=1)
return probs, seq_embed
else:
raise ValueError("undefined mode")