Skip to content

Commit

Permalink
Updated documentation.
Browse files Browse the repository at this point in the history
  • Loading branch information
ErnstRoell committed Jul 4, 2024
1 parent 695e040 commit 046b26a
Showing 1 changed file with 66 additions and 3 deletions.
69 changes: 66 additions & 3 deletions dect/ect.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,25 @@
class ECTConfig:
"""
Configuration of the ECT Layer.
Parameters
----------
bump_steps : int
The number of steps to discretize the ECT into.
radius : float
The radius of the circle the directions lie on. Usually this is a bit
larger than the objects we wish to compute the ECT for, which in most
cases have radius 1. For now it defaults to 1 as well.
ect_type : str
The type of ECT we wish to compute. Can be "points" for point clouds,
"edges" for graphs or "faces" for meshes.
normalized: bool
Whether or not to normalize the ECT. Only work with ect_type set to
points and normalized the ECT to the interval [0,1].
fixed: bool
Option to keep the directions fixed or not. In case the directions are
learnable, we can use backpropagation to optimize over a set of
directions. See notebooks for examples.
"""

bump_steps: int = 32
Expand Down Expand Up @@ -192,7 +211,35 @@ def normalize(ect):


class ECTLayer(nn.Module):
"""Machine learning layer for computing the ECT."""
"""Machine learning layer for computing the ECT.
Parameters
----------
v: torch.FloatTensor
The direction vector that contains the directions. The shape of the
tensor v is either [ndims, num_thetas] or [n_channels, ndims, num_thetas].
config: ECTConfig
The configuration config of the ECT layer.
Config Parameters
----------
bump_steps : int
The number of steps to discretize the ECT into.
radius : float
The radius of the circle the directions lie on. Usually this is a bit
larger than the objects we wish to compute the ECT for, which in most
cases have radius 1. For now it defaults to 1 as well.
ect_type : str
The type of ECT we wish to compute. Can be "points" for point clouds,
"edges" for graphs or "faces" for meshes.
normalized: bool
Whether or not to normalize the ECT. Only work with ect_type set to
points and normalized the ECT to the interval [0,1].
fixed: bool
Option to keep the directions fixed or not. In case the directions are
learnable, we can use backpropagation to optimize over a set of
directions. See notebooks for examples.
"""

def __init__(self, config: ECTConfig, v=None):
super().__init__()
Expand All @@ -210,7 +257,6 @@ def __init__(self, config: ECTConfig, v=None):
v.unsqueeze(0)

# The set of directions is added
# TODO: Requires testing.
if config.fixed:
self.v = nn.Parameter(v.movedim(-1, -2), requires_grad=False)
else:
Expand All @@ -230,7 +276,24 @@ def __init__(self, config: ECTConfig, v=None):
self.compute_ect = compute_ect_faces

def forward(self, batch: Batch):
"""Forward method for the ECT Layer."""
"""Forward method for the ECT Layer.
Parameters
----------
batch : Batch
A batch of data containing the node coordinates, edges, faces and
batch index. It should follow the pytorch geometric conventions.
Returns
----------
ect: torch.FloatTensor
Returns the ECT of each data object in the batch. If the layer is
initialized with v of the shape [ndims,num_thetas], the returned ECT
has shape [batch,num_thetas,bump_steps]. In case the layer is
initialized with v of the form [n_channels, ndims, num_thetas] the
returned ECT has the shape [batch,n_channels,num_thetas,bump_steps]
"""
# Movedim for geotorch.
ect = self.compute_ect(batch, self.v.movedim(-1, -2), self.lin)
if self.config.normalized:
Expand Down

0 comments on commit 046b26a

Please sign in to comment.