-
Notifications
You must be signed in to change notification settings - Fork 69
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e0e9364
commit e1bfadc
Showing
1 changed file
with
43 additions
and
105 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,118 +1,56 @@ | ||
""" Module for Loss class """ | ||
|
||
import logging | ||
from torch_geometric.nn import MessagePassing, InstanceNorm, radius_graph | ||
from torch_geometric.data import Data | ||
import torch | ||
from . import LabelTensor | ||
from torch_geometric.nn import radius_graph | ||
from torch_geometric.data import Data | ||
|
||
class Graph: | ||
""" | ||
PINA Graph managing the PyG Data class. | ||
""" | ||
def __init__(self, data): | ||
self.data = data | ||
|
||
@staticmethod | ||
def _build_triangulation(**kwargs): | ||
logging.debug("Creating graph with triangulation mode.") | ||
|
||
# check for mandatory arguments | ||
if "nodes_coordinates" not in kwargs: | ||
raise ValueError("Nodes coordinates must be provided in the kwargs.") | ||
if "nodes_data" not in kwargs: | ||
raise ValueError("Nodes data must be provided in the kwargs.") | ||
if "triangles" not in kwargs: | ||
raise ValueError("Triangles must be provided in the kwargs.") | ||
|
||
nodes_coordinates = kwargs["nodes_coordinates"] | ||
nodes_data = kwargs["nodes_data"] | ||
triangles = kwargs["triangles"] | ||
|
||
def __init__(self, x=None, pos=None, edge_index=None, edge_attr=None, **kwargs): | ||
if isinstance(x, torch.Tensor): | ||
self.size_x = x.size(0) | ||
|
||
if isinstance(pos, torch.Tensor): | ||
self.size_pos = pos.size(0) | ||
self.data = None | ||
if x is not None and pos is not None: | ||
self.build_graphs_list(x, pos, **kwargs) | ||
|
||
def less_first(a, b): | ||
return [a, b] if a < b else [b, a] | ||
|
||
list_of_edges = [] | ||
|
||
for triangle in triangles: | ||
for e1, e2 in [[0, 1], [1, 2], [2, 0]]: | ||
list_of_edges.append(less_first(triangle[e1],triangle[e2])) | ||
|
||
array_of_edges = torch.unique(torch.Tensor(list_of_edges), dim=0) # remove duplicates | ||
array_of_edges = array_of_edges.t().contiguous() | ||
print(array_of_edges) | ||
|
||
# list_of_lengths = [] | ||
|
||
# for p1,p2 in array_of_edges: | ||
# x1, y1 = tri.points[p1] | ||
# x2, y2 = tri.points[p2] | ||
# list_of_lengths.append((x1-x2)**2 + (y1-y2)**2) | ||
|
||
# array_of_lengths = np.sqrt(np.array(list_of_lengths)) | ||
|
||
# return array_of_edges, array_of_lengths | ||
def build_graphs_list(self, x, pos, method='radius', | ||
build_edge_attr=False, **kwargs): | ||
""" | ||
Build the graph from the node features and the node positions. | ||
""" | ||
if isinstance(x, list) and isinstance(pos, list): | ||
if len(x) != len(pos): | ||
raise ValueError("The number of node features and node positions" | ||
" must be the same.") | ||
if isinstance(x, (torch.Tensor, LabelTensor)) and isinstance( | ||
pos, list): | ||
x = [x] * len(pos) # Copy just the reference | ||
if isinstance(pos, (torch.Tensor, LabelTensor)): | ||
edge_idx = [self._build_edge_index(pos, method, **kwargs)] * len(x) | ||
else: | ||
edge_idx = [self._build_edge_index(p, method, **kwargs) for p in pos] | ||
if build_edge_attr is not None: | ||
edge_attr = [self._build_edge_attr(p, e) for p, e in zip(pos, edge_idx)] | ||
else: | ||
edge_attr = [None] * len(x) | ||
|
||
return Data( | ||
x=nodes_data, | ||
pos=nodes_coordinates.T, | ||
|
||
edge_index=array_of_edges, | ||
) | ||
graphs = [] | ||
for i in range(len(x)): | ||
graphs.append(Data(x=x[i], pos=pos[i], edge_index=edge_idx[i], | ||
edge_attr=edge_attr[i])) | ||
self.data = graphs | ||
|
||
@staticmethod | ||
def _build_radius(**kwargs): | ||
logging.debug("Creating graph with radius mode.") | ||
|
||
# check for mandatory arguments | ||
if "nodes_coordinates" not in kwargs: | ||
raise ValueError("Nodes coordinates must be provided in the kwargs.") | ||
if "nodes_data" not in kwargs: | ||
raise ValueError("Nodes data must be provided in the kwargs.") | ||
if "radius" not in kwargs: | ||
raise ValueError("Radius must be provided in the kwargs.") | ||
|
||
nodes_coordinates = kwargs["nodes_coordinates"] | ||
nodes_data = kwargs["nodes_data"] | ||
radius = kwargs["radius"] | ||
|
||
edges_data = kwargs.get("edge_data", None) | ||
loop = kwargs.get("loop", False) | ||
batch = kwargs.get("batch", None) | ||
|
||
logging.debug(f"radius: {radius}, loop: {loop}, " | ||
f"batch: {batch}") | ||
|
||
edge_index = radius_graph( | ||
x=nodes_coordinates.tensor, | ||
r=radius, | ||
loop=loop, | ||
batch=batch, | ||
) | ||
|
||
logging.debug(f"edge_index computed") | ||
return Data( | ||
x=nodes_data.tensor, | ||
pos=nodes_coordinates.tensor, | ||
edge_index=edge_index, | ||
edge_attr=edges_data, | ||
) | ||
def _build_edge_index(pos, method, **kwargs): | ||
if method == 'radius': | ||
return radius_graph(pos, **kwargs) | ||
else: | ||
raise ValueError("The method must be 'radius'.") | ||
|
||
@staticmethod | ||
def build(mode, **kwargs): | ||
""" | ||
Constructor for the `Graph` class. | ||
""" | ||
if mode == "radius": | ||
graph = Graph._build_radius(**kwargs) | ||
elif mode == "triangulation": | ||
graph = Graph._build_triangulation(**kwargs) | ||
else: | ||
raise ValueError(f"Mode {mode} not recognized") | ||
|
||
return Graph(graph) | ||
def _build_edge_attr(pos, edge_index,): | ||
return torch.norm((pos[edge_index[0]] - pos[edge_index[1]]), dim=-1) | ||
|
||
|
||
def __repr__(self): | ||
return f"Graph(data={self.data})" |