diff --git a/dect/ect.py b/dect/ect.py index dda992c..1248a1f 100644 --- a/dect/ect.py +++ b/dect/ect.py @@ -109,7 +109,9 @@ def compute_ecc( return segment_add_coo(ecc, index) -def compute_ect_points(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor): +def compute_ect_points( + batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor +): """Computes the Euler Characteristic Transform of a batch of point clouds. Parameters @@ -126,7 +128,9 @@ def compute_ect_points(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTenso return compute_ecc(nh, batch.batch, lin) -def compute_ect_edges(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor): +def compute_ect_edges( + batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor +): """Computes the Euler Characteristic Transform of a batch of graphs. Parameters @@ -159,7 +163,9 @@ def compute_ect_edges(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor ) -def compute_ect_faces(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor): +def compute_ect_faces( + batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor +): """Computes the Euler Characteristic Transform of a batch of meshes. Parameters @@ -220,9 +226,9 @@ def __init__(self, config: ECTConfig, v=None): super().__init__() self.config = config self.lin = nn.Parameter( - torch.linspace(-config.radius, config.radius, config.bump_steps).view( - -1, 1, 1, 1 - ), + torch.linspace( + -config.radius, config.radius, config.bump_steps + ).view(-1, 1, 1, 1), requires_grad=False, ) diff --git a/dect/wect.py b/dect/wect.py index b2a188c..8e0eb40 100644 --- a/dect/wect.py +++ b/dect/wect.py @@ -77,14 +77,18 @@ def compute_wect( return compute_wecc(nh, batch.batch, lin, batch.node_weights) if wect_type == "edges": # noinspection PyUnboundLocalVariable - return compute_wecc(nh, batch.batch, lin, batch.node_weights) - compute_wecc( + return compute_wecc( + nh, batch.batch, lin, batch.node_weights + ) - compute_wecc( eh, batch.batch[batch.edge_index[0]], lin, edge_weights ) if wect_type == "faces": # noinspection PyUnboundLocalVariable return ( compute_wecc(nh, batch.batch, lin, batch.node_weights) - - compute_wecc(eh, batch.batch[batch.edge_index[0]], lin, edge_weights) + - compute_wecc( + eh, batch.batch[batch.edge_index[0]], lin, edge_weights + ) + compute_wecc(fh, batch.batch[batch.face[0]], lin, face_weights) ) raise ValueError(f"Invalid wect_type: {wect_type}") @@ -106,9 +110,9 @@ def __init__(self, config: ECTConfig, v=None): super().__init__() self.config = config self.lin = nn.Parameter( - torch.linspace(-config.radius, config.radius, config.bump_steps).view( - -1, 1, 1, 1 - ), + torch.linspace( + -config.radius, config.radius, config.bump_steps + ).view(-1, 1, 1, 1), requires_grad=False, )