-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgraph.py
83 lines (55 loc) · 2.94 KB
/
graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import torch.nn as nn
from torch_geometric.nn import GATConv
import torch.nn.functional as F
from utils.helpers import device
def gumbel_adjacency_matrix(node_embeddings, similarity_threshold, temperature, hard):
node_norm = torch.norm(node_embeddings, p=2, dim=-1, keepdim=True)
norm_matrix = torch.matmul(node_norm, node_norm.transpose(-2, -1))
similarity_matrix = torch.matmul(node_embeddings, node_embeddings.transpose(-2, -1)) / (norm_matrix + 1e-8)
similarity_matrix = torch.sigmoid(similarity_matrix)
sim_matrix_centered = similarity_matrix - similarity_threshold
adjacency_matrix = F.gumbel_softmax(sim_matrix_centered, temperature, hard=hard)
adjacency_matrix = adjacency_matrix * (1 - torch.eye(node_embeddings.shape[-2], device=device).unsqueeze(0))
return adjacency_matrix
class StateToGraph(nn.Module):
def __init__(self, state_dim, node_feature_dim, num_nodes):
super(StateToGraph, self).__init__()
self.fc1 = nn.Linear(state_dim, num_nodes * node_feature_dim)
self.num_nodes = num_nodes
self.node_feature_dim = node_feature_dim
def adjacency_to_edge_index(self, adj_matrix):
b, n, _ = adj_matrix.shape
# Create row and column index tensors
row_indices, col_indices = torch.meshgrid(torch.arange(n), torch.arange(n))
row_indices = row_indices.to(device).repeat(b, 1, 1)
col_indices = col_indices.to(device).repeat(b, 1, 1)
# Find non-zero elements in the adjacency matrix
non_zero_mask = adj_matrix.bool()
# Extract non-zero row and column indices and create edge_index tensor
edge_index = torch.stack((row_indices[non_zero_mask], col_indices[non_zero_mask]), dim=1)
# Calculate edge count for each (a, b) pair
edge_counts = non_zero_mask.view(b, n * n).sum(dim=-1)
edge_indices = torch.split(edge_index.T, edge_counts.tolist(),dim=-1)
return edge_indices
def forward(self, x, similarity_threshould, temperature, hard):
x = torch.relu(self.fc1(x))
node_emb = x.view(-1, self.num_nodes, self.node_feature_dim)
adjacency_matrix = gumbel_adjacency_matrix(
node_emb,
similarity_threshold=similarity_threshould,
temperature=temperature,
hard=hard,
)
edge_index = self.adjacency_to_edge_index(adjacency_matrix)
return node_emb, edge_index
class GAT(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_heads, dropout=0.6):
super(GAT, self).__init__()
self.gat1 = GATConv(input_dim, hidden_dim, heads=num_heads, dropout=dropout)
self.gat2 = GATConv(hidden_dim * num_heads, output_dim, dropout=dropout)
def forward(self, g):
x = self.gat1(g.x, g.edge_index)
x = F.elu(x)
x = self.gat2(x, g.edge_index)
return x