Skip to content

Commit

Permalink
Adding black + some minor changes in wect and the README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
rballeba committed Jul 17, 2024
1 parent efd2063 commit 08ab939
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 62 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ from dect.directions import generate_uniform_2d_directions

v = generate_uniform_2d_directions(num_thetas=64)

layer = ECTLayer(ECTConfig(), V=v)
layer = ECTLayer(ECTConfig(), v=v)

points_coordinates = torch.tensor(
[[0.5, 0.0], [-0.5, 0.0], [0.5, 0.5]], requires_grad=True
Expand Down
18 changes: 6 additions & 12 deletions dect/ect.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,7 @@ def compute_ecc(
return segment_add_coo(ecc, index)


def compute_ect_points(
batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor
):
def compute_ect_points(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor):
"""Computes the Euler Characteristic Transform of a batch of point clouds.
Parameters
Expand All @@ -127,9 +125,7 @@ def compute_ect_points(
return compute_ecc(nh, batch.batch, lin)


def compute_ect_edges(
batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor
):
def compute_ect_edges(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor):
"""Computes the Euler Characteristic Transform of a batch of graphs.
Parameters
Expand Down Expand Up @@ -162,9 +158,7 @@ def compute_ect_edges(
)


def compute_ect_faces(
batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor
):
def compute_ect_faces(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor):
"""Computes the Euler Characteristic Transform of a batch of meshes.
Parameters
Expand Down Expand Up @@ -225,9 +219,9 @@ def __init__(self, config: ECTConfig, v=None):
super().__init__()
self.config = config
self.lin = nn.Parameter(
torch.linspace(
-config.radius, config.radius, config.bump_steps
).view(-1, 1, 1, 1),
torch.linspace(-config.radius, config.radius, config.bump_steps).view(
-1, 1, 1, 1
),
requires_grad=False,
)

Expand Down
98 changes: 49 additions & 49 deletions dect/wect.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import nn
from torch_scatter import segment_add_coo

from ect import ECTConfig, Batch, normalize
from dect.ect import ECTConfig, Batch, normalize


def compute_wecc(
Expand All @@ -17,26 +17,26 @@ def compute_wecc(
):
"""Computes the weighted Euler Characteristic curve.
Parameters
----------
nh : torch.FloatTensor
The node heights, computed as the inner product of the node coordinates
x and the direction vector v.
index: torch.LongTensor
The index that indicates to which pointcloud a node height belongs. For
the node heights it is the same as the batch index, for the higher order
simplices it will have to be recomputed.
lin: torch.FloatTensor
The discretization of the interval [-1,1] each node height falls in this
range due to rescaling in normalizing the data.
weight: torch.FloatTensor
The weight of the node, edge or face. It is the maximum of the node
weights for the edges and faces.
scale: torch.FloatTensor
A single number that scales the sigmoid function by multiplying the
sigmoid with the scale. With high (100>) values, the ect will resemble a
discrete ECT and with lower values it will smooth the ECT.
"""
Parameters
----------
nh : torch.FloatTensor
The node heights, computed as the inner product of the node coordinates
x and the direction vector v.
index: torch.LongTensor
The index that indicates to which pointcloud a node height belongs. For
the node heights it is the same as the batch index, for the higher order
simplices it will have to be recomputed.
lin: torch.FloatTensor
The discretization of the interval [-1,1] each node height falls in this
range due to rescaling in normalizing the data.
weight: torch.FloatTensor
The weight of the node, edge or face. It is the maximum of the node
weights for the edges and faces.
scale: torch.FloatTensor
A single number that scales the sigmoid function by multiplying the
sigmoid with the scale. With high (100>) values, the ect will resemble a
discrete ECT and with lower values it will smooth the ECT.
"""
ecc = torch.nn.functional.sigmoid(scale * torch.sub(lin, nh)) * weight.view(
1, -1, 1
)
Expand All @@ -52,19 +52,19 @@ def compute_wect(
):
"""Computes the Weighted Euler Characteristic Transform of a batch of point clouds.
Parameters
----------
batch : Batch
A batch of data containing the node coordinates, batch index, edge_index, face, and
node weights.
v: torch.FloatTensor
The direction vector that contains the directions.
lin: torch.FloatTensor
The discretization of the interval [-1,1] each node height falls in this
range due to rescaling in normalizing the data.
wect_type: str
The type of WECT to compute. Can be "points", "edges", or "faces".
"""
Parameters
----------
batch : Batch
A batch of data containing the node coordinates, batch index, edge_index, face, and
node weights.
v: torch.FloatTensor
The direction vector that contains the directions.
lin: torch.FloatTensor
The discretization of the interval [-1,1] each node height falls in this
range due to rescaling in normalizing the data.
wect_type: str
The type of WECT to compute. Can be "points", "edges", or "faces".
"""
nh = batch.x @ v
if wect_type in ["edges", "faces"]:
edge_weights, _ = batch.node_weights[batch.edge_index].max(axis=0)
Expand Down Expand Up @@ -132,21 +132,21 @@ def forward(self, batch: Batch):
"""Forward method for the ECT Layer.
Parameters
----------
batch : Batch
A batch of data containing the node coordinates, edges, faces,
batch index, and node_weights. It should follow the pytorch geometric conventions.
Returns
----------
wect: torch.FloatTensor
Returns the WECT of each data object in the batch. If the layer is
initialized with v of the shape [ndims,num_thetas], the returned WECT
has shape [batch,num_thetas,bump_steps]. In case the layer is
initialized with v of the form [n_channels, ndims, num_thetas] the
returned WECT has the shape [batch,n_channels,num_thetas,bump_steps]
"""
Parameters
----------
batch : Batch
A batch of data containing the node coordinates, edges, faces,
batch index, and node_weights. It should follow the pytorch geometric conventions.
Returns
----------
wect: torch.FloatTensor
Returns the WECT of each data object in the batch. If the layer is
initialized with v of the shape [ndims,num_thetas], the returned WECT
has shape [batch,num_thetas,bump_steps]. In case the layer is
initialized with v of the form [n_channels, ndims, num_thetas] the
returned WECT has the shape [batch,n_channels,num_thetas,bump_steps]
"""
# Movedim for geotorch
wect = compute_wect(
batch, self.v.movedim(-1, -2), self.lin, self.config.ect_type
Expand Down

0 comments on commit 08ab939

Please sign in to comment.