Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wastewater observation process #310

Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
272819f
add ww obs process
sbidari Jan 23, 2025
0106d2a
Merge branch 'main' into 300-update-wastewaterobservationprocess-clas…
sbidari Jan 23, 2025
7b7a239
pre-commit
sbidari Jan 23, 2025
de0f638
update build pyrenew model
sbidari Jan 23, 2025
1c5c249
fix output assignment in test latent infection process
sbidari Jan 23, 2025
6fba2fb
include some rvs as placeholders
sbidari Jan 23, 2025
c589679
change prior
sbidari Jan 23, 2025
3607397
use priors from stan model
sbidari Jan 23, 2025
c8f5a19
fix typo
sbidari Jan 23, 2025
bf32430
move data args to sample method
sbidari Jan 24, 2025
4025ebc
pre-commit
sbidari Jan 24, 2025
d50cdd1
update prod_priors.py
sbidari Jan 24, 2025
aaa30a5
sync
sbidari Jan 24, 2025
d8bcadd
update generate_predictive.py
sbidari Jan 27, 2025
c502005
pre-commit
sbidari Jan 27, 2025
d88cc3d
move model constant to prior as per dhm suggestion
sbidari Jan 27, 2025
380f285
make get_viral_trajectory a class method
sbidari Jan 28, 2025
b71b551
remove outdated call to utils
sbidari Jan 28, 2025
2941c31
Merge branch 'main' into 300-update-wastewaterobservationprocess-clas…
sbidari Jan 28, 2025
783425a
output assignment
sbidari Jan 28, 2025
45548c5
Merge branch '300-update-wastewaterobservationprocess-class-in-pyrene…
sbidari Jan 28, 2025
a5672cc
state -> population
sbidari Jan 28, 2025
ccee24e
code review suggestions
sbidari Jan 28, 2025
c31179c
code review suggestions
sbidari Jan 28, 2025
8f66b5f
update distributional variable sample call
sbidari Jan 28, 2025
1ec9038
Apply suggestions from code review
sbidari Jan 29, 2025
dbf1a82
Merge branch '300-update-wastewaterobservationprocess-class-in-pyrene…
sbidari Jan 29, 2025
1ceead2
pre-commit
sbidari Jan 29, 2025
7f1eb62
Apply suggestions from code review
sbidari Jan 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion demos/ww_model/ww_model_demo.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ from pyrenew.deterministic import DeterministicVariable, DeterministicPMF

from pyrenew.randomvariable import DistributionalVariable, TransformedVariable
from pyrenew_hew.ww_site_level_dynamics_model import ww_site_level_dynamics_model
from pyrenew_hew.utils import convert_to_logmean_log_sd
import pyrenew_hew.plotting as plotting

numpyro.set_host_device_count(4)
Expand All @@ -29,6 +28,14 @@ We will use the data used in the `wwinference` [vignette](https://github.com/CDC
with open("data/fit/stan_data.json","r") as file:
stan_data = json.load(file)

# define functions called later
def convert_to_logmean_log_sd(mean, sd):
logmean = jnp.log(
jnp.power(mean, 2) / jnp.sqrt(jnp.power(sd, 2) + jnp.power(mean, 2))
)
logsd = jnp.sqrt(jnp.log(1 + (jnp.power(sd, 2) / jnp.power(mean, 2))))
return logmean, logsd

```

```{python}
Expand Down
12 changes: 10 additions & 2 deletions pipelines/build_pyrenew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,16 @@
ihr_rv=priors["ihr_rv"],
)

# placeholder
my_wastewater_obs_model = WastewaterObservationProcess()
my_wastewater_obs_model = WastewaterObservationProcess(

Check warning on line 112 in pipelines/build_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/build_pyrenew_model.py#L112

Added line #L112 was not covered by tests
t_peak_rv=priors["t_peak_rv"],
duration_shed_after_peak_rv=priors["duration_shed_after_peak_rv"],
log10_genome_per_inf_ind_rv=priors["log10_genome_per_inf_ind_rv"],
mode_sigma_ww_site_rv=priors["mode_sigma_ww_site_rv"],
sd_log_sigma_ww_site_rv=priors["sd_log_sigma_ww_site_rv"],
mode_sd_ww_site_rv=priors["mode_sd_ww_site_rv"],
max_shed_interval=priors["max_shed_interval"],
ww_ml_produced_per_day=priors["ww_ml_produced_per_day"],
)

my_model = PyrenewHEWModel(
population_size=population_size,
Expand Down
2 changes: 1 addition & 1 deletion pipelines/generate_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def generate_and_save_predictions(
data=forecast_data,
sample_ed_visits=True,
sample_hospital_admissions=True,
sample_wastewater=True,
sample_wastewater=False,
)

idata = az.from_numpyro(
Expand Down
27 changes: 27 additions & 0 deletions pipelines/priors/prod_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,30 @@
hosp_admit_neg_bin_concentration_rv = DistributionalVariable(
"hosp_admit_neg_bin_concentration", dist.LogNormal(4, 2)
)

t_peak_rv = DistributionalVariable("t_peak", dist.TruncatedNormal(5, 1, low=0))

duration_shed_after_peak_rv = DistributionalVariable(
"durtion_shed_after_peak", dist.TruncatedNormal(12, 3, low=0)
)

log10_genome_per_inf_ind_rv = DistributionalVariable(
"log10_genome_per_inf_ind", dist.Normal(12, 2)
)

mode_sigma_ww_site_rv = DistributionalVariable(
"mode_sigma_ww_site",
dist.TruncatedNormal(1, 1, low=0),
)

sd_log_sigma_ww_site_rv = DistributionalVariable(
"sd_log_sigma_ww_site", dist.TruncatedNormal(0, 0.693, low=0)
)

mode_sd_ww_site_rv = DistributionalVariable(
"mode_sd_ww_site", dist.TruncatedNormal(0, 0.25, low=0)
)

# model constants related to wastewater obs process
ww_ml_produced_per_day = 227000
max_shed_interval = 26
260 changes: 248 additions & 12 deletions pyrenew_hew/pyrenew_hew_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# numpydoc ignore=GL08
import datetime

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
Expand Down Expand Up @@ -254,7 +255,7 @@
numpyro.deterministic("rt", inf_with_feedback_proc_sample.rt)
numpyro.deterministic("latent_infections", latent_infections)

return latent_infections
return latent_infections, latent_infections_subpop


class EDVisitObservationProcess(RandomVariable):
Expand Down Expand Up @@ -466,18 +467,232 @@

class WastewaterObservationProcess(RandomVariable):
"""
Placeholder for wastewater obs process
Observe and/or predict wastewater concentration
"""

def __init__(self) -> None:
pass

def sample(self):
pass
def __init__(
self,
t_peak_rv: RandomVariable,
duration_shed_after_peak_rv: RandomVariable,
log10_genome_per_inf_ind_rv: RandomVariable,
mode_sigma_ww_site_rv: RandomVariable,
sd_log_sigma_ww_site_rv: RandomVariable,
mode_sd_ww_site_rv: RandomVariable,
max_shed_interval: float,
ww_ml_produced_per_day: float,
) -> None:
self.t_peak_rv = t_peak_rv
self.duration_shed_after_peak_rv = duration_shed_after_peak_rv
self.log10_genome_per_inf_ind_rv = log10_genome_per_inf_ind_rv
self.mode_sigma_ww_site_rv = mode_sigma_ww_site_rv
self.sd_log_sigma_ww_site_rv = sd_log_sigma_ww_site_rv
self.mode_sd_ww_site_rv = mode_sd_ww_site_rv
self.max_shed_interval = max_shed_interval
self.ww_ml_produced_per_day = ww_ml_produced_per_day

Check warning on line 491 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L484-L491

Added lines #L484 - L491 were not covered by tests

def validate(self):
pass

@staticmethod
def normed_shedding_cdf(
time: ArrayLike, t_p: float, t_d: float, log_base: float
) -> ArrayLike:
"""
calculates fraction of total fecal RNA shedding that has occurred
by a given time post infection.


Parameters
----------
time: ArrayLike
Time points to calculate the CDF of viral shedding.
t_p : float
Time (in days) from infection to peak shedding.
t_d: float
Time (in days) from peak shedding to the end of shedding.
log_base: float
Log base used for the shedding kinetics function.


Returns
-------
ArrayLike
Normalized CDF values of viral shedding at each time point.
"""
norm_const = (t_p + t_d) * ((log_base - 1) / jnp.log(log_base) - 1)

Check warning on line 522 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L522

Added line #L522 was not covered by tests

def ad_pre(x):
return (

Check warning on line 525 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L524-L525

Added lines #L524 - L525 were not covered by tests
t_p / jnp.log(log_base) * jnp.exp(jnp.log(log_base) * x / t_p)
- x
)

def ad_post(x):
return (

Check warning on line 531 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L530-L531

Added lines #L530 - L531 were not covered by tests
-t_d
/ jnp.log(log_base)
* jnp.exp(jnp.log(log_base) * (1 - ((x - t_p) / t_d)))
- x
)

return (

Check warning on line 538 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L538

Added line #L538 was not covered by tests
jnp.where(
time < t_p + t_d,
jnp.where(
time < t_p,
ad_pre(time) - ad_pre(0),
ad_pre(t_p) - ad_pre(0) + ad_post(time) - ad_post(t_p),
),
norm_const,
)
/ norm_const
)

def get_viral_trajectory(
self,
tpeak: float,
duration_shed_after_peak: float,
) -> ArrayLike:
"""
Computes the probability mass function (PMF) of
daily viral shedding based on a normalized CDF.

Parameters
----------
tpeak: float
Time (in days) from infection to peak viral load in shedding.
duration_shed_after_peak: float
Duration (in days) of detectable viral shedding after the peak.

Returns
-------
ArrayLike
Normalized daily viral shedding PMF
"""
daily_shedding_pmf = self.normed_shedding_cdf(

Check warning on line 572 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L572

Added line #L572 was not covered by tests
jnp.arange(1, self.max_shed_interval),
tpeak,
duration_shed_after_peak,
10,
) - self.normed_shedding_cdf(
jnp.arange(0, self.max_shed_interval - 1),
tpeak,
duration_shed_after_peak,
10,
)
return daily_shedding_pmf

Check warning on line 583 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L583

Added line #L583 was not covered by tests

def sample(
self,
latent_infections_subpop: ArrayLike,
data_observed: ArrayLike,
n_datapoints: int,
ww_uncensored: ArrayLike,
ww_censored: ArrayLike,
ww_sampled_lab_sites: ArrayLike,
ww_sampled_subpops: ArrayLike,
ww_sampled_times: ArrayLike,
ww_log_lod: ArrayLike,
lab_site_to_subpop_map: ArrayLike,
n_ww_lab_sites: int,
shedding_offset: float,
pop_fraction: ArrayLike,
):
t_peak = self.t_peak_rv()
dur_shed = self.duration_shed_after_peak_rv()
viral_kinetics = self.get_viral_trajectory(t_peak, dur_shed)

Check warning on line 603 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L601-L603

Added lines #L601 - L603 were not covered by tests

def batch_colvolve_fn(m):
return jnp.convolve(m, viral_kinetics, mode="valid")

Check warning on line 606 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L605-L606

Added lines #L605 - L606 were not covered by tests

model_net_inf_ind_shedding = jax.vmap(

Check warning on line 608 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L608

Added line #L608 was not covered by tests
batch_colvolve_fn, in_axes=1, out_axes=1
)(jnp.atleast_2d(latent_infections_subpop))[-n_datapoints:, :]
numpyro.deterministic(

Check warning on line 611 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L611

Added line #L611 was not covered by tests
"model_net_inf_ind_shedding", model_net_inf_ind_shedding
)

log10_genome_per_inf_ind = self.log10_genome_per_inf_ind_rv()
expected_obs_viral_genomes = (

Check warning on line 616 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L615-L616

Added lines #L615 - L616 were not covered by tests
jnp.log(10) * log10_genome_per_inf_ind
+ jnp.log(model_net_inf_ind_shedding + shedding_offset)
- jnp.log(self.ww_ml_produced_per_day)
)
numpyro.deterministic(

Check warning on line 621 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L621

Added line #L621 was not covered by tests
"expected_obs_viral_genomes", expected_obs_viral_genomes
)

mode_sigma_ww_site = self.mode_sigma_ww_site_rv()
sd_log_sigma_ww_site = self.sd_log_sigma_ww_site_rv()
mode_sd_ww_site = self.mode_sd_ww_site_rv()

Check warning on line 627 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L625-L627

Added lines #L625 - L627 were not covered by tests

mode_ww_site_rv = DistributionalVariable(

Check warning on line 629 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L629

Added line #L629 was not covered by tests
"mode_ww_site",
dist.Normal(0, mode_sd_ww_site),
reparam=LocScaleReparam(0),
) # lab-site specific variation

sigma_ww_site_rv = TransformedVariable(

Check warning on line 635 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L635

Added line #L635 was not covered by tests
"sigma_ww_site",
DistributionalVariable(
"log_sigma_ww_site",
dist.Normal(jnp.log(mode_sigma_ww_site), sd_log_sigma_ww_site),
reparam=LocScaleReparam(0),
),
transforms=transformation.ExpTransform(),
)

with numpyro.plate("n_ww_lab_sites", n_ww_lab_sites):
mode_ww_site = mode_ww_site_rv()
sigma_ww_site = sigma_ww_site_rv()

Check warning on line 647 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L645-L647

Added lines #L645 - L647 were not covered by tests

# multiply the expected observed genomes by the site-specific multiplier at that sampling time
expected_obs_log_v_site = (

Check warning on line 650 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L650

Added line #L650 was not covered by tests
expected_obs_viral_genomes[ww_sampled_times, ww_sampled_subpops]
+ mode_ww_site[ww_sampled_lab_sites]
)

DistributionalVariable(

Check warning on line 655 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L655

Added line #L655 was not covered by tests
"log_conc_obs",
dist.Normal(
loc=expected_obs_log_v_site[ww_uncensored],
scale=sigma_ww_site[ww_sampled_lab_sites[ww_uncensored]],
),
).sample(
obs=(
data_observed[ww_uncensored]
if data_observed is not None
else None
),
)

if ww_censored.shape[0] != 0:
log_cdf_values = dist.Normal(

Check warning on line 670 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L669-L670

Added lines #L669 - L670 were not covered by tests
loc=expected_obs_log_v_site[ww_censored],
scale=sigma_ww_site[ww_sampled_lab_sites[ww_censored]],
).log_cdf(ww_log_lod[ww_censored])
numpyro.factor("log_prob_censored", log_cdf_values.sum())

Check warning on line 674 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L674

Added line #L674 was not covered by tests

# Predict site and population level wastewater concentrations
site_log_ww_conc = DistributionalVariable(

Check warning on line 677 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L677

Added line #L677 was not covered by tests
"site_log_ww_conc",
dist.Normal(
loc=expected_obs_viral_genomes[:, lab_site_to_subpop_map]
+ mode_ww_site,
scale=sigma_ww_site,
),
)()

population_latent_viral_genome_conc = jax.scipy.special.logsumexp(

Check warning on line 686 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L686

Added line #L686 was not covered by tests
expected_obs_viral_genomes, axis=1, b=pop_fraction
)
numpyro.deterministic(

Check warning on line 689 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L689

Added line #L689 was not covered by tests
"population_latent_viral_genome_conc",
population_latent_viral_genome_conc,
)

return site_log_ww_conc, population_latent_viral_genome_conc

Check warning on line 694 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L694

Added line #L694 was not covered by tests


class PyrenewHEWModel(Model): # numpydoc ignore=GL08
def __init__(
Expand Down Expand Up @@ -505,16 +720,19 @@
sample_wastewater: bool = False,
) -> dict[str, ArrayLike]: # numpydoc ignore=GL08
n_init_days = self.latent_infection_process_rv.n_initialization_points
latent_infections = self.latent_infection_process_rv(
n_days_post_init=data.n_days_post_init,
latent_infections, latent_infections_subpop = (

Check warning on line 723 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L723

Added line #L723 was not covered by tests
self.latent_infection_process_rv(
n_days_post_init=data.n_days_post_init,
)
)
first_latent_infection_dow = (
data.first_data_date_overall - datetime.timedelta(days=n_init_days)
).weekday()

observed_ed_visits = None
observed_admissions = None
observed_wastewater = None
site_level_observed_wastewater = None
population_level_observed_wastewater = None

Check warning on line 735 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L734-L735

Added lines #L734 - L735 were not covered by tests
sbidari marked this conversation as resolved.
Show resolved Hide resolved

iedr = None

Expand All @@ -537,10 +755,28 @@
iedr=iedr,
)
if sample_wastewater:
observed_wastewater = self.wastewater_obs_process_rv()
(

Check warning on line 758 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L758

Added line #L758 was not covered by tests
site_level_observed_wastewater,
population_level_observed_wastewater,
sbidari marked this conversation as resolved.
Show resolved Hide resolved
) = self.wastewater_obs_process_rv(
latent_infections=latent_infections,
sbidari marked this conversation as resolved.
Show resolved Hide resolved
latent_infections_subpop=latent_infections_subpop,
data_observed=data.data_observed_disease_wastewater,
n_datapoints=data.n_wastewater_datapoints,
ww_uncensored=None, # placeholder
ww_censored=None, # placeholder
ww_sampled_lab_sites=None, # placeholder
ww_sampled_subpops=None, # placeholder
ww_sampled_times=None, # placeholder
ww_log_lod=None, # placeholder
lab_site_to_subpop_map=None, # placeholder
n_ww_lab_sites=None, # placeholder
shedding_offset=1e-8,
)

return {
"ed_visits": observed_ed_visits,
"hospital_admissions": observed_admissions,
"wasewater": observed_wastewater,
"site_level_wastewater_conc": site_level_observed_wastewater,
"population_level_wastewater_conc": population_level_observed_wastewater,
sbidari marked this conversation as resolved.
Show resolved Hide resolved
}
Loading
Loading