From efd2063f6c058b86afeaf18076aea592b6d32f64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rub=C3=A9n=20Ballester?= Date: Wed, 17 Jul 2024 18:08:37 +0200 Subject: [PATCH 1/6] Updating requirements, some typos of ect, and created the weighted version of ect in wect.py --- dect/ect.py | 40 +++++------- dect/wect.py | 156 +++++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 3 +- 3 files changed, 174 insertions(+), 25 deletions(-) create mode 100644 dect/wect.py diff --git a/dect/ect.py b/dect/ect.py index 0fabf76..363506d 100644 --- a/dect/ect.py +++ b/dect/ect.py @@ -61,12 +61,16 @@ class Batch: The indices of the points that span a face in the simplicial complex. Conforms to pytorch_geometric standards. Shape has to be of the form [3,num_faces]. + node_weights: torch.FloatTensor + The weights of the nodes in the complex. The shape has to be + [num_nodes,]. """ x: torch.FloatTensor batch: torch.LongTensor edge_index: torch.LongTensor | None = None face: torch.LongTensor | None = None + node_weights: torch.FloatTensor | None = None def compute_ecc( @@ -89,9 +93,6 @@ def compute_ecc( lin: torch.FloatTensor The discretization of the interval [-1,1] each node height falls in this range due to rescaling in normalizing the data. - out: torch.FloatTensor - The shape of the resulting tensor after summation. It has to be of the - shape [num_discretization_steps, batch_size, num_thetas] scale: torch.FloatTensor A single number that scales the sigmoid function by multiplying the sigmoid with the scale. With high (100>) values, the ect will resemble a @@ -121,16 +122,13 @@ def compute_ect_points( lin: torch.FloatTensor The discretization of the interval [-1,1] each node height falls in this range due to rescaling in normalizing the data. - out: torch.FloatTensor - The shape of the resulting tensor after summation. It has to be of the - shape [num_discretization_steps, batch_size, num_thetas] """ nh = batch.x @ v return compute_ecc(nh, batch.batch, lin) def compute_ect_edges( - data: Batch, v: torch.FloatTensor, lin: torch.FloatTensor + batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor ): """Computes the Euler Characteristic Transform of a batch of graphs. @@ -144,23 +142,20 @@ def compute_ect_edges( lin: torch.FloatTensor The discretization of the interval [-1,1] each node height falls in this range due to rescaling in normalizing the data. - out: torch.FloatTensor - The shape of the resulting tensor after summation. It has to be of the - shape [num_discretization_steps, batch_size, num_thetas] """ # Compute the node heigths - nh = data.x @ v + nh = batch.x @ v # Perform a lookup with the edge indices on node heights, this replaces the # node index with its node height and then compute the maximum over the # columns to compute the edge height. - eh, _ = nh[data.edge_index].max(dim=0) + eh, _ = nh[batch.edge_index].max(dim=0) # Compute which batch an edge belongs to. We take the first index of the # edge (or faces) and do a lookup on the batch index of that node in the # batch indices of the nodes. - batch_index_nodes = data.batch - batch_index_edges = data.batch[data.edge_index[0]] + batch_index_nodes = batch.batch + batch_index_edges = batch.batch[batch.edge_index[0]] return compute_ecc(nh, batch_index_nodes, lin) - compute_ecc( eh, batch_index_edges, lin @@ -168,7 +163,7 @@ def compute_ect_edges( def compute_ect_faces( - data: Batch, v: torch.FloatTensor, lin: torch.FloatTensor + batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor ): """Computes the Euler Characteristic Transform of a batch of meshes. @@ -182,27 +177,24 @@ def compute_ect_faces( lin: torch.FloatTensor The discretization of the interval [-1,1] each node height falls in this range due to rescaling in normalizing the data. - out: torch.FloatTensor - The shape of the resulting tensor after summation. It has to be of the - shape [num_discretization_steps, batch_size, num_thetas] """ # Compute the node heigths - nh = data.x @ v + nh = batch.x @ v # Perform a lookup with the edge indices on node heights, this replaces the # node index with its node height and then compute the maximum over the # columns to compute the edge height. - eh, _ = nh[data.edge_index].max(dim=0) + eh, _ = nh[batch.edge_index].max(dim=0) # Do the same thing for the faces. - fh, _ = nh[data.face].max(dim=0) + fh, _ = nh[batch.face].max(dim=0) # Compute which batch an edge belongs to. We take the first index of the # edge (or faces) and do a lookup on the batch index of that node in the # batch indices of the nodes. - batch_index_nodes = data.batch - batch_index_edges = data.batch[data.edge_index[0]] - batch_index_faces = data.batch[data.face[0]] + batch_index_nodes = batch.batch + batch_index_edges = batch.batch[batch.edge_index[0]] + batch_index_faces = batch.batch[batch.face[0]] return ( compute_ecc(nh, batch_index_nodes, lin) diff --git a/dect/wect.py b/dect/wect.py new file mode 100644 index 0000000..7498e12 --- /dev/null +++ b/dect/wect.py @@ -0,0 +1,156 @@ +from typing import Literal + +import geotorch +import torch +from torch import nn +from torch_scatter import segment_add_coo + +from ect import ECTConfig, Batch, normalize + + +def compute_wecc( + nh: torch.FloatTensor, + index: torch.LongTensor, + lin: torch.FloatTensor, + weight: torch.FloatTensor, + scale: float = 500, +): + """Computes the weighted Euler Characteristic curve. + + Parameters + ---------- + nh : torch.FloatTensor + The node heights, computed as the inner product of the node coordinates + x and the direction vector v. + index: torch.LongTensor + The index that indicates to which pointcloud a node height belongs. For + the node heights it is the same as the batch index, for the higher order + simplices it will have to be recomputed. + lin: torch.FloatTensor + The discretization of the interval [-1,1] each node height falls in this + range due to rescaling in normalizing the data. + weight: torch.FloatTensor + The weight of the node, edge or face. It is the maximum of the node + weights for the edges and faces. + scale: torch.FloatTensor + A single number that scales the sigmoid function by multiplying the + sigmoid with the scale. With high (100>) values, the ect will resemble a + discrete ECT and with lower values it will smooth the ECT. + """ + ecc = torch.nn.functional.sigmoid(scale * torch.sub(lin, nh)) * weight.view( + 1, -1, 1 + ) + ecc = ecc.movedim(0, 2).movedim(0, 1) + return segment_add_coo(ecc, index) + + +def compute_wect( + batch: Batch, + v: torch.FloatTensor, + lin: torch.FloatTensor, + wect_type: Literal["points"] | Literal["edges"] | Literal["faces"], +): + """Computes the Weighted Euler Characteristic Transform of a batch of point clouds. + + Parameters + ---------- + batch : Batch + A batch of data containing the node coordinates, batch index, edge_index, face, and + node weights. + v: torch.FloatTensor + The direction vector that contains the directions. + lin: torch.FloatTensor + The discretization of the interval [-1,1] each node height falls in this + range due to rescaling in normalizing the data. + wect_type: str + The type of WECT to compute. Can be "points", "edges", or "faces". + """ + nh = batch.x @ v + if wect_type in ["edges", "faces"]: + edge_weights, _ = batch.node_weights[batch.edge_index].max(axis=0) + eh, _ = nh[batch.edge_index].min(dim=0) + if wect_type == "faces": + face_weights, _ = batch.node_weights[batch.face].max(axis=0) + fh, _ = nh[batch.face].min(dim=0) + + if wect_type == "points": + return compute_wecc(nh, batch.batch, lin, batch.node_weights) + if wect_type == "edges": + # noinspection PyUnboundLocalVariable + return compute_wecc(nh, batch.batch, lin, batch.node_weights) - compute_wecc( + eh, batch.batch[batch.edge_index[0]], lin, edge_weights + ) + if wect_type == "faces": + # noinspection PyUnboundLocalVariable + return ( + compute_wecc(nh, batch.batch, lin, batch.node_weights) + - compute_wecc(eh, batch.batch[batch.edge_index[0]], lin, edge_weights) + + compute_wecc(fh, batch.batch[batch.face[0]], lin, face_weights) + ) + raise ValueError(f"Invalid wect_type: {wect_type}") + + +class WECTLayer(nn.Module): + """Machine learning layer for computing the WECT (Weighted ECT). + + Parameters + ---------- + v: torch.FloatTensor + The direction vector that contains the directions. The shape of the + tensor v is either [ndims, num_thetas] or [n_channels, ndims, num_thetas]. + config : ECTConfig + The configuration config of the WECT layer. + """ + + def __init__(self, config: ECTConfig, v=None): + super().__init__() + self.config = config + self.lin = nn.Parameter( + torch.linspace(-config.radius, config.radius, config.bump_steps).view( + -1, 1, 1, 1 + ), + requires_grad=False, + ) + + # If provided with one set of directions. + # For backwards compatibility. + if v.ndim == 2: + v.unsqueeze(0) + + # The set of directions is added + if config.fixed: + self.v = nn.Parameter(v.movedim(-1, -2), requires_grad=False) + else: + self.v = nn.Parameter(torch.zeros_like(v.movedim(-1, -2))) + geotorch.constraints.sphere(self, "v", radius=config.radius) + # Since geotorch randomizes the vector during initialization, we + # assign the values after registering it with spherical constraints. + # See Geotorch documentation for examples. + self.v = v.movedim(-1, -2) + + def forward(self, batch: Batch): + """Forward method for the ECT Layer. + + + Parameters + ---------- + batch : Batch + A batch of data containing the node coordinates, edges, faces, + batch index, and node_weights. It should follow the pytorch geometric conventions. + + Returns + ---------- + wect: torch.FloatTensor + Returns the WECT of each data object in the batch. If the layer is + initialized with v of the shape [ndims,num_thetas], the returned WECT + has shape [batch,num_thetas,bump_steps]. In case the layer is + initialized with v of the form [n_channels, ndims, num_thetas] the + returned WECT has the shape [batch,n_channels,num_thetas,bump_steps] + """ + # Movedim for geotorch + wect = compute_wect( + batch, self.v.movedim(-1, -2), self.lin, self.config.ect_type + ) + if self.config.normalized: + return normalize(wect) + return wect.squeeze() diff --git a/requirements.txt b/requirements.txt index a9e681e..0da2fc4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ torch -torch-scatter \ No newline at end of file +torch-scatter +geotorch \ No newline at end of file From 08ab93944297cb41df60754932df29642ef8e3f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rub=C3=A9n=20Ballester?= Date: Wed, 17 Jul 2024 19:28:05 +0200 Subject: [PATCH 2/6] Adding black + some minor changes in wect and the README.md --- README.md | 2 +- dect/ect.py | 18 ++++------ dect/wect.py | 98 ++++++++++++++++++++++++++-------------------------- 3 files changed, 56 insertions(+), 62 deletions(-) diff --git a/README.md b/README.md index 6f2bdec..8c9ac64 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ from dect.directions import generate_uniform_2d_directions v = generate_uniform_2d_directions(num_thetas=64) -layer = ECTLayer(ECTConfig(), V=v) +layer = ECTLayer(ECTConfig(), v=v) points_coordinates = torch.tensor( [[0.5, 0.0], [-0.5, 0.0], [0.5, 0.5]], requires_grad=True diff --git a/dect/ect.py b/dect/ect.py index 363506d..81edeb2 100644 --- a/dect/ect.py +++ b/dect/ect.py @@ -108,9 +108,7 @@ def compute_ecc( return segment_add_coo(ecc, index) -def compute_ect_points( - batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor -): +def compute_ect_points(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor): """Computes the Euler Characteristic Transform of a batch of point clouds. Parameters @@ -127,9 +125,7 @@ def compute_ect_points( return compute_ecc(nh, batch.batch, lin) -def compute_ect_edges( - batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor -): +def compute_ect_edges(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor): """Computes the Euler Characteristic Transform of a batch of graphs. Parameters @@ -162,9 +158,7 @@ def compute_ect_edges( ) -def compute_ect_faces( - batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor -): +def compute_ect_faces(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor): """Computes the Euler Characteristic Transform of a batch of meshes. Parameters @@ -225,9 +219,9 @@ def __init__(self, config: ECTConfig, v=None): super().__init__() self.config = config self.lin = nn.Parameter( - torch.linspace( - -config.radius, config.radius, config.bump_steps - ).view(-1, 1, 1, 1), + torch.linspace(-config.radius, config.radius, config.bump_steps).view( + -1, 1, 1, 1 + ), requires_grad=False, ) diff --git a/dect/wect.py b/dect/wect.py index 7498e12..b2a188c 100644 --- a/dect/wect.py +++ b/dect/wect.py @@ -5,7 +5,7 @@ from torch import nn from torch_scatter import segment_add_coo -from ect import ECTConfig, Batch, normalize +from dect.ect import ECTConfig, Batch, normalize def compute_wecc( @@ -17,26 +17,26 @@ def compute_wecc( ): """Computes the weighted Euler Characteristic curve. - Parameters - ---------- - nh : torch.FloatTensor - The node heights, computed as the inner product of the node coordinates - x and the direction vector v. - index: torch.LongTensor - The index that indicates to which pointcloud a node height belongs. For - the node heights it is the same as the batch index, for the higher order - simplices it will have to be recomputed. - lin: torch.FloatTensor - The discretization of the interval [-1,1] each node height falls in this - range due to rescaling in normalizing the data. - weight: torch.FloatTensor - The weight of the node, edge or face. It is the maximum of the node - weights for the edges and faces. - scale: torch.FloatTensor - A single number that scales the sigmoid function by multiplying the - sigmoid with the scale. With high (100>) values, the ect will resemble a - discrete ECT and with lower values it will smooth the ECT. - """ + Parameters + ---------- + nh : torch.FloatTensor + The node heights, computed as the inner product of the node coordinates + x and the direction vector v. + index: torch.LongTensor + The index that indicates to which pointcloud a node height belongs. For + the node heights it is the same as the batch index, for the higher order + simplices it will have to be recomputed. + lin: torch.FloatTensor + The discretization of the interval [-1,1] each node height falls in this + range due to rescaling in normalizing the data. + weight: torch.FloatTensor + The weight of the node, edge or face. It is the maximum of the node + weights for the edges and faces. + scale: torch.FloatTensor + A single number that scales the sigmoid function by multiplying the + sigmoid with the scale. With high (100>) values, the ect will resemble a + discrete ECT and with lower values it will smooth the ECT. + """ ecc = torch.nn.functional.sigmoid(scale * torch.sub(lin, nh)) * weight.view( 1, -1, 1 ) @@ -52,19 +52,19 @@ def compute_wect( ): """Computes the Weighted Euler Characteristic Transform of a batch of point clouds. - Parameters - ---------- - batch : Batch - A batch of data containing the node coordinates, batch index, edge_index, face, and - node weights. - v: torch.FloatTensor - The direction vector that contains the directions. - lin: torch.FloatTensor - The discretization of the interval [-1,1] each node height falls in this - range due to rescaling in normalizing the data. - wect_type: str - The type of WECT to compute. Can be "points", "edges", or "faces". - """ + Parameters + ---------- + batch : Batch + A batch of data containing the node coordinates, batch index, edge_index, face, and + node weights. + v: torch.FloatTensor + The direction vector that contains the directions. + lin: torch.FloatTensor + The discretization of the interval [-1,1] each node height falls in this + range due to rescaling in normalizing the data. + wect_type: str + The type of WECT to compute. Can be "points", "edges", or "faces". + """ nh = batch.x @ v if wect_type in ["edges", "faces"]: edge_weights, _ = batch.node_weights[batch.edge_index].max(axis=0) @@ -132,21 +132,21 @@ def forward(self, batch: Batch): """Forward method for the ECT Layer. - Parameters - ---------- - batch : Batch - A batch of data containing the node coordinates, edges, faces, - batch index, and node_weights. It should follow the pytorch geometric conventions. - - Returns - ---------- - wect: torch.FloatTensor - Returns the WECT of each data object in the batch. If the layer is - initialized with v of the shape [ndims,num_thetas], the returned WECT - has shape [batch,num_thetas,bump_steps]. In case the layer is - initialized with v of the form [n_channels, ndims, num_thetas] the - returned WECT has the shape [batch,n_channels,num_thetas,bump_steps] - """ + Parameters + ---------- + batch : Batch + A batch of data containing the node coordinates, edges, faces, + batch index, and node_weights. It should follow the pytorch geometric conventions. + + Returns + ---------- + wect: torch.FloatTensor + Returns the WECT of each data object in the batch. If the layer is + initialized with v of the shape [ndims,num_thetas], the returned WECT + has shape [batch,num_thetas,bump_steps]. In case the layer is + initialized with v of the form [n_channels, ndims, num_thetas] the + returned WECT has the shape [batch,n_channels,num_thetas,bump_steps] + """ # Movedim for geotorch wect = compute_wect( batch, self.v.movedim(-1, -2), self.lin, self.config.ect_type From d42812e11ea8435e79346b16604d59acc65a48aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rub=C3=A9n=20Ballester?= Date: Thu, 18 Jul 2024 07:13:40 +0200 Subject: [PATCH 3/6] Some minor documentation changes. --- dect/ect.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dect/ect.py b/dect/ect.py index 81edeb2..dda992c 100644 --- a/dect/ect.py +++ b/dect/ect.py @@ -60,9 +60,10 @@ class Batch: face: The indices of the points that span a face in the simplicial complex. Conforms to pytorch_geometric standards. Shape has to be of the form - [3,num_faces]. + [3,num_faces] or [4, num_faces], depending on the type of complex + (simplicial or cubical). node_weights: torch.FloatTensor - The weights of the nodes in the complex. The shape has to be + Optional weights for the nodes in the complex. The shape has to be [num_nodes,]. """ From c8107e9c4b67f6921b1a969731fbc20cfdf92928 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rub=C3=A9n=20Ballester?= Date: Thu, 18 Jul 2024 07:30:31 +0200 Subject: [PATCH 4/6] Black with line length constraint applied. --- dect/ect.py | 18 ++++++++++++------ dect/wect.py | 14 +++++++++----- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/dect/ect.py b/dect/ect.py index dda992c..1248a1f 100644 --- a/dect/ect.py +++ b/dect/ect.py @@ -109,7 +109,9 @@ def compute_ecc( return segment_add_coo(ecc, index) -def compute_ect_points(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor): +def compute_ect_points( + batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor +): """Computes the Euler Characteristic Transform of a batch of point clouds. Parameters @@ -126,7 +128,9 @@ def compute_ect_points(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTenso return compute_ecc(nh, batch.batch, lin) -def compute_ect_edges(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor): +def compute_ect_edges( + batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor +): """Computes the Euler Characteristic Transform of a batch of graphs. Parameters @@ -159,7 +163,9 @@ def compute_ect_edges(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor ) -def compute_ect_faces(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor): +def compute_ect_faces( + batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor +): """Computes the Euler Characteristic Transform of a batch of meshes. Parameters @@ -220,9 +226,9 @@ def __init__(self, config: ECTConfig, v=None): super().__init__() self.config = config self.lin = nn.Parameter( - torch.linspace(-config.radius, config.radius, config.bump_steps).view( - -1, 1, 1, 1 - ), + torch.linspace( + -config.radius, config.radius, config.bump_steps + ).view(-1, 1, 1, 1), requires_grad=False, ) diff --git a/dect/wect.py b/dect/wect.py index b2a188c..8e0eb40 100644 --- a/dect/wect.py +++ b/dect/wect.py @@ -77,14 +77,18 @@ def compute_wect( return compute_wecc(nh, batch.batch, lin, batch.node_weights) if wect_type == "edges": # noinspection PyUnboundLocalVariable - return compute_wecc(nh, batch.batch, lin, batch.node_weights) - compute_wecc( + return compute_wecc( + nh, batch.batch, lin, batch.node_weights + ) - compute_wecc( eh, batch.batch[batch.edge_index[0]], lin, edge_weights ) if wect_type == "faces": # noinspection PyUnboundLocalVariable return ( compute_wecc(nh, batch.batch, lin, batch.node_weights) - - compute_wecc(eh, batch.batch[batch.edge_index[0]], lin, edge_weights) + - compute_wecc( + eh, batch.batch[batch.edge_index[0]], lin, edge_weights + ) + compute_wecc(fh, batch.batch[batch.face[0]], lin, face_weights) ) raise ValueError(f"Invalid wect_type: {wect_type}") @@ -106,9 +110,9 @@ def __init__(self, config: ECTConfig, v=None): super().__init__() self.config = config self.lin = nn.Parameter( - torch.linspace(-config.radius, config.radius, config.bump_steps).view( - -1, 1, 1, 1 - ), + torch.linspace( + -config.radius, config.radius, config.bump_steps + ).view(-1, 1, 1, 1), requires_grad=False, ) From c110676b3a2cbb043a7a2ed144aa9ec26f023712 Mon Sep 17 00:00:00 2001 From: ErnstRoell Date: Fri, 19 Jul 2024 10:29:39 +0200 Subject: [PATCH 5/6] Reformatted the comments to line length 80. --- dect/wect.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/dect/wect.py b/dect/wect.py index 8e0eb40..67dd526 100644 --- a/dect/wect.py +++ b/dect/wect.py @@ -15,7 +15,7 @@ def compute_wecc( weight: torch.FloatTensor, scale: float = 500, ): - """Computes the weighted Euler Characteristic curve. + """Computes the Weighted Euler Characteristic Curve. Parameters ---------- @@ -50,13 +50,15 @@ def compute_wect( lin: torch.FloatTensor, wect_type: Literal["points"] | Literal["edges"] | Literal["faces"], ): - """Computes the Weighted Euler Characteristic Transform of a batch of point clouds. + """ + Computes the Weighted Euler Characteristic Transform of a batch of point + clouds. Parameters ---------- batch : Batch - A batch of data containing the node coordinates, batch index, edge_index, face, and - node weights. + A batch of data containing the node coordinates, batch index, + edge_index, face, and node weights. v: torch.FloatTensor The direction vector that contains the directions. lin: torch.FloatTensor @@ -101,9 +103,10 @@ class WECTLayer(nn.Module): ---------- v: torch.FloatTensor The direction vector that contains the directions. The shape of the - tensor v is either [ndims, num_thetas] or [n_channels, ndims, num_thetas]. + tensor v is either [ndims, num_thetas] or [n_channels, ndims, + num_thetas]. config : ECTConfig - The configuration config of the WECT layer. + The configuration object of the WECT layer. """ def __init__(self, config: ECTConfig, v=None): @@ -133,21 +136,22 @@ def __init__(self, config: ECTConfig, v=None): self.v = v.movedim(-1, -2) def forward(self, batch: Batch): - """Forward method for the ECT Layer. + """Forward method for the WECT Layer. Parameters ---------- batch : Batch - A batch of data containing the node coordinates, edges, faces, - batch index, and node_weights. It should follow the pytorch geometric conventions. + A batch of data containing the node coordinates, edges, faces, batch + index, and node_weights. It should follow the pytorch geometric + conventions. Returns ---------- wect: torch.FloatTensor Returns the WECT of each data object in the batch. If the layer is - initialized with v of the shape [ndims,num_thetas], the returned WECT - has shape [batch,num_thetas,bump_steps]. In case the layer is + initialized with v of the shape [ndims,num_thetas], the returned + WECT has shape [batch,num_thetas,bump_steps]. In case the layer is initialized with v of the form [n_channels, ndims, num_thetas] the returned WECT has the shape [batch,n_channels,num_thetas,bump_steps] """ From a0f7638a2ab81ab4601667bdee69228005eefcd8 Mon Sep 17 00:00:00 2001 From: ErnstRoell Date: Fri, 19 Jul 2024 10:31:48 +0200 Subject: [PATCH 6/6] More reformatting in dect.ect. --- dect/ect.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dect/ect.py b/dect/ect.py index 1248a1f..2f14f9d 100644 --- a/dect/ect.py +++ b/dect/ect.py @@ -80,7 +80,7 @@ def compute_ecc( lin: torch.FloatTensor, scale: float = 100, ) -> torch.FloatTensor: - """Computes the Euler Characteristic curve. + """Computes the Euler Characteristic Curve. Parameters ---------- @@ -216,7 +216,8 @@ class ECTLayer(nn.Module): ---------- v: torch.FloatTensor The direction vector that contains the directions. The shape of the - tensor v is either [ndims, num_thetas] or [n_channels, ndims, num_thetas]. + tensor v is either [ndims, num_thetas] or [n_channels, ndims, + num_thetas]. config: ECTConfig The configuration config of the ECT layer.