-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstep_4.py
41 lines (27 loc) · 979 Bytes
/
step_4.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np
import pyro
import pyro.distributions as dist
import torch
from pyro import poutine
PRIOR_BELIEF_IN_MU = 0.0
def model(data=None, n_obs=None):
if data is None and n_obs is None:
raise ValueError("Someone has gotta tell us how many observations there are")
if data is not None:
n_obs = data.shape[0]
mu = pyro.param("mu", torch.tensor(PRIOR_BELIEF_IN_MU))
with pyro.plate("N", n_obs):
y = pyro.sample("y", dist.Normal(mu, 1), obs=data)
return y
def main():
one_trace = poutine.trace(model).get_trace(data=None, n_obs=10)
# .nodes???
sampled_y_vector = one_trace.nodes["y"]["value"].detach().numpy()
print("Our sampled vector is:")
print(sampled_y_vector)
print("Mean of our sampled vector is:")
print(np.mean(sampled_y_vector))
print("The log_prob_sum of this sample is:")
print(one_trace.log_prob_sum().detach().numpy().item())
if __name__ == "__main__":
main()