-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathBert_CLS.py
70 lines (58 loc) · 3.09 KB
/
Bert_CLS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from transformers import BertModel,BertConfig
import torch.nn as nn
import torch
from torch.autograd import Variable
class SentenceClassffier(nn.Module):
def __init__(self,pretrain_model,cache_dir,pooling = 'first-last-avg',label_number = 2):
super(SentenceClassffier, self).__init__()
self.config = BertConfig.from_pretrained(pretrain_model,cache_dir = cache_dir)
self.bert = BertModel.from_pretrained(pretrain_model,config =self.config,cache_dir = cache_dir)
self.pooling = pooling
self.act = nn.LeakyReLU()
self.label_number = label_number
self.fn = nn.Linear(self.config.hidden_size,self.label_number)
self.loss = nn.CrossEntropyLoss()
self.hidden_dim = self.config.hidden_size//2
self.lstm = nn.LSTM(self.config.hidden_size,self.hidden_dim,bias=True,bidirectional=True,batch_first=True)
## todo
## 添加lstm层 捕捉上下文语义信息
### 或者使用cnn提取local信息
def forward(self,text,attention_mask,token_type_ids,labels = None):
out =self.bert(text,attention_mask = attention_mask,token_type_ids= token_type_ids,output_hidden_states = True)
if self.pooling == 'cls':
logits = out.last_hidden_state[:, 0] # [batch, 768]
if self.pooling == 'pooler':
logits = out.pooler_output # [batch, 768]
if self.pooling == 'last-avg':
# last = out.last_hidden_state.transpose(1, 2) # [batch, 768, seqlen]
# logits = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [batch, 768]
last = out.last_hidden_state
bsz = last.shape[0]
hidden = self.rand_init_hidden(bsz)
logits,_ = self.lstm(last,hidden)
logits = logits[:,-1]
if self.pooling == 'first-last-avg':
# first = out.hidden_states[1].transpose(1, 2) # [batch, 768, seqlen]
# last = out.hidden_states[-1].transpose(1, 2) # [batch, 768, seqlen]
# first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1) # [batch, 768]
# last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [batch, 768]
# avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1) # [batch, 2, 768]
# logits = torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1) # [batch, 768]
first = out.hidden_states[1]
last = out.hidden_states[-1]
features = first+last
bsz = features.shape[0]
hidden = self.rand_init_hidden(bsz)
logits, _ = self.lstm(features, hidden)
logits = logits[:, -1]
if labels is None:
pred = self.act(self.fn(logits))
return pred
else:
pred = self.act(self.fn(logits))
loss = self.loss(pred,labels)
return (loss,pred)
def rand_init_hidden(self,batch_size):
return Variable(
torch.randn(2, batch_size, self.hidden_dim)), Variable(
torch.randn(2, batch_size, self.hidden_dim))