-
Notifications
You must be signed in to change notification settings - Fork 30
/
aggregator.py
90 lines (73 loc) · 3.56 KB
/
aggregator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
import numpy
class Aggregator(nn.Module):
def __init__(self, batch_size, dim, dropout, act, name=None):
super(Aggregator, self).__init__()
self.dropout = dropout
self.act = act
self.batch_size = batch_size
self.dim = dim
def forward(self):
pass
class LocalAggregator(nn.Module):
def __init__(self, dim, alpha, dropout=0., name=None):
super(LocalAggregator, self).__init__()
self.dim = dim
self.dropout = dropout
self.a_0 = nn.Parameter(torch.Tensor(self.dim, 1))
self.a_1 = nn.Parameter(torch.Tensor(self.dim, 1))
self.a_2 = nn.Parameter(torch.Tensor(self.dim, 1))
self.a_3 = nn.Parameter(torch.Tensor(self.dim, 1))
self.bias = nn.Parameter(torch.Tensor(self.dim))
self.leakyrelu = nn.LeakyReLU(alpha)
def forward(self, hidden, adj, mask_item=None):
h = hidden
batch_size = h.shape[0]
N = h.shape[1]
a_input = (h.repeat(1, 1, N).view(batch_size, N * N, self.dim)
* h.repeat(1, N, 1)).view(batch_size, N, N, self.dim)
e_0 = torch.matmul(a_input, self.a_0)
e_1 = torch.matmul(a_input, self.a_1)
e_2 = torch.matmul(a_input, self.a_2)
e_3 = torch.matmul(a_input, self.a_3)
e_0 = self.leakyrelu(e_0).squeeze(-1).view(batch_size, N, N)
e_1 = self.leakyrelu(e_1).squeeze(-1).view(batch_size, N, N)
e_2 = self.leakyrelu(e_2).squeeze(-1).view(batch_size, N, N)
e_3 = self.leakyrelu(e_3).squeeze(-1).view(batch_size, N, N)
mask = -9e15 * torch.ones_like(e_0)
alpha = torch.where(adj.eq(1), e_0, mask)
alpha = torch.where(adj.eq(2), e_1, alpha)
alpha = torch.where(adj.eq(3), e_2, alpha)
alpha = torch.where(adj.eq(4), e_3, alpha)
alpha = torch.softmax(alpha, dim=-1)
output = torch.matmul(alpha, h)
return output
class GlobalAggregator(nn.Module):
def __init__(self, dim, dropout, act=torch.relu, name=None):
super(GlobalAggregator, self).__init__()
self.dropout = dropout
self.act = act
self.dim = dim
self.w_1 = nn.Parameter(torch.Tensor(self.dim + 1, self.dim))
self.w_2 = nn.Parameter(torch.Tensor(self.dim, 1))
self.w_3 = nn.Parameter(torch.Tensor(2 * self.dim, self.dim))
self.bias = nn.Parameter(torch.Tensor(self.dim))
def forward(self, self_vectors, neighbor_vector, batch_size, masks, neighbor_weight, extra_vector=None):
if extra_vector is not None:
alpha = torch.matmul(torch.cat([extra_vector.unsqueeze(2).repeat(1, 1, neighbor_vector.shape[2], 1)*neighbor_vector, neighbor_weight.unsqueeze(-1)], -1), self.w_1).squeeze(-1)
alpha = F.leaky_relu(alpha, negative_slope=0.2)
alpha = torch.matmul(alpha, self.w_2).squeeze(-1)
alpha = torch.softmax(alpha, -1).unsqueeze(-1)
neighbor_vector = torch.sum(alpha * neighbor_vector, dim=-2)
else:
neighbor_vector = torch.mean(neighbor_vector, dim=2)
# self_vectors = F.dropout(self_vectors, 0.5, training=self.training)
output = torch.cat([self_vectors, neighbor_vector], -1)
output = F.dropout(output, self.dropout, training=self.training)
output = torch.matmul(output, self.w_3)
output = output.view(batch_size, -1, self.dim)
output = self.act(output)
return output