Skip to content

Commit

Permalink
mdoel description
Browse files Browse the repository at this point in the history
  • Loading branch information
rfl-urbaniak committed Oct 18, 2024
1 parent 21826b7 commit b12a2c9
Showing 1 changed file with 19 additions and 26 deletions.
45 changes: 19 additions & 26 deletions docs/source/dynamical_multi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
"source": [
"## Overview: Composing Hierarchical Bayesian Models with ODEs\n",
"\n",
"In our previous tutorial on causal reasoning with continuous time dynamical systems we showed how embedding differential equation solvers in ChiRho allows us to do the following: (i) fit dynamical systems parameters using Pyro's support for variational inference and (ii) reason about (uncertain) interventions representing various policy decisions. In this tutorial we expand on that early epidemiological example by composing the same simple susceptible, infected, recovered (SIR) model with hierarchical priors over dynamical systems parameters for each of several distinct geographic location (we will use three locations in our running example). The key insight here is that the same Bayesian modeling motifs for pooling statistical information between distinct strata in standard Bayesian multilevel regression modeling (see Chapter 5 of Gelman, A., Carlin, J. B., Stern, H. S., & Rubin, D. B. (1995). *Bayesian data analysis*. Chapman and Hall/CRC with subsequent editions and improvements) can be used when the regression equations are swapped out with mechanistic models in the form of differential equations.\n",
"In our previous tutorial on causal reasoning with continuous time dynamical systems we showed how embedding differential equation solvers in ChiRho allows us to do the following: (i) fit dynamical systems parameters using Pyro's support for variational inference and (ii) reason about (uncertain) interventions representing various policy decisions. In this tutorial we expand on that early epidemiological example by composing the same simple susceptible, infected, recovered (SIR) model with hierarchical priors over dynamical systems parameters for each of several distinct geographic location (we will use three locations in our running example). The key insight here is that the same Bayesian modeling motifs for pooling statistical information between distinct strata in standard Bayesian multilevel regression modeling (see Chapter 5 of Gelman, A., Carlin, J. B., Stern, H. S., & Rubin, D. B. (1995). *Bayesian data analysis*. Chapman and Hall/CRC with subsequent editions and improvements) can be used when the regression equations are swapped out with mechanistic models in the form of differential equations. We will also see that such effects carry over to the predicted effects of interventions. \n",
"\n",
"### Background: causal reasoning in dynamical systems\n",
"\n",
Expand Down Expand Up @@ -123,9 +123,18 @@
"## Causal probabilistic program"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Just as in the previous tutorial, we define the differential equation model as a `pyro.nn.PyroModule` as follows, where the forward method is a function from states `X` to the time derivatives of the states, `dX`. Fortunately, we can use the exact same implementation for the stratified example here, taking advantage of PyTorch's tensor broadcasting semantics.\n",
"\n",
"Also, we assume we only make observations in one of the locations, which - intuitively - should decrease our uncertaintly about that location more than about other locations, while still allowing us learn something about them, in virtue of the location-specific parameters coming from the same general distributions. Conceptually, `single_observation_model()` takes a trajectory already produced by a simulation, and generates a sample of Poisson-distributed observations at the first location."
]
},
{
"cell_type": "code",
"execution_count": 76,
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2023-07-18T18:46:29.357796Z",
Expand All @@ -148,25 +157,24 @@
" return dX\n",
" \n",
"def sir_observation_model(X: State[torch.Tensor]) -> None:\n",
" # We don't observe the number of susceptible individuals directly.\n",
" \n",
" # Note: Here we set the event_dim to 1 if the last dimension of X[\"I\"] is > 1, as the sir_observation_model\n",
" # can be used for both single and multi-dimensional observations.\n",
" event_dim = 1 if X[\"I\"].shape and X[\"I\"].shape[-1] > 1 else 0\n",
" pyro.sample(\"I_obs\", dist.Poisson(X[\"I\"]).to_event(event_dim)) # noisy number of infected actually observed\n",
" pyro.sample(\"R_obs\", dist.Poisson(X[\"R\"]).to_event(event_dim)) # noisy number of recovered actually observed\n",
"\n",
"def single_observation_model(X: State[torch.Tensor]) -> None:\n",
" # In this example we only take noisy measurements of a single town corresponding to the first index in the state tensors.\n",
" # In this example we only take noisy measurements of a single town corresponding to\n",
" # the first index in the state tensors (0 in the second-last dimension, the last dimension is time).\n",
" first_X = {k: v[..., 0, :] for k, v in X.items()}\n",
" return sir_observation_model(first_X)\n"
" return sir_observation_model(first_X)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To use this model definition in a stratified setting, we simply extend the tensor dimensions of the `init_state` as follows."
"To use this model definition in a stratified setting, we simply extend the tensor dimensions of the `init_state` as follows. We generate the ground truth trajectories, as well as trajectories logged at observation times, passed on to the `single_observation_model()` to generate synthetic data used further on."
]
},
{
Expand All @@ -177,15 +185,16 @@
"source": [
"n_strata = 3\n",
"\n",
"# Assume that in each town there is initially a population of 99 thousand people that are susceptible, 1 thousand infected, and 0 recovered.\n",
"# Assume that in each town there is initially a population of 99 thousand people that are susceptible,\n",
"# 1 thousand infected, and 0 recovered.\n",
"init_state = dict(S=torch.ones(n_strata) * 99, I=torch.ones(n_strata), R=torch.zeros(n_strata))\n",
"start_time = torch.tensor(0.0)\n",
"end_time = torch.tensor(6.0)\n",
"step_size = torch.tensor(0.1)\n",
"logging_times = torch.arange(start_time, end_time, step_size)\n",
"\n",
"\n",
"# We now simulate from the SIR model. Notice that the true parameters are similar to each other, but not exactly the same.\n",
"# We now simulate from the SIR model. Notice that the true parameters are similar to each other,\n",
"# but not exactly the same.\n",
"beta_true = torch.tensor([0.03, 0.04, 0.035])\n",
"gamma_true = torch.tensor([0.4, 0.385, 0.405])\n",
"sir_true = SIRDynamics(beta_true, gamma_true)\n",
Expand Down Expand Up @@ -260,7 +269,6 @@
" test_end_time, color=\"black\", linestyle=\":\"\n",
" )\n",
"\n",
"\n",
"def plot_sir_data(n_strata, colors,\n",
" sir_traj = None, logging_times = None, \n",
" sir_data = None, obs_logging_times = None, \n",
Expand All @@ -284,11 +292,6 @@
"\n",
" \n",
" for j, key in enumerate([\"S\", \"I\", \"R\"]):\n",
" # if true_traj[key].ndim > 2:\n",
" # reshaped_val = true_traj[key].squeeze().view([n_strata, num_samples,true_traj[key].shape[-1]])\n",
" # else:\n",
" # reshaped_val = true_traj[key]\n",
"\n",
" SIR_uncertainty_plot(true_logging_times, true_traj[key][i, :], color=\"black\", ax=ax[i, j], linestyle=\"dashed\")\n",
" \n",
" if plot_true_peak:\n",
Expand All @@ -300,12 +303,7 @@
" for i in range(n_strata):\n",
" \n",
" for j, key in enumerate([\"S\", \"I\", \"R\"]):\n",
" # if sir_traj[key].ndim > 2:\n",
" # _sir_traj[key] = sir_traj[key].squeeze().view([n_strata, num_samples,sir_traj[key].shape[-1]])\n",
" #reshaped_val = sir_traj[key].squeeze().view([n_strata, num_samples,sir_traj[key].shape[-1]])\n",
" SIR_uncertainty_plot(logging_times, sir_traj[key][...,0, i, :], color=colors[key], ax=ax[i, j])\n",
" # SIR_uncertainty_plot(logging_times, _sir_traj[\"I\"][i, :], color=colors[\"I\"], ax=ax[i, 1])\n",
" # SIR_uncertainty_plot(logging_times, _sir_traj[\"R\"][i, :], color=colors[\"R\"], ax=ax[i, 2])\n",
"\n",
" # Set x-axis labels\n",
" ax[i, 0].set_xlabel(\"Time (months)\")\n",
Expand All @@ -318,11 +316,6 @@
" ax[i, 1].set_title(\"Infected\")\n",
" ax[i, 2].set_title(\"Recovered\")\n",
"\n",
" \n",
" #ax[i, 1].set_ylabel(f\"Town {i}\")\n",
" #ax[i, 2].set_ylabel(f\"Town {i}\")\n",
"\n",
"\n",
" ax_right_2 = ax[i, 2].twinx() \n",
" ax_right_2.set_ylabel(f\"Town {i}\", rotation=270, labelpad=15)\n",
" ax_right_2.yaxis.set_label_position(\"right\")\n",
Expand Down

0 comments on commit b12a2c9

Please sign in to comment.