From f652b4295f1cffb9904972f64209d11fd389e4dd Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Wed, 29 Jan 2025 16:57:45 -0600 Subject: [PATCH] Squashed commit of the following: commit 39c003dd52d911c3c25e7b71f80c9549507418b0 Author: Subekshya Bidari <37636707+sbidari@users.noreply.github.com> Date: Wed Jan 29 16:20:53 2025 -0500 wastewater observation process (#310) * add ww obs process * pre-commit * update build pyrenew model * fix output assignment in test latent infection process * include some rvs as placeholders * change prior * use priors from stan model * fix typo * move data args to sample method * pre-commit * update prod_priors.py * sync * update generate_predictive.py * pre-commit * move model constant to prior as per dhm suggestion * make get_viral_trajectory a class method * remove outdated call to utils * output assignment * state -> population * code review suggestions * code review suggestions * update distributional variable sample call * Apply suggestions from code review Co-authored-by: Dylan H. Morris * pre-commit * Apply suggestions from code review Co-authored-by: Dylan H. Morris --------- Co-authored-by: Dylan H. Morris commit df3648d21fbbbe4879b46f52fcb12239831f71e0 Author: Dylan H. Morris Date: Tue Jan 28 17:20:34 2025 +0000 Improvements and bug fixes for epiweekly other hubverse tables and plots (#312) --- .pre-commit-config.yaml | 2 +- demos/ww_model/ww_model_demo.qmd | 9 +- .../test_to_epiweekly_quantile_table.R | 111 ++++++-- pipelines/build_pyrenew_model.py | 12 +- pipelines/generate_predictive.py | 2 +- pipelines/priors/prod_priors.py | 27 ++ pyrenew_hew/pyrenew_hew_model.py | 260 +++++++++++++++++- pyrenew_hew/utils.py | 54 ---- pyrenew_hew/ww_site_level_dynamics_model.py | 102 ++++++- tests/test_latent_infection_process.py | 2 +- tests/test_utils.py | 27 -- 11 files changed, 484 insertions(+), 124 deletions(-) delete mode 100644 pyrenew_hew/utils.py delete mode 100644 tests/test_utils.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d4ab5a48..7bf26f02 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,7 +26,7 @@ repos: ##### # R - repo: https://github.com/lorenzwalthert/precommit - rev: v0.4.3.9001 + rev: v0.4.3 hooks: - id: style-files - id: lintr diff --git a/demos/ww_model/ww_model_demo.qmd b/demos/ww_model/ww_model_demo.qmd index 6747af6c..e2bbb532 100644 --- a/demos/ww_model/ww_model_demo.qmd +++ b/demos/ww_model/ww_model_demo.qmd @@ -16,7 +16,6 @@ from pyrenew.deterministic import DeterministicVariable, DeterministicPMF from pyrenew.randomvariable import DistributionalVariable, TransformedVariable from pyrenew_hew.ww_site_level_dynamics_model import ww_site_level_dynamics_model -from pyrenew_hew.utils import convert_to_logmean_log_sd import pyrenew_hew.plotting as plotting numpyro.set_host_device_count(4) @@ -29,6 +28,14 @@ We will use the data used in the `wwinference` [vignette](https://github.com/CDC with open("data/fit/stan_data.json","r") as file: stan_data = json.load(file) +# define functions called later +def convert_to_logmean_log_sd(mean, sd): + logmean = jnp.log( + jnp.power(mean, 2) / jnp.sqrt(jnp.power(sd, 2) + jnp.power(mean, 2)) + ) + logsd = jnp.sqrt(jnp.log(1 + (jnp.power(sd, 2) / jnp.power(mean, 2)))) + return logmean, logsd + ``` ```{python} diff --git a/hewr/tests/testthat/test_to_epiweekly_quantile_table.R b/hewr/tests/testthat/test_to_epiweekly_quantile_table.R index 15a6942b..b4600223 100644 --- a/hewr/tests/testthat/test_to_epiweekly_quantile_table.R +++ b/hewr/tests/testthat/test_to_epiweekly_quantile_table.R @@ -39,9 +39,20 @@ test_that("to_epiweekly_quantiles works as expected", { ) |> suppressMessages() expect_s3_class(result, "tbl_df") - expect_setequal(c( - "epiweek", "epiyear", "quantile_value", "quantile_level", "location" - ), colnames(result)) + checkmate::expect_names( + colnames(result), + identical.to = c( + "epiweek", + "epiyear", + "quantile_value", + "quantile_level", + "location", + "source_samples" + ) + ) + + expect_equal(draws_file_name, unique(result$source_samples)) + expect_gt(nrow(result), 0) } @@ -127,7 +138,11 @@ test_that("to_epiweekly_quantiles handles missing forecast files", { # tests for `to_epiweekly_quantile_table` -test_that("to_epiweekly_quantile_table handles multiple locations", { +test_that(paste0( + "to_epiweekly_quantile_table ", + "handles multiple locations ", + "and multiple source files" +), { batch_dir_name <- "covid-19_r_2024-12-14_f_2024-12-08_t_2024-12-14" tempdir <- withr::local_tempdir() @@ -142,6 +157,17 @@ test_that("to_epiweekly_quantile_table handles multiple locations", { if (loc != "loc3") { disease_cols <- c(disease_cols, "prop_disease_ed_visits") } + create_tidy_forecast_data( + directory = loc_dir, + filename = "epiweekly_with_epiweekly_other_samples.parquet", + date_cols = seq( + lubridate::ymd("2024-12-08"), lubridate::ymd("2024-12-14"), + by = "week" + ), + disease_cols = disease_cols, + n_draw = 25, + with_epiweek = TRUE + ) create_tidy_forecast_data( directory = loc_dir, @@ -157,7 +183,10 @@ test_that("to_epiweekly_quantile_table handles multiple locations", { }) ## should succeed despite loc3 not having valid draws with strict = FALSE - result_w_both_locations <- to_epiweekly_quantile_table(temp_batch_dir) |> + result_w_both_locations <- + to_epiweekly_quantile_table(temp_batch_dir, + epiweekly_other_locations = "loc1" + ) |> suppressMessages() ## should error if strict = TRUE because loc3 does not have @@ -168,6 +197,44 @@ test_that("to_epiweekly_quantile_table handles multiple locations", { "did not find valid draws" ) + ## should succeed with strict = TRUE if loc3 is excluded + alt_result_w_both_locations <- ( + to_epiweekly_quantile_table(temp_batch_dir, + strict = TRUE, + exclude = "loc3" + )) |> + suppressMessages() + + ## results should be equivalent for loc2, + ## but not for loc1 + expect_equal( + result_w_both_locations |> + dplyr::filter(location == "loc2"), + alt_result_w_both_locations |> + dplyr::filter(location == "loc2") + ) + + ## check that one used epiweekly + ## other for loc1 while other used + ## default, resulting in different values + loc1_a <- result_w_both_locations |> + dplyr::filter(location == "loc1") |> + dplyr::pull(.data$value) + loc1_b <- alt_result_w_both_locations |> + dplyr::filter(location == "loc1") |> + dplyr::pull(.data$value) + + ## length checks ensure that the + ## number of allowed equalities _could_ + ## be reached if the vectors were mostly + ## or entirely identical + expect_gt(length(loc1_a), 10) + expect_gt(length(loc1_b), 10) + expect_lt( + sum(loc1_a == loc1_b), + 5 + ) + expect_s3_class(result_w_both_locations, "tbl_df") expect_gt(nrow(result_w_both_locations), 0) checkmate::expect_names( @@ -181,20 +248,32 @@ test_that("to_epiweekly_quantile_table handles multiple locations", { "output_type", "output_type_id", "value", - "other_ed_visit_forecast" + "source_samples" ) ) expect_setequal( - c("loc1", "loc2"), - result_w_both_locations$location + result_w_both_locations$location, + c("loc1", "loc2") + ) + expect_setequal( + alt_result_w_both_locations$location, + c("loc1", "loc2") ) - expect_false("loc3" %in% result_w_both_locations$location) - result_w_one_location <- to_epiweekly_quantile_table( - model_batch_dir = temp_batch_dir, - exclude = "loc1" - ) |> - suppressMessages() - expect_true("loc2" %in% result_w_one_location$location) - expect_false("loc1" %in% result_w_one_location$location) + expect_setequal( + result_w_both_locations$source_samples, + c( + "epiweekly_samples", + "epiweekly_with_epiweekly_other_samples" + ) + ) + + expect_setequal( + alt_result_w_both_locations$source_samples, + "epiweekly_samples" + ) + + + expect_false("loc3" %in% result_w_both_locations$location) + expect_false("loc3" %in% alt_result_w_both_locations$location) }) diff --git a/pipelines/build_pyrenew_model.py b/pipelines/build_pyrenew_model.py index 80a94479..0019e29f 100644 --- a/pipelines/build_pyrenew_model.py +++ b/pipelines/build_pyrenew_model.py @@ -109,8 +109,16 @@ def build_model_from_dir( ihr_rv=priors["ihr_rv"], ) - # placeholder - my_wastewater_obs_model = WastewaterObservationProcess() + my_wastewater_obs_model = WastewaterObservationProcess( + t_peak_rv=priors["t_peak_rv"], + duration_shed_after_peak_rv=priors["duration_shed_after_peak_rv"], + log10_genome_per_inf_ind_rv=priors["log10_genome_per_inf_ind_rv"], + mode_sigma_ww_site_rv=priors["mode_sigma_ww_site_rv"], + sd_log_sigma_ww_site_rv=priors["sd_log_sigma_ww_site_rv"], + mode_sd_ww_site_rv=priors["mode_sd_ww_site_rv"], + max_shed_interval=priors["max_shed_interval"], + ww_ml_produced_per_day=priors["ww_ml_produced_per_day"], + ) my_model = PyrenewHEWModel( population_size=population_size, diff --git a/pipelines/generate_predictive.py b/pipelines/generate_predictive.py index e205c49b..d9fea1e3 100644 --- a/pipelines/generate_predictive.py +++ b/pipelines/generate_predictive.py @@ -33,7 +33,7 @@ def generate_and_save_predictions( data=forecast_data, sample_ed_visits=True, sample_hospital_admissions=True, - sample_wastewater=True, + sample_wastewater=False, ) idata = az.from_numpyro( diff --git a/pipelines/priors/prod_priors.py b/pipelines/priors/prod_priors.py index 2d5c0138..c8b1beaa 100644 --- a/pipelines/priors/prod_priors.py +++ b/pipelines/priors/prod_priors.py @@ -93,3 +93,30 @@ hosp_admit_neg_bin_concentration_rv = DistributionalVariable( "hosp_admit_neg_bin_concentration", dist.LogNormal(4, 2) ) + +t_peak_rv = DistributionalVariable("t_peak", dist.TruncatedNormal(5, 1, low=0)) + +duration_shed_after_peak_rv = DistributionalVariable( + "durtion_shed_after_peak", dist.TruncatedNormal(12, 3, low=0) +) + +log10_genome_per_inf_ind_rv = DistributionalVariable( + "log10_genome_per_inf_ind", dist.Normal(12, 2) +) + +mode_sigma_ww_site_rv = DistributionalVariable( + "mode_sigma_ww_site", + dist.TruncatedNormal(1, 1, low=0), +) + +sd_log_sigma_ww_site_rv = DistributionalVariable( + "sd_log_sigma_ww_site", dist.TruncatedNormal(0, 0.693, low=0) +) + +mode_sd_ww_site_rv = DistributionalVariable( + "mode_sd_ww_site", dist.TruncatedNormal(0, 0.25, low=0) +) + +# model constants related to wastewater obs process +ww_ml_produced_per_day = 227000 +max_shed_interval = 26 diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 37f0e408..c52ecfaf 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -1,6 +1,7 @@ # numpydoc ignore=GL08 import datetime +import jax import jax.numpy as jnp import numpyro import numpyro.distributions as dist @@ -254,7 +255,7 @@ def sample(self, n_days_post_init: int): numpyro.deterministic("rt", inf_with_feedback_proc_sample.rt) numpyro.deterministic("latent_infections", latent_infections) - return latent_infections + return latent_infections, latent_infections_subpop class EDVisitObservationProcess(RandomVariable): @@ -466,18 +467,232 @@ def sample( class WastewaterObservationProcess(RandomVariable): """ - Placeholder for wastewater obs process + Observe and/or predict wastewater concentration """ - def __init__(self) -> None: - pass - - def sample(self): - pass + def __init__( + self, + t_peak_rv: RandomVariable, + duration_shed_after_peak_rv: RandomVariable, + log10_genome_per_inf_ind_rv: RandomVariable, + mode_sigma_ww_site_rv: RandomVariable, + sd_log_sigma_ww_site_rv: RandomVariable, + mode_sd_ww_site_rv: RandomVariable, + max_shed_interval: float, + ww_ml_produced_per_day: float, + ) -> None: + self.t_peak_rv = t_peak_rv + self.duration_shed_after_peak_rv = duration_shed_after_peak_rv + self.log10_genome_per_inf_ind_rv = log10_genome_per_inf_ind_rv + self.mode_sigma_ww_site_rv = mode_sigma_ww_site_rv + self.sd_log_sigma_ww_site_rv = sd_log_sigma_ww_site_rv + self.mode_sd_ww_site_rv = mode_sd_ww_site_rv + self.max_shed_interval = max_shed_interval + self.ww_ml_produced_per_day = ww_ml_produced_per_day def validate(self): pass + @staticmethod + def normed_shedding_cdf( + time: ArrayLike, t_p: float, t_d: float, log_base: float + ) -> ArrayLike: + """ + calculates fraction of total fecal RNA shedding that has occurred + by a given time post infection. + + + Parameters + ---------- + time: ArrayLike + Time points to calculate the CDF of viral shedding. + t_p : float + Time (in days) from infection to peak shedding. + t_d: float + Time (in days) from peak shedding to the end of shedding. + log_base: float + Log base used for the shedding kinetics function. + + + Returns + ------- + ArrayLike + Normalized CDF values of viral shedding at each time point. + """ + norm_const = (t_p + t_d) * ((log_base - 1) / jnp.log(log_base) - 1) + + def ad_pre(x): + return ( + t_p / jnp.log(log_base) * jnp.exp(jnp.log(log_base) * x / t_p) + - x + ) + + def ad_post(x): + return ( + -t_d + / jnp.log(log_base) + * jnp.exp(jnp.log(log_base) * (1 - ((x - t_p) / t_d))) + - x + ) + + return ( + jnp.where( + time < t_p + t_d, + jnp.where( + time < t_p, + ad_pre(time) - ad_pre(0), + ad_pre(t_p) - ad_pre(0) + ad_post(time) - ad_post(t_p), + ), + norm_const, + ) + / norm_const + ) + + def get_viral_trajectory( + self, + tpeak: float, + duration_shed_after_peak: float, + ) -> ArrayLike: + """ + Computes the probability mass function (PMF) of + daily viral shedding based on a normalized CDF. + + Parameters + ---------- + tpeak: float + Time (in days) from infection to peak viral load in shedding. + duration_shed_after_peak: float + Duration (in days) of detectable viral shedding after the peak. + + Returns + ------- + ArrayLike + Normalized daily viral shedding PMF + """ + daily_shedding_pmf = self.normed_shedding_cdf( + jnp.arange(1, self.max_shed_interval), + tpeak, + duration_shed_after_peak, + 10, + ) - self.normed_shedding_cdf( + jnp.arange(0, self.max_shed_interval - 1), + tpeak, + duration_shed_after_peak, + 10, + ) + return daily_shedding_pmf + + def sample( + self, + latent_infections_subpop: ArrayLike, + data_observed: ArrayLike, + n_datapoints: int, + ww_uncensored: ArrayLike, + ww_censored: ArrayLike, + ww_sampled_lab_sites: ArrayLike, + ww_sampled_subpops: ArrayLike, + ww_sampled_times: ArrayLike, + ww_log_lod: ArrayLike, + lab_site_to_subpop_map: ArrayLike, + n_ww_lab_sites: int, + shedding_offset: float, + pop_fraction: ArrayLike, + ): + t_peak = self.t_peak_rv() + dur_shed = self.duration_shed_after_peak_rv() + viral_kinetics = self.get_viral_trajectory(t_peak, dur_shed) + + def batch_colvolve_fn(m): + return jnp.convolve(m, viral_kinetics, mode="valid") + + model_net_inf_ind_shedding = jax.vmap( + batch_colvolve_fn, in_axes=1, out_axes=1 + )(jnp.atleast_2d(latent_infections_subpop))[-n_datapoints:, :] + numpyro.deterministic( + "model_net_inf_ind_shedding", model_net_inf_ind_shedding + ) + + log10_genome_per_inf_ind = self.log10_genome_per_inf_ind_rv() + expected_obs_viral_genomes = ( + jnp.log(10) * log10_genome_per_inf_ind + + jnp.log(model_net_inf_ind_shedding + shedding_offset) + - jnp.log(self.ww_ml_produced_per_day) + ) + numpyro.deterministic( + "expected_obs_viral_genomes", expected_obs_viral_genomes + ) + + mode_sigma_ww_site = self.mode_sigma_ww_site_rv() + sd_log_sigma_ww_site = self.sd_log_sigma_ww_site_rv() + mode_sd_ww_site = self.mode_sd_ww_site_rv() + + mode_ww_site_rv = DistributionalVariable( + "mode_ww_site", + dist.Normal(0, mode_sd_ww_site), + reparam=LocScaleReparam(0), + ) # lab-site specific variation + + sigma_ww_site_rv = TransformedVariable( + "sigma_ww_site", + DistributionalVariable( + "log_sigma_ww_site", + dist.Normal(jnp.log(mode_sigma_ww_site), sd_log_sigma_ww_site), + reparam=LocScaleReparam(0), + ), + transforms=transformation.ExpTransform(), + ) + + with numpyro.plate("n_ww_lab_sites", n_ww_lab_sites): + mode_ww_site = mode_ww_site_rv() + sigma_ww_site = sigma_ww_site_rv() + + # multiply the expected observed genomes by the site-specific multiplier at that sampling time + expected_obs_log_v_site = ( + expected_obs_viral_genomes[ww_sampled_times, ww_sampled_subpops] + + mode_ww_site[ww_sampled_lab_sites] + ) + + DistributionalVariable( + "log_conc_obs", + dist.Normal( + loc=expected_obs_log_v_site[ww_uncensored], + scale=sigma_ww_site[ww_sampled_lab_sites[ww_uncensored]], + ), + ).sample( + obs=( + data_observed[ww_uncensored] + if data_observed is not None + else None + ), + ) + + if ww_censored.shape[0] != 0: + log_cdf_values = dist.Normal( + loc=expected_obs_log_v_site[ww_censored], + scale=sigma_ww_site[ww_sampled_lab_sites[ww_censored]], + ).log_cdf(ww_log_lod[ww_censored]) + numpyro.factor("log_prob_censored", log_cdf_values.sum()) + + # Predict site and population level wastewater concentrations + site_log_ww_conc = DistributionalVariable( + "site_log_ww_conc", + dist.Normal( + loc=expected_obs_viral_genomes[:, lab_site_to_subpop_map] + + mode_ww_site, + scale=sigma_ww_site, + ), + )() + + population_latent_viral_genome_conc = jax.scipy.special.logsumexp( + expected_obs_viral_genomes, axis=1, b=pop_fraction + ) + numpyro.deterministic( + "population_latent_viral_genome_conc", + population_latent_viral_genome_conc, + ) + + return site_log_ww_conc, population_latent_viral_genome_conc + class PyrenewHEWModel(Model): # numpydoc ignore=GL08 def __init__( @@ -505,8 +720,10 @@ def sample( sample_wastewater: bool = False, ) -> dict[str, ArrayLike]: # numpydoc ignore=GL08 n_init_days = self.latent_infection_process_rv.n_initialization_points - latent_infections = self.latent_infection_process_rv( - n_days_post_init=data.n_days_post_init, + latent_infections, latent_infections_subpop = ( + self.latent_infection_process_rv( + n_days_post_init=data.n_days_post_init, + ) ) first_latent_infection_dow = ( data.first_data_date_overall - datetime.timedelta(days=n_init_days) @@ -514,7 +731,8 @@ def sample( observed_ed_visits = None observed_admissions = None - observed_wastewater = None + site_level_observed_wastewater = None + population_level_latent_wastewater = None iedr = None @@ -537,10 +755,28 @@ def sample( iedr=iedr, ) if sample_wastewater: - observed_wastewater = self.wastewater_obs_process_rv() + ( + site_level_observed_wastewater, + population_level_latent_wastewater, + ) = self.wastewater_obs_process_rv( + latent_infections=latent_infections, + latent_infections_subpop=latent_infections_subpop, + data_observed=data.data_observed_disease_wastewater, + n_datapoints=data.n_wastewater_datapoints, + ww_uncensored=None, # placeholder + ww_censored=None, # placeholder + ww_sampled_lab_sites=None, # placeholder + ww_sampled_subpops=None, # placeholder + ww_sampled_times=None, # placeholder + ww_log_lod=None, # placeholder + lab_site_to_subpop_map=None, # placeholder + n_ww_lab_sites=None, # placeholder + shedding_offset=1e-8, + ) return { "ed_visits": observed_ed_visits, "hospital_admissions": observed_admissions, - "wasewater": observed_wastewater, + "site_level_wastewater_conc": site_level_observed_wastewater, + "population_level_latent_wastewater_conc": population_level_latent_wastewater, } diff --git a/pyrenew_hew/utils.py b/pyrenew_hew/utils.py deleted file mode 100644 index 006601a4..00000000 --- a/pyrenew_hew/utils.py +++ /dev/null @@ -1,54 +0,0 @@ -import jax.numpy as jnp - - -def convert_to_logmean_log_sd(mean, sd): - logmean = jnp.log( - jnp.power(mean, 2) / jnp.sqrt(jnp.power(sd, 2) + jnp.power(mean, 2)) - ) - logsd = jnp.sqrt(jnp.log(1 + (jnp.power(sd, 2) / jnp.power(mean, 2)))) - return logmean, logsd - - -def normed_shedding_cdf( - time: float, t_p: float, t_d: float, log_base: float -) -> float: - """ - fraction of total fecal RNA shedding that has occurred - by a given time post infection. - """ - norm_const = (t_p + t_d) * ((log_base - 1) / jnp.log(log_base) - 1) - - def ad_pre(x): - return ( - t_p / jnp.log(log_base) * jnp.exp(jnp.log(log_base) * x / t_p) - x - ) - - def ad_post(x): - return ( - -t_d - / jnp.log(log_base) - * jnp.exp(jnp.log(log_base) * (1 - ((x - t_p) / t_d))) - - x - ) - - return ( - jnp.where( - time < t_p + t_d, - jnp.where( - time < t_p, - ad_pre(time) - ad_pre(0), - ad_pre(t_p) - ad_pre(0) + ad_post(time) - ad_post(t_p), - ), - norm_const, - ) - / norm_const - ) - - -def get_vl_trajectory(tpeak, duration_shedding_after_peak, max_days): - daily_shedding_pmf = normed_shedding_cdf( - jnp.arange(1, max_days), tpeak, duration_shedding_after_peak, 10 - ) - normed_shedding_cdf( - jnp.arange(0, max_days - 1), tpeak, duration_shedding_after_peak, 10 - ) - return daily_shedding_pmf diff --git a/pyrenew_hew/ww_site_level_dynamics_model.py b/pyrenew_hew/ww_site_level_dynamics_model.py index 23b32a47..5d3ee493 100644 --- a/pyrenew_hew/ww_site_level_dynamics_model.py +++ b/pyrenew_hew/ww_site_level_dynamics_model.py @@ -4,6 +4,7 @@ import numpyro.distributions as dist import numpyro.distributions.transforms as transforms import pyrenew.transformation as transformation +from jax.typing import ArrayLike from numpyro.infer.reparam import LocScaleReparam from pyrenew.arrayutils import tile_until_n from pyrenew.convolve import compute_delay_ascertained_incidence @@ -18,8 +19,6 @@ from pyrenew.process import ARProcess, DifferencedProcess from pyrenew.randomvariable import DistributionalVariable, TransformedVariable -from pyrenew_hew.utils import get_vl_trajectory - class ww_site_level_dynamics_model(Model): # numpydoc ignore=GL08 def __init__( @@ -350,15 +349,17 @@ def sample( gen_int=generation_interval_pmf, ) - latent_infections_subpop = jnp.concat( - [ - i0, - inf_with_feedback_proc_sample.post_initialization_infections, - ] + latent_infections_subpop = jnp.atleast_2d( + jnp.concat( + [ + i0, + inf_with_feedback_proc_sample.post_initialization_infections, + ] + ) ) if self.n_subpops == 1: - latent_infections = latent_infections_subpop + latent_infections = jnp.squeeze(latent_infections_subpop) else: latent_infections = jnp.sum( self.pop_fraction * latent_infections_subpop, axis=1 @@ -437,7 +438,7 @@ def sample( if self.include_ww: t_peak = self.t_peak_rv() dur_shed = self.dur_shed_after_peak_rv() - s = get_vl_trajectory(t_peak, dur_shed, self.max_shed_interval) + s = get_viral_trajectory(t_peak, dur_shed, self.max_shed_interval) def batch_colvolve_fn(m): return jnp.convolve(m, s, mode="valid") @@ -556,3 +557,86 @@ def batch_colvolve_fn(m): observed_hospital_admissions, site_ww_pred_log if self.include_ww else None, ) + + +def normed_shedding_cdf( + time: ArrayLike, t_p: float, t_d: float, log_base: float +) -> ArrayLike: + """ + calculates fraction of total fecal RNA shedding that has occurred + by a given time post infection. + + + Parameters + ---------- + time: ArrayLike + Time points to calculate the CDF of viral shedding. + t_p : float + Time (in days) from infection to peak shedding. + t_d: float + Time (in days) from peak shedding to the end of shedding. + log_base: float + Log base used for the shedding kinetics function. + + + Returns + ------- + ArrayLike + Normalized CDF values of viral shedding at each time point. + """ + norm_const = (t_p + t_d) * ((log_base - 1) / jnp.log(log_base) - 1) + + def ad_pre(x): + return ( + t_p / jnp.log(log_base) * jnp.exp(jnp.log(log_base) * x / t_p) - x + ) + + def ad_post(x): + return ( + -t_d + / jnp.log(log_base) + * jnp.exp(jnp.log(log_base) * (1 - ((x - t_p) / t_d))) + - x + ) + + return ( + jnp.where( + time < t_p + t_d, + jnp.where( + time < t_p, + ad_pre(time) - ad_pre(0), + ad_pre(t_p) - ad_pre(0) + ad_post(time) - ad_post(t_p), + ), + norm_const, + ) + / norm_const + ) + + +def get_viral_trajectory( + tpeak: float, duration_shedding_after_peak: float, max_days: int +): + """ + Computes the probability mass function (PMF) of + daily viral shedding based on a normalized CDF. + + Parameters + ---------- + tpeak: float + Time (in days) from infection to peak viral load in shedding. + duration_shedding_after_peak: float + Duration (in days) of detectable viral shedding after the peak. + max_days: int + Maximum number of days to calculate the shedding trajectory. + + Returns + ------- + ArrayLike + Normalized daily viral shedding PMF + """ + daily_shedding_pmf = normed_shedding_cdf( + jnp.arange(1, max_days), tpeak, duration_shedding_after_peak, 10 + ) - normed_shedding_cdf( + jnp.arange(0, max_days - 1), tpeak, duration_shedding_after_peak, 10 + ) + return daily_shedding_pmf diff --git a/tests/test_latent_infection_process.py b/tests/test_latent_infection_process.py index 695fe916..6ce0bf51 100644 --- a/tests/test_latent_infection_process.py +++ b/tests/test_latent_infection_process.py @@ -48,7 +48,7 @@ def test_LatentInfectionProcess(): ) with numpyro.handlers.seed(rng_seed=223): - latent_inf_w_hierarchical_effects = my_latent_infection_model( + latent_inf_w_hierarchical_effects, _ = my_latent_infection_model( n_days_post_init=n_days_post_init ) diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index 7763b3cf..00000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,27 +0,0 @@ -""" -Tests for the pyrenew-hew utils. -""" - -import jax.numpy as jnp -import pytest - -from pyrenew_hew.utils import convert_to_logmean_log_sd - - -@pytest.mark.parametrize( - ["mean", "sd"], - [ - [ - jnp.array([10]), - jnp.array([0]), - ] - ], -) -def test_convert_to_logmean_log_sd_edge_case_zero_sd(mean, sd): - logmean, logsd = convert_to_logmean_log_sd(mean, sd) - - expected_logmean = jnp.log(10.0) - expected_logsd = 0.0 - - assert jnp.isclose(logmean, expected_logmean) - assert jnp.isclose(logsd, expected_logsd)