From dc01ec8c8587f5bfed3e8fe0f9c12621f1f28bd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iago=20Su=C3=A1rez?= Date: Mon, 12 Feb 2024 08:43:21 +0100 Subject: [PATCH] Fixing #37 allow_shifted=False in MotionBlur (#38) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fixing https://github.com/cvg/glue-factory/issues/37 * Black formatting --------- Co-authored-by: Rémi Pautrat <32239569+rpautrat@users.noreply.github.com> Co-authored-by: Rémi Pautrat --- gluefactory/datasets/augmentations.py | 12 +++++++++--- gluefactory/datasets/eth3d.py | 1 + gluefactory/datasets/homographies.py | 2 +- gluefactory/datasets/hpatches.py | 3 ++- gluefactory/geometry/gt_generation.py | 6 +++--- gluefactory/models/cache_loader.py | 17 +++++++++++------ gluefactory/models/extractors/aliked.py | 8 +++++--- .../models/extractors/superpoint_open.py | 1 + gluefactory/models/lines/wireframe.py | 6 +++--- tests/test_integration.py | 8 +++++--- 10 files changed, 41 insertions(+), 23 deletions(-) diff --git a/gluefactory/datasets/augmentations.py b/gluefactory/datasets/augmentations.py index bd391294..7541787c 100644 --- a/gluefactory/datasets/augmentations.py +++ b/gluefactory/datasets/augmentations.py @@ -191,7 +191,13 @@ def _init(self, conf): [ A.Blur(**kwi(blur, p=0.1, blur_limit=(3, 9), n="blur")), A.MotionBlur( - **kwi(blur, p=0.2, blur_limit=(3, 25), n="motion_blur") + **kwi( + blur, + p=0.2, + blur_limit=(3, 25), + allow_shifted=False, + n="motion_blur", + ) ), A.ISONoise(), A.ImageCompression(), @@ -222,14 +228,14 @@ def _init(self, conf): A.OneOf( [ A.Blur(blur_limit=(3, 9)), - A.MotionBlur(blur_limit=(3, 25)), + A.MotionBlur(blur_limit=(3, 25), allow_shifted=False), A.ISONoise(), A.ImageCompression(), ], p=0.1, ), A.Blur(p=0.1, blur_limit=(3, 9)), - A.MotionBlur(p=0.1, blur_limit=(3, 25)), + A.MotionBlur(p=0.1, blur_limit=(3, 25), allow_shifted=False), A.RandomBrightnessContrast( p=0.5, brightness_limit=(-0.4, 0.0), contrast_limit=(-0.3, 0.0) ), diff --git a/gluefactory/datasets/eth3d.py b/gluefactory/datasets/eth3d.py index 44fd73f8..953d775e 100644 --- a/gluefactory/datasets/eth3d.py +++ b/gluefactory/datasets/eth3d.py @@ -1,6 +1,7 @@ """ ETH3D multi-view benchmark, used for line matching evaluation. """ + import logging import os import shutil diff --git a/gluefactory/datasets/homographies.py b/gluefactory/datasets/homographies.py index 08f7563c..9db7f2fb 100644 --- a/gluefactory/datasets/homographies.py +++ b/gluefactory/datasets/homographies.py @@ -293,7 +293,7 @@ def visualize(args): images = [] for _, data in zip(range(args.num_items), loader): images.append( - (data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2)) + [data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2)] ) plot_image_grid(images, dpi=args.dpi) plt.tight_layout() diff --git a/gluefactory/datasets/hpatches.py b/gluefactory/datasets/hpatches.py index baf4ac8e..cf4c7993 100644 --- a/gluefactory/datasets/hpatches.py +++ b/gluefactory/datasets/hpatches.py @@ -1,6 +1,7 @@ """ Simply load images from a folder or nested folders (does not have any split). """ + import argparse import logging import tarfile @@ -127,7 +128,7 @@ def visualize(args): images = [] for _, data in zip(range(args.num_items), loader): images.append( - (data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2)) + [data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2)] ) plot_image_grid(images, dpi=args.dpi) plt.tight_layout() diff --git a/gluefactory/geometry/gt_generation.py b/gluefactory/geometry/gt_generation.py index 21390cd7..b80a7778 100644 --- a/gluefactory/geometry/gt_generation.py +++ b/gluefactory/geometry/gt_generation.py @@ -375,9 +375,9 @@ def gt_line_matches_from_pose_depth( all_in_batch = ( torch.arange(b_size)[:, None].repeat(1, assignation.shape[-1]).flatten() ) - positive[ - all_in_batch, assignation[:, 0].flatten(), assignation[:, 1].flatten() - ] = True + positive[all_in_batch, assignation[:, 0].flatten(), assignation[:, 1].flatten()] = ( + True + ) m0 = assignation.new_full((b_size, n_lines0), unmatched, dtype=torch.long) m0.scatter_(-1, assignation[:, 0], assignation[:, 1]) diff --git a/gluefactory/models/cache_loader.py b/gluefactory/models/cache_loader.py index b345a997..837421ba 100644 --- a/gluefactory/models/cache_loader.py +++ b/gluefactory/models/cache_loader.py @@ -47,9 +47,11 @@ def pad_line_features(pred, seq_l: int = None): def recursive_load(grp, pkeys): return { - k: torch.from_numpy(grp[k].__array__()) - if isinstance(grp[k], h5py.Dataset) - else recursive_load(grp[k], list(grp.keys())) + k: ( + torch.from_numpy(grp[k].__array__()) + if isinstance(grp[k], h5py.Dataset) + else recursive_load(grp[k], list(grp.keys())) + ) for k in pkeys } @@ -108,9 +110,12 @@ def _forward(self, data): pred = recursive_load(grp, pkeys) if self.numeric_dtype is not None: pred = { - k: v - if not isinstance(v, torch.Tensor) or not torch.is_floating_point(v) - else v.to(dtype=self.numeric_dtype) + k: ( + v + if not isinstance(v, torch.Tensor) + or not torch.is_floating_point(v) + else v.to(dtype=self.numeric_dtype) + ) for k, v in pred.items() } pred = batch_to_device(pred, device) diff --git a/gluefactory/models/extractors/aliked.py b/gluefactory/models/extractors/aliked.py index 80cd348a..254a434e 100644 --- a/gluefactory/models/extractors/aliked.py +++ b/gluefactory/models/extractors/aliked.py @@ -717,9 +717,11 @@ def _init(self, conf): radius=conf.nms_radius, top_k=-1 if conf.detection_threshold > 0 else conf.max_num_keypoints, scores_th=conf.detection_threshold, - n_limit=conf.max_num_keypoints - if conf.max_num_keypoints > 0 - else self.n_limit_max, + n_limit=( + conf.max_num_keypoints + if conf.max_num_keypoints > 0 + else self.n_limit_max + ), ) # load pretrained diff --git a/gluefactory/models/extractors/superpoint_open.py b/gluefactory/models/extractors/superpoint_open.py index 1f960407..434e0a1d 100644 --- a/gluefactory/models/extractors/superpoint_open.py +++ b/gluefactory/models/extractors/superpoint_open.py @@ -5,6 +5,7 @@ The implementation of this model and its trained weights are made available under the MIT license. """ + from collections import OrderedDict from types import SimpleNamespace diff --git a/gluefactory/models/lines/wireframe.py b/gluefactory/models/lines/wireframe.py index ac0d0b5a..8f541c6a 100644 --- a/gluefactory/models/lines/wireframe.py +++ b/gluefactory/models/lines/wireframe.py @@ -256,9 +256,9 @@ def _forward(self, data): associativity = torch.eye( len(all_points[-1]), dtype=torch.bool, device=device ) - associativity[ - : n_true_junctions[bs], : n_true_junctions[bs] - ] = line_association[bs][: n_true_junctions[bs], : n_true_junctions[bs]] + associativity[: n_true_junctions[bs], : n_true_junctions[bs]] = ( + line_association[bs][: n_true_junctions[bs], : n_true_junctions[bs]] + ) pl_associativity.append(associativity) all_points = torch.stack(all_points, dim=0) diff --git a/tests/test_integration.py b/tests/test_integration.py index e459ada5..3592cff1 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -35,9 +35,11 @@ def create_input_data(cv_img0, cv_img1, device): data = {"view0": ip(img0), "view1": ip(img1)} data = map_tensor( data, - lambda t: t[None].to(device) - if isinstance(t, Tensor) - else torch.from_numpy(t)[None].to(device), + lambda t: ( + t[None].to(device) + if isinstance(t, Tensor) + else torch.from_numpy(t)[None].to(device) + ), ) return data