Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test convolve mode in hospitaladmissionspy #398

Merged
merged 25 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
65b248e
testing convolve mode
sbidari Aug 20, 2024
8d66ef8
Merge branch 'main' of https://github.com/CDCgov/multisignal-epi-infe…
sbidari Aug 20, 2024
17107c2
Merge branch 'main' into 385-incorrect-convolve-mode-in-hospitaladmis…
sbidari Aug 21, 2024
3a54855
update tutorial to work with convolve mode valid
sbidari Aug 21, 2024
399250f
Merge branch 'main' into 385-incorrect-convolve-mode-in-hospitaladmis…
sbidari Aug 21, 2024
9a7cbb3
update latent admissions test
sbidari Aug 21, 2024
cbff93c
update DOW tutorial for convolve mode valid
sbidari Aug 21, 2024
41b070f
update hosp model tests
sbidari Aug 21, 2024
e9130ce
create helper function for convolve and add tests
sbidari Aug 21, 2024
4737288
forgot to run precommit earlier
sbidari Aug 21, 2024
eb9e168
Merge branch 'main' into 385-incorrect-convolve-mode-in-hospitaladmis…
sbidari Aug 21, 2024
e87d742
update test for model with DOW effect
sbidari Aug 21, 2024
999d124
Merge branch 'main' of https://github.com/CDCgov/PyRenew into 385-inc…
sbidari Aug 22, 2024
b4c5ca2
renaming helper function, add n_initialization_point
sbidari Aug 22, 2024
6840243
Merge branch 'main' of https://github.com/CDCgov/PyRenew into 385-inc…
sbidari Aug 22, 2024
ed19002
Apply suggestions from code review
sbidari Aug 22, 2024
0599cfc
Merge branch 'main' into 385-incorrect-convolve-mode-in-hospitaladmis…
sbidari Aug 22, 2024
9095259
move helper function from metaclass to convolve.py
sbidari Aug 23, 2024
c69d6cc
uniformize starting point of all plots
sbidari Aug 23, 2024
c28ec02
adopt new var names
sbidari Aug 23, 2024
4bc1736
fix var names
sbidari Aug 23, 2024
ea72333
fix docstring
sbidari Aug 23, 2024
5c4395f
update n_initialization_points
sbidari Aug 23, 2024
554301b
Update pyrenew/convolve.py
damonbayer Aug 26, 2024
3efcfad
Update pyrenew/convolve.py
damonbayer Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 8 additions & 18 deletions docs/source/tutorials/day_of_the_week.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ inf_hosp_int = datasets.load_infection_admission_interval()
# We only need the probability_mass column of each dataset
gen_int_array = gen_int["probability_mass"].to_numpy()
gen_int = gen_int_array
inf_hosp_int = inf_hosp_int["probability_mass"].to_numpy()
inf_hosp_int_array = inf_hosp_int["probability_mass"].to_numpy()
```

2. Next, we defined the model's components:
Expand All @@ -56,7 +56,7 @@ import jax.numpy as jnp
import numpyro.distributions as dist

inf_hosp_int = deterministic.DeterministicPMF(
name="inf_hosp_int", value=inf_hosp_int
name="inf_hosp_int", value=inf_hosp_int_array
)

hosp_rate = metaclass.DistributionalRV(
Expand All @@ -77,14 +77,16 @@ from pyrenew.latent import (

# Infection process
latent_inf = latent.Infections()
n_initialization_points = max(gen_int_array.size, inf_hosp_int_array.size)

I0 = InfectionInitializationProcess(
"I0_initialization",
metaclass.DistributionalRV(
name="I0",
distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)),
),
InitializeInfectionsExponentialGrowth(
gen_int_array.size,
n_initialization_points,
deterministic.DeterministicVariable(name="rate", value=0.05),
),
t_unit=1,
Expand Down Expand Up @@ -201,11 +203,7 @@ hosp_model.run(
out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(
daily_hosp_admits.astype(float),
(gen_int_array.size, 0),
constant_values=np.nan,
),
obs_signal=daily_hosp_admits.astype(float),
)
```

Expand Down Expand Up @@ -299,11 +297,7 @@ The new model with the day-of-the-week effect can be compared to the previous mo
out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(
daily_hosp_admits.astype(float),
(gen_int_array.size, 0),
constant_values=np.nan,
),
obs_signal=daily_hosp_admits.astype(float),
)
```

Expand All @@ -314,10 +308,6 @@ out = hosp_model.plot_posterior(
out_dow = hosp_model_dow.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(
daily_hosp_admits.astype(float),
(gen_int_array.size, 0),
constant_values=np.nan,
),
obs_signal=daily_hosp_admits.astype(float),
)
```
27 changes: 12 additions & 15 deletions docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,17 @@ inf_hosp_int = datasets.load_infection_admission_interval()
# We only need the probability_mass column of each dataset
gen_int_array = gen_int["probability_mass"].to_numpy()
gen_int = gen_int_array
inf_hosp_int = inf_hosp_int["probability_mass"].to_numpy()
inf_hosp_int_array = inf_hosp_int["probability_mass"].to_numpy()

# Taking a peek at the first 5 elements of each
gen_int[:5], inf_hosp_int[:5]
gen_int[:5], inf_hosp_int_array[:5]

# Visualizing both quantities side by side
fig, axs = plt.subplots(1, 2)

axs[0].plot(gen_int)
axs[0].set_title("Generation interval")
axs[1].plot(inf_hosp_int)
axs[1].plot(inf_hosp_int_array)
axs[1].set_title("Infection to hospital admission interval")
plt.show()
```
Expand All @@ -142,7 +142,7 @@ import jax.numpy as jnp
import numpyro.distributions as dist

inf_hosp_int = deterministic.DeterministicPMF(
name="inf_hosp_int", value=inf_hosp_int
name="inf_hosp_int", value=inf_hosp_int_array
)

hosp_rate = metaclass.DistributionalRV(
Expand All @@ -168,14 +168,15 @@ from pyrenew.latent import (

# Infection process
latent_inf = latent.Infections()
n_initialization_points = max(gen_int_array.size, inf_hosp_int_array.size)
I0 = InfectionInitializationProcess(
"I0_initialization",
metaclass.DistributionalRV(
name="I0",
distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)),
),
InitializeInfectionsExponentialGrowth(
gen_int_array.size,
n_initialization_points,
deterministic.DeterministicVariable(name="rate", value=0.05),
),
t_unit=1,
Expand Down Expand Up @@ -308,11 +309,7 @@ We can use the `Model` object's `plot_posterior` method to visualize the model f
out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(
daily_hosp_admits.astype(float),
(gen_int_array.size, 0),
constant_values=np.nan,
),
obs_signal=daily_hosp_admits.astype(float),
)
```

Expand Down Expand Up @@ -499,7 +496,7 @@ def compute_eti(dataset, eti_prob):

fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
idata.prior_predictive["negbinom_rv_dim_0"] + gen_int.size(),
idata.prior_predictive["negbinom_rv_dim_0"],
damonbayer marked this conversation as resolved.
Show resolved Hide resolved
hdi_data=compute_eti(idata.prior_predictive["negbinom_rv"], 0.9),
color="C0",
smooth=False,
Expand All @@ -508,7 +505,7 @@ az.plot_hdi(
)

az.plot_hdi(
idata.prior_predictive["negbinom_rv_dim_0"] + gen_int.size(),
idata.prior_predictive["negbinom_rv_dim_0"],
hdi_data=compute_eti(idata.prior_predictive["negbinom_rv"], 0.5),
color="C0",
smooth=False,
Expand All @@ -517,7 +514,7 @@ az.plot_hdi(
)

plt.scatter(
idata.observed_data["negbinom_rv_dim_0"] + gen_int.size(),
idata.observed_data["negbinom_rv_dim_0"],
idata.observed_data["negbinom_rv"],
color="black",
)
Expand All @@ -533,7 +530,7 @@ And now we plot the posterior predictive distributions with a `{python} n_foreca
```{python}
# | label: fig-output-posterior-predictive-forecast
# | fig-cap: Posterior predictive admissions, including a forecast.
x_data = idata.posterior_predictive["negbinom_rv_dim_0"] + gen_int.size()
x_data = idata.posterior_predictive["negbinom_rv_dim_0"]
y_data = idata.posterior_predictive["negbinom_rv"]
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
Expand Down Expand Up @@ -564,7 +561,7 @@ plt.plot(
label="Median",
)
plt.scatter(
idata.observed_data["negbinom_rv_dim_0"] + gen_int.size(),
idata.observed_data["negbinom_rv_dim_0"],
idata.observed_data["negbinom_rv"],
color="black",
)
Expand Down
14 changes: 9 additions & 5 deletions pyrenew/latent/hospitaladmissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

import pyrenew.arrayutils as au
from pyrenew.deterministic import DeterministicVariable
from pyrenew.metaclass import RandomVariable, SampledValue
from pyrenew.metaclass import (
RandomVariable,
SampledValue,
compute_delay_ascertained_incidence,
)


class HospitalAdmissionsSample(NamedTuple):
Expand Down Expand Up @@ -210,11 +214,11 @@ def sample(
*_,
) = self.infection_to_admission_interval_rv(**kwargs)

latent_hospital_admissions = jnp.convolve(
infection_hosp_rate.value * latent_infections.value,
latent_hospital_admissions = compute_delay_ascertained_incidence(
infection_hosp_rate.value,
latent_infections.value,
infection_to_admission_interval.value,
mode="full",
)[: latent_infections.value.shape[0]]
)

# Applying the day of the week effect. For this we need to:
# 1. Get the day of the week effect
Expand Down
33 changes: 33 additions & 0 deletions pyrenew/metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Callable, NamedTuple, Self, get_type_hints

import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -125,6 +126,38 @@ def _assert_sample_and_rtype(
return None


def compute_delay_ascertained_incidence(
sbidari marked this conversation as resolved.
Show resolved Hide resolved
incidence_to_observation_rate: ArrayLike,
latent_incidence: ArrayLike,
incidence_to_observation_delay_interval: ArrayLike,
) -> ArrayLike:
"""
Computes incidences observed according
to a given observation rate and based
on a delay interval.

Parameters
----------
incidence_to_observation_rate: ArrayLike
The rate at which latent incidences are observed.
sbidari marked this conversation as resolved.
Show resolved Hide resolved
latent_incidence: ArrayLike
Incidence values based on the true underlying process.
incidence_to_observation_delay_interval: ArrayLike
Pmf of delay interval between incidence to observation.
sbidari marked this conversation as resolved.
Show resolved Hide resolved

Returns
--------
ArrayLike
The incidence after the observation delay.
sbidari marked this conversation as resolved.
Show resolved Hide resolved
"""
delay_obs_incidence = jnp.convolve(
incidence_to_observation_rate * latent_incidence,
incidence_to_observation_delay_interval,
mode="valid",
)
return delay_obs_incidence


class SampledValue(NamedTuple):
"""
A container for a value sampled from a RandomVariable.
Expand Down
49 changes: 49 additions & 0 deletions test/test_incidence_observed_with_delay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# numpydoc ignore=GL08

import jax.numpy as jnp
import pytest
from numpy.testing import assert_array_equal

from pyrenew.metaclass import compute_delay_ascertained_incidence


@pytest.mark.parametrize(
["obs_rate", "latent_incidence", "delay_interval", "expected_output"],
[
[
jnp.array([1.0]),
jnp.array([1.0, 2.0, 3.0]),
jnp.array([1.0]),
jnp.array([1.0, 2.0, 3.0]),
],
[
jnp.array([1.0, 0.1, 1.0]),
jnp.array([1.0, 2.0, 3.0]),
jnp.array([1.0]),
jnp.array([1.0, 0.2, 3.0]),
],
[
jnp.array([1.0]),
jnp.array([1.0, 2.0, 3.0]),
jnp.array([0.5, 0.5]),
jnp.array([1.5, 2.5]),
],
[
jnp.array([1.0]),
jnp.array([0, 2.0, 4.0]),
jnp.array([0.25, 0.5, 0.25]),
jnp.array([2]),
],
],
)
def test(obs_rate, latent_incidence, delay_interval, expected_output):
"""
Tests for helper function to compute
incidence observed with a delay
"""
result = compute_delay_ascertained_incidence(
obs_rate,
latent_incidence,
delay_interval,
)
assert_array_equal(result, expected_output)
61 changes: 34 additions & 27 deletions test/test_latent_admissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,35 @@ def test_admissions_sample():
# Generating Rt and Infections to compute the hospital admissions

rt = SimpleRt()
n_steps = 30

with numpyro.handlers.seed(rng_seed=223):
sim_rt = rt(n=30)[0].value
sim_rt = rt(n=n_steps)[0].value

gen_int = jnp.array([0.5, 0.1, 0.1, 0.2, 0.1])
i0 = 10 * jnp.ones_like(gen_int)

inf_hosp_int_array = jnp.array(
[
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0.25,
0.5,
0.1,
0.1,
0.05,
]
)
i0 = 10 * jnp.ones_like(inf_hosp_int_array)
inf1 = Infections()

with numpyro.handlers.seed(rng_seed=223):
Expand All @@ -37,28 +59,7 @@ def test_admissions_sample():
# Testing the hospital admissions
inf_hosp = DeterministicPMF(
name="inf_hosp",
value=jnp.array(
[
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0.25,
0.5,
0.1,
0.1,
0.05,
]
),
value=inf_hosp_int_array,
)

hosp1 = HospitalAdmissions(
Expand All @@ -69,10 +70,16 @@ def test_admissions_sample():
)

with numpyro.handlers.seed(rng_seed=223):
sim_hosp_1 = hosp1(latent_infections=inf_sampled1[0])
sim_hosp_1 = hosp1(
latent_infections=SampledValue(
value=jnp.hstack(
[i0, inf_sampled1.post_initialization_infections.value]
)
)
)

testing.assert_array_less(
sim_hosp_1.latent_hospital_admissions.value,
sim_hosp_1.latent_hospital_admissions.value[-n_steps:],
inf_sampled1[0].value,
)
inf_hosp2 = jnp.ones(30)
Expand Down
Loading