Skip to content

Commit

Permalink
Test convolve mode in hospitaladmissionspy (#398)
Browse files Browse the repository at this point in the history
* testing convolve mode

* update tutorial to work with convolve mode valid

* update latent admissions test

* update DOW tutorial for convolve mode valid

* update hosp model tests

* create helper function for convolve and add tests

* forgot to run precommit earlier

* update test for model with DOW effect

* renaming helper function, add n_initialization_point

* Apply suggestions from code review

Co-authored-by: Dylan H. Morris <[email protected]>

* move helper function from metaclass to convolve.py

* uniformize starting point of all plots

* adopt new var names

* fix var names

* fix docstring

* update n_initialization_points

* Update pyrenew/convolve.py

* Update pyrenew/convolve.py

---------

Co-authored-by: Dylan H. Morris <[email protected]>
Co-authored-by: Damon Bayer <[email protected]>
  • Loading branch information
3 people authored Aug 26, 2024
1 parent 769712b commit 1a86104
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 137 deletions.
26 changes: 8 additions & 18 deletions docs/source/tutorials/day_of_the_week.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ inf_hosp_int = datasets.load_infection_admission_interval()
# We only need the probability_mass column of each dataset
gen_int_array = gen_int["probability_mass"].to_numpy()
gen_int = gen_int_array
inf_hosp_int = inf_hosp_int["probability_mass"].to_numpy()
inf_hosp_int_array = inf_hosp_int["probability_mass"].to_numpy()
```

2. Next, we defined the model's components:
Expand All @@ -56,7 +56,7 @@ import jax.numpy as jnp
import numpyro.distributions as dist
inf_hosp_int = deterministic.DeterministicPMF(
name="inf_hosp_int", value=inf_hosp_int
name="inf_hosp_int", value=inf_hosp_int_array
)
hosp_rate = metaclass.DistributionalRV(
Expand All @@ -77,14 +77,16 @@ from pyrenew.latent import (
# Infection process
latent_inf = latent.Infections()
n_initialization_points = max(gen_int_array.size, inf_hosp_int_array.size) - 1
I0 = InfectionInitializationProcess(
"I0_initialization",
metaclass.DistributionalRV(
name="I0",
distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)),
),
InitializeInfectionsExponentialGrowth(
gen_int_array.size,
n_initialization_points,
deterministic.DeterministicVariable(name="rate", value=0.05),
),
t_unit=1,
Expand Down Expand Up @@ -201,11 +203,7 @@ hosp_model.run(
out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(
daily_hosp_admits.astype(float),
(gen_int_array.size, 0),
constant_values=np.nan,
),
obs_signal=daily_hosp_admits.astype(float),
)
```

Expand Down Expand Up @@ -299,11 +297,7 @@ The new model with the day-of-the-week effect can be compared to the previous mo
out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(
daily_hosp_admits.astype(float),
(gen_int_array.size, 0),
constant_values=np.nan,
),
obs_signal=daily_hosp_admits.astype(float),
)
```

Expand All @@ -314,10 +308,6 @@ out = hosp_model.plot_posterior(
out_dow = hosp_model_dow.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(
daily_hosp_admits.astype(float),
(gen_int_array.size, 0),
constant_values=np.nan,
),
obs_signal=daily_hosp_admits.astype(float),
)
```
41 changes: 15 additions & 26 deletions docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,17 @@ inf_hosp_int = datasets.load_infection_admission_interval()
# We only need the probability_mass column of each dataset
gen_int_array = gen_int["probability_mass"].to_numpy()
gen_int = gen_int_array
inf_hosp_int = inf_hosp_int["probability_mass"].to_numpy()
inf_hosp_int_array = inf_hosp_int["probability_mass"].to_numpy()
# Taking a peek at the first 5 elements of each
gen_int[:5], inf_hosp_int[:5]
gen_int[:5], inf_hosp_int_array[:5]
# Visualizing both quantities side by side
fig, axs = plt.subplots(1, 2)
axs[0].plot(gen_int)
axs[0].set_title("Generation interval")
axs[1].plot(inf_hosp_int)
axs[1].plot(inf_hosp_int_array)
axs[1].set_title("Infection to hospital admission interval")
plt.show()
```
Expand All @@ -142,7 +142,7 @@ import jax.numpy as jnp
import numpyro.distributions as dist
inf_hosp_int = deterministic.DeterministicPMF(
name="inf_hosp_int", value=inf_hosp_int
name="inf_hosp_int", value=inf_hosp_int_array
)
hosp_rate = metaclass.DistributionalRV(
Expand All @@ -168,14 +168,15 @@ from pyrenew.latent import (
# Infection process
latent_inf = latent.Infections()
n_initialization_points = max(gen_int_array.size, inf_hosp_int_array.size) - 1
I0 = InfectionInitializationProcess(
"I0_initialization",
metaclass.DistributionalRV(
name="I0",
distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)),
),
InitializeInfectionsExponentialGrowth(
gen_int_array.size,
n_initialization_points,
deterministic.DeterministicVariable(name="rate", value=0.05),
),
t_unit=1,
Expand Down Expand Up @@ -308,11 +309,7 @@ We can use the `Model` object's `plot_posterior` method to visualize the model f
out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(
daily_hosp_admits.astype(float),
(gen_int_array.size, 0),
constant_values=np.nan,
),
obs_signal=daily_hosp_admits.astype(float),
)
```

Expand Down Expand Up @@ -381,22 +378,14 @@ axes.set_ylabel("Hospital Admissions", fontsize=10)
plt.show()
```

We can look at individual draws from the posterior distribution of latent infections:

```{python}
# | label: fig-output-infections
# | fig-cap: Latent infections
out2 = hosp_model.plot_posterior(
var="all_latent_infections", ylab="Latent Infections"
)
```

We can also look at credible intervals for the posterior distribution of latent infections:

```{python}
# | label: fig-output-infections-distribution
# | fig-cap: Posterior Latent Infections
x_data = idata.posterior["all_latent_infections_dim_0"]
x_data = (
idata.posterior["all_latent_infections_dim_0"] - n_initialization_points
)
y_data = idata.posterior["all_latent_infections"]
fig, axes = plt.subplots(figsize=(6, 5))
Expand Down Expand Up @@ -499,7 +488,7 @@ def compute_eti(dataset, eti_prob):
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
idata.prior_predictive["negbinom_rv_dim_0"] + gen_int.size(),
idata.prior_predictive["negbinom_rv_dim_0"],
hdi_data=compute_eti(idata.prior_predictive["negbinom_rv"], 0.9),
color="C0",
smooth=False,
Expand All @@ -508,7 +497,7 @@ az.plot_hdi(
)
az.plot_hdi(
idata.prior_predictive["negbinom_rv_dim_0"] + gen_int.size(),
idata.prior_predictive["negbinom_rv_dim_0"],
hdi_data=compute_eti(idata.prior_predictive["negbinom_rv"], 0.5),
color="C0",
smooth=False,
Expand All @@ -517,7 +506,7 @@ az.plot_hdi(
)
plt.scatter(
idata.observed_data["negbinom_rv_dim_0"] + gen_int.size(),
idata.observed_data["negbinom_rv_dim_0"],
idata.observed_data["negbinom_rv"],
color="black",
)
Expand All @@ -533,7 +522,7 @@ And now we plot the posterior predictive distributions with a `{python} n_foreca
```{python}
# | label: fig-output-posterior-predictive-forecast
# | fig-cap: Posterior predictive admissions, including a forecast.
x_data = idata.posterior_predictive["negbinom_rv_dim_0"] + gen_int.size()
x_data = idata.posterior_predictive["negbinom_rv_dim_0"]
y_data = idata.posterior_predictive["negbinom_rv"]
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
Expand Down Expand Up @@ -564,7 +553,7 @@ plt.plot(
label="Median",
)
plt.scatter(
idata.observed_data["negbinom_rv_dim_0"] + gen_int.size(),
idata.observed_data["negbinom_rv_dim_0"],
idata.observed_data["negbinom_rv"],
color="black",
)
Expand Down
42 changes: 42 additions & 0 deletions pyrenew/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,45 @@ def _new_scanner(
return latest, (new_val, m_net1)

return _new_scanner


def compute_delay_ascertained_incidence(
p_observed_given_incident: ArrayLike,
latent_incidence: ArrayLike,
delay_incidence_to_observation_pmf: ArrayLike,
) -> ArrayLike:
"""
Computes incidences observed according
to a given observation rate and based
on a delay interval.
Parameters
----------
p_observed_given_incident: ArrayLike
The rate at which latent incident counts translate into observed counts.
For example, setting ``p_observed_given_incident=0.001``
when the incident counts are infections and the observed counts are
reported hospital admissions could be used to model disease and population
for which the probability of a latent infection leading to a reported
hospital admission is 0.001.
latent_incidence: ArrayLike
Incidence values based on the true underlying process.
delay_incidence_to_observation_pmf: ArrayLike
Probability mass function of delay interval from incidence to observation,
where the :math:`i`\th entry represents a delay of :math:`i`
time units, i.e. ``delay_incidence_to_observation_pmf[0]`` represents
the fraction of observations that are delayed 0 time unit,
``delay_incidence_to_observation_pmf[1]`` represents the fraction
that are delayed 1 time units, et cetera.
Returns
--------
ArrayLike
The predicted timeseries of delayed observations.
"""
delay_obs_incidence = jnp.convolve(
p_observed_given_incident * latent_incidence,
delay_incidence_to_observation_pmf,
mode="valid",
)
return delay_obs_incidence
9 changes: 5 additions & 4 deletions pyrenew/latent/hospitaladmissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpyro

import pyrenew.arrayutils as au
from pyrenew.convolve import compute_delay_ascertained_incidence
from pyrenew.deterministic import DeterministicVariable
from pyrenew.metaclass import RandomVariable, SampledValue

Expand Down Expand Up @@ -210,11 +211,11 @@ def sample(
*_,
) = self.infection_to_admission_interval_rv(**kwargs)

latent_hospital_admissions = jnp.convolve(
infection_hosp_rate.value * latent_infections.value,
latent_hospital_admissions = compute_delay_ascertained_incidence(
infection_hosp_rate.value,
latent_infections.value,
infection_to_admission_interval.value,
mode="full",
)[: latent_infections.value.shape[0]]
)

# Applying the day of the week effect. For this we need to:
# 1. Get the day of the week effect
Expand Down
49 changes: 49 additions & 0 deletions test/test_incidence_observed_with_delay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# numpydoc ignore=GL08

import jax.numpy as jnp
import pytest
from numpy.testing import assert_array_equal

from pyrenew.convolve import compute_delay_ascertained_incidence


@pytest.mark.parametrize(
["obs_rate", "latent_incidence", "delay_interval", "expected_output"],
[
[
jnp.array([1.0]),
jnp.array([1.0, 2.0, 3.0]),
jnp.array([1.0]),
jnp.array([1.0, 2.0, 3.0]),
],
[
jnp.array([1.0, 0.1, 1.0]),
jnp.array([1.0, 2.0, 3.0]),
jnp.array([1.0]),
jnp.array([1.0, 0.2, 3.0]),
],
[
jnp.array([1.0]),
jnp.array([1.0, 2.0, 3.0]),
jnp.array([0.5, 0.5]),
jnp.array([1.5, 2.5]),
],
[
jnp.array([1.0]),
jnp.array([0, 2.0, 4.0]),
jnp.array([0.25, 0.5, 0.25]),
jnp.array([2]),
],
],
)
def test(obs_rate, latent_incidence, delay_interval, expected_output):
"""
Tests for helper function to compute
incidence observed with a delay
"""
result = compute_delay_ascertained_incidence(
obs_rate,
latent_incidence,
delay_interval,
)
assert_array_equal(result, expected_output)
Loading

0 comments on commit 1a86104

Please sign in to comment.