-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathSInetwork_example1.py
114 lines (82 loc) · 2.92 KB
/
SInetwork_example1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# -*- coding: utf-8 -*-
"""
Created on Mon Mar 7 19:09:46 2022
@author: Cecilia
"""
# -*- coding: utf-8 -*-
"""
Created on Thu Jan 13 10:31:19 2022
@author: Cecilia
"""
import torch
from nflows import transforms, distributions, flows, nn
from DIS import DIS
from models.SInetwork import SInetworkModel
import numpy as np
import matplotlib.pyplot as plt
from torch.distributions import MultivariateNormal
from SInetworkLikelihood import SInetworkLikelihood
from LikelihoodBased import ImpSamp
from scipy import stats
from time import time
plt.ion()
torch.manual_seed(111)
ninputs = 17 #2 + 5 + 5(4)/2
"Syntetic observations"
obs= [[[0], [0, 1, 3], [0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]]
p_inf = torch.distributions.Normal(0., 1.).cdf(torch.tensor(-0.1000))
p_con = torch.distributions.Normal(0., 1.).cdf(torch.tensor(-0.4151))
"Model for analysis"
model = SInetworkModel( observations=obs, n_nodes=5, n_inputs=ninputs)
" Setting up normalising flows "
base_dist = distributions.StandardNormal(shape=[ninputs])
transform = transforms.MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
features = ninputs,
hidden_features = 20,
num_bins = 5,
tails = "linear",
tail_bound = 10.,
num_blocks = 3
)
approx_dist = flows.Flow(transform, base_dist)
optimizer = torch.optim.Adam(
approx_dist.parameters()
)
"Run the analysis"
dis2 = DIS(model, approx_dist, optimizer,
importance_sample_size=5000, ess_target=250, max_weight=0.1)
dis2.pretrain(initial_target=model.prior, goal=0.75, report_every=10)
while dis2.eps > 0. or dis2.ess < 250.:
dis2.train(iterations=1)
nsamp = 10000
is_start_time = time()
with torch.no_grad():
weighted_params = dis2.get_sample(10*nsamp)
weighted_params.update_epsilon(0.0)
is_end_time = time()
is_time = (is_end_time - is_start_time) / 60.
print(f'Time for IS using DIS proposal {is_time:.1f} mins')
params = weighted_params.sample(nsamp).detach()
sel_infection, sel_contact = model.convert_inputs(params)[0:2]
"Run the likelihood-based analysis"
modlik = SInetworkLikelihood(obs[0],5, stats.beta(1,1), stats.beta(1,1))
IS = ImpSamp(modlik, S=10*nsamp, size=nsamp)
sample = IS.sample()
plt.hist(sample[:,0], density=True, alpha=0.6)
plt.hist(sel_infection.numpy(), density=True, alpha=0.6)
plt.title('Probability of infection')
plt.show()
plt.hist(sample[:,1], density=True, alpha=0.6)
plt.hist(sel_contact.numpy(), density=True, alpha=0.6)
plt.title('Probability of contact')
plt.show()
_, ax= plt.subplots(2,2)
ax[0,0].hist(sample[:,0], density=True, edgecolor='C0',alpha=0.5)
ax[0,0].set_ylabel('IS')
ax[0,1].hist(sample[:,1], density=True,edgecolor='C0', alpha=0.5)
ax[1,0].hist(sel_infection.numpy(), density=True,color='red', edgecolor='red', alpha=0.5)
ax[1,1].hist(sel_contact.numpy(), density=True,color='red', edgecolor='red', alpha=0.5)
ax[1,0].set_ylabel('DIS')
ax[1,0].set_xlabel('prob. infection')
ax[1,1].set_xlabel('prob. contact')
plt.savefig("SI_ex1_post.pdf")