Skip to content

Commit

Permalink
Merge pull request #78 from cvg/dev - v1.1
Browse files Browse the repository at this point in the history
- **[BREAKING]** improved structure of the SfM folders (triangulation and reconstruction), see [#76](#76)
- Support for image retrieval (NetVLAD, DIR) and more local features (SIFT, R2D2)
- Support for more datasets: Aachen v1.1, Extended CMU Seasons, RobotCar Seasons, 4Seasons, Cambridge Landmarks, 7-Scenes
- Simplified pipeline and API
- Spatial matcher
- Support for arbitrary paths of features and matches
- Support for matching multiple feature files together
  • Loading branch information
skydes authored Jul 17, 2021
2 parents e64814c + bd268b4 commit 91f40bf
Show file tree
Hide file tree
Showing 63 changed files with 207,970 additions and 19,332 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
__pycache__
*.pyc
*.egg-info
.ipynb_checkpoints
outputs/
6 changes: 6 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,9 @@
path = third_party/SuperGluePretrainedNetwork
url = https://github.com/skydes/SuperGluePretrainedNetwork.git
branch = fix-memory
[submodule "third_party/deep-image-retrieval"]
path = third_party/deep-image-retrieval
url = https://github.com/naver/deep-image-retrieval.git
[submodule "third_party/r2d2"]
path = third_party/r2d2
url = https://github.com/naver/r2d2.git
47 changes: 38 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# hloc - the hierarchical localization toolbox

This is `hloc`, a modular toolbox for state-of-the-art 6-DoF visual localization. It implements [Hierarchical Localization](https://arxiv.org/abs/1812.03506), leveraging image retrieval and feature matching, and is fast, accurate, and scalable. This codebase won the indoor/outdoor [localization challenge at CVPR 2020](https://sites.google.com/view/vislocslamcvpr2020/home), in combination with [SuperGlue](https://psarlin.com/superglue/), our graph neural network for feature matching.
This is `hloc`, a modular toolbox for state-of-the-art 6-DoF visual localization. It implements [Hierarchical Localization](https://arxiv.org/abs/1812.03506), leveraging image retrieval and feature matching, and is fast, accurate, and scalable. This codebase won the indoor/outdoor localization challenges at [CVPR 2020](https://sites.google.com/view/vislocslamcvpr2020/home) and [ECCV 2020](https://sites.google.com/view/ltvl2020/), in combination with [SuperGlue](https://psarlin.com/superglue/), our graph neural network for feature matching.

With `hloc`, you can:

Expand Down Expand Up @@ -34,8 +34,6 @@ docker run -it --rm -p 8888:8888 hloc:latest # for GPU support, add `--runtime=
jupyter notebook --ip 0.0.0.0 --port 8888 --no-browser --allow-root
```



## General pipeline

The toolbox is composed of scripts, which roughly perform the following steps:
Expand All @@ -57,6 +55,7 @@ Strcture of the toolbox:
- `hloc/*.py` : top-level scripts
- `hloc/extractors/` : interfaces for feature extractors
- `hloc/matchers/` : interfaces for feature matchers
- `hloc/pipelines/` : entire pipelines for multiple datasets

## Tasks

Expand Down Expand Up @@ -84,7 +83,11 @@ We show in [`pipeline_SfM.ipynb`](https://nbviewer.jupyter.org/github/cvg/Hierar

## Results

`hloc` currently supports [SuperPoint](https://arxiv.org/abs/1712.07629) and [D2-Net](https://arxiv.org/abs/1905.03561) local feature extractors; and [SuperGlue](https://arxiv.org/abs/1911.11763) and Nearest Neighbor matchers. Using [NetVLAD](https://arxiv.org/abs/1511.07247) for retrieval, we obtain the following best results:
- Supported local feature extractors: [SuperPoint](https://arxiv.org/abs/1712.07629), [D2-Net](https://arxiv.org/abs/1905.03561), [SIFT](https://www.cs.ubc.ca/~lowe/papers/ijcv04.pdf), and [R2D2](https://arxiv.org/abs/1906.06195).
- Supported feature matchers: [SuperGlue](https://arxiv.org/abs/1911.11763) and the Nearest Neighbor matcher with ratio test and/or mutual check.
- Supported image retrieval: [NetVLAD](https://arxiv.org/abs/1511.07247) and [AP-GeM/DIR](https://github.com/naver/deep-image-retrieval).

Using [NetVLAD](https://arxiv.org/abs/1511.07247) for retrieval, we obtain the following best results:

| Methods | Aachen day | Aachen night | Retrieval |
| ------------------------------------------------------------ | ------------------ | ------------------ | -------------- |
Expand All @@ -101,6 +104,10 @@ We show in [`pipeline_SfM.ipynb`](https://nbviewer.jupyter.org/github/cvg/Hierar

Check out [visuallocalization.net/benchmark](https://www.visuallocalization.net/benchmark) for more details and additional baselines.

## Supported datasets

We provide in [`hloc/pipelines/`](./hloc/pipelines) scripts to run the reconstruction and the localization on the following datasets: Aachen Day-Night (v1.0 and v1.1), InLoc, Extended CMU Seasons, RobotCar Seasons, 4Seasons, Cambridge Landmarks, and 7-Scenes.

## BibTex Citation

If you report any of the above results in a publication, or use any of the tools provided here, please consider citing both [Hierarchical Localization](https://arxiv.org/abs/1812.03506) and [SuperGlue](https://arxiv.org/abs/1911.11763) papers:
Expand Down Expand Up @@ -162,17 +169,39 @@ In a match file, each key corresponds to the string `path0.replace('/', '-')+'_'
<details>
<summary>[Click to expand]</summary>

For now `hloc` does not have an interface for image retrieval. You will need to export the global descriptors into an HDF5 file, in which each key corresponds to the relative path of an image w.r.t. the dataset root, and contains a dataset `global_descriptor` with size D. You can then export the images pairs with [`hloc/pairs_from_retrieval.py`](hloc/pairs_from_retrieval.py).
`hloc` also provides an interface for image retrieval via `hloc/extract_features.py`. As previously, simply add a new interface to [`hloc/extractors/`](hloc/extractors/). Alternatively, you will need to export the global descriptors into an HDF5 file, in which each key corresponds to the relative path of an image w.r.t. the dataset root, and contains a dataset `global_descriptor` with size D. You can then export the images pairs with [`hloc/pairs_from_retrieval.py`](hloc/pairs_from_retrieval.py).
</details>

## Versions

<details>
<summary>dev branch</summary>

Continuously adds new features.
</details>

<details>
<summary>v1.1 (July 2021)</summary>

- **Breaking**: improved structure of the SfM folders (triangulation and reconstruction), see [#76](https://github.com/cvg/Hierarchical-Localization/pull/76)
- Support for image retrieval (NetVLAD, DIR) and more local features (SIFT, R2D2)
- Support for more datasets: Aachen v1.1, Extended CMU Seasons, RobotCar Seasons, Cambridge Landmarks, 7-Scenes
- Simplified pipeline and API
- Spatial matcher
</details>

<details>
<summary>v1.0 (July 2020)</summary>

Initial public version.
</details>

## Contributions welcome!

External contributions are very much welcome. This is a non-exaustive list of features that might be valuable additions:

- [ ] more localization datasets (RobotCar Seasons, CMU Seasons, Aachen v1.1, Cambridge Landmarks, 7Scenes)
- [ ] covisibility clustering for InLoc
- [ ] visualization of the raw predictions (features and matches)
- [ ] interfaces for image retrieval (e.g. [DIR](https://github.com/almazan/deep-image-retrieval), [NetVLAD](https://github.com/uzh-rpg/netvlad_tf_open))
- [ ] other local features
- [ ] other local features or image retrieval

Created and maintained by [Paul-Edouard Sarlin](https://psarlin.com/).
Created and maintained by [Paul-Edouard Sarlin](https://psarlin.com/) with the help of many.
14,280 changes: 7,377 additions & 6,903 deletions doc/depth_aachen.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16,200 changes: 8,312 additions & 7,888 deletions doc/loc_aachen.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8,846 changes: 4,650 additions & 4,196 deletions doc/loc_inloc.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 0 additions & 1 deletion hloc/colmap_from_nvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def recover_database_images_and_ids(database_path):
for name, image_id, camera_id in ret:
images[name] = image_id
cameras[name] = camera_id

db.close()
logging.info(
f'Found {len(images)} images and {len(cameras)} cameras in database.')
Expand Down
138 changes: 103 additions & 35 deletions hloc/extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
import numpy as np
from tqdm import tqdm
import pprint
import collections.abc as collections

from . import extractors
from .utils.base_model import dynamic_load
from .utils.tools import map_tensor
from .utils.parsers import parse_image_lists
from .utils.io import read_image, list_h5_names


'''
Expand All @@ -34,6 +37,21 @@
'resize_max': 1024,
},
},
# Resize images to 1600px even if they are originally smaller.
# Improves the keypoint localization if the images are of good quality.
'superpoint_max': {
'output': 'feats-superpoint-n4096-rmax1600',
'model': {
'name': 'superpoint',
'nms_radius': 3,
'max_keypoints': 4096,
},
'preprocessing': {
'grayscale': True,
'resize_max': 1600,
'resize_force': True,
},
},
'superpoint_inloc': {
'output': 'feats-superpoint-n4096-r1600',
'model': {
Expand All @@ -57,6 +75,34 @@
'resize_max': 1600,
},
},
'sift': {
'output': 'feats-sift',
'model': {
'name': 'sift'
},
'preprocessing': {
'grayscale': True,
'resize_max': 1600,
},
},
'dir': {
'output': 'global-feats-dir',
'model': {
'name': 'dir',
},
'preprocessing': {
'resize_max': 1024,
},
},
'netvlad': {
'output': 'global-feats-netvlad',
'model': {
'name': 'netvlad',
},
'preprocessing': {
'resize_max': 1024,
},
},
}


Expand All @@ -65,37 +111,45 @@ class ImageDataset(torch.utils.data.Dataset):
'globs': ['*.jpg', '*.png', '*.jpeg', '*.JPG', '*.PNG'],
'grayscale': False,
'resize_max': None,
'resize_force': False,
}

def __init__(self, root, conf):
def __init__(self, root, conf, paths=None):
self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
self.root = root

self.paths = []
for g in conf.globs:
self.paths += list(Path(root).glob('**/'+g))
if len(self.paths) == 0:
raise ValueError(f'Could not find any image in root: {root}.')
self.paths = sorted(list(set(self.paths)))
self.paths = [i.relative_to(root) for i in self.paths]
logging.info(f'Found {len(self.paths)} images in root {root}.')
if paths is None:
paths = []
for g in conf.globs:
paths += list(Path(root).glob('**/'+g))
if len(paths) == 0:
raise ValueError(f'Could not find any image in root: {root}.')
paths = sorted(list(set(paths)))
self.names = [i.relative_to(root).as_posix() for i in paths]
logging.info(f'Found {len(self.names)} images in root {root}.')
else:
if isinstance(paths, (Path, str)):
self.names = parse_image_lists(paths)
elif isinstance(paths, collections.Iterable):
self.names = [p.as_posix() if isinstance(p, Path) else p
for p in paths]
else:
raise ValueError(f'Unknown format for path argument {paths}.')

for name in self.names:
if not (root / name).exists():
raise ValueError(
f'Image {name} does not exists in root: {root}.')

def __getitem__(self, idx):
path = self.paths[idx]
if self.conf.grayscale:
mode = cv2.IMREAD_GRAYSCALE
else:
mode = cv2.IMREAD_COLOR
image = cv2.imread(str(self.root / path), mode)
if not self.conf.grayscale:
image = image[:, :, ::-1] # BGR to RGB
if image is None:
raise ValueError(f'Cannot read image {str(path)}.')
name = self.names[idx]
image = read_image(self.root / name, self.conf.grayscale)
image = image.astype(np.float32)
size = image.shape[:2][::-1]
w, h = size

if self.conf.resize_max and max(w, h) > self.conf.resize_max:
if self.conf.resize_max and (self.conf.resize_force
or max(w, h) > self.conf.resize_max):
scale = self.conf.resize_max / max(h, w)
h_new, w_new = int(round(h*scale)), int(round(w*scale))
image = cv2.resize(
Expand All @@ -108,33 +162,43 @@ def __getitem__(self, idx):
image = image / 255.

data = {
'name': path.as_posix(),
'name': name,
'image': image,
'original_size': np.array(size),
}
return data

def __len__(self):
return len(self.paths)
return len(self.names)


@torch.no_grad()
def main(conf, image_dir, export_dir, as_half=False):
def main(conf, image_dir, export_dir=None, as_half=False,
image_list=None, feature_path=None):
logging.info('Extracting local features with configuration:'
f'\n{pprint.pformat(conf)}')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
Model = dynamic_load(extractors, conf['model']['name'])
model = Model(conf['model']).eval().to(device)

loader = ImageDataset(image_dir, conf['preprocessing'])
loader = ImageDataset(image_dir, conf['preprocessing'], image_list)
loader = torch.utils.data.DataLoader(loader, num_workers=1)

feature_path = Path(export_dir, conf['output']+'.h5')
if feature_path is None:
feature_path = Path(export_dir, conf['output']+'.h5')
feature_path.parent.mkdir(exist_ok=True, parents=True)
feature_file = h5py.File(str(feature_path), 'a')
skip_names = set(list_h5_names(feature_path)
if feature_path.exists() else ())
if set(loader.dataset.names).issubset(set(skip_names)):
logging.info('Skipping the extraction.')
return feature_path

device = 'cuda' if torch.cuda.is_available() else 'cpu'
Model = dynamic_load(extractors, conf['model']['name'])
model = Model(conf['model']).eval().to(device)

for data in tqdm(loader):
name = data['name'][0] # remove batch dimension
if name in skip_names:
continue

pred = model(map_tensor(data, lambda x: x.to(device)))
pred = {k: v[0].cpu().numpy() for k, v in pred.items()}

Expand All @@ -150,14 +214,15 @@ def main(conf, image_dir, export_dir, as_half=False):
if (dt == np.float32) and (dt != np.float16):
pred[k] = pred[k].astype(np.float16)

grp = feature_file.create_group(data['name'][0])
for k, v in pred.items():
grp.create_dataset(k, data=v)
with h5py.File(str(feature_path), 'a') as fd:
grp = fd.create_group(name)
for k, v in pred.items():
grp.create_dataset(k, data=v)

del pred

feature_file.close()
logging.info('Finished exporting features.')
return feature_path


if __name__ == '__main__':
Expand All @@ -166,5 +231,8 @@ def main(conf, image_dir, export_dir, as_half=False):
parser.add_argument('--export_dir', type=Path, required=True)
parser.add_argument('--conf', type=str, default='superpoint_aachen',
choices=list(confs.keys()))
parser.add_argument('--as_half', action='store_true')
parser.add_argument('--image_list', type=Path)
parser.add_argument('--feature_path', type=Path)
args = parser.parse_args()
main(confs[args.conf], args.image_dir, args.export_dir)
main(confs[args.conf], args.image_dir, args.export_dir, args.as_half)
10 changes: 3 additions & 7 deletions hloc/extractors/d2net.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import sys
from pathlib import Path
import subprocess
import logging
import torch

from ..utils.base_model import BaseModel
Expand All @@ -15,22 +14,19 @@
class D2Net(BaseModel):
default_conf = {
'model_name': 'd2_tf.pth',
'checkpoint_dir': d2net_path / 'models',
'use_relu': True,
'multiscale': False,
}
required_inputs = ['image']

def _init(self, conf):
model_file = d2net_path / 'models' / conf['model_name']
model_file = conf['checkpoint_dir'] / conf['model_name']
if not model_file.exists():
model_file.parent.mkdir(exist_ok=True)
cmd = ['wget', 'https://dsmn.ml/files/d2-net/'+conf['model_name'],
'-O', str(model_file)]
ret = subprocess.call(cmd)
if ret != 0:
logging.warning(
f'Cannot download the D2-Net model with `{cmd}`.')
exit(ret)
subprocess.run(cmd, check=True)

self.net = _D2Net(
model_file=model_file,
Expand Down
Loading

0 comments on commit 91f40bf

Please sign in to comment.