Skip to content

Commit

Permalink
Update csu_model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
allenanie authored Aug 10, 2018
1 parent d8139c6 commit 37956a0
Showing 1 changed file with 1 addition and 140 deletions.
141 changes: 1 addition & 140 deletions csu_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,146 +133,7 @@ def __init__(self, sigma_M, sigma_B, sigma_W, **kwargs):
class LSTM_w_M_Config(LSTMBaseConfig):
def __init__(self, beta, **kwargs):
super(LSTM_w_M_Config, self).__init__(beta=beta, m=True, **kwargs)


"""
Hierarchical ConvNet
"""


class ConvNetEncoder(nn.Module):
def __init__(self, config):
super(ConvNetEncoder, self).__init__()

self.word_emb_dim = config['word_emb_dim']
self.enc_lstm_dim = config['enc_lstm_dim']

self.convnet1 = nn.Sequential(
nn.Conv1d(self.word_emb_dim, 2 * self.enc_lstm_dim, kernel_size=3,
stride=1, padding=1),
nn.ReLU(inplace=True),
)
self.convnet2 = nn.Sequential(
nn.Conv1d(2 * self.enc_lstm_dim, 2 * self.enc_lstm_dim, kernel_size=3,
stride=1, padding=1),
nn.ReLU(inplace=True),
)
self.convnet3 = nn.Sequential(
nn.Conv1d(2 * self.enc_lstm_dim, 2 * self.enc_lstm_dim, kernel_size=3,
stride=1, padding=1),
nn.ReLU(inplace=True),
)
self.convnet4 = nn.Sequential(
nn.Conv1d(2 * self.enc_lstm_dim, 2 * self.enc_lstm_dim, kernel_size=3,
stride=1, padding=1),
nn.ReLU(inplace=True),
)

def forward(self, sent_tuple):
# sent_len: [max_len, ..., min_len] (batch)
# sent: Variable(seqlen x batch x worddim)

sent, sent_len = sent_tuple

sent = sent.transpose(0, 1).transpose(1, 2).contiguous()
# batch, nhid, seqlen)

sent = self.convnet1(sent)
u1 = torch.max(sent, 2)[0]

sent = self.convnet2(sent)
u2 = torch.max(sent, 2)[0]

sent = self.convnet3(sent)
u3 = torch.max(sent, 2)[0]

sent = self.convnet4(sent)
u4 = torch.max(sent, 2)[0]

emb = torch.cat((u1, u2, u3, u4), 1)

return emb


"""
Normal ConvNet
"""
class NormalConvNetEncoder(nn.Module):
def __init__(self, config):
super(NormalConvNetEncoder, self).__init__()
self.word_emb_dim = config['word_emb_dim']
self.enc_lstm_dim = config['enc_lstm_dim']
self.conv = nn.Conv2d(in_channels=1, out_channels=self.enc_lstm_dim, kernel_size=(3, self.word_emb_dim), stride=(1, self.word_emb_dim))

def encode(self, inputs):
output = inputs.transpose(0, 1).unsqueeze(1) # [batch_size, in_kernel, seq_length, embed_dim]
output = F.relu(self.conv(output)) # conv -> [batch_size, out_kernel, seq_length, 1]
output = output.squeeze(3).max(2)[0] # max_pool -> [batch_size, out_kernel]
return output

def forward(self, sent_tuple):
# sent_len: [max_len, ..., min_len] (batch)
# sent: Variable(seqlen x batch x worddim)
sent, sent_len = sent_tuple
emb = self.encode(sent)
return emb

"""
https://github.com/Shawn1993/cnn-text-classification-pytorch/blob/master/model.py
352 stars
"""
class CNN_Text_Encoder(nn.Module):
def __init__(self, config):
super(CNN_Text_Encoder, self).__init__()

self.word_emb_dim = config['word_emb_dim']

# V = args.embed_num
# D = args.embed_dim
# C = args.class_num
Ci = 1
Co = config['kernel_num'] # 100
Ks = config['kernel_sizes'] # '3,4,5'
# len(Ks)*Co

# self.convs1 = [nn.Conv2d(Ci, Co, (K, D)) for K in Ks]
self.convs1 = nn.ModuleList([nn.Conv2d(Ci, Co, (K, self.word_emb_dim)) for K in Ks])
'''
self.conv13 = nn.Conv2d(Ci, Co, (3, D))
self.conv14 = nn.Conv2d(Ci, Co, (4, D))
self.conv15 = nn.Conv2d(Ci, Co, (5, D))
'''
# self.dropout = nn.Dropout(args.dropout)
# self.fc1 = nn.Linear(len(Ks) * Co, C)

def conv_and_pool(self, x, conv):
x = F.relu(conv(x)).squeeze(3) # (N, Co, W)
x = F.max_pool1d(x, x.size(2)).squeeze(2)
return x

def forward(self, x):
# x = self.embed(x) # (N, W, D)

x = x[0].transpose(0, 1).unsqueeze(1)
# x = x.unsqueeze(1) # (N, Ci, W, D)

x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1] # [(N, Co, W), ...]*len(Ks)

x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # [(N, Co), ...]*len(Ks)

x = torch.cat(x, 1)

'''
x1 = self.conv_and_pool(x,self.conv13) #(N,Co)
x2 = self.conv_and_pool(x,self.conv14) #(N,Co)
x3 = self.conv_and_pool(x,self.conv15) #(N,Co)
x = torch.cat((x1, x2, x3), 1) # (N,len(Ks)*Co)
'''
# x = self.dropout(x) # (N, len(Ks)*Co)
# logit = self.fc1(x) # (N, C)
return x




class Classifier(nn.Module):
def __init__(self, vocab, config):
Expand Down

0 comments on commit 37956a0

Please sign in to comment.