-
Notifications
You must be signed in to change notification settings - Fork 3
/
model_pre.py
121 lines (92 loc) · 3.93 KB
/
model_pre.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import dgl
import torch
import torch.nn.functional as F
import numpy as np
import gensim
from attention_diffusion import GATNet
from dgl.nn.pytorch.glob import WeightAndSum
class Model(torch.nn.Module):
def __init__(self,
num_hidden,
num_layers,
num_heads,
k,
alpha,
vocab,
n_gram,
drop_out,
class_num,
num_feats,
max_length=350,
cuda=True,
):
super(Model, self).__init__()
self.is_cuda = cuda
self.vocab = vocab
self.node_hidden = torch.nn.Embedding(len(vocab), num_feats)
self.node_hidden.weight.data.copy_(torch.tensor(self.load_word2vec('/content/glove.6B.300d.txt')))
self.node_hidden.weight.requires_grad = True
self.len_vocab = len(vocab)
self.ngram = n_gram
self.max_length = max_length
self.gatnet = GATNet(class_num, class_num, class_num, num_layers, k, alpha, num_heads, merge='mean')
self.dropout = torch.nn.Dropout(p=drop_out)
self.activation = torch.nn.ReLU()
self.linear1 = torch.nn.Linear(num_feats, num_hidden, bias=True)
self.bn1 = torch.nn.BatchNorm1d(num_hidden)
self.linear2 = torch.nn.Linear(num_hidden, class_num, bias=True)
self.bn2 = torch.nn.BatchNorm1d(class_num)
self.weight_and_sum = WeightAndSum(class_num)
def reset(self):
gain = torch.nn.init.calculate_gain("relu")
torch.nn.init.xavier_normal_(self.Linear.weight, gain=gain)
torch.nn.init.xavier_normal_(self.gate_nn.weight, gain=gain)
torch.nn.init.xavier_normal_(self.attn_fc.weight, gain=gain)
def load_word2vec(self, word2vec_file):
model = gensim.models.KeyedVectors.load_word2vec_format(word2vec_file)
embedding_matrix = []
for word in self.vocab:
try:
embedding_matrix.append(model[word])
except KeyError:
embedding_matrix.append(np.random.uniform(-0.1,0.1,300))
embedding_matrix = np.array(embedding_matrix)
return embedding_matrix
def add_seq_edges(self, doc_ids: list, old_to_new: dict):
edges = []
old_edge_id = []
for index, src_word_old in enumerate(doc_ids):
src = old_to_new[src_word_old]
for i in range(max(0, index - self.ngram), min(index + self.ngram + 1, len(doc_ids))):
dst_word_old = doc_ids[i]
dst = old_to_new[dst_word_old]
edges.append([src, dst])
return edges
def seq_to_graph(self, doc_ids: list) -> dgl.DGLGraph():
if len(doc_ids) > self.max_length:
doc_ids = doc_ids[:self.max_length]
local_vocab = set(doc_ids)
old_to_new = dict(zip(local_vocab, range(len(local_vocab))))
if self.is_cuda:
local_vocab = torch.tensor(list(local_vocab)).cuda()
else:
local_vocab = torch.tensor(list(local_vocab))
sub_graph = dgl.DGLGraph().to('cuda')
sub_graph.add_nodes(len(local_vocab))
local_node_hidden = self.node_hidden(local_vocab)
sub_graph.ndata['k'] = local_node_hidden
seq_edges = self.add_seq_edges(doc_ids, old_to_new)
edges = []
edges.extend(seq_edges)
srcs, dsts = zip(*edges)
sub_graph.add_edges(srcs, dsts)
return sub_graph
def forward(self, doc_ids):
sub_graphs = [self.seq_to_graph(doc) for doc in doc_ids]
batch_graph = dgl.batch(sub_graphs)
batch_f = self.dropout(batch_graph.ndata['k'])
batch_f = self.activation(self.linear1(batch_f))
batch_f = self.linear2(self.dropout(batch_f))
h1 = self.gatnet(batch_graph, batch_f)
h1 = self.weight_and_sum(batch_graph, h1)
return h1