-
Notifications
You must be signed in to change notification settings - Fork 0
/
matching.py
113 lines (110 loc) · 3.99 KB
/
matching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.functional import mse_loss
from torch.utils.data import Dataset, DataLoader
from utils import tqdm
import h5py
import numpy as np
import random
import lagomorph as lm
def affine_matching(I,
J,
A=None,
T=None,
affine_steps=100,
reg_weightA=1e2,
reg_weightT=1e1,
learning_rate_A=1e-4,
learning_rate_T=1e-2,
progress_bar=True):
"""Matching image I to J via affine transform"""
if A is None:
A = torch.zeros((I.shape[0],3,3), dtype=I.dtype).to(I.device)
if T is None:
T = torch.zeros((I.shape[0],3), dtype=I.dtype).to(I.device)
J = J.to(I.device)
losses = []
I.requires_grad_(False)
J.requires_grad_(False)
steps = range(affine_steps)
eye = torch.eye(3).view(1,3,3).type(I.dtype).to(I.device)
if progress_bar: steps = tqdm(steps)
for mit in steps:
A.requires_grad_(True)
T.requires_grad_(True)
if A.grad is not None and T.grad is not None:
A.grad.detach_()
A.grad.zero_()
T.grad.detach_()
T.grad.zero_()
Idef = lm.affine_interp(I, A+eye, T)
regtermA = mse_loss(A,A)
regtermT = mse_loss(T,T)
loss = mse_loss(Idef, J) + .5*reg_weightA*regtermA + .5*reg_weightT*regtermT
loss.backward()
loss.detach_()
with torch.no_grad():
losses.append(loss)
#if torch.isnan(losses[-1]).item():
#print(f"loss is NaN at iter {mit}")
#break
#if mit > 0 and losses[-1].item() > losses[-2].item():
#print(f"loss increased at iter {mit}")
A.add_(-learning_rate_A, A.grad)
T.add_(-learning_rate_T, T.grad)
return A.detach(), T.detach(), [l.item() for l in losses]
def lddmm_matching( I,
J,
m=None,
lddmm_steps=1000,
lddmm_integration_steps=10,
reg_weight=1e-1,
learning_rate_pose = 2e-2,
fluid_params=[1.0,.1,.01],
progress_bar=True
):
"""Matching image I to J via LDDMM"""
if m is None:
defsh = [I.shape[0], 3] + list(I.shape[2:])
m = torch.zeros(defsh, dtype=I.dtype).to(I.device)
do_regridding = m.shape[2:] != I.shape[2:]
J = J.to(I.device)
matchterms = []
regterms = []
losses = []
metric = lm.FluidMetric(fluid_params)
m.requires_grad_()
pb = range(lddmm_steps)
if progress_bar: pb = tqdm(pb)
for mit in pb:
if m.grad is not None:
m.grad.detach_()
m.grad.zero_()
m.requires_grad_()
h = lm.expmap(metric, m, num_steps=lddmm_integration_steps)
if do_regridding is not None:
h = lm.regrid(h, shape=I.shape[2:], displacement=True)
Idef = lm.interp(I, h)
regterm = (metric.sharp(m)*m).mean()
matchterm = mse_loss(Idef, J)
matchterms.append(matchterm.detach().item())
regterms.append(regterm.detach().item())
loss = matchterm + reg_weight*regterm
loss.backward()
loss.detach_()
with torch.no_grad():
#v = metric.sharp(m)
#regterm = (v*m).mean()#.detach()
#del v
#losses.append(loss.detach()+ .5*reg_weight*regterm)
losses.append(loss.detach())
p = metric.flat(m.grad).detach()
if torch.isnan(losses[-1]).item():
print(f"loss is NaN at iter {mit}")
break
#if mit > 0 and losses[-1].item() > losses[-2].item():
# print(f"loss increased at iter {mit}")
#p.add_(reg_weight/np.prod(m.shape[1:]), m)
m.add_(-learning_rate_pose, p)
return m.detach(), [l.item() for l in losses], matchterms, regterms