-
Notifications
You must be signed in to change notification settings - Fork 0
/
lib_diago.py
84 lines (66 loc) · 3.6 KB
/
lib_diago.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
################################################################################
# Original code from : https://github.com/Diego999/pyGAT
# Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903
################################################################################
class SpecialSpmmFunction(torch.autograd.Function): #Special function for only sparse region backpropataion layer
@staticmethod
def forward(ctx, indices, values, shape, b):
assert indices.requires_grad == False
a = torch.sparse_coo_tensor(indices, values, shape)
ctx.save_for_backward(a, b)
ctx.N = shape[0]
return torch.matmul(a, b)
@staticmethod
def backward(ctx, grad_output):
a, b = ctx.saved_tensors
grad_values = grad_b = None
if ctx.needs_input_grad[1]:
grad_a_dense = grad_output.matmul(b.t())
edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :]
grad_values = grad_a_dense.view(-1)[edge_idx]
if ctx.needs_input_grad[3]:
grad_b = a.t().matmul(grad_output)
return None, grad_values, None, grad_b
class SpecialSpmm(nn.Module):
def forward(self, indices, values, shape, b):
return SpecialSpmmFunction.apply(indices, values, shape, b)
################################################################################
class SpGraphAttentionLayer_v2(nn.Module): # Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903 with GATv2 addded.
def __init__(self, in_features, out_features, dropout, alpha, concat=True):
super(SpGraphAttentionLayer_v2, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.alpha = alpha
self.concat = concat
self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
nn.init.xavier_normal_(self.W.data, gain=1.414)
self.a = nn.Parameter(torch.zeros(size=(1, 2*out_features)))
nn.init.xavier_normal_(self.a.data, gain=1.414)
self.dropout = nn.Dropout(dropout)
self.leakyrelu = nn.LeakyReLU(self.alpha)
self.special_spmm = SpecialSpmm()
def forward(self, input, adj):
N = input.size()[0]
edge = adj._indices() #GAYAN: Modification. Replaced nonzero to get sparse edge list
h = torch.mm(input, self.W)# h: N x out
assert not torch.isnan(h).any()
edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t()# Self-attention on the nodes - Shared attention mechanism edge: 2*D x E
#edge_e = torch.exp(-self.leakyrelu(self.a.mm(edge_h).squeeze()))# edge_e: E GAT
edge_e = torch.exp(-self.a.mm(self.leakyrelu(edge_h)).squeeze())# GAYAN GATv2
assert not torch.isnan(edge_e).any()
e_rowsum = self.special_spmm(edge, edge_e, torch.Size([N, N]), torch.ones(size=(N,1), device=adj.device))# e_rowsum: N x 1
edge_e = self.dropout(edge_e)# edge_e: E
h_prime = self.special_spmm(edge, edge_e, torch.Size([N, N]), h)# h_prime: N x out
assert not torch.isnan(h_prime).any()
h_prime = h_prime.div(e_rowsum)# h_prime: N x out
assert not torch.isnan(h_prime).any()
if self.concat:# if this layer is not last layer,
return F.elu(h_prime)
else:# if this layer is last layer,
return h_prime
def __repr__(self):
return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'