Skip to content

Commit

Permalink
Switch to kornia DISK (#291)
Browse files Browse the repository at this point in the history
* Switch to kornia DISK
* Bump kornia minimum version
  • Loading branch information
skydes authored Jul 17, 2023
1 parent a828176 commit 90733da
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 69 deletions.
3 changes: 0 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,3 @@
[submodule "third_party/r2d2"]
path = third_party/r2d2
url = https://github.com/naver/r2d2.git
[submodule "third_party/disk"]
path = third_party/disk
url = https://github.com/cvlab-epfl/disk.git
79 changes: 15 additions & 64 deletions hloc/extractors/disk.py
Original file line number Diff line number Diff line change
@@ -1,81 +1,32 @@
import sys
from pathlib import Path
from functools import partial
import torch
import torch.nn.functional as F
import kornia

from ..utils.base_model import BaseModel

disk_path = Path(__file__).parent / "../../third_party/disk"
sys.path.append(str(disk_path))
from disk import DISK as _DISK # noqa E402


class DISK(BaseModel):
default_conf = {
'model_name': 'depth-save.pth',
'weights': 'depth',
'max_keypoints': None,
'desc_dim': 128,
'mode': 'nms',
'nms_window_size': 5,
'detection_threshold': 0.0,
'pad_if_not_divisible': True,
}
required_inputs = ['image']

def _init(self, conf):
self.model = _DISK(window=8, desc_dim=conf['desc_dim'])

state_dict = torch.load(
disk_path / conf['model_name'], map_location='cpu')
if 'extractor' in state_dict:
weights = state_dict['extractor']
elif 'disk' in state_dict:
weights = state_dict['disk']
else:
raise KeyError('Incompatible weight file!')
self.model.load_state_dict(weights)

if conf['mode'] == 'nms':
self.extract = partial(
self.model.features,
kind='nms',
window_size=conf['nms_window_size'],
cutoff=0.,
n=conf['max_keypoints']
)
elif conf['mode'] == 'rng':
self.extract = partial(self.model.features, kind='rng')
else:
raise KeyError(
f'mode must be `nms` or `rng`, got `{conf["mode"]}`')
self.model = kornia.feature.DISK.from_pretrained(conf['weights'])

def _forward(self, data):
image = data['image']
# make sure that the dimensions of the image are multiple of 16
orig_h, orig_w = image.shape[-2:]
new_h = round(orig_h / 16) * 16
new_w = round(orig_w / 16) * 16
image = F.pad(image, (0, new_w - orig_w, 0, new_h - orig_h))

batched_features = self.extract(image)

assert(len(batched_features) == 1)
features = batched_features[0]

# filter points detected in the padded areas
kpts = features.kp
valid = torch.all(kpts <= kpts.new_tensor([orig_w, orig_h]) - 1, 1)
kpts = kpts[valid]
descriptors = features.desc[valid]
scores = features.kp_logp[valid]

# order the keypoints
indices = torch.argsort(scores, descending=True)
kpts = kpts[indices]
descriptors = descriptors[indices]
scores = scores[indices]

features = self.model(
image,
n=self.conf['max_keypoints'],
window_size=self.conf['nms_window_size'],
score_threshold=self.conf['detection_threshold'],
pad_if_not_divisible=self.conf['pad_if_not_divisible'],
)
return {
'keypoints': kpts[None],
'descriptors': descriptors.t()[None],
'scores': scores[None],
'keypoints': [f.keypoints for f in features],
'keypoint_scores': [f.detection_scores for f in features],
'descriptors': [f.descriptors.t() for f in features],
}
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ plotly
scipy
h5py
pycolmap>=0.3.0
kornia>=0.6.7
kornia>=0.6.11
gdown
lightglue @ git+https://github.com/cvg/LightGlue
1 change: 0 additions & 1 deletion third_party/disk
Submodule disk deleted from eafa0e

0 comments on commit 90733da

Please sign in to comment.