Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding WECT and fixing some minor typos in ECT/Readme. #7

Merged
merged 6 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ from dect.directions import generate_uniform_2d_directions

v = generate_uniform_2d_directions(num_thetas=64)

layer = ECTLayer(ECTConfig(), V=v)
layer = ECTLayer(ECTConfig(), v=v)

points_coordinates = torch.tensor(
[[0.5, 0.0], [-0.5, 0.0], [0.5, 0.5]], requires_grad=True
Expand Down
48 changes: 21 additions & 27 deletions dect/ect.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,18 @@ class Batch:
face:
The indices of the points that span a face in the simplicial complex.
Conforms to pytorch_geometric standards. Shape has to be of the form
[3,num_faces].
[3,num_faces] or [4, num_faces], depending on the type of complex
(simplicial or cubical).
node_weights: torch.FloatTensor
Optional weights for the nodes in the complex. The shape has to be
[num_nodes,].
"""

x: torch.FloatTensor
batch: torch.LongTensor
edge_index: torch.LongTensor | None = None
face: torch.LongTensor | None = None
node_weights: torch.FloatTensor | None = None


def compute_ecc(
Expand All @@ -75,7 +80,7 @@ def compute_ecc(
lin: torch.FloatTensor,
scale: float = 100,
) -> torch.FloatTensor:
"""Computes the Euler Characteristic curve.
"""Computes the Euler Characteristic Curve.

Parameters
----------
Expand All @@ -89,9 +94,6 @@ def compute_ecc(
lin: torch.FloatTensor
The discretization of the interval [-1,1] each node height falls in this
range due to rescaling in normalizing the data.
out: torch.FloatTensor
The shape of the resulting tensor after summation. It has to be of the
shape [num_discretization_steps, batch_size, num_thetas]
scale: torch.FloatTensor
A single number that scales the sigmoid function by multiplying the
sigmoid with the scale. With high (100>) values, the ect will resemble a
Expand Down Expand Up @@ -121,16 +123,13 @@ def compute_ect_points(
lin: torch.FloatTensor
The discretization of the interval [-1,1] each node height falls in this
range due to rescaling in normalizing the data.
out: torch.FloatTensor
The shape of the resulting tensor after summation. It has to be of the
shape [num_discretization_steps, batch_size, num_thetas]
"""
nh = batch.x @ v
return compute_ecc(nh, batch.batch, lin)


def compute_ect_edges(
data: Batch, v: torch.FloatTensor, lin: torch.FloatTensor
batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor
):
"""Computes the Euler Characteristic Transform of a batch of graphs.

Expand All @@ -144,31 +143,28 @@ def compute_ect_edges(
lin: torch.FloatTensor
The discretization of the interval [-1,1] each node height falls in this
range due to rescaling in normalizing the data.
out: torch.FloatTensor
The shape of the resulting tensor after summation. It has to be of the
shape [num_discretization_steps, batch_size, num_thetas]
"""
# Compute the node heigths
nh = data.x @ v
nh = batch.x @ v

# Perform a lookup with the edge indices on node heights, this replaces the
# node index with its node height and then compute the maximum over the
# columns to compute the edge height.
eh, _ = nh[data.edge_index].max(dim=0)
eh, _ = nh[batch.edge_index].max(dim=0)

# Compute which batch an edge belongs to. We take the first index of the
# edge (or faces) and do a lookup on the batch index of that node in the
# batch indices of the nodes.
batch_index_nodes = data.batch
batch_index_edges = data.batch[data.edge_index[0]]
batch_index_nodes = batch.batch
batch_index_edges = batch.batch[batch.edge_index[0]]

return compute_ecc(nh, batch_index_nodes, lin) - compute_ecc(
eh, batch_index_edges, lin
)


def compute_ect_faces(
data: Batch, v: torch.FloatTensor, lin: torch.FloatTensor
batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor
):
"""Computes the Euler Characteristic Transform of a batch of meshes.

Expand All @@ -182,27 +178,24 @@ def compute_ect_faces(
lin: torch.FloatTensor
The discretization of the interval [-1,1] each node height falls in this
range due to rescaling in normalizing the data.
out: torch.FloatTensor
The shape of the resulting tensor after summation. It has to be of the
shape [num_discretization_steps, batch_size, num_thetas]
"""
# Compute the node heigths
nh = data.x @ v
nh = batch.x @ v

# Perform a lookup with the edge indices on node heights, this replaces the
# node index with its node height and then compute the maximum over the
# columns to compute the edge height.
eh, _ = nh[data.edge_index].max(dim=0)
eh, _ = nh[batch.edge_index].max(dim=0)

# Do the same thing for the faces.
fh, _ = nh[data.face].max(dim=0)
fh, _ = nh[batch.face].max(dim=0)

# Compute which batch an edge belongs to. We take the first index of the
# edge (or faces) and do a lookup on the batch index of that node in the
# batch indices of the nodes.
batch_index_nodes = data.batch
batch_index_edges = data.batch[data.edge_index[0]]
batch_index_faces = data.batch[data.face[0]]
batch_index_nodes = batch.batch
batch_index_edges = batch.batch[batch.edge_index[0]]
batch_index_faces = batch.batch[batch.face[0]]

return (
compute_ecc(nh, batch_index_nodes, lin)
Expand All @@ -223,7 +216,8 @@ class ECTLayer(nn.Module):
----------
v: torch.FloatTensor
The direction vector that contains the directions. The shape of the
tensor v is either [ndims, num_thetas] or [n_channels, ndims, num_thetas].
tensor v is either [ndims, num_thetas] or [n_channels, ndims,
num_thetas].
config: ECTConfig
The configuration config of the ECT layer.

Expand Down
164 changes: 164 additions & 0 deletions dect/wect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from typing import Literal

import geotorch
import torch
from torch import nn
from torch_scatter import segment_add_coo

from dect.ect import ECTConfig, Batch, normalize


def compute_wecc(
nh: torch.FloatTensor,
index: torch.LongTensor,
lin: torch.FloatTensor,
weight: torch.FloatTensor,
scale: float = 500,
):
"""Computes the Weighted Euler Characteristic Curve.

Parameters
----------
nh : torch.FloatTensor
The node heights, computed as the inner product of the node coordinates
x and the direction vector v.
index: torch.LongTensor
The index that indicates to which pointcloud a node height belongs. For
the node heights it is the same as the batch index, for the higher order
simplices it will have to be recomputed.
lin: torch.FloatTensor
The discretization of the interval [-1,1] each node height falls in this
range due to rescaling in normalizing the data.
weight: torch.FloatTensor
The weight of the node, edge or face. It is the maximum of the node
weights for the edges and faces.
scale: torch.FloatTensor
A single number that scales the sigmoid function by multiplying the
sigmoid with the scale. With high (100>) values, the ect will resemble a
discrete ECT and with lower values it will smooth the ECT.
"""
ecc = torch.nn.functional.sigmoid(scale * torch.sub(lin, nh)) * weight.view(
1, -1, 1
)
ecc = ecc.movedim(0, 2).movedim(0, 1)
return segment_add_coo(ecc, index)


def compute_wect(
batch: Batch,
v: torch.FloatTensor,
lin: torch.FloatTensor,
wect_type: Literal["points"] | Literal["edges"] | Literal["faces"],
):
"""
Computes the Weighted Euler Characteristic Transform of a batch of point
clouds.

Parameters
----------
batch : Batch
A batch of data containing the node coordinates, batch index,
edge_index, face, and node weights.
v: torch.FloatTensor
The direction vector that contains the directions.
lin: torch.FloatTensor
The discretization of the interval [-1,1] each node height falls in this
range due to rescaling in normalizing the data.
wect_type: str
The type of WECT to compute. Can be "points", "edges", or "faces".
"""
nh = batch.x @ v
if wect_type in ["edges", "faces"]:
edge_weights, _ = batch.node_weights[batch.edge_index].max(axis=0)
eh, _ = nh[batch.edge_index].min(dim=0)
if wect_type == "faces":
face_weights, _ = batch.node_weights[batch.face].max(axis=0)
fh, _ = nh[batch.face].min(dim=0)

if wect_type == "points":
return compute_wecc(nh, batch.batch, lin, batch.node_weights)
if wect_type == "edges":
# noinspection PyUnboundLocalVariable
return compute_wecc(
nh, batch.batch, lin, batch.node_weights
) - compute_wecc(
eh, batch.batch[batch.edge_index[0]], lin, edge_weights
)
if wect_type == "faces":
# noinspection PyUnboundLocalVariable
return (
compute_wecc(nh, batch.batch, lin, batch.node_weights)
- compute_wecc(
eh, batch.batch[batch.edge_index[0]], lin, edge_weights
)
+ compute_wecc(fh, batch.batch[batch.face[0]], lin, face_weights)
)
raise ValueError(f"Invalid wect_type: {wect_type}")


class WECTLayer(nn.Module):
"""Machine learning layer for computing the WECT (Weighted ECT).

Parameters
----------
v: torch.FloatTensor
The direction vector that contains the directions. The shape of the
tensor v is either [ndims, num_thetas] or [n_channels, ndims,
num_thetas].
config : ECTConfig
The configuration object of the WECT layer.
"""

def __init__(self, config: ECTConfig, v=None):
super().__init__()
self.config = config
self.lin = nn.Parameter(
torch.linspace(
-config.radius, config.radius, config.bump_steps
).view(-1, 1, 1, 1),
requires_grad=False,
)

# If provided with one set of directions.
# For backwards compatibility.
if v.ndim == 2:
v.unsqueeze(0)

# The set of directions is added
if config.fixed:
self.v = nn.Parameter(v.movedim(-1, -2), requires_grad=False)
else:
self.v = nn.Parameter(torch.zeros_like(v.movedim(-1, -2)))
geotorch.constraints.sphere(self, "v", radius=config.radius)
# Since geotorch randomizes the vector during initialization, we
# assign the values after registering it with spherical constraints.
# See Geotorch documentation for examples.
self.v = v.movedim(-1, -2)

def forward(self, batch: Batch):
"""Forward method for the WECT Layer.


Parameters
----------
batch : Batch
A batch of data containing the node coordinates, edges, faces, batch
index, and node_weights. It should follow the pytorch geometric
conventions.

Returns
----------
wect: torch.FloatTensor
Returns the WECT of each data object in the batch. If the layer is
initialized with v of the shape [ndims,num_thetas], the returned
WECT has shape [batch,num_thetas,bump_steps]. In case the layer is
initialized with v of the form [n_channels, ndims, num_thetas] the
returned WECT has the shape [batch,n_channels,num_thetas,bump_steps]
"""
# Movedim for geotorch
wect = compute_wect(
batch, self.v.movedim(-1, -2), self.lin, self.config.ect_type
)
if self.config.normalized:
return normalize(wect)
return wect.squeeze()
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
torch
torch-scatter
torch-scatter
geotorch