diff --git a/hloc/extract_features.py b/hloc/extract_features.py index e3da01ef..70bb5cb6 100644 --- a/hloc/extract_features.py +++ b/hloc/extract_features.py @@ -170,9 +170,10 @@ class ImageDataset(torch.utils.data.Dataset): 'interpolation': 'cv2_area', # pil_linear is more accurate but slower } - def __init__(self, root, conf, paths=None): + def __init__(self, root, conf, paths=None, mask_dir: Optional[Path]=None): self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf}) self.root = root + self.mask_dir = mask_dir if paths is None: paths = [] @@ -220,6 +221,10 @@ def __getitem__(self, idx): 'image': image, 'original_size': np.array(size), } + if self.mask_dir: + mask_path = self.mask_dir / f'{name}.png' + if mask_path.exists(): + data['mask'] = read_image(mask_path, True) return data def __len__(self): @@ -233,11 +238,12 @@ def main(conf: Dict, as_half: bool = True, image_list: Optional[Union[Path, List[str]]] = None, feature_path: Optional[Path] = None, - overwrite: bool = False) -> Path: + overwrite: bool = False, + mask_dir: Optional[Path] = None) -> Path: logger.info('Extracting local features with configuration:' f'\n{pprint.pformat(conf)}') - dataset = ImageDataset(image_dir, conf['preprocessing'], image_list) + dataset = ImageDataset(image_dir, conf['preprocessing'], image_list, mask_dir) if feature_path is None: feature_path = Path(export_dir, conf['output']+'.h5') feature_path.parent.mkdir(exist_ok=True, parents=True) @@ -268,6 +274,12 @@ def main(conf: Dict, pred['scales'] *= scales.mean() # add keypoint uncertainties scaled to the original resolution uncertainty = getattr(model, 'detection_noise', 1) * scales.mean() + if 'mask' in data: + mask = data['mask'][0] # cuz `batch_size == 1` + valid_keypoint = mask[pred['keypoints'][:, 1].astype('int'), pred['keypoints'][:, 0].astype('int')] + pred['keypoints'] = pred['keypoints'][valid_keypoint > 0] + pred['descriptors'] = pred['descriptors'][:, valid_keypoint > 0] + pred['scores'] = pred['scores'][valid_keypoint > 0] if as_half: for k in pred: @@ -307,5 +319,6 @@ def main(conf: Dict, parser.add_argument('--as_half', action='store_true') parser.add_argument('--image_list', type=Path) parser.add_argument('--feature_path', type=Path) + parser.add_argument('--mask_dir', type=Path) args = parser.parse_args() - main(confs[args.conf], args.image_dir, args.export_dir, args.as_half) + main(confs[args.conf], args.image_dir, args.export_dir, args.as_half, mask_dir=args.mask_dir) diff --git a/hloc/extractors/disk.py b/hloc/extractors/disk.py index dc04280c..7a10c88a 100644 --- a/hloc/extractors/disk.py +++ b/hloc/extractors/disk.py @@ -27,6 +27,6 @@ def _forward(self, data): ) return { 'keypoints': [f.keypoints for f in features], - 'keypoint_scores': [f.detection_scores for f in features], + 'scores': [f.detection_scores for f in features], 'descriptors': [f.descriptors.t() for f in features], }