-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_convlddmm.py
70 lines (64 loc) · 3.13 KB
/
run_convlddmm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
from torch.distributed import barrier
import os
from cmdline import *
from atlas import *
oasis_ds_std = OASISDataset(crop=None,
h5path=f'{prefix}convaffinestd_{suffix}.h5',
pooling=None,
one_scan_per_subject=False)
oasis_ds_test_std = OASISDataset(crop=None,
h5path=f'{prefix}convaffinestd_test_{suffix}.h5',
pooling=None,
one_scan_per_subject=False)
deepaffinefile = f'{prefix}deepaffine_{suffix}.pth'
I_deepaffine, affine_net, epoch_losses_deepaffine, full_losses_deepaffine, \
iter_losses_deepaffine, test_losses_deepaffine \
= torch.load(deepaffinefile, map_location=loc)
I_deepaffine = I_deepaffine.to(loc)
fluid_params = [.1,0,.01]
reg_weight = 1e4
if rank == 0: torch.save(fluid_params, f'{prefix}fluidparams_{suffix}.pth')
convlddmmfile = f'{prefix}convlddmm_{suffix}.pth'
if not os.path.isfile(convlddmmfile): # conventional lddmm atlas
print("Conventional LDDMM atlas building")
res = lddmm_atlas(dataset=oasis_ds_std,
I0=I_deepaffine.clone().to('cuda'),
fluid_params=fluid_params,
learning_rate_pose=1e-3,
learning_rate_image=5e4,
reg_weight=reg_weight,
momentum_preconditioning=False,
batch_size=30,
num_epochs=500,
gpu=gpu,
world_size=args.world_size,
rank=rank)
if rank == 0: torch.save(res, convlddmmfile)
else:
torch.load(convlddmmfile, map_location='cpu')
barrier()
Ilddmm, _, _, _ = res
Ilddmm = Ilddmm.to(loc)
# On the test set, use same atlas-building code but with zero learning rate for
# the image
convlddmmtestfile = f'{prefix}convlddmm_test_{suffix}.pth'
if not os.path.isfile(convlddmmtestfile): # conventional lddmm atlas
print("Conventional LDDMM Test")
res = lddmm_atlas(dataset=oasis_ds_test_std,
I0=Ilddmm,
fluid_params=fluid_params,
learning_rate_pose=1e-3,
learning_rate_image=0e4,
momentum_preconditioning=False,
reg_weight=reg_weight,
batch_size=30,
num_epochs=1,
lddmm_steps=500,
gpu=gpu,
world_size=args.world_size,
rank=rank)
if rank == 0: torch.save(res, convlddmmtestfile)
del res
#Ilddmm, mom_lddmm, epoch_losses, iter_losses = torch.load(convlddmmtestfile,
#map_location=loc)