Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test convolve mode in hospitaladmissionspy #398

Merged
merged 25 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
65b248e
testing convolve mode
sbidari Aug 20, 2024
8d66ef8
Merge branch 'main' of https://github.com/CDCgov/multisignal-epi-infe…
sbidari Aug 20, 2024
17107c2
Merge branch 'main' into 385-incorrect-convolve-mode-in-hospitaladmis…
sbidari Aug 21, 2024
3a54855
update tutorial to work with convolve mode valid
sbidari Aug 21, 2024
399250f
Merge branch 'main' into 385-incorrect-convolve-mode-in-hospitaladmis…
sbidari Aug 21, 2024
9a7cbb3
update latent admissions test
sbidari Aug 21, 2024
cbff93c
update DOW tutorial for convolve mode valid
sbidari Aug 21, 2024
41b070f
update hosp model tests
sbidari Aug 21, 2024
e9130ce
create helper function for convolve and add tests
sbidari Aug 21, 2024
4737288
forgot to run precommit earlier
sbidari Aug 21, 2024
eb9e168
Merge branch 'main' into 385-incorrect-convolve-mode-in-hospitaladmis…
sbidari Aug 21, 2024
e87d742
update test for model with DOW effect
sbidari Aug 21, 2024
999d124
Merge branch 'main' of https://github.com/CDCgov/PyRenew into 385-inc…
sbidari Aug 22, 2024
b4c5ca2
renaming helper function, add n_initialization_point
sbidari Aug 22, 2024
6840243
Merge branch 'main' of https://github.com/CDCgov/PyRenew into 385-inc…
sbidari Aug 22, 2024
ed19002
Apply suggestions from code review
sbidari Aug 22, 2024
0599cfc
Merge branch 'main' into 385-incorrect-convolve-mode-in-hospitaladmis…
sbidari Aug 22, 2024
9095259
move helper function from metaclass to convolve.py
sbidari Aug 23, 2024
c69d6cc
uniformize starting point of all plots
sbidari Aug 23, 2024
c28ec02
adopt new var names
sbidari Aug 23, 2024
4bc1736
fix var names
sbidari Aug 23, 2024
ea72333
fix docstring
sbidari Aug 23, 2024
5c4395f
update n_initialization_points
sbidari Aug 23, 2024
554301b
Update pyrenew/convolve.py
damonbayer Aug 26, 2024
3efcfad
Update pyrenew/convolve.py
damonbayer Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 8 additions & 18 deletions docs/source/tutorials/day_of_the_week.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ inf_hosp_int = datasets.load_infection_admission_interval()
# We only need the probability_mass column of each dataset
gen_int_array = gen_int["probability_mass"].to_numpy()
gen_int = gen_int_array
inf_hosp_int = inf_hosp_int["probability_mass"].to_numpy()
inf_hosp_int_array = inf_hosp_int["probability_mass"].to_numpy()
```

2. Next, we defined the model's components:
Expand All @@ -56,7 +56,7 @@ import jax.numpy as jnp
import numpyro.distributions as dist

inf_hosp_int = deterministic.DeterministicPMF(
name="inf_hosp_int", value=inf_hosp_int
name="inf_hosp_int", value=inf_hosp_int_array
)

hosp_rate = metaclass.DistributionalRV(
Expand All @@ -77,14 +77,16 @@ from pyrenew.latent import (

# Infection process
latent_inf = latent.Infections()
n_initialization_points = max(gen_int_array.size, inf_hosp_int_array.size)

I0 = InfectionInitializationProcess(
"I0_initialization",
metaclass.DistributionalRV(
name="I0",
distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)),
),
InitializeInfectionsExponentialGrowth(
gen_int_array.size,
n_initialization_points,
deterministic.DeterministicVariable(name="rate", value=0.05),
),
t_unit=1,
Expand Down Expand Up @@ -201,11 +203,7 @@ hosp_model.run(
out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(
daily_hosp_admits.astype(float),
(gen_int_array.size, 0),
constant_values=np.nan,
),
obs_signal=daily_hosp_admits.astype(float),
)
```

Expand Down Expand Up @@ -299,11 +297,7 @@ The new model with the day-of-the-week effect can be compared to the previous mo
out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(
daily_hosp_admits.astype(float),
(gen_int_array.size, 0),
constant_values=np.nan,
),
obs_signal=daily_hosp_admits.astype(float),
)
```

Expand All @@ -314,10 +308,6 @@ out = hosp_model.plot_posterior(
out_dow = hosp_model_dow.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(
daily_hosp_admits.astype(float),
(gen_int_array.size, 0),
constant_values=np.nan,
),
obs_signal=daily_hosp_admits.astype(float),
)
```
41 changes: 15 additions & 26 deletions docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,17 @@ inf_hosp_int = datasets.load_infection_admission_interval()
# We only need the probability_mass column of each dataset
gen_int_array = gen_int["probability_mass"].to_numpy()
gen_int = gen_int_array
inf_hosp_int = inf_hosp_int["probability_mass"].to_numpy()
inf_hosp_int_array = inf_hosp_int["probability_mass"].to_numpy()

# Taking a peek at the first 5 elements of each
gen_int[:5], inf_hosp_int[:5]
gen_int[:5], inf_hosp_int_array[:5]

# Visualizing both quantities side by side
fig, axs = plt.subplots(1, 2)

axs[0].plot(gen_int)
axs[0].set_title("Generation interval")
axs[1].plot(inf_hosp_int)
axs[1].plot(inf_hosp_int_array)
axs[1].set_title("Infection to hospital admission interval")
plt.show()
```
Expand All @@ -142,7 +142,7 @@ import jax.numpy as jnp
import numpyro.distributions as dist

inf_hosp_int = deterministic.DeterministicPMF(
name="inf_hosp_int", value=inf_hosp_int
name="inf_hosp_int", value=inf_hosp_int_array
)

hosp_rate = metaclass.DistributionalRV(
Expand All @@ -168,14 +168,15 @@ from pyrenew.latent import (

# Infection process
latent_inf = latent.Infections()
n_initialization_points = max(gen_int_array.size, inf_hosp_int_array.size)
I0 = InfectionInitializationProcess(
"I0_initialization",
metaclass.DistributionalRV(
name="I0",
distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)),
),
InitializeInfectionsExponentialGrowth(
gen_int_array.size,
n_initialization_points,
deterministic.DeterministicVariable(name="rate", value=0.05),
),
t_unit=1,
Expand Down Expand Up @@ -308,11 +309,7 @@ We can use the `Model` object's `plot_posterior` method to visualize the model f
out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(
daily_hosp_admits.astype(float),
(gen_int_array.size, 0),
constant_values=np.nan,
),
obs_signal=daily_hosp_admits.astype(float),
)
```

Expand Down Expand Up @@ -381,22 +378,14 @@ axes.set_ylabel("Hospital Admissions", fontsize=10)
plt.show()
```

We can look at individual draws from the posterior distribution of latent infections:

```{python}
# | label: fig-output-infections
# | fig-cap: Latent infections
out2 = hosp_model.plot_posterior(
var="all_latent_infections", ylab="Latent Infections"
)
```

We can also look at credible intervals for the posterior distribution of latent infections:

```{python}
# | label: fig-output-infections-distribution
# | fig-cap: Posterior Latent Infections
x_data = idata.posterior["all_latent_infections_dim_0"]
x_data = (
idata.posterior["all_latent_infections_dim_0"] - n_initialization_points
)
y_data = idata.posterior["all_latent_infections"]

fig, axes = plt.subplots(figsize=(6, 5))
Expand Down Expand Up @@ -499,7 +488,7 @@ def compute_eti(dataset, eti_prob):

fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
idata.prior_predictive["negbinom_rv_dim_0"] + gen_int.size(),
idata.prior_predictive["negbinom_rv_dim_0"],
damonbayer marked this conversation as resolved.
Show resolved Hide resolved
hdi_data=compute_eti(idata.prior_predictive["negbinom_rv"], 0.9),
color="C0",
smooth=False,
Expand All @@ -508,7 +497,7 @@ az.plot_hdi(
)

az.plot_hdi(
idata.prior_predictive["negbinom_rv_dim_0"] + gen_int.size(),
idata.prior_predictive["negbinom_rv_dim_0"],
hdi_data=compute_eti(idata.prior_predictive["negbinom_rv"], 0.5),
color="C0",
smooth=False,
Expand All @@ -517,7 +506,7 @@ az.plot_hdi(
)

plt.scatter(
idata.observed_data["negbinom_rv_dim_0"] + gen_int.size(),
idata.observed_data["negbinom_rv_dim_0"],
idata.observed_data["negbinom_rv"],
color="black",
)
Expand All @@ -533,7 +522,7 @@ And now we plot the posterior predictive distributions with a `{python} n_foreca
```{python}
# | label: fig-output-posterior-predictive-forecast
# | fig-cap: Posterior predictive admissions, including a forecast.
x_data = idata.posterior_predictive["negbinom_rv_dim_0"] + gen_int.size()
x_data = idata.posterior_predictive["negbinom_rv_dim_0"]
y_data = idata.posterior_predictive["negbinom_rv"]
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
Expand Down Expand Up @@ -564,7 +553,7 @@ plt.plot(
label="Median",
)
plt.scatter(
idata.observed_data["negbinom_rv_dim_0"] + gen_int.size(),
idata.observed_data["negbinom_rv_dim_0"],
idata.observed_data["negbinom_rv"],
color="black",
)
Expand Down
42 changes: 42 additions & 0 deletions pyrenew/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,45 @@ def _new_scanner(
return latest, (new_val, m_net1)

return _new_scanner


def compute_delay_ascertained_incidence(
p_observed_given_incident: ArrayLike,
latent_incidence: ArrayLike,
delay_incidence_to_observation_delay_pmf: ArrayLike,
damonbayer marked this conversation as resolved.
Show resolved Hide resolved
damonbayer marked this conversation as resolved.
Show resolved Hide resolved
) -> ArrayLike:
"""
Computes incidences observed according
to a given observation rate and based
on a delay interval.

Parameters
----------
p_observed_given_incident: ArrayLike
The rate at which latent incident counts translate into observed counts.
For example, setting ``p_observed_given_incident=0.001``
when the incident counts are infections and the observed counts are
reported hospital admissions could be used to model disease and population
for which the probability (reported) hospital.admission given infection is
sbidari marked this conversation as resolved.
Show resolved Hide resolved
0.001.
latent_incidence: ArrayLike
Incidence values based on the true underlying process.
delay_incidence_to_observation_delay_pmf: ArrayLike
Probability mass function of delay interval from incidence to observation,
where the :math`i^{th}` entry (0-indexed) represents a delay of :math:`1+i`
time units, i.e. ``delay_incidence_to_observation_delay_pmf[0]`` represents
the fraction of observations that are delayed 1 time unit,
``delay_incidence_to_observation_delay_pmf[1]`` represents the fraction
that are delayed 2 time units, et cetera.

Returns
--------
ArrayLike
The predicted timeseries of delayed observations.
"""
delay_obs_incidence = jnp.convolve(
p_observed_given_incident * latent_incidence,
delay_incidence_to_observation_delay_pmf,
mode="valid",
)
return delay_obs_incidence
9 changes: 5 additions & 4 deletions pyrenew/latent/hospitaladmissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpyro

import pyrenew.arrayutils as au
from pyrenew.convolve import compute_delay_ascertained_incidence
from pyrenew.deterministic import DeterministicVariable
from pyrenew.metaclass import RandomVariable, SampledValue

Expand Down Expand Up @@ -210,11 +211,11 @@ def sample(
*_,
) = self.infection_to_admission_interval_rv(**kwargs)

latent_hospital_admissions = jnp.convolve(
infection_hosp_rate.value * latent_infections.value,
latent_hospital_admissions = compute_delay_ascertained_incidence(
infection_hosp_rate.value,
latent_infections.value,
infection_to_admission_interval.value,
mode="full",
)[: latent_infections.value.shape[0]]
)

# Applying the day of the week effect. For this we need to:
# 1. Get the day of the week effect
Expand Down
49 changes: 49 additions & 0 deletions test/test_incidence_observed_with_delay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# numpydoc ignore=GL08

import jax.numpy as jnp
import pytest
from numpy.testing import assert_array_equal

from pyrenew.convolve import compute_delay_ascertained_incidence


@pytest.mark.parametrize(
["obs_rate", "latent_incidence", "delay_interval", "expected_output"],
[
[
jnp.array([1.0]),
jnp.array([1.0, 2.0, 3.0]),
jnp.array([1.0]),
jnp.array([1.0, 2.0, 3.0]),
],
[
jnp.array([1.0, 0.1, 1.0]),
jnp.array([1.0, 2.0, 3.0]),
jnp.array([1.0]),
jnp.array([1.0, 0.2, 3.0]),
],
[
jnp.array([1.0]),
jnp.array([1.0, 2.0, 3.0]),
jnp.array([0.5, 0.5]),
jnp.array([1.5, 2.5]),
],
[
jnp.array([1.0]),
jnp.array([0, 2.0, 4.0]),
jnp.array([0.25, 0.5, 0.25]),
jnp.array([2]),
],
],
)
def test(obs_rate, latent_incidence, delay_interval, expected_output):
"""
Tests for helper function to compute
incidence observed with a delay
"""
result = compute_delay_ascertained_incidence(
obs_rate,
latent_incidence,
delay_interval,
)
assert_array_equal(result, expected_output)
Loading