-
Notifications
You must be signed in to change notification settings - Fork 11
/
pytest_deep.py
49 lines (44 loc) · 2.37 KB
/
pytest_deep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import pytest
import numpy as np
import IVIMNET.simulations as sim
import IVIMNET.deep as deep
import time
class hyperparams(object):
def __init__(self):
self.fig = False # plot results and intermediate steps
self.save_name = 'optim' # orig or optim (or optim_adsig for in vivo) torch18_015bounds_force_pos_false
self.net_pars = deep.net_pars(self.save_name)
self.train_pars = deep.train_pars(self.save_name)
self.fit = deep.lsqfit()
self.sharp = True # options for in vivo. Not relevant to upload
self.sim = deep.sim()
self.reps = 1
def test_NN():
arg = hyperparams()
arg = deep.checkarg(arg)
arg.sim.repeats = 1
arg.sim.sims=5000000
SNR = (10,120)
IVIM_signal_noisy, D, f, Dp = sim.sim_signal(SNR, arg.sim.bvalues, sims=arg.sim.sims, Dmin=arg.sim.range[0][0],
Dmax=arg.sim.range[1][0], fmin=arg.sim.range[0][1],
fmax=arg.sim.range[1][1], Dsmin=arg.sim.range[0][2],
Dsmax=arg.sim.range[1][2], rician=arg.sim.rician)
start_time = time.time()
# train network
net = deep.learn_IVIM(IVIM_signal_noisy, arg.sim.bvalues, arg)
elapsed_time = time.time() - start_time
print('\ntime elapsed for training: {}\n'.format(elapsed_time))
matNN = np.zeros([5, 3, 3])
aa=0
for SNR in [15, 20, 25, 50, 100]:
IVIM_signal_noisy, D, f, Dp = sim.sim_signal(SNR, arg.sim.bvalues, sims=30000, Dmin=arg.sim.range[0][0],
Dmax=arg.sim.range[1][0], fmin=arg.sim.range[0][1],
fmax=arg.sim.range[1][1], Dsmin=arg.sim.range[0][2],
Dsmax=arg.sim.range[1][2], rician=arg.sim.rician)
paramsNN = deep.predict_IVIM(IVIM_signal_noisy, arg.sim.bvalues, net,
arg)
matNN[aa] = sim.print_errors(np.squeeze(D), np.squeeze(f), np.squeeze(Dp), paramsNN)
aa=aa+1
ref = 1.1 * np.array([[[0.21870159],[0.27667806],[0.46173239]],[[0.18687773],[0.2458038],[0.40430529]],[[0.16957438],[0.22985733],[0.37065742]],[[0.14261743],[0.20649752],[0.31245713]],[[0.13493688],[0.20026352],[0.29356715]]])
np.testing.assert_array_less(matNN[:,:,1:2], ref)
assert elapsed_time < 600