diff --git a/README.md b/README.md index 6f2bdec..8c9ac64 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ from dect.directions import generate_uniform_2d_directions v = generate_uniform_2d_directions(num_thetas=64) -layer = ECTLayer(ECTConfig(), V=v) +layer = ECTLayer(ECTConfig(), v=v) points_coordinates = torch.tensor( [[0.5, 0.0], [-0.5, 0.0], [0.5, 0.5]], requires_grad=True diff --git a/dect/ect.py b/dect/ect.py index 0fabf76..2f14f9d 100644 --- a/dect/ect.py +++ b/dect/ect.py @@ -60,13 +60,18 @@ class Batch: face: The indices of the points that span a face in the simplicial complex. Conforms to pytorch_geometric standards. Shape has to be of the form - [3,num_faces]. + [3,num_faces] or [4, num_faces], depending on the type of complex + (simplicial or cubical). + node_weights: torch.FloatTensor + Optional weights for the nodes in the complex. The shape has to be + [num_nodes,]. """ x: torch.FloatTensor batch: torch.LongTensor edge_index: torch.LongTensor | None = None face: torch.LongTensor | None = None + node_weights: torch.FloatTensor | None = None def compute_ecc( @@ -75,7 +80,7 @@ def compute_ecc( lin: torch.FloatTensor, scale: float = 100, ) -> torch.FloatTensor: - """Computes the Euler Characteristic curve. + """Computes the Euler Characteristic Curve. Parameters ---------- @@ -89,9 +94,6 @@ def compute_ecc( lin: torch.FloatTensor The discretization of the interval [-1,1] each node height falls in this range due to rescaling in normalizing the data. - out: torch.FloatTensor - The shape of the resulting tensor after summation. It has to be of the - shape [num_discretization_steps, batch_size, num_thetas] scale: torch.FloatTensor A single number that scales the sigmoid function by multiplying the sigmoid with the scale. With high (100>) values, the ect will resemble a @@ -121,16 +123,13 @@ def compute_ect_points( lin: torch.FloatTensor The discretization of the interval [-1,1] each node height falls in this range due to rescaling in normalizing the data. - out: torch.FloatTensor - The shape of the resulting tensor after summation. It has to be of the - shape [num_discretization_steps, batch_size, num_thetas] """ nh = batch.x @ v return compute_ecc(nh, batch.batch, lin) def compute_ect_edges( - data: Batch, v: torch.FloatTensor, lin: torch.FloatTensor + batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor ): """Computes the Euler Characteristic Transform of a batch of graphs. @@ -144,23 +143,20 @@ def compute_ect_edges( lin: torch.FloatTensor The discretization of the interval [-1,1] each node height falls in this range due to rescaling in normalizing the data. - out: torch.FloatTensor - The shape of the resulting tensor after summation. It has to be of the - shape [num_discretization_steps, batch_size, num_thetas] """ # Compute the node heigths - nh = data.x @ v + nh = batch.x @ v # Perform a lookup with the edge indices on node heights, this replaces the # node index with its node height and then compute the maximum over the # columns to compute the edge height. - eh, _ = nh[data.edge_index].max(dim=0) + eh, _ = nh[batch.edge_index].max(dim=0) # Compute which batch an edge belongs to. We take the first index of the # edge (or faces) and do a lookup on the batch index of that node in the # batch indices of the nodes. - batch_index_nodes = data.batch - batch_index_edges = data.batch[data.edge_index[0]] + batch_index_nodes = batch.batch + batch_index_edges = batch.batch[batch.edge_index[0]] return compute_ecc(nh, batch_index_nodes, lin) - compute_ecc( eh, batch_index_edges, lin @@ -168,7 +164,7 @@ def compute_ect_edges( def compute_ect_faces( - data: Batch, v: torch.FloatTensor, lin: torch.FloatTensor + batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor ): """Computes the Euler Characteristic Transform of a batch of meshes. @@ -182,27 +178,24 @@ def compute_ect_faces( lin: torch.FloatTensor The discretization of the interval [-1,1] each node height falls in this range due to rescaling in normalizing the data. - out: torch.FloatTensor - The shape of the resulting tensor after summation. It has to be of the - shape [num_discretization_steps, batch_size, num_thetas] """ # Compute the node heigths - nh = data.x @ v + nh = batch.x @ v # Perform a lookup with the edge indices on node heights, this replaces the # node index with its node height and then compute the maximum over the # columns to compute the edge height. - eh, _ = nh[data.edge_index].max(dim=0) + eh, _ = nh[batch.edge_index].max(dim=0) # Do the same thing for the faces. - fh, _ = nh[data.face].max(dim=0) + fh, _ = nh[batch.face].max(dim=0) # Compute which batch an edge belongs to. We take the first index of the # edge (or faces) and do a lookup on the batch index of that node in the # batch indices of the nodes. - batch_index_nodes = data.batch - batch_index_edges = data.batch[data.edge_index[0]] - batch_index_faces = data.batch[data.face[0]] + batch_index_nodes = batch.batch + batch_index_edges = batch.batch[batch.edge_index[0]] + batch_index_faces = batch.batch[batch.face[0]] return ( compute_ecc(nh, batch_index_nodes, lin) @@ -223,7 +216,8 @@ class ECTLayer(nn.Module): ---------- v: torch.FloatTensor The direction vector that contains the directions. The shape of the - tensor v is either [ndims, num_thetas] or [n_channels, ndims, num_thetas]. + tensor v is either [ndims, num_thetas] or [n_channels, ndims, + num_thetas]. config: ECTConfig The configuration config of the ECT layer. diff --git a/dect/wect.py b/dect/wect.py new file mode 100644 index 0000000..67dd526 --- /dev/null +++ b/dect/wect.py @@ -0,0 +1,164 @@ +from typing import Literal + +import geotorch +import torch +from torch import nn +from torch_scatter import segment_add_coo + +from dect.ect import ECTConfig, Batch, normalize + + +def compute_wecc( + nh: torch.FloatTensor, + index: torch.LongTensor, + lin: torch.FloatTensor, + weight: torch.FloatTensor, + scale: float = 500, +): + """Computes the Weighted Euler Characteristic Curve. + + Parameters + ---------- + nh : torch.FloatTensor + The node heights, computed as the inner product of the node coordinates + x and the direction vector v. + index: torch.LongTensor + The index that indicates to which pointcloud a node height belongs. For + the node heights it is the same as the batch index, for the higher order + simplices it will have to be recomputed. + lin: torch.FloatTensor + The discretization of the interval [-1,1] each node height falls in this + range due to rescaling in normalizing the data. + weight: torch.FloatTensor + The weight of the node, edge or face. It is the maximum of the node + weights for the edges and faces. + scale: torch.FloatTensor + A single number that scales the sigmoid function by multiplying the + sigmoid with the scale. With high (100>) values, the ect will resemble a + discrete ECT and with lower values it will smooth the ECT. + """ + ecc = torch.nn.functional.sigmoid(scale * torch.sub(lin, nh)) * weight.view( + 1, -1, 1 + ) + ecc = ecc.movedim(0, 2).movedim(0, 1) + return segment_add_coo(ecc, index) + + +def compute_wect( + batch: Batch, + v: torch.FloatTensor, + lin: torch.FloatTensor, + wect_type: Literal["points"] | Literal["edges"] | Literal["faces"], +): + """ + Computes the Weighted Euler Characteristic Transform of a batch of point + clouds. + + Parameters + ---------- + batch : Batch + A batch of data containing the node coordinates, batch index, + edge_index, face, and node weights. + v: torch.FloatTensor + The direction vector that contains the directions. + lin: torch.FloatTensor + The discretization of the interval [-1,1] each node height falls in this + range due to rescaling in normalizing the data. + wect_type: str + The type of WECT to compute. Can be "points", "edges", or "faces". + """ + nh = batch.x @ v + if wect_type in ["edges", "faces"]: + edge_weights, _ = batch.node_weights[batch.edge_index].max(axis=0) + eh, _ = nh[batch.edge_index].min(dim=0) + if wect_type == "faces": + face_weights, _ = batch.node_weights[batch.face].max(axis=0) + fh, _ = nh[batch.face].min(dim=0) + + if wect_type == "points": + return compute_wecc(nh, batch.batch, lin, batch.node_weights) + if wect_type == "edges": + # noinspection PyUnboundLocalVariable + return compute_wecc( + nh, batch.batch, lin, batch.node_weights + ) - compute_wecc( + eh, batch.batch[batch.edge_index[0]], lin, edge_weights + ) + if wect_type == "faces": + # noinspection PyUnboundLocalVariable + return ( + compute_wecc(nh, batch.batch, lin, batch.node_weights) + - compute_wecc( + eh, batch.batch[batch.edge_index[0]], lin, edge_weights + ) + + compute_wecc(fh, batch.batch[batch.face[0]], lin, face_weights) + ) + raise ValueError(f"Invalid wect_type: {wect_type}") + + +class WECTLayer(nn.Module): + """Machine learning layer for computing the WECT (Weighted ECT). + + Parameters + ---------- + v: torch.FloatTensor + The direction vector that contains the directions. The shape of the + tensor v is either [ndims, num_thetas] or [n_channels, ndims, + num_thetas]. + config : ECTConfig + The configuration object of the WECT layer. + """ + + def __init__(self, config: ECTConfig, v=None): + super().__init__() + self.config = config + self.lin = nn.Parameter( + torch.linspace( + -config.radius, config.radius, config.bump_steps + ).view(-1, 1, 1, 1), + requires_grad=False, + ) + + # If provided with one set of directions. + # For backwards compatibility. + if v.ndim == 2: + v.unsqueeze(0) + + # The set of directions is added + if config.fixed: + self.v = nn.Parameter(v.movedim(-1, -2), requires_grad=False) + else: + self.v = nn.Parameter(torch.zeros_like(v.movedim(-1, -2))) + geotorch.constraints.sphere(self, "v", radius=config.radius) + # Since geotorch randomizes the vector during initialization, we + # assign the values after registering it with spherical constraints. + # See Geotorch documentation for examples. + self.v = v.movedim(-1, -2) + + def forward(self, batch: Batch): + """Forward method for the WECT Layer. + + + Parameters + ---------- + batch : Batch + A batch of data containing the node coordinates, edges, faces, batch + index, and node_weights. It should follow the pytorch geometric + conventions. + + Returns + ---------- + wect: torch.FloatTensor + Returns the WECT of each data object in the batch. If the layer is + initialized with v of the shape [ndims,num_thetas], the returned + WECT has shape [batch,num_thetas,bump_steps]. In case the layer is + initialized with v of the form [n_channels, ndims, num_thetas] the + returned WECT has the shape [batch,n_channels,num_thetas,bump_steps] + """ + # Movedim for geotorch + wect = compute_wect( + batch, self.v.movedim(-1, -2), self.lin, self.config.ect_type + ) + if self.config.normalized: + return normalize(wect) + return wect.squeeze() diff --git a/requirements.txt b/requirements.txt index a9e681e..0da2fc4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ torch -torch-scatter \ No newline at end of file +torch-scatter +geotorch \ No newline at end of file