-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdefault_motion_configs.py
72 lines (63 loc) · 2.28 KB
/
default_motion_configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import ml_collections
import torch
def get_default_configs():
config = ml_collections.ConfigDict()
# training
config.training = training = ml_collections.ConfigDict()
config.training.batch_size = 16
training.snapshot_freq = 10000
training.log_freq = 50
training.eval_freq = 500
## store additional checkpoints for preemption in cloud computing environments
training.snapshot_freq_for_preemption = 5000
## produce samples at each snapshot.
training.snapshot_sampling = True
training.likelihood_weighting = False
training.continuous = True
training.reduce_mean = False
# sampling
config.sampling = sampling = ml_collections.ConfigDict()
sampling.n_steps_each = 1
sampling.noise_removal = True
sampling.probability_flow = False
sampling.snr = 0.075 # todo: this needs to be tuned between 0.05 and 0.2
# evaluation
config.eval = evaluate = ml_collections.ConfigDict()
evaluate.begin_ckpt = 50 # todo: first checkpoint to be evaluated
evaluate.end_ckpt = 100 # todo: last checkpoint ot be evaluated
evaluate.batch_size = 512
evaluate.enable_sampling = True
evaluate.num_samples = 50000
evaluate.enable_loss = True
evaluate.enable_bpd = False
evaluate.bpd_dataset = 'test' # todo: here I need to pass motion-free and motion-affected data sets
# data
config.data = data = ml_collections.ConfigDict()
data.dataset = 'motion'
data.rot_dir = '/media/mareike/Elements/Data/HeadSimulatedProjectionDataCQ500FanBeam/motion_free'
data.image_size = 256
data.random_flip = True
data.uniform_dequantization = False
data.centered = False
data.num_channels = 1
# model
config.model = model = ml_collections.ConfigDict()
model.sigma_max = 378 # todo: maximum noise perturbation during training
model.sigma_min = 0.01 # todo: minimum noise perturbation during training
model.num_scales = 2000
model.beta_min = 0.1
model.beta_max = 20.
model.dropout = 0.
model.embedding_type = 'fourier'
# optimization
config.optim = optim = ml_collections.ConfigDict()
optim.weight_decay = 0
optim.optimizer = 'Adam'
optim.lr = 2e-4
optim.beta1 = 0.9
optim.eps = 1e-8
optim.warmup = 5000
optim.grad_clip = 1.
config.seed = 42
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
return config