From e121c3bffefc08d22b99ca4496e966d2a19e96f9 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 29 Aug 2024 13:34:32 -0400 Subject: [PATCH 01/50] update stan_data file --- notebooks/data/fit/stan_data.json | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/notebooks/data/fit/stan_data.json b/notebooks/data/fit/stan_data.json index 95fe121c..c0903df9 100644 --- a/notebooks/data/fit/stan_data.json +++ b/notebooks/data/fit/stan_data.json @@ -356,10 +356,10 @@ "hosp_wday_effect_prior_alpha": [5, 5, 5, 5, 5, 5, 5], "initial_growth_prior_mean": 0, "initial_growth_prior_sd": 0.01, - "sigma_ww_site_prior_mean_mean": 1, - "sigma_ww_site_prior_mean_sd": 1, - "sigma_ww_site_prior_sd_mean": 0, - "sigma_ww_site_prior_sd_sd": 1, + "mode_sigma_ww_site_prior_mode": 1, + "mode_sigma_ww_site_prior_sd": 1, + "sd_log_sigma_ww_site_prior_mode": 0, + "sd_log_sigma_ww_site_prior_sd": 0.693, "eta_sd_sd": 0.01, "sigma_i0_prior_mode": 0, "sigma_i0_prior_sd": 0.5, From 6571b863c605a48385ce175264569b08afa9ecf1 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 29 Aug 2024 17:53:07 -0400 Subject: [PATCH 02/50] first pass site-level model implementation --- .../site_level_dynamics_model.py | 277 ++++++++++++++++++ pyrenew_covid_wastewater/utils.py | 26 ++ 2 files changed, 303 insertions(+) create mode 100644 pyrenew_covid_wastewater/site_level_dynamics_model.py create mode 100644 pyrenew_covid_wastewater/utils.py diff --git a/pyrenew_covid_wastewater/site_level_dynamics_model.py b/pyrenew_covid_wastewater/site_level_dynamics_model.py new file mode 100644 index 00000000..d26e2425 --- /dev/null +++ b/pyrenew_covid_wastewater/site_level_dynamics_model.py @@ -0,0 +1,277 @@ +import jax +import jax.numpy as jnp +import numpy as np +import numpyro +import numpyro.distributions as dist +import numpyro.distributions.transforms as transforms +from pyrenew.deterministic import DeterministicVariable, DeterministicPMF +from pyrenew.process import ARProcess, RtWeeklyDiffARProcess +import pyrenew.transformation as transformation +from pyrenew.latent import ( + InfectionInitializationProcess, + InitializeInfectionsExponentialGrowth, + InfectionsWithFeedback, +) + +from pyrenew.observation import NegativeBinomialObservation +from pyrenew.arrayutils import tile_until_n +from pyrenew.randomvariable import DistributionalVariable, TransformedVariable + + +with numpyro.handlers.seed(rng_seed=223): + eta_sd = eta_sd_rv()[0].value + autoreg_rt = autoreg_rt_rv()[0].value + log_r_mu_intercept = log_r_mu_intercept_rv()[0].value + +autoreg_rt_det_rv = DeterministicVariable("autoreg_rt_det", autoreg_rt) + +init_rate_of_change_rv = DistributionalVariable( + "init_rate_of_change", + dist.Normal(0, eta_sd / jnp.sqrt(1 - jnp.pow(autoreg_rt, 2))), +) + +with numpyro.handlers.seed(rng_seed=223): + init_rate_of_change = init_rate_of_change_rv()[0].value + +rt_proc = RtWeeklyDiffARProcess( + name="rtu_state_weekly_diff", + offset=0, + log_rt_rv=DeterministicVariable( + name="log_rt", + value=jnp.array( + [ + log_r_mu_intercept, + log_r_mu_intercept + init_rate_of_change, + ] + ), + ), + autoreg_rv=autoreg_rt_det_rv, + periodic_diff_sd_rv=DeterministicVariable( + name="periodic_diff_sd", value=jnp.array(eta_sd) + ), +) + +with numpyro.handlers.seed(rng_seed=223): + rtu = rt_proc.sample( + duration=n_datapoints + ) # log_r_mu_t_in_weeks in stan - not log anymore and not weekly either + + +with numpyro.handlers.seed(rng_seed=223): + t_peak = t_peak_rv() + viral_peak = viral_peak_rv() + dur_shed = dur_shed_rv() + +s = get_vl_trajectory(t_peak[0].value, viral_peak[0].value, dur_shed[0].value, gt_max) + + +# Site-level Rt, to be repeated for each site +r_site_t = jnp.zeros((n_subpops, obs_time + horizon_time)) +new_i_site_matrix = jnp.zeros((n_subpops, n_datapoints + n_initialization_points)) +model_log_v_ot = jnp.zeros((n_subpops, obs_time + horizon_time)) + +for i in range(n_subpops): + with numpyro.handlers.seed(rng_seed=223): + autoreg_rt_site = autoreg_rt_site_rv() + sigma_rt = sigma_rt_rv() + + rtu_site_ar_init_rv = DistributionalVariable( + "rtu_site_ar_init", + dist.Normal( + 0, + sigma_rt[0].value / jnp.sqrt(1 - jnp.pow(autoreg_rt_site[0].value, 2)), + ), + ) + + rtu_site_ar_proc = ARProcess(noise_rv_name="rtu_ar_proc") + + with numpyro.handlers.seed(rng_seed=223): + rtu_site_ar_init = rtu_site_ar_init_rv() + rtu_site_ar_weekly = rtu_site_ar_proc( + n=n_weeks, + init_vals=rtu_site_ar_init[0].value, + autoreg=autoreg_rt_site[0].value, + noise_sd=sigma_rt[0].value, + ) + + rtu_site_ar = jnp.repeat( + transformation.ExpTransform()(rtu_site_ar_weekly[0].value), repeats=7 + )[:n_datapoints] + + rtu_site = ( + rtu_site_ar + rtu.rt.value + ) # this reults in more sensible values but it should be as below? + # rtu_site = rtu_site_ar*rtu.rt.value + + # Site level disease dynamic estimates! + with numpyro.handlers.seed(rng_seed=223): + i0_over_n = i0_over_n_rv() + sigma_i0 = sigma_i0_rv() + eta_i0 = eta_i0_rv() + initial_growth = initialization_rate_rv() + eta_growth = eta_growth_rv() + sigma_growth = sigma_growth_rv() + + # Calculate infection and adjusted Rt for each sight using site-level i0 `i0_site_over_n` and initialization rate `growth_site` + # These are computed as a vector in stan code, but iid implementation is probably better for using numpyro.plate + + # site level growth rate + growth_site = initial_growth[0].value + eta_growth[0].value * sigma_growth[0].value + + growth_site_rv = DeterministicVariable("growth_site_rv", jnp.array(growth_site)) + + # site-level initial per capita infection incidence + i0_site_over_n = jax.nn.sigmoid( + transforms.logit(i0_over_n[0].value) + eta_i0[0].value * sigma_i0[0].value + ) + + i0_site_over_n_rv = DeterministicVariable( + "i0_site_over_n_rv", jnp.array(i0_site_over_n) + ) + + infection_initialization_process = InfectionInitializationProcess( + "I0_initialization", + i0_site_over_n_rv, + InitializeInfectionsExponentialGrowth( + n_initialization_points, + growth_site_rv, + t_pre_init=i0_t_offset, + ), + t_unit=1, + ) + + with numpyro.handlers.seed(rng_seed=223): + generation_interval_pmf = generation_interval_pmf_rv() + i0 = infection_initialization_process() + inf_with_feedback_proc_sample = inf_with_feedback_proc.sample( + Rt=rtu_site, + I0=i0[0].value, + gen_int=generation_interval_pmf[0].value, + ) + + new_i_site = jnp.concat( + [ + i0[0].value, + inf_with_feedback_proc_sample.post_initialization_infections.value, + ] + ) + r_site_t = r_site_t.at[i, :].set(inf_with_feedback_proc_sample.rt.value) + new_i_site_matrix = new_i_site_matrix.at[i, :].set(new_i_site) + + # number of net infected individuals shedding on each day (sum of individuals in dift stages of infection) + model_net_i = jnp.convolve(new_i_site, s, mode="valid")[-n_datapoints:] + + with numpyro.handlers.seed(rng_seed=223): + log10_g = log10_g_rv() + + # expected observed viral genomes/mL at all observed and forecasted times + # [n_subpops, ot + ht] model_log_v_ot aka do it for all subpop + model_log_v_ot_site = ( + jnp.log(10) * log10_g[0].value + + jnp.log(model_net_i[: (obs_time + horizon_time)] + 1e-8) + - jnp.log(ww_ml_produced_per_day) + ) + model_log_v_ot = model_log_v_ot.at[i, :].set(model_log_v_ot_site) + +state_inf_per_capita = jnp.sum(pop_fraction_reshaped * new_i_site_matrix, axis=0) +# Hospital admission component + + +# p_hosp_w is std_normal - weekly random walk for IHR + +with numpyro.handlers.seed(rng_seed=223): + p_hosp_mean = p_hosp_mean_rv() + p_hosp_w_sd = p_hosp_w_sd_rv() + autoreg_p_hosp = autoreg_p_hosp_rv() + +p_hosp_ar_proc = ARProcess("p_hosp") + +p_hosp_ar_init_rv = DistributionalVariable( + "p_hosp_ar_init", + dist.Normal( + 0, + p_hosp_w_sd[0].value / jnp.sqrt(1 - jnp.pow(autoreg_p_hosp[0].value, 2)), + ), +) + +with numpyro.handlers.seed(rng_seed=223): + p_hosp_ar_init = p_hosp_ar_init_rv() + p_hosp_ar = p_hosp_ar_proc.sample( + n=n_weeks, + autoreg=autoreg_p_hosp[0].value, + init_vals=p_hosp_ar_init[0].value, + noise_sd=p_hosp_w_sd[0].value, + ) + +ihr = jnp.repeat( + transformation.SigmoidTransform()(p_hosp_ar[0].value + p_hosp_mean[0].value), + repeats=7, +)[:n_datapoints] + + +with numpyro.handlers.seed(rng_seed=223): + hosp_wday_effect_raw = hosp_wday_effect_rv()[0].value + inf_to_hosp = inf_to_hosp_rv()[0].value + +hosp_wday_effect = tile_until_n(hosp_wday_effect_raw, n_datapoints) + +potential_latent_hospital_admissions = jnp.convolve( + state_inf_per_capita, + inf_to_hosp, + mode="valid", +)[-n_datapoints:] + +latent_hospital_admissions = ( + potential_latent_hospital_admissions * ihr * hosp_wday_effect * state_pop +) + + +with numpyro.handlers.seed(rng_seed=223): + mode_sigma_ww_site = mode_sigma_ww_site_rv()[0].value + sd_log_sigma_ww_site = sd_log_sigma_ww_site_rv()[0].value + eta_log_sigma_ww_site = eta_log_sigma_ww_site_rv()[0].value + ww_site_mod_raw = ww_site_mod_raw_rv()[0].value + ww_site_mod_sd = ww_site_mod_sd_rv()[0].value + + +# These are the true expected genomes at the site level before observation error +# (which is at the lab-site level) +exp_obs_log_v_true = model_log_v_ot[ww_sampled_sites, ww_sampled_times] + +# modify by lab-site specific variation (multiplier!) +ww_site_mod = ww_site_mod_raw * ww_site_mod_sd + +# LHS log transformed obs genomes per person-day, RHS multiplies the expected observed +# genomes by the site-specific multiplier at that sampling time +exp_obs_log_v = exp_obs_log_v_true + ww_site_mod[ww_sampled_lab_sites] + +sigma_ww_site = jnp.exp( + jnp.log(mode_sigma_ww_site) + sd_log_sigma_ww_site * eta_log_sigma_ww_site +) + +g = jnp.power(log10_g[0].value, 10) # Estimated genomes shed per infected individual + + +log_conc_obs_rv = numpyro.sample( + "log_conc", + dist.Normal( + loc=exp_obs_log_v[ww_uncensored], + scale=sigma_ww_site[ww_sampled_lab_sites[ww_uncensored]], + ), + obs=data_observed_log_conc[ww_uncensored], +) + +if ww_censored.shape[0] != 0: + log_cdf_values = dist.Normal( + loc=exp_obs_log_v[ww_censored], + scale=sigma_ww_site[ww_sampled_lab_sites[ww_censored]], + ).log_cdf(ww_log_lod[ww_censored]) + + numpyro.factor("log_prob_censored", log_cdf_values.sum()) + + +with numpyro.handlers.seed(rng_seed=223): + observed_hospital_admissions = hospital_admission_obs_rv( + mu=latent_hospital_admissions, + obs=data_observed_hospital_admissions, + ) diff --git a/pyrenew_covid_wastewater/utils.py b/pyrenew_covid_wastewater/utils.py new file mode 100644 index 00000000..831ccdaf --- /dev/null +++ b/pyrenew_covid_wastewater/utils.py @@ -0,0 +1,26 @@ +# helper function + +import jax.numpy as jnp + + +def convert_to_logmean_log_sd(mean, sd): + logmean = jnp.log( + jnp.power(mean, 2) / jnp.sqrt(jnp.power(sd, 2) + jnp.power(mean, 2)) + ) + logsd = jnp.sqrt(jnp.log(1 + (jnp.power(sd, 2) / jnp.power(mean, 2)))) + return logmean, logsd + + +def get_vl_trajectory(tpeak, viral_peak, duration_shedding, n): + s = jnp.zeros(n) + growth = viral_peak / tpeak + wane = viral_peak / (duration_shedding - tpeak) + + t = jnp.arange(n) + s = jnp.where(t <= tpeak, jnp.power(10, growth * t), s) + + s = jnp.where(t > tpeak, jnp.maximum(0, viral_peak + wane * tpeak - wane * t), s) + s = jnp.where(t > tpeak, jnp.power(10, s), s) + + s = s / jnp.sum(s) + return s From ffa82082f4f904e853f2367e0e4baf7b95edcdde Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 29 Aug 2024 18:04:49 -0400 Subject: [PATCH 03/50] format and clean up --- .../site_level_dynamics_model.py | 611 +++++++++++------- pyrenew_covid_wastewater/utils.py | 4 +- 2 files changed, 375 insertions(+), 240 deletions(-) diff --git a/pyrenew_covid_wastewater/site_level_dynamics_model.py b/pyrenew_covid_wastewater/site_level_dynamics_model.py index d26e2425..2a933638 100644 --- a/pyrenew_covid_wastewater/site_level_dynamics_model.py +++ b/pyrenew_covid_wastewater/site_level_dynamics_model.py @@ -1,277 +1,410 @@ import jax import jax.numpy as jnp -import numpy as np import numpyro import numpyro.distributions as dist import numpyro.distributions.transforms as transforms -from pyrenew.deterministic import DeterministicVariable, DeterministicPMF -from pyrenew.process import ARProcess, RtWeeklyDiffARProcess import pyrenew.transformation as transformation +from pyrenew.arrayutils import tile_until_n +from pyrenew.deterministic import DeterministicVariable from pyrenew.latent import ( InfectionInitializationProcess, - InitializeInfectionsExponentialGrowth, InfectionsWithFeedback, + InitializeInfectionsExponentialGrowth, ) +from pyrenew.metaclass import Model +from pyrenew.process import ARProcess, RtWeeklyDiffARProcess +from pyrenew.randomvariable import DistributionalVariable + +from pyrenew_covid_wastewater.utils import get_vl_trajectory + + +class hosp_only_ww_model(Model): # numpydoc ignore=GL08 + def __init__( + self, + state_pop, + n_subpops, + n_initialization_points, + gt_max, + i0_t_offset, + log_r_mu_intercept_rv, + autoreg_rt_rv, + eta_sd_rv, + t_peak_rv, + viral_peak_rv, + dur_shed_rv, + autoreg_rt_site_rv, + sigma_rt_rv, + i0_over_n_rv, + sigma_i0_rv, + eta_i0_rv, + initialization_rate_rv, + eta_growth_rv, + sigma_growth_rv, + generation_interval_pmf_rv, + infection_feedback_strength_rv, + infection_feedback_pmf_rv, + p_hosp_mean_rv, + p_hosp_w_sd_rv, + autoreg_p_hosp_rv, + hosp_wday_effect_rv, + inf_to_hosp_rv, + log10_g_rv, + mode_sigma_ww_site_rv, + sd_log_sigma_ww_site_rv, + eta_log_sigma_ww_site_rv, + ww_site_mod_raw_rv, + ww_site_mod_sd_rv, + hospital_admission_obs_rv, + ww_ml_produced_per_day, + pop_fraction_reshaped, + ww_uncensored, + ww_censored, + ww_sampled_lab_sites, + ww_sampled_sites, + ww_sampled_times, + ww_log_lod, + ): # numpydoc ignore=GL08 + self.state_pop = state_pop + self.n_subpops = (n_subpops,) + self.n_initialization_points = n_initialization_points + self.gt_max = gt_max + self.i0_t_offset = i0_t_offset + self.log_r_mu_intercept_rv = log_r_mu_intercept_rv + self.autoreg_rt_rv = autoreg_rt_rv + self.eta_sd_rv = eta_sd_rv + self.t_peak_rv = t_peak_rv + self.viral_peak_rv = viral_peak_rv + self.dur_shed_rv = dur_shed_rv + self.autoreg_rt_site_rv = autoreg_rt_site_rv + self.sigma_rt_rv = sigma_rt_rv + self.i0_over_n_rv = i0_over_n_rv + self.sigma_i0_rv = sigma_i0_rv + self.eta_i0_rv = eta_i0_rv + self.initial_growth_rv = initialization_rate_rv + self.eta_growth_rv = eta_growth_rv + self.sigma_growth_rv = sigma_growth_rv + self.generation_interval_pmf_rv = generation_interval_pmf_rv + self.p_hosp_mean_rv = p_hosp_mean_rv + self.p_hosp_w_sd_rv = p_hosp_w_sd_rv + self.autoreg_p_hosp_rv = autoreg_p_hosp_rv + self.hosp_wday_effect_rv = hosp_wday_effect_rv + self.inf_to_hosp_rv = inf_to_hosp_rv + self.log10_g_rv = log10_g_rv + self.mode_sigma_ww_site_rv = mode_sigma_ww_site_rv + self.sd_log_sigma_ww_site_rv = sd_log_sigma_ww_site_rv + self.eta_log_sigma_ww_site_rv = eta_log_sigma_ww_site_rv + self.ww_site_mod_raw_rv = ww_site_mod_raw_rv + self.ww_site_mod_sd_rv = ww_site_mod_sd_rv + self.hospital_admission_obs_rv = hospital_admission_obs_rv + self.ww_ml_produced_per_day = ww_ml_produced_per_day + self.pop_fraction_reshaped = pop_fraction_reshaped + self.ww_uncensored = ww_uncensored + self.ww_censored = ww_censored + self.ww_sampled_lab_sites = ww_sampled_lab_sites + self.ww_sampled_sites = ww_sampled_sites + self.ww_sampled_times = ww_sampled_times + self.ww_log_lod = ww_log_lod + self.inf_with_feedback_proc = InfectionsWithFeedback( + infection_feedback_strength=infection_feedback_strength_rv, + infection_feedback_pmf=infection_feedback_pmf_rv, + ) + return None + + def validate(self): # numpydoc ignore=GL08 + return None + + def sample( + self, + n_datapoints=None, + data_observed_hospital_admissions=None, + data_observed_log_conc=None, + ): # numpydoc ignore=GL08 + if n_datapoints is None and data_observed_hospital_admissions is None: + raise ValueError( + "Either n_datapoints or data_observed_hosp_admissions " + "must be passed." + ) + elif ( + n_datapoints is not None + and data_observed_hospital_admissions is not None + ): + raise ValueError( + "Cannot pass both n_datapoints and data_observed_hospital_admissions." + ) + elif n_datapoints is None: + n_datapoints = len(data_observed_hospital_admissions) + else: + n_datapoints = n_datapoints + + n_weeks = n_datapoints // 7 + 1 + + eta_sd = self.eta_sd_rv()[0].value + autoreg_rt = self.autoreg_rt_rv()[0].value + log_r_mu_intercept = self.log_r_mu_intercept_rv()[0].value + + autoreg_rt_det_rv = DeterministicVariable("autoreg_rt_det", autoreg_rt) + init_rate_of_change_rv = DistributionalVariable( + "init_rate_of_change", + dist.Normal(0, eta_sd / jnp.sqrt(1 - jnp.pow(autoreg_rt, 2))), + ) -from pyrenew.observation import NegativeBinomialObservation -from pyrenew.arrayutils import tile_until_n -from pyrenew.randomvariable import DistributionalVariable, TransformedVariable - + init_rate_of_change = init_rate_of_change_rv()[0].value + + rt_proc = RtWeeklyDiffARProcess( + name="rtu_state_weekly_diff", + offset=0, + log_rt_rv=DeterministicVariable( + name="log_rt", + value=jnp.array( + [ + log_r_mu_intercept, + log_r_mu_intercept + init_rate_of_change, + ] + ), + ), + autoreg_rv=autoreg_rt_det_rv, + periodic_diff_sd_rv=DeterministicVariable( + name="periodic_diff_sd", value=jnp.array(eta_sd) + ), + ) -with numpyro.handlers.seed(rng_seed=223): - eta_sd = eta_sd_rv()[0].value - autoreg_rt = autoreg_rt_rv()[0].value - log_r_mu_intercept = log_r_mu_intercept_rv()[0].value + rtu = rt_proc.sample( + duration=n_datapoints + ) # log_r_mu_t_in_weeks in stan - not log anymore and not weekly either -autoreg_rt_det_rv = DeterministicVariable("autoreg_rt_det", autoreg_rt) + t_peak = self.t_peak_rv()[0].value + viral_peak = self.viral_peak_rv()[0].value + dur_shed = self.dur_shed_rv()[0].value -init_rate_of_change_rv = DistributionalVariable( - "init_rate_of_change", - dist.Normal(0, eta_sd / jnp.sqrt(1 - jnp.pow(autoreg_rt, 2))), -) + s = get_vl_trajectory(t_peak, viral_peak, dur_shed, self.gt_max) -with numpyro.handlers.seed(rng_seed=223): - init_rate_of_change = init_rate_of_change_rv()[0].value - -rt_proc = RtWeeklyDiffARProcess( - name="rtu_state_weekly_diff", - offset=0, - log_rt_rv=DeterministicVariable( - name="log_rt", - value=jnp.array( - [ - log_r_mu_intercept, - log_r_mu_intercept + init_rate_of_change, + # Site-level Rt, to be repeated for each site + r_site_t = jnp.zeros((self.n_subpops, n_datapoints)) + new_i_site_matrix = jnp.zeros( + (self.n_subpops, n_datapoints + self.n_initialization_points) + ) + model_log_v_ot = jnp.zeros((self.n_subpops, n_datapoints)) + + for i in range(self.n_subpops): + autoreg_rt_site = self.autoreg_rt_site_rv()[0].value + sigma_rt = self.sigma_rt_rv()[0].value + + rtu_site_ar_init_rv = DistributionalVariable( + "rtu_site_ar_init", + dist.Normal( + 0, + sigma_rt[0].value + / jnp.sqrt(1 - jnp.pow(autoreg_rt_site[0].value, 2)), + ), + ) + + rtu_site_ar_proc = ARProcess(noise_rv_name="rtu_ar_proc") + + rtu_site_ar_init = rtu_site_ar_init_rv()[0].value + rtu_site_ar_weekly = rtu_site_ar_proc( + n=n_weeks, + init_vals=rtu_site_ar_init, + autoreg=autoreg_rt_site, + noise_sd=sigma_rt, + ) + + rtu_site_ar = jnp.repeat( + transformation.ExpTransform()(rtu_site_ar_weekly[0].value), + repeats=7, + )[:n_datapoints] + + rtu_site = ( + rtu_site_ar + rtu.rt.value + ) # this reults in more sensible values but it should be as below? + # rtu_site = rtu_site_ar*rtu.rt.value + + # Site level disease dynamic estimates! + i0_over_n = self.i0_over_n_rv() + sigma_i0 = self.sigma_i0_rv() + eta_i0 = self.eta_i0_rv() + initial_growth = self.initialization_rate_rv() + eta_growth = self.eta_growth_rv() + sigma_growth = self.sigma_growth_rv() + + # Calculate infection and adjusted Rt for each sight using site-level i0 `i0_site_over_n` and initialization rate `growth_site` + # These are computed as a vector in stan code, but iid implementation is probably better for using numpyro.plate + + # site level growth rate + growth_site = ( + initial_growth[0].value + + eta_growth[0].value * sigma_growth[0].value + ) + + growth_site_rv = DeterministicVariable( + "growth_site_rv", jnp.array(growth_site) + ) + + # site-level initial per capita infection incidence + i0_site_over_n = jax.nn.sigmoid( + transforms.logit(i0_over_n[0].value) + + eta_i0[0].value * sigma_i0[0].value + ) + + i0_site_over_n_rv = DeterministicVariable( + "i0_site_over_n_rv", jnp.array(i0_site_over_n) + ) + + infection_initialization_process = InfectionInitializationProcess( + "I0_initialization", + i0_site_over_n_rv, + InitializeInfectionsExponentialGrowth( + self.n_initialization_points, + growth_site_rv, + t_pre_init=self.i0_t_offset, + ), + t_unit=1, + ) + + generation_interval_pmf = self.generation_interval_pmf_rv() + i0 = infection_initialization_process() + + inf_with_feedback_proc_sample = self.inf_with_feedback_proc.sample( + Rt=rtu_site, + I0=i0[0].value, + gen_int=generation_interval_pmf[0].value, + ) + + new_i_site = jnp.concat( + [ + i0[0].value, + inf_with_feedback_proc_sample.post_initialization_infections.value, + ] + ) + r_site_t = r_site_t.at[i, :].set( + inf_with_feedback_proc_sample.rt.value + ) + new_i_site_matrix = new_i_site_matrix.at[i, :].set(new_i_site) + + # number of net infected individuals shedding on each day (sum of individuals in dift stages of infection) + model_net_i = jnp.convolve(new_i_site, s, mode="valid")[ + -n_datapoints: ] - ), - ), - autoreg_rv=autoreg_rt_det_rv, - periodic_diff_sd_rv=DeterministicVariable( - name="periodic_diff_sd", value=jnp.array(eta_sd) - ), -) - -with numpyro.handlers.seed(rng_seed=223): - rtu = rt_proc.sample( - duration=n_datapoints - ) # log_r_mu_t_in_weeks in stan - not log anymore and not weekly either - -with numpyro.handlers.seed(rng_seed=223): - t_peak = t_peak_rv() - viral_peak = viral_peak_rv() - dur_shed = dur_shed_rv() + log10_g = self.log10_g_rv() -s = get_vl_trajectory(t_peak[0].value, viral_peak[0].value, dur_shed[0].value, gt_max) + # expected observed viral genomes/mL at all observed and forecasted times + # [n_subpops, ot + ht] model_log_v_ot aka do it for all subpop + model_log_v_ot_site = ( + jnp.log(10) * log10_g[0].value + + jnp.log(model_net_i[:(n_datapoints)] + 1e-8) + - jnp.log(self.ww_ml_produced_per_day) + ) + model_log_v_ot = model_log_v_ot.at[i, :].set(model_log_v_ot_site) + state_inf_per_capita = jnp.sum( + self.pop_fraction_reshaped * new_i_site_matrix, axis=0 + ) + # Hospital admission component -# Site-level Rt, to be repeated for each site -r_site_t = jnp.zeros((n_subpops, obs_time + horizon_time)) -new_i_site_matrix = jnp.zeros((n_subpops, n_datapoints + n_initialization_points)) -model_log_v_ot = jnp.zeros((n_subpops, obs_time + horizon_time)) + # p_hosp_w is std_normal - weekly random walk for IHR -for i in range(n_subpops): - with numpyro.handlers.seed(rng_seed=223): - autoreg_rt_site = autoreg_rt_site_rv() - sigma_rt = sigma_rt_rv() + p_hosp_mean = self.p_hosp_mean_rv() + p_hosp_w_sd = self.p_hosp_w_sd_rv() + autoreg_p_hosp = self.autoreg_p_hosp_rv() - rtu_site_ar_init_rv = DistributionalVariable( - "rtu_site_ar_init", - dist.Normal( - 0, - sigma_rt[0].value / jnp.sqrt(1 - jnp.pow(autoreg_rt_site[0].value, 2)), - ), - ) + p_hosp_ar_proc = ARProcess("p_hosp") - rtu_site_ar_proc = ARProcess(noise_rv_name="rtu_ar_proc") + p_hosp_ar_init_rv = DistributionalVariable( + "p_hosp_ar_init", + dist.Normal( + 0, + p_hosp_w_sd[0].value + / jnp.sqrt(1 - jnp.pow(autoreg_p_hosp[0].value, 2)), + ), + ) - with numpyro.handlers.seed(rng_seed=223): - rtu_site_ar_init = rtu_site_ar_init_rv() - rtu_site_ar_weekly = rtu_site_ar_proc( + p_hosp_ar_init = p_hosp_ar_init_rv() + p_hosp_ar = p_hosp_ar_proc.sample( n=n_weeks, - init_vals=rtu_site_ar_init[0].value, - autoreg=autoreg_rt_site[0].value, - noise_sd=sigma_rt[0].value, + autoreg=autoreg_p_hosp[0].value, + init_vals=p_hosp_ar_init[0].value, + noise_sd=p_hosp_w_sd[0].value, ) - rtu_site_ar = jnp.repeat( - transformation.ExpTransform()(rtu_site_ar_weekly[0].value), repeats=7 - )[:n_datapoints] - - rtu_site = ( - rtu_site_ar + rtu.rt.value - ) # this reults in more sensible values but it should be as below? - # rtu_site = rtu_site_ar*rtu.rt.value - - # Site level disease dynamic estimates! - with numpyro.handlers.seed(rng_seed=223): - i0_over_n = i0_over_n_rv() - sigma_i0 = sigma_i0_rv() - eta_i0 = eta_i0_rv() - initial_growth = initialization_rate_rv() - eta_growth = eta_growth_rv() - sigma_growth = sigma_growth_rv() - - # Calculate infection and adjusted Rt for each sight using site-level i0 `i0_site_over_n` and initialization rate `growth_site` - # These are computed as a vector in stan code, but iid implementation is probably better for using numpyro.plate - - # site level growth rate - growth_site = initial_growth[0].value + eta_growth[0].value * sigma_growth[0].value - - growth_site_rv = DeterministicVariable("growth_site_rv", jnp.array(growth_site)) - - # site-level initial per capita infection incidence - i0_site_over_n = jax.nn.sigmoid( - transforms.logit(i0_over_n[0].value) + eta_i0[0].value * sigma_i0[0].value - ) - - i0_site_over_n_rv = DeterministicVariable( - "i0_site_over_n_rv", jnp.array(i0_site_over_n) - ) - - infection_initialization_process = InfectionInitializationProcess( - "I0_initialization", - i0_site_over_n_rv, - InitializeInfectionsExponentialGrowth( - n_initialization_points, - growth_site_rv, - t_pre_init=i0_t_offset, - ), - t_unit=1, - ) - - with numpyro.handlers.seed(rng_seed=223): - generation_interval_pmf = generation_interval_pmf_rv() - i0 = infection_initialization_process() - inf_with_feedback_proc_sample = inf_with_feedback_proc.sample( - Rt=rtu_site, - I0=i0[0].value, - gen_int=generation_interval_pmf[0].value, + ihr = jnp.repeat( + transformation.SigmoidTransform()( + p_hosp_ar[0].value + p_hosp_mean[0].value + ), + repeats=7, + )[:n_datapoints] + + hosp_wday_effect_raw = self.hosp_wday_effect_rv()[0].value + inf_to_hosp = self.inf_to_hosp_rv()[0].value + + hosp_wday_effect = tile_until_n(hosp_wday_effect_raw, n_datapoints) + + potential_latent_hospital_admissions = jnp.convolve( + state_inf_per_capita, + inf_to_hosp, + mode="valid", + )[-n_datapoints:] + + latent_hospital_admissions = ( + potential_latent_hospital_admissions + * ihr + * hosp_wday_effect + * self.state_pop ) - new_i_site = jnp.concat( - [ - i0[0].value, - inf_with_feedback_proc_sample.post_initialization_infections.value, - ] - ) - r_site_t = r_site_t.at[i, :].set(inf_with_feedback_proc_sample.rt.value) - new_i_site_matrix = new_i_site_matrix.at[i, :].set(new_i_site) - - # number of net infected individuals shedding on each day (sum of individuals in dift stages of infection) - model_net_i = jnp.convolve(new_i_site, s, mode="valid")[-n_datapoints:] - - with numpyro.handlers.seed(rng_seed=223): - log10_g = log10_g_rv() - - # expected observed viral genomes/mL at all observed and forecasted times - # [n_subpops, ot + ht] model_log_v_ot aka do it for all subpop - model_log_v_ot_site = ( - jnp.log(10) * log10_g[0].value - + jnp.log(model_net_i[: (obs_time + horizon_time)] + 1e-8) - - jnp.log(ww_ml_produced_per_day) - ) - model_log_v_ot = model_log_v_ot.at[i, :].set(model_log_v_ot_site) - -state_inf_per_capita = jnp.sum(pop_fraction_reshaped * new_i_site_matrix, axis=0) -# Hospital admission component - - -# p_hosp_w is std_normal - weekly random walk for IHR + mode_sigma_ww_site = self.mode_sigma_ww_site_rv()[0].value + sd_log_sigma_ww_site = self.sd_log_sigma_ww_site_rv()[0].value + eta_log_sigma_ww_site = self.eta_log_sigma_ww_site_rv()[0].value + ww_site_mod_raw = self.ww_site_mod_raw_rv()[0].value + ww_site_mod_sd = self.ww_site_mod_sd_rv()[0].value -with numpyro.handlers.seed(rng_seed=223): - p_hosp_mean = p_hosp_mean_rv() - p_hosp_w_sd = p_hosp_w_sd_rv() - autoreg_p_hosp = autoreg_p_hosp_rv() - -p_hosp_ar_proc = ARProcess("p_hosp") - -p_hosp_ar_init_rv = DistributionalVariable( - "p_hosp_ar_init", - dist.Normal( - 0, - p_hosp_w_sd[0].value / jnp.sqrt(1 - jnp.pow(autoreg_p_hosp[0].value, 2)), - ), -) - -with numpyro.handlers.seed(rng_seed=223): - p_hosp_ar_init = p_hosp_ar_init_rv() - p_hosp_ar = p_hosp_ar_proc.sample( - n=n_weeks, - autoreg=autoreg_p_hosp[0].value, - init_vals=p_hosp_ar_init[0].value, - noise_sd=p_hosp_w_sd[0].value, - ) - -ihr = jnp.repeat( - transformation.SigmoidTransform()(p_hosp_ar[0].value + p_hosp_mean[0].value), - repeats=7, -)[:n_datapoints] - - -with numpyro.handlers.seed(rng_seed=223): - hosp_wday_effect_raw = hosp_wday_effect_rv()[0].value - inf_to_hosp = inf_to_hosp_rv()[0].value - -hosp_wday_effect = tile_until_n(hosp_wday_effect_raw, n_datapoints) - -potential_latent_hospital_admissions = jnp.convolve( - state_inf_per_capita, - inf_to_hosp, - mode="valid", -)[-n_datapoints:] - -latent_hospital_admissions = ( - potential_latent_hospital_admissions * ihr * hosp_wday_effect * state_pop -) - - -with numpyro.handlers.seed(rng_seed=223): - mode_sigma_ww_site = mode_sigma_ww_site_rv()[0].value - sd_log_sigma_ww_site = sd_log_sigma_ww_site_rv()[0].value - eta_log_sigma_ww_site = eta_log_sigma_ww_site_rv()[0].value - ww_site_mod_raw = ww_site_mod_raw_rv()[0].value - ww_site_mod_sd = ww_site_mod_sd_rv()[0].value - - -# These are the true expected genomes at the site level before observation error -# (which is at the lab-site level) -exp_obs_log_v_true = model_log_v_ot[ww_sampled_sites, ww_sampled_times] - -# modify by lab-site specific variation (multiplier!) -ww_site_mod = ww_site_mod_raw * ww_site_mod_sd - -# LHS log transformed obs genomes per person-day, RHS multiplies the expected observed -# genomes by the site-specific multiplier at that sampling time -exp_obs_log_v = exp_obs_log_v_true + ww_site_mod[ww_sampled_lab_sites] + # These are the true expected genomes at the site level before observation error + # (which is at the lab-site level) + exp_obs_log_v_true = model_log_v_ot[ + self.ww_sampled_sites, self.ww_sampled_times + ] -sigma_ww_site = jnp.exp( - jnp.log(mode_sigma_ww_site) + sd_log_sigma_ww_site * eta_log_sigma_ww_site -) + # modify by lab-site specific variation (multiplier!) + ww_site_mod = ww_site_mod_raw * ww_site_mod_sd -g = jnp.power(log10_g[0].value, 10) # Estimated genomes shed per infected individual + # LHS log transformed obs genomes per person-day, RHS multiplies the expected observed + # genomes by the site-specific multiplier at that sampling time + exp_obs_log_v = ( + exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] + ) + sigma_ww_site = jnp.exp( + jnp.log(mode_sigma_ww_site) + + sd_log_sigma_ww_site * eta_log_sigma_ww_site + ) -log_conc_obs_rv = numpyro.sample( - "log_conc", - dist.Normal( - loc=exp_obs_log_v[ww_uncensored], - scale=sigma_ww_site[ww_sampled_lab_sites[ww_uncensored]], - ), - obs=data_observed_log_conc[ww_uncensored], -) + # g = jnp.power( + # log10_g[0].value, 10 + # ) # Estimated genomes shed per infected individual + + log_conc_obs = numpyro.sample( + "log_conc", + dist.Normal( + loc=exp_obs_log_v[self.ww_uncensored], + scale=sigma_ww_site[ + self.ww_sampled_lab_sites[self.ww_uncensored] + ], + ), + obs=data_observed_log_conc[self.ww_uncensored], + ) -if ww_censored.shape[0] != 0: - log_cdf_values = dist.Normal( - loc=exp_obs_log_v[ww_censored], - scale=sigma_ww_site[ww_sampled_lab_sites[ww_censored]], - ).log_cdf(ww_log_lod[ww_censored]) + if self.ww_censored.shape[0] != 0: + log_cdf_values = dist.Normal( + loc=exp_obs_log_v[self.ww_censored], + scale=sigma_ww_site[ + self.ww_sampled_lab_sites[self.ww_censored] + ], + ).log_cdf(self.ww_log_lod[self.ww_censored]) - numpyro.factor("log_prob_censored", log_cdf_values.sum()) + numpyro.factor("log_prob_censored", log_cdf_values.sum()) + observed_hospital_admissions = self.hospital_admission_obs_rv( + mu=latent_hospital_admissions, + obs=data_observed_hospital_admissions, + ) -with numpyro.handlers.seed(rng_seed=223): - observed_hospital_admissions = hospital_admission_obs_rv( - mu=latent_hospital_admissions, - obs=data_observed_hospital_admissions, - ) + return (observed_hospital_admissions, log_conc_obs) diff --git a/pyrenew_covid_wastewater/utils.py b/pyrenew_covid_wastewater/utils.py index 831ccdaf..3a84f396 100644 --- a/pyrenew_covid_wastewater/utils.py +++ b/pyrenew_covid_wastewater/utils.py @@ -19,7 +19,9 @@ def get_vl_trajectory(tpeak, viral_peak, duration_shedding, n): t = jnp.arange(n) s = jnp.where(t <= tpeak, jnp.power(10, growth * t), s) - s = jnp.where(t > tpeak, jnp.maximum(0, viral_peak + wane * tpeak - wane * t), s) + s = jnp.where( + t > tpeak, jnp.maximum(0, viral_peak + wane * tpeak - wane * t), s + ) s = jnp.where(t > tpeak, jnp.power(10, s), s) s = s / jnp.sum(s) From df7ae334a3313f94cc1039b32ea8cc7c9357bd90 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 29 Aug 2024 23:32:21 -0400 Subject: [PATCH 04/50] sync changes --- .../site_level_dynamics_model.py | 53 ++++++++----------- 1 file changed, 21 insertions(+), 32 deletions(-) diff --git a/pyrenew_covid_wastewater/site_level_dynamics_model.py b/pyrenew_covid_wastewater/site_level_dynamics_model.py index 2a933638..8ddf6cde 100644 --- a/pyrenew_covid_wastewater/site_level_dynamics_model.py +++ b/pyrenew_covid_wastewater/site_level_dynamics_model.py @@ -14,11 +14,11 @@ from pyrenew.metaclass import Model from pyrenew.process import ARProcess, RtWeeklyDiffARProcess from pyrenew.randomvariable import DistributionalVariable - +from pyrenew.observation import NegativeBinomialObservation from pyrenew_covid_wastewater.utils import get_vl_trajectory -class hosp_only_ww_model(Model): # numpydoc ignore=GL08 +class ww_site_level_dynamics_model(Model): # numpydoc ignore=GL08 def __init__( self, state_pop, @@ -54,7 +54,7 @@ def __init__( eta_log_sigma_ww_site_rv, ww_site_mod_raw_rv, ww_site_mod_sd_rv, - hospital_admission_obs_rv, + phi_rv, ww_ml_produced_per_day, pop_fraction_reshaped, ww_uncensored, @@ -65,7 +65,7 @@ def __init__( ww_log_lod, ): # numpydoc ignore=GL08 self.state_pop = state_pop - self.n_subpops = (n_subpops,) + self.n_subpops = n_subpops self.n_initialization_points = n_initialization_points self.gt_max = gt_max self.i0_t_offset = i0_t_offset @@ -95,7 +95,7 @@ def __init__( self.eta_log_sigma_ww_site_rv = eta_log_sigma_ww_site_rv self.ww_site_mod_raw_rv = ww_site_mod_raw_rv self.ww_site_mod_sd_rv = ww_site_mod_sd_rv - self.hospital_admission_obs_rv = hospital_admission_obs_rv + self.phi_rv = phi_rv self.ww_ml_produced_per_day = ww_ml_produced_per_day self.pop_fraction_reshaped = pop_fraction_reshaped self.ww_uncensored = ww_uncensored @@ -104,6 +104,7 @@ def __init__( self.ww_sampled_sites = ww_sampled_sites self.ww_sampled_times = ww_sampled_times self.ww_log_lod = ww_log_lod + self.inf_with_feedback_proc = InfectionsWithFeedback( infection_feedback_strength=infection_feedback_strength_rv, infection_feedback_pmf=infection_feedback_pmf_rv, @@ -124,10 +125,7 @@ def sample( "Either n_datapoints or data_observed_hosp_admissions " "must be passed." ) - elif ( - n_datapoints is not None - and data_observed_hospital_admissions is not None - ): + elif n_datapoints is not None and data_observed_hospital_admissions is not None: raise ValueError( "Cannot pass both n_datapoints and data_observed_hospital_admissions." ) @@ -193,8 +191,7 @@ def sample( "rtu_site_ar_init", dist.Normal( 0, - sigma_rt[0].value - / jnp.sqrt(1 - jnp.pow(autoreg_rt_site[0].value, 2)), + sigma_rt / jnp.sqrt(1 - jnp.pow(autoreg_rt_site, 2)), ), ) @@ -222,7 +219,7 @@ def sample( i0_over_n = self.i0_over_n_rv() sigma_i0 = self.sigma_i0_rv() eta_i0 = self.eta_i0_rv() - initial_growth = self.initialization_rate_rv() + initial_growth = self.initial_growth_rv() eta_growth = self.eta_growth_rv() sigma_growth = self.sigma_growth_rv() @@ -231,8 +228,7 @@ def sample( # site level growth rate growth_site = ( - initial_growth[0].value - + eta_growth[0].value * sigma_growth[0].value + initial_growth[0].value + eta_growth[0].value * sigma_growth[0].value ) growth_site_rv = DeterministicVariable( @@ -275,15 +271,11 @@ def sample( inf_with_feedback_proc_sample.post_initialization_infections.value, ] ) - r_site_t = r_site_t.at[i, :].set( - inf_with_feedback_proc_sample.rt.value - ) + r_site_t = r_site_t.at[i, :].set(inf_with_feedback_proc_sample.rt.value) new_i_site_matrix = new_i_site_matrix.at[i, :].set(new_i_site) # number of net infected individuals shedding on each day (sum of individuals in dift stages of infection) - model_net_i = jnp.convolve(new_i_site, s, mode="valid")[ - -n_datapoints: - ] + model_net_i = jnp.convolve(new_i_site, s, mode="valid")[-n_datapoints:] log10_g = self.log10_g_rv() @@ -368,13 +360,10 @@ def sample( # LHS log transformed obs genomes per person-day, RHS multiplies the expected observed # genomes by the site-specific multiplier at that sampling time - exp_obs_log_v = ( - exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] - ) + exp_obs_log_v = exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] sigma_ww_site = jnp.exp( - jnp.log(mode_sigma_ww_site) - + sd_log_sigma_ww_site * eta_log_sigma_ww_site + jnp.log(mode_sigma_ww_site) + sd_log_sigma_ww_site * eta_log_sigma_ww_site ) # g = jnp.power( @@ -385,9 +374,7 @@ def sample( "log_conc", dist.Normal( loc=exp_obs_log_v[self.ww_uncensored], - scale=sigma_ww_site[ - self.ww_sampled_lab_sites[self.ww_uncensored] - ], + scale=sigma_ww_site[self.ww_sampled_lab_sites[self.ww_uncensored]], ), obs=data_observed_log_conc[self.ww_uncensored], ) @@ -395,14 +382,16 @@ def sample( if self.ww_censored.shape[0] != 0: log_cdf_values = dist.Normal( loc=exp_obs_log_v[self.ww_censored], - scale=sigma_ww_site[ - self.ww_sampled_lab_sites[self.ww_censored] - ], + scale=sigma_ww_site[self.ww_sampled_lab_sites[self.ww_censored]], ).log_cdf(self.ww_log_lod[self.ww_censored]) numpyro.factor("log_prob_censored", log_cdf_values.sum()) - observed_hospital_admissions = self.hospital_admission_obs_rv( + hospital_admission_obs_rv = NegativeBinomialObservation( + "observed_hospital_admissions", concentration_rv=self.phi_rv + ) + + observed_hospital_admissions = hospital_admission_obs_rv( mu=latent_hospital_admissions, obs=data_observed_hospital_admissions, ) From 4263287ec19329539ddce3bda9f25b15600de8b6 Mon Sep 17 00:00:00 2001 From: sbidari Date: Tue, 3 Sep 2024 12:54:21 -0400 Subject: [PATCH 05/50] create rvs --- notebooks/site_level_ww_model_demo.qmd | 322 ++++++++++++++++++ .../site_level_dynamics_model.py | 6 +- 2 files changed, 327 insertions(+), 1 deletion(-) create mode 100644 notebooks/site_level_ww_model_demo.qmd diff --git a/notebooks/site_level_ww_model_demo.qmd b/notebooks/site_level_ww_model_demo.qmd new file mode 100644 index 00000000..c5f4fc08 --- /dev/null +++ b/notebooks/site_level_ww_model_demo.qmd @@ -0,0 +1,322 @@ +--- +jupyter: python3 +--- + +```{python} +import json + +import numpyro +import numpyro.distributions as dist +import numpyro.distributions.transforms as transforms +from pyrenew.deterministic import DeterministicVariable, DeterministicPMF +from pyrenew.latent import ( + InfectionsWithFeedback +) + +from pyrenew.randomvariable import DistributionalVariable, TransformedVariable +from pyrenew_covid_wastewater.site_level_dynamics_model import ww_site_level_dynamics_model + +numpyro.set_host_device_count(1) +``` + +```{python} +with open("data/fit/stan_data.json","r") as file: + stan_data = json.load(file) + +#helper function +from pyrenew_covid_wastewater.utils import * +``` + +```{python} +gt_max = stan_data["gt_max"] #lower=1 +hosp_delay_max = stan_data["hosp_delay_max"] +n_initialization_points = max(gt_max, hosp_delay_max) -1 +i0_t_offset = 0 # check this later + +# maximum time index for the hospital admissions (max number of days we could have observations) +obs_time = stan_data["ot"] +horizon_time = stan_data["ht"] #horizon time (nowcast + forecast time) +n_weeks = stan_data["n_weeks"] +unobs_time = stan_data["uot"] #unobserved time before we observe hospital admissions/ WW +#n_datapoints = obs_time+horizon_time + +n_subpops = stan_data["n_subpops"] #number of WW sites +state_pop = stan_data["state_pop"] +subpop_size = stan_data["subpop_size"] +norm_pop = stan_data["norm_pop"] +pop_fraction = jnp.array(subpop_size)/norm_pop +pop_fraction_reshaped = pop_fraction[:, jnp.newaxis] + +#mL of ww produced per person per day +ww_ml_produced_per_day = stan_data["mwpd"] +n_ww_lab_sites = stan_data["n_ww_lab_sites"] +ww_log_lod =jnp.array(stan_data["ww_log_lod"]) # The limit of detection in that site at that time point +lab_site_to_site_map = stan_data["lab_site_to_site_map"] # which lab sites correspond to which sites + +n_censored = stan_data["n_censored"] +n_uncensored = stan_data["n_uncensored"] +ww_censored = jnp.array(stan_data["ww_censored"]) #times that the WW data is below the LOD +ww_uncensored = jnp.array(stan_data["ww_uncensored"]) #time that WW data is above LOD + +obs_ww_days = stan_data["owt"] #number of days of observed WW (should be roughly ot/7) + +ww_sampled_sites = stan_data["ww_sampled_sites"] # vector of unique sites in order of the sampled times +ww_sampled_times = stan_data["ww_sampled_times"] # a list of all of the days on which WW is sampled + +ww_sampled_lab_sites = jnp.array(stan_data["ww_sampled_lab_sites"]) + +data_observed_log_conc = jnp.array(stan_data["log_conc"]) + +data_observed_hospital_admissions = jnp.array(stan_data["hosp"]) + +``` + +```{python} + +# State-leve R(t) AR + RW implementation: + +eta_sd_sd = stan_data["eta_sd_sd"] +eta_sd_rv = DistributionalVariable("eta_sd", dist.TruncatedNormal(0, eta_sd_sd, low=0)) + +autoreg_rt_a = stan_data["autoreg_rt_a"] +autoreg_rt_b = stan_data["autoreg_rt_b"] +autoreg_rt_rv = DistributionalVariable("autoreg_rt", dist.Beta(autoreg_rt_a, autoreg_rt_b)) + +r_prior_mean = stan_data["r_prior_mean"] +r_prior_sd = stan_data["r_prior_sd"] +r_logmean, r_logsd = convert_to_logmean_log_sd(r_prior_mean, r_prior_sd) + +# log of state level mean R(t) in weeks +log_r_mu_intercept_rv = DistributionalVariable( + "log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd) +) + +# viral shedding parameters +viral_shedding_pars = stan_data["viral_shedding_pars"] + +t_peak_mean = viral_shedding_pars[0] +t_peak_sd = viral_shedding_pars[1] +viral_peak_mean = viral_shedding_pars[2] +viral_peak_sd = viral_shedding_pars[3] +dur_shed_mean = viral_shedding_pars[4] +dur_shed_sd = viral_shedding_pars[5] + +t_peak_rv = DistributionalVariable( + "t_peak", dist.Normal(t_peak_mean, t_peak_sd) +) + +viral_peak_rv = DistributionalVariable( + "viral_peak", dist.Normal(viral_peak_mean, viral_peak_sd) +) + +dur_shed_rv = DistributionalVariable( + "dur_shed", dist.Normal(dur_shed_mean, dur_shed_sd) +) + +# Infection +infection_feedback_prior_logmean = stan_data["inf_feedback_prior_logmean"] +infection_feedback_prior_logsd = stan_data["inf_feedback_prior_logsd"] +infection_feedback_strength_rv = TransformedVariable( + "inf_feedback", + DistributionalVariable( + "inf_feedback_raw", + dist.LogNormal(infection_feedback_prior_logmean, infection_feedback_prior_logsd), + ), + transforms=transforms.AffineTransform(loc=0, scale=-1), +) + +infection_feedback_pmf = stan_data["infection_feedback_pmf"] +infection_feedback_pmf_rv = DeterministicPMF( + "infection_feedback_pmf", jnp.array(infection_feedback_pmf) +) + +# generation interval distribution +generation_interval = stan_data["generation_interval"] +generation_interval_pmf_rv = DeterministicPMF( + "generation_interval_pmf", jnp.array(generation_interval) +) + +autoreg_rt_site_a = stan_data["autoreg_rt_site_a"] +autoreg_rt_site_b = stan_data["autoreg_rt_site_b"] +autoreg_rt_site_rv = DistributionalVariable( + "autoreg_rt_site",dist.Beta(autoreg_rt_site_a, autoreg_rt_site_b) + ) + +sigma_rt_prior = stan_data["sigma_rt_prior"] +sigma_rt_rv = DistributionalVariable( + "sigma_rt", dist.Normal(0,sigma_rt_prior) +) + +i0_over_n_prior_a = stan_data["i0_over_n_prior_a"] +i0_over_n_prior_b = stan_data["i0_over_n_prior_b"] +i0_over_n_rv = DistributionalVariable( + "i0_over_n_rv", dist.Beta(i0_over_n_prior_a, i0_over_n_prior_b) +) + +initial_growth_prior_mean = stan_data["initial_growth_prior_mean"] +initial_growth_prior_sd = stan_data["initial_growth_prior_sd"] +initialization_rate_rv = DistributionalVariable( + "rate", + dist.TruncatedNormal( + loc=initial_growth_prior_mean, + scale=initial_growth_prior_sd, + low=-1, + high=1, + ), +) + +sigma_i0_prior_mode = stan_data["sigma_i0_prior_mode"] +sigma_i0_prior_sd =stan_data["sigma_i0_prior_sd"] +# stdev between logit state and site initial per capita infection incidence +sigma_i0_rv = DistributionalVariable( + "sigma_i0", dist.Normal(sigma_i0_prior_mode, sigma_i0_prior_sd) +) + +#z-score on logit scale of state initial per capita infection incidence relative to state value +eta_i0_rv = DistributionalVariable( + "eta_i0", dist.Normal(0,1) #.expand([n_subpops]) +) + +#vector[n_subpops] eta_growth; +eta_growth_rv = DistributionalVariable( + "eta_growth", dist.Normal(0,1) #.expand([n_subpops]) +) + +sigma_growth_rv = DistributionalVariable( + "sigma_growth", dist.Normal(0,0.05) +) + +p_hosp_prior_mean = stan_data["p_hosp_prior_mean"] +p_hosp_sd_logit = stan_data["p_hosp_sd_logit"] +p_hosp_mean_rv = DistributionalVariable( + "p_hosp_mean", + dist.Normal(transforms.logit(p_hosp_prior_mean), p_hosp_sd_logit), +) # logit scale + +p_hosp_w_sd_sd = stan_data["p_hosp_w_sd_sd"] +p_hosp_w_sd_rv = DistributionalVariable( + "p_hosp_w_sd_sd", dist.TruncatedNormal(0, p_hosp_w_sd_sd, low=0) +) + +autoreg_p_hosp_a = stan_data["autoreg_p_hosp_a"] +autoreg_p_hosp_b = stan_data["autoreg_p_hosp_b"] +autoreg_p_hosp_rv = DistributionalVariable( + "autoreg_p_hosp", dist.Beta(autoreg_p_hosp_a, autoreg_p_hosp_b) +) + +hosp_wday_effect_rv = TransformedVariable( + "hosp_wday_effect", + DistributionalVariable( + "hosp_wday_effect_raw", + dist.Dirichlet(jnp.array(stan_data["hosp_wday_effect_prior_alpha"])), + ), + transforms.AffineTransform(loc=0, scale=7), +) + +inf_to_hosp_rv = DeterministicVariable( + "inf_to_hosp", jnp.array(stan_data["inf_to_hosp"]) +) + +inv_sqrt_phi_prior_mean = stan_data["inv_sqrt_phi_prior_mean"] +inv_sqrt_phi_prior_sd = stan_data["inv_sqrt_phi_prior_sd"] +phi_rv = TransformedVariable( + "phi", + DistributionalVariable( + "inv_sqrt_phi", + dist.TruncatedNormal( + loc=inv_sqrt_phi_prior_mean, + scale=inv_sqrt_phi_prior_sd, + low=1 / jnp.sqrt(5000), + ), + ), + transforms=transforms.PowerTransform(-2), +) + +log10_g_prior_mean = stan_data["log10_g_prior_mean"] # mean log10 of number of genomes per infected individual +log10_g_prior_sd = stan_data["log10_g_prior_sd"] +log10_g_rv = DistributionalVariable( + "log10_g", dist.Normal(log10_g_prior_mean, log10_g_prior_sd) +) + +mode_sigma_ww_site_prior_mode = stan_data["mode_sigma_ww_site_prior_mode"] +mode_sigma_ww_site_prior_sd = stan_data["mode_sigma_ww_site_prior_sd"] +mode_sigma_ww_site_rv = DistributionalVariable( + "mode_sigma_ww_site", dist.Normal(mode_sigma_ww_site_prior_mode,mode_sigma_ww_site_prior_sd) +) + +sd_log_sigma_ww_site_prior_mode = stan_data["sd_log_sigma_ww_site_prior_mode"] +sd_log_sigma_ww_site_prior_sd = stan_data["sd_log_sigma_ww_site_prior_sd"] +sd_log_sigma_ww_site_rv = DistributionalVariable( + "sd_log_sigma_ww_site", dist.TruncatedNormal(sd_log_sigma_ww_site_prior_mode, sd_log_sigma_ww_site_prior_sd,low=0) +) + +eta_log_sigma_ww_site_rv = DistributionalVariable( + "eta_log_sigma_ww_site", dist.Normal(0,1).expand([n_ww_lab_sites]) +) + +ww_site_mod_raw_rv = DistributionalVariable( + "ww_site_mod_raw", dist.Normal(0,1).expand([n_ww_lab_sites]) +) + +ww_site_mod_sd_sd = stan_data["ww_site_mod_sd_sd"] +ww_site_mod_sd_rv = DistributionalVariable( + "ww_site_mod_sd", dist.TruncatedNormal(0,ww_site_mod_sd_sd,low=0) +) +``` + +```{python} +my_model = ww_site_level_dynamics_model( + state_pop, + n_subpops, + n_initialization_points, + gt_max, + i0_t_offset, + log_r_mu_intercept_rv, + autoreg_rt_rv, + eta_sd_rv, + t_peak_rv, + viral_peak_rv, + dur_shed_rv, + autoreg_rt_site_rv, + sigma_rt_rv, + i0_over_n_rv, + sigma_i0_rv, + eta_i0_rv, + initialization_rate_rv, + eta_growth_rv, + sigma_growth_rv, + generation_interval_pmf_rv, + infection_feedback_strength_rv, + infection_feedback_pmf_rv, + p_hosp_mean_rv, + p_hosp_w_sd_rv, + autoreg_p_hosp_rv, + hosp_wday_effect_rv, + inf_to_hosp_rv, + log10_g_rv, + mode_sigma_ww_site_rv, + sd_log_sigma_ww_site_rv, + eta_log_sigma_ww_site_rv, + ww_site_mod_raw_rv, + ww_site_mod_sd_rv, + phi_rv, + ww_ml_produced_per_day, + pop_fraction_reshaped, + ww_uncensored, + ww_censored, + ww_sampled_lab_sites, + ww_sampled_sites, + ww_sampled_times, + ww_log_lod, +) +``` + +```{python} +prior_pred = my_model.prior_predictive( + n_datapoints = len(data_observed_hospital_admissions), + numpyro_predictive_args={"num_samples":10} +) +``` + + + diff --git a/pyrenew_covid_wastewater/site_level_dynamics_model.py b/pyrenew_covid_wastewater/site_level_dynamics_model.py index 8ddf6cde..7cc867fe 100644 --- a/pyrenew_covid_wastewater/site_level_dynamics_model.py +++ b/pyrenew_covid_wastewater/site_level_dynamics_model.py @@ -376,7 +376,11 @@ def sample( loc=exp_obs_log_v[self.ww_uncensored], scale=sigma_ww_site[self.ww_sampled_lab_sites[self.ww_uncensored]], ), - obs=data_observed_log_conc[self.ww_uncensored], + obs=( + data_observed_log_conc[self.ww_uncensored] + if data_observed_log_conc is not None + else None + ), ) if self.ww_censored.shape[0] != 0: From 77932853a34a3dfb8df3f1cfcacdeb5182deef9d Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 5 Sep 2024 16:41:20 -0400 Subject: [PATCH 06/50] sync --- pyrenew_covid_wastewater/site_level_dynamics_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyrenew_covid_wastewater/site_level_dynamics_model.py b/pyrenew_covid_wastewater/site_level_dynamics_model.py index 7cc867fe..2d5f68ce 100644 --- a/pyrenew_covid_wastewater/site_level_dynamics_model.py +++ b/pyrenew_covid_wastewater/site_level_dynamics_model.py @@ -371,7 +371,7 @@ def sample( # ) # Estimated genomes shed per infected individual log_conc_obs = numpyro.sample( - "log_conc", + "log_conc_uncensored", dist.Normal( loc=exp_obs_log_v[self.ww_uncensored], scale=sigma_ww_site[self.ww_sampled_lab_sites[self.ww_uncensored]], @@ -389,7 +389,7 @@ def sample( scale=sigma_ww_site[self.ww_sampled_lab_sites[self.ww_censored]], ).log_cdf(self.ww_log_lod[self.ww_censored]) - numpyro.factor("log_prob_censored", log_cdf_values.sum()) + numpyro.factor("prob_conc_censored", log_cdf_values.sum()) hospital_admission_obs_rv = NegativeBinomialObservation( "observed_hospital_admissions", concentration_rv=self.phi_rv From 927f18ba0eb697990e9c414a997acc0ff3e8e227 Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 16 Sep 2024 14:43:17 -0400 Subject: [PATCH 07/50] synchronize recent pyrenew changes --- notebooks/site_level_ww_model_demo.qmd | 322 ------------------ .../site_level_dynamics_model.py | 112 +++--- 2 files changed, 50 insertions(+), 384 deletions(-) delete mode 100644 notebooks/site_level_ww_model_demo.qmd diff --git a/notebooks/site_level_ww_model_demo.qmd b/notebooks/site_level_ww_model_demo.qmd deleted file mode 100644 index c5f4fc08..00000000 --- a/notebooks/site_level_ww_model_demo.qmd +++ /dev/null @@ -1,322 +0,0 @@ ---- -jupyter: python3 ---- - -```{python} -import json - -import numpyro -import numpyro.distributions as dist -import numpyro.distributions.transforms as transforms -from pyrenew.deterministic import DeterministicVariable, DeterministicPMF -from pyrenew.latent import ( - InfectionsWithFeedback -) - -from pyrenew.randomvariable import DistributionalVariable, TransformedVariable -from pyrenew_covid_wastewater.site_level_dynamics_model import ww_site_level_dynamics_model - -numpyro.set_host_device_count(1) -``` - -```{python} -with open("data/fit/stan_data.json","r") as file: - stan_data = json.load(file) - -#helper function -from pyrenew_covid_wastewater.utils import * -``` - -```{python} -gt_max = stan_data["gt_max"] #lower=1 -hosp_delay_max = stan_data["hosp_delay_max"] -n_initialization_points = max(gt_max, hosp_delay_max) -1 -i0_t_offset = 0 # check this later - -# maximum time index for the hospital admissions (max number of days we could have observations) -obs_time = stan_data["ot"] -horizon_time = stan_data["ht"] #horizon time (nowcast + forecast time) -n_weeks = stan_data["n_weeks"] -unobs_time = stan_data["uot"] #unobserved time before we observe hospital admissions/ WW -#n_datapoints = obs_time+horizon_time - -n_subpops = stan_data["n_subpops"] #number of WW sites -state_pop = stan_data["state_pop"] -subpop_size = stan_data["subpop_size"] -norm_pop = stan_data["norm_pop"] -pop_fraction = jnp.array(subpop_size)/norm_pop -pop_fraction_reshaped = pop_fraction[:, jnp.newaxis] - -#mL of ww produced per person per day -ww_ml_produced_per_day = stan_data["mwpd"] -n_ww_lab_sites = stan_data["n_ww_lab_sites"] -ww_log_lod =jnp.array(stan_data["ww_log_lod"]) # The limit of detection in that site at that time point -lab_site_to_site_map = stan_data["lab_site_to_site_map"] # which lab sites correspond to which sites - -n_censored = stan_data["n_censored"] -n_uncensored = stan_data["n_uncensored"] -ww_censored = jnp.array(stan_data["ww_censored"]) #times that the WW data is below the LOD -ww_uncensored = jnp.array(stan_data["ww_uncensored"]) #time that WW data is above LOD - -obs_ww_days = stan_data["owt"] #number of days of observed WW (should be roughly ot/7) - -ww_sampled_sites = stan_data["ww_sampled_sites"] # vector of unique sites in order of the sampled times -ww_sampled_times = stan_data["ww_sampled_times"] # a list of all of the days on which WW is sampled - -ww_sampled_lab_sites = jnp.array(stan_data["ww_sampled_lab_sites"]) - -data_observed_log_conc = jnp.array(stan_data["log_conc"]) - -data_observed_hospital_admissions = jnp.array(stan_data["hosp"]) - -``` - -```{python} - -# State-leve R(t) AR + RW implementation: - -eta_sd_sd = stan_data["eta_sd_sd"] -eta_sd_rv = DistributionalVariable("eta_sd", dist.TruncatedNormal(0, eta_sd_sd, low=0)) - -autoreg_rt_a = stan_data["autoreg_rt_a"] -autoreg_rt_b = stan_data["autoreg_rt_b"] -autoreg_rt_rv = DistributionalVariable("autoreg_rt", dist.Beta(autoreg_rt_a, autoreg_rt_b)) - -r_prior_mean = stan_data["r_prior_mean"] -r_prior_sd = stan_data["r_prior_sd"] -r_logmean, r_logsd = convert_to_logmean_log_sd(r_prior_mean, r_prior_sd) - -# log of state level mean R(t) in weeks -log_r_mu_intercept_rv = DistributionalVariable( - "log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd) -) - -# viral shedding parameters -viral_shedding_pars = stan_data["viral_shedding_pars"] - -t_peak_mean = viral_shedding_pars[0] -t_peak_sd = viral_shedding_pars[1] -viral_peak_mean = viral_shedding_pars[2] -viral_peak_sd = viral_shedding_pars[3] -dur_shed_mean = viral_shedding_pars[4] -dur_shed_sd = viral_shedding_pars[5] - -t_peak_rv = DistributionalVariable( - "t_peak", dist.Normal(t_peak_mean, t_peak_sd) -) - -viral_peak_rv = DistributionalVariable( - "viral_peak", dist.Normal(viral_peak_mean, viral_peak_sd) -) - -dur_shed_rv = DistributionalVariable( - "dur_shed", dist.Normal(dur_shed_mean, dur_shed_sd) -) - -# Infection -infection_feedback_prior_logmean = stan_data["inf_feedback_prior_logmean"] -infection_feedback_prior_logsd = stan_data["inf_feedback_prior_logsd"] -infection_feedback_strength_rv = TransformedVariable( - "inf_feedback", - DistributionalVariable( - "inf_feedback_raw", - dist.LogNormal(infection_feedback_prior_logmean, infection_feedback_prior_logsd), - ), - transforms=transforms.AffineTransform(loc=0, scale=-1), -) - -infection_feedback_pmf = stan_data["infection_feedback_pmf"] -infection_feedback_pmf_rv = DeterministicPMF( - "infection_feedback_pmf", jnp.array(infection_feedback_pmf) -) - -# generation interval distribution -generation_interval = stan_data["generation_interval"] -generation_interval_pmf_rv = DeterministicPMF( - "generation_interval_pmf", jnp.array(generation_interval) -) - -autoreg_rt_site_a = stan_data["autoreg_rt_site_a"] -autoreg_rt_site_b = stan_data["autoreg_rt_site_b"] -autoreg_rt_site_rv = DistributionalVariable( - "autoreg_rt_site",dist.Beta(autoreg_rt_site_a, autoreg_rt_site_b) - ) - -sigma_rt_prior = stan_data["sigma_rt_prior"] -sigma_rt_rv = DistributionalVariable( - "sigma_rt", dist.Normal(0,sigma_rt_prior) -) - -i0_over_n_prior_a = stan_data["i0_over_n_prior_a"] -i0_over_n_prior_b = stan_data["i0_over_n_prior_b"] -i0_over_n_rv = DistributionalVariable( - "i0_over_n_rv", dist.Beta(i0_over_n_prior_a, i0_over_n_prior_b) -) - -initial_growth_prior_mean = stan_data["initial_growth_prior_mean"] -initial_growth_prior_sd = stan_data["initial_growth_prior_sd"] -initialization_rate_rv = DistributionalVariable( - "rate", - dist.TruncatedNormal( - loc=initial_growth_prior_mean, - scale=initial_growth_prior_sd, - low=-1, - high=1, - ), -) - -sigma_i0_prior_mode = stan_data["sigma_i0_prior_mode"] -sigma_i0_prior_sd =stan_data["sigma_i0_prior_sd"] -# stdev between logit state and site initial per capita infection incidence -sigma_i0_rv = DistributionalVariable( - "sigma_i0", dist.Normal(sigma_i0_prior_mode, sigma_i0_prior_sd) -) - -#z-score on logit scale of state initial per capita infection incidence relative to state value -eta_i0_rv = DistributionalVariable( - "eta_i0", dist.Normal(0,1) #.expand([n_subpops]) -) - -#vector[n_subpops] eta_growth; -eta_growth_rv = DistributionalVariable( - "eta_growth", dist.Normal(0,1) #.expand([n_subpops]) -) - -sigma_growth_rv = DistributionalVariable( - "sigma_growth", dist.Normal(0,0.05) -) - -p_hosp_prior_mean = stan_data["p_hosp_prior_mean"] -p_hosp_sd_logit = stan_data["p_hosp_sd_logit"] -p_hosp_mean_rv = DistributionalVariable( - "p_hosp_mean", - dist.Normal(transforms.logit(p_hosp_prior_mean), p_hosp_sd_logit), -) # logit scale - -p_hosp_w_sd_sd = stan_data["p_hosp_w_sd_sd"] -p_hosp_w_sd_rv = DistributionalVariable( - "p_hosp_w_sd_sd", dist.TruncatedNormal(0, p_hosp_w_sd_sd, low=0) -) - -autoreg_p_hosp_a = stan_data["autoreg_p_hosp_a"] -autoreg_p_hosp_b = stan_data["autoreg_p_hosp_b"] -autoreg_p_hosp_rv = DistributionalVariable( - "autoreg_p_hosp", dist.Beta(autoreg_p_hosp_a, autoreg_p_hosp_b) -) - -hosp_wday_effect_rv = TransformedVariable( - "hosp_wday_effect", - DistributionalVariable( - "hosp_wday_effect_raw", - dist.Dirichlet(jnp.array(stan_data["hosp_wday_effect_prior_alpha"])), - ), - transforms.AffineTransform(loc=0, scale=7), -) - -inf_to_hosp_rv = DeterministicVariable( - "inf_to_hosp", jnp.array(stan_data["inf_to_hosp"]) -) - -inv_sqrt_phi_prior_mean = stan_data["inv_sqrt_phi_prior_mean"] -inv_sqrt_phi_prior_sd = stan_data["inv_sqrt_phi_prior_sd"] -phi_rv = TransformedVariable( - "phi", - DistributionalVariable( - "inv_sqrt_phi", - dist.TruncatedNormal( - loc=inv_sqrt_phi_prior_mean, - scale=inv_sqrt_phi_prior_sd, - low=1 / jnp.sqrt(5000), - ), - ), - transforms=transforms.PowerTransform(-2), -) - -log10_g_prior_mean = stan_data["log10_g_prior_mean"] # mean log10 of number of genomes per infected individual -log10_g_prior_sd = stan_data["log10_g_prior_sd"] -log10_g_rv = DistributionalVariable( - "log10_g", dist.Normal(log10_g_prior_mean, log10_g_prior_sd) -) - -mode_sigma_ww_site_prior_mode = stan_data["mode_sigma_ww_site_prior_mode"] -mode_sigma_ww_site_prior_sd = stan_data["mode_sigma_ww_site_prior_sd"] -mode_sigma_ww_site_rv = DistributionalVariable( - "mode_sigma_ww_site", dist.Normal(mode_sigma_ww_site_prior_mode,mode_sigma_ww_site_prior_sd) -) - -sd_log_sigma_ww_site_prior_mode = stan_data["sd_log_sigma_ww_site_prior_mode"] -sd_log_sigma_ww_site_prior_sd = stan_data["sd_log_sigma_ww_site_prior_sd"] -sd_log_sigma_ww_site_rv = DistributionalVariable( - "sd_log_sigma_ww_site", dist.TruncatedNormal(sd_log_sigma_ww_site_prior_mode, sd_log_sigma_ww_site_prior_sd,low=0) -) - -eta_log_sigma_ww_site_rv = DistributionalVariable( - "eta_log_sigma_ww_site", dist.Normal(0,1).expand([n_ww_lab_sites]) -) - -ww_site_mod_raw_rv = DistributionalVariable( - "ww_site_mod_raw", dist.Normal(0,1).expand([n_ww_lab_sites]) -) - -ww_site_mod_sd_sd = stan_data["ww_site_mod_sd_sd"] -ww_site_mod_sd_rv = DistributionalVariable( - "ww_site_mod_sd", dist.TruncatedNormal(0,ww_site_mod_sd_sd,low=0) -) -``` - -```{python} -my_model = ww_site_level_dynamics_model( - state_pop, - n_subpops, - n_initialization_points, - gt_max, - i0_t_offset, - log_r_mu_intercept_rv, - autoreg_rt_rv, - eta_sd_rv, - t_peak_rv, - viral_peak_rv, - dur_shed_rv, - autoreg_rt_site_rv, - sigma_rt_rv, - i0_over_n_rv, - sigma_i0_rv, - eta_i0_rv, - initialization_rate_rv, - eta_growth_rv, - sigma_growth_rv, - generation_interval_pmf_rv, - infection_feedback_strength_rv, - infection_feedback_pmf_rv, - p_hosp_mean_rv, - p_hosp_w_sd_rv, - autoreg_p_hosp_rv, - hosp_wday_effect_rv, - inf_to_hosp_rv, - log10_g_rv, - mode_sigma_ww_site_rv, - sd_log_sigma_ww_site_rv, - eta_log_sigma_ww_site_rv, - ww_site_mod_raw_rv, - ww_site_mod_sd_rv, - phi_rv, - ww_ml_produced_per_day, - pop_fraction_reshaped, - ww_uncensored, - ww_censored, - ww_sampled_lab_sites, - ww_sampled_sites, - ww_sampled_times, - ww_log_lod, -) -``` - -```{python} -prior_pred = my_model.prior_predictive( - n_datapoints = len(data_observed_hospital_admissions), - numpyro_predictive_args={"num_samples":10} -) -``` - - - diff --git a/pyrenew_covid_wastewater/site_level_dynamics_model.py b/pyrenew_covid_wastewater/site_level_dynamics_model.py index 2d5f68ce..52e64160 100644 --- a/pyrenew_covid_wastewater/site_level_dynamics_model.py +++ b/pyrenew_covid_wastewater/site_level_dynamics_model.py @@ -12,7 +12,7 @@ InitializeInfectionsExponentialGrowth, ) from pyrenew.metaclass import Model -from pyrenew.process import ARProcess, RtWeeklyDiffARProcess +from pyrenew.process import ARProcess, DifferencedProcess from pyrenew.randomvariable import DistributionalVariable from pyrenew.observation import NegativeBinomialObservation from pyrenew_covid_wastewater.utils import get_vl_trajectory @@ -109,6 +109,14 @@ def __init__( infection_feedback_strength=infection_feedback_strength_rv, infection_feedback_pmf=infection_feedback_pmf_rv, ) + + self.ar_diff_rt = DifferencedProcess( + fundamental_process=ARProcess( + noise_rv_name="rtu_weekly_diff_first_diff_ar_process_noise" + ), + differencing_order=1, + ) + return None def validate(self): # numpydoc ignore=GL08 @@ -134,45 +142,30 @@ def sample( else: n_datapoints = n_datapoints - n_weeks = n_datapoints // 7 + 1 + n_weeks_post_init = n_datapoints // 7 + 1 - eta_sd = self.eta_sd_rv()[0].value - autoreg_rt = self.autoreg_rt_rv()[0].value - log_r_mu_intercept = self.log_r_mu_intercept_rv()[0].value + eta_sd = self.eta_sd_rv() + autoreg_rt = self.autoreg_rt_rv() + log_r_mu_intercept = self.log_r_mu_intercept_rv() - autoreg_rt_det_rv = DeterministicVariable("autoreg_rt_det", autoreg_rt) - init_rate_of_change_rv = DistributionalVariable( - "init_rate_of_change", + rt_init_rate_of_change_rv = DistributionalVariable( + "rt_init_rate_of_change", dist.Normal(0, eta_sd / jnp.sqrt(1 - jnp.pow(autoreg_rt, 2))), ) - init_rate_of_change = init_rate_of_change_rv()[0].value - - rt_proc = RtWeeklyDiffARProcess( - name="rtu_state_weekly_diff", - offset=0, - log_rt_rv=DeterministicVariable( - name="log_rt", - value=jnp.array( - [ - log_r_mu_intercept, - log_r_mu_intercept + init_rate_of_change, - ] - ), - ), - autoreg_rv=autoreg_rt_det_rv, - periodic_diff_sd_rv=DeterministicVariable( - name="periodic_diff_sd", value=jnp.array(eta_sd) - ), - ) + rt_init_rate_of_change = rt_init_rate_of_change_rv() - rtu = rt_proc.sample( - duration=n_datapoints - ) # log_r_mu_t_in_weeks in stan - not log anymore and not weekly either + log_rtu_weekly = self.ar_diff_rt( + n=n_weeks_post_init, + init_vals=jnp.array(log_r_mu_intercept), + autoreg=jnp.array(autoreg_rt), + noise_sd=jnp.array(eta_sd), + fundamental_process_init_vals=jnp.array(rt_init_rate_of_change), + ) - t_peak = self.t_peak_rv()[0].value - viral_peak = self.viral_peak_rv()[0].value - dur_shed = self.dur_shed_rv()[0].value + t_peak = self.t_peak_rv() + viral_peak = self.viral_peak_rv() + dur_shed = self.dur_shed_rv() s = get_vl_trajectory(t_peak, viral_peak, dur_shed, self.gt_max) @@ -184,8 +177,8 @@ def sample( model_log_v_ot = jnp.zeros((self.n_subpops, n_datapoints)) for i in range(self.n_subpops): - autoreg_rt_site = self.autoreg_rt_site_rv()[0].value - sigma_rt = self.sigma_rt_rv()[0].value + autoreg_rt_site = self.autoreg_rt_site_rv() + sigma_rt = self.sigma_rt_rv() rtu_site_ar_init_rv = DistributionalVariable( "rtu_site_ar_init", @@ -197,16 +190,17 @@ def sample( rtu_site_ar_proc = ARProcess(noise_rv_name="rtu_ar_proc") - rtu_site_ar_init = rtu_site_ar_init_rv()[0].value + rtu_site_ar_init = rtu_site_ar_init_rv() rtu_site_ar_weekly = rtu_site_ar_proc( - n=n_weeks, + n=n_weeks_post_init, init_vals=rtu_site_ar_init, autoreg=autoreg_rt_site, noise_sd=sigma_rt, ) + # need some changes here on combining the two vals rtu_site_ar = jnp.repeat( - transformation.ExpTransform()(rtu_site_ar_weekly[0].value), + transformation.ExpTransform()(rtu_site_ar_weekly), repeats=7, )[:n_datapoints] @@ -227,9 +221,7 @@ def sample( # These are computed as a vector in stan code, but iid implementation is probably better for using numpyro.plate # site level growth rate - growth_site = ( - initial_growth[0].value + eta_growth[0].value * sigma_growth[0].value - ) + growth_site = initial_growth + eta_growth * sigma_growth growth_site_rv = DeterministicVariable( "growth_site_rv", jnp.array(growth_site) @@ -237,8 +229,7 @@ def sample( # site-level initial per capita infection incidence i0_site_over_n = jax.nn.sigmoid( - transforms.logit(i0_over_n[0].value) - + eta_i0[0].value * sigma_i0[0].value + transforms.logit(i0_over_n) + eta_i0 * sigma_i0 ) i0_site_over_n_rv = DeterministicVariable( @@ -261,13 +252,13 @@ def sample( inf_with_feedback_proc_sample = self.inf_with_feedback_proc.sample( Rt=rtu_site, - I0=i0[0].value, - gen_int=generation_interval_pmf[0].value, + I0=i0, + gen_int=generation_interval_pmf, ) new_i_site = jnp.concat( [ - i0[0].value, + i0, inf_with_feedback_proc_sample.post_initialization_infections.value, ] ) @@ -282,7 +273,7 @@ def sample( # expected observed viral genomes/mL at all observed and forecasted times # [n_subpops, ot + ht] model_log_v_ot aka do it for all subpop model_log_v_ot_site = ( - jnp.log(10) * log10_g[0].value + jnp.log(10) * log10_g + jnp.log(model_net_i[:(n_datapoints)] + 1e-8) - jnp.log(self.ww_ml_produced_per_day) ) @@ -305,28 +296,25 @@ def sample( "p_hosp_ar_init", dist.Normal( 0, - p_hosp_w_sd[0].value - / jnp.sqrt(1 - jnp.pow(autoreg_p_hosp[0].value, 2)), + p_hosp_w_sd / jnp.sqrt(1 - jnp.pow(autoreg_p_hosp, 2)), ), ) p_hosp_ar_init = p_hosp_ar_init_rv() p_hosp_ar = p_hosp_ar_proc.sample( n=n_weeks, - autoreg=autoreg_p_hosp[0].value, - init_vals=p_hosp_ar_init[0].value, - noise_sd=p_hosp_w_sd[0].value, + autoreg=autoreg_p_hosp, + init_vals=p_hosp_ar_init, + noise_sd=p_hosp_w_sd, ) ihr = jnp.repeat( - transformation.SigmoidTransform()( - p_hosp_ar[0].value + p_hosp_mean[0].value - ), + transformation.SigmoidTransform()(p_hosp_ar + p_hosp_mean), repeats=7, )[:n_datapoints] - hosp_wday_effect_raw = self.hosp_wday_effect_rv()[0].value - inf_to_hosp = self.inf_to_hosp_rv()[0].value + hosp_wday_effect_raw = self.hosp_wday_effect_rv() + inf_to_hosp = self.inf_to_hosp_rv() hosp_wday_effect = tile_until_n(hosp_wday_effect_raw, n_datapoints) @@ -343,11 +331,11 @@ def sample( * self.state_pop ) - mode_sigma_ww_site = self.mode_sigma_ww_site_rv()[0].value - sd_log_sigma_ww_site = self.sd_log_sigma_ww_site_rv()[0].value - eta_log_sigma_ww_site = self.eta_log_sigma_ww_site_rv()[0].value - ww_site_mod_raw = self.ww_site_mod_raw_rv()[0].value - ww_site_mod_sd = self.ww_site_mod_sd_rv()[0].value + mode_sigma_ww_site = self.mode_sigma_ww_site_rv() + sd_log_sigma_ww_site = self.sd_log_sigma_ww_site_rv() + eta_log_sigma_ww_site = self.eta_log_sigma_ww_site_rv() + ww_site_mod_raw = self.ww_site_mod_raw_rv() + ww_site_mod_sd = self.ww_site_mod_sd_rv() # These are the true expected genomes at the site level before observation error # (which is at the lab-site level) From d27928cc1f28cf91cf36198d50ee657311e1b278 Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 23 Sep 2024 12:04:24 -0400 Subject: [PATCH 08/50] sync changes --- .../site_level_dynamics_model.py | 294 ++++++++++-------- pyrenew_covid_wastewater/utils.py | 4 +- 2 files changed, 163 insertions(+), 135 deletions(-) diff --git a/pyrenew_covid_wastewater/site_level_dynamics_model.py b/pyrenew_covid_wastewater/site_level_dynamics_model.py index 52e64160..331eb3c5 100644 --- a/pyrenew_covid_wastewater/site_level_dynamics_model.py +++ b/pyrenew_covid_wastewater/site_level_dynamics_model.py @@ -6,15 +6,17 @@ import pyrenew.transformation as transformation from pyrenew.arrayutils import tile_until_n from pyrenew.deterministic import DeterministicVariable +from pyrenew.distributions import CensoredNormal from pyrenew.latent import ( InfectionInitializationProcess, InfectionsWithFeedback, InitializeInfectionsExponentialGrowth, ) from pyrenew.metaclass import Model +from pyrenew.observation import NegativeBinomialObservation from pyrenew.process import ARProcess, DifferencedProcess from pyrenew.randomvariable import DistributionalVariable -from pyrenew.observation import NegativeBinomialObservation + from pyrenew_covid_wastewater.utils import get_vl_trajectory @@ -23,6 +25,7 @@ def __init__( self, state_pop, n_subpops, + unobs_time, n_initialization_points, gt_max, i0_t_offset, @@ -34,12 +37,12 @@ def __init__( dur_shed_rv, autoreg_rt_site_rv, sigma_rt_rv, - i0_over_n_rv, - sigma_i0_rv, - eta_i0_rv, - initialization_rate_rv, - eta_growth_rv, - sigma_growth_rv, + i_first_obs_over_n_rv, + sigma_i_first_obs_rv, + eta_i_first_obs_rv, + sigma_initial_exp_growth_rate_rv, + eta_initial_exp_growth_rate_rv, + mean_initial_exp_growth_rate_rv, generation_interval_pmf_rv, infection_feedback_strength_rv, infection_feedback_pmf_rv, @@ -56,16 +59,18 @@ def __init__( ww_site_mod_sd_rv, phi_rv, ww_ml_produced_per_day, - pop_fraction_reshaped, + pop_fraction, ww_uncensored, ww_censored, ww_sampled_lab_sites, ww_sampled_sites, ww_sampled_times, ww_log_lod, + lab_site_to_site_map, ): # numpydoc ignore=GL08 self.state_pop = state_pop self.n_subpops = n_subpops + self.unobs_time = unobs_time self.n_initialization_points = n_initialization_points self.gt_max = gt_max self.i0_t_offset = i0_t_offset @@ -77,12 +82,14 @@ def __init__( self.dur_shed_rv = dur_shed_rv self.autoreg_rt_site_rv = autoreg_rt_site_rv self.sigma_rt_rv = sigma_rt_rv - self.i0_over_n_rv = i0_over_n_rv - self.sigma_i0_rv = sigma_i0_rv - self.eta_i0_rv = eta_i0_rv - self.initial_growth_rv = initialization_rate_rv - self.eta_growth_rv = eta_growth_rv - self.sigma_growth_rv = sigma_growth_rv + self.i_first_obs_over_n_rv = i_first_obs_over_n_rv + self.sigma_i_first_obs_rv = sigma_i_first_obs_rv + self.eta_i_first_obs_rv = eta_i_first_obs_rv + self.sigma_initial_exp_growth_rate_rv = ( + sigma_initial_exp_growth_rate_rv + ) + self.eta_initial_exp_growth_rate_rv = eta_initial_exp_growth_rate_rv + self.mean_initial_exp_growth_rate_rv = mean_initial_exp_growth_rate_rv self.generation_interval_pmf_rv = generation_interval_pmf_rv self.p_hosp_mean_rv = p_hosp_mean_rv self.p_hosp_w_sd_rv = p_hosp_w_sd_rv @@ -97,13 +104,14 @@ def __init__( self.ww_site_mod_sd_rv = ww_site_mod_sd_rv self.phi_rv = phi_rv self.ww_ml_produced_per_day = ww_ml_produced_per_day - self.pop_fraction_reshaped = pop_fraction_reshaped + self.pop_fraction = pop_fraction self.ww_uncensored = ww_uncensored self.ww_censored = ww_censored self.ww_sampled_lab_sites = ww_sampled_lab_sites self.ww_sampled_sites = ww_sampled_sites self.ww_sampled_times = ww_sampled_times self.ww_log_lod = ww_log_lod + self.lab_site_to_site_map = lab_site_to_site_map self.inf_with_feedback_proc = InfectionsWithFeedback( infection_feedback_strength=infection_feedback_strength_rv, @@ -111,9 +119,7 @@ def __init__( ) self.ar_diff_rt = DifferencedProcess( - fundamental_process=ARProcess( - noise_rv_name="rtu_weekly_diff_first_diff_ar_process_noise" - ), + fundamental_process=ARProcess(), differencing_order=1, ) @@ -133,7 +139,10 @@ def sample( "Either n_datapoints or data_observed_hosp_admissions " "must be passed." ) - elif n_datapoints is not None and data_observed_hospital_admissions is not None: + elif ( + n_datapoints is not None + and data_observed_hospital_admissions is not None + ): raise ValueError( "Cannot pass both n_datapoints and data_observed_hospital_admissions." ) @@ -156,6 +165,7 @@ def sample( rt_init_rate_of_change = rt_init_rate_of_change_rv() log_rtu_weekly = self.ar_diff_rt( + noise_name="rtu_weekly_diff_first_diff_ar_process_noise", n=n_weeks_post_init, init_vals=jnp.array(log_r_mu_intercept), autoreg=jnp.array(autoreg_rt), @@ -169,17 +179,25 @@ def sample( s = get_vl_trajectory(t_peak, viral_peak, dur_shed, self.gt_max) - # Site-level Rt, to be repeated for each site - r_site_t = jnp.zeros((self.n_subpops, n_datapoints)) - new_i_site_matrix = jnp.zeros( - (self.n_subpops, n_datapoints + self.n_initialization_points) - ) - model_log_v_ot = jnp.zeros((self.n_subpops, n_datapoints)) + with numpyro.plate("n_subpops", self.n_subpops): + i_first_obs_over_n_site = jax.nn.sigmoid( + transforms.logit(self.i_first_obs_over_n_rv()) + + self.sigma_i_first_obs_rv() * self.eta_i_first_obs_rv() + ) + + initial_exp_growth_rate_site = ( + self.mean_initial_exp_growth_rate_rv() + + self.sigma_initial_exp_growth_rate_rv() + * self.eta_initial_exp_growth_rate_rv() + ) + + log_i0_site = ( + jnp.log(i_first_obs_over_n_site) + - self.unobs_time * initial_exp_growth_rate_site + ) - for i in range(self.n_subpops): autoreg_rt_site = self.autoreg_rt_site_rv() sigma_rt = self.sigma_rt_rv() - rtu_site_ar_init_rv = DistributionalVariable( "rtu_site_ar_init", dist.Normal( @@ -187,111 +205,82 @@ def sample( sigma_rt / jnp.sqrt(1 - jnp.pow(autoreg_rt_site, 2)), ), ) - - rtu_site_ar_proc = ARProcess(noise_rv_name="rtu_ar_proc") - rtu_site_ar_init = rtu_site_ar_init_rv() - rtu_site_ar_weekly = rtu_site_ar_proc( - n=n_weeks_post_init, - init_vals=rtu_site_ar_init, - autoreg=autoreg_rt_site, - noise_sd=sigma_rt, - ) - # need some changes here on combining the two vals - rtu_site_ar = jnp.repeat( - transformation.ExpTransform()(rtu_site_ar_weekly), - repeats=7, - )[:n_datapoints] - - rtu_site = ( - rtu_site_ar + rtu.rt.value - ) # this reults in more sensible values but it should be as below? - # rtu_site = rtu_site_ar*rtu.rt.value - - # Site level disease dynamic estimates! - i0_over_n = self.i0_over_n_rv() - sigma_i0 = self.sigma_i0_rv() - eta_i0 = self.eta_i0_rv() - initial_growth = self.initial_growth_rv() - eta_growth = self.eta_growth_rv() - sigma_growth = self.sigma_growth_rv() - - # Calculate infection and adjusted Rt for each sight using site-level i0 `i0_site_over_n` and initialization rate `growth_site` - # These are computed as a vector in stan code, but iid implementation is probably better for using numpyro.plate - - # site level growth rate - growth_site = initial_growth + eta_growth * sigma_growth - - growth_site_rv = DeterministicVariable( - "growth_site_rv", jnp.array(growth_site) - ) + rtu_site_ar_proc = ARProcess() + rtu_site_ar_weekly = rtu_site_ar_proc( + noise_name="rtu_ar_proc", + n=n_weeks_post_init, + init_vals=rtu_site_ar_init[jnp.newaxis], + autoreg=autoreg_rt_site[jnp.newaxis], + noise_sd=sigma_rt, + ) - # site-level initial per capita infection incidence - i0_site_over_n = jax.nn.sigmoid( - transforms.logit(i0_over_n) + eta_i0 * sigma_i0 - ) + rtu_site = jnp.repeat( + jnp.exp(rtu_site_ar_weekly + log_rtu_weekly[:, jnp.newaxis]), + repeats=7, + axis=0, + )[:n_datapoints, :] - i0_site_over_n_rv = DeterministicVariable( - "i0_site_over_n_rv", jnp.array(i0_site_over_n) - ) + i0_site_rv = DeterministicVariable("log_i0_site", jnp.exp(log_i0_site)) + initial_exp_growth_rate_site_rv = DeterministicVariable( + "initial_exp_growth_rate_site", initial_exp_growth_rate_site + ) - infection_initialization_process = InfectionInitializationProcess( - "I0_initialization", - i0_site_over_n_rv, - InitializeInfectionsExponentialGrowth( - self.n_initialization_points, - growth_site_rv, - t_pre_init=self.i0_t_offset, - ), - t_unit=1, - ) + infection_initialization_process = InfectionInitializationProcess( + "I0_initialization", + i0_site_rv, + InitializeInfectionsExponentialGrowth( + self.n_initialization_points, + initial_exp_growth_rate_site_rv, + t_pre_init=self.i0_t_offset, + ), + ) - generation_interval_pmf = self.generation_interval_pmf_rv() - i0 = infection_initialization_process() + generation_interval_pmf = self.generation_interval_pmf_rv() + i0 = infection_initialization_process() + with numpyro.plate("n_subpops", self.n_subpops): inf_with_feedback_proc_sample = self.inf_with_feedback_proc.sample( Rt=rtu_site, I0=i0, gen_int=generation_interval_pmf, ) - new_i_site = jnp.concat( - [ - i0, - inf_with_feedback_proc_sample.post_initialization_infections.value, - ] - ) - r_site_t = r_site_t.at[i, :].set(inf_with_feedback_proc_sample.rt.value) - new_i_site_matrix = new_i_site_matrix.at[i, :].set(new_i_site) + new_i_site = jnp.concat( + [ + i0, + inf_with_feedback_proc_sample.post_initialization_infections, + ] + ) + r_site_t = inf_with_feedback_proc_sample.rt - # number of net infected individuals shedding on each day (sum of individuals in dift stages of infection) - model_net_i = jnp.convolve(new_i_site, s, mode="valid")[-n_datapoints:] + numpyro.deterministic("r_site_t", r_site_t) - log10_g = self.log10_g_rv() + state_inf_per_capita = jnp.sum(self.pop_fraction * new_i_site, axis=1) - # expected observed viral genomes/mL at all observed and forecasted times - # [n_subpops, ot + ht] model_log_v_ot aka do it for all subpop - model_log_v_ot_site = ( - jnp.log(10) * log10_g - + jnp.log(model_net_i[:(n_datapoints)] + 1e-8) - - jnp.log(self.ww_ml_produced_per_day) - ) - model_log_v_ot = model_log_v_ot.at[i, :].set(model_log_v_ot_site) + # number of net infected individuals shedding on each day (sum of individuals in dift stages of infection) + batch_colvolve_fn = lambda m: jnp.convolve(m, s, mode="valid") + model_net_i = jax.vmap(batch_colvolve_fn, in_axes=1, out_axes=1)( + new_i_site + )[-n_datapoints:, :] - state_inf_per_capita = jnp.sum( - self.pop_fraction_reshaped * new_i_site_matrix, axis=0 + log10_g = self.log10_g_rv() + + # expected observed viral genomes/mL at all observed and forecasted times + # [n_subpops, ot + ht] model_log_v_ot aka do it for all subpop + model_log_v_ot = ( + jnp.log(10) * log10_g + + jnp.log(model_net_i[:n_datapoints, :] + 1e-8) + - jnp.log(self.ww_ml_produced_per_day) ) - # Hospital admission component - # p_hosp_w is std_normal - weekly random walk for IHR + # Hospital admission component p_hosp_mean = self.p_hosp_mean_rv() p_hosp_w_sd = self.p_hosp_w_sd_rv() autoreg_p_hosp = self.autoreg_p_hosp_rv() - p_hosp_ar_proc = ARProcess("p_hosp") - p_hosp_ar_init_rv = DistributionalVariable( "p_hosp_ar_init", dist.Normal( @@ -301,8 +290,10 @@ def sample( ) p_hosp_ar_init = p_hosp_ar_init_rv() + p_hosp_ar_proc = ARProcess() p_hosp_ar = p_hosp_ar_proc.sample( - n=n_weeks, + noise_name="p_hosp", + n=n_weeks_post_init, autoreg=autoreg_p_hosp, init_vals=p_hosp_ar_init, noise_sd=p_hosp_w_sd, @@ -337,6 +328,8 @@ def sample( ww_site_mod_raw = self.ww_site_mod_raw_rv() ww_site_mod_sd = self.ww_site_mod_sd_rv() + # Observations at the site level (genomes/person/day) are: + # get a vector of genomes/person/day on the days WW was measured # These are the true expected genomes at the site level before observation error # (which is at the lab-site level) exp_obs_log_v_true = model_log_v_ot[ @@ -348,37 +341,19 @@ def sample( # LHS log transformed obs genomes per person-day, RHS multiplies the expected observed # genomes by the site-specific multiplier at that sampling time - exp_obs_log_v = exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] + exp_obs_log_v = ( + exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] + ) sigma_ww_site = jnp.exp( - jnp.log(mode_sigma_ww_site) + sd_log_sigma_ww_site * eta_log_sigma_ww_site + jnp.log(mode_sigma_ww_site) + + sd_log_sigma_ww_site * eta_log_sigma_ww_site ) # g = jnp.power( - # log10_g[0].value, 10 + # log10_g, 10 # ) # Estimated genomes shed per infected individual - log_conc_obs = numpyro.sample( - "log_conc_uncensored", - dist.Normal( - loc=exp_obs_log_v[self.ww_uncensored], - scale=sigma_ww_site[self.ww_sampled_lab_sites[self.ww_uncensored]], - ), - obs=( - data_observed_log_conc[self.ww_uncensored] - if data_observed_log_conc is not None - else None - ), - ) - - if self.ww_censored.shape[0] != 0: - log_cdf_values = dist.Normal( - loc=exp_obs_log_v[self.ww_censored], - scale=sigma_ww_site[self.ww_sampled_lab_sites[self.ww_censored]], - ).log_cdf(self.ww_log_lod[self.ww_censored]) - - numpyro.factor("prob_conc_censored", log_cdf_values.sum()) - hospital_admission_obs_rv = NegativeBinomialObservation( "observed_hospital_admissions", concentration_rv=self.phi_rv ) @@ -388,4 +363,55 @@ def sample( obs=data_observed_hospital_admissions, ) - return (observed_hospital_admissions, log_conc_obs) + log_conc_obs = numpyro.sample( + "log_conc_obs", + CensoredNormal( + loc=exp_obs_log_v, + scale=sigma_ww_site[self.ww_sampled_lab_sites], + lower_limit=self.ww_log_lod, + ), + obs=data_observed_log_conc, + ) + + ww_pred = numpyro.sample( + "site_ww_pred", + dist.Normal( + loc=model_log_v_ot[:, self.lab_site_to_site_map] + ww_site_mod, + scale=sigma_ww_site, + ), + ) + + state_model_net_i = jnp.convolve( + state_inf_per_capita, s, mode="valid" + )[-n_datapoints:] + state_log_c = ( + jnp.log(10) * log10_g + + jnp.log(state_model_net_i[:n_datapoints] + 1e-8) + - jnp.log(self.ww_ml_produced_per_day) + ) + exp_state_ww_conc = jnp.exp(state_log_c) + + state_rt = ( + state_inf_per_capita[-n_datapoints:] + / jnp.convolve( + state_inf_per_capita, + jnp.hstack( + (jnp.array([0]), jnp.array(generation_interval_pmf)) + ), + mode="valid", + )[-n_datapoints:] + ) + + numpyro.deterministic("ww_pred", ww_pred) + numpyro.deterministic("exp_state_ww_conc", exp_state_ww_conc) + numpyro.deterministic("state_rt", state_rt) + + return ( + latent_hospital_admissions, + observed_hospital_admissions, + ww_pred, + exp_state_ww_conc, + state_rt, + r_site_t, + rtu_site, + ) diff --git a/pyrenew_covid_wastewater/utils.py b/pyrenew_covid_wastewater/utils.py index 9f8b5013..545239c4 100644 --- a/pyrenew_covid_wastewater/utils.py +++ b/pyrenew_covid_wastewater/utils.py @@ -17,7 +17,9 @@ def get_vl_trajectory(tpeak, viral_peak, duration_shedding, n): t = jnp.arange(n) s = jnp.where(t <= tpeak, jnp.power(10, growth * t), s) - s = jnp.where(t > tpeak, jnp.maximum(0, viral_peak + wane * tpeak - wane * t), s) + s = jnp.where( + t > tpeak, jnp.maximum(0, viral_peak + wane * tpeak - wane * t), s + ) s = jnp.where(t > tpeak, jnp.power(10, s), s) s = s / jnp.sum(s) From 8afc6915779a3b14b71a07c1db813bccb694d244 Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 23 Sep 2024 12:07:05 -0400 Subject: [PATCH 09/50] update gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 432ccb18..06283c09 100644 --- a/.gitignore +++ b/.gitignore @@ -122,6 +122,7 @@ target/ # Jupyter Notebook .ipynb_checkpoints +*.ipynb # IPython profile_default/ From dd0d89944e33c6144a40c719ac22a6809fea9917 Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 30 Sep 2024 10:03:06 -0400 Subject: [PATCH 10/50] use helper function --- .../site_level_dynamics_model.py | 61 +++++++++---------- 1 file changed, 29 insertions(+), 32 deletions(-) diff --git a/pyrenew_covid_wastewater/site_level_dynamics_model.py b/pyrenew_covid_wastewater/site_level_dynamics_model.py index 331eb3c5..4cafa1a0 100644 --- a/pyrenew_covid_wastewater/site_level_dynamics_model.py +++ b/pyrenew_covid_wastewater/site_level_dynamics_model.py @@ -17,6 +17,8 @@ from pyrenew.process import ARProcess, DifferencedProcess from pyrenew.randomvariable import DistributionalVariable +from pyrenew.convolve import compute_delay_ascertained_incidence + from pyrenew_covid_wastewater.utils import get_vl_trajectory @@ -85,9 +87,7 @@ def __init__( self.i_first_obs_over_n_rv = i_first_obs_over_n_rv self.sigma_i_first_obs_rv = sigma_i_first_obs_rv self.eta_i_first_obs_rv = eta_i_first_obs_rv - self.sigma_initial_exp_growth_rate_rv = ( - sigma_initial_exp_growth_rate_rv - ) + self.sigma_initial_exp_growth_rate_rv = sigma_initial_exp_growth_rate_rv self.eta_initial_exp_growth_rate_rv = eta_initial_exp_growth_rate_rv self.mean_initial_exp_growth_rate_rv = mean_initial_exp_growth_rate_rv self.generation_interval_pmf_rv = generation_interval_pmf_rv @@ -139,10 +139,7 @@ def sample( "Either n_datapoints or data_observed_hosp_admissions " "must be passed." ) - elif ( - n_datapoints is not None - and data_observed_hospital_admissions is not None - ): + elif n_datapoints is not None and data_observed_hospital_admissions is not None: raise ValueError( "Cannot pass both n_datapoints and data_observed_hospital_admissions." ) @@ -222,7 +219,9 @@ def sample( axis=0, )[:n_datapoints, :] - i0_site_rv = DeterministicVariable("log_i0_site", jnp.exp(log_i0_site)) + numpyro.deterministic("rtu_site", rtu_site) + + i0_site_rv = DeterministicVariable("i0_site", jnp.exp(log_i0_site)) initial_exp_growth_rate_site_rv = DeterministicVariable( "initial_exp_growth_rate_site", initial_exp_growth_rate_site ) @@ -254,29 +253,29 @@ def sample( ] ) r_site_t = inf_with_feedback_proc_sample.rt - numpyro.deterministic("r_site_t", r_site_t) state_inf_per_capita = jnp.sum(self.pop_fraction * new_i_site, axis=1) + numpyro.deterministic("state_inf_per_capita", state_inf_per_capita) # number of net infected individuals shedding on each day (sum of individuals in dift stages of infection) batch_colvolve_fn = lambda m: jnp.convolve(m, s, mode="valid") - model_net_i = jax.vmap(batch_colvolve_fn, in_axes=1, out_axes=1)( - new_i_site - )[-n_datapoints:, :] + model_net_i = jax.vmap(batch_colvolve_fn, in_axes=1, out_axes=1)(new_i_site)[ + -n_datapoints:, : + ] + numpyro.deterministic("model_net_i", model_net_i) log10_g = self.log10_g_rv() # expected observed viral genomes/mL at all observed and forecasted times - # [n_subpops, ot + ht] model_log_v_ot aka do it for all subpop model_log_v_ot = ( jnp.log(10) * log10_g + jnp.log(model_net_i[:n_datapoints, :] + 1e-8) - jnp.log(self.ww_ml_produced_per_day) ) + numpyro.deterministic("model_log_v_ot", model_log_v_ot) # Hospital admission component - p_hosp_mean = self.p_hosp_mean_rv() p_hosp_w_sd = self.p_hosp_w_sd_rv() autoreg_p_hosp = self.autoreg_p_hosp_rv() @@ -309,10 +308,10 @@ def sample( hosp_wday_effect = tile_until_n(hosp_wday_effect_raw, n_datapoints) - potential_latent_hospital_admissions = jnp.convolve( - state_inf_per_capita, - inf_to_hosp, - mode="valid", + potential_latent_hospital_admissions = compute_delay_ascertained_incidence( + p_observed_given_incident=1, + latent_incidence=state_inf_per_capita, + delay_incidence_to_observation_pmf=inf_to_hosp, )[-n_datapoints:] latent_hospital_admissions = ( @@ -341,13 +340,10 @@ def sample( # LHS log transformed obs genomes per person-day, RHS multiplies the expected observed # genomes by the site-specific multiplier at that sampling time - exp_obs_log_v = ( - exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] - ) + exp_obs_log_v = exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] sigma_ww_site = jnp.exp( - jnp.log(mode_sigma_ww_site) - + sd_log_sigma_ww_site * eta_log_sigma_ww_site + jnp.log(mode_sigma_ww_site) + sd_log_sigma_ww_site * eta_log_sigma_ww_site ) # g = jnp.power( @@ -381,29 +377,29 @@ def sample( ), ) - state_model_net_i = jnp.convolve( - state_inf_per_capita, s, mode="valid" - )[-n_datapoints:] + state_model_net_i = jnp.convolve(state_inf_per_capita, s, mode="valid")[ + -n_datapoints: + ] + numpyro.deterministic("state_model_net_i", state_model_net_i) + state_log_c = ( jnp.log(10) * log10_g + jnp.log(state_model_net_i[:n_datapoints] + 1e-8) - jnp.log(self.ww_ml_produced_per_day) ) + numpyro.deterministic("state_log_c", state_log_c) + exp_state_ww_conc = jnp.exp(state_log_c) + numpyro.deterministic("exp_state_ww_conc", exp_state_ww_conc) state_rt = ( state_inf_per_capita[-n_datapoints:] / jnp.convolve( state_inf_per_capita, - jnp.hstack( - (jnp.array([0]), jnp.array(generation_interval_pmf)) - ), + jnp.hstack((jnp.array([0]), jnp.array(generation_interval_pmf))), mode="valid", )[-n_datapoints:] ) - - numpyro.deterministic("ww_pred", ww_pred) - numpyro.deterministic("exp_state_ww_conc", exp_state_ww_conc) numpyro.deterministic("state_rt", state_rt) return ( @@ -414,4 +410,5 @@ def sample( state_rt, r_site_t, rtu_site, + state_inf_per_capita, ) From b77c6b8efd5d99fb3f8368f9de8cd486054c9bc0 Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 30 Sep 2024 11:28:41 -0400 Subject: [PATCH 11/50] fix precommit checks --- .../site_level_dynamics_model.py | 53 ++++++++++++------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/pyrenew_covid_wastewater/site_level_dynamics_model.py b/pyrenew_covid_wastewater/site_level_dynamics_model.py index 4cafa1a0..7606e5d2 100644 --- a/pyrenew_covid_wastewater/site_level_dynamics_model.py +++ b/pyrenew_covid_wastewater/site_level_dynamics_model.py @@ -5,6 +5,7 @@ import numpyro.distributions.transforms as transforms import pyrenew.transformation as transformation from pyrenew.arrayutils import tile_until_n +from pyrenew.convolve import compute_delay_ascertained_incidence from pyrenew.deterministic import DeterministicVariable from pyrenew.distributions import CensoredNormal from pyrenew.latent import ( @@ -17,8 +18,6 @@ from pyrenew.process import ARProcess, DifferencedProcess from pyrenew.randomvariable import DistributionalVariable -from pyrenew.convolve import compute_delay_ascertained_incidence - from pyrenew_covid_wastewater.utils import get_vl_trajectory @@ -87,7 +86,9 @@ def __init__( self.i_first_obs_over_n_rv = i_first_obs_over_n_rv self.sigma_i_first_obs_rv = sigma_i_first_obs_rv self.eta_i_first_obs_rv = eta_i_first_obs_rv - self.sigma_initial_exp_growth_rate_rv = sigma_initial_exp_growth_rate_rv + self.sigma_initial_exp_growth_rate_rv = ( + sigma_initial_exp_growth_rate_rv + ) self.eta_initial_exp_growth_rate_rv = eta_initial_exp_growth_rate_rv self.mean_initial_exp_growth_rate_rv = mean_initial_exp_growth_rate_rv self.generation_interval_pmf_rv = generation_interval_pmf_rv @@ -139,7 +140,10 @@ def sample( "Either n_datapoints or data_observed_hosp_admissions " "must be passed." ) - elif n_datapoints is not None and data_observed_hospital_admissions is not None: + elif ( + n_datapoints is not None + and data_observed_hospital_admissions is not None + ): raise ValueError( "Cannot pass both n_datapoints and data_observed_hospital_admissions." ) @@ -259,10 +263,12 @@ def sample( numpyro.deterministic("state_inf_per_capita", state_inf_per_capita) # number of net infected individuals shedding on each day (sum of individuals in dift stages of infection) - batch_colvolve_fn = lambda m: jnp.convolve(m, s, mode="valid") - model_net_i = jax.vmap(batch_colvolve_fn, in_axes=1, out_axes=1)(new_i_site)[ - -n_datapoints:, : - ] + def batch_colvolve_fn(m): + return jnp.convolve(m, s, mode="valid") + + model_net_i = jax.vmap(batch_colvolve_fn, in_axes=1, out_axes=1)( + new_i_site + )[-n_datapoints:, :] numpyro.deterministic("model_net_i", model_net_i) log10_g = self.log10_g_rv() @@ -308,11 +314,13 @@ def sample( hosp_wday_effect = tile_until_n(hosp_wday_effect_raw, n_datapoints) - potential_latent_hospital_admissions = compute_delay_ascertained_incidence( - p_observed_given_incident=1, - latent_incidence=state_inf_per_capita, - delay_incidence_to_observation_pmf=inf_to_hosp, - )[-n_datapoints:] + potential_latent_hospital_admissions = ( + compute_delay_ascertained_incidence( + p_observed_given_incident=1, + latent_incidence=state_inf_per_capita, + delay_incidence_to_observation_pmf=inf_to_hosp, + )[-n_datapoints:] + ) latent_hospital_admissions = ( potential_latent_hospital_admissions @@ -340,10 +348,13 @@ def sample( # LHS log transformed obs genomes per person-day, RHS multiplies the expected observed # genomes by the site-specific multiplier at that sampling time - exp_obs_log_v = exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] + exp_obs_log_v = ( + exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] + ) sigma_ww_site = jnp.exp( - jnp.log(mode_sigma_ww_site) + sd_log_sigma_ww_site * eta_log_sigma_ww_site + jnp.log(mode_sigma_ww_site) + + sd_log_sigma_ww_site * eta_log_sigma_ww_site ) # g = jnp.power( @@ -359,7 +370,7 @@ def sample( obs=data_observed_hospital_admissions, ) - log_conc_obs = numpyro.sample( + numpyro.sample( "log_conc_obs", CensoredNormal( loc=exp_obs_log_v, @@ -377,9 +388,9 @@ def sample( ), ) - state_model_net_i = jnp.convolve(state_inf_per_capita, s, mode="valid")[ - -n_datapoints: - ] + state_model_net_i = jnp.convolve( + state_inf_per_capita, s, mode="valid" + )[-n_datapoints:] numpyro.deterministic("state_model_net_i", state_model_net_i) state_log_c = ( @@ -396,7 +407,9 @@ def sample( state_inf_per_capita[-n_datapoints:] / jnp.convolve( state_inf_per_capita, - jnp.hstack((jnp.array([0]), jnp.array(generation_interval_pmf))), + jnp.hstack( + (jnp.array([0]), jnp.array(generation_interval_pmf)) + ), mode="valid", )[-n_datapoints:] ) From 0ddaf6d418766a36c75a2219544861d35b701023 Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 30 Sep 2024 13:36:33 -0400 Subject: [PATCH 12/50] add .qmd demo file --- .gitignore | 1 + notebooks/site_level_ww_model_demo.qmd | 397 ++++++++++++++++++ .../site_level_dynamics_model.py | 5 +- 3 files changed, 401 insertions(+), 2 deletions(-) create mode 100644 notebooks/site_level_ww_model_demo.qmd diff --git a/.gitignore b/.gitignore index 06283c09..ec4bf65a 100644 --- a/.gitignore +++ b/.gitignore @@ -388,3 +388,4 @@ poetry.lock notebooks/*_files/ notebooks/*.md +notebooks/*.quarto_ipynb diff --git a/notebooks/site_level_ww_model_demo.qmd b/notebooks/site_level_ww_model_demo.qmd new file mode 100644 index 00000000..33ef2e84 --- /dev/null +++ b/notebooks/site_level_ww_model_demo.qmd @@ -0,0 +1,397 @@ +--- +jupyter: python3 +--- + +```{python} +import json + +import numpyro +import numpyro.distributions as dist +import numpyro.distributions.transforms as transforms +from pyrenew.deterministic import DeterministicVariable, DeterministicPMF + +from pyrenew.randomvariable import DistributionalVariable, TransformedVariable +from pyrenew_covid_wastewater.site_level_dynamics_model import ww_site_level_dynamics_model +import jax + +numpyro.set_host_device_count(4) +``` + +```{python} +with open("data/fit/stan_data.json","r") as file: + stan_data = json.load(file) + +#helper function +from pyrenew_covid_wastewater.utils import * +``` + +```{python} +gt_max = stan_data["gt_max"] #lower=1 +hosp_delay_max = stan_data["hosp_delay_max"] +n_initialization_points = max(gt_max, hosp_delay_max) +i0_t_offset = 0 # check this later + +# maximum time index for the hospital admissions (max number of days we could have observations) +obs_time = stan_data["ot"] +horizon_time = stan_data["ht"] #horizon time (nowcast + forecast time) +n_weeks = stan_data["n_weeks"] +unobs_time = stan_data["uot"] #unobserved time before we observe hospital admissions/ WW +#n_datapoints = obs_time+horizon_time + +n_subpops = stan_data["n_subpops"] #number of WW sites +state_pop = stan_data["state_pop"] +subpop_size = stan_data["subpop_size"] +norm_pop = stan_data["norm_pop"] +pop_fraction = jnp.array(subpop_size)/norm_pop + +#mL of ww produced per person per day +ww_ml_produced_per_day = stan_data["mwpd"] +n_ww_lab_sites = stan_data["n_ww_lab_sites"] +ww_log_lod =jnp.array(stan_data["ww_log_lod"]) # The limit of detection in that site at that time point +lab_site_to_site_map = jnp.array(stan_data["lab_site_to_site_map"]) # which lab sites correspond to which sites + +n_censored = stan_data["n_censored"] +n_uncensored = stan_data["n_uncensored"] +ww_censored = jnp.array(stan_data["ww_censored"]) #times that the WW data is below the LOD +ww_uncensored = jnp.array(stan_data["ww_uncensored"]) #time that WW data is above LOD + +obs_ww_days = stan_data["owt"] #number of days of observed WW (should be roughly ot/7) + +ww_sampled_sites = jnp.array(stan_data["ww_sampled_sites"]) # vector of unique sites in order of the sampled times +ww_sampled_times = jnp.array(stan_data["ww_sampled_times"]) # a list of all of the days on which WW is sampled +ww_sampled_lab_sites = jnp.array(stan_data["ww_sampled_lab_sites"]) + +data_observed_log_conc = jnp.array(stan_data["log_conc"]) +data_observed_hospital_admissions = jnp.array(stan_data["hosp"]) + +``` + +```{python} +# State-leve R(t) AR + RW implementation: + +eta_sd_sd = stan_data["eta_sd_sd"] +eta_sd_rv = DistributionalVariable("eta_sd", dist.TruncatedNormal(0, eta_sd_sd, low=0)) + +autoreg_rt_a = stan_data["autoreg_rt_a"] +autoreg_rt_b = stan_data["autoreg_rt_b"] +autoreg_rt_rv = DistributionalVariable("autoreg_rt", dist.Beta(autoreg_rt_a, autoreg_rt_b)) + +r_prior_mean = stan_data["r_prior_mean"] +r_prior_sd = stan_data["r_prior_sd"] +r_logmean, r_logsd = convert_to_logmean_log_sd(r_prior_mean, r_prior_sd) + +# log of state level mean R(t) in weeks +log_r_mu_intercept_rv = DistributionalVariable( + "log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd) +) + +``` + +```{python} +# viral shedding parameters +viral_shedding_pars = stan_data["viral_shedding_pars"] + +t_peak_mean = viral_shedding_pars[0] +t_peak_sd = viral_shedding_pars[1] +viral_peak_mean = viral_shedding_pars[2] +viral_peak_sd = viral_shedding_pars[3] +dur_shed_mean = viral_shedding_pars[4] +dur_shed_sd = viral_shedding_pars[5] + +t_peak_rv = DistributionalVariable( + "t_peak", dist.TruncatedNormal(t_peak_mean, t_peak_sd,low=0) +) + +viral_peak_rv = DistributionalVariable( + "viral_peak", dist.Normal(viral_peak_mean, viral_peak_sd) +) + +dur_shed_rv = DistributionalVariable( + "dur_shed", dist.TruncatedNormal(dur_shed_mean, dur_shed_sd, low=0) +) +``` + +```{python} +# Infection and site-level dynamics +infection_feedback_prior_logmean = stan_data["inf_feedback_prior_logmean"] +infection_feedback_prior_logsd = stan_data["inf_feedback_prior_logsd"] +infection_feedback_strength_rv = TransformedVariable( + "inf_feedback", + DistributionalVariable( + "inf_feedback_raw", + dist.LogNormal(infection_feedback_prior_logmean, infection_feedback_prior_logsd), + ), + transforms=transforms.AffineTransform(loc=0, scale=-1), +) + +infection_feedback_pmf = stan_data["infection_feedback_pmf"] +infection_feedback_pmf_rv = DeterministicPMF( + "infection_feedback_pmf", jnp.array(infection_feedback_pmf) +) + +# generation interval distribution +generation_interval = stan_data["generation_interval"] +generation_interval_pmf_rv = DeterministicPMF( + "generation_interval_pmf", jnp.array(generation_interval) +) + +autoreg_rt_site_a = stan_data["autoreg_rt_site_a"] +autoreg_rt_site_b = stan_data["autoreg_rt_site_b"] +autoreg_rt_site_rv = DistributionalVariable( + "autoreg_rt_site",dist.Beta(autoreg_rt_site_a, autoreg_rt_site_b) + ) + +sigma_rt_prior = stan_data["sigma_rt_prior"] +sigma_rt_rv = DistributionalVariable( + "sigma_rt", dist.TruncatedNormal(0,sigma_rt_prior,low=0) +) + +i_first_obs_over_n_prior_a = stan_data["i_first_obs_over_n_prior_a"] +i_first_obs_over_n_prior_b = stan_data["i_first_obs_over_n_prior_b"] +i_first_obs_over_n_rv = DistributionalVariable( + "i_first_obs_over_n", dist.Beta(i_first_obs_over_n_prior_a, i_first_obs_over_n_prior_b) +) + +sigma_i_first_obs_prior_mode = stan_data["sigma_i_first_obs_prior_mode"] +sigma_i_first_obs_prior_sd = stan_data["sigma_i_first_obs_prior_sd"] +sigma_i_first_obs_rv = DistributionalVariable( + "sigma_i_first_obs", dist.TruncatedNormal( + sigma_i_first_obs_prior_mode,sigma_i_first_obs_prior_sd,low=0 + ) +) + +eta_i_first_obs_rv = DistributionalVariable( + "eta_i_first_obs", dist.Normal(0,1) +) + +eta_initial_exp_growth_rate_rv = DistributionalVariable( + "eta_initial_exp_growth_rate", dist.Normal(0,1) +) + +sigma_initial_exp_growth_rate_prior_mode = stan_data["sigma_initial_exp_growth_rate_prior_mode"] +sigma_initial_exp_growth_rate_prior_sd = stan_data["sigma_initial_exp_growth_rate_prior_sd"] +sigma_initial_exp_growth_rate_rv = DistributionalVariable( + "sigma_initial_exp_growth_rate", dist.TruncatedNormal( + sigma_initial_exp_growth_rate_prior_mode,sigma_initial_exp_growth_rate_prior_sd,low=0 + ) +) + +mean_initial_exp_growth_rate_prior_mean = stan_data["mean_initial_exp_growth_rate_prior_mean"] +mean_initial_exp_growth_rate_prior_sd = stan_data["mean_initial_exp_growth_rate_prior_sd"] +# mean_initial_exp_growth_rate_rv = DistributionalVariable( +# "mean_initial_exp_growth_rate", dist.Normal( +# mean_initial_exp_growth_rate_prior_mean, mean_initial_exp_growth_rate_prior_sd +# ) +# ) + +#stan code uses normal distribution but hosp_only model uses TruncatedNormal +mean_initial_exp_growth_rate_rv = DistributionalVariable( + "mean_initial_exp_growth_rate", dist.TruncatedNormal( + loc=mean_initial_exp_growth_rate_prior_mean, + scale=mean_initial_exp_growth_rate_prior_sd, + low=-1, + high=1, + ) +) + +``` + +```{python} +p_hosp_prior_mean = stan_data["p_hosp_prior_mean"] +p_hosp_sd_logit = stan_data["p_hosp_sd_logit"] +p_hosp_mean_rv = DistributionalVariable( + "p_hosp_mean", + dist.Normal(transforms.logit(p_hosp_prior_mean), p_hosp_sd_logit), +) # logit scale + +p_hosp_w_sd_sd = stan_data["p_hosp_w_sd_sd"] +p_hosp_w_sd_rv = DistributionalVariable( + "p_hosp_w_sd_sd", dist.TruncatedNormal(0, p_hosp_w_sd_sd, low=0) +) + +autoreg_p_hosp_a = stan_data["autoreg_p_hosp_a"] +autoreg_p_hosp_b = stan_data["autoreg_p_hosp_b"] +autoreg_p_hosp_rv = DistributionalVariable( + "autoreg_p_hosp", dist.Beta(autoreg_p_hosp_a, autoreg_p_hosp_b) +) + +hosp_wday_effect_rv = TransformedVariable( + "hosp_wday_effect", + DistributionalVariable( + "hosp_wday_effect_raw", + dist.Dirichlet(jnp.array(stan_data["hosp_wday_effect_prior_alpha"])), + ), + transforms.AffineTransform(loc=0, scale=7), +) + +inf_to_hosp_rv = DeterministicVariable( + "inf_to_hosp", jnp.array(stan_data["inf_to_hosp"]) +) + +inv_sqrt_phi_prior_mean = stan_data["inv_sqrt_phi_prior_mean"] +inv_sqrt_phi_prior_sd = stan_data["inv_sqrt_phi_prior_sd"] + +phi_rv = TransformedVariable( + "phi", + DistributionalVariable( + "inv_sqrt_phi", + dist.TruncatedNormal( + loc=inv_sqrt_phi_prior_mean, + scale=inv_sqrt_phi_prior_sd, + low=1 / jnp.sqrt(5000), + ), + ), + transforms=transforms.PowerTransform(-2), +) + +``` + +```{python} +log10_g_prior_mean = stan_data["log10_g_prior_mean"] # mean log10 of number of genomes per infected individual +log10_g_prior_sd = stan_data["log10_g_prior_sd"] +log10_g_rv = DistributionalVariable( + "log10_g", dist.Normal(log10_g_prior_mean, log10_g_prior_sd) +) + +mode_sigma_ww_site_prior_mode = stan_data["mode_sigma_ww_site_prior_mode"] +mode_sigma_ww_site_prior_sd = stan_data["mode_sigma_ww_site_prior_sd"] +mode_sigma_ww_site_rv = DistributionalVariable( + "mode_sigma_ww_site", dist.Normal(mode_sigma_ww_site_prior_mode,mode_sigma_ww_site_prior_sd) +) + +sd_log_sigma_ww_site_prior_mode = stan_data["sd_log_sigma_ww_site_prior_mode"] +sd_log_sigma_ww_site_prior_sd = stan_data["sd_log_sigma_ww_site_prior_sd"] +sd_log_sigma_ww_site_rv = DistributionalVariable( + "sd_log_sigma_ww_site", dist.TruncatedNormal(sd_log_sigma_ww_site_prior_mode, sd_log_sigma_ww_site_prior_sd,low=0) +) + +eta_log_sigma_ww_site_rv = DistributionalVariable( + "eta_log_sigma_ww_site", dist.Normal(0,1).expand([n_ww_lab_sites]) +) + +ww_site_mod_raw_rv = DistributionalVariable( + "ww_site_mod_raw", dist.Normal(0,1).expand([n_ww_lab_sites]) +) + +ww_site_mod_sd_sd = stan_data["ww_site_mod_sd_sd"] +ww_site_mod_sd_rv = DistributionalVariable( + "ww_site_mod_sd", dist.TruncatedNormal(0,ww_site_mod_sd_sd,low=0) +) +``` + +```{python} +my_model = ww_site_level_dynamics_model( + state_pop, + n_subpops, + unobs_time, + n_initialization_points, + gt_max, + i0_t_offset, + log_r_mu_intercept_rv, + autoreg_rt_rv, + eta_sd_rv, + t_peak_rv, + viral_peak_rv, + dur_shed_rv, + autoreg_rt_site_rv, + sigma_rt_rv, + i_first_obs_over_n_rv, + sigma_i_first_obs_rv, + eta_i_first_obs_rv, + sigma_initial_exp_growth_rate_rv, + eta_initial_exp_growth_rate_rv, + mean_initial_exp_growth_rate_rv, + generation_interval_pmf_rv, + infection_feedback_strength_rv, + infection_feedback_pmf_rv, + p_hosp_mean_rv, + p_hosp_w_sd_rv, + autoreg_p_hosp_rv, + hosp_wday_effect_rv, + inf_to_hosp_rv, + log10_g_rv, + mode_sigma_ww_site_rv, + sd_log_sigma_ww_site_rv, + eta_log_sigma_ww_site_rv, + ww_site_mod_raw_rv, + ww_site_mod_sd_rv, + phi_rv, + ww_ml_produced_per_day, + pop_fraction, + ww_uncensored, + ww_censored, + ww_sampled_lab_sites, + ww_sampled_sites, + ww_sampled_times, + ww_log_lod, + lab_site_to_site_map +) +``` + +```{python} +with numpyro.handlers.seed(rng_seed=9): + test_model_sample = my_model.sample(n_datapoints=50) +``` + +```{python} +latent_hospital_admissions = test_model_sample[0] +observed_hospital_admissions =test_model_sample[1] +ww_pred = test_model_sample[2] +exp_state_ww_conc=test_model_sample[3] +state_rt = test_model_sample[4] +r_site_t = test_model_sample[5] +rtu_site = test_model_sample[6] +state_inf_per_capita = test_model_sample[7] +``` + +```{python} +state_rt +``` + +```{python} +observed_hospital_admissions +``` + +```{python} +state_inf_per_capita*state_pop +``` + +```{python} +# n_forecast_days = 35 + +# prior_predictive = my_model.prior_predictive( +# n_datapoints=len(data_observed_hospital_admissions) + n_forecast_days, +# numpyro_predictive_args={"num_samples": 200}, +# ) +``` + +```{python} +my_model.run( + num_warmup=1000, + num_samples=500, + rng_key=jax.random.key(10), + data_observed_hospital_admissions=data_observed_hospital_admissions, + data_observed_log_conc=data_observed_log_conc, + mcmc_args=dict(num_chains=2, progress_bar=False), +) +``` + +```{python} +import arviz as az + +idata = az.from_numpyro( + my_model.mcmc +) +``` + +```{python} +idata.posterior.r_site_t +``` + +```{python} +idata.posterior.log_i0_site +``` + +```{python} +#summary = az.summary(idata) +``` diff --git a/pyrenew_covid_wastewater/site_level_dynamics_model.py b/pyrenew_covid_wastewater/site_level_dynamics_model.py index 7606e5d2..610e48c4 100644 --- a/pyrenew_covid_wastewater/site_level_dynamics_model.py +++ b/pyrenew_covid_wastewater/site_level_dynamics_model.py @@ -184,18 +184,19 @@ def sample( i_first_obs_over_n_site = jax.nn.sigmoid( transforms.logit(self.i_first_obs_over_n_rv()) + self.sigma_i_first_obs_rv() * self.eta_i_first_obs_rv() - ) + ) # per capita infection incidence at the first observed time initial_exp_growth_rate_site = ( self.mean_initial_exp_growth_rate_rv() + self.sigma_initial_exp_growth_rate_rv() * self.eta_initial_exp_growth_rate_rv() - ) + ) # site level unobserved period growth rate log_i0_site = ( jnp.log(i_first_obs_over_n_site) - self.unobs_time * initial_exp_growth_rate_site ) + numpyro.deterministic("log_i0_site", log_i0_site) autoreg_rt_site = self.autoreg_rt_site_rv() sigma_rt = self.sigma_rt_rv() From 604a13155cae6aeb8920884ae4a17e0f5c04c5f0 Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 30 Sep 2024 15:16:49 -0400 Subject: [PATCH 13/50] test i0 values --- notebooks/site_level_ww_model_demo.qmd | 36 +++++++------------ .../site_level_dynamics_model.py | 18 +++++----- 2 files changed, 23 insertions(+), 31 deletions(-) diff --git a/notebooks/site_level_ww_model_demo.qmd b/notebooks/site_level_ww_model_demo.qmd index 33ef2e84..ba353711 100644 --- a/notebooks/site_level_ww_model_demo.qmd +++ b/notebooks/site_level_ww_model_demo.qmd @@ -329,31 +329,21 @@ my_model = ww_site_level_dynamics_model( ``` ```{python} -with numpyro.handlers.seed(rng_seed=9): - test_model_sample = my_model.sample(n_datapoints=50) +# with numpyro.handlers.seed(rng_seed=242): +# test_model_sample = my_model.sample(n_datapoints=50) ``` - -```{python} -latent_hospital_admissions = test_model_sample[0] -observed_hospital_admissions =test_model_sample[1] -ww_pred = test_model_sample[2] -exp_state_ww_conc=test_model_sample[3] -state_rt = test_model_sample[4] -r_site_t = test_model_sample[5] -rtu_site = test_model_sample[6] -state_inf_per_capita = test_model_sample[7] -``` - -```{python} -state_rt -``` - -```{python} -observed_hospital_admissions -``` - ```{python} -state_inf_per_capita*state_pop +n_runs = 10 # Number of times to run the sampling +base_key = jax.random.PRNGKey(242) +samples = [] +for i in range(n_runs): + key = jax.random.split(base_key, n_runs)[i] + with numpyro.handlers.seed(rng_seed=key): + sample = my_model.sample(n_datapoints=50) + samples.append(sample) + +# Only samples returned is I0 +print(samples) ``` ```{python} diff --git a/pyrenew_covid_wastewater/site_level_dynamics_model.py b/pyrenew_covid_wastewater/site_level_dynamics_model.py index 610e48c4..57f2a436 100644 --- a/pyrenew_covid_wastewater/site_level_dynamics_model.py +++ b/pyrenew_covid_wastewater/site_level_dynamics_model.py @@ -417,12 +417,14 @@ def batch_colvolve_fn(m): numpyro.deterministic("state_rt", state_rt) return ( - latent_hospital_admissions, - observed_hospital_admissions, - ww_pred, - exp_state_ww_conc, - state_rt, - r_site_t, - rtu_site, - state_inf_per_capita, + # latent_hospital_admissions, + # observed_hospital_admissions, + # ww_pred, + # exp_state_ww_conc, + # state_rt, + # r_site_t, + # rtu_site, + # state_inf_per_capita, + # log_i0_site, + i0, ) From c09dfe2f018c0faa40271054efc69ad5985a642b Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 30 Sep 2024 19:51:37 -0400 Subject: [PATCH 14/50] record variables for debug --- pyrenew_covid_wastewater/site_level_dynamics_model.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pyrenew_covid_wastewater/site_level_dynamics_model.py b/pyrenew_covid_wastewater/site_level_dynamics_model.py index 57f2a436..a56254d5 100644 --- a/pyrenew_covid_wastewater/site_level_dynamics_model.py +++ b/pyrenew_covid_wastewater/site_level_dynamics_model.py @@ -191,6 +191,9 @@ def sample( + self.sigma_initial_exp_growth_rate_rv() * self.eta_initial_exp_growth_rate_rv() ) # site level unobserved period growth rate + numpyro.deterministic( + "initial_exp_growth_rate_site", initial_exp_growth_rate_site + ) log_i0_site = ( jnp.log(i_first_obs_over_n_site) @@ -329,6 +332,9 @@ def batch_colvolve_fn(m): * hosp_wday_effect * self.state_pop ) + numpyro.deterministic( + "latent_hospital_admissions", latent_hospital_admissions + ) mode_sigma_ww_site = self.mode_sigma_ww_site_rv() sd_log_sigma_ww_site = self.sd_log_sigma_ww_site_rv() @@ -388,6 +394,7 @@ def batch_colvolve_fn(m): scale=sigma_ww_site, ), ) + numpyro.deterministic("ww_pred", ww_pred) state_model_net_i = jnp.convolve( state_inf_per_capita, s, mode="valid" From ce3ff89b66dd766b92af1eaea94a72c76db0b884 Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 2 Oct 2024 15:18:44 -0400 Subject: [PATCH 15/50] print trace for debugging and change variables incorrectly specifies as vector to scalar --- notebooks/site_level_ww_model_demo.qmd | 65 +++++--------- .../site_level_dynamics_model.py | 90 ++++++++++--------- 2 files changed, 73 insertions(+), 82 deletions(-) diff --git a/notebooks/site_level_ww_model_demo.qmd b/notebooks/site_level_ww_model_demo.qmd index ba353711..cfa85e24 100644 --- a/notebooks/site_level_ww_model_demo.qmd +++ b/notebooks/site_level_ww_model_demo.qmd @@ -178,12 +178,6 @@ sigma_initial_exp_growth_rate_rv = DistributionalVariable( mean_initial_exp_growth_rate_prior_mean = stan_data["mean_initial_exp_growth_rate_prior_mean"] mean_initial_exp_growth_rate_prior_sd = stan_data["mean_initial_exp_growth_rate_prior_sd"] -# mean_initial_exp_growth_rate_rv = DistributionalVariable( -# "mean_initial_exp_growth_rate", dist.Normal( -# mean_initial_exp_growth_rate_prior_mean, mean_initial_exp_growth_rate_prior_sd -# ) -# ) - #stan code uses normal distribution but hosp_only model uses TruncatedNormal mean_initial_exp_growth_rate_rv = DistributionalVariable( "mean_initial_exp_growth_rate", dist.TruncatedNormal( @@ -266,11 +260,11 @@ sd_log_sigma_ww_site_rv = DistributionalVariable( ) eta_log_sigma_ww_site_rv = DistributionalVariable( - "eta_log_sigma_ww_site", dist.Normal(0,1).expand([n_ww_lab_sites]) + "eta_log_sigma_ww_site", dist.Normal(0,1) ) ww_site_mod_raw_rv = DistributionalVariable( - "ww_site_mod_raw", dist.Normal(0,1).expand([n_ww_lab_sites]) + "ww_site_mod_raw", dist.Normal(0,1) ) ww_site_mod_sd_sd = stan_data["ww_site_mod_sd_sd"] @@ -283,6 +277,7 @@ ww_site_mod_sd_rv = DistributionalVariable( my_model = ww_site_level_dynamics_model( state_pop, n_subpops, + n_ww_lab_sites, unobs_time, n_initialization_points, gt_max, @@ -329,21 +324,8 @@ my_model = ww_site_level_dynamics_model( ``` ```{python} -# with numpyro.handlers.seed(rng_seed=242): -# test_model_sample = my_model.sample(n_datapoints=50) -``` -```{python} -n_runs = 10 # Number of times to run the sampling -base_key = jax.random.PRNGKey(242) -samples = [] -for i in range(n_runs): - key = jax.random.split(base_key, n_runs)[i] - with numpyro.handlers.seed(rng_seed=key): - sample = my_model.sample(n_datapoints=50) - samples.append(sample) - -# Only samples returned is I0 -print(samples) +with numpyro.handlers.seed(rng_seed=242): + test_model_sample = my_model.sample(n_datapoints=50) ``` ```{python} @@ -356,32 +338,33 @@ print(samples) ``` ```{python} +from numpyro.infer.initialization import init_to_sample my_model.run( - num_warmup=1000, - num_samples=500, - rng_key=jax.random.key(10), + num_warmup=100, + num_samples=100, + rng_key=jax.random.key(223), data_observed_hospital_admissions=data_observed_hospital_admissions, data_observed_log_conc=data_observed_log_conc, - mcmc_args=dict(num_chains=2, progress_bar=False), + mcmc_args=dict(num_chains=1), + nuts_args=dict(init_strategy=init_to_sample) ) ``` ```{python} -import arviz as az - -idata = az.from_numpyro( - my_model.mcmc -) -``` - -```{python} -idata.posterior.r_site_t -``` - -```{python} -idata.posterior.log_i0_site +with numpyro.handlers.trace() as tr: + my_model.run( + num_warmup=100, + num_samples=100, + rng_key=jax.random.key(223), + data_observed_hospital_admissions=data_observed_hospital_admissions, + data_observed_log_conc=data_observed_log_conc, + mcmc_args=dict(num_chains=1), + nuts_args=dict(init_strategy=init_to_sample) + ) ``` ```{python} -#summary = az.summary(idata) +# Print trace of the random variables +for site in tr.values(): + print(site['name'], site['value']) ``` diff --git a/pyrenew_covid_wastewater/site_level_dynamics_model.py b/pyrenew_covid_wastewater/site_level_dynamics_model.py index a56254d5..c70bcd75 100644 --- a/pyrenew_covid_wastewater/site_level_dynamics_model.py +++ b/pyrenew_covid_wastewater/site_level_dynamics_model.py @@ -26,6 +26,7 @@ def __init__( self, state_pop, n_subpops, + n_ww_lab_sites, unobs_time, n_initialization_points, gt_max, @@ -71,6 +72,7 @@ def __init__( ): # numpydoc ignore=GL08 self.state_pop = state_pop self.n_subpops = n_subpops + self.n_ww_lab_sites = n_ww_lab_sites self.unobs_time = unobs_time self.n_initialization_points = n_initialization_points self.gt_max = gt_max @@ -173,6 +175,7 @@ def sample( noise_sd=jnp.array(eta_sd), fundamental_process_init_vals=jnp.array(rt_init_rate_of_change), ) + numpyro.deterministic("log_rtu_weekly", log_rtu_weekly) t_peak = self.t_peak_rv() viral_peak = self.viral_peak_rv() @@ -181,35 +184,46 @@ def sample( s = get_vl_trajectory(t_peak, viral_peak, dur_shed, self.gt_max) with numpyro.plate("n_subpops", self.n_subpops): - i_first_obs_over_n_site = jax.nn.sigmoid( - transforms.logit(self.i_first_obs_over_n_rv()) - + self.sigma_i_first_obs_rv() * self.eta_i_first_obs_rv() - ) # per capita infection incidence at the first observed time + eta_i_first_obs = self.eta_i_first_obs_rv() + eta_initial_exp_growth_rate = self.eta_initial_exp_growth_rate_rv() - initial_exp_growth_rate_site = ( + i_first_obs_over_n_site = jax.nn.sigmoid( + transforms.logit(self.i_first_obs_over_n_rv()) + + self.sigma_i_first_obs_rv() * eta_i_first_obs + ) # per capita infection incidence at the first observed time + numpyro.deterministic( + "i_first_obs_over_n_site", i_first_obs_over_n_site + ) + + initial_exp_growth_rate_site = jnp.clip( + ( self.mean_initial_exp_growth_rate_rv() + self.sigma_initial_exp_growth_rate_rv() - * self.eta_initial_exp_growth_rate_rv() - ) # site level unobserved period growth rate - numpyro.deterministic( - "initial_exp_growth_rate_site", initial_exp_growth_rate_site - ) + * eta_initial_exp_growth_rate + ), + a_min=-0.005, + a_max=0.005, + ) # site level unobserved period growth rate + numpyro.deterministic( + "initial_exp_growth_rate_site", initial_exp_growth_rate_site + ) - log_i0_site = ( - jnp.log(i_first_obs_over_n_site) - - self.unobs_time * initial_exp_growth_rate_site - ) - numpyro.deterministic("log_i0_site", log_i0_site) - - autoreg_rt_site = self.autoreg_rt_site_rv() - sigma_rt = self.sigma_rt_rv() - rtu_site_ar_init_rv = DistributionalVariable( - "rtu_site_ar_init", - dist.Normal( - 0, - sigma_rt / jnp.sqrt(1 - jnp.pow(autoreg_rt_site, 2)), - ), - ) + log_i0_site = ( + jnp.log(i_first_obs_over_n_site) + - self.unobs_time * initial_exp_growth_rate_site + ) + numpyro.deterministic("log_i0_site", log_i0_site) + + autoreg_rt_site = self.autoreg_rt_site_rv() + sigma_rt = self.sigma_rt_rv() + rtu_site_ar_init_rv = DistributionalVariable( + "rtu_site_ar_init", + dist.Normal( + 0, + sigma_rt / jnp.sqrt(1 - jnp.pow(autoreg_rt_site, 2)), + ), + ) + with numpyro.plate("n_subpops", self.n_subpops): rtu_site_ar_init = rtu_site_ar_init_rv() rtu_site_ar_proc = ARProcess() @@ -221,6 +235,8 @@ def sample( noise_sd=sigma_rt, ) + numpyro.deterministic("rtu_site_ar_weekly", rtu_site_ar_weekly) + rtu_site = jnp.repeat( jnp.exp(rtu_site_ar_weekly + log_rtu_weekly[:, jnp.newaxis]), repeats=7, @@ -246,6 +262,7 @@ def sample( generation_interval_pmf = self.generation_interval_pmf_rv() i0 = infection_initialization_process() + numpyro.deterministic("i0", i0) with numpyro.plate("n_subpops", self.n_subpops): inf_with_feedback_proc_sample = self.inf_with_feedback_proc.sample( @@ -301,7 +318,7 @@ def batch_colvolve_fn(m): p_hosp_ar_init = p_hosp_ar_init_rv() p_hosp_ar_proc = ARProcess() p_hosp_ar = p_hosp_ar_proc.sample( - noise_name="p_hosp", + noise_name="p_hosp_noise", n=n_weeks_post_init, autoreg=autoreg_p_hosp, init_vals=p_hosp_ar_init, @@ -312,6 +329,7 @@ def batch_colvolve_fn(m): transformation.SigmoidTransform()(p_hosp_ar + p_hosp_mean), repeats=7, )[:n_datapoints] + numpyro.deterministic("ihr", ihr) hosp_wday_effect_raw = self.hosp_wday_effect_rv() inf_to_hosp = self.inf_to_hosp_rv() @@ -338,8 +356,10 @@ def batch_colvolve_fn(m): mode_sigma_ww_site = self.mode_sigma_ww_site_rv() sd_log_sigma_ww_site = self.sd_log_sigma_ww_site_rv() - eta_log_sigma_ww_site = self.eta_log_sigma_ww_site_rv() - ww_site_mod_raw = self.ww_site_mod_raw_rv() + + with numpyro.plate("n_ww_lab_sites", self.n_ww_lab_sites): + eta_log_sigma_ww_site = self.eta_log_sigma_ww_site_rv() + ww_site_mod_raw = self.ww_site_mod_raw_rv() ww_site_mod_sd = self.ww_site_mod_sd_rv() # Observations at the site level (genomes/person/day) are: @@ -394,7 +414,6 @@ def batch_colvolve_fn(m): scale=sigma_ww_site, ), ) - numpyro.deterministic("ww_pred", ww_pred) state_model_net_i = jnp.convolve( state_inf_per_capita, s, mode="valid" @@ -423,15 +442,4 @@ def batch_colvolve_fn(m): ) numpyro.deterministic("state_rt", state_rt) - return ( - # latent_hospital_admissions, - # observed_hospital_admissions, - # ww_pred, - # exp_state_ww_conc, - # state_rt, - # r_site_t, - # rtu_site, - # state_inf_per_capita, - # log_i0_site, - i0, - ) + return latent_hospital_admissions From 4acb257214197d3d1b311a10c19b9de8d335f7fa Mon Sep 17 00:00:00 2001 From: Subekshya Bidari <37636707+sbidari@users.noreply.github.com> Date: Wed, 2 Oct 2024 14:56:17 -0600 Subject: [PATCH 16/50] use LocScaleParam for non-centered parametrization (#24) * use LocScaleParam for non centered parametrization * Apply suggestions from code review Co-authored-by: Dylan H. Morris --------- Co-authored-by: Dylan H. Morris --- notebooks/site_level_ww_model_demo.qmd | 44 ++++++++-------- .../site_level_dynamics_model.py | 50 +++++++++---------- 2 files changed, 45 insertions(+), 49 deletions(-) diff --git a/notebooks/site_level_ww_model_demo.qmd b/notebooks/site_level_ww_model_demo.qmd index cfa85e24..d74a29f3 100644 --- a/notebooks/site_level_ww_model_demo.qmd +++ b/notebooks/site_level_ww_model_demo.qmd @@ -160,14 +160,6 @@ sigma_i_first_obs_rv = DistributionalVariable( ) ) -eta_i_first_obs_rv = DistributionalVariable( - "eta_i_first_obs", dist.Normal(0,1) -) - -eta_initial_exp_growth_rate_rv = DistributionalVariable( - "eta_initial_exp_growth_rate", dist.Normal(0,1) -) - sigma_initial_exp_growth_rate_prior_mode = stan_data["sigma_initial_exp_growth_rate_prior_mode"] sigma_initial_exp_growth_rate_prior_sd = stan_data["sigma_initial_exp_growth_rate_prior_sd"] sigma_initial_exp_growth_rate_rv = DistributionalVariable( @@ -292,9 +284,7 @@ my_model = ww_site_level_dynamics_model( sigma_rt_rv, i_first_obs_over_n_rv, sigma_i_first_obs_rv, - eta_i_first_obs_rv, sigma_initial_exp_growth_rate_rv, - eta_initial_exp_growth_rate_rv, mean_initial_exp_growth_rate_rv, generation_interval_pmf_rv, infection_feedback_strength_rv, @@ -338,20 +328,8 @@ with numpyro.handlers.seed(rng_seed=242): ``` ```{python} -from numpyro.infer.initialization import init_to_sample -my_model.run( - num_warmup=100, - num_samples=100, - rng_key=jax.random.key(223), - data_observed_hospital_admissions=data_observed_hospital_admissions, - data_observed_log_conc=data_observed_log_conc, - mcmc_args=dict(num_chains=1), - nuts_args=dict(init_strategy=init_to_sample) -) -``` - -```{python} -with numpyro.handlers.trace() as tr: +try: + from numpyro.infer.initialization import init_to_sample my_model.run( num_warmup=100, num_samples=100, @@ -361,6 +339,24 @@ with numpyro.handlers.trace() as tr: mcmc_args=dict(num_chains=1), nuts_args=dict(init_strategy=init_to_sample) ) +except RuntimeError as e: + print(f"RuntimeError occurred: {e}") +``` + +```{python} +try: + with numpyro.handlers.trace() as tr: + my_model.run( + num_warmup=100, + num_samples=100, + rng_key=jax.random.key(223), + data_observed_hospital_admissions=data_observed_hospital_admissions, + data_observed_log_conc=data_observed_log_conc, + mcmc_args=dict(num_chains=1), + nuts_args=dict(init_strategy=init_to_sample) + ) +except AssertionError as e: + print(f"AssertionError occurred: {e}") ``` ```{python} diff --git a/pyrenew_covid_wastewater/site_level_dynamics_model.py b/pyrenew_covid_wastewater/site_level_dynamics_model.py index c70bcd75..3b6c5bd2 100644 --- a/pyrenew_covid_wastewater/site_level_dynamics_model.py +++ b/pyrenew_covid_wastewater/site_level_dynamics_model.py @@ -4,6 +4,7 @@ import numpyro.distributions as dist import numpyro.distributions.transforms as transforms import pyrenew.transformation as transformation +from numpyro.infer.reparam import LocScaleReparam from pyrenew.arrayutils import tile_until_n from pyrenew.convolve import compute_delay_ascertained_incidence from pyrenew.deterministic import DeterministicVariable @@ -16,7 +17,7 @@ from pyrenew.metaclass import Model from pyrenew.observation import NegativeBinomialObservation from pyrenew.process import ARProcess, DifferencedProcess -from pyrenew.randomvariable import DistributionalVariable +from pyrenew.randomvariable import DistributionalVariable, TransformedVariable from pyrenew_covid_wastewater.utils import get_vl_trajectory @@ -41,9 +42,7 @@ def __init__( sigma_rt_rv, i_first_obs_over_n_rv, sigma_i_first_obs_rv, - eta_i_first_obs_rv, sigma_initial_exp_growth_rate_rv, - eta_initial_exp_growth_rate_rv, mean_initial_exp_growth_rate_rv, generation_interval_pmf_rv, infection_feedback_strength_rv, @@ -87,11 +86,9 @@ def __init__( self.sigma_rt_rv = sigma_rt_rv self.i_first_obs_over_n_rv = i_first_obs_over_n_rv self.sigma_i_first_obs_rv = sigma_i_first_obs_rv - self.eta_i_first_obs_rv = eta_i_first_obs_rv self.sigma_initial_exp_growth_rate_rv = ( sigma_initial_exp_growth_rate_rv ) - self.eta_initial_exp_growth_rate_rv = eta_initial_exp_growth_rate_rv self.mean_initial_exp_growth_rate_rv = mean_initial_exp_growth_rate_rv self.generation_interval_pmf_rv = generation_interval_pmf_rv self.p_hosp_mean_rv = p_hosp_mean_rv @@ -183,31 +180,34 @@ def sample( s = get_vl_trajectory(t_peak, viral_peak, dur_shed, self.gt_max) - with numpyro.plate("n_subpops", self.n_subpops): - eta_i_first_obs = self.eta_i_first_obs_rv() - eta_initial_exp_growth_rate = self.eta_initial_exp_growth_rate_rv() - - i_first_obs_over_n_site = jax.nn.sigmoid( - transforms.logit(self.i_first_obs_over_n_rv()) - + self.sigma_i_first_obs_rv() * eta_i_first_obs - ) # per capita infection incidence at the first observed time - numpyro.deterministic( - "i_first_obs_over_n_site", i_first_obs_over_n_site + mean_initial_exp_growth_rate = self.mean_initial_exp_growth_rate_rv() + sigma_initial_exp_growth_rate = self.sigma_initial_exp_growth_rate_rv() + initial_exp_growth_rate_site_rv = DistributionalVariable( + "initial_exp_growth_rate_site", + dist.Normal( + mean_initial_exp_growth_rate, sigma_initial_exp_growth_rate + ), + reparam=LocScaleReparam(0), ) - initial_exp_growth_rate_site = jnp.clip( - ( - self.mean_initial_exp_growth_rate_rv() - + self.sigma_initial_exp_growth_rate_rv() - * eta_initial_exp_growth_rate + i_first_obs_over_n = self.i_first_obs_over_n_rv() + sigma_i_first_obs = self.sigma_i_first_obs_rv() + i_first_obs_over_n_site_rv = TransformedVariable( + "i_first_obs_over_n_site", + DistributionalVariable( + "i_first_obs_over_n_site_raw", + dist.Normal( + transforms.logit(i_first_obs_over_n), sigma_i_first_obs + ), + reparam=LocScaleReparam(0), ), - a_min=-0.005, - a_max=0.005, - ) # site level unobserved period growth rate - numpyro.deterministic( - "initial_exp_growth_rate_site", initial_exp_growth_rate_site + transforms=transforms.SigmoidTransform(), ) + with numpyro.plate("n_subpops", self.n_subpops): + initial_exp_growth_rate_site = initial_exp_growth_rate_site_rv() + i_first_obs_over_n_site = i_first_obs_over_n_site_rv() + log_i0_site = ( jnp.log(i_first_obs_over_n_site) - self.unobs_time * initial_exp_growth_rate_site From af060dd8fc1887ffa5dbd9af25d9b6430e3eb5c8 Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 2 Oct 2024 20:14:01 -0400 Subject: [PATCH 17/50] further use of LocScaleReparam --- notebooks/site_level_ww_model_demo.qmd | 10 ---- .../site_level_dynamics_model.py | 46 ++++++++++--------- 2 files changed, 25 insertions(+), 31 deletions(-) diff --git a/notebooks/site_level_ww_model_demo.qmd b/notebooks/site_level_ww_model_demo.qmd index d74a29f3..7b7f276d 100644 --- a/notebooks/site_level_ww_model_demo.qmd +++ b/notebooks/site_level_ww_model_demo.qmd @@ -251,14 +251,6 @@ sd_log_sigma_ww_site_rv = DistributionalVariable( "sd_log_sigma_ww_site", dist.TruncatedNormal(sd_log_sigma_ww_site_prior_mode, sd_log_sigma_ww_site_prior_sd,low=0) ) -eta_log_sigma_ww_site_rv = DistributionalVariable( - "eta_log_sigma_ww_site", dist.Normal(0,1) -) - -ww_site_mod_raw_rv = DistributionalVariable( - "ww_site_mod_raw", dist.Normal(0,1) -) - ww_site_mod_sd_sd = stan_data["ww_site_mod_sd_sd"] ww_site_mod_sd_rv = DistributionalVariable( "ww_site_mod_sd", dist.TruncatedNormal(0,ww_site_mod_sd_sd,low=0) @@ -297,8 +289,6 @@ my_model = ww_site_level_dynamics_model( log10_g_rv, mode_sigma_ww_site_rv, sd_log_sigma_ww_site_rv, - eta_log_sigma_ww_site_rv, - ww_site_mod_raw_rv, ww_site_mod_sd_rv, phi_rv, ww_ml_produced_per_day, diff --git a/pyrenew_covid_wastewater/site_level_dynamics_model.py b/pyrenew_covid_wastewater/site_level_dynamics_model.py index 3b6c5bd2..7829b688 100644 --- a/pyrenew_covid_wastewater/site_level_dynamics_model.py +++ b/pyrenew_covid_wastewater/site_level_dynamics_model.py @@ -55,8 +55,6 @@ def __init__( log10_g_rv, mode_sigma_ww_site_rv, sd_log_sigma_ww_site_rv, - eta_log_sigma_ww_site_rv, - ww_site_mod_raw_rv, ww_site_mod_sd_rv, phi_rv, ww_ml_produced_per_day, @@ -99,8 +97,6 @@ def __init__( self.log10_g_rv = log10_g_rv self.mode_sigma_ww_site_rv = mode_sigma_ww_site_rv self.sd_log_sigma_ww_site_rv = sd_log_sigma_ww_site_rv - self.eta_log_sigma_ww_site_rv = eta_log_sigma_ww_site_rv - self.ww_site_mod_raw_rv = ww_site_mod_raw_rv self.ww_site_mod_sd_rv = ww_site_mod_sd_rv self.phi_rv = phi_rv self.ww_ml_produced_per_day = ww_ml_produced_per_day @@ -283,7 +279,7 @@ def sample( state_inf_per_capita = jnp.sum(self.pop_fraction * new_i_site, axis=1) numpyro.deterministic("state_inf_per_capita", state_inf_per_capita) - # number of net infected individuals shedding on each day (sum of individuals in dift stages of infection) + # number of net infected individuals shedding on each day (sum of individuals in diff stages of infection) def batch_colvolve_fn(m): return jnp.convolve(m, s, mode="valid") @@ -356,11 +352,27 @@ def batch_colvolve_fn(m): mode_sigma_ww_site = self.mode_sigma_ww_site_rv() sd_log_sigma_ww_site = self.sd_log_sigma_ww_site_rv() + ww_site_mod_sd = self.ww_site_mod_sd_rv() + + ww_site_mod_rv = DistributionalVariable( + "ww_site_mod", + dist.Normal(0, ww_site_mod_sd), + reparam=LocScaleReparam(0), + ) # lab-site specific variation + + sigma_ww_site_rv = TransformedVariable( + "sigma_ww_site", + DistributionalVariable( + "log_sigma_ww_site", + dist.Normal(jnp.log(mode_sigma_ww_site), sd_log_sigma_ww_site), + reparam=LocScaleReparam(0), + ), + transforms=transforms.ExpTransform(), + ) with numpyro.plate("n_ww_lab_sites", self.n_ww_lab_sites): - eta_log_sigma_ww_site = self.eta_log_sigma_ww_site_rv() - ww_site_mod_raw = self.ww_site_mod_raw_rv() - ww_site_mod_sd = self.ww_site_mod_sd_rv() + ww_site_mod = ww_site_mod_rv() + sigma_ww_site = sigma_ww_site_rv() # Observations at the site level (genomes/person/day) are: # get a vector of genomes/person/day on the days WW was measured @@ -370,24 +382,12 @@ def batch_colvolve_fn(m): self.ww_sampled_sites, self.ww_sampled_times ] - # modify by lab-site specific variation (multiplier!) - ww_site_mod = ww_site_mod_raw * ww_site_mod_sd - # LHS log transformed obs genomes per person-day, RHS multiplies the expected observed # genomes by the site-specific multiplier at that sampling time exp_obs_log_v = ( exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] ) - sigma_ww_site = jnp.exp( - jnp.log(mode_sigma_ww_site) - + sd_log_sigma_ww_site * eta_log_sigma_ww_site - ) - - # g = jnp.power( - # log10_g, 10 - # ) # Estimated genomes shed per infected individual - hospital_admission_obs_rv = NegativeBinomialObservation( "observed_hospital_admissions", concentration_rv=self.phi_rv ) @@ -442,4 +442,8 @@ def batch_colvolve_fn(m): ) numpyro.deterministic("state_rt", state_rt) - return latent_hospital_admissions + return ( + latent_hospital_admissions, + observed_hospital_admissions, + ww_pred, + ) From b0d264b421634b12ce4b0a77ceefe0f9e7ceeba0 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 3 Oct 2024 16:32:19 -0400 Subject: [PATCH 18/50] rename pyrenew-hew --- pyrenew_hew/site_level_dynamics_model.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/pyrenew_hew/site_level_dynamics_model.py b/pyrenew_hew/site_level_dynamics_model.py index 7829b688..333e861b 100644 --- a/pyrenew_hew/site_level_dynamics_model.py +++ b/pyrenew_hew/site_level_dynamics_model.py @@ -19,7 +19,7 @@ from pyrenew.process import ARProcess, DifferencedProcess from pyrenew.randomvariable import DistributionalVariable, TransformedVariable -from pyrenew_covid_wastewater.utils import get_vl_trajectory +from pyrenew_hew.utils import get_vl_trajectory class ww_site_level_dynamics_model(Model): # numpydoc ignore=GL08 @@ -178,12 +178,17 @@ def sample( mean_initial_exp_growth_rate = self.mean_initial_exp_growth_rate_rv() sigma_initial_exp_growth_rate = self.sigma_initial_exp_growth_rate_rv() - initial_exp_growth_rate_site_rv = DistributionalVariable( - "initial_exp_growth_rate_site", - dist.Normal( - mean_initial_exp_growth_rate, sigma_initial_exp_growth_rate + initial_exp_growth_rate_site_rv = TransformedVariable( + "clipped_initial_exp_growth_rate_site", + DistributionalVariable( + "initial_exp_growth_rate_site_raw", + dist.Normal( + mean_initial_exp_growth_rate, + sigma_initial_exp_growth_rate, + ), + reparam=LocScaleReparam(0), ), - reparam=LocScaleReparam(0), + transforms=lambda x: jnp.clip(x, -0.01, 0.01), ) i_first_obs_over_n = self.i_first_obs_over_n_rv() @@ -204,6 +209,10 @@ def sample( initial_exp_growth_rate_site = initial_exp_growth_rate_site_rv() i_first_obs_over_n_site = i_first_obs_over_n_site_rv() + numpyro.deterministic( + "initial_exp_growth_rate_site", initial_exp_growth_rate_site + ) + log_i0_site = ( jnp.log(i_first_obs_over_n_site) - self.unobs_time * initial_exp_growth_rate_site From 66cb48c0d8ef7c97dd16b087d7311e195d25e864 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 3 Oct 2024 19:03:11 -0400 Subject: [PATCH 19/50] name changes --- notebooks/site_level_ww_model_demo.qmd | 2 +- pyrenew_hew/site_level_dynamics_model.py | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/notebooks/site_level_ww_model_demo.qmd b/notebooks/site_level_ww_model_demo.qmd index 7b7f276d..87dbe6c7 100644 --- a/notebooks/site_level_ww_model_demo.qmd +++ b/notebooks/site_level_ww_model_demo.qmd @@ -22,7 +22,7 @@ with open("data/fit/stan_data.json","r") as file: stan_data = json.load(file) #helper function -from pyrenew_covid_wastewater.utils import * +from pyrenew_hew.utils import * ``` ```{python} diff --git a/pyrenew_hew/site_level_dynamics_model.py b/pyrenew_hew/site_level_dynamics_model.py index 333e861b..e96080a3 100644 --- a/pyrenew_hew/site_level_dynamics_model.py +++ b/pyrenew_hew/site_level_dynamics_model.py @@ -269,12 +269,11 @@ def sample( i0 = infection_initialization_process() numpyro.deterministic("i0", i0) - with numpyro.plate("n_subpops", self.n_subpops): - inf_with_feedback_proc_sample = self.inf_with_feedback_proc.sample( - Rt=rtu_site, - I0=i0, - gen_int=generation_interval_pmf, - ) + inf_with_feedback_proc_sample = self.inf_with_feedback_proc.sample( + Rt=rtu_site, + I0=i0, + gen_int=generation_interval_pmf, + ) new_i_site = jnp.concat( [ From 2eb3c1ca78211df7d1e63c16683f22e6c46a9026 Mon Sep 17 00:00:00 2001 From: sbidari Date: Fri, 4 Oct 2024 15:04:33 -0400 Subject: [PATCH 20/50] use init_values initialization --- notebooks/data/fit/stan_data.json | 26 +++++----- notebooks/site_level_ww_model_demo.qmd | 16 +++--- pyrenew_hew/initialization.py | 72 ++++++++++++++++++++++++++ 3 files changed, 95 insertions(+), 19 deletions(-) create mode 100644 pyrenew_hew/initialization.py diff --git a/notebooks/data/fit/stan_data.json b/notebooks/data/fit/stan_data.json index 025816d8..6eb8e1e3 100644 --- a/notebooks/data/fit/stan_data.json +++ b/notebooks/data/fit/stan_data.json @@ -7,10 +7,10 @@ "n_subpops": 5, "n_ww_sites": 4, "n_ww_lab_sites": 5, - "owt": 88, + "owt": 98, "oht": 90, "n_censored": 0, - "n_uncensored": 88, + "n_uncensored": 98, "uot": 50, "ht": 35, "n_weeks": 18, @@ -322,17 +322,17 @@ "generation_interval": [0.161701189933765, 0.320525743089203, 0.242198071982593, 0.134825252524032, 0.0689141939998525, 0.0346219683116734, 0.017497710736154, 0.00908172017279556, 0.00483656086299504, 0.00260732346885217, 0.00143298046642562, 0.00082002579123121, 0.0004729600977183, 0.000284420637980485, 0.000179877924728358], "ts": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], "state_pop": 3000000.0, - "subpop_size": [400000.0, 200000.0, 100000.0, 50000.0, 2250000.0], + "subpop_size": [50000.0, 100000.0, 200000.0, 400000.0, 2250000.0], "norm_pop": 3000000.0, - "ww_sampled_times": [2, 5, 6, 6, 8, 9, 11, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 18, 18, 18, 19, 20, 21, 22, 23, 23, 25, 26, 27, 29, 29, 29, 31, 32, 32, 33, 33, 34, 36, 36, 37, 37, 37, 39, 42, 42, 42, 43, 45, 45, 46, 47, 48, 51, 53, 58, 58, 59, 59, 63, 63, 64, 65, 65, 67, 70, 70, 73, 73, 74, 75, 76, 76, 76, 78, 80, 81, 82, 83, 83, 84, 87, 89, 91, 92, 93, 93, 95], + "ww_sampled_times": [2, 5, 6, 6, 6, 8, 9, 11, 12, 12, 13, 13, 14, 14, 14, 15, 15, 18, 18, 18, 19, 20, 21, 21, 22, 23, 23, 25, 26, 27, 27, 29, 29, 29, 31, 31, 32, 32, 33, 33, 34, 36, 37, 37, 37, 39, 40, 42, 42, 42, 43, 45, 45, 46, 47, 47, 48, 50, 51, 53, 58, 58, 59, 59, 60, 62, 63, 64, 65, 65, 67, 69, 70, 70, 73, 73, 74, 75, 75, 76, 76, 76, 78, 79, 81, 82, 83, 83, 86, 87, 87, 88, 89, 89, 92, 92, 92, 92], "hosp_times": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90], - "ww_sampled_lab_sites": [1, 5, 4, 5, 3, 2, 4, 1, 3, 2, 3, 5, 1, 3, 5, 3, 5, 2, 4, 5, 2, 3, 1, 2, 3, 5, 4, 3, 3, 2, 3, 4, 5, 1, 3, 3, 5, 5, 3, 4, 2, 3, 4, 2, 1, 3, 4, 4, 1, 5, 3, 2, 1, 4, 4, 2, 4, 1, 2, 1, 2, 4, 1, 3, 1, 1, 3, 4, 5, 1, 3, 2, 3, 4, 5, 4, 5, 5, 2, 4, 4, 4, 4, 3, 2, 3, 5, 1], - "ww_log_lod": [5.09434727489065, 4.9806950154474, 4.73771588167502, 4.9806950154474, 5.2940513994166, 4.76390186342314, 4.73771588167502, 5.09434727489065, 5.2940513994166, 4.76390186342314, 5.2940513994166, 4.9806950154474, 5.09434727489065, 5.2940513994166, 4.9806950154474, 5.2940513994166, 4.9806950154474, 4.76390186342314, 4.73771588167502, 4.9806950154474, 4.76390186342314, 5.2940513994166, 5.09434727489065, 4.76390186342314, 5.2940513994166, 4.9806950154474, 4.73771588167502, 5.2940513994166, 5.2940513994166, 4.76390186342314, 5.2940513994166, 4.73771588167502, 4.9806950154474, 5.09434727489065, 5.2940513994166, 5.2940513994166, 4.9806950154474, 4.9806950154474, 5.2940513994166, 4.73771588167502, 4.76390186342314, 5.2940513994166, 4.73771588167502, 4.76390186342314, 5.09434727489065, 5.2940513994166, 4.73771588167502, 4.73771588167502, 5.09434727489065, 4.9806950154474, 5.2940513994166, 4.76390186342314, 5.09434727489065, 4.73771588167502, 4.73771588167502, 4.76390186342314, 4.73771588167502, 5.09434727489065, 4.76390186342314, 5.09434727489065, 4.76390186342314, 4.73771588167502, 5.09434727489065, 5.2940513994166, 5.09434727489065, 5.09434727489065, 5.2940513994166, 4.73771588167502, 4.9806950154474, 5.09434727489065, 5.2940513994166, 4.76390186342314, 5.2940513994166, 4.73771588167502, 4.9806950154474, 4.73771588167502, 4.9806950154474, 4.9806950154474, 4.76390186342314, 4.73771588167502, 4.73771588167502, 4.73771588167502, 4.73771588167502, 5.2940513994166, 4.76390186342314, 5.2940513994166, 4.9806950154474, 5.09434727489065], + "ww_sampled_lab_sites": [5, 1, 1, 2, 4, 2, 3, 2, 3, 5, 1, 3, 1, 3, 4, 1, 3, 1, 2, 5, 5, 2, 4, 5, 3, 1, 2, 2, 3, 2, 4, 2, 3, 5, 1, 4, 2, 5, 1, 2, 1, 2, 2, 3, 4, 5, 4, 2, 3, 5, 2, 1, 4, 3, 3, 4, 5, 4, 1, 2, 2, 5, 4, 5, 4, 4, 5, 2, 2, 5, 5, 4, 2, 4, 1, 2, 4, 3, 4, 1, 2, 5, 1, 4, 1, 4, 2, 3, 3, 2, 4, 4, 3, 5, 1, 2, 4, 5], + "ww_log_lod": [4.93514594555074, 5.10629923852651, 5.10629923852651, 4.88222110274807, 4.86366790424887, 4.88222110274807, 5.0120320880869, 4.88222110274807, 5.0120320880869, 4.93514594555074, 5.10629923852651, 5.0120320880869, 5.10629923852651, 5.0120320880869, 4.86366790424887, 5.10629923852651, 5.0120320880869, 5.10629923852651, 4.88222110274807, 4.93514594555074, 4.93514594555074, 4.88222110274807, 4.86366790424887, 4.93514594555074, 5.0120320880869, 5.10629923852651, 4.88222110274807, 4.88222110274807, 5.0120320880869, 4.88222110274807, 4.86366790424887, 4.88222110274807, 5.0120320880869, 4.93514594555074, 5.10629923852651, 4.86366790424887, 4.88222110274807, 4.93514594555074, 5.10629923852651, 4.88222110274807, 5.10629923852651, 4.88222110274807, 4.88222110274807, 5.0120320880869, 4.86366790424887, 4.93514594555074, 4.86366790424887, 4.88222110274807, 5.0120320880869, 4.93514594555074, 4.88222110274807, 5.10629923852651, 4.86366790424887, 5.0120320880869, 5.0120320880869, 4.86366790424887, 4.93514594555074, 4.86366790424887, 5.10629923852651, 4.88222110274807, 4.88222110274807, 4.93514594555074, 4.86366790424887, 4.93514594555074, 4.86366790424887, 4.86366790424887, 4.93514594555074, 4.88222110274807, 4.88222110274807, 4.93514594555074, 4.93514594555074, 4.86366790424887, 4.88222110274807, 4.86366790424887, 5.10629923852651, 4.88222110274807, 4.86366790424887, 5.0120320880869, 4.86366790424887, 5.10629923852651, 4.88222110274807, 4.93514594555074, 5.10629923852651, 4.86366790424887, 5.10629923852651, 4.86366790424887, 4.88222110274807, 5.0120320880869, 5.0120320880869, 4.88222110274807, 4.86366790424887, 4.86366790424887, 5.0120320880869, 4.93514594555074, 5.10629923852651, 4.88222110274807, 4.86366790424887, 4.93514594555074], "ww_censored": [], - "ww_uncensored": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88], - "hosp": [6, 6, 7, 10, 10, 12, 12, 10, 8, 15, 8, 9, 9, 13, 17, 7, 12, 10, 13, 10, 12, 15, 19, 19, 22, 17, 19, 14, 17, 19, 18, 13, 24, 21, 35, 26, 25, 30, 26, 20, 29, 38, 35, 41, 30, 37, 35, 46, 38, 23, 38, 22, 28, 23, 31, 19, 23, 17, 23, 26, 17, 17, 12, 13, 9, 22, 12, 13, 17, 14, 12, 6, 10, 10, 4, 12, 9, 8, 9, 8, 6, 13, 7, 8, 13, 9, 9, 17, 7, 10], + "ww_uncensored": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98], + "hosp": [11, 13, 20, 5, 16, 10, 11, 19, 12, 16, 19, 17, 19, 17, 17, 25, 27, 16, 24, 11, 32, 24, 20, 18, 25, 26, 29, 26, 26, 35, 39, 36, 29, 37, 32, 35, 41, 37, 35, 39, 43, 40, 56, 42, 39, 43, 39, 28, 35, 43, 38, 42, 29, 38, 33, 43, 34, 20, 33, 18, 26, 20, 22, 19, 19, 19, 18, 25, 17, 17, 17, 15, 11, 7, 20, 11, 11, 15, 12, 11, 5, 10, 11, 4, 12, 9, 9, 9, 8, 8], "day_of_week": [5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3], - "log_conc": [7.84933755868547, 8.75878317201939, 7.94731683160414, 7.84271554481825, 6.84314706058328, 8.49580667918879, 8.70688802330958, 8.3080964891881, 7.51476199898784, 8.73263730516131, 7.67705894528627, 9.16138870419256, 8.2291748710282, 7.57842684442178, 8.7658468204119, 7.59564541757592, 8.6485267724375, 8.73847850883661, 8.53496849868392, 7.81189950945825, 8.72816004875312, 7.77632957957527, 8.35751634986049, 8.61237112259723, 7.14724348782869, 8.98381918138397, 9.71929374032385, 7.73223113044046, 7.64136587504144, 8.93727178027946, 7.69319599090821, 9.3539099812084, 9.77803780431265, 8.83676433943342, 7.80095273126195, 7.88428742699397, 10.6861051151009, 10.7177204634667, 7.86033075211836, 9.44031060901259, 9.19820314756664, 8.10945148438765, 9.47368587127031, 9.30021922706583, 8.71406234479695, 7.82242543443078, 8.73443519589195, 9.246306907259, 8.61154444152258, 10.8332932698813, 7.39925321859574, 9.06066397101092, 8.61830748102478, 8.86917291106784, 9.02943162748827, 8.42736799456162, 8.11764762377314, 8.03266723037298, 8.41720674557318, 7.98228105459503, 8.42950370265189, 8.47947015286844, 7.81419539933152, 6.95672231924945, 7.74818768373832, 7.66575220799484, 6.86777849754671, 7.76297112083073, 6.95229085580734, 7.56030545976231, 6.64452132423733, 7.65706249921245, 6.80982040079249, 8.28375427477922, 10.3920632090468, 8.17535318491482, 8.62625413957733, 7.21373104269779, 8.15256350454411, 8.38397288981646, 8.40418626057656, 8.23933140611506, 7.91321568315416, 6.89107975822161, 8.02468424809018, 6.87101220865236, 8.34645251673327, 7.62822976068377], + "log_conc": [6.3844908130805, 8.14451035738225, 7.99770763595962, 7.82998259689739, 6.4065742348394, 6.66439791225434, 8.7544308668928, 7.85805962848062, 8.83441333698154, 6.72297492958816, 9.03854995609113, 8.96151408570252, 8.36088219460206, 8.74052156354291, 6.72719658379387, 8.29289727032086, 8.91861790218074, 8.0702280675275, 7.4389112076145, 6.77983024511269, 6.82772608306932, 7.63018989860758, 6.91335626813194, 7.09903364058461, 9.38597761451942, 8.91079250143596, 7.68316477525713, 7.57159825566626, 8.91341077022231, 7.54119768468027, 7.21178099441415, 8.11151634353993, 9.65835698208185, 7.11541423675457, 8.72931323625662, 7.04281800242487, 7.93292336587937, 7.3541408399261, 7.77058309665516, 7.3238749648738, 9.1836263651675, 6.86834211205752, 7.29747920526473, 9.65397333190831, 7.31810228083491, 7.44987795716377, 7.25459751550647, 7.50930808782469, 9.71349028240529, 7.34513810510385, 7.13629588019881, 8.48713363300135, 7.12633309236834, 9.86632530967862, 9.54003707593803, 7.27115055045907, 7.5915296718635, 7.18740430517153, 8.27103835450306, 7.99170154819369, 6.92106789787305, 6.78053997955787, 6.54298893005428, 6.7809936962329, 6.59354464399851, 6.49279888377357, 6.71885536982598, 7.34430803998769, 6.73894402478992, 6.68879749345191, 6.77580972113133, 6.30907304588734, 6.53560785949508, 6.22760064816145, 6.56990430875933, 5.9789011134782, 6.34157202783156, 9.26255500277877, 6.1085637533118, 7.53719251660252, 6.44209701868898, 6.24226969865518, 7.33953416813322, 5.83502199331505, 8.14983630746274, 6.01446037613056, 6.01526356530096, 8.93301317095421, 9.48910229840805, 6.57203742858315, 6.16434064390406, 6.09536847734072, 9.34256082891041, 6.21179364557822, 7.90828047588131, 6.18426243118251, 5.89866720955479, 6.16994023799605], "compute_likelihood": 1, "include_ww": 1, "include_hosp": 1, @@ -351,8 +351,8 @@ "r_prior_sd": 1, "log10_g_prior_mean": 12, "log10_g_prior_sd": 2, - "i_first_obs_over_n_prior_a": 1.0015, - "i_first_obs_over_n_prior_b": 5.9985, + "i_first_obs_over_n_prior_a": 1.00204761904762, + "i_first_obs_over_n_prior_b": 5.99795238095238, "hosp_wday_effect_prior_alpha": [5, 5, 5, 5, 5, 5, 5], "mean_initial_exp_growth_rate_prior_mean": 0, "mean_initial_exp_growth_rate_prior_sd": 0.01, @@ -374,6 +374,6 @@ "sigma_rt_prior": 0.1, "log_phi_g_prior_mean": -2.302585, "log_phi_g_prior_sd": 5, - "ww_sampled_sites": [1, 4, 3, 4, 2, 1, 3, 1, 2, 1, 2, 4, 1, 2, 4, 2, 4, 1, 3, 4, 1, 2, 1, 1, 2, 4, 3, 2, 2, 1, 2, 3, 4, 1, 2, 2, 4, 4, 2, 3, 1, 2, 3, 1, 1, 2, 3, 3, 1, 4, 2, 1, 1, 3, 3, 1, 3, 1, 1, 1, 1, 3, 1, 2, 1, 1, 2, 3, 4, 1, 2, 1, 2, 3, 4, 3, 4, 4, 1, 3, 3, 3, 3, 2, 1, 2, 4, 1], - "lab_site_to_site_map": [1, 1, 2, 3, 4] + "ww_sampled_sites": [4, 1, 1, 2, 4, 2, 3, 2, 3, 4, 1, 3, 1, 3, 4, 1, 3, 1, 2, 4, 4, 2, 4, 4, 3, 1, 2, 2, 3, 2, 4, 2, 3, 4, 1, 4, 2, 4, 1, 2, 1, 2, 2, 3, 4, 4, 4, 2, 3, 4, 2, 1, 4, 3, 3, 4, 4, 4, 1, 2, 2, 4, 4, 4, 4, 4, 4, 2, 2, 4, 4, 4, 2, 4, 1, 2, 4, 3, 4, 1, 2, 4, 1, 4, 1, 4, 2, 3, 3, 2, 4, 4, 3, 4, 1, 2, 4, 4], + "lab_site_to_site_map": [1, 2, 3, 4, 4] } diff --git a/notebooks/site_level_ww_model_demo.qmd b/notebooks/site_level_ww_model_demo.qmd index 87dbe6c7..46bf1030 100644 --- a/notebooks/site_level_ww_model_demo.qmd +++ b/notebooks/site_level_ww_model_demo.qmd @@ -11,7 +11,8 @@ import numpyro.distributions.transforms as transforms from pyrenew.deterministic import DeterministicVariable, DeterministicPMF from pyrenew.randomvariable import DistributionalVariable, TransformedVariable -from pyrenew_covid_wastewater.site_level_dynamics_model import ww_site_level_dynamics_model +from pyrenew_hew.site_level_dynamics_model import ww_site_level_dynamics_model +from pyrenew_hew.initialization import get_initialization import jax numpyro.set_host_device_count(4) @@ -82,7 +83,7 @@ r_logmean, r_logsd = convert_to_logmean_log_sd(r_prior_mean, r_prior_sd) # log of state level mean R(t) in weeks log_r_mu_intercept_rv = DistributionalVariable( - "log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd) + "log_r_mu_intercept", dist.Normal(r_logmean, r_logsd) ) ``` @@ -192,7 +193,7 @@ p_hosp_mean_rv = DistributionalVariable( p_hosp_w_sd_sd = stan_data["p_hosp_w_sd_sd"] p_hosp_w_sd_rv = DistributionalVariable( - "p_hosp_w_sd_sd", dist.TruncatedNormal(0, p_hosp_w_sd_sd, low=0) + "p_hosp_w_sd", dist.TruncatedNormal(0, p_hosp_w_sd_sd, low=0) ) autoreg_p_hosp_a = stan_data["autoreg_p_hosp_a"] @@ -316,10 +317,13 @@ with numpyro.handlers.seed(rng_seed=242): # numpyro_predictive_args={"num_samples": 200}, # ) ``` +```{python} +init_vals = get_initialization(stan_data,stdev=0.01,rng_key=jax.random.PRNGKey(0)) +``` ```{python} +from numpyro.infer.initialization import init_to_sample, init_to_value try: - from numpyro.infer.initialization import init_to_sample my_model.run( num_warmup=100, num_samples=100, @@ -327,7 +331,7 @@ try: data_observed_hospital_admissions=data_observed_hospital_admissions, data_observed_log_conc=data_observed_log_conc, mcmc_args=dict(num_chains=1), - nuts_args=dict(init_strategy=init_to_sample) + nuts_args=dict(init_strategy=init_to_value(values=init_vals)) ) except RuntimeError as e: print(f"RuntimeError occurred: {e}") @@ -343,7 +347,7 @@ try: data_observed_hospital_admissions=data_observed_hospital_admissions, data_observed_log_conc=data_observed_log_conc, mcmc_args=dict(num_chains=1), - nuts_args=dict(init_strategy=init_to_sample) + nuts_args=dict(init_strategy=init_to_value(values=init_vals)) ) except AssertionError as e: print(f"AssertionError occurred: {e}") diff --git a/pyrenew_hew/initialization.py b/pyrenew_hew/initialization.py new file mode 100644 index 00000000..9a8aac04 --- /dev/null +++ b/pyrenew_hew/initialization.py @@ -0,0 +1,72 @@ +import jax +import jax.numpy as jnp +import numpyro.distributions as dist +import numpy as np +from pyrenew_hew.utils import convert_to_logmean_log_sd + + +def get_initialization(stan_data, stdev, rng_key): + i_first_obs_est = np.mean(stan_data["hosp"][:7]) / stan_data["p_hosp_prior_mean"] + logit_i_frac_est = jax.scipy.special.logit(i_first_obs_est / stan_data["state_pop"]) + + init_vals = { + "eta_sd": jnp.abs(dist.Normal(0, stdev).sample(rng_key)), + "autoreg_rt": jnp.abs( + dist.Normal( + stan_data["autoreg_rt_a"] + / (stan_data["autoreg_rt_a"] + stan_data["autoreg_rt_b"]), + 0.05, + ).sample(rng_key) + ), + "log_r_mu_intercept": dist.Normal( + convert_to_logmean_log_sd(1, stdev)[0], + convert_to_logmean_log_sd(1, stdev)[1], + ).sample(rng_key), + "sigma_rt": jnp.abs(dist.Normal(0, stdev).sample(rng_key)), + "autoreg_rt_site": jnp.abs(dist.Normal(0.5, 0.05).sample(rng_key)), + "sigma_i_first_obs": jnp.abs(dist.Normal(0, stdev).sample(rng_key)), + "sigma_initial_exp_growth_rate": jnp.abs(dist.Normal(0, stdev).sample(rng_key)), + "i_first_obs_over_n": jax.nn.sigmoid( + dist.Normal(logit_i_frac_est, 0.05).sample(rng_key) + ), + "mean_initial_exp_growth_rate": dist.Normal(0, stdev).sample(rng_key), + "inv_sqrt_phi": 1 / jnp.sqrt(200) + + dist.Normal(1 / 10000, 1 / 10000).sample(rng_key), + "mode_sigma_ww_site": jnp.abs( + dist.Normal( + stan_data["mode_sigma_ww_site_prior_mode"], + stdev * stan_data["mode_sigma_ww_site_prior_sd"], + ).sample(rng_key) + ), + "sd_log_sigma_ww_site": jnp.abs( + dist.Normal( + stan_data["sd_log_sigma_ww_site_prior_mode"], + stdev * stan_data["sd_log_sigma_ww_site_prior_sd"], + ).sample(rng_key) + ), + "p_hosp_mean": dist.Normal( + jax.scipy.special.logit(stan_data["p_hosp_prior_mean"]), stdev + ).sample(rng_key), + "p_hosp_w_sd": jnp.abs(dist.Normal(0.01, 0.001).sample(rng_key)), + "autoreg_p_hosp": jnp.abs(dist.Normal(1 / 100, 0.001).sample(rng_key)), + "t_peak": dist.Normal( + stan_data["viral_shedding_pars"][0], + stdev * stan_data["viral_shedding_pars"][1], + ).sample(rng_key), + "viral_peak": dist.Normal( + stan_data["viral_shedding_pars"][2], + stdev * stan_data["viral_shedding_pars"][3], + ).sample(rng_key), + "dur_shed": dist.Normal( + stan_data["viral_shedding_pars"][4], + stdev * stan_data["viral_shedding_pars"][5], + ).sample(rng_key), + "log10_g": dist.Normal(stan_data["log10_g_prior_mean"], 0.5).sample(rng_key), + "ww_site_mod_sd": jnp.abs(dist.Normal(0, stdev).sample(rng_key)), + "hosp_wday_effect_raw": jax.nn.softmax( + jnp.abs(dist.Normal(1 / 7, stdev).expand([7]).sample(rng_key)) + ), + "inf_feedback_raw": jnp.abs(dist.Normal(500, 20).sample(rng_key)), + } + + return init_vals From 2ac9a84751cde3b0da887ae1567d5b7b9d41b23b Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 7 Oct 2024 15:55:30 -0400 Subject: [PATCH 21/50] change dur_shed --- notebooks/site_level_ww_model_demo.qmd | 9 +++++---- pyrenew_hew/initialization.py | 19 ++++++++++++++----- pyrenew_hew/site_level_dynamics_model.py | 14 ++++++++------ pyrenew_hew/utils.py | 17 +++++------------ 4 files changed, 32 insertions(+), 27 deletions(-) diff --git a/notebooks/site_level_ww_model_demo.qmd b/notebooks/site_level_ww_model_demo.qmd index 46bf1030..2dcbc6b6 100644 --- a/notebooks/site_level_ww_model_demo.qmd +++ b/notebooks/site_level_ww_model_demo.qmd @@ -83,7 +83,7 @@ r_logmean, r_logsd = convert_to_logmean_log_sd(r_prior_mean, r_prior_sd) # log of state level mean R(t) in weeks log_r_mu_intercept_rv = DistributionalVariable( - "log_r_mu_intercept", dist.Normal(r_logmean, r_logsd) + "log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd) ) ``` @@ -107,9 +107,10 @@ viral_peak_rv = DistributionalVariable( "viral_peak", dist.Normal(viral_peak_mean, viral_peak_sd) ) -dur_shed_rv = DistributionalVariable( - "dur_shed", dist.TruncatedNormal(dur_shed_mean, dur_shed_sd, low=0) +dur_shed_after_peak_rv = DistributionalVariable( + "dur_shed_after_peak", dist.TruncatedNormal(dur_shed_mean, dur_shed_sd, low=0) ) +max_shed_interval = dur_shed_mean + 3*dur_shed_sd ``` ```{python} @@ -265,7 +266,7 @@ my_model = ww_site_level_dynamics_model( n_ww_lab_sites, unobs_time, n_initialization_points, - gt_max, + max_shed_interval, i0_t_offset, log_r_mu_intercept_rv, autoreg_rt_rv, diff --git a/pyrenew_hew/initialization.py b/pyrenew_hew/initialization.py index 9a8aac04..640cd487 100644 --- a/pyrenew_hew/initialization.py +++ b/pyrenew_hew/initialization.py @@ -1,13 +1,18 @@ import jax import jax.numpy as jnp -import numpyro.distributions as dist import numpy as np +import numpyro.distributions as dist + from pyrenew_hew.utils import convert_to_logmean_log_sd def get_initialization(stan_data, stdev, rng_key): - i_first_obs_est = np.mean(stan_data["hosp"][:7]) / stan_data["p_hosp_prior_mean"] - logit_i_frac_est = jax.scipy.special.logit(i_first_obs_est / stan_data["state_pop"]) + i_first_obs_est = ( + np.mean(stan_data["hosp"][:7]) / stan_data["p_hosp_prior_mean"] + ) + logit_i_frac_est = jax.scipy.special.logit( + i_first_obs_est / stan_data["state_pop"] + ) init_vals = { "eta_sd": jnp.abs(dist.Normal(0, stdev).sample(rng_key)), @@ -25,7 +30,9 @@ def get_initialization(stan_data, stdev, rng_key): "sigma_rt": jnp.abs(dist.Normal(0, stdev).sample(rng_key)), "autoreg_rt_site": jnp.abs(dist.Normal(0.5, 0.05).sample(rng_key)), "sigma_i_first_obs": jnp.abs(dist.Normal(0, stdev).sample(rng_key)), - "sigma_initial_exp_growth_rate": jnp.abs(dist.Normal(0, stdev).sample(rng_key)), + "sigma_initial_exp_growth_rate": jnp.abs( + dist.Normal(0, stdev).sample(rng_key) + ), "i_first_obs_over_n": jax.nn.sigmoid( dist.Normal(logit_i_frac_est, 0.05).sample(rng_key) ), @@ -61,7 +68,9 @@ def get_initialization(stan_data, stdev, rng_key): stan_data["viral_shedding_pars"][4], stdev * stan_data["viral_shedding_pars"][5], ).sample(rng_key), - "log10_g": dist.Normal(stan_data["log10_g_prior_mean"], 0.5).sample(rng_key), + "log10_g": dist.Normal(stan_data["log10_g_prior_mean"], 0.5).sample( + rng_key + ), "ww_site_mod_sd": jnp.abs(dist.Normal(0, stdev).sample(rng_key)), "hosp_wday_effect_raw": jax.nn.softmax( jnp.abs(dist.Normal(1 / 7, stdev).expand([7]).sample(rng_key)) diff --git a/pyrenew_hew/site_level_dynamics_model.py b/pyrenew_hew/site_level_dynamics_model.py index e96080a3..0bc9ab37 100644 --- a/pyrenew_hew/site_level_dynamics_model.py +++ b/pyrenew_hew/site_level_dynamics_model.py @@ -30,14 +30,14 @@ def __init__( n_ww_lab_sites, unobs_time, n_initialization_points, - gt_max, + max_shed_interval, i0_t_offset, log_r_mu_intercept_rv, autoreg_rt_rv, eta_sd_rv, t_peak_rv, viral_peak_rv, - dur_shed_rv, + dur_shed_after_peak_rv, autoreg_rt_site_rv, sigma_rt_rv, i_first_obs_over_n_rv, @@ -72,14 +72,14 @@ def __init__( self.n_ww_lab_sites = n_ww_lab_sites self.unobs_time = unobs_time self.n_initialization_points = n_initialization_points - self.gt_max = gt_max + self.max_shed_interval = max_shed_interval self.i0_t_offset = i0_t_offset self.log_r_mu_intercept_rv = log_r_mu_intercept_rv self.autoreg_rt_rv = autoreg_rt_rv self.eta_sd_rv = eta_sd_rv self.t_peak_rv = t_peak_rv self.viral_peak_rv = viral_peak_rv - self.dur_shed_rv = dur_shed_rv + self.dur_shed_after_peak_rv = dur_shed_after_peak_rv self.autoreg_rt_site_rv = autoreg_rt_site_rv self.sigma_rt_rv = sigma_rt_rv self.i_first_obs_over_n_rv = i_first_obs_over_n_rv @@ -172,9 +172,11 @@ def sample( t_peak = self.t_peak_rv() viral_peak = self.viral_peak_rv() - dur_shed = self.dur_shed_rv() + dur_shed = self.dur_shed_after_peak_rv() - s = get_vl_trajectory(t_peak, viral_peak, dur_shed, self.gt_max) + s = get_vl_trajectory( + t_peak, viral_peak, dur_shed, self.max_shed_interval + ) mean_initial_exp_growth_rate = self.mean_initial_exp_growth_rate_rv() sigma_initial_exp_growth_rate = self.sigma_initial_exp_growth_rate_rv() diff --git a/pyrenew_hew/utils.py b/pyrenew_hew/utils.py index 545239c4..60a0a5ff 100644 --- a/pyrenew_hew/utils.py +++ b/pyrenew_hew/utils.py @@ -9,18 +9,11 @@ def convert_to_logmean_log_sd(mean, sd): return logmean, logsd -def get_vl_trajectory(tpeak, viral_peak, duration_shedding, n): - s = jnp.zeros(n) +def get_vl_trajectory(tpeak, viral_peak, duration_shedding_after_peak, n): growth = viral_peak / tpeak - wane = viral_peak / (duration_shedding - tpeak) - + wane = viral_peak / duration_shedding_after_peak t = jnp.arange(n) - s = jnp.where(t <= tpeak, jnp.power(10, growth * t), s) - - s = jnp.where( - t > tpeak, jnp.maximum(0, viral_peak + wane * tpeak - wane * t), s + s = 10 ** jnp.where( + t <= tpeak, growth * t, viral_peak + wane * (tpeak - t) ) - s = jnp.where(t > tpeak, jnp.power(10, s), s) - - s = s / jnp.sum(s) - return s + return s / jnp.sum(s) From b4100fae4784724ab0eca20200b3f5bd71b92620 Mon Sep 17 00:00:00 2001 From: sbidari Date: Tue, 8 Oct 2024 12:00:36 -0400 Subject: [PATCH 22/50] changing shedding kinetics function --- pyrenew_hew/site_level_dynamics_model.py | 8 +--- pyrenew_hew/utils.py | 47 ++++++++++++++++++++---- 2 files changed, 42 insertions(+), 13 deletions(-) diff --git a/pyrenew_hew/site_level_dynamics_model.py b/pyrenew_hew/site_level_dynamics_model.py index 0bc9ab37..57722538 100644 --- a/pyrenew_hew/site_level_dynamics_model.py +++ b/pyrenew_hew/site_level_dynamics_model.py @@ -36,7 +36,6 @@ def __init__( autoreg_rt_rv, eta_sd_rv, t_peak_rv, - viral_peak_rv, dur_shed_after_peak_rv, autoreg_rt_site_rv, sigma_rt_rv, @@ -78,7 +77,6 @@ def __init__( self.autoreg_rt_rv = autoreg_rt_rv self.eta_sd_rv = eta_sd_rv self.t_peak_rv = t_peak_rv - self.viral_peak_rv = viral_peak_rv self.dur_shed_after_peak_rv = dur_shed_after_peak_rv self.autoreg_rt_site_rv = autoreg_rt_site_rv self.sigma_rt_rv = sigma_rt_rv @@ -171,12 +169,10 @@ def sample( numpyro.deterministic("log_rtu_weekly", log_rtu_weekly) t_peak = self.t_peak_rv() - viral_peak = self.viral_peak_rv() + # viral_peak = self.viral_peak_rv() dur_shed = self.dur_shed_after_peak_rv() - s = get_vl_trajectory( - t_peak, viral_peak, dur_shed, self.max_shed_interval - ) + s = get_vl_trajectory(t_peak, dur_shed, self.max_shed_interval) mean_initial_exp_growth_rate = self.mean_initial_exp_growth_rate_rv() sigma_initial_exp_growth_rate = self.sigma_initial_exp_growth_rate_rv() diff --git a/pyrenew_hew/utils.py b/pyrenew_hew/utils.py index 60a0a5ff..10305930 100644 --- a/pyrenew_hew/utils.py +++ b/pyrenew_hew/utils.py @@ -9,11 +9,44 @@ def convert_to_logmean_log_sd(mean, sd): return logmean, logsd -def get_vl_trajectory(tpeak, viral_peak, duration_shedding_after_peak, n): - growth = viral_peak / tpeak - wane = viral_peak / duration_shedding_after_peak - t = jnp.arange(n) - s = 10 ** jnp.where( - t <= tpeak, growth * t, viral_peak + wane * (tpeak - t) +def normed_shedding_cdf( + time: float, t_p: float, t_d: float, log_base: float +) -> float: + """ + fraction of total fecal RNA shedding that has occurred + by a given time post infection. + """ + norm_const = (t_p + t_d) * ((log_base - 1) / jnp.log(log_base) - 1) + ad_pre = ( + lambda x: t_p + / jnp.log(log_base) + * jnp.exp(jnp.log(log_base) * x / t_p) + - x ) - return s / jnp.sum(s) + ad_post = ( + lambda x: -t_d + / jnp.log(log_base) + * jnp.exp(jnp.log(log_base) * (1 - ((x - t_p) / t_d))) + - x + ) + return ( + jnp.where( + time < t_p + t_d, + jnp.where( + time < t_p, + ad_pre(time) - ad_pre(0), + ad_pre(t_p) - ad_pre(0) + ad_post(time) - ad_post(t_p), + ), + norm_const, + ) + / norm_const + ) + + +def get_vl_trajectory(tpeak, duration_shedding_after_peak, max_days): + daily_shedding_pmf = normed_shedding_cdf( + jnp.arange(1, max_days), tpeak, duration_shedding_after_peak, 10 + ) - normed_shedding_cdf( + jnp.arange(0, max_days - 1), tpeak, duration_shedding_after_peak, 10 + ) + return daily_shedding_pmf From 4a35e82332cf16d61be27d445bda9d2078fe0786 Mon Sep 17 00:00:00 2001 From: sbidari Date: Tue, 8 Oct 2024 14:58:26 -0400 Subject: [PATCH 23/50] update stan_data --- notebooks/data/fit/stan_data.json | 40 +++++++++++-------- notebooks/data/fit_hosp_only/stan_data.json | 44 ++++++++++++--------- notebooks/wwinference.Rmd | 23 +++++------ 3 files changed, 60 insertions(+), 47 deletions(-) diff --git a/notebooks/data/fit/stan_data.json b/notebooks/data/fit/stan_data.json index 6eb8e1e3..8e728800 100644 --- a/notebooks/data/fit/stan_data.json +++ b/notebooks/data/fit/stan_data.json @@ -1,7 +1,7 @@ { "gt_max": 15, "hosp_delay_max": 55, - "inf_to_hosp": [0.0, 0.00469384736487552, 0.0145200073436112, 0.0278627741704387, 0.0423656492135518, 0.0558071445014868, 0.0665713169684116, 0.0737925805176124, 0.0772854627892072, 0.0773666390616176, 0.0746515449009948, 0.0698761436052596, 0.0637663813017696, 0.0569581929821651, 0.0499600186601535, 0.0431457477049282, 0.0367662806214046, 0.0309702535668237, 0.0258273785539499, 0.0213504646948306, 0.0175141661880584, 0.0142698211023571, 0.0115565159519833, 0.00930888979824423, 0.00746229206759214, 0.00595605679409682, 0.00473519993107751, 0.00375117728281842, 0.00296198928038098, 0.00233187862772459, 0.00183079868293457, 0.00143377454057296, 0.00107076258525208, 0.000773006742366448, 0.000539573690886396, 0.000364177599116743, 0.000237727628685579, 0.000150157714457011, 9.18283319498657e-05, 5.44079947589853e-05, 3.12548818921464e-05, 1.74202619730274e-05, 9.42698047424712e-06, 4.95614149002087e-06, 2.53275674485913e-06, 1.25854819834554e-06, 6.08116579596933e-07, 2.85572858589747e-07, 1.30129404249734e-07, 5.73280599448306e-08, 2.4219376577964e-08, 9.6316861194457e-09, 3.43804936850951e-09, 9.34806280366887e-10, 0.0], + "inf_to_hosp": [0.0, 0.00469384736487552, 0.0145200073436112, 0.0278627741704387, 0.0423656492135518, 0.0558071445014868, 0.0665713169684116, 0.0737925805176124, 0.0772854627892072, 0.0773666390616176, 0.0746515449009948, 0.0698761436052596, 0.0637663813017696, 0.0569581929821651, 0.0499600186601535, 0.0431457477049282, 0.0367662806214046, 0.0309702535668238, 0.0258273785539499, 0.0213504646948306, 0.0175141661880584, 0.0142698211023571, 0.0115565159519833, 0.00930888979824423, 0.00746229206759215, 0.00595605679409682, 0.00473519993107751, 0.00375117728281842, 0.00296198928038098, 0.00233187862772459, 0.00183079868293457, 0.00143377454057296, 0.00107076258525208, 0.000773006742366448, 0.000539573690886396, 0.000364177599116743, 0.000237727628685578, 0.000150157714457011, 9.18283319498657e-05, 5.44079947589853e-05, 3.12548818921464e-05, 1.74202619730274e-05, 9.42698047424712e-06, 4.95614149002087e-06, 2.53275674485913e-06, 1.25854819834554e-06, 6.08116579596933e-07, 2.85572858589747e-07, 1.30129404249734e-07, 5.73280599448305e-08, 2.4219376577964e-08, 9.6316861194457e-09, 3.43804936850951e-09, 9.34806280366887e-10, 0.0], "mwpd": 227000.0, "ot": 90, "n_subpops": 5, @@ -9,8 +9,8 @@ "n_ww_lab_sites": 5, "owt": 98, "oht": 90, - "n_censored": 0, - "n_uncensored": 98, + "n_censored": 1, + "n_uncensored": 97, "uot": 50, "ht": 35, "n_weeks": 18, @@ -322,17 +322,19 @@ "generation_interval": [0.161701189933765, 0.320525743089203, 0.242198071982593, 0.134825252524032, 0.0689141939998525, 0.0346219683116734, 0.017497710736154, 0.00908172017279556, 0.00483656086299504, 0.00260732346885217, 0.00143298046642562, 0.00082002579123121, 0.0004729600977183, 0.000284420637980485, 0.000179877924728358], "ts": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], "state_pop": 3000000.0, - "subpop_size": [50000.0, 100000.0, 200000.0, 400000.0, 2250000.0], + "subpop_size": [2250000.0, 400000.0, 200000.0, 100000.0, 50000.0], "norm_pop": 3000000.0, - "ww_sampled_times": [2, 5, 6, 6, 6, 8, 9, 11, 12, 12, 13, 13, 14, 14, 14, 15, 15, 18, 18, 18, 19, 20, 21, 21, 22, 23, 23, 25, 26, 27, 27, 29, 29, 29, 31, 31, 32, 32, 33, 33, 34, 36, 37, 37, 37, 39, 40, 42, 42, 42, 43, 45, 45, 46, 47, 47, 48, 50, 51, 53, 58, 58, 59, 59, 60, 62, 63, 64, 65, 65, 67, 69, 70, 70, 73, 73, 74, 75, 75, 76, 76, 76, 78, 79, 81, 82, 83, 83, 86, 87, 87, 88, 89, 89, 92, 92, 92, 92], + "ww_sampled_times": [2, 2, 2, 5, 6, 6, 8, 9, 11, 12, 12, 12, 13, 14, 14, 14, 15, 15, 16, 18, 18, 18, 19, 20, 21, 22, 22, 23, 23, 25, 26, 26, 27, 29, 29, 29, 31, 32, 32, 33, 33, 34, 36, 36, 37, 37, 39, 40, 42, 42, 42, 43, 45, 46, 46, 47, 48, 51, 52, 53, 56, 57, 58, 58, 59, 59, 63, 63, 64, 65, 65, 67, 70, 70, 73, 74, 75, 76, 76, 77, 78, 80, 81, 83, 83, 84, 86, 87, 89, 89, 90, 90, 92, 92, 92, 94, 94, 94], "hosp_times": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90], - "ww_sampled_lab_sites": [5, 1, 1, 2, 4, 2, 3, 2, 3, 5, 1, 3, 1, 3, 4, 1, 3, 1, 2, 5, 5, 2, 4, 5, 3, 1, 2, 2, 3, 2, 4, 2, 3, 5, 1, 4, 2, 5, 1, 2, 1, 2, 2, 3, 4, 5, 4, 2, 3, 5, 2, 1, 4, 3, 3, 4, 5, 4, 1, 2, 2, 5, 4, 5, 4, 4, 5, 2, 2, 5, 5, 4, 2, 4, 1, 2, 4, 3, 4, 1, 2, 5, 1, 4, 1, 4, 2, 3, 3, 2, 4, 4, 3, 5, 1, 2, 4, 5], - "ww_log_lod": [4.93514594555074, 5.10629923852651, 5.10629923852651, 4.88222110274807, 4.86366790424887, 4.88222110274807, 5.0120320880869, 4.88222110274807, 5.0120320880869, 4.93514594555074, 5.10629923852651, 5.0120320880869, 5.10629923852651, 5.0120320880869, 4.86366790424887, 5.10629923852651, 5.0120320880869, 5.10629923852651, 4.88222110274807, 4.93514594555074, 4.93514594555074, 4.88222110274807, 4.86366790424887, 4.93514594555074, 5.0120320880869, 5.10629923852651, 4.88222110274807, 4.88222110274807, 5.0120320880869, 4.88222110274807, 4.86366790424887, 4.88222110274807, 5.0120320880869, 4.93514594555074, 5.10629923852651, 4.86366790424887, 4.88222110274807, 4.93514594555074, 5.10629923852651, 4.88222110274807, 5.10629923852651, 4.88222110274807, 4.88222110274807, 5.0120320880869, 4.86366790424887, 4.93514594555074, 4.86366790424887, 4.88222110274807, 5.0120320880869, 4.93514594555074, 4.88222110274807, 5.10629923852651, 4.86366790424887, 5.0120320880869, 5.0120320880869, 4.86366790424887, 4.93514594555074, 4.86366790424887, 5.10629923852651, 4.88222110274807, 4.88222110274807, 4.93514594555074, 4.86366790424887, 4.93514594555074, 4.86366790424887, 4.86366790424887, 4.93514594555074, 4.88222110274807, 4.88222110274807, 4.93514594555074, 4.93514594555074, 4.86366790424887, 4.88222110274807, 4.86366790424887, 5.10629923852651, 4.88222110274807, 4.86366790424887, 5.0120320880869, 4.86366790424887, 5.10629923852651, 4.88222110274807, 4.93514594555074, 5.10629923852651, 4.86366790424887, 5.10629923852651, 4.86366790424887, 4.88222110274807, 5.0120320880869, 5.0120320880869, 4.88222110274807, 4.86366790424887, 4.86366790424887, 5.0120320880869, 4.93514594555074, 5.10629923852651, 4.88222110274807, 4.86366790424887, 4.93514594555074], - "ww_censored": [], - "ww_uncensored": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98], - "hosp": [11, 13, 20, 5, 16, 10, 11, 19, 12, 16, 19, 17, 19, 17, 17, 25, 27, 16, 24, 11, 32, 24, 20, 18, 25, 26, 29, 26, 26, 35, 39, 36, 29, 37, 32, 35, 41, 37, 35, 39, 43, 40, 56, 42, 39, 43, 39, 28, 35, 43, 38, 42, 29, 38, 33, 43, 34, 20, 33, 18, 26, 20, 22, 19, 19, 19, 18, 25, 17, 17, 17, 15, 11, 7, 20, 11, 11, 15, 12, 11, 5, 10, 11, 4, 12, 9, 9, 9, 8, 8], + "ww_sampled_subpops": [2, 4, 5, 4, 3, 4, 2, 2, 3, 2, 2, 5, 4, 2, 2, 4, 2, 4, 5, 2, 3, 4, 2, 2, 2, 2, 5, 2, 4, 3, 2, 5, 3, 2, 2, 3, 4, 2, 3, 3, 4, 4, 2, 3, 2, 3, 2, 5, 2, 2, 3, 3, 4, 2, 5, 2, 2, 4, 5, 4, 5, 5, 2, 4, 2, 5, 2, 5, 3, 2, 2, 2, 2, 2, 4, 2, 2, 2, 4, 4, 4, 3, 4, 2, 3, 3, 2, 3, 2, 4, 4, 5, 2, 4, 5, 2, 4, 5], + "lab_site_to_subpop_map": [2, 2, 3, 4, 5], + "ww_sampled_lab_sites": [1, 4, 5, 4, 3, 4, 2, 2, 3, 1, 2, 5, 4, 1, 2, 4, 2, 4, 5, 2, 3, 4, 1, 2, 1, 2, 5, 2, 4, 3, 2, 5, 3, 1, 2, 3, 4, 1, 3, 3, 4, 4, 2, 3, 2, 3, 1, 5, 1, 2, 3, 3, 4, 2, 5, 2, 1, 4, 5, 4, 5, 5, 1, 4, 1, 5, 1, 5, 3, 1, 2, 1, 1, 2, 4, 1, 2, 2, 4, 4, 4, 3, 4, 2, 3, 3, 1, 3, 1, 4, 4, 5, 2, 4, 5, 2, 4, 5], + "ww_log_lod": [4.81786386350134, 4.93524984908241, 4.78269939060126, 4.93524984908241, 5.01370230665543, 4.93524984908241, 5.14825526105204, 5.14825526105204, 5.01370230665543, 4.81786386350134, 5.14825526105204, 4.78269939060126, 4.93524984908241, 4.81786386350134, 5.14825526105204, 4.93524984908241, 5.14825526105204, 4.93524984908241, 4.78269939060126, 5.14825526105204, 5.01370230665543, 4.93524984908241, 4.81786386350134, 5.14825526105204, 4.81786386350134, 5.14825526105204, 4.78269939060126, 5.14825526105204, 4.93524984908241, 5.01370230665543, 5.14825526105204, 4.78269939060126, 5.01370230665543, 4.81786386350134, 5.14825526105204, 5.01370230665543, 4.93524984908241, 4.81786386350134, 5.01370230665543, 5.01370230665543, 4.93524984908241, 4.93524984908241, 5.14825526105204, 5.01370230665543, 5.14825526105204, 5.01370230665543, 4.81786386350134, 4.78269939060126, 4.81786386350134, 5.14825526105204, 5.01370230665543, 5.01370230665543, 4.93524984908241, 5.14825526105204, 4.78269939060126, 5.14825526105204, 4.81786386350134, 4.93524984908241, 4.78269939060126, 4.93524984908241, 4.78269939060126, 4.78269939060126, 4.81786386350134, 4.93524984908241, 4.81786386350134, 4.78269939060126, 4.81786386350134, 4.78269939060126, 5.01370230665543, 4.81786386350134, 5.14825526105204, 4.81786386350134, 4.81786386350134, 5.14825526105204, 4.93524984908241, 4.81786386350134, 5.14825526105204, 5.14825526105204, 4.93524984908241, 4.93524984908241, 4.93524984908241, 5.01370230665543, 4.93524984908241, 5.14825526105204, 5.01370230665543, 5.01370230665543, 4.81786386350134, 5.01370230665543, 4.81786386350134, 4.93524984908241, 4.93524984908241, 4.78269939060126, 5.14825526105204, 4.93524984908241, 4.78269939060126, 5.14825526105204, 4.93524984908241, 4.78269939060126], + "ww_censored": [66], + "ww_uncensored": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98], + "hosp": [14, 15, 11, 22, 13, 17, 26, 19, 21, 21, 18, 32, 22, 29, 15, 38, 30, 26, 26, 36, 37, 41, 38, 37, 50, 57, 56, 46, 55, 49, 53, 58, 58, 59, 63, 68, 66, 90, 71, 74, 90, 79, 66, 76, 91, 79, 94, 78, 93, 82, 100, 92, 61, 93, 60, 71, 62, 72, 53, 55, 49, 59, 68, 46, 50, 45, 36, 32, 56, 37, 41, 41, 35, 45, 37, 33, 27, 33, 34, 34, 38, 27, 43, 30, 33, 34, 35, 33, 35, 40], "day_of_week": [5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3], - "log_conc": [6.3844908130805, 8.14451035738225, 7.99770763595962, 7.82998259689739, 6.4065742348394, 6.66439791225434, 8.7544308668928, 7.85805962848062, 8.83441333698154, 6.72297492958816, 9.03854995609113, 8.96151408570252, 8.36088219460206, 8.74052156354291, 6.72719658379387, 8.29289727032086, 8.91861790218074, 8.0702280675275, 7.4389112076145, 6.77983024511269, 6.82772608306932, 7.63018989860758, 6.91335626813194, 7.09903364058461, 9.38597761451942, 8.91079250143596, 7.68316477525713, 7.57159825566626, 8.91341077022231, 7.54119768468027, 7.21178099441415, 8.11151634353993, 9.65835698208185, 7.11541423675457, 8.72931323625662, 7.04281800242487, 7.93292336587937, 7.3541408399261, 7.77058309665516, 7.3238749648738, 9.1836263651675, 6.86834211205752, 7.29747920526473, 9.65397333190831, 7.31810228083491, 7.44987795716377, 7.25459751550647, 7.50930808782469, 9.71349028240529, 7.34513810510385, 7.13629588019881, 8.48713363300135, 7.12633309236834, 9.86632530967862, 9.54003707593803, 7.27115055045907, 7.5915296718635, 7.18740430517153, 8.27103835450306, 7.99170154819369, 6.92106789787305, 6.78053997955787, 6.54298893005428, 6.7809936962329, 6.59354464399851, 6.49279888377357, 6.71885536982598, 7.34430803998769, 6.73894402478992, 6.68879749345191, 6.77580972113133, 6.30907304588734, 6.53560785949508, 6.22760064816145, 6.56990430875933, 5.9789011134782, 6.34157202783156, 9.26255500277877, 6.1085637533118, 7.53719251660252, 6.44209701868898, 6.24226969865518, 7.33953416813322, 5.83502199331505, 8.14983630746274, 6.01446037613056, 6.01526356530096, 8.93301317095421, 9.48910229840805, 6.57203742858315, 6.16434064390406, 6.09536847734072, 9.34256082891041, 6.21179364557822, 7.90828047588131, 6.18426243118251, 5.89866720955479, 6.16994023799605], + "log_conc": [7.66802724479397, 8.21650318956816, 6.4512477164068, 7.75785732441969, 9.04472643003972, 7.8876093738573, 8.85695271080419, 8.77448654283081, 9.25474422044025, 8.48197724221017, 9.18487615016851, 7.73586744786872, 9.15190271402333, 8.57063681753772, 9.22494958414535, 8.8811814139552, 8.95303168179479, 8.87479731488632, 7.04631257099227, 9.17451866263841, 9.81185744840712, 8.34692842110394, 8.56700168572782, 9.35764330055156, 8.64805057574179, 9.45079742597308, 7.30867303574344, 9.41472892758601, 9.12723564522376, 9.63427957851573, 9.4877830585957, 9.02899614473423, 9.86293199017148, 8.81668715131234, 9.35443062233028, 9.86957090190018, 8.95533551342666, 9.10840763554723, 9.89336595991182, 10.2035349739338, 8.51721998551749, 8.50763772111075, 9.68146528671227, 10.1267878546749, 9.82586519456221, 10.4562946689248, 9.39359538798059, 9.06912617169831, 9.13017806227903, 9.87195987699652, 10.1617885270518, 9.92070098259861, 9.04498782443653, 9.99736049469332, 7.8122356539209, 9.90761719216578, 9.19954542861679, 9.20908242423643, 7.53800674845928, 9.23810876627582, 8.74766338564374, 5.68405424571368, 8.94327722754565, 8.13815511398509, 8.72901851354369, 2.39134969530063, 8.55874357420902, 7.99413714656807, 8.70447685985134, 8.7648016157833, 9.28255076496634, 8.71171479812706, 8.80114398496407, 9.28818292020013, 7.6967588326828, 8.58440395540122, 8.89885558938604, 9.12890628468619, 7.639038061726, 7.70375332709297, 8.03272855574736, 8.59715668661873, 7.85031287265023, 8.95008137695999, 9.11122924030913, 9.12994052023167, 8.37562782560992, 8.46419130893256, 8.46685643123332, 8.15758459191837, 7.81253430699302, 6.6648633811922, 9.23344255196987, 8.06975580003644, 6.09213112399348, 8.97077401961493, 7.26054982586712, 8.02898834978338], "compute_likelihood": 1, "include_ww": 1, "include_hosp": 1, @@ -341,8 +343,8 @@ "viral_shedding_pars": [5.0, 1.0, 5.1, 0.5, 17.0, 3.0], "autoreg_rt_a": 2, "autoreg_rt_b": 40, - "autoreg_rt_site_a": 1, - "autoreg_rt_site_b": 4, + "autoreg_rt_subpop_a": 1, + "autoreg_rt_subpop_b": 4, "autoreg_p_hosp_a": 1, "autoreg_p_hosp_b": 100, "inv_sqrt_phi_prior_mean": 0.1, @@ -351,8 +353,8 @@ "r_prior_sd": 1, "log10_g_prior_mean": 12, "log10_g_prior_sd": 2, - "i_first_obs_over_n_prior_a": 1.00204761904762, - "i_first_obs_over_n_prior_b": 5.99795238095238, + "i_first_obs_over_n_prior_a": 1.00280952380952, + "i_first_obs_over_n_prior_b": 5.99719047619048, "hosp_wday_effect_prior_alpha": [5, 5, 5, 5, 5, 5, 5], "mean_initial_exp_growth_rate_prior_mean": 0, "mean_initial_exp_growth_rate_prior_sd": 0.01, @@ -374,6 +376,10 @@ "sigma_rt_prior": 0.1, "log_phi_g_prior_mean": -2.302585, "log_phi_g_prior_sd": 5, - "ww_sampled_sites": [4, 1, 1, 2, 4, 2, 3, 2, 3, 4, 1, 3, 1, 3, 4, 1, 3, 1, 2, 4, 4, 2, 4, 4, 3, 1, 2, 2, 3, 2, 4, 2, 3, 4, 1, 4, 2, 4, 1, 2, 1, 2, 2, 3, 4, 4, 4, 2, 3, 4, 2, 1, 4, 3, 3, 4, 4, 4, 1, 2, 2, 4, 4, 4, 4, 4, 4, 2, 2, 4, 4, 4, 2, 4, 1, 2, 4, 3, 4, 1, 2, 4, 1, 4, 1, 4, 2, 3, 3, 2, 4, 4, 3, 4, 1, 2, 4, 4], - "lab_site_to_site_map": [1, 2, 3, 4, 4] + "offset_ref_log_r_t_prior_mean": 0, + "offset_ref_log_r_t_prior_sd": 0.2, + "offset_ref_logit_i_first_obs_prior_mean": 0, + "offset_ref_logit_i_first_obs_prior_sd": 0.25, + "offset_ref_initial_exp_growth_rate_prior_mean": 0, + "offset_ref_initial_exp_growth_rate_prior_sd": 0.025 } diff --git a/notebooks/data/fit_hosp_only/stan_data.json b/notebooks/data/fit_hosp_only/stan_data.json index 35902d09..e320b347 100644 --- a/notebooks/data/fit_hosp_only/stan_data.json +++ b/notebooks/data/fit_hosp_only/stan_data.json @@ -1,16 +1,16 @@ { "gt_max": 15, "hosp_delay_max": 55, - "inf_to_hosp": [0.0, 0.00469384736487552, 0.0145200073436112, 0.0278627741704387, 0.0423656492135518, 0.0558071445014868, 0.0665713169684116, 0.0737925805176124, 0.0772854627892072, 0.0773666390616176, 0.0746515449009948, 0.0698761436052596, 0.0637663813017696, 0.0569581929821651, 0.0499600186601535, 0.0431457477049282, 0.0367662806214046, 0.0309702535668237, 0.0258273785539499, 0.0213504646948306, 0.0175141661880584, 0.0142698211023571, 0.0115565159519833, 0.00930888979824423, 0.00746229206759214, 0.00595605679409682, 0.00473519993107751, 0.00375117728281842, 0.00296198928038098, 0.00233187862772459, 0.00183079868293457, 0.00143377454057296, 0.00107076258525208, 0.000773006742366448, 0.000539573690886396, 0.000364177599116743, 0.000237727628685579, 0.000150157714457011, 9.18283319498657e-05, 5.44079947589853e-05, 3.12548818921464e-05, 1.74202619730274e-05, 9.42698047424712e-06, 4.95614149002087e-06, 2.53275674485913e-06, 1.25854819834554e-06, 6.08116579596933e-07, 2.85572858589747e-07, 1.30129404249734e-07, 5.73280599448306e-08, 2.4219376577964e-08, 9.6316861194457e-09, 3.43804936850951e-09, 9.34806280366887e-10, 0.0], + "inf_to_hosp": [0.0, 0.00469384736487552, 0.0145200073436112, 0.0278627741704387, 0.0423656492135518, 0.0558071445014868, 0.0665713169684116, 0.0737925805176124, 0.0772854627892072, 0.0773666390616176, 0.0746515449009948, 0.0698761436052596, 0.0637663813017696, 0.0569581929821651, 0.0499600186601535, 0.0431457477049282, 0.0367662806214046, 0.0309702535668238, 0.0258273785539499, 0.0213504646948306, 0.0175141661880584, 0.0142698211023571, 0.0115565159519833, 0.00930888979824423, 0.00746229206759215, 0.00595605679409682, 0.00473519993107751, 0.00375117728281842, 0.00296198928038098, 0.00233187862772459, 0.00183079868293457, 0.00143377454057296, 0.00107076258525208, 0.000773006742366448, 0.000539573690886396, 0.000364177599116743, 0.000237727628685578, 0.000150157714457011, 9.18283319498657e-05, 5.44079947589853e-05, 3.12548818921464e-05, 1.74202619730274e-05, 9.42698047424712e-06, 4.95614149002087e-06, 2.53275674485913e-06, 1.25854819834554e-06, 6.08116579596933e-07, 2.85572858589747e-07, 1.30129404249734e-07, 5.73280599448305e-08, 2.4219376577964e-08, 9.6316861194457e-09, 3.43804936850951e-09, 9.34806280366887e-10, 0.0], "mwpd": 227000.0, "ot": 90, - "n_subpops": 5, - "n_ww_sites": 4, - "n_ww_lab_sites": 5, - "owt": 88, + "n_subpops": 1, + "n_ww_sites": 0.0, + "n_ww_lab_sites": 0, + "owt": 0, "oht": 90, "n_censored": 0, - "n_uncensored": 88, + "n_uncensored": 0, "uot": 50, "ht": 35, "n_weeks": 18, @@ -322,17 +322,19 @@ "generation_interval": [0.161701189933765, 0.320525743089203, 0.242198071982593, 0.134825252524032, 0.0689141939998525, 0.0346219683116734, 0.017497710736154, 0.00908172017279556, 0.00483656086299504, 0.00260732346885217, 0.00143298046642562, 0.00082002579123121, 0.0004729600977183, 0.000284420637980485, 0.000179877924728358], "ts": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], "state_pop": 3000000.0, - "subpop_size": [400000.0, 200000.0, 100000.0, 50000.0, 2250000.0], + "subpop_size": [3000000.0], "norm_pop": 3000000.0, - "ww_sampled_times": [2, 5, 6, 6, 8, 9, 11, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 18, 18, 18, 19, 20, 21, 22, 23, 23, 25, 26, 27, 29, 29, 29, 31, 32, 32, 33, 33, 34, 36, 36, 37, 37, 37, 39, 42, 42, 42, 43, 45, 45, 46, 47, 48, 51, 53, 58, 58, 59, 59, 63, 63, 64, 65, 65, 67, 70, 70, 73, 73, 74, 75, 76, 76, 76, 78, 80, 81, 82, 83, 83, 84, 87, 89, 91, 92, 93, 93, 95], + "ww_sampled_times": [], "hosp_times": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90], - "ww_sampled_lab_sites": [1, 5, 4, 5, 3, 2, 4, 1, 3, 2, 3, 5, 1, 3, 5, 3, 5, 2, 4, 5, 2, 3, 1, 2, 3, 5, 4, 3, 3, 2, 3, 4, 5, 1, 3, 3, 5, 5, 3, 4, 2, 3, 4, 2, 1, 3, 4, 4, 1, 5, 3, 2, 1, 4, 4, 2, 4, 1, 2, 1, 2, 4, 1, 3, 1, 1, 3, 4, 5, 1, 3, 2, 3, 4, 5, 4, 5, 5, 2, 4, 4, 4, 4, 3, 2, 3, 5, 1], - "ww_log_lod": [5.09434727489065, 4.9806950154474, 4.73771588167502, 4.9806950154474, 5.2940513994166, 4.76390186342314, 4.73771588167502, 5.09434727489065, 5.2940513994166, 4.76390186342314, 5.2940513994166, 4.9806950154474, 5.09434727489065, 5.2940513994166, 4.9806950154474, 5.2940513994166, 4.9806950154474, 4.76390186342314, 4.73771588167502, 4.9806950154474, 4.76390186342314, 5.2940513994166, 5.09434727489065, 4.76390186342314, 5.2940513994166, 4.9806950154474, 4.73771588167502, 5.2940513994166, 5.2940513994166, 4.76390186342314, 5.2940513994166, 4.73771588167502, 4.9806950154474, 5.09434727489065, 5.2940513994166, 5.2940513994166, 4.9806950154474, 4.9806950154474, 5.2940513994166, 4.73771588167502, 4.76390186342314, 5.2940513994166, 4.73771588167502, 4.76390186342314, 5.09434727489065, 5.2940513994166, 4.73771588167502, 4.73771588167502, 5.09434727489065, 4.9806950154474, 5.2940513994166, 4.76390186342314, 5.09434727489065, 4.73771588167502, 4.73771588167502, 4.76390186342314, 4.73771588167502, 5.09434727489065, 4.76390186342314, 5.09434727489065, 4.76390186342314, 4.73771588167502, 5.09434727489065, 5.2940513994166, 5.09434727489065, 5.09434727489065, 5.2940513994166, 4.73771588167502, 4.9806950154474, 5.09434727489065, 5.2940513994166, 4.76390186342314, 5.2940513994166, 4.73771588167502, 4.9806950154474, 4.73771588167502, 4.9806950154474, 4.9806950154474, 4.76390186342314, 4.73771588167502, 4.73771588167502, 4.73771588167502, 4.73771588167502, 5.2940513994166, 4.76390186342314, 5.2940513994166, 4.9806950154474, 5.09434727489065], + "ww_sampled_subpops": [], + "lab_site_to_subpop_map": [], + "ww_sampled_lab_sites": [], + "ww_log_lod": [], "ww_censored": [], - "ww_uncensored": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88], - "hosp": [6, 6, 7, 10, 10, 12, 12, 10, 8, 15, 8, 9, 9, 13, 17, 7, 12, 10, 13, 10, 12, 15, 19, 19, 22, 17, 19, 14, 17, 19, 18, 13, 24, 21, 35, 26, 25, 30, 26, 20, 29, 38, 35, 41, 30, 37, 35, 46, 38, 23, 38, 22, 28, 23, 31, 19, 23, 17, 23, 26, 17, 17, 12, 13, 9, 22, 12, 13, 17, 14, 12, 6, 10, 10, 4, 12, 9, 8, 9, 8, 6, 13, 7, 8, 13, 9, 9, 17, 7, 10], + "ww_uncensored": [], + "hosp": [14, 15, 11, 22, 13, 17, 26, 19, 21, 21, 18, 32, 22, 29, 15, 38, 30, 26, 26, 36, 37, 41, 38, 37, 50, 57, 56, 46, 55, 49, 53, 58, 58, 59, 63, 68, 66, 90, 71, 74, 90, 79, 66, 76, 91, 79, 94, 78, 93, 82, 100, 92, 61, 93, 60, 71, 62, 72, 53, 55, 49, 59, 68, 46, 50, 45, 36, 32, 56, 37, 41, 41, 35, 45, 37, 33, 27, 33, 34, 34, 38, 27, 43, 30, 33, 34, 35, 33, 35, 40], "day_of_week": [5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3], - "log_conc": [7.84933755868547, 8.75878317201939, 7.94731683160414, 7.84271554481825, 6.84314706058328, 8.49580667918879, 8.70688802330958, 8.3080964891881, 7.51476199898784, 8.73263730516131, 7.67705894528627, 9.16138870419256, 8.2291748710282, 7.57842684442178, 8.7658468204119, 7.59564541757592, 8.6485267724375, 8.73847850883661, 8.53496849868392, 7.81189950945825, 8.72816004875312, 7.77632957957527, 8.35751634986049, 8.61237112259723, 7.14724348782869, 8.98381918138397, 9.71929374032385, 7.73223113044046, 7.64136587504144, 8.93727178027946, 7.69319599090821, 9.3539099812084, 9.77803780431265, 8.83676433943342, 7.80095273126195, 7.88428742699397, 10.6861051151009, 10.7177204634667, 7.86033075211836, 9.44031060901259, 9.19820314756664, 8.10945148438765, 9.47368587127031, 9.30021922706583, 8.71406234479695, 7.82242543443078, 8.73443519589195, 9.246306907259, 8.61154444152258, 10.8332932698813, 7.39925321859574, 9.06066397101092, 8.61830748102478, 8.86917291106784, 9.02943162748827, 8.42736799456162, 8.11764762377314, 8.03266723037298, 8.41720674557318, 7.98228105459503, 8.42950370265189, 8.47947015286844, 7.81419539933152, 6.95672231924945, 7.74818768373832, 7.66575220799484, 6.86777849754671, 7.76297112083073, 6.95229085580734, 7.56030545976231, 6.64452132423733, 7.65706249921245, 6.80982040079249, 8.28375427477922, 10.3920632090468, 8.17535318491482, 8.62625413957733, 7.21373104269779, 8.15256350454411, 8.38397288981646, 8.40418626057656, 8.23933140611506, 7.91321568315416, 6.89107975822161, 8.02468424809018, 6.87101220865236, 8.34645251673327, 7.62822976068377], + "log_conc": [], "compute_likelihood": 1, "include_ww": 0, "include_hosp": 1, @@ -341,8 +343,8 @@ "viral_shedding_pars": [5.0, 1.0, 5.1, 0.5, 17.0, 3.0], "autoreg_rt_a": 2, "autoreg_rt_b": 40, - "autoreg_rt_site_a": 1, - "autoreg_rt_site_b": 4, + "autoreg_rt_subpop_a": 1, + "autoreg_rt_subpop_b": 4, "autoreg_p_hosp_a": 1, "autoreg_p_hosp_b": 100, "inv_sqrt_phi_prior_mean": 0.1, @@ -351,8 +353,8 @@ "r_prior_sd": 1, "log10_g_prior_mean": 12, "log10_g_prior_sd": 2, - "i_first_obs_over_n_prior_a": 1.0015, - "i_first_obs_over_n_prior_b": 5.9985, + "i_first_obs_over_n_prior_a": 1.00280952380952, + "i_first_obs_over_n_prior_b": 5.99719047619048, "hosp_wday_effect_prior_alpha": [5, 5, 5, 5, 5, 5, 5], "mean_initial_exp_growth_rate_prior_mean": 0, "mean_initial_exp_growth_rate_prior_sd": 0.01, @@ -374,6 +376,10 @@ "sigma_rt_prior": 0.0001, "log_phi_g_prior_mean": -2.302585, "log_phi_g_prior_sd": 5, - "ww_sampled_sites": [1, 4, 3, 4, 2, 1, 3, 1, 2, 1, 2, 4, 1, 2, 4, 2, 4, 1, 3, 4, 1, 2, 1, 1, 2, 4, 3, 2, 2, 1, 2, 3, 4, 1, 2, 2, 4, 4, 2, 3, 1, 2, 3, 1, 1, 2, 3, 3, 1, 4, 2, 1, 1, 3, 3, 1, 3, 1, 1, 1, 1, 3, 1, 2, 1, 1, 2, 3, 4, 1, 2, 1, 2, 3, 4, 3, 4, 4, 1, 3, 3, 3, 3, 2, 1, 2, 4, 1], - "lab_site_to_site_map": [1, 1, 2, 3, 4] + "offset_ref_log_r_t_prior_mean": 0, + "offset_ref_log_r_t_prior_sd": 0.2, + "offset_ref_logit_i_first_obs_prior_mean": 0, + "offset_ref_logit_i_first_obs_prior_sd": 0.25, + "offset_ref_initial_exp_growth_rate_prior_mean": 0, + "offset_ref_initial_exp_growth_rate_prior_sd": 0.025 } diff --git a/notebooks/wwinference.Rmd b/notebooks/wwinference.Rmd index 2af43b94..6713e8cb 100644 --- a/notebooks/wwinference.Rmd +++ b/notebooks/wwinference.Rmd @@ -377,7 +377,7 @@ state-level R(t) estimate. We can generate this directly on the output of `wwinference()` using: ```{r extracting-draws} -draws_df <- get_draws_df(ww_fit) +draws_df <- get_draws(ww_fit) cat( "Variables in dataframe: ", @@ -396,9 +396,12 @@ Rather than using S3 methods supplied for `wwinference()`, the elements in the This is demonstrated below: ```{r extracting-draws-explicit} -draws_df_explicit <- get_draws_df( +draws_explicit <- get_draws( x = ww_fit$raw_input_data$input_ww_data, count_data = ww_fit$raw_input_data$input_count_data, + date_time_spine = ww_fit$raw_input_data$date_time_spine, + site_subpop_spine = ww_fit$raw_input_data$site_subpop_spine, + lab_site_subpop_spine = ww_fit$raw_input_data$lab_site_subpop_spine, stan_data_list = ww_fit$stan_data_list, fit_obj = ww_fit$fit ) @@ -413,23 +416,22 @@ visualize data that was below the LOD (even though the fit incorporated them via the censored observation process.) ```{r generating-figures, out.width='100%'} -draws_df <- get_draws_df(ww_fit) plot_hosp <- get_plot_forecasted_counts( - draws = draws_df, + draws = draws$predicted_counts, count_data_eval = hosp_data_eval, count_data_eval_col_name = "daily_hosp_admits_for_eval", forecast_date = forecast_date ) plot_hosp -plot_ww <- get_plot_ww_conc(draws_df, forecast_date) +plot_ww <- get_plot_ww_conc(draws$predicted_ww, forecast_date) plot_ww -plot_state_rt <- get_plot_global_rt(draws_df, forecast_date) +plot_state_rt <- get_plot_global_rt(draws$global_rt, forecast_date) plot_state_rt -plot_subpop_rt <- get_plot_subpop_rt(draws_df, forecast_date) +plot_subpop_rt <- get_plot_subpop_rt(draws$subpop_rt, forecast_date) plot_subpop_rt ``` @@ -520,14 +522,13 @@ fit_hosp_only <- wwinference::wwinference( ``` ```{r plot-hosp-only, out.width='100%'} -draws_df_hosp_only <- get_draws_df(fit_hosp_only) -plot_hosp_hosp_only <- get_plot_forecasted_counts( - draws = draws_df_hosp_only, +draws_hosp_only <- get_draws(fit_hosp_only) +plot(draws_hosp_only, + what = "predicted_counts", count_data_eval = hosp_data_eval, count_data_eval_col_name = "daily_hosp_admits_for_eval", forecast_date = forecast_date ) -plot_hosp_hosp_only ``` ```{r copy results} From 764a42ce4a02fb9d76f57cb1c985a0e330ac825e Mon Sep 17 00:00:00 2001 From: sbidari Date: Tue, 8 Oct 2024 14:59:45 -0400 Subject: [PATCH 24/50] update priors --- notebooks/site_level_ww_model_demo.qmd | 51 ++++++++++++++++++++------ 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/notebooks/site_level_ww_model_demo.qmd b/notebooks/site_level_ww_model_demo.qmd index 2dcbc6b6..16ed4c55 100644 --- a/notebooks/site_level_ww_model_demo.qmd +++ b/notebooks/site_level_ww_model_demo.qmd @@ -39,7 +39,7 @@ n_weeks = stan_data["n_weeks"] unobs_time = stan_data["uot"] #unobserved time before we observe hospital admissions/ WW #n_datapoints = obs_time+horizon_time -n_subpops = stan_data["n_subpops"] #number of WW sites +n_subpops = stan_data["n_subpops"] #number of modeled subpops state_pop = stan_data["state_pop"] subpop_size = stan_data["subpop_size"] norm_pop = stan_data["norm_pop"] @@ -58,10 +58,12 @@ ww_uncensored = jnp.array(stan_data["ww_uncensored"]) #time that WW data is abo obs_ww_days = stan_data["owt"] #number of days of observed WW (should be roughly ot/7) -ww_sampled_sites = jnp.array(stan_data["ww_sampled_sites"]) # vector of unique sites in order of the sampled times -ww_sampled_times = jnp.array(stan_data["ww_sampled_times"]) # a list of all of the days on which WW is sampled +ww_sampled_subpop = jnp.array(stan_data["ww_sampled_subpop"]) # vector of unique sites in order of the sampled times +ww_sampled_times = jnp.array(stan_data["ww_sampled_times"]) # a list of all of the days on which WW is sampled, mapped to corresponding subpops ww_sampled_lab_sites = jnp.array(stan_data["ww_sampled_lab_sites"]) +#vector mapping the subpops to lab-site combos +#lab_site_to_subpop_map = stan_data["lab_site_to_subpop_map"] data_observed_log_conc = jnp.array(stan_data["log_conc"]) data_observed_hospital_admissions = jnp.array(stan_data["hosp"]) @@ -75,19 +77,46 @@ eta_sd_rv = DistributionalVariable("eta_sd", dist.TruncatedNormal(0, eta_sd_sd, autoreg_rt_a = stan_data["autoreg_rt_a"] autoreg_rt_b = stan_data["autoreg_rt_b"] -autoreg_rt_rv = DistributionalVariable("autoreg_rt", dist.Beta(autoreg_rt_a, autoreg_rt_b)) +autoreg_rt_rv = DistributionalVariable("autoreg_rt", dist.Beta(autoreg_rt_a, autoreg_rt_b)) #autoregressive coefficient for AR process on first differences in log R(t) r_prior_mean = stan_data["r_prior_mean"] r_prior_sd = stan_data["r_prior_sd"] r_logmean, r_logsd = convert_to_logmean_log_sd(r_prior_mean, r_prior_sd) # log of state level mean R(t) in weeks -log_r_mu_intercept_rv = DistributionalVariable( - "log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd) +log_r_t_first_obs_rv = DistributionalVariable( + "log_r_t_first_obs", dist.Normal(r_logmean, r_logsd) ) ``` +```{python} + +offset_ref_log_r_t_prior_mean=stan_data["offset_ref_log_r_t_prior_mean"] +offset_ref_log_r_t_prior_sd = stan_data["offset_ref_log_r_t_prior_sd"] +offset_ref_log_r_t_rv = DistributionalVariable( + "offset_ref_log_r_t", dist.Normal( + offset_ref_log_r_t_prior_mean, offset_ref_log_r_t_prior_sd + ) +) + +offset_ref_logit_i_first_obs_prior_mean = stan_data["offset_ref_logit_i_first_obs_prior_mean"] +offset_ref_logit_i_first_obs_prior_sd = stan_data["offset_ref_logit_i_first_obs_prior_sd"] +offset_ref_logit_i_first_obs_rv = DistributionalVariable( + "offset_ref_logit_i_first_obs", dist.Normal( + offset_ref_logit_i_first_obs_prior_mean, offset_ref_logit_i_first_obs_prior_sd + ) +) + +offset_ref_initial_exp_growth_rate_prior_mean = stan_data["offset_ref_initial_exp_growth_rate_prior_mean"] +offset_ref_initial_exp_growth_rate_prior_sd = stan_data["offset_ref_initial_exp_growth_rate_prior_sd"] +offset_ref_initial_exp_growth_rate_rv = DistributionalVariable( + "offset_ref_initial_exp_growth_rate", dist.Normal( + offset_ref_initial_exp_growth_rate_prior_mean, offset_ref_initial_exp_growth_rate_prior_sd + ) +) +``` + ```{python} # viral shedding parameters viral_shedding_pars = stan_data["viral_shedding_pars"] @@ -137,16 +166,16 @@ generation_interval_pmf_rv = DeterministicPMF( "generation_interval_pmf", jnp.array(generation_interval) ) -autoreg_rt_site_a = stan_data["autoreg_rt_site_a"] -autoreg_rt_site_b = stan_data["autoreg_rt_site_b"] -autoreg_rt_site_rv = DistributionalVariable( - "autoreg_rt_site",dist.Beta(autoreg_rt_site_a, autoreg_rt_site_b) +autoreg_rt_subpop_a = stan_data["autoreg_rt_subpop_a"] +autoreg_rt_subpop_b = stan_data["autoreg_rt_subpop_b"] +autoreg_rt_subpop_rv = DistributionalVariable( + "autoreg_rt_subpop",dist.Beta(autoreg_rt_subpop_a, autoreg_rt_subpop_b) ) sigma_rt_prior = stan_data["sigma_rt_prior"] sigma_rt_rv = DistributionalVariable( "sigma_rt", dist.TruncatedNormal(0,sigma_rt_prior,low=0) -) +)# magnitude of subpopulation level R(t) heterogeneity i_first_obs_over_n_prior_a = stan_data["i_first_obs_over_n_prior_a"] i_first_obs_over_n_prior_b = stan_data["i_first_obs_over_n_prior_b"] From 8d72aa50947f3a4131344251bc3ee0861a8da060 Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 9 Oct 2024 11:53:41 -0400 Subject: [PATCH 25/50] model rewrite w ref subpop --- notebooks/site_level_ww_model_demo.qmd | 28 ++- notebooks/wwinference.Rmd | 1 - pyrenew_hew/site_level_dynamics_model.py | 233 ++++++++++++++--------- 3 files changed, 160 insertions(+), 102 deletions(-) diff --git a/notebooks/site_level_ww_model_demo.qmd b/notebooks/site_level_ww_model_demo.qmd index 16ed4c55..f2e804a9 100644 --- a/notebooks/site_level_ww_model_demo.qmd +++ b/notebooks/site_level_ww_model_demo.qmd @@ -37,7 +37,6 @@ obs_time = stan_data["ot"] horizon_time = stan_data["ht"] #horizon time (nowcast + forecast time) n_weeks = stan_data["n_weeks"] unobs_time = stan_data["uot"] #unobserved time before we observe hospital admissions/ WW -#n_datapoints = obs_time+horizon_time n_subpops = stan_data["n_subpops"] #number of modeled subpops state_pop = stan_data["state_pop"] @@ -49,7 +48,6 @@ pop_fraction = jnp.array(subpop_size)/norm_pop ww_ml_produced_per_day = stan_data["mwpd"] n_ww_lab_sites = stan_data["n_ww_lab_sites"] ww_log_lod =jnp.array(stan_data["ww_log_lod"]) # The limit of detection in that site at that time point -lab_site_to_site_map = jnp.array(stan_data["lab_site_to_site_map"]) # which lab sites correspond to which sites n_censored = stan_data["n_censored"] n_uncensored = stan_data["n_uncensored"] @@ -58,12 +56,12 @@ ww_uncensored = jnp.array(stan_data["ww_uncensored"]) #time that WW data is abo obs_ww_days = stan_data["owt"] #number of days of observed WW (should be roughly ot/7) -ww_sampled_subpop = jnp.array(stan_data["ww_sampled_subpop"]) # vector of unique sites in order of the sampled times +ww_sampled_subpops = jnp.array(stan_data["ww_sampled_subpops"]) # vector of unique sites in order of the sampled times ww_sampled_times = jnp.array(stan_data["ww_sampled_times"]) # a list of all of the days on which WW is sampled, mapped to corresponding subpops ww_sampled_lab_sites = jnp.array(stan_data["ww_sampled_lab_sites"]) #vector mapping the subpops to lab-site combos +lab_site_to_subpop_map = stan_data["lab_site_to_subpop_map"] -#lab_site_to_subpop_map = stan_data["lab_site_to_subpop_map"] data_observed_log_conc = jnp.array(stan_data["log_conc"]) data_observed_hospital_admissions = jnp.array(stan_data["hosp"]) @@ -77,7 +75,9 @@ eta_sd_rv = DistributionalVariable("eta_sd", dist.TruncatedNormal(0, eta_sd_sd, autoreg_rt_a = stan_data["autoreg_rt_a"] autoreg_rt_b = stan_data["autoreg_rt_b"] -autoreg_rt_rv = DistributionalVariable("autoreg_rt", dist.Beta(autoreg_rt_a, autoreg_rt_b)) #autoregressive coefficient for AR process on first differences in log R(t) +autoreg_rt_rv = DistributionalVariable( + "autoreg_rt", dist.Beta(autoreg_rt_a, autoreg_rt_b) + ) #autoregressive coefficient for AR process on first differences in log R(t) r_prior_mean = stan_data["r_prior_mean"] r_prior_sd = stan_data["r_prior_sd"] @@ -201,13 +201,12 @@ sigma_initial_exp_growth_rate_rv = DistributionalVariable( mean_initial_exp_growth_rate_prior_mean = stan_data["mean_initial_exp_growth_rate_prior_mean"] mean_initial_exp_growth_rate_prior_sd = stan_data["mean_initial_exp_growth_rate_prior_sd"] -#stan code uses normal distribution but hosp_only model uses TruncatedNormal mean_initial_exp_growth_rate_rv = DistributionalVariable( "mean_initial_exp_growth_rate", dist.TruncatedNormal( loc=mean_initial_exp_growth_rate_prior_mean, scale=mean_initial_exp_growth_rate_prior_sd, - low=-1, - high=1, + low=-0.001, + high=0.001, ) ) @@ -297,13 +296,12 @@ my_model = ww_site_level_dynamics_model( n_initialization_points, max_shed_interval, i0_t_offset, - log_r_mu_intercept_rv, + log_r_t_first_obs_rv, autoreg_rt_rv, eta_sd_rv, t_peak_rv, - viral_peak_rv, - dur_shed_rv, - autoreg_rt_site_rv, + dur_shed_after_peak_rv, + autoreg_rt_subpop_rv, sigma_rt_rv, i_first_obs_over_n_rv, sigma_i_first_obs_rv, @@ -324,13 +322,11 @@ my_model = ww_site_level_dynamics_model( phi_rv, ww_ml_produced_per_day, pop_fraction, - ww_uncensored, - ww_censored, ww_sampled_lab_sites, - ww_sampled_sites, + ww_sampled_subpops, ww_sampled_times, ww_log_lod, - lab_site_to_site_map + lab_site_to_subpop_map ) ``` diff --git a/notebooks/wwinference.Rmd b/notebooks/wwinference.Rmd index 6713e8cb..4ce1ef5f 100644 --- a/notebooks/wwinference.Rmd +++ b/notebooks/wwinference.Rmd @@ -416,7 +416,6 @@ visualize data that was below the LOD (even though the fit incorporated them via the censored observation process.) ```{r generating-figures, out.width='100%'} - plot_hosp <- get_plot_forecasted_counts( draws = draws$predicted_counts, count_data_eval = hosp_data_eval, diff --git a/pyrenew_hew/site_level_dynamics_model.py b/pyrenew_hew/site_level_dynamics_model.py index 57722538..7ae22796 100644 --- a/pyrenew_hew/site_level_dynamics_model.py +++ b/pyrenew_hew/site_level_dynamics_model.py @@ -32,12 +32,12 @@ def __init__( n_initialization_points, max_shed_interval, i0_t_offset, - log_r_mu_intercept_rv, + log_r_t_first_obs_rv, autoreg_rt_rv, eta_sd_rv, t_peak_rv, dur_shed_after_peak_rv, - autoreg_rt_site_rv, + autoreg_rt_subpop_rv, sigma_rt_rv, i_first_obs_over_n_rv, sigma_i_first_obs_rv, @@ -58,13 +58,11 @@ def __init__( phi_rv, ww_ml_produced_per_day, pop_fraction, - ww_uncensored, - ww_censored, ww_sampled_lab_sites, - ww_sampled_sites, + ww_sampled_subpops, ww_sampled_times, ww_log_lod, - lab_site_to_site_map, + lab_site_to_subpop_map, ): # numpydoc ignore=GL08 self.state_pop = state_pop self.n_subpops = n_subpops @@ -73,12 +71,12 @@ def __init__( self.n_initialization_points = n_initialization_points self.max_shed_interval = max_shed_interval self.i0_t_offset = i0_t_offset - self.log_r_mu_intercept_rv = log_r_mu_intercept_rv + self.log_r_mu_intercept_rv = (log_r_t_first_obs_rv,) self.autoreg_rt_rv = autoreg_rt_rv self.eta_sd_rv = eta_sd_rv self.t_peak_rv = t_peak_rv self.dur_shed_after_peak_rv = dur_shed_after_peak_rv - self.autoreg_rt_site_rv = autoreg_rt_site_rv + self.autoreg_rt_site_rv = autoreg_rt_subpop_rv self.sigma_rt_rv = sigma_rt_rv self.i_first_obs_over_n_rv = i_first_obs_over_n_rv self.sigma_i_first_obs_rv = sigma_i_first_obs_rv @@ -99,13 +97,11 @@ def __init__( self.phi_rv = phi_rv self.ww_ml_produced_per_day = ww_ml_produced_per_day self.pop_fraction = pop_fraction - self.ww_uncensored = ww_uncensored - self.ww_censored = ww_censored self.ww_sampled_lab_sites = ww_sampled_lab_sites - self.ww_sampled_sites = ww_sampled_sites + self.ww_sampled_sites = ww_sampled_subpops self.ww_sampled_times = ww_sampled_times self.ww_log_lod = ww_log_lod - self.lab_site_to_site_map = lab_site_to_site_map + self.lab_site_to_site_map = lab_site_to_subpop_map self.inf_with_feedback_proc = InfectionsWithFeedback( infection_feedback_strength=infection_feedback_strength_rv, @@ -149,7 +145,7 @@ def sample( eta_sd = self.eta_sd_rv() autoreg_rt = self.autoreg_rt_rv() - log_r_mu_intercept = self.log_r_mu_intercept_rv() + log_r_t_first_obs = self.log_r_t_first_obs_rv() rt_init_rate_of_change_rv = DistributionalVariable( "rt_init_rate_of_change", @@ -158,15 +154,15 @@ def sample( rt_init_rate_of_change = rt_init_rate_of_change_rv() - log_rtu_weekly = self.ar_diff_rt( + log_r_t_in_weeks = self.ar_diff_rt( noise_name="rtu_weekly_diff_first_diff_ar_process_noise", n=n_weeks_post_init, - init_vals=jnp.array(log_r_mu_intercept), + init_vals=jnp.array(log_r_t_first_obs), autoreg=jnp.array(autoreg_rt), noise_sd=jnp.array(eta_sd), fundamental_process_init_vals=jnp.array(rt_init_rate_of_change), ) - numpyro.deterministic("log_rtu_weekly", log_rtu_weekly) + numpyro.deterministic("log_r_t_in_weeks", log_r_t_in_weeks) t_peak = self.t_peak_rv() # viral_peak = self.viral_peak_rv() @@ -174,91 +170,155 @@ def sample( s = get_vl_trajectory(t_peak, dur_shed, self.max_shed_interval) + i_first_obs_over_n = self.i_first_obs_over_n_rv() + offset_ref_logit_i_first_obs = self.offset_ref_logit_i_first_obs_rv() + mean_initial_exp_growth_rate = self.mean_initial_exp_growth_rate_rv() - sigma_initial_exp_growth_rate = self.sigma_initial_exp_growth_rate_rv() - initial_exp_growth_rate_site_rv = TransformedVariable( - "clipped_initial_exp_growth_rate_site", - DistributionalVariable( - "initial_exp_growth_rate_site_raw", - dist.Normal( - mean_initial_exp_growth_rate, - sigma_initial_exp_growth_rate, - ), - reparam=LocScaleReparam(0), - ), - transforms=lambda x: jnp.clip(x, -0.01, 0.01), + offset_ref_initial_exp_growth_rate = ( + self.offset_ref_initial_exp_growth_rate_rv() ) - i_first_obs_over_n = self.i_first_obs_over_n_rv() - sigma_i_first_obs = self.sigma_i_first_obs_rv() - i_first_obs_over_n_site_rv = TransformedVariable( - "i_first_obs_over_n_site", - DistributionalVariable( - "i_first_obs_over_n_site_raw", - dist.Normal( - transforms.logit(i_first_obs_over_n), sigma_i_first_obs - ), - reparam=LocScaleReparam(0), - ), - transforms=transforms.SigmoidTransform(), + i_first_obs_over_n_ref_subpop = transforms.SigmoidTransform()( + transforms.logit(i_first_obs_over_n) + + jnp.where(self.n_subpops > 1, offset_ref_logit_i_first_obs, 0) + ) + initial_exp_growth_rate_ref_subpop = ( + mean_initial_exp_growth_rate + + jnp.where( + self.n_subpops > 1, offset_ref_initial_exp_growth_rate, 0 + ) ) - with numpyro.plate("n_subpops", self.n_subpops): - initial_exp_growth_rate_site = initial_exp_growth_rate_site_rv() - i_first_obs_over_n_site = i_first_obs_over_n_site_rv() + if self.n_subpops > 1: + sigma_i_first_obs = self.sigma_i_first_obs_rv() + i_first_obs_over_n_non_ref_subpop_rv = TransformedVariable( + "i_first_obs_over_n_non_ref_subpop", + DistributionalVariable( + "i_first_obs_over_n_non_ref_subpop_raw", + dist.Normal( + transforms.logit(i_first_obs_over_n), sigma_i_first_obs + ), + reparam=LocScaleReparam(0), + ), + transforms=transforms.SigmoidTransform(), + ) + sigma_initial_exp_growth_rate = ( + self.sigma_initial_exp_growth_rate_rv() + ) + initial_exp_growth_rate_non_ref_subpop_rv = TransformedVariable( + "clipped_initial_exp_growth_rate_non_ref_subpop", + DistributionalVariable( + "initial_exp_growth_rate_non_ref_subpop_raw", + dist.Normal( + mean_initial_exp_growth_rate, + sigma_initial_exp_growth_rate, + ), + reparam=LocScaleReparam(0), + ), + transforms=lambda x: jnp.clip(x, -0.01, 0.01), + ) + + with numpyro.plate("n_subpops", self.n_subpops - 1): + initial_exp_growth_rate_non_ref_subpop = ( + initial_exp_growth_rate_non_ref_subpop_rv() + ) + i_first_obs_over_n_non_ref_subpop = ( + i_first_obs_over_n_non_ref_subpop_rv() + ) + + i_first_obs_over_n_subpop = jnp.hstack( + [ + i_first_obs_over_n_ref_subpop, + i_first_obs_over_n_non_ref_subpop, + ] + ) + initial_exp_growth_rate_subpop = jnp.hstack( + [ + initial_exp_growth_rate_ref_subpop, + initial_exp_growth_rate_non_ref_subpop, + ] + ) + else: + i_first_obs_over_n_subpop = i_first_obs_over_n_ref_subpop + initial_exp_growth_rate_subpop = initial_exp_growth_rate_ref_subpop numpyro.deterministic( - "initial_exp_growth_rate_site", initial_exp_growth_rate_site + "i_first_obs_over_n_subpop", i_first_obs_over_n_subpop ) - - log_i0_site = ( - jnp.log(i_first_obs_over_n_site) - - self.unobs_time * initial_exp_growth_rate_site + numpyro.deterministic( + "initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop ) - numpyro.deterministic("log_i0_site", log_i0_site) - autoreg_rt_site = self.autoreg_rt_site_rv() - sigma_rt = self.sigma_rt_rv() - rtu_site_ar_init_rv = DistributionalVariable( - "rtu_site_ar_init", - dist.Normal( - 0, - sigma_rt / jnp.sqrt(1 - jnp.pow(autoreg_rt_site, 2)), - ), + log_i0_subpop = ( + jnp.log(i_first_obs_over_n_subpop) + - self.unobs_time * initial_exp_growth_rate_subpop ) - with numpyro.plate("n_subpops", self.n_subpops): - rtu_site_ar_init = rtu_site_ar_init_rv() + numpyro.deterministic("log_i0_subpop", log_i0_subpop) - rtu_site_ar_proc = ARProcess() - rtu_site_ar_weekly = rtu_site_ar_proc( - noise_name="rtu_ar_proc", - n=n_weeks_post_init, - init_vals=rtu_site_ar_init[jnp.newaxis], - autoreg=autoreg_rt_site[jnp.newaxis], - noise_sd=sigma_rt, + offset_ref_log_r_t = self.offset_ref_log_r_t_rv() + log_rtu_ref_subpop_in_week = log_r_t_in_weeks + jnp.where( + self.n_subpops > 1, offset_ref_log_r_t, 0 ) - numpyro.deterministic("rtu_site_ar_weekly", rtu_site_ar_weekly) + if self.n_subpops > 1: + autoreg_rt_subpop = self.autoreg_rt_subpop_rv() + sigma_rt = self.sigma_rt_rv() + rtu_site_ar_init_rv = DistributionalVariable( + "rtu_site_ar_init", + dist.Normal( + 0, + sigma_rt / jnp.sqrt(1 - jnp.pow(autoreg_rt_subpop, 2)), + ), + ) + with numpyro.plate("n_subpops", self.n_subpops - 1): + rtu_site_ar_init = rtu_site_ar_init_rv() + + rtu_subpop_ar_proc = ARProcess() + rtu_subpop_ar_weekly = rtu_subpop_ar_proc( + noise_name="rtu_ar_proc", + n=n_weeks_post_init, + init_vals=rtu_site_ar_init[jnp.newaxis], + autoreg=autoreg_rt_subpop[jnp.newaxis], + noise_sd=sigma_rt, + ) - rtu_site = jnp.repeat( - jnp.exp(rtu_site_ar_weekly + log_rtu_weekly[:, jnp.newaxis]), - repeats=7, - axis=0, - )[:n_datapoints, :] + numpyro.deterministic("rtu_subpop_ar_weekly", rtu_subpop_ar_weekly) + log_rtu_non_ref_subpop_in_week = ( + rtu_subpop_ar_weekly + log_r_t_in_weeks[:, jnp.newaxis] + ) + log_rtu_subpop_in_week = jnp.concat( + [ + log_rtu_ref_subpop_in_week[:, jnp.newaxis], + log_rtu_non_ref_subpop_in_week, + ], + axis=1, + ) + else: + log_rtu_subpop_in_week = log_rtu_ref_subpop_in_week[:, jnp.newaxis] + + rtu_subpop = jnp.squeeze( + jnp.repeat( + jnp.exp(log_rtu_subpop_in_week), + repeats=7, + axis=0, + )[:n_datapoints, :] + ) - numpyro.deterministic("rtu_site", rtu_site) + numpyro.deterministic("rtu_subpop", rtu_subpop) - i0_site_rv = DeterministicVariable("i0_site", jnp.exp(log_i0_site)) - initial_exp_growth_rate_site_rv = DeterministicVariable( - "initial_exp_growth_rate_site", initial_exp_growth_rate_site + i0_subpop_rv = DeterministicVariable( + "i0_subpop", jnp.exp(log_i0_subpop) + ) + initial_exp_growth_rate_subpop_rv = DeterministicVariable( + "initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop ) infection_initialization_process = InfectionInitializationProcess( "I0_initialization", - i0_site_rv, + i0_subpop_rv, InitializeInfectionsExponentialGrowth( self.n_initialization_points, - initial_exp_growth_rate_site_rv, + initial_exp_growth_rate_subpop_rv, t_pre_init=self.i0_t_offset, ), ) @@ -268,21 +328,23 @@ def sample( numpyro.deterministic("i0", i0) inf_with_feedback_proc_sample = self.inf_with_feedback_proc.sample( - Rt=rtu_site, + Rt=rtu_subpop, I0=i0, gen_int=generation_interval_pmf, ) - new_i_site = jnp.concat( + new_i_subpop = jnp.concat( [ i0, inf_with_feedback_proc_sample.post_initialization_infections, ] ) - r_site_t = inf_with_feedback_proc_sample.rt - numpyro.deterministic("r_site_t", r_site_t) + r_subpop_t = inf_with_feedback_proc_sample.rt + numpyro.deterministic("r_subpop_t", r_subpop_t) - state_inf_per_capita = jnp.sum(self.pop_fraction * new_i_site, axis=1) + state_inf_per_capita = jnp.sum( + self.pop_fraction * new_i_subpop, axis=1 + ) numpyro.deterministic("state_inf_per_capita", state_inf_per_capita) # number of net infected individuals shedding on each day (sum of individuals in diff stages of infection) @@ -290,7 +352,7 @@ def batch_colvolve_fn(m): return jnp.convolve(m, s, mode="valid") model_net_i = jax.vmap(batch_colvolve_fn, in_axes=1, out_axes=1)( - new_i_site + new_i_subpop )[-n_datapoints:, :] numpyro.deterministic("model_net_i", model_net_i) @@ -416,7 +478,8 @@ def batch_colvolve_fn(m): ww_pred = numpyro.sample( "site_ww_pred", dist.Normal( - loc=model_log_v_ot[:, self.lab_site_to_site_map] + ww_site_mod, + loc=model_log_v_ot[:, self.lab_site_to_subpop_map] + + ww_site_mod, scale=sigma_ww_site, ), ) From 09f97326c4fe70fa589a1a8ea2538be1ce970ceb Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 10 Oct 2024 12:42:24 -0400 Subject: [PATCH 26/50] indexing error fixes --- notebooks/site_level_ww_model_demo.qmd | 58 ++++---- pyrenew_hew/site_level_dynamics_model.py | 165 +++++++++++++---------- 2 files changed, 125 insertions(+), 98 deletions(-) diff --git a/notebooks/site_level_ww_model_demo.qmd b/notebooks/site_level_ww_model_demo.qmd index f2e804a9..c0aacae7 100644 --- a/notebooks/site_level_ww_model_demo.qmd +++ b/notebooks/site_level_ww_model_demo.qmd @@ -13,6 +13,7 @@ from pyrenew.deterministic import DeterministicVariable, DeterministicPMF from pyrenew.randomvariable import DistributionalVariable, TransformedVariable from pyrenew_hew.site_level_dynamics_model import ww_site_level_dynamics_model from pyrenew_hew.initialization import get_initialization +from numpyro.infer.initialization import init_to_sample, init_to_value import jax numpyro.set_host_device_count(4) @@ -32,10 +33,12 @@ hosp_delay_max = stan_data["hosp_delay_max"] n_initialization_points = max(gt_max, hosp_delay_max) i0_t_offset = 0 # check this later -# maximum time index for the hospital admissions (max number of days we could have observations) -obs_time = stan_data["ot"] + +obs_time = stan_data["ot"]# maximum time index for the hospital admissions (max number of days we could have observations) +obs_ww_time = stan_data["owt"] #number of days of observed WW (should be roughly ot/7) +obs_hosp_time = stan_data["oht"] # number of days that we have hospital admissions observations + horizon_time = stan_data["ht"] #horizon time (nowcast + forecast time) -n_weeks = stan_data["n_weeks"] unobs_time = stan_data["uot"] #unobserved time before we observe hospital admissions/ WW n_subpops = stan_data["n_subpops"] #number of modeled subpops @@ -51,16 +54,15 @@ ww_log_lod =jnp.array(stan_data["ww_log_lod"]) # The limit of detection in that n_censored = stan_data["n_censored"] n_uncensored = stan_data["n_uncensored"] -ww_censored = jnp.array(stan_data["ww_censored"]) #times that the WW data is below the LOD -ww_uncensored = jnp.array(stan_data["ww_uncensored"]) #time that WW data is above LOD +ww_censored = jnp.array(stan_data["ww_censored"])-1 #times that the WW data is below the LOD +ww_uncensored = jnp.array(stan_data["ww_uncensored"])-1 #time that WW data is above LOD -obs_ww_days = stan_data["owt"] #number of days of observed WW (should be roughly ot/7) +ww_sampled_subpops = jnp.array(stan_data["ww_sampled_subpops"]) -1 # vector of unique sites in order of the sampled times +ww_sampled_times = jnp.array(stan_data["ww_sampled_times"]) -1 # a list of all of the days on which WW is sampled, mapped to corresponding subpops +ww_sampled_lab_sites = jnp.array(stan_data["ww_sampled_lab_sites"]) - 1 +hosp_times = jnp.array(stan_data["hosp_times"]) -1 -ww_sampled_subpops = jnp.array(stan_data["ww_sampled_subpops"]) # vector of unique sites in order of the sampled times -ww_sampled_times = jnp.array(stan_data["ww_sampled_times"]) # a list of all of the days on which WW is sampled, mapped to corresponding subpops -ww_sampled_lab_sites = jnp.array(stan_data["ww_sampled_lab_sites"]) -#vector mapping the subpops to lab-site combos -lab_site_to_subpop_map = stan_data["lab_site_to_subpop_map"] +lab_site_to_subpop_map = jnp.array(stan_data["lab_site_to_subpop_map"]) -1 #vector mapping the subpops to lab-site combos data_observed_log_conc = jnp.array(stan_data["log_conc"]) data_observed_hospital_admissions = jnp.array(stan_data["hosp"]) @@ -75,9 +77,7 @@ eta_sd_rv = DistributionalVariable("eta_sd", dist.TruncatedNormal(0, eta_sd_sd, autoreg_rt_a = stan_data["autoreg_rt_a"] autoreg_rt_b = stan_data["autoreg_rt_b"] -autoreg_rt_rv = DistributionalVariable( - "autoreg_rt", dist.Beta(autoreg_rt_a, autoreg_rt_b) - ) #autoregressive coefficient for AR process on first differences in log R(t) +autoreg_rt_rv = DistributionalVariable("autoreg_rt", dist.Beta(autoreg_rt_a, autoreg_rt_b)) #autoregressive coefficient for AR process on first differences in log R(t) r_prior_mean = stan_data["r_prior_mean"] r_prior_sd = stan_data["r_prior_sd"] @@ -87,7 +87,6 @@ r_logmean, r_logsd = convert_to_logmean_log_sd(r_prior_mean, r_prior_sd) log_r_t_first_obs_rv = DistributionalVariable( "log_r_t_first_obs", dist.Normal(r_logmean, r_logsd) ) - ``` ```{python} @@ -111,8 +110,10 @@ offset_ref_logit_i_first_obs_rv = DistributionalVariable( offset_ref_initial_exp_growth_rate_prior_mean = stan_data["offset_ref_initial_exp_growth_rate_prior_mean"] offset_ref_initial_exp_growth_rate_prior_sd = stan_data["offset_ref_initial_exp_growth_rate_prior_sd"] offset_ref_initial_exp_growth_rate_rv = DistributionalVariable( - "offset_ref_initial_exp_growth_rate", dist.Normal( - offset_ref_initial_exp_growth_rate_prior_mean, offset_ref_initial_exp_growth_rate_prior_sd + "offset_ref_initial_exp_growth_rate", dist.TruncatedNormal( + offset_ref_initial_exp_growth_rate_prior_mean, offset_ref_initial_exp_growth_rate_prior_sd, + low=-0.01, + high=0.01, ) ) ``` @@ -137,7 +138,7 @@ viral_peak_rv = DistributionalVariable( ) dur_shed_after_peak_rv = DistributionalVariable( - "dur_shed_after_peak", dist.TruncatedNormal(dur_shed_mean, dur_shed_sd, low=0) + "dur_shed_after_peak", dist.TruncatedNormal(dur_shed_mean-t_peak_mean, jnp.sqrt(dur_shed_sd**2+t_peak_sd**2), low=0) ) max_shed_interval = dur_shed_mean + 3*dur_shed_sd ``` @@ -205,11 +206,10 @@ mean_initial_exp_growth_rate_rv = DistributionalVariable( "mean_initial_exp_growth_rate", dist.TruncatedNormal( loc=mean_initial_exp_growth_rate_prior_mean, scale=mean_initial_exp_growth_rate_prior_sd, - low=-0.001, - high=0.001, + low=-0.01, + high=0.01, ) ) - ``` ```{python} @@ -259,7 +259,6 @@ phi_rv = TransformedVariable( ), transforms=transforms.PowerTransform(-2), ) - ``` ```{python} @@ -272,7 +271,7 @@ log10_g_rv = DistributionalVariable( mode_sigma_ww_site_prior_mode = stan_data["mode_sigma_ww_site_prior_mode"] mode_sigma_ww_site_prior_sd = stan_data["mode_sigma_ww_site_prior_sd"] mode_sigma_ww_site_rv = DistributionalVariable( - "mode_sigma_ww_site", dist.Normal(mode_sigma_ww_site_prior_mode,mode_sigma_ww_site_prior_sd) + "mode_sigma_ww_site", dist.TruncatedNormal(mode_sigma_ww_site_prior_mode,mode_sigma_ww_site_prior_sd,low=0) ) sd_log_sigma_ww_site_prior_mode = stan_data["sd_log_sigma_ww_site_prior_mode"] @@ -307,6 +306,9 @@ my_model = ww_site_level_dynamics_model( sigma_i_first_obs_rv, sigma_initial_exp_growth_rate_rv, mean_initial_exp_growth_rate_rv, + offset_ref_logit_i_first_obs_rv, + offset_ref_initial_exp_growth_rate_rv, + offset_ref_log_r_t_rv, generation_interval_pmf_rv, infection_feedback_strength_rv, infection_feedback_pmf_rv, @@ -322,17 +324,21 @@ my_model = ww_site_level_dynamics_model( phi_rv, ww_ml_produced_per_day, pop_fraction, + ww_uncensored, + ww_censored, ww_sampled_lab_sites, ww_sampled_subpops, ww_sampled_times, ww_log_lod, - lab_site_to_subpop_map + lab_site_to_subpop_map, + hosp_times ) ``` ```{python} -with numpyro.handlers.seed(rng_seed=242): - test_model_sample = my_model.sample(n_datapoints=50) +# with numpyro.handlers.seed(rng_seed=242): +# test_model_sample = my_model.sample(n_datapoints=50) +# test_model_sample ``` ```{python} diff --git a/pyrenew_hew/site_level_dynamics_model.py b/pyrenew_hew/site_level_dynamics_model.py index 7ae22796..2b93713f 100644 --- a/pyrenew_hew/site_level_dynamics_model.py +++ b/pyrenew_hew/site_level_dynamics_model.py @@ -43,6 +43,9 @@ def __init__( sigma_i_first_obs_rv, sigma_initial_exp_growth_rate_rv, mean_initial_exp_growth_rate_rv, + offset_ref_logit_i_first_obs_rv, + offset_ref_initial_exp_growth_rate_rv, + offset_ref_log_r_t_rv, generation_interval_pmf_rv, infection_feedback_strength_rv, infection_feedback_pmf_rv, @@ -58,11 +61,14 @@ def __init__( phi_rv, ww_ml_produced_per_day, pop_fraction, + ww_uncensored, + ww_censored, ww_sampled_lab_sites, ww_sampled_subpops, ww_sampled_times, ww_log_lod, lab_site_to_subpop_map, + hosp_times, ): # numpydoc ignore=GL08 self.state_pop = state_pop self.n_subpops = n_subpops @@ -71,12 +77,12 @@ def __init__( self.n_initialization_points = n_initialization_points self.max_shed_interval = max_shed_interval self.i0_t_offset = i0_t_offset - self.log_r_mu_intercept_rv = (log_r_t_first_obs_rv,) + self.log_r_t_first_obs_rv = log_r_t_first_obs_rv self.autoreg_rt_rv = autoreg_rt_rv self.eta_sd_rv = eta_sd_rv self.t_peak_rv = t_peak_rv self.dur_shed_after_peak_rv = dur_shed_after_peak_rv - self.autoreg_rt_site_rv = autoreg_rt_subpop_rv + self.autoreg_rt_subpop_rv = autoreg_rt_subpop_rv self.sigma_rt_rv = sigma_rt_rv self.i_first_obs_over_n_rv = i_first_obs_over_n_rv self.sigma_i_first_obs_rv = sigma_i_first_obs_rv @@ -84,6 +90,11 @@ def __init__( sigma_initial_exp_growth_rate_rv ) self.mean_initial_exp_growth_rate_rv = mean_initial_exp_growth_rate_rv + self.offset_ref_logit_i_first_obs_rv = offset_ref_logit_i_first_obs_rv + self.offset_ref_initial_exp_growth_rate_rv = ( + offset_ref_initial_exp_growth_rate_rv + ) + self.offset_ref_log_r_t_rv = offset_ref_log_r_t_rv self.generation_interval_pmf_rv = generation_interval_pmf_rv self.p_hosp_mean_rv = p_hosp_mean_rv self.p_hosp_w_sd_rv = p_hosp_w_sd_rv @@ -97,11 +108,14 @@ def __init__( self.phi_rv = phi_rv self.ww_ml_produced_per_day = ww_ml_produced_per_day self.pop_fraction = pop_fraction + self.ww_uncensored = ww_uncensored + self.ww_censored = ww_censored self.ww_sampled_lab_sites = ww_sampled_lab_sites - self.ww_sampled_sites = ww_sampled_subpops + self.ww_sampled_subpops = ww_sampled_subpops self.ww_sampled_times = ww_sampled_times self.ww_log_lod = ww_log_lod - self.lab_site_to_site_map = lab_site_to_subpop_map + self.lab_site_to_subpop_map = lab_site_to_subpop_map + self.hosp_times = hosp_times self.inf_with_feedback_proc = InfectionsWithFeedback( infection_feedback_strength=infection_feedback_strength_rv, @@ -124,22 +138,19 @@ def sample( data_observed_hospital_admissions=None, data_observed_log_conc=None, ): # numpydoc ignore=GL08 - if n_datapoints is None and data_observed_hospital_admissions is None: - raise ValueError( - "Either n_datapoints or data_observed_hosp_admissions " - "must be passed." - ) - elif ( - n_datapoints is not None - and data_observed_hospital_admissions is not None + if ( + data_observed_hospital_admissions is None + and data_observed_log_conc is None ): - raise ValueError( - "Cannot pass both n_datapoints and data_observed_hospital_admissions." - ) - elif n_datapoints is None: - n_datapoints = len(data_observed_hospital_admissions) - else: n_datapoints = n_datapoints + else: + n_datapoints = 94 + # int( + # max( + # max(self.ww_sampled_times) + 1, + # len(data_observed_hospital_admissions), + # ) + # ) n_weeks_post_init = n_datapoints // 7 + 1 @@ -189,6 +200,11 @@ def sample( ) ) + offset_ref_log_r_t = self.offset_ref_log_r_t_rv() + log_rtu_ref_subpop_in_week = log_r_t_in_weeks + jnp.where( + self.n_subpops > 1, offset_ref_log_r_t, 0 + ) + if self.n_subpops > 1: sigma_i_first_obs = self.sigma_i_first_obs_rv() i_first_obs_over_n_non_ref_subpop_rv = TransformedVariable( @@ -217,6 +233,23 @@ def sample( ), transforms=lambda x: jnp.clip(x, -0.01, 0.01), ) + # initial_exp_growth_rate_non_ref_subpop_rv = DistributionalVariable( + # "initial_exp_growth_rate_non_ref_subpop_raw", + # dist.Normal( + # mean_initial_exp_growth_rate, + # sigma_initial_exp_growth_rate, + # ), + # reparam=LocScaleReparam(0), + # ) + autoreg_rt_subpop = self.autoreg_rt_subpop_rv() + sigma_rt = self.sigma_rt_rv() + rtu_subpop_ar_init_rv = DistributionalVariable( + "rtu_subpop_ar_init", + dist.Normal( + 0, + sigma_rt / jnp.sqrt(1 - jnp.pow(autoreg_rt_subpop, 2)), + ), + ) with numpyro.plate("n_subpops", self.n_subpops - 1): initial_exp_growth_rate_non_ref_subpop = ( @@ -225,6 +258,7 @@ def sample( i_first_obs_over_n_non_ref_subpop = ( i_first_obs_over_n_non_ref_subpop_rv() ) + rtu_subpop_ar_init = rtu_subpop_ar_init_rv() i_first_obs_over_n_subpop = jnp.hstack( [ @@ -238,50 +272,15 @@ def sample( initial_exp_growth_rate_non_ref_subpop, ] ) - else: - i_first_obs_over_n_subpop = i_first_obs_over_n_ref_subpop - initial_exp_growth_rate_subpop = initial_exp_growth_rate_ref_subpop - - numpyro.deterministic( - "i_first_obs_over_n_subpop", i_first_obs_over_n_subpop - ) - numpyro.deterministic( - "initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop - ) - - log_i0_subpop = ( - jnp.log(i_first_obs_over_n_subpop) - - self.unobs_time * initial_exp_growth_rate_subpop - ) - numpyro.deterministic("log_i0_subpop", log_i0_subpop) - - offset_ref_log_r_t = self.offset_ref_log_r_t_rv() - log_rtu_ref_subpop_in_week = log_r_t_in_weeks + jnp.where( - self.n_subpops > 1, offset_ref_log_r_t, 0 - ) - - if self.n_subpops > 1: - autoreg_rt_subpop = self.autoreg_rt_subpop_rv() - sigma_rt = self.sigma_rt_rv() - rtu_site_ar_init_rv = DistributionalVariable( - "rtu_site_ar_init", - dist.Normal( - 0, - sigma_rt / jnp.sqrt(1 - jnp.pow(autoreg_rt_subpop, 2)), - ), - ) - with numpyro.plate("n_subpops", self.n_subpops - 1): - rtu_site_ar_init = rtu_site_ar_init_rv() rtu_subpop_ar_proc = ARProcess() rtu_subpop_ar_weekly = rtu_subpop_ar_proc( noise_name="rtu_ar_proc", n=n_weeks_post_init, - init_vals=rtu_site_ar_init[jnp.newaxis], + init_vals=rtu_subpop_ar_init[jnp.newaxis], autoreg=autoreg_rt_subpop[jnp.newaxis], noise_sd=sigma_rt, ) - numpyro.deterministic("rtu_subpop_ar_weekly", rtu_subpop_ar_weekly) log_rtu_non_ref_subpop_in_week = ( rtu_subpop_ar_weekly + log_r_t_in_weeks[:, jnp.newaxis] @@ -294,8 +293,23 @@ def sample( axis=1, ) else: + i_first_obs_over_n_subpop = i_first_obs_over_n_ref_subpop + initial_exp_growth_rate_subpop = initial_exp_growth_rate_ref_subpop log_rtu_subpop_in_week = log_rtu_ref_subpop_in_week[:, jnp.newaxis] + numpyro.deterministic( + "i_first_obs_over_n_subpop", i_first_obs_over_n_subpop + ) + numpyro.deterministic( + "initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop + ) + + log_i0_subpop = ( + jnp.log(i_first_obs_over_n_subpop) + - self.unobs_time * initial_exp_growth_rate_subpop + ) + numpyro.deterministic("log_i0_subpop", log_i0_subpop) + rtu_subpop = jnp.squeeze( jnp.repeat( jnp.exp(log_rtu_subpop_in_week), @@ -303,7 +317,6 @@ def sample( axis=0, )[:n_datapoints, :] ) - numpyro.deterministic("rtu_subpop", rtu_subpop) i0_subpop_rv = DeterministicVariable( @@ -361,7 +374,7 @@ def batch_colvolve_fn(m): # expected observed viral genomes/mL at all observed and forecasted times model_log_v_ot = ( jnp.log(10) * log10_g - + jnp.log(model_net_i[:n_datapoints, :] + 1e-8) + + jnp.log(model_net_i + 1e-8) - jnp.log(self.ww_ml_produced_per_day) ) numpyro.deterministic("model_log_v_ot", model_log_v_ot) @@ -442,16 +455,12 @@ def batch_colvolve_fn(m): ww_site_mod = ww_site_mod_rv() sigma_ww_site = sigma_ww_site_rv() - # Observations at the site level (genomes/person/day) are: - # get a vector of genomes/person/day on the days WW was measured - # These are the true expected genomes at the site level before observation error - # (which is at the lab-site level) + # expected observations at each site in log scale exp_obs_log_v_true = model_log_v_ot[ - self.ww_sampled_sites, self.ww_sampled_times + self.ww_sampled_times, self.ww_sampled_subpops ] - # LHS log transformed obs genomes per person-day, RHS multiplies the expected observed - # genomes by the site-specific multiplier at that sampling time + # multiplies the expected observed genomes by the site-specific multiplier at that sampling time exp_obs_log_v = ( exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] ) @@ -461,22 +470,34 @@ def batch_colvolve_fn(m): ) observed_hospital_admissions = hospital_admission_obs_rv( - mu=latent_hospital_admissions, + mu=latent_hospital_admissions[self.hosp_times], obs=data_observed_hospital_admissions, ) numpyro.sample( "log_conc_obs", CensoredNormal( - loc=exp_obs_log_v, - scale=sigma_ww_site[self.ww_sampled_lab_sites], - lower_limit=self.ww_log_lod, + loc=exp_obs_log_v[self.ww_uncensored], + scale=sigma_ww_site[ + self.ww_sampled_lab_sites[self.ww_uncensored] + ], + lower_limit=self.ww_log_lod[self.ww_uncensored], ), - obs=data_observed_log_conc, + obs=data_observed_log_conc[self.ww_uncensored], ) - ww_pred = numpyro.sample( - "site_ww_pred", + # numpyro.sample( + # "log_conc_obs", + # CensoredNormal( + # loc=exp_obs_log_v, + # scale=sigma_ww_site[self.ww_sampled_lab_sites], + # lower_limit=self.ww_log_lod, + # ), + # obs=data_observed_log_conc, + # ) + + ww_pred_log = numpyro.sample( + "site_ww_pred_log", dist.Normal( loc=model_log_v_ot[:, self.lab_site_to_subpop_map] + ww_site_mod, @@ -491,7 +512,7 @@ def batch_colvolve_fn(m): state_log_c = ( jnp.log(10) * log10_g - + jnp.log(state_model_net_i[:n_datapoints] + 1e-8) + + jnp.log(state_model_net_i + 1e-8) - jnp.log(self.ww_ml_produced_per_day) ) numpyro.deterministic("state_log_c", state_log_c) @@ -514,5 +535,5 @@ def batch_colvolve_fn(m): return ( latent_hospital_admissions, observed_hospital_admissions, - ww_pred, + ww_pred_log, ) From 4bd1dc0baa3456b144158e2ce2e3830f7fb0cdcf Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 10 Oct 2024 12:43:52 -0400 Subject: [PATCH 27/50] initialization --- pyrenew_hew/initialization.py | 32 +++++++++++--- pyrenew_hew/site_level_dynamics_model.py | 54 +++++++++--------------- 2 files changed, 46 insertions(+), 40 deletions(-) diff --git a/pyrenew_hew/initialization.py b/pyrenew_hew/initialization.py index 640cd487..57a968ac 100644 --- a/pyrenew_hew/initialization.py +++ b/pyrenew_hew/initialization.py @@ -15,6 +15,28 @@ def get_initialization(stan_data, stdev, rng_key): ) init_vals = { + "offset_ref_log_r_t": ( + dist.Normal( + stan_data["offset_ref_log_r_t_prior_mean"], stdev + ).sample(rng_key) + if stan_data["n_subpops"] > 1 + else None + ), + "offset_ref_logit_i_first_obs": ( + dist.Normal( + stan_data["offset_ref_logit_i_first_obs_prior_mean"], stdev + ).sample(rng_key) + if stan_data["n_subpops"] > 1 + else None + ), + "offset_ref_initial_exp_growth_rate": ( + dist.Normal( + stan_data["offset_ref_initial_exp_growth_rate_prior_mean"], + stdev, + ).sample(rng_key) + if stan_data["n_subpops"] > 1 + else None + ), "eta_sd": jnp.abs(dist.Normal(0, stdev).sample(rng_key)), "autoreg_rt": jnp.abs( dist.Normal( @@ -23,12 +45,12 @@ def get_initialization(stan_data, stdev, rng_key): 0.05, ).sample(rng_key) ), - "log_r_mu_intercept": dist.Normal( + "log_r_t_first_obs": dist.Normal( convert_to_logmean_log_sd(1, stdev)[0], convert_to_logmean_log_sd(1, stdev)[1], ).sample(rng_key), "sigma_rt": jnp.abs(dist.Normal(0, stdev).sample(rng_key)), - "autoreg_rt_site": jnp.abs(dist.Normal(0.5, 0.05).sample(rng_key)), + "autoreg_rt_subpop": jnp.abs(dist.Normal(0.5, 0.05).sample(rng_key)), "sigma_i_first_obs": jnp.abs(dist.Normal(0, stdev).sample(rng_key)), "sigma_initial_exp_growth_rate": jnp.abs( dist.Normal(0, stdev).sample(rng_key) @@ -60,11 +82,7 @@ def get_initialization(stan_data, stdev, rng_key): stan_data["viral_shedding_pars"][0], stdev * stan_data["viral_shedding_pars"][1], ).sample(rng_key), - "viral_peak": dist.Normal( - stan_data["viral_shedding_pars"][2], - stdev * stan_data["viral_shedding_pars"][3], - ).sample(rng_key), - "dur_shed": dist.Normal( + "dur_shed_after_peak": dist.Normal( stan_data["viral_shedding_pars"][4], stdev * stan_data["viral_shedding_pars"][5], ).sample(rng_key), diff --git a/pyrenew_hew/site_level_dynamics_model.py b/pyrenew_hew/site_level_dynamics_model.py index 2b93713f..e21a0f8a 100644 --- a/pyrenew_hew/site_level_dynamics_model.py +++ b/pyrenew_hew/site_level_dynamics_model.py @@ -144,13 +144,10 @@ def sample( ): n_datapoints = n_datapoints else: - n_datapoints = 94 - # int( - # max( - # max(self.ww_sampled_times) + 1, - # len(data_observed_hospital_admissions), - # ) - # ) + n_datapoints = max( + len(data_observed_log_conc), + len(data_observed_hospital_admissions), + ) n_weeks_post_init = n_datapoints // 7 + 1 @@ -233,14 +230,7 @@ def sample( ), transforms=lambda x: jnp.clip(x, -0.01, 0.01), ) - # initial_exp_growth_rate_non_ref_subpop_rv = DistributionalVariable( - # "initial_exp_growth_rate_non_ref_subpop_raw", - # dist.Normal( - # mean_initial_exp_growth_rate, - # sigma_initial_exp_growth_rate, - # ), - # reparam=LocScaleReparam(0), - # ) + autoreg_rt_subpop = self.autoreg_rt_subpop_rv() sigma_rt = self.sigma_rt_rv() rtu_subpop_ar_init_rv = DistributionalVariable( @@ -474,28 +464,26 @@ def batch_colvolve_fn(m): obs=data_observed_hospital_admissions, ) - numpyro.sample( - "log_conc_obs", - CensoredNormal( - loc=exp_obs_log_v[self.ww_uncensored], - scale=sigma_ww_site[ - self.ww_sampled_lab_sites[self.ww_uncensored] - ], - lower_limit=self.ww_log_lod[self.ww_uncensored], - ), - obs=data_observed_log_conc[self.ww_uncensored], - ) - # numpyro.sample( # "log_conc_obs", # CensoredNormal( - # loc=exp_obs_log_v, - # scale=sigma_ww_site[self.ww_sampled_lab_sites], - # lower_limit=self.ww_log_lod, + # loc=exp_obs_log_v[self.ww_uncensored], + # scale=sigma_ww_site[self.ww_sampled_lab_sites[self.ww_uncensored]], + # lower_limit=self.ww_log_lod[self.ww_uncensored], # ), - # obs=data_observed_log_conc, + # obs=data_observed_log_conc[self.ww_uncensored], # ) + numpyro.sample( + "log_conc_obs", + CensoredNormal( + loc=exp_obs_log_v, + scale=sigma_ww_site[self.ww_sampled_lab_sites], + lower_limit=self.ww_log_lod, + ), + obs=data_observed_log_conc, + ) + ww_pred_log = numpyro.sample( "site_ww_pred_log", dist.Normal( @@ -517,8 +505,8 @@ def batch_colvolve_fn(m): ) numpyro.deterministic("state_log_c", state_log_c) - exp_state_ww_conc = jnp.exp(state_log_c) - numpyro.deterministic("exp_state_ww_conc", exp_state_ww_conc) + expected_state_ww_conc = jnp.exp(state_log_c) + numpyro.deterministic("expected_state_ww_conc", expected_state_ww_conc) state_rt = ( state_inf_per_capita[-n_datapoints:] From 29919234a43d6d7a6c67c246182abb2c4693fb05 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 10 Oct 2024 17:39:17 -0400 Subject: [PATCH 28/50] replace censorednormal dist --- pyrenew_hew/site_level_dynamics_model.py | 101 +++++++++-------------- 1 file changed, 38 insertions(+), 63 deletions(-) diff --git a/pyrenew_hew/site_level_dynamics_model.py b/pyrenew_hew/site_level_dynamics_model.py index e21a0f8a..a7529caf 100644 --- a/pyrenew_hew/site_level_dynamics_model.py +++ b/pyrenew_hew/site_level_dynamics_model.py @@ -86,9 +86,7 @@ def __init__( self.sigma_rt_rv = sigma_rt_rv self.i_first_obs_over_n_rv = i_first_obs_over_n_rv self.sigma_i_first_obs_rv = sigma_i_first_obs_rv - self.sigma_initial_exp_growth_rate_rv = ( - sigma_initial_exp_growth_rate_rv - ) + self.sigma_initial_exp_growth_rate_rv = sigma_initial_exp_growth_rate_rv self.mean_initial_exp_growth_rate_rv = mean_initial_exp_growth_rate_rv self.offset_ref_logit_i_first_obs_rv = offset_ref_logit_i_first_obs_rv self.offset_ref_initial_exp_growth_rate_rv = ( @@ -138,12 +136,10 @@ def sample( data_observed_hospital_admissions=None, data_observed_log_conc=None, ): # numpydoc ignore=GL08 - if ( - data_observed_hospital_admissions is None - and data_observed_log_conc is None - ): + if data_observed_hospital_admissions is None and data_observed_log_conc is None: n_datapoints = n_datapoints else: + # n_datapoints = len(data_observed_hospital_admissions) n_datapoints = max( len(data_observed_log_conc), len(data_observed_hospital_admissions), @@ -190,11 +186,8 @@ def sample( transforms.logit(i_first_obs_over_n) + jnp.where(self.n_subpops > 1, offset_ref_logit_i_first_obs, 0) ) - initial_exp_growth_rate_ref_subpop = ( - mean_initial_exp_growth_rate - + jnp.where( - self.n_subpops > 1, offset_ref_initial_exp_growth_rate, 0 - ) + initial_exp_growth_rate_ref_subpop = mean_initial_exp_growth_rate + jnp.where( + self.n_subpops > 1, offset_ref_initial_exp_growth_rate, 0 ) offset_ref_log_r_t = self.offset_ref_log_r_t_rv() @@ -215,9 +208,7 @@ def sample( ), transforms=transforms.SigmoidTransform(), ) - sigma_initial_exp_growth_rate = ( - self.sigma_initial_exp_growth_rate_rv() - ) + sigma_initial_exp_growth_rate = self.sigma_initial_exp_growth_rate_rv() initial_exp_growth_rate_non_ref_subpop_rv = TransformedVariable( "clipped_initial_exp_growth_rate_non_ref_subpop", DistributionalVariable( @@ -287,9 +278,7 @@ def sample( initial_exp_growth_rate_subpop = initial_exp_growth_rate_ref_subpop log_rtu_subpop_in_week = log_rtu_ref_subpop_in_week[:, jnp.newaxis] - numpyro.deterministic( - "i_first_obs_over_n_subpop", i_first_obs_over_n_subpop - ) + numpyro.deterministic("i_first_obs_over_n_subpop", i_first_obs_over_n_subpop) numpyro.deterministic( "initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop ) @@ -309,9 +298,7 @@ def sample( ) numpyro.deterministic("rtu_subpop", rtu_subpop) - i0_subpop_rv = DeterministicVariable( - "i0_subpop", jnp.exp(log_i0_subpop) - ) + i0_subpop_rv = DeterministicVariable("i0_subpop", jnp.exp(log_i0_subpop)) initial_exp_growth_rate_subpop_rv = DeterministicVariable( "initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop ) @@ -345,18 +332,16 @@ def sample( r_subpop_t = inf_with_feedback_proc_sample.rt numpyro.deterministic("r_subpop_t", r_subpop_t) - state_inf_per_capita = jnp.sum( - self.pop_fraction * new_i_subpop, axis=1 - ) + state_inf_per_capita = jnp.sum(self.pop_fraction * new_i_subpop, axis=1) numpyro.deterministic("state_inf_per_capita", state_inf_per_capita) # number of net infected individuals shedding on each day (sum of individuals in diff stages of infection) def batch_colvolve_fn(m): return jnp.convolve(m, s, mode="valid") - model_net_i = jax.vmap(batch_colvolve_fn, in_axes=1, out_axes=1)( - new_i_subpop - )[-n_datapoints:, :] + model_net_i = jax.vmap(batch_colvolve_fn, in_axes=1, out_axes=1)(new_i_subpop)[ + -n_datapoints:, : + ] numpyro.deterministic("model_net_i", model_net_i) log10_g = self.log10_g_rv() @@ -403,13 +388,11 @@ def batch_colvolve_fn(m): hosp_wday_effect = tile_until_n(hosp_wday_effect_raw, n_datapoints) - potential_latent_hospital_admissions = ( - compute_delay_ascertained_incidence( - p_observed_given_incident=1, - latent_incidence=state_inf_per_capita, - delay_incidence_to_observation_pmf=inf_to_hosp, - )[-n_datapoints:] - ) + potential_latent_hospital_admissions = compute_delay_ascertained_incidence( + p_observed_given_incident=1, + latent_incidence=state_inf_per_capita, + delay_incidence_to_observation_pmf=inf_to_hosp, + )[-n_datapoints:] latent_hospital_admissions = ( potential_latent_hospital_admissions @@ -417,9 +400,7 @@ def batch_colvolve_fn(m): * hosp_wday_effect * self.state_pop ) - numpyro.deterministic( - "latent_hospital_admissions", latent_hospital_admissions - ) + numpyro.deterministic("latent_hospital_admissions", latent_hospital_admissions) mode_sigma_ww_site = self.mode_sigma_ww_site_rv() sd_log_sigma_ww_site = self.sd_log_sigma_ww_site_rv() @@ -451,9 +432,7 @@ def batch_colvolve_fn(m): ] # multiplies the expected observed genomes by the site-specific multiplier at that sampling time - exp_obs_log_v = ( - exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] - ) + exp_obs_log_v = exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] hospital_admission_obs_rv = NegativeBinomialObservation( "observed_hospital_admissions", concentration_rv=self.phi_rv @@ -464,38 +443,36 @@ def batch_colvolve_fn(m): obs=data_observed_hospital_admissions, ) - # numpyro.sample( - # "log_conc_obs", - # CensoredNormal( - # loc=exp_obs_log_v[self.ww_uncensored], - # scale=sigma_ww_site[self.ww_sampled_lab_sites[self.ww_uncensored]], - # lower_limit=self.ww_log_lod[self.ww_uncensored], - # ), - # obs=data_observed_log_conc[self.ww_uncensored], - # ) - numpyro.sample( "log_conc_obs", - CensoredNormal( - loc=exp_obs_log_v, - scale=sigma_ww_site[self.ww_sampled_lab_sites], - lower_limit=self.ww_log_lod, + dist.Normal( + loc=exp_obs_log_v[self.ww_uncensored], + scale=sigma_ww_site[self.ww_sampled_lab_sites[self.ww_uncensored]], + ), + obs=( + data_observed_log_conc[self.ww_uncensored] + if data_observed_log_conc is not None + else None ), - obs=data_observed_log_conc, ) + if self.ww_censored.shape[0] != 0: + log_cdf_values = dist.Normal( + loc=exp_obs_log_v[self.ww_censored], + scale=sigma_ww_site[self.ww_sampled_lab_sites[self.ww_censored]], + ).log_cdf(self.ww_log_lod[self.ww_censored]) + numpyro.factor("log_prob_censored", log_cdf_values.sum()) ww_pred_log = numpyro.sample( "site_ww_pred_log", dist.Normal( - loc=model_log_v_ot[:, self.lab_site_to_subpop_map] - + ww_site_mod, + loc=model_log_v_ot[:, self.lab_site_to_subpop_map] + ww_site_mod, scale=sigma_ww_site, ), ) - state_model_net_i = jnp.convolve( - state_inf_per_capita, s, mode="valid" - )[-n_datapoints:] + state_model_net_i = jnp.convolve(state_inf_per_capita, s, mode="valid")[ + -n_datapoints: + ] numpyro.deterministic("state_model_net_i", state_model_net_i) state_log_c = ( @@ -512,9 +489,7 @@ def batch_colvolve_fn(m): state_inf_per_capita[-n_datapoints:] / jnp.convolve( state_inf_per_capita, - jnp.hstack( - (jnp.array([0]), jnp.array(generation_interval_pmf)) - ), + jnp.hstack((jnp.array([0]), jnp.array(generation_interval_pmf))), mode="valid", )[-n_datapoints:] ) From ac8737edb3c49f337eb9613a0fec129fbd663735 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 10 Oct 2024 23:26:37 -0400 Subject: [PATCH 29/50] add plots, clean up demo --- ...qmd => site_level_dynamics_model_demo.qmd} | 94 ++++++++++--------- pyrenew_hew/plotting.py | 13 +-- pyrenew_hew/site_level_dynamics_model.py | 14 +-- 3 files changed, 63 insertions(+), 58 deletions(-) rename notebooks/{site_level_ww_model_demo.qmd => site_level_dynamics_model_demo.qmd} (88%) diff --git a/notebooks/site_level_ww_model_demo.qmd b/notebooks/site_level_dynamics_model_demo.qmd similarity index 88% rename from notebooks/site_level_ww_model_demo.qmd rename to notebooks/site_level_dynamics_model_demo.qmd index c0aacae7..3ae6373a 100644 --- a/notebooks/site_level_ww_model_demo.qmd +++ b/notebooks/site_level_dynamics_model_demo.qmd @@ -13,9 +13,10 @@ from pyrenew.deterministic import DeterministicVariable, DeterministicPMF from pyrenew.randomvariable import DistributionalVariable, TransformedVariable from pyrenew_hew.site_level_dynamics_model import ww_site_level_dynamics_model from pyrenew_hew.initialization import get_initialization -from numpyro.infer.initialization import init_to_sample, init_to_value import jax - +from numpyro.infer.initialization import init_to_sample, init_to_value +import pyrenew_hew.plotting as plotting +import matplotlib.pyplot as plt numpyro.set_host_device_count(4) ``` @@ -66,7 +67,6 @@ lab_site_to_subpop_map = jnp.array(stan_data["lab_site_to_subpop_map"]) -1 #ve data_observed_log_conc = jnp.array(stan_data["log_conc"]) data_observed_hospital_admissions = jnp.array(stan_data["hosp"]) - ``` ```{python} @@ -90,7 +90,6 @@ log_r_t_first_obs_rv = DistributionalVariable( ``` ```{python} - offset_ref_log_r_t_prior_mean=stan_data["offset_ref_log_r_t_prior_mean"] offset_ref_log_r_t_prior_sd = stan_data["offset_ref_log_r_t_prior_sd"] offset_ref_log_r_t_rv = DistributionalVariable( @@ -336,57 +335,68 @@ my_model = ww_site_level_dynamics_model( ``` ```{python} -# with numpyro.handlers.seed(rng_seed=242): -# test_model_sample = my_model.sample(n_datapoints=50) -# test_model_sample +n_forecast_days = 35 + +prior_predictive = my_model.prior_predictive( + n_datapoints=len(data_observed_hospital_admissions) + n_forecast_days, + numpyro_predictive_args={"num_samples": 100}, +) ``` ```{python} -# n_forecast_days = 35 +init_vals = get_initialization(stan_data,stdev=0.01,rng_key=jax.random.PRNGKey(0)) +my_model.run( + num_warmup=500, + num_samples=100, + rng_key=jax.random.key(223), + data_observed_hospital_admissions=data_observed_hospital_admissions, + data_observed_log_conc=data_observed_log_conc, + mcmc_args=dict(num_chains=4), + nuts_args=dict(init_strategy=init_to_value(values=init_vals)) +) +``` -# prior_predictive = my_model.prior_predictive( -# n_datapoints=len(data_observed_hospital_admissions) + n_forecast_days, -# numpyro_predictive_args={"num_samples": 200}, -# ) +```{python} +posterior_predictive = my_model.posterior_predictive( + n_datapoints=len(data_observed_hospital_admissions) + n_forecast_days +) ``` + +We can use initial values set in the stan code by using `init_strategy = init_to_value` and use the following to generate initial values +`init_vals = get_initialization(stan_data,stdev=0.01,rng_key=jax.random.PRNGKey(0))` + ```{python} -init_vals = get_initialization(stan_data,stdev=0.01,rng_key=jax.random.PRNGKey(0)) +import arviz as az +idata = az.from_numpyro( + my_model.mcmc, + posterior_predictive=posterior_predictive, + prior=prior_predictive, +) ``` ```{python} -from numpyro.infer.initialization import init_to_sample, init_to_value -try: - my_model.run( - num_warmup=100, - num_samples=100, - rng_key=jax.random.key(223), - data_observed_hospital_admissions=data_observed_hospital_admissions, - data_observed_log_conc=data_observed_log_conc, - mcmc_args=dict(num_chains=1), - nuts_args=dict(init_strategy=init_to_value(values=init_vals)) - ) -except RuntimeError as e: - print(f"RuntimeError occurred: {e}") +plotting.plot_predictive(idata, prior=True) ``` ```{python} -try: - with numpyro.handlers.trace() as tr: - my_model.run( - num_warmup=100, - num_samples=100, - rng_key=jax.random.key(223), - data_observed_hospital_admissions=data_observed_hospital_admissions, - data_observed_log_conc=data_observed_log_conc, - mcmc_args=dict(num_chains=1), - nuts_args=dict(init_strategy=init_to_value(values=init_vals)) - ) -except AssertionError as e: - print(f"AssertionError occurred: {e}") +plotting.plot_predictive(idata) +``` + +```{python} +plotting.plot_posterior(idata,'state_rt') ``` ```{python} -# Print trace of the random variables -for site in tr.values(): - print(site['name'], site['value']) +for i in range(n_subpops): + plotting.plot_posterior(idata, 'r_subpop_t', dim_1=i) ``` + +```{python} +for i in range(n_ww_lab_sites): + plotting.plot_posterior(idata, 'site_ww_pred_log', dim_1=i) +``` + +```{python} +diagnostic_stats_summary = az.summary(idata) +``` + diff --git a/pyrenew_hew/plotting.py b/pyrenew_hew/plotting.py index c94d5547..0f3b9449 100644 --- a/pyrenew_hew/plotting.py +++ b/pyrenew_hew/plotting.py @@ -9,9 +9,13 @@ def compute_eti(dataset, eti_prob): return eti_bdry.values.T -def plot_posterior(idata, name): +def plot_posterior(idata, name, dim_1=None): x_data = idata.posterior[f"{name}_dim_0"] - y_data = idata.posterior[name] + y_data = ( + idata.posterior[name] + if dim_1 is None + else idata.posterior[name].isel({f"{name}_dim_1": dim_1}) + ) fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( x_data, @@ -45,14 +49,11 @@ def plot_posterior(idata, name): axes.set_title(name, fontsize=10) axes.set_xlabel("Time", fontsize=10) axes.set_ylabel(name, fontsize=10) - return fig def plot_predictive(idata, prior=False): prior_or_post_text = "Prior" if prior else "Posterior" - predictive_obj = ( - idata.prior_predictive if prior else idata.posterior_predictive - ) + predictive_obj = idata.prior_predictive if prior else idata.posterior_predictive x_data = predictive_obj["observed_hospital_admissions_dim_0"] y_data = predictive_obj["observed_hospital_admissions"] diff --git a/pyrenew_hew/site_level_dynamics_model.py b/pyrenew_hew/site_level_dynamics_model.py index a7529caf..71ea1650 100644 --- a/pyrenew_hew/site_level_dynamics_model.py +++ b/pyrenew_hew/site_level_dynamics_model.py @@ -139,11 +139,7 @@ def sample( if data_observed_hospital_admissions is None and data_observed_log_conc is None: n_datapoints = n_datapoints else: - # n_datapoints = len(data_observed_hospital_admissions) - n_datapoints = max( - len(data_observed_log_conc), - len(data_observed_hospital_admissions), - ) + n_datapoints = len(data_observed_hospital_admissions) n_weeks_post_init = n_datapoints // 7 + 1 @@ -437,9 +433,8 @@ def batch_colvolve_fn(m): hospital_admission_obs_rv = NegativeBinomialObservation( "observed_hospital_admissions", concentration_rv=self.phi_rv ) - observed_hospital_admissions = hospital_admission_obs_rv( - mu=latent_hospital_admissions[self.hosp_times], + mu=latent_hospital_admissions, obs=data_observed_hospital_admissions, ) @@ -462,7 +457,7 @@ def batch_colvolve_fn(m): ).log_cdf(self.ww_log_lod[self.ww_censored]) numpyro.factor("log_prob_censored", log_cdf_values.sum()) - ww_pred_log = numpyro.sample( + site_ww_pred_log = numpyro.sample( "site_ww_pred_log", dist.Normal( loc=model_log_v_ot[:, self.lab_site_to_subpop_map] + ww_site_mod, @@ -496,7 +491,6 @@ def batch_colvolve_fn(m): numpyro.deterministic("state_rt", state_rt) return ( - latent_hospital_admissions, observed_hospital_admissions, - ww_pred_log, + site_ww_pred_log, ) From 115479ea0b2b0dadc9e4aceaaef51370b4839df7 Mon Sep 17 00:00:00 2001 From: sbidari Date: Fri, 11 Oct 2024 11:17:33 -0400 Subject: [PATCH 30/50] clean up --- notebooks/site_level_dynamics_model_demo.qmd | 68 ++++++++++++++++++-- pyrenew_hew/site_level_dynamics_model.py | 8 ++- 2 files changed, 68 insertions(+), 8 deletions(-) diff --git a/notebooks/site_level_dynamics_model_demo.qmd b/notebooks/site_level_dynamics_model_demo.qmd index 3ae6373a..55cbdf8b 100644 --- a/notebooks/site_level_dynamics_model_demo.qmd +++ b/notebooks/site_level_dynamics_model_demo.qmd @@ -62,7 +62,7 @@ ww_sampled_subpops = jnp.array(stan_data["ww_sampled_subpops"]) -1 # vector of ww_sampled_times = jnp.array(stan_data["ww_sampled_times"]) -1 # a list of all of the days on which WW is sampled, mapped to corresponding subpops ww_sampled_lab_sites = jnp.array(stan_data["ww_sampled_lab_sites"]) - 1 hosp_times = jnp.array(stan_data["hosp_times"]) -1 - +max_ww_sampled_days = max(stan_data["ww_sampled_times"]) lab_site_to_subpop_map = jnp.array(stan_data["lab_site_to_subpop_map"]) -1 #vector mapping the subpops to lab-site combos data_observed_log_conc = jnp.array(stan_data["log_conc"]) @@ -330,7 +330,8 @@ my_model = ww_site_level_dynamics_model( ww_sampled_times, ww_log_lod, lab_site_to_subpop_map, - hosp_times + hosp_times, + max_ww_sampled_days, ) ``` @@ -338,13 +339,12 @@ my_model = ww_site_level_dynamics_model( n_forecast_days = 35 prior_predictive = my_model.prior_predictive( - n_datapoints=len(data_observed_hospital_admissions) + n_forecast_days, + n_datapoints= max(len(data_observed_hospital_admissions), max_ww_sampled_days) + n_forecast_days, numpyro_predictive_args={"num_samples": 100}, ) ``` ```{python} -init_vals = get_initialization(stan_data,stdev=0.01,rng_key=jax.random.PRNGKey(0)) my_model.run( num_warmup=500, num_samples=100, @@ -352,13 +352,13 @@ my_model.run( data_observed_hospital_admissions=data_observed_hospital_admissions, data_observed_log_conc=data_observed_log_conc, mcmc_args=dict(num_chains=4), - nuts_args=dict(init_strategy=init_to_value(values=init_vals)) + nuts_args=dict(init_strategy=init_to_sample) ) ``` ```{python} posterior_predictive = my_model.posterior_predictive( - n_datapoints=len(data_observed_hospital_admissions) + n_forecast_days + n_datapoints= max(len(data_observed_hospital_admissions), max_ww_sampled_days) + n_forecast_days, ) ``` @@ -374,6 +374,10 @@ idata = az.from_numpyro( ) ``` +```{python} +idata +``` + ```{python} plotting.plot_predictive(idata, prior=True) ``` @@ -386,6 +390,45 @@ plotting.plot_predictive(idata) plotting.plot_posterior(idata,'state_rt') ``` +```{python} +x_data = idata.posterior_predictive['state_rt_dim_0'] +y_data = idata.posterior_predictive['state_rt'] + +def compute_eti(dataset, eti_prob): + eti_bdry = dataset.quantile( + ((1 - eti_prob) / 2, 1 / 2 + eti_prob / 2), dim=("chain", "draw") + ) + return eti_bdry.values.T +fig, axes = plt.subplots(figsize=(6, 5)) +az.plot_hdi( + x_data, + hdi_data=compute_eti(y_data, 0.9), + color="C0", + smooth=False, + fill_kwargs={"alpha": 0.3}, + ax=axes, +) + +az.plot_hdi( + x_data, + hdi_data=compute_eti(y_data, 0.5), + color="C0", + smooth=False, + fill_kwargs={"alpha": 0.6}, + ax=axes, +) + +# Add median of the posterior to the figure +median_ts = y_data.median(dim=["chain", "draw"]) + +plt.plot( + x_data, + median_ts, + color="C0", + label="Median", +) +``` + ```{python} for i in range(n_subpops): plotting.plot_posterior(idata, 'r_subpop_t', dim_1=i) @@ -400,3 +443,16 @@ for i in range(n_ww_lab_sites): diagnostic_stats_summary = az.summary(idata) ``` +```{python} +max(diagnostic_stats_summary['r_hat']) +``` + +```{python} +min(diagnostic_stats_summary['ess_bulk']) +``` + +```{python} +min(diagnostic_stats_summary['ess_tail']) +``` + + diff --git a/pyrenew_hew/site_level_dynamics_model.py b/pyrenew_hew/site_level_dynamics_model.py index 71ea1650..7fd53773 100644 --- a/pyrenew_hew/site_level_dynamics_model.py +++ b/pyrenew_hew/site_level_dynamics_model.py @@ -69,6 +69,7 @@ def __init__( ww_log_lod, lab_site_to_subpop_map, hosp_times, + max_ww_sampled_days, ): # numpydoc ignore=GL08 self.state_pop = state_pop self.n_subpops = n_subpops @@ -114,6 +115,7 @@ def __init__( self.ww_log_lod = ww_log_lod self.lab_site_to_subpop_map = lab_site_to_subpop_map self.hosp_times = hosp_times + self.max_ww_sampled_days = max_ww_sampled_days self.inf_with_feedback_proc = InfectionsWithFeedback( infection_feedback_strength=infection_feedback_strength_rv, @@ -139,7 +141,9 @@ def sample( if data_observed_hospital_admissions is None and data_observed_log_conc is None: n_datapoints = n_datapoints else: - n_datapoints = len(data_observed_hospital_admissions) + n_datapoints = max( + len(data_observed_hospital_admissions), self.max_ww_sampled_days + ) n_weeks_post_init = n_datapoints // 7 + 1 @@ -434,7 +438,7 @@ def batch_colvolve_fn(m): "observed_hospital_admissions", concentration_rv=self.phi_rv ) observed_hospital_admissions = hospital_admission_obs_rv( - mu=latent_hospital_admissions, + mu=latent_hospital_admissions[self.hosp_times], obs=data_observed_hospital_admissions, ) From 73e433cebdfe6832e77493a09c8f5390e72016fd Mon Sep 17 00:00:00 2001 From: sbidari Date: Fri, 11 Oct 2024 11:34:08 -0400 Subject: [PATCH 31/50] pre-commit changes --- notebooks/site_level_dynamics_model_demo.qmd | 4 +- pyrenew_hew/plotting.py | 4 +- pyrenew_hew/site_level_dynamics_model.py | 83 +++++++++++++------- pyrenew_hew/utils.py | 26 +++--- 4 files changed, 74 insertions(+), 43 deletions(-) diff --git a/notebooks/site_level_dynamics_model_demo.qmd b/notebooks/site_level_dynamics_model_demo.qmd index 55cbdf8b..769d058f 100644 --- a/notebooks/site_level_dynamics_model_demo.qmd +++ b/notebooks/site_level_dynamics_model_demo.qmd @@ -362,7 +362,7 @@ posterior_predictive = my_model.posterior_predictive( ) ``` -We can use initial values set in the stan code by using `init_strategy = init_to_value` and use the following to generate initial values +We can use initial values set in the stan code by using `init_strategy = init_to_value` and use the following to generate initial values `init_vals = get_initialization(stan_data,stdev=0.01,rng_key=jax.random.PRNGKey(0))` ```{python} @@ -454,5 +454,3 @@ min(diagnostic_stats_summary['ess_bulk']) ```{python} min(diagnostic_stats_summary['ess_tail']) ``` - - diff --git a/pyrenew_hew/plotting.py b/pyrenew_hew/plotting.py index 0f3b9449..69cdbf10 100644 --- a/pyrenew_hew/plotting.py +++ b/pyrenew_hew/plotting.py @@ -53,7 +53,9 @@ def plot_posterior(idata, name, dim_1=None): def plot_predictive(idata, prior=False): prior_or_post_text = "Prior" if prior else "Posterior" - predictive_obj = idata.prior_predictive if prior else idata.posterior_predictive + predictive_obj = ( + idata.prior_predictive if prior else idata.posterior_predictive + ) x_data = predictive_obj["observed_hospital_admissions_dim_0"] y_data = predictive_obj["observed_hospital_admissions"] diff --git a/pyrenew_hew/site_level_dynamics_model.py b/pyrenew_hew/site_level_dynamics_model.py index 7fd53773..73c4b497 100644 --- a/pyrenew_hew/site_level_dynamics_model.py +++ b/pyrenew_hew/site_level_dynamics_model.py @@ -8,7 +8,6 @@ from pyrenew.arrayutils import tile_until_n from pyrenew.convolve import compute_delay_ascertained_incidence from pyrenew.deterministic import DeterministicVariable -from pyrenew.distributions import CensoredNormal from pyrenew.latent import ( InfectionInitializationProcess, InfectionsWithFeedback, @@ -87,7 +86,9 @@ def __init__( self.sigma_rt_rv = sigma_rt_rv self.i_first_obs_over_n_rv = i_first_obs_over_n_rv self.sigma_i_first_obs_rv = sigma_i_first_obs_rv - self.sigma_initial_exp_growth_rate_rv = sigma_initial_exp_growth_rate_rv + self.sigma_initial_exp_growth_rate_rv = ( + sigma_initial_exp_growth_rate_rv + ) self.mean_initial_exp_growth_rate_rv = mean_initial_exp_growth_rate_rv self.offset_ref_logit_i_first_obs_rv = offset_ref_logit_i_first_obs_rv self.offset_ref_initial_exp_growth_rate_rv = ( @@ -138,11 +139,15 @@ def sample( data_observed_hospital_admissions=None, data_observed_log_conc=None, ): # numpydoc ignore=GL08 - if data_observed_hospital_admissions is None and data_observed_log_conc is None: + if ( + data_observed_hospital_admissions is None + and data_observed_log_conc is None + ): n_datapoints = n_datapoints else: n_datapoints = max( - len(data_observed_hospital_admissions), self.max_ww_sampled_days + len(data_observed_hospital_admissions), + self.max_ww_sampled_days, ) n_weeks_post_init = n_datapoints // 7 + 1 @@ -186,8 +191,11 @@ def sample( transforms.logit(i_first_obs_over_n) + jnp.where(self.n_subpops > 1, offset_ref_logit_i_first_obs, 0) ) - initial_exp_growth_rate_ref_subpop = mean_initial_exp_growth_rate + jnp.where( - self.n_subpops > 1, offset_ref_initial_exp_growth_rate, 0 + initial_exp_growth_rate_ref_subpop = ( + mean_initial_exp_growth_rate + + jnp.where( + self.n_subpops > 1, offset_ref_initial_exp_growth_rate, 0 + ) ) offset_ref_log_r_t = self.offset_ref_log_r_t_rv() @@ -208,7 +216,9 @@ def sample( ), transforms=transforms.SigmoidTransform(), ) - sigma_initial_exp_growth_rate = self.sigma_initial_exp_growth_rate_rv() + sigma_initial_exp_growth_rate = ( + self.sigma_initial_exp_growth_rate_rv() + ) initial_exp_growth_rate_non_ref_subpop_rv = TransformedVariable( "clipped_initial_exp_growth_rate_non_ref_subpop", DistributionalVariable( @@ -278,7 +288,9 @@ def sample( initial_exp_growth_rate_subpop = initial_exp_growth_rate_ref_subpop log_rtu_subpop_in_week = log_rtu_ref_subpop_in_week[:, jnp.newaxis] - numpyro.deterministic("i_first_obs_over_n_subpop", i_first_obs_over_n_subpop) + numpyro.deterministic( + "i_first_obs_over_n_subpop", i_first_obs_over_n_subpop + ) numpyro.deterministic( "initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop ) @@ -298,7 +310,9 @@ def sample( ) numpyro.deterministic("rtu_subpop", rtu_subpop) - i0_subpop_rv = DeterministicVariable("i0_subpop", jnp.exp(log_i0_subpop)) + i0_subpop_rv = DeterministicVariable( + "i0_subpop", jnp.exp(log_i0_subpop) + ) initial_exp_growth_rate_subpop_rv = DeterministicVariable( "initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop ) @@ -332,16 +346,18 @@ def sample( r_subpop_t = inf_with_feedback_proc_sample.rt numpyro.deterministic("r_subpop_t", r_subpop_t) - state_inf_per_capita = jnp.sum(self.pop_fraction * new_i_subpop, axis=1) + state_inf_per_capita = jnp.sum( + self.pop_fraction * new_i_subpop, axis=1 + ) numpyro.deterministic("state_inf_per_capita", state_inf_per_capita) # number of net infected individuals shedding on each day (sum of individuals in diff stages of infection) def batch_colvolve_fn(m): return jnp.convolve(m, s, mode="valid") - model_net_i = jax.vmap(batch_colvolve_fn, in_axes=1, out_axes=1)(new_i_subpop)[ - -n_datapoints:, : - ] + model_net_i = jax.vmap(batch_colvolve_fn, in_axes=1, out_axes=1)( + new_i_subpop + )[-n_datapoints:, :] numpyro.deterministic("model_net_i", model_net_i) log10_g = self.log10_g_rv() @@ -388,11 +404,13 @@ def batch_colvolve_fn(m): hosp_wday_effect = tile_until_n(hosp_wday_effect_raw, n_datapoints) - potential_latent_hospital_admissions = compute_delay_ascertained_incidence( - p_observed_given_incident=1, - latent_incidence=state_inf_per_capita, - delay_incidence_to_observation_pmf=inf_to_hosp, - )[-n_datapoints:] + potential_latent_hospital_admissions = ( + compute_delay_ascertained_incidence( + p_observed_given_incident=1, + latent_incidence=state_inf_per_capita, + delay_incidence_to_observation_pmf=inf_to_hosp, + )[-n_datapoints:] + ) latent_hospital_admissions = ( potential_latent_hospital_admissions @@ -400,7 +418,9 @@ def batch_colvolve_fn(m): * hosp_wday_effect * self.state_pop ) - numpyro.deterministic("latent_hospital_admissions", latent_hospital_admissions) + numpyro.deterministic( + "latent_hospital_admissions", latent_hospital_admissions + ) mode_sigma_ww_site = self.mode_sigma_ww_site_rv() sd_log_sigma_ww_site = self.sd_log_sigma_ww_site_rv() @@ -432,7 +452,9 @@ def batch_colvolve_fn(m): ] # multiplies the expected observed genomes by the site-specific multiplier at that sampling time - exp_obs_log_v = exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] + exp_obs_log_v = ( + exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] + ) hospital_admission_obs_rv = NegativeBinomialObservation( "observed_hospital_admissions", concentration_rv=self.phi_rv @@ -446,7 +468,9 @@ def batch_colvolve_fn(m): "log_conc_obs", dist.Normal( loc=exp_obs_log_v[self.ww_uncensored], - scale=sigma_ww_site[self.ww_sampled_lab_sites[self.ww_uncensored]], + scale=sigma_ww_site[ + self.ww_sampled_lab_sites[self.ww_uncensored] + ], ), obs=( data_observed_log_conc[self.ww_uncensored] @@ -457,21 +481,24 @@ def batch_colvolve_fn(m): if self.ww_censored.shape[0] != 0: log_cdf_values = dist.Normal( loc=exp_obs_log_v[self.ww_censored], - scale=sigma_ww_site[self.ww_sampled_lab_sites[self.ww_censored]], + scale=sigma_ww_site[ + self.ww_sampled_lab_sites[self.ww_censored] + ], ).log_cdf(self.ww_log_lod[self.ww_censored]) numpyro.factor("log_prob_censored", log_cdf_values.sum()) site_ww_pred_log = numpyro.sample( "site_ww_pred_log", dist.Normal( - loc=model_log_v_ot[:, self.lab_site_to_subpop_map] + ww_site_mod, + loc=model_log_v_ot[:, self.lab_site_to_subpop_map] + + ww_site_mod, scale=sigma_ww_site, ), ) - state_model_net_i = jnp.convolve(state_inf_per_capita, s, mode="valid")[ - -n_datapoints: - ] + state_model_net_i = jnp.convolve( + state_inf_per_capita, s, mode="valid" + )[-n_datapoints:] numpyro.deterministic("state_model_net_i", state_model_net_i) state_log_c = ( @@ -488,7 +515,9 @@ def batch_colvolve_fn(m): state_inf_per_capita[-n_datapoints:] / jnp.convolve( state_inf_per_capita, - jnp.hstack((jnp.array([0]), jnp.array(generation_interval_pmf))), + jnp.hstack( + (jnp.array([0]), jnp.array(generation_interval_pmf)) + ), mode="valid", )[-n_datapoints:] ) diff --git a/pyrenew_hew/utils.py b/pyrenew_hew/utils.py index 10305930..006601a4 100644 --- a/pyrenew_hew/utils.py +++ b/pyrenew_hew/utils.py @@ -17,18 +17,20 @@ def normed_shedding_cdf( by a given time post infection. """ norm_const = (t_p + t_d) * ((log_base - 1) / jnp.log(log_base) - 1) - ad_pre = ( - lambda x: t_p - / jnp.log(log_base) - * jnp.exp(jnp.log(log_base) * x / t_p) - - x - ) - ad_post = ( - lambda x: -t_d - / jnp.log(log_base) - * jnp.exp(jnp.log(log_base) * (1 - ((x - t_p) / t_d))) - - x - ) + + def ad_pre(x): + return ( + t_p / jnp.log(log_base) * jnp.exp(jnp.log(log_base) * x / t_p) - x + ) + + def ad_post(x): + return ( + -t_d + / jnp.log(log_base) + * jnp.exp(jnp.log(log_base) * (1 - ((x - t_p) / t_d))) + - x + ) + return ( jnp.where( time < t_p + t_d, From 851d233ad6566f2226da0ed9c1b7f4088db0c868 Mon Sep 17 00:00:00 2001 From: sbidari Date: Fri, 11 Oct 2024 19:05:52 -0400 Subject: [PATCH 32/50] add hosp only fit compatibility --- notebooks/site_level_dynamics_model_demo.qmd | 163 ++++++--- pyrenew_hew/site_level_dynamics_model.py | 327 +++++++++++-------- 2 files changed, 296 insertions(+), 194 deletions(-) diff --git a/notebooks/site_level_dynamics_model_demo.qmd b/notebooks/site_level_dynamics_model_demo.qmd index 769d058f..9a795dbe 100644 --- a/notebooks/site_level_dynamics_model_demo.qmd +++ b/notebooks/site_level_dynamics_model_demo.qmd @@ -32,7 +32,7 @@ from pyrenew_hew.utils import * gt_max = stan_data["gt_max"] #lower=1 hosp_delay_max = stan_data["hosp_delay_max"] n_initialization_points = max(gt_max, hosp_delay_max) -i0_t_offset = 0 # check this later +i0_t_offset = 0 obs_time = stan_data["ot"]# maximum time index for the hospital admissions (max number of days we could have observations) @@ -132,10 +132,6 @@ t_peak_rv = DistributionalVariable( "t_peak", dist.TruncatedNormal(t_peak_mean, t_peak_sd,low=0) ) -viral_peak_rv = DistributionalVariable( - "viral_peak", dist.Normal(viral_peak_mean, viral_peak_sd) -) - dur_shed_after_peak_rv = DistributionalVariable( "dur_shed_after_peak", dist.TruncatedNormal(dur_shed_mean-t_peak_mean, jnp.sqrt(dur_shed_sd**2+t_peak_sd**2), low=0) ) @@ -287,51 +283,52 @@ ww_site_mod_sd_rv = DistributionalVariable( ```{python} my_model = ww_site_level_dynamics_model( - state_pop, - n_subpops, - n_ww_lab_sites, - unobs_time, - n_initialization_points, - max_shed_interval, - i0_t_offset, - log_r_t_first_obs_rv, - autoreg_rt_rv, - eta_sd_rv, - t_peak_rv, - dur_shed_after_peak_rv, - autoreg_rt_subpop_rv, - sigma_rt_rv, - i_first_obs_over_n_rv, - sigma_i_first_obs_rv, - sigma_initial_exp_growth_rate_rv, - mean_initial_exp_growth_rate_rv, - offset_ref_logit_i_first_obs_rv, - offset_ref_initial_exp_growth_rate_rv, - offset_ref_log_r_t_rv, - generation_interval_pmf_rv, - infection_feedback_strength_rv, - infection_feedback_pmf_rv, - p_hosp_mean_rv, - p_hosp_w_sd_rv, - autoreg_p_hosp_rv, - hosp_wday_effect_rv, - inf_to_hosp_rv, - log10_g_rv, - mode_sigma_ww_site_rv, - sd_log_sigma_ww_site_rv, - ww_site_mod_sd_rv, - phi_rv, - ww_ml_produced_per_day, - pop_fraction, - ww_uncensored, - ww_censored, - ww_sampled_lab_sites, - ww_sampled_subpops, - ww_sampled_times, - ww_log_lod, - lab_site_to_subpop_map, - hosp_times, - max_ww_sampled_days, + state_pop, + unobs_time, + n_initialization_points, + i0_t_offset, + log_r_t_first_obs_rv, + autoreg_rt_rv, + eta_sd_rv, + autoreg_rt_subpop_rv, + sigma_rt_rv, + i_first_obs_over_n_rv, + sigma_i_first_obs_rv, + sigma_initial_exp_growth_rate_rv, + mean_initial_exp_growth_rate_rv, + offset_ref_logit_i_first_obs_rv, + offset_ref_initial_exp_growth_rate_rv, + offset_ref_log_r_t_rv, + generation_interval_pmf_rv, + infection_feedback_strength_rv, + infection_feedback_pmf_rv, + p_hosp_mean_rv, + p_hosp_w_sd_rv, + autoreg_p_hosp_rv, + hosp_wday_effect_rv, + inf_to_hosp_rv, + phi_rv, + hosp_times, + pop_fraction, + n_subpops, + t_peak_rv, + dur_shed_after_peak_rv, + n_ww_lab_sites, + max_shed_interval, + log10_g_rv, + mode_sigma_ww_site_rv, + sd_log_sigma_ww_site_rv, + ww_site_mod_sd_rv, + ww_ml_produced_per_day, + ww_uncensored, + ww_censored, + ww_sampled_lab_sites, + ww_sampled_subpops, + ww_sampled_times, + ww_log_lod, + lab_site_to_subpop_map, + max_ww_sampled_days, + include_ww=1 ) ``` @@ -454,3 +451,69 @@ min(diagnostic_stats_summary['ess_bulk']) ```{python} min(diagnostic_stats_summary['ess_tail']) ``` + +We can fit the model using only hospital admissions data +```{python} +my_model_hosp_only_fit = ww_site_level_dynamics_model( + state_pop, + unobs_time, + n_initialization_points, + i0_t_offset, + log_r_t_first_obs_rv, + autoreg_rt_rv, + eta_sd_rv, + autoreg_rt_subpop_rv, + sigma_rt_rv, + i_first_obs_over_n_rv, + sigma_i_first_obs_rv, + sigma_initial_exp_growth_rate_rv, + mean_initial_exp_growth_rate_rv, + offset_ref_logit_i_first_obs_rv, + offset_ref_initial_exp_growth_rate_rv, + offset_ref_log_r_t_rv, + generation_interval_pmf_rv, + infection_feedback_strength_rv, + infection_feedback_pmf_rv, + p_hosp_mean_rv, + p_hosp_w_sd_rv, + autoreg_p_hosp_rv, + hosp_wday_effect_rv, + inf_to_hosp_rv, + phi_rv, + hosp_times, + pop_fraction=1, + n_subpops=1, + include_ww=0, +) +``` + +```{python} +my_model_hosp_only_fit.run( + num_warmup=500, + num_samples=100, + rng_key=jax.random.key(223), + data_observed_hospital_admissions=data_observed_hospital_admissions, + mcmc_args=dict(num_chains=4), + nuts_args=dict(init_strategy=init_to_sample) +) +``` + +```{python} +posterior_predictive_hosp_only = my_model_hosp_only_fit.posterior_predictive( + n_datapoints= len(data_observed_hospital_admissions) + n_forecast_days, +) +``` + +```{python} +idata_hosp_only = az.from_numpyro( + my_model_hosp_only_fit.mcmc, + posterior_predictive=posterior_predictive_hosp_only, +) +``` +```{python} +plotting.plot_predictive(idata_hosp_only) +``` + +```{python} +plotting.plot_posterior(idata_hosp_only,'r_subpop_t') +``` diff --git a/pyrenew_hew/site_level_dynamics_model.py b/pyrenew_hew/site_level_dynamics_model.py index 73c4b497..64d74f40 100644 --- a/pyrenew_hew/site_level_dynamics_model.py +++ b/pyrenew_hew/site_level_dynamics_model.py @@ -25,17 +25,12 @@ class ww_site_level_dynamics_model(Model): # numpydoc ignore=GL08 def __init__( self, state_pop, - n_subpops, - n_ww_lab_sites, unobs_time, n_initialization_points, - max_shed_interval, i0_t_offset, log_r_t_first_obs_rv, autoreg_rt_rv, eta_sd_rv, - t_peak_rv, - dur_shed_after_peak_rv, autoreg_rt_subpop_rv, sigma_rt_rv, i_first_obs_over_n_rv, @@ -53,22 +48,28 @@ def __init__( autoreg_p_hosp_rv, hosp_wday_effect_rv, inf_to_hosp_rv, - log10_g_rv, - mode_sigma_ww_site_rv, - sd_log_sigma_ww_site_rv, - ww_site_mod_sd_rv, phi_rv, - ww_ml_produced_per_day, - pop_fraction, - ww_uncensored, - ww_censored, - ww_sampled_lab_sites, - ww_sampled_subpops, - ww_sampled_times, - ww_log_lod, - lab_site_to_subpop_map, hosp_times, - max_ww_sampled_days, + pop_fraction, + n_subpops, + t_peak_rv=None, + dur_shed_after_peak_rv=None, + n_ww_lab_sites=None, + max_shed_interval=None, + log10_g_rv=None, + mode_sigma_ww_site_rv=None, + sd_log_sigma_ww_site_rv=None, + ww_site_mod_sd_rv=None, + ww_ml_produced_per_day=None, + ww_uncensored=None, + ww_censored=None, + ww_sampled_lab_sites=None, + ww_sampled_subpops=None, + ww_sampled_times=None, + ww_log_lod=None, + lab_site_to_subpop_map=None, + max_ww_sampled_days=None, + include_ww=1, ): # numpydoc ignore=GL08 self.state_pop = state_pop self.n_subpops = n_subpops @@ -117,6 +118,7 @@ def __init__( self.lab_site_to_subpop_map = lab_site_to_subpop_map self.hosp_times = hosp_times self.max_ww_sampled_days = max_ww_sampled_days + self.include_ww = include_ww self.inf_with_feedback_proc = InfectionsWithFeedback( infection_feedback_strength=infection_feedback_strength_rv, @@ -139,16 +141,45 @@ def sample( data_observed_hospital_admissions=None, data_observed_log_conc=None, ): # numpydoc ignore=GL08 - if ( - data_observed_hospital_admissions is None - and data_observed_log_conc is None - ): - n_datapoints = n_datapoints + if n_datapoints is None: + if ( + data_observed_hospital_admissions is None + and data_observed_log_conc is None + ): + raise ValueError( + "Either n_datapoints or data_observed_hosp_admissions " + "must be passed." + ) + elif ( + data_observed_hospital_admissions is None + and data_observed_log_conc is not None + ): + raise ValueError( + "Either n_datapoints or data_observed_hosp_admissions " + "must be passed." + ) + elif ( + data_observed_hospital_admissions is not None + and data_observed_log_conc is None + ): + n_datapoints = len(data_observed_hospital_admissions) + else: + n_datapoints = max( + len(data_observed_hospital_admissions), + self.max_ww_sampled_days, + ) else: - n_datapoints = max( - len(data_observed_hospital_admissions), - self.max_ww_sampled_days, - ) + if ( + data_observed_hospital_admissions is not None + or data_observed_log_conc is not None + ): + raise ValueError( + "Cannot pass both n_datapoints and " + "data_observed_hospital_admissions " + "or data_observed_log_conc" + ) + else: + n_datapoints = n_datapoints n_weeks_post_init = n_datapoints // 7 + 1 @@ -173,12 +204,6 @@ def sample( ) numpyro.deterministic("log_r_t_in_weeks", log_r_t_in_weeks) - t_peak = self.t_peak_rv() - # viral_peak = self.viral_peak_rv() - dur_shed = self.dur_shed_after_peak_rv() - - s = get_vl_trajectory(t_peak, dur_shed, self.max_shed_interval) - i_first_obs_over_n = self.i_first_obs_over_n_rv() offset_ref_logit_i_first_obs = self.offset_ref_logit_i_first_obs_rv() @@ -337,12 +362,17 @@ def sample( gen_int=generation_interval_pmf, ) - new_i_subpop = jnp.concat( - [ - i0, - inf_with_feedback_proc_sample.post_initialization_infections, - ] + new_i_subpop = jnp.atleast_2d( + jnp.concat( + [ + i0, + inf_with_feedback_proc_sample.post_initialization_infections, + ] + ) ) + if new_i_subpop.shape[0] == 1: + new_i_subpop = new_i_subpop.T + r_subpop_t = inf_with_feedback_proc_sample.rt numpyro.deterministic("r_subpop_t", r_subpop_t) @@ -351,25 +381,6 @@ def sample( ) numpyro.deterministic("state_inf_per_capita", state_inf_per_capita) - # number of net infected individuals shedding on each day (sum of individuals in diff stages of infection) - def batch_colvolve_fn(m): - return jnp.convolve(m, s, mode="valid") - - model_net_i = jax.vmap(batch_colvolve_fn, in_axes=1, out_axes=1)( - new_i_subpop - )[-n_datapoints:, :] - numpyro.deterministic("model_net_i", model_net_i) - - log10_g = self.log10_g_rv() - - # expected observed viral genomes/mL at all observed and forecasted times - model_log_v_ot = ( - jnp.log(10) * log10_g - + jnp.log(model_net_i + 1e-8) - - jnp.log(self.ww_ml_produced_per_day) - ) - numpyro.deterministic("model_log_v_ot", model_log_v_ot) - # Hospital admission component p_hosp_mean = self.p_hosp_mean_rv() p_hosp_w_sd = self.p_hosp_w_sd_rv() @@ -422,40 +433,6 @@ def batch_colvolve_fn(m): "latent_hospital_admissions", latent_hospital_admissions ) - mode_sigma_ww_site = self.mode_sigma_ww_site_rv() - sd_log_sigma_ww_site = self.sd_log_sigma_ww_site_rv() - ww_site_mod_sd = self.ww_site_mod_sd_rv() - - ww_site_mod_rv = DistributionalVariable( - "ww_site_mod", - dist.Normal(0, ww_site_mod_sd), - reparam=LocScaleReparam(0), - ) # lab-site specific variation - - sigma_ww_site_rv = TransformedVariable( - "sigma_ww_site", - DistributionalVariable( - "log_sigma_ww_site", - dist.Normal(jnp.log(mode_sigma_ww_site), sd_log_sigma_ww_site), - reparam=LocScaleReparam(0), - ), - transforms=transforms.ExpTransform(), - ) - - with numpyro.plate("n_ww_lab_sites", self.n_ww_lab_sites): - ww_site_mod = ww_site_mod_rv() - sigma_ww_site = sigma_ww_site_rv() - - # expected observations at each site in log scale - exp_obs_log_v_true = model_log_v_ot[ - self.ww_sampled_times, self.ww_sampled_subpops - ] - - # multiplies the expected observed genomes by the site-specific multiplier at that sampling time - exp_obs_log_v = ( - exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] - ) - hospital_admission_obs_rv = NegativeBinomialObservation( "observed_hospital_admissions", concentration_rv=self.phi_rv ) @@ -464,66 +441,128 @@ def batch_colvolve_fn(m): obs=data_observed_hospital_admissions, ) - numpyro.sample( - "log_conc_obs", - dist.Normal( - loc=exp_obs_log_v[self.ww_uncensored], - scale=sigma_ww_site[ - self.ww_sampled_lab_sites[self.ww_uncensored] - ], - ), - obs=( - data_observed_log_conc[self.ww_uncensored] - if data_observed_log_conc is not None - else None - ), - ) - if self.ww_censored.shape[0] != 0: - log_cdf_values = dist.Normal( - loc=exp_obs_log_v[self.ww_censored], - scale=sigma_ww_site[ - self.ww_sampled_lab_sites[self.ww_censored] - ], - ).log_cdf(self.ww_log_lod[self.ww_censored]) - numpyro.factor("log_prob_censored", log_cdf_values.sum()) + if self.include_ww: + t_peak = self.t_peak_rv() + dur_shed = self.dur_shed_after_peak_rv() + s = get_vl_trajectory(t_peak, dur_shed, self.max_shed_interval) - site_ww_pred_log = numpyro.sample( - "site_ww_pred_log", - dist.Normal( - loc=model_log_v_ot[:, self.lab_site_to_subpop_map] - + ww_site_mod, - scale=sigma_ww_site, - ), - ) + # number of net infected individuals shedding on each day (sum of individuals in diff stages of infection) + def batch_colvolve_fn(m): + return jnp.convolve(m, s, mode="valid") - state_model_net_i = jnp.convolve( - state_inf_per_capita, s, mode="valid" - )[-n_datapoints:] - numpyro.deterministic("state_model_net_i", state_model_net_i) + model_net_i = jax.vmap(batch_colvolve_fn, in_axes=1, out_axes=1)( + new_i_subpop + )[-n_datapoints:, :] + numpyro.deterministic("model_net_i", model_net_i) - state_log_c = ( - jnp.log(10) * log10_g - + jnp.log(state_model_net_i + 1e-8) - - jnp.log(self.ww_ml_produced_per_day) - ) - numpyro.deterministic("state_log_c", state_log_c) + log10_g = self.log10_g_rv() - expected_state_ww_conc = jnp.exp(state_log_c) - numpyro.deterministic("expected_state_ww_conc", expected_state_ww_conc) + # expected observed viral genomes/mL at all observed and forecasted times + model_log_v_ot = ( + jnp.log(10) * log10_g + + jnp.log(model_net_i + 1e-8) + - jnp.log(self.ww_ml_produced_per_day) + ) + numpyro.deterministic("model_log_v_ot", model_log_v_ot) - state_rt = ( - state_inf_per_capita[-n_datapoints:] - / jnp.convolve( - state_inf_per_capita, - jnp.hstack( - (jnp.array([0]), jnp.array(generation_interval_pmf)) + mode_sigma_ww_site = self.mode_sigma_ww_site_rv() + sd_log_sigma_ww_site = self.sd_log_sigma_ww_site_rv() + ww_site_mod_sd = self.ww_site_mod_sd_rv() + + ww_site_mod_rv = DistributionalVariable( + "ww_site_mod", + dist.Normal(0, ww_site_mod_sd), + reparam=LocScaleReparam(0), + ) # lab-site specific variation + + sigma_ww_site_rv = TransformedVariable( + "sigma_ww_site", + DistributionalVariable( + "log_sigma_ww_site", + dist.Normal( + jnp.log(mode_sigma_ww_site), sd_log_sigma_ww_site + ), + reparam=LocScaleReparam(0), ), - mode="valid", + transforms=transforms.ExpTransform(), + ) + + with numpyro.plate("n_ww_lab_sites", self.n_ww_lab_sites): + ww_site_mod = ww_site_mod_rv() + sigma_ww_site = sigma_ww_site_rv() + + # expected observations at each site in log scale + exp_obs_log_v_true = model_log_v_ot[ + self.ww_sampled_times, self.ww_sampled_subpops + ] + + # multiplies the expected observed genomes by the site-specific multiplier at that sampling time + exp_obs_log_v = ( + exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] + ) + + numpyro.sample( + "log_conc_obs", + dist.Normal( + loc=exp_obs_log_v[self.ww_uncensored], + scale=sigma_ww_site[ + self.ww_sampled_lab_sites[self.ww_uncensored] + ], + ), + obs=( + data_observed_log_conc[self.ww_uncensored] + if data_observed_log_conc is not None + else None + ), + ) + if self.ww_censored.shape[0] != 0: + log_cdf_values = dist.Normal( + loc=exp_obs_log_v[self.ww_censored], + scale=sigma_ww_site[ + self.ww_sampled_lab_sites[self.ww_censored] + ], + ).log_cdf(self.ww_log_lod[self.ww_censored]) + numpyro.factor("log_prob_censored", log_cdf_values.sum()) + + site_ww_pred_log = numpyro.sample( + "site_ww_pred_log", + dist.Normal( + loc=model_log_v_ot[:, self.lab_site_to_subpop_map] + + ww_site_mod, + scale=sigma_ww_site, + ), + ) + + state_model_net_i = jnp.convolve( + state_inf_per_capita, s, mode="valid" )[-n_datapoints:] - ) - numpyro.deterministic("state_rt", state_rt) + numpyro.deterministic("state_model_net_i", state_model_net_i) + + state_log_c = ( + jnp.log(10) * log10_g + + jnp.log(state_model_net_i + 1e-8) + - jnp.log(self.ww_ml_produced_per_day) + ) + numpyro.deterministic("state_log_c", state_log_c) + + expected_state_ww_conc = jnp.exp(state_log_c) + numpyro.deterministic( + "expected_state_ww_conc", expected_state_ww_conc + ) + + state_rt = ( + state_inf_per_capita[-n_datapoints:] + / jnp.convolve( + state_inf_per_capita, + jnp.hstack( + (jnp.array([0]), jnp.array(generation_interval_pmf)) + ), + mode="valid", + )[-n_datapoints:] + ) + numpyro.deterministic("state_rt", state_rt) return ( observed_hospital_admissions, - site_ww_pred_log, + site_ww_pred_log if self.include_ww else None, ) From 431e578e9bb830d7272f7f0669b6e3da511e342e Mon Sep 17 00:00:00 2001 From: sbidari Date: Fri, 11 Oct 2024 19:37:55 -0400 Subject: [PATCH 33/50] clean up --- notebooks/site_level_dynamics_model_demo.qmd | 72 ++++++++-------- pyrenew_hew/site_level_dynamics_model.py | 86 ++++++-------------- 2 files changed, 63 insertions(+), 95 deletions(-) diff --git a/notebooks/site_level_dynamics_model_demo.qmd b/notebooks/site_level_dynamics_model_demo.qmd index 9a795dbe..915d41b5 100644 --- a/notebooks/site_level_dynamics_model_demo.qmd +++ b/notebooks/site_level_dynamics_model_demo.qmd @@ -1,10 +1,14 @@ --- -jupyter: python3 +title: "Replicating model fitting hospital admissions and waste water data from ww-inference-model" +format: gfm +engine: jupyter --- ```{python} +# | label: setup import json - +import jax +import jax.numpy as jnp import numpyro import numpyro.distributions as dist import numpyro.distributions.transforms as transforms @@ -13,45 +17,40 @@ from pyrenew.deterministic import DeterministicVariable, DeterministicPMF from pyrenew.randomvariable import DistributionalVariable, TransformedVariable from pyrenew_hew.site_level_dynamics_model import ww_site_level_dynamics_model from pyrenew_hew.initialization import get_initialization -import jax from numpyro.infer.initialization import init_to_sample, init_to_value +from pyrenew_hew.utils import convert_to_logmean_log_sd, get_vl_trajectory import pyrenew_hew.plotting as plotting import matplotlib.pyplot as plt + numpyro.set_host_device_count(4) ``` +We will use the data used in the `wwinference` [vignette](https://github.com/CDCgov/ww-inference-model/blob/main/vignettes/wwinference.Rmd) in the [`ww-inference-model` project](https://github.com/CDCgov/ww-inference-model). This data is generated by running `notebooks/wwinference.Rmd`, which replicates the original vignette and saves the relevant data in `notebooks/data/fit/stan_data.json`. + ```{python} +# | label: Import stan-data with open("data/fit/stan_data.json","r") as file: stan_data = json.load(file) -#helper function -from pyrenew_hew.utils import * ``` ```{python} +# | label: assign variables gt_max = stan_data["gt_max"] #lower=1 hosp_delay_max = stan_data["hosp_delay_max"] n_initialization_points = max(gt_max, hosp_delay_max) i0_t_offset = 0 - -obs_time = stan_data["ot"]# maximum time index for the hospital admissions (max number of days we could have observations) -obs_ww_time = stan_data["owt"] #number of days of observed WW (should be roughly ot/7) -obs_hosp_time = stan_data["oht"] # number of days that we have hospital admissions observations - -horizon_time = stan_data["ht"] #horizon time (nowcast + forecast time) -unobs_time = stan_data["uot"] #unobserved time before we observe hospital admissions/ WW - -n_subpops = stan_data["n_subpops"] #number of modeled subpops +unobs_time = stan_data["uot"] #unobserved time +n_subpops = stan_data["n_subpops"] state_pop = stan_data["state_pop"] subpop_size = stan_data["subpop_size"] norm_pop = stan_data["norm_pop"] pop_fraction = jnp.array(subpop_size)/norm_pop -#mL of ww produced per person per day ww_ml_produced_per_day = stan_data["mwpd"] n_ww_lab_sites = stan_data["n_ww_lab_sites"] -ww_log_lod =jnp.array(stan_data["ww_log_lod"]) # The limit of detection in that site at that time point +ww_log_lod =jnp.array(stan_data["ww_log_lod"]) n_censored = stan_data["n_censored"] n_uncensored = stan_data["n_uncensored"] @@ -59,7 +58,7 @@ ww_censored = jnp.array(stan_data["ww_censored"])-1 #times that the WW data is ww_uncensored = jnp.array(stan_data["ww_uncensored"])-1 #time that WW data is above LOD ww_sampled_subpops = jnp.array(stan_data["ww_sampled_subpops"]) -1 # vector of unique sites in order of the sampled times -ww_sampled_times = jnp.array(stan_data["ww_sampled_times"]) -1 # a list of all of the days on which WW is sampled, mapped to corresponding subpops +ww_sampled_times = jnp.array(stan_data["ww_sampled_times"]) -1 # a list of all of the days on which WW is sampled ww_sampled_lab_sites = jnp.array(stan_data["ww_sampled_lab_sites"]) - 1 hosp_times = jnp.array(stan_data["hosp_times"]) -1 max_ww_sampled_days = max(stan_data["ww_sampled_times"]) @@ -70,8 +69,7 @@ data_observed_hospital_admissions = jnp.array(stan_data["hosp"]) ``` ```{python} -# State-leve R(t) AR + RW implementation: - +# | label: set priors eta_sd_sd = stan_data["eta_sd_sd"] eta_sd_rv = DistributionalVariable("eta_sd", dist.TruncatedNormal(0, eta_sd_sd, low=0)) @@ -83,13 +81,10 @@ r_prior_mean = stan_data["r_prior_mean"] r_prior_sd = stan_data["r_prior_sd"] r_logmean, r_logsd = convert_to_logmean_log_sd(r_prior_mean, r_prior_sd) -# log of state level mean R(t) in weeks log_r_t_first_obs_rv = DistributionalVariable( "log_r_t_first_obs", dist.Normal(r_logmean, r_logsd) ) -``` -```{python} offset_ref_log_r_t_prior_mean=stan_data["offset_ref_log_r_t_prior_mean"] offset_ref_log_r_t_prior_sd = stan_data["offset_ref_log_r_t_prior_sd"] offset_ref_log_r_t_rv = DistributionalVariable( @@ -118,13 +113,11 @@ offset_ref_initial_exp_growth_rate_rv = DistributionalVariable( ``` ```{python} -# viral shedding parameters +# | label: viral shedding parameters viral_shedding_pars = stan_data["viral_shedding_pars"] t_peak_mean = viral_shedding_pars[0] t_peak_sd = viral_shedding_pars[1] -viral_peak_mean = viral_shedding_pars[2] -viral_peak_sd = viral_shedding_pars[3] dur_shed_mean = viral_shedding_pars[4] dur_shed_sd = viral_shedding_pars[5] @@ -139,7 +132,6 @@ max_shed_interval = dur_shed_mean + 3*dur_shed_sd ``` ```{python} -# Infection and site-level dynamics infection_feedback_prior_logmean = stan_data["inf_feedback_prior_logmean"] infection_feedback_prior_logsd = stan_data["inf_feedback_prior_logsd"] infection_feedback_strength_rv = TransformedVariable( @@ -156,7 +148,6 @@ infection_feedback_pmf_rv = DeterministicPMF( "infection_feedback_pmf", jnp.array(infection_feedback_pmf) ) -# generation interval distribution generation_interval = stan_data["generation_interval"] generation_interval_pmf_rv = DeterministicPMF( "generation_interval_pmf", jnp.array(generation_interval) @@ -282,6 +273,7 @@ ww_site_mod_sd_rv = DistributionalVariable( ``` ```{python} +# | label: create model my_model = ww_site_level_dynamics_model( state_pop, unobs_time, @@ -332,7 +324,9 @@ my_model = ww_site_level_dynamics_model( ) ``` +Check that we can simulate from the prior predictive ```{python} +# | label: prior n_forecast_days = 35 prior_predictive = my_model.prior_predictive( @@ -341,7 +335,9 @@ prior_predictive = my_model.prior_predictive( ) ``` +Fit the model to observed hospital admissions and wastewater data ```{python} +# | label: model fit my_model.run( num_warmup=500, num_samples=100, @@ -353,7 +349,9 @@ my_model.run( ) ``` +Simulate the posterior predictive distribution ```{python} +# | label: posterior predictive posterior_predictive = my_model.posterior_predictive( n_datapoints= max(len(data_observed_hospital_admissions), max_ww_sampled_days) + n_forecast_days, ) @@ -363,6 +361,7 @@ We can use initial values set in the stan code by using `init_strategy = init_to `init_vals = get_initialization(stan_data,stdev=0.01,rng_key=jax.random.PRNGKey(0))` ```{python} +# | label: create inference data object import arviz as az idata = az.from_numpyro( my_model.mcmc, @@ -376,18 +375,22 @@ idata ``` ```{python} +# | label: plot prior preditive plotting.plot_predictive(idata, prior=True) ``` ```{python} +# | label: plot posterior preditive plotting.plot_predictive(idata) ``` ```{python} +# | label: plot posterior rt plotting.plot_posterior(idata,'state_rt') ``` ```{python} +# | label: plot posterior preditive rt x_data = idata.posterior_predictive['state_rt_dim_0'] y_data = idata.posterior_predictive['state_rt'] @@ -427,16 +430,19 @@ plt.plot( ``` ```{python} +# | label: plot posterior rt by subpop for i in range(n_subpops): plotting.plot_posterior(idata, 'r_subpop_t', dim_1=i) ``` ```{python} +# | label: plot posterior ww conconcentration for i in range(n_ww_lab_sites): plotting.plot_posterior(idata, 'site_ww_pred_log', dim_1=i) ``` ```{python} +# | label: diagnostic stats diagnostic_stats_summary = az.summary(idata) ``` @@ -444,16 +450,9 @@ diagnostic_stats_summary = az.summary(idata) max(diagnostic_stats_summary['r_hat']) ``` +We fit the model using only hospital admissions data ```{python} -min(diagnostic_stats_summary['ess_bulk']) -``` - -```{python} -min(diagnostic_stats_summary['ess_tail']) -``` - -We can fit the model using only hospital admissions data -```{python} +# | label: create hosp only model my_model_hosp_only_fit = ww_site_level_dynamics_model( state_pop, unobs_time, @@ -488,6 +487,7 @@ my_model_hosp_only_fit = ww_site_level_dynamics_model( ``` ```{python} +# | label: hosp only fit my_model_hosp_only_fit.run( num_warmup=500, num_samples=100, diff --git a/pyrenew_hew/site_level_dynamics_model.py b/pyrenew_hew/site_level_dynamics_model.py index 64d74f40..d671ece2 100644 --- a/pyrenew_hew/site_level_dynamics_model.py +++ b/pyrenew_hew/site_level_dynamics_model.py @@ -69,7 +69,7 @@ def __init__( ww_log_lod=None, lab_site_to_subpop_map=None, max_ww_sampled_days=None, - include_ww=1, + include_ww=0, ): # numpydoc ignore=GL08 self.state_pop = state_pop self.n_subpops = n_subpops @@ -87,9 +87,7 @@ def __init__( self.sigma_rt_rv = sigma_rt_rv self.i_first_obs_over_n_rv = i_first_obs_over_n_rv self.sigma_i_first_obs_rv = sigma_i_first_obs_rv - self.sigma_initial_exp_growth_rate_rv = ( - sigma_initial_exp_growth_rate_rv - ) + self.sigma_initial_exp_growth_rate_rv = sigma_initial_exp_growth_rate_rv self.mean_initial_exp_growth_rate_rv = mean_initial_exp_growth_rate_rv self.offset_ref_logit_i_first_obs_rv = offset_ref_logit_i_first_obs_rv self.offset_ref_initial_exp_growth_rate_rv = ( @@ -216,11 +214,8 @@ def sample( transforms.logit(i_first_obs_over_n) + jnp.where(self.n_subpops > 1, offset_ref_logit_i_first_obs, 0) ) - initial_exp_growth_rate_ref_subpop = ( - mean_initial_exp_growth_rate - + jnp.where( - self.n_subpops > 1, offset_ref_initial_exp_growth_rate, 0 - ) + initial_exp_growth_rate_ref_subpop = mean_initial_exp_growth_rate + jnp.where( + self.n_subpops > 1, offset_ref_initial_exp_growth_rate, 0 ) offset_ref_log_r_t = self.offset_ref_log_r_t_rv() @@ -241,9 +236,7 @@ def sample( ), transforms=transforms.SigmoidTransform(), ) - sigma_initial_exp_growth_rate = ( - self.sigma_initial_exp_growth_rate_rv() - ) + sigma_initial_exp_growth_rate = self.sigma_initial_exp_growth_rate_rv() initial_exp_growth_rate_non_ref_subpop_rv = TransformedVariable( "clipped_initial_exp_growth_rate_non_ref_subpop", DistributionalVariable( @@ -313,9 +306,7 @@ def sample( initial_exp_growth_rate_subpop = initial_exp_growth_rate_ref_subpop log_rtu_subpop_in_week = log_rtu_ref_subpop_in_week[:, jnp.newaxis] - numpyro.deterministic( - "i_first_obs_over_n_subpop", i_first_obs_over_n_subpop - ) + numpyro.deterministic("i_first_obs_over_n_subpop", i_first_obs_over_n_subpop) numpyro.deterministic( "initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop ) @@ -335,9 +326,7 @@ def sample( ) numpyro.deterministic("rtu_subpop", rtu_subpop) - i0_subpop_rv = DeterministicVariable( - "i0_subpop", jnp.exp(log_i0_subpop) - ) + i0_subpop_rv = DeterministicVariable("i0_subpop", jnp.exp(log_i0_subpop)) initial_exp_growth_rate_subpop_rv = DeterministicVariable( "initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop ) @@ -376,9 +365,7 @@ def sample( r_subpop_t = inf_with_feedback_proc_sample.rt numpyro.deterministic("r_subpop_t", r_subpop_t) - state_inf_per_capita = jnp.sum( - self.pop_fraction * new_i_subpop, axis=1 - ) + state_inf_per_capita = jnp.sum(self.pop_fraction * new_i_subpop, axis=1) numpyro.deterministic("state_inf_per_capita", state_inf_per_capita) # Hospital admission component @@ -415,13 +402,11 @@ def sample( hosp_wday_effect = tile_until_n(hosp_wday_effect_raw, n_datapoints) - potential_latent_hospital_admissions = ( - compute_delay_ascertained_incidence( - p_observed_given_incident=1, - latent_incidence=state_inf_per_capita, - delay_incidence_to_observation_pmf=inf_to_hosp, - )[-n_datapoints:] - ) + potential_latent_hospital_admissions = compute_delay_ascertained_incidence( + p_observed_given_incident=1, + latent_incidence=state_inf_per_capita, + delay_incidence_to_observation_pmf=inf_to_hosp, + )[-n_datapoints:] latent_hospital_admissions = ( potential_latent_hospital_admissions @@ -429,9 +414,7 @@ def sample( * hosp_wday_effect * self.state_pop ) - numpyro.deterministic( - "latent_hospital_admissions", latent_hospital_admissions - ) + numpyro.deterministic("latent_hospital_admissions", latent_hospital_admissions) hospital_admission_obs_rv = NegativeBinomialObservation( "observed_hospital_admissions", concentration_rv=self.phi_rv @@ -441,12 +424,12 @@ def sample( obs=data_observed_hospital_admissions, ) + # wastewater component if self.include_ww: t_peak = self.t_peak_rv() dur_shed = self.dur_shed_after_peak_rv() s = get_vl_trajectory(t_peak, dur_shed, self.max_shed_interval) - # number of net infected individuals shedding on each day (sum of individuals in diff stages of infection) def batch_colvolve_fn(m): return jnp.convolve(m, s, mode="valid") @@ -456,13 +439,11 @@ def batch_colvolve_fn(m): numpyro.deterministic("model_net_i", model_net_i) log10_g = self.log10_g_rv() - - # expected observed viral genomes/mL at all observed and forecasted times model_log_v_ot = ( jnp.log(10) * log10_g + jnp.log(model_net_i + 1e-8) - jnp.log(self.ww_ml_produced_per_day) - ) + ) # expected observed viral genomes/mL at all observed and forecasted times numpyro.deterministic("model_log_v_ot", model_log_v_ot) mode_sigma_ww_site = self.mode_sigma_ww_site_rv() @@ -479,9 +460,7 @@ def batch_colvolve_fn(m): "sigma_ww_site", DistributionalVariable( "log_sigma_ww_site", - dist.Normal( - jnp.log(mode_sigma_ww_site), sd_log_sigma_ww_site - ), + dist.Normal(jnp.log(mode_sigma_ww_site), sd_log_sigma_ww_site), reparam=LocScaleReparam(0), ), transforms=transforms.ExpTransform(), @@ -496,18 +475,14 @@ def batch_colvolve_fn(m): self.ww_sampled_times, self.ww_sampled_subpops ] - # multiplies the expected observed genomes by the site-specific multiplier at that sampling time - exp_obs_log_v = ( - exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] - ) + # multiply the expected observed genomes by the site-specific multiplier at that sampling time + exp_obs_log_v = exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] numpyro.sample( "log_conc_obs", dist.Normal( loc=exp_obs_log_v[self.ww_uncensored], - scale=sigma_ww_site[ - self.ww_sampled_lab_sites[self.ww_uncensored] - ], + scale=sigma_ww_site[self.ww_sampled_lab_sites[self.ww_uncensored]], ), obs=( data_observed_log_conc[self.ww_uncensored] @@ -518,24 +493,21 @@ def batch_colvolve_fn(m): if self.ww_censored.shape[0] != 0: log_cdf_values = dist.Normal( loc=exp_obs_log_v[self.ww_censored], - scale=sigma_ww_site[ - self.ww_sampled_lab_sites[self.ww_censored] - ], + scale=sigma_ww_site[self.ww_sampled_lab_sites[self.ww_censored]], ).log_cdf(self.ww_log_lod[self.ww_censored]) numpyro.factor("log_prob_censored", log_cdf_values.sum()) site_ww_pred_log = numpyro.sample( "site_ww_pred_log", dist.Normal( - loc=model_log_v_ot[:, self.lab_site_to_subpop_map] - + ww_site_mod, + loc=model_log_v_ot[:, self.lab_site_to_subpop_map] + ww_site_mod, scale=sigma_ww_site, ), ) - state_model_net_i = jnp.convolve( - state_inf_per_capita, s, mode="valid" - )[-n_datapoints:] + state_model_net_i = jnp.convolve(state_inf_per_capita, s, mode="valid")[ + -n_datapoints: + ] numpyro.deterministic("state_model_net_i", state_model_net_i) state_log_c = ( @@ -546,17 +518,13 @@ def batch_colvolve_fn(m): numpyro.deterministic("state_log_c", state_log_c) expected_state_ww_conc = jnp.exp(state_log_c) - numpyro.deterministic( - "expected_state_ww_conc", expected_state_ww_conc - ) + numpyro.deterministic("expected_state_ww_conc", expected_state_ww_conc) state_rt = ( state_inf_per_capita[-n_datapoints:] / jnp.convolve( state_inf_per_capita, - jnp.hstack( - (jnp.array([0]), jnp.array(generation_interval_pmf)) - ), + jnp.hstack((jnp.array([0]), jnp.array(generation_interval_pmf))), mode="valid", )[-n_datapoints:] ) From 33d3f650a29b3284107089f561d0128b13de01e4 Mon Sep 17 00:00:00 2001 From: sbidari Date: Fri, 11 Oct 2024 19:39:26 -0400 Subject: [PATCH 34/50] pre commit changhes --- notebooks/site_level_dynamics_model_demo.qmd | 6 +- pyrenew_hew/site_level_dynamics_model.py | 76 ++++++++++++++------ 2 files changed, 56 insertions(+), 26 deletions(-) diff --git a/notebooks/site_level_dynamics_model_demo.qmd b/notebooks/site_level_dynamics_model_demo.qmd index 915d41b5..515c8667 100644 --- a/notebooks/site_level_dynamics_model_demo.qmd +++ b/notebooks/site_level_dynamics_model_demo.qmd @@ -50,7 +50,7 @@ pop_fraction = jnp.array(subpop_size)/norm_pop ww_ml_produced_per_day = stan_data["mwpd"] n_ww_lab_sites = stan_data["n_ww_lab_sites"] -ww_log_lod =jnp.array(stan_data["ww_log_lod"]) +ww_log_lod =jnp.array(stan_data["ww_log_lod"]) n_censored = stan_data["n_censored"] n_uncensored = stan_data["n_uncensored"] @@ -324,7 +324,7 @@ my_model = ww_site_level_dynamics_model( ) ``` -Check that we can simulate from the prior predictive +Check that we can simulate from the prior predictive ```{python} # | label: prior n_forecast_days = 35 @@ -361,7 +361,7 @@ We can use initial values set in the stan code by using `init_strategy = init_to `init_vals = get_initialization(stan_data,stdev=0.01,rng_key=jax.random.PRNGKey(0))` ```{python} -# | label: create inference data object +# | label: create inference data object import arviz as az idata = az.from_numpyro( my_model.mcmc, diff --git a/pyrenew_hew/site_level_dynamics_model.py b/pyrenew_hew/site_level_dynamics_model.py index d671ece2..2c295879 100644 --- a/pyrenew_hew/site_level_dynamics_model.py +++ b/pyrenew_hew/site_level_dynamics_model.py @@ -87,7 +87,9 @@ def __init__( self.sigma_rt_rv = sigma_rt_rv self.i_first_obs_over_n_rv = i_first_obs_over_n_rv self.sigma_i_first_obs_rv = sigma_i_first_obs_rv - self.sigma_initial_exp_growth_rate_rv = sigma_initial_exp_growth_rate_rv + self.sigma_initial_exp_growth_rate_rv = ( + sigma_initial_exp_growth_rate_rv + ) self.mean_initial_exp_growth_rate_rv = mean_initial_exp_growth_rate_rv self.offset_ref_logit_i_first_obs_rv = offset_ref_logit_i_first_obs_rv self.offset_ref_initial_exp_growth_rate_rv = ( @@ -214,8 +216,11 @@ def sample( transforms.logit(i_first_obs_over_n) + jnp.where(self.n_subpops > 1, offset_ref_logit_i_first_obs, 0) ) - initial_exp_growth_rate_ref_subpop = mean_initial_exp_growth_rate + jnp.where( - self.n_subpops > 1, offset_ref_initial_exp_growth_rate, 0 + initial_exp_growth_rate_ref_subpop = ( + mean_initial_exp_growth_rate + + jnp.where( + self.n_subpops > 1, offset_ref_initial_exp_growth_rate, 0 + ) ) offset_ref_log_r_t = self.offset_ref_log_r_t_rv() @@ -236,7 +241,9 @@ def sample( ), transforms=transforms.SigmoidTransform(), ) - sigma_initial_exp_growth_rate = self.sigma_initial_exp_growth_rate_rv() + sigma_initial_exp_growth_rate = ( + self.sigma_initial_exp_growth_rate_rv() + ) initial_exp_growth_rate_non_ref_subpop_rv = TransformedVariable( "clipped_initial_exp_growth_rate_non_ref_subpop", DistributionalVariable( @@ -306,7 +313,9 @@ def sample( initial_exp_growth_rate_subpop = initial_exp_growth_rate_ref_subpop log_rtu_subpop_in_week = log_rtu_ref_subpop_in_week[:, jnp.newaxis] - numpyro.deterministic("i_first_obs_over_n_subpop", i_first_obs_over_n_subpop) + numpyro.deterministic( + "i_first_obs_over_n_subpop", i_first_obs_over_n_subpop + ) numpyro.deterministic( "initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop ) @@ -326,7 +335,9 @@ def sample( ) numpyro.deterministic("rtu_subpop", rtu_subpop) - i0_subpop_rv = DeterministicVariable("i0_subpop", jnp.exp(log_i0_subpop)) + i0_subpop_rv = DeterministicVariable( + "i0_subpop", jnp.exp(log_i0_subpop) + ) initial_exp_growth_rate_subpop_rv = DeterministicVariable( "initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop ) @@ -365,7 +376,9 @@ def sample( r_subpop_t = inf_with_feedback_proc_sample.rt numpyro.deterministic("r_subpop_t", r_subpop_t) - state_inf_per_capita = jnp.sum(self.pop_fraction * new_i_subpop, axis=1) + state_inf_per_capita = jnp.sum( + self.pop_fraction * new_i_subpop, axis=1 + ) numpyro.deterministic("state_inf_per_capita", state_inf_per_capita) # Hospital admission component @@ -402,11 +415,13 @@ def sample( hosp_wday_effect = tile_until_n(hosp_wday_effect_raw, n_datapoints) - potential_latent_hospital_admissions = compute_delay_ascertained_incidence( - p_observed_given_incident=1, - latent_incidence=state_inf_per_capita, - delay_incidence_to_observation_pmf=inf_to_hosp, - )[-n_datapoints:] + potential_latent_hospital_admissions = ( + compute_delay_ascertained_incidence( + p_observed_given_incident=1, + latent_incidence=state_inf_per_capita, + delay_incidence_to_observation_pmf=inf_to_hosp, + )[-n_datapoints:] + ) latent_hospital_admissions = ( potential_latent_hospital_admissions @@ -414,7 +429,9 @@ def sample( * hosp_wday_effect * self.state_pop ) - numpyro.deterministic("latent_hospital_admissions", latent_hospital_admissions) + numpyro.deterministic( + "latent_hospital_admissions", latent_hospital_admissions + ) hospital_admission_obs_rv = NegativeBinomialObservation( "observed_hospital_admissions", concentration_rv=self.phi_rv @@ -460,7 +477,9 @@ def batch_colvolve_fn(m): "sigma_ww_site", DistributionalVariable( "log_sigma_ww_site", - dist.Normal(jnp.log(mode_sigma_ww_site), sd_log_sigma_ww_site), + dist.Normal( + jnp.log(mode_sigma_ww_site), sd_log_sigma_ww_site + ), reparam=LocScaleReparam(0), ), transforms=transforms.ExpTransform(), @@ -476,13 +495,17 @@ def batch_colvolve_fn(m): ] # multiply the expected observed genomes by the site-specific multiplier at that sampling time - exp_obs_log_v = exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] + exp_obs_log_v = ( + exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] + ) numpyro.sample( "log_conc_obs", dist.Normal( loc=exp_obs_log_v[self.ww_uncensored], - scale=sigma_ww_site[self.ww_sampled_lab_sites[self.ww_uncensored]], + scale=sigma_ww_site[ + self.ww_sampled_lab_sites[self.ww_uncensored] + ], ), obs=( data_observed_log_conc[self.ww_uncensored] @@ -493,21 +516,24 @@ def batch_colvolve_fn(m): if self.ww_censored.shape[0] != 0: log_cdf_values = dist.Normal( loc=exp_obs_log_v[self.ww_censored], - scale=sigma_ww_site[self.ww_sampled_lab_sites[self.ww_censored]], + scale=sigma_ww_site[ + self.ww_sampled_lab_sites[self.ww_censored] + ], ).log_cdf(self.ww_log_lod[self.ww_censored]) numpyro.factor("log_prob_censored", log_cdf_values.sum()) site_ww_pred_log = numpyro.sample( "site_ww_pred_log", dist.Normal( - loc=model_log_v_ot[:, self.lab_site_to_subpop_map] + ww_site_mod, + loc=model_log_v_ot[:, self.lab_site_to_subpop_map] + + ww_site_mod, scale=sigma_ww_site, ), ) - state_model_net_i = jnp.convolve(state_inf_per_capita, s, mode="valid")[ - -n_datapoints: - ] + state_model_net_i = jnp.convolve( + state_inf_per_capita, s, mode="valid" + )[-n_datapoints:] numpyro.deterministic("state_model_net_i", state_model_net_i) state_log_c = ( @@ -518,13 +544,17 @@ def batch_colvolve_fn(m): numpyro.deterministic("state_log_c", state_log_c) expected_state_ww_conc = jnp.exp(state_log_c) - numpyro.deterministic("expected_state_ww_conc", expected_state_ww_conc) + numpyro.deterministic( + "expected_state_ww_conc", expected_state_ww_conc + ) state_rt = ( state_inf_per_capita[-n_datapoints:] / jnp.convolve( state_inf_per_capita, - jnp.hstack((jnp.array([0]), jnp.array(generation_interval_pmf))), + jnp.hstack( + (jnp.array([0]), jnp.array(generation_interval_pmf)) + ), mode="valid", )[-n_datapoints:] ) From 259e895ce55a0248dca4ffd590dc418dfb172dc9 Mon Sep 17 00:00:00 2001 From: sbidari Date: Tue, 15 Oct 2024 16:32:42 -0400 Subject: [PATCH 35/50] fix forecasting --- notebooks/site_level_dynamics_model_demo.qmd | 57 ++++---------------- pyrenew_hew/plotting.py | 20 ++++--- pyrenew_hew/site_level_dynamics_model.py | 38 ++++++++----- 3 files changed, 45 insertions(+), 70 deletions(-) diff --git a/notebooks/site_level_dynamics_model_demo.qmd b/notebooks/site_level_dynamics_model_demo.qmd index 515c8667..fb968d1a 100644 --- a/notebooks/site_level_dynamics_model_demo.qmd +++ b/notebooks/site_level_dynamics_model_demo.qmd @@ -320,7 +320,7 @@ my_model = ww_site_level_dynamics_model( ww_log_lod, lab_site_to_subpop_map, max_ww_sampled_days, - include_ww=1 + include_ww=True ) ``` @@ -353,7 +353,7 @@ Simulate the posterior predictive distribution ```{python} # | label: posterior predictive posterior_predictive = my_model.posterior_predictive( - n_datapoints= max(len(data_observed_hospital_admissions), max_ww_sampled_days) + n_forecast_days, + n_datapoints= max(len(data_observed_hospital_admissions), max_ww_sampled_days) + n_forecast_days, forecast=True ) ``` @@ -370,18 +370,14 @@ idata = az.from_numpyro( ) ``` -```{python} -idata -``` - ```{python} # | label: plot prior preditive -plotting.plot_predictive(idata, prior=True) +plotting.plot_predictive(idata, "observed_hospital_admissions", prior=True) ``` ```{python} # | label: plot posterior preditive -plotting.plot_predictive(idata) +plotting.plot_predictive(idata,"pred_hospital_admissions") ``` ```{python} @@ -390,43 +386,8 @@ plotting.plot_posterior(idata,'state_rt') ``` ```{python} -# | label: plot posterior preditive rt -x_data = idata.posterior_predictive['state_rt_dim_0'] -y_data = idata.posterior_predictive['state_rt'] - -def compute_eti(dataset, eti_prob): - eti_bdry = dataset.quantile( - ((1 - eti_prob) / 2, 1 / 2 + eti_prob / 2), dim=("chain", "draw") - ) - return eti_bdry.values.T -fig, axes = plt.subplots(figsize=(6, 5)) -az.plot_hdi( - x_data, - hdi_data=compute_eti(y_data, 0.9), - color="C0", - smooth=False, - fill_kwargs={"alpha": 0.3}, - ax=axes, -) - -az.plot_hdi( - x_data, - hdi_data=compute_eti(y_data, 0.5), - color="C0", - smooth=False, - fill_kwargs={"alpha": 0.6}, - ax=axes, -) - -# Add median of the posterior to the figure -median_ts = y_data.median(dim=["chain", "draw"]) - -plt.plot( - x_data, - median_ts, - color="C0", - label="Median", -) +# | label: plot posterior predictive rt +plotting.plot_predictive(idata,'state_rt') ``` ```{python} @@ -482,7 +443,7 @@ my_model_hosp_only_fit = ww_site_level_dynamics_model( hosp_times, pop_fraction=1, n_subpops=1, - include_ww=0, + include_ww=False, ) ``` @@ -500,7 +461,7 @@ my_model_hosp_only_fit.run( ```{python} posterior_predictive_hosp_only = my_model_hosp_only_fit.posterior_predictive( - n_datapoints= len(data_observed_hospital_admissions) + n_forecast_days, + n_datapoints= len(data_observed_hospital_admissions) + n_forecast_days,forecast=True ) ``` @@ -511,7 +472,7 @@ idata_hosp_only = az.from_numpyro( ) ``` ```{python} -plotting.plot_predictive(idata_hosp_only) +plotting.plot_predictive(idata_hosp_only,name='pred_hospital_admissions') ``` ```{python} diff --git a/pyrenew_hew/plotting.py b/pyrenew_hew/plotting.py index 69cdbf10..21304825 100644 --- a/pyrenew_hew/plotting.py +++ b/pyrenew_hew/plotting.py @@ -51,14 +51,14 @@ def plot_posterior(idata, name, dim_1=None): axes.set_ylabel(name, fontsize=10) -def plot_predictive(idata, prior=False): +def plot_predictive(idata, name="observed_hospital_admissions", prior=False): prior_or_post_text = "Prior" if prior else "Posterior" predictive_obj = ( idata.prior_predictive if prior else idata.posterior_predictive ) - x_data = predictive_obj["observed_hospital_admissions_dim_0"] - y_data = predictive_obj["observed_hospital_admissions"] + x_data = predictive_obj[f"{name}_dim_0"] + y_data = predictive_obj[name] fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( @@ -88,11 +88,15 @@ def plot_predictive(idata, prior=False): color="C0", label="Median", ) - plt.scatter( - idata.observed_data["observed_hospital_admissions_dim_0"], - idata.observed_data["observed_hospital_admissions"], - color="black", - ) + if ( + name == "observed_hospital_admissions" + or name == "pred_hospital_admissions" + ): + plt.scatter( + idata.observed_data["observed_hospital_admissions_dim_0"], + idata.observed_data["observed_hospital_admissions"], + color="black", + ) axes.legend() axes.set_title(f"{prior_or_post_text} Predictive Admissions", fontsize=10) axes.set_xlabel("Time", fontsize=10) diff --git a/pyrenew_hew/site_level_dynamics_model.py b/pyrenew_hew/site_level_dynamics_model.py index 2c295879..f7fba15e 100644 --- a/pyrenew_hew/site_level_dynamics_model.py +++ b/pyrenew_hew/site_level_dynamics_model.py @@ -14,7 +14,6 @@ InitializeInfectionsExponentialGrowth, ) from pyrenew.metaclass import Model -from pyrenew.observation import NegativeBinomialObservation from pyrenew.process import ARProcess, DifferencedProcess from pyrenew.randomvariable import DistributionalVariable, TransformedVariable @@ -69,7 +68,7 @@ def __init__( ww_log_lod=None, lab_site_to_subpop_map=None, max_ww_sampled_days=None, - include_ww=0, + include_ww=False, ): # numpydoc ignore=GL08 self.state_pop = state_pop self.n_subpops = n_subpops @@ -129,7 +128,6 @@ def __init__( fundamental_process=ARProcess(), differencing_order=1, ) - return None def validate(self): # numpydoc ignore=GL08 @@ -140,6 +138,7 @@ def sample( n_datapoints=None, data_observed_hospital_admissions=None, data_observed_log_conc=None, + forecast=False, ): # numpydoc ignore=GL08 if n_datapoints is None: if ( @@ -228,7 +227,11 @@ def sample( self.n_subpops > 1, offset_ref_log_r_t, 0 ) - if self.n_subpops > 1: + if self.n_subpops == 1: + i_first_obs_over_n_subpop = i_first_obs_over_n_ref_subpop + initial_exp_growth_rate_subpop = initial_exp_growth_rate_ref_subpop + log_rtu_subpop_in_week = log_rtu_ref_subpop_in_week[:, jnp.newaxis] + else: sigma_i_first_obs = self.sigma_i_first_obs_rv() i_first_obs_over_n_non_ref_subpop_rv = TransformedVariable( "i_first_obs_over_n_non_ref_subpop", @@ -308,10 +311,6 @@ def sample( ], axis=1, ) - else: - i_first_obs_over_n_subpop = i_first_obs_over_n_ref_subpop - initial_exp_growth_rate_subpop = initial_exp_growth_rate_ref_subpop - log_rtu_subpop_in_week = log_rtu_ref_subpop_in_week[:, jnp.newaxis] numpyro.deterministic( "i_first_obs_over_n_subpop", i_first_obs_over_n_subpop @@ -433,14 +432,24 @@ def sample( "latent_hospital_admissions", latent_hospital_admissions ) - hospital_admission_obs_rv = NegativeBinomialObservation( - "observed_hospital_admissions", concentration_rv=self.phi_rv - ) - observed_hospital_admissions = hospital_admission_obs_rv( - mu=latent_hospital_admissions[self.hosp_times], + phi = self.phi_rv() + numpyro.sample( + "observed_hospital_admissions", + dist.NegativeBinomial2( + mean=latent_hospital_admissions[self.hosp_times], + concentration=phi, + ), obs=data_observed_hospital_admissions, ) + if forecast: + pred_hospital_admissions = numpyro.sample( + "pred_hospital_admissions", + dist.NegativeBinomial2( + mean=latent_hospital_admissions, concentration=phi + ).mask(False), + ) + # wastewater component if self.include_ww: t_peak = self.t_peak_rv() @@ -522,6 +531,7 @@ def batch_colvolve_fn(m): ).log_cdf(self.ww_log_lod[self.ww_censored]) numpyro.factor("log_prob_censored", log_cdf_values.sum()) + # if forecast: site_ww_pred_log = numpyro.sample( "site_ww_pred_log", dist.Normal( @@ -561,6 +571,6 @@ def batch_colvolve_fn(m): numpyro.deterministic("state_rt", state_rt) return ( - observed_hospital_admissions, + pred_hospital_admissions if forecast else None, site_ww_pred_log if self.include_ww else None, ) From b0278d4b265d7abb670caf85e8ebb970820ac28d Mon Sep 17 00:00:00 2001 From: sbidari Date: Tue, 15 Oct 2024 16:54:28 -0400 Subject: [PATCH 36/50] clean up wwinference.Rmd --- notebooks/wwinference.Rmd | 27 +++++---------------------- 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/notebooks/wwinference.Rmd b/notebooks/wwinference.Rmd index 4ce1ef5f..ea478bed 100644 --- a/notebooks/wwinference.Rmd +++ b/notebooks/wwinference.Rmd @@ -177,7 +177,7 @@ hosp_data_preprocessed <- wwinference::preprocess_count_data( We'll make some plots of the data just to make sure it looks like what we'd expect: -```{r time-series-fig, out.width='100%'} +```{r wastewater-time-series-fig, out.width='100%'} ggplot(ww_data_preprocessed) + geom_point( aes( @@ -199,7 +199,8 @@ ggplot(ww_data_preprocessed) + ylab("Genome copies/mL") + ggtitle("Lab-site level wastewater concentration") + theme_bw() - +``` +```{r hosp-time-series-fig, out.width='100%'} ggplot(hosp_data_preprocessed) + # Plot the hospital admissions data that we will evaluate against in white geom_point( @@ -375,28 +376,10 @@ observed hospital admissions and wastewater concentrations, as well as the latent variables of interest including the site-level R(t) estimates and the state-level R(t) estimate. -We can generate this directly on the output of `wwinference()` using: -```{r extracting-draws} -draws_df <- get_draws(ww_fit) - -cat( - "Variables in dataframe: ", - sprintf("%s", paste(unique(draws_df$name), collapse = ", ")) -) -``` -Note that by default the `get_draws_df()` function will return a tidy long -dataframe with all of the posterior draws joined to applicable data for each of -the included variables. To examine a particular variable (e.g. `"predicted counts"` for posterior -predicted hospital admissions), filter the data frame based on the `name` column. - -### Using explicit passed arguments rather than S3 methods - -Rather than using S3 methods supplied for `wwinference()`, the elements in the -`wwinference_fit` object can also be used directly to create this dataframe. -This is demonstrated below: +We can generate this using the `ww_fit` object as demonstrated below: ```{r extracting-draws-explicit} -draws_explicit <- get_draws( +draws <- get_draws( x = ww_fit$raw_input_data$input_ww_data, count_data = ww_fit$raw_input_data$input_count_data, date_time_spine = ww_fit$raw_input_data$date_time_spine, From cc00bd9aaa4abb1c5561b99fdf1844c72f6cf78d Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 16 Oct 2024 11:08:53 -0400 Subject: [PATCH 37/50] code review suggestions --- notebooks/site_level_dynamics_model_demo.qmd | 3 - pyrenew_hew/initialization.py | 99 -------------------- pyrenew_hew/site_level_dynamics_model.py | 14 +-- 3 files changed, 8 insertions(+), 108 deletions(-) delete mode 100644 pyrenew_hew/initialization.py diff --git a/notebooks/site_level_dynamics_model_demo.qmd b/notebooks/site_level_dynamics_model_demo.qmd index fb968d1a..bb0650c4 100644 --- a/notebooks/site_level_dynamics_model_demo.qmd +++ b/notebooks/site_level_dynamics_model_demo.qmd @@ -357,9 +357,6 @@ posterior_predictive = my_model.posterior_predictive( ) ``` -We can use initial values set in the stan code by using `init_strategy = init_to_value` and use the following to generate initial values -`init_vals = get_initialization(stan_data,stdev=0.01,rng_key=jax.random.PRNGKey(0))` - ```{python} # | label: create inference data object import arviz as az diff --git a/pyrenew_hew/initialization.py b/pyrenew_hew/initialization.py deleted file mode 100644 index 57a968ac..00000000 --- a/pyrenew_hew/initialization.py +++ /dev/null @@ -1,99 +0,0 @@ -import jax -import jax.numpy as jnp -import numpy as np -import numpyro.distributions as dist - -from pyrenew_hew.utils import convert_to_logmean_log_sd - - -def get_initialization(stan_data, stdev, rng_key): - i_first_obs_est = ( - np.mean(stan_data["hosp"][:7]) / stan_data["p_hosp_prior_mean"] - ) - logit_i_frac_est = jax.scipy.special.logit( - i_first_obs_est / stan_data["state_pop"] - ) - - init_vals = { - "offset_ref_log_r_t": ( - dist.Normal( - stan_data["offset_ref_log_r_t_prior_mean"], stdev - ).sample(rng_key) - if stan_data["n_subpops"] > 1 - else None - ), - "offset_ref_logit_i_first_obs": ( - dist.Normal( - stan_data["offset_ref_logit_i_first_obs_prior_mean"], stdev - ).sample(rng_key) - if stan_data["n_subpops"] > 1 - else None - ), - "offset_ref_initial_exp_growth_rate": ( - dist.Normal( - stan_data["offset_ref_initial_exp_growth_rate_prior_mean"], - stdev, - ).sample(rng_key) - if stan_data["n_subpops"] > 1 - else None - ), - "eta_sd": jnp.abs(dist.Normal(0, stdev).sample(rng_key)), - "autoreg_rt": jnp.abs( - dist.Normal( - stan_data["autoreg_rt_a"] - / (stan_data["autoreg_rt_a"] + stan_data["autoreg_rt_b"]), - 0.05, - ).sample(rng_key) - ), - "log_r_t_first_obs": dist.Normal( - convert_to_logmean_log_sd(1, stdev)[0], - convert_to_logmean_log_sd(1, stdev)[1], - ).sample(rng_key), - "sigma_rt": jnp.abs(dist.Normal(0, stdev).sample(rng_key)), - "autoreg_rt_subpop": jnp.abs(dist.Normal(0.5, 0.05).sample(rng_key)), - "sigma_i_first_obs": jnp.abs(dist.Normal(0, stdev).sample(rng_key)), - "sigma_initial_exp_growth_rate": jnp.abs( - dist.Normal(0, stdev).sample(rng_key) - ), - "i_first_obs_over_n": jax.nn.sigmoid( - dist.Normal(logit_i_frac_est, 0.05).sample(rng_key) - ), - "mean_initial_exp_growth_rate": dist.Normal(0, stdev).sample(rng_key), - "inv_sqrt_phi": 1 / jnp.sqrt(200) - + dist.Normal(1 / 10000, 1 / 10000).sample(rng_key), - "mode_sigma_ww_site": jnp.abs( - dist.Normal( - stan_data["mode_sigma_ww_site_prior_mode"], - stdev * stan_data["mode_sigma_ww_site_prior_sd"], - ).sample(rng_key) - ), - "sd_log_sigma_ww_site": jnp.abs( - dist.Normal( - stan_data["sd_log_sigma_ww_site_prior_mode"], - stdev * stan_data["sd_log_sigma_ww_site_prior_sd"], - ).sample(rng_key) - ), - "p_hosp_mean": dist.Normal( - jax.scipy.special.logit(stan_data["p_hosp_prior_mean"]), stdev - ).sample(rng_key), - "p_hosp_w_sd": jnp.abs(dist.Normal(0.01, 0.001).sample(rng_key)), - "autoreg_p_hosp": jnp.abs(dist.Normal(1 / 100, 0.001).sample(rng_key)), - "t_peak": dist.Normal( - stan_data["viral_shedding_pars"][0], - stdev * stan_data["viral_shedding_pars"][1], - ).sample(rng_key), - "dur_shed_after_peak": dist.Normal( - stan_data["viral_shedding_pars"][4], - stdev * stan_data["viral_shedding_pars"][5], - ).sample(rng_key), - "log10_g": dist.Normal(stan_data["log10_g_prior_mean"], 0.5).sample( - rng_key - ), - "ww_site_mod_sd": jnp.abs(dist.Normal(0, stdev).sample(rng_key)), - "hosp_wday_effect_raw": jax.nn.softmax( - jnp.abs(dist.Normal(1 / 7, stdev).expand([7]).sample(rng_key)) - ), - "inf_feedback_raw": jnp.abs(dist.Normal(500, 20).sample(rng_key)), - } - - return init_vals diff --git a/pyrenew_hew/site_level_dynamics_model.py b/pyrenew_hew/site_level_dynamics_model.py index f7fba15e..e7fc8663 100644 --- a/pyrenew_hew/site_level_dynamics_model.py +++ b/pyrenew_hew/site_level_dynamics_model.py @@ -140,11 +140,13 @@ def sample( data_observed_log_conc=None, forecast=False, ): # numpydoc ignore=GL08 - if n_datapoints is None: + if ( + n_datapoints is None + ): # calculate model calibration period based on data if ( data_observed_hospital_admissions is None and data_observed_log_conc is None - ): + ): # no data for calibration raise ValueError( "Either n_datapoints or data_observed_hosp_admissions " "must be passed." @@ -152,7 +154,7 @@ def sample( elif ( data_observed_hospital_admissions is None and data_observed_log_conc is not None - ): + ): # does not support fitting to just wastewater data raise ValueError( "Either n_datapoints or data_observed_hosp_admissions " "must be passed." @@ -160,9 +162,9 @@ def sample( elif ( data_observed_hospital_admissions is not None and data_observed_log_conc is None - ): + ): # only fit hosp admissions data n_datapoints = len(data_observed_hospital_admissions) - else: + else: # both hosp admisssions and ww data provided n_datapoints = max( len(data_observed_hospital_admissions), self.max_ww_sampled_days, @@ -180,7 +182,7 @@ def sample( else: n_datapoints = n_datapoints - n_weeks_post_init = n_datapoints // 7 + 1 + n_weeks_post_init = -((-n_datapoints) // 7) # n_datapoints // 7 + 1 eta_sd = self.eta_sd_rv() autoreg_rt = self.autoreg_rt_rv() From 595dbb8578de1ff46d46cd1d60571be2235fa314 Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 16 Oct 2024 12:18:30 -0400 Subject: [PATCH 38/50] make some model args optional for hosp only fit --- notebooks/site_level_dynamics_model_demo.qmd | 12 ++++-------- notebooks/wwinference.Rmd | 4 ---- pyrenew_hew/site_level_dynamics_model.py | 8 ++++---- 3 files changed, 8 insertions(+), 16 deletions(-) diff --git a/notebooks/site_level_dynamics_model_demo.qmd b/notebooks/site_level_dynamics_model_demo.qmd index bb0650c4..dbdfd505 100644 --- a/notebooks/site_level_dynamics_model_demo.qmd +++ b/notebooks/site_level_dynamics_model_demo.qmd @@ -282,11 +282,7 @@ my_model = ww_site_level_dynamics_model( log_r_t_first_obs_rv, autoreg_rt_rv, eta_sd_rv, - autoreg_rt_subpop_rv, - sigma_rt_rv, i_first_obs_over_n_rv, - sigma_i_first_obs_rv, - sigma_initial_exp_growth_rate_rv, mean_initial_exp_growth_rate_rv, offset_ref_logit_i_first_obs_rv, offset_ref_initial_exp_growth_rate_rv, @@ -303,6 +299,10 @@ my_model = ww_site_level_dynamics_model( hosp_times, pop_fraction, n_subpops, + autoreg_rt_subpop_rv, + sigma_rt_rv, + sigma_i_first_obs_rv, + sigma_initial_exp_growth_rate_rv, t_peak_rv, dur_shed_after_peak_rv, n_ww_lab_sites, @@ -419,11 +419,7 @@ my_model_hosp_only_fit = ww_site_level_dynamics_model( log_r_t_first_obs_rv, autoreg_rt_rv, eta_sd_rv, - autoreg_rt_subpop_rv, - sigma_rt_rv, i_first_obs_over_n_rv, - sigma_i_first_obs_rv, - sigma_initial_exp_growth_rate_rv, mean_initial_exp_growth_rate_rv, offset_ref_logit_i_first_obs_rv, offset_ref_initial_exp_growth_rate_rv, diff --git a/notebooks/wwinference.Rmd b/notebooks/wwinference.Rmd index ea478bed..7964fb36 100644 --- a/notebooks/wwinference.Rmd +++ b/notebooks/wwinference.Rmd @@ -481,10 +481,6 @@ rely on the admissions only model if there are covergence or known data issues with the wastewater data. ```{r fit-hosp-only, warning=FALSE, message=FALSE} -params$sigma_rt_prior <- 0.0001 # This model contains a single "site," so we -# decrease the prior variance to make the "site" Rt nearly the same as the -# "global" Rt. - fit_hosp_only <- wwinference::wwinference( ww_data = ww_data_to_fit, count_data = hosp_data_preprocessed, diff --git a/pyrenew_hew/site_level_dynamics_model.py b/pyrenew_hew/site_level_dynamics_model.py index e7fc8663..4f60c664 100644 --- a/pyrenew_hew/site_level_dynamics_model.py +++ b/pyrenew_hew/site_level_dynamics_model.py @@ -30,11 +30,7 @@ def __init__( log_r_t_first_obs_rv, autoreg_rt_rv, eta_sd_rv, - autoreg_rt_subpop_rv, - sigma_rt_rv, i_first_obs_over_n_rv, - sigma_i_first_obs_rv, - sigma_initial_exp_growth_rate_rv, mean_initial_exp_growth_rate_rv, offset_ref_logit_i_first_obs_rv, offset_ref_initial_exp_growth_rate_rv, @@ -51,6 +47,10 @@ def __init__( hosp_times, pop_fraction, n_subpops, + autoreg_rt_subpop_rv=None, + sigma_rt_rv=None, + sigma_i_first_obs_rv=None, + sigma_initial_exp_growth_rate_rv=None, t_peak_rv=None, dur_shed_after_peak_rv=None, n_ww_lab_sites=None, From e2a778588cbc0d907b4c7b1a3e0989da210765b2 Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 16 Oct 2024 12:50:46 -0400 Subject: [PATCH 39/50] more cleanup --- notebooks/site_level_dynamics_model_demo.qmd | 8 ++------ pyrenew_hew/site_level_dynamics_model.py | 16 ++++++---------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/notebooks/site_level_dynamics_model_demo.qmd b/notebooks/site_level_dynamics_model_demo.qmd index dbdfd505..c15f5cb5 100644 --- a/notebooks/site_level_dynamics_model_demo.qmd +++ b/notebooks/site_level_dynamics_model_demo.qmd @@ -16,8 +16,6 @@ from pyrenew.deterministic import DeterministicVariable, DeterministicPMF from pyrenew.randomvariable import DistributionalVariable, TransformedVariable from pyrenew_hew.site_level_dynamics_model import ww_site_level_dynamics_model -from pyrenew_hew.initialization import get_initialization -from numpyro.infer.initialization import init_to_sample, init_to_value from pyrenew_hew.utils import convert_to_logmean_log_sd, get_vl_trajectory import pyrenew_hew.plotting as plotting import matplotlib.pyplot as plt @@ -344,8 +342,7 @@ my_model.run( rng_key=jax.random.key(223), data_observed_hospital_admissions=data_observed_hospital_admissions, data_observed_log_conc=data_observed_log_conc, - mcmc_args=dict(num_chains=4), - nuts_args=dict(init_strategy=init_to_sample) + mcmc_args=dict(num_chains=4) ) ``` @@ -447,8 +444,7 @@ my_model_hosp_only_fit.run( num_samples=100, rng_key=jax.random.key(223), data_observed_hospital_admissions=data_observed_hospital_admissions, - mcmc_args=dict(num_chains=4), - nuts_args=dict(init_strategy=init_to_sample) + mcmc_args=dict(num_chains=4) ) ``` diff --git a/pyrenew_hew/site_level_dynamics_model.py b/pyrenew_hew/site_level_dynamics_model.py index 4f60c664..d1bb19db 100644 --- a/pyrenew_hew/site_level_dynamics_model.py +++ b/pyrenew_hew/site_level_dynamics_model.py @@ -249,17 +249,13 @@ def sample( sigma_initial_exp_growth_rate = ( self.sigma_initial_exp_growth_rate_rv() ) - initial_exp_growth_rate_non_ref_subpop_rv = TransformedVariable( - "clipped_initial_exp_growth_rate_non_ref_subpop", - DistributionalVariable( - "initial_exp_growth_rate_non_ref_subpop_raw", - dist.Normal( - mean_initial_exp_growth_rate, - sigma_initial_exp_growth_rate, - ), - reparam=LocScaleReparam(0), + initial_exp_growth_rate_non_ref_subpop_rv = DistributionalVariable( + "initial_exp_growth_rate_non_ref_subpop_raw", + dist.Normal( + mean_initial_exp_growth_rate, + sigma_initial_exp_growth_rate, ), - transforms=lambda x: jnp.clip(x, -0.01, 0.01), + reparam=LocScaleReparam(0), ) autoreg_rt_subpop = self.autoreg_rt_subpop_rv() From 244baab3b19df39a83bd8aac2043194d071810e8 Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 16 Oct 2024 12:53:53 -0400 Subject: [PATCH 40/50] remove extra imports --- notebooks/site_level_dynamics_model_demo.qmd | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/notebooks/site_level_dynamics_model_demo.qmd b/notebooks/site_level_dynamics_model_demo.qmd index c15f5cb5..d17fbb5e 100644 --- a/notebooks/site_level_dynamics_model_demo.qmd +++ b/notebooks/site_level_dynamics_model_demo.qmd @@ -16,9 +16,8 @@ from pyrenew.deterministic import DeterministicVariable, DeterministicPMF from pyrenew.randomvariable import DistributionalVariable, TransformedVariable from pyrenew_hew.site_level_dynamics_model import ww_site_level_dynamics_model -from pyrenew_hew.utils import convert_to_logmean_log_sd, get_vl_trajectory +from pyrenew_hew.utils import convert_to_logmean_log_sd import pyrenew_hew.plotting as plotting -import matplotlib.pyplot as plt numpyro.set_host_device_count(4) ``` From 26eea3abe216620b27f5d676594c106c17d88a3c Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 16 Oct 2024 13:17:20 -0400 Subject: [PATCH 41/50] update plotting.py for predictive plots --- pyrenew_hew/plotting.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/pyrenew_hew/plotting.py b/pyrenew_hew/plotting.py index 21304825..cc4b4452 100644 --- a/pyrenew_hew/plotting.py +++ b/pyrenew_hew/plotting.py @@ -51,14 +51,20 @@ def plot_posterior(idata, name, dim_1=None): axes.set_ylabel(name, fontsize=10) -def plot_predictive(idata, name="observed_hospital_admissions", prior=False): +def plot_predictive( + idata, name="observed_hospital_admissions", dim_1=None, prior=False +): prior_or_post_text = "Prior" if prior else "Posterior" predictive_obj = ( idata.prior_predictive if prior else idata.posterior_predictive ) x_data = predictive_obj[f"{name}_dim_0"] - y_data = predictive_obj[name] + y_data = ( + predictive_obj[name] + if dim_1 is None + else predictive_obj[name].isel({f"{name}_dim_1": dim_1}) + ) fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( @@ -98,6 +104,6 @@ def plot_predictive(idata, name="observed_hospital_admissions", prior=False): color="black", ) axes.legend() - axes.set_title(f"{prior_or_post_text} Predictive Admissions", fontsize=10) + axes.set_title(f"{prior_or_post_text} {name}", fontsize=10) axes.set_xlabel("Time", fontsize=10) - axes.set_ylabel("Hospital Admissions", fontsize=10) + axes.set_ylabel(f"{name}", fontsize=10) From 22e8410f762ca8334f91955f6f07c6282da046a8 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 17 Oct 2024 09:57:24 -0400 Subject: [PATCH 42/50] update predictive sampling --- notebooks/site_level_dynamics_model_demo.qmd | 23 ++++++-------- pyrenew_hew/site_level_dynamics_model.py | 32 +++++++++----------- 2 files changed, 23 insertions(+), 32 deletions(-) diff --git a/notebooks/site_level_dynamics_model_demo.qmd b/notebooks/site_level_dynamics_model_demo.qmd index d17fbb5e..9f2fe6f4 100644 --- a/notebooks/site_level_dynamics_model_demo.qmd +++ b/notebooks/site_level_dynamics_model_demo.qmd @@ -336,7 +336,7 @@ Fit the model to observed hospital admissions and wastewater data ```{python} # | label: model fit my_model.run( - num_warmup=500, + num_warmup=100, num_samples=100, rng_key=jax.random.key(223), data_observed_hospital_admissions=data_observed_hospital_admissions, @@ -349,7 +349,7 @@ Simulate the posterior predictive distribution ```{python} # | label: posterior predictive posterior_predictive = my_model.posterior_predictive( - n_datapoints= max(len(data_observed_hospital_admissions), max_ww_sampled_days) + n_forecast_days, forecast=True + n_datapoints= max(len(data_observed_hospital_admissions), max_ww_sampled_days) + n_forecast_days, is_predictive=True ) ``` @@ -370,12 +370,7 @@ plotting.plot_predictive(idata, "observed_hospital_admissions", prior=True) ```{python} # | label: plot posterior preditive -plotting.plot_predictive(idata,"pred_hospital_admissions") -``` - -```{python} -# | label: plot posterior rt -plotting.plot_posterior(idata,'state_rt') +plotting.plot_predictive(idata,"observed_hospital_admissions") ``` ```{python} @@ -384,9 +379,9 @@ plotting.plot_predictive(idata,'state_rt') ``` ```{python} -# | label: plot posterior rt by subpop +# | label: plot posterior predictive rt by subpop for i in range(n_subpops): - plotting.plot_posterior(idata, 'r_subpop_t', dim_1=i) + plotting.plot_predictive(idata, 'r_subpop_t', dim_1=i) ``` ```{python} @@ -439,7 +434,7 @@ my_model_hosp_only_fit = ww_site_level_dynamics_model( ```{python} # | label: hosp only fit my_model_hosp_only_fit.run( - num_warmup=500, + num_warmup=100, num_samples=100, rng_key=jax.random.key(223), data_observed_hospital_admissions=data_observed_hospital_admissions, @@ -449,7 +444,7 @@ my_model_hosp_only_fit.run( ```{python} posterior_predictive_hosp_only = my_model_hosp_only_fit.posterior_predictive( - n_datapoints= len(data_observed_hospital_admissions) + n_forecast_days,forecast=True + n_datapoints= len(data_observed_hospital_admissions) + n_forecast_days,is_predictive=True ) ``` @@ -460,9 +455,9 @@ idata_hosp_only = az.from_numpyro( ) ``` ```{python} -plotting.plot_predictive(idata_hosp_only,name='pred_hospital_admissions') +plotting.plot_predictive(idata_hosp_only,name='observed_hospital_admissions') ``` ```{python} -plotting.plot_posterior(idata_hosp_only,'r_subpop_t') +plotting.plot_predictive(idata_hosp_only,'r_subpop_t') ``` diff --git a/pyrenew_hew/site_level_dynamics_model.py b/pyrenew_hew/site_level_dynamics_model.py index d1bb19db..50fbe7d7 100644 --- a/pyrenew_hew/site_level_dynamics_model.py +++ b/pyrenew_hew/site_level_dynamics_model.py @@ -14,6 +14,7 @@ InitializeInfectionsExponentialGrowth, ) from pyrenew.metaclass import Model +from pyrenew.observation import NegativeBinomialObservation from pyrenew.process import ARProcess, DifferencedProcess from pyrenew.randomvariable import DistributionalVariable, TransformedVariable @@ -138,7 +139,7 @@ def sample( n_datapoints=None, data_observed_hospital_admissions=None, data_observed_log_conc=None, - forecast=False, + is_predictive=False, ): # numpydoc ignore=GL08 if ( n_datapoints is None @@ -430,23 +431,19 @@ def sample( "latent_hospital_admissions", latent_hospital_admissions ) - phi = self.phi_rv() - numpyro.sample( - "observed_hospital_admissions", - dist.NegativeBinomial2( - mean=latent_hospital_admissions[self.hosp_times], - concentration=phi, - ), - obs=data_observed_hospital_admissions, + hospital_admission_obs_rv = NegativeBinomialObservation( + "observed_hospital_admissions", concentration_rv=self.phi_rv ) - if forecast: - pred_hospital_admissions = numpyro.sample( - "pred_hospital_admissions", - dist.NegativeBinomial2( - mean=latent_hospital_admissions, concentration=phi - ).mask(False), - ) + if not is_predictive: + mu_obs_hosp = latent_hospital_admissions[self.hosp_times] + else: + mu_obs_hosp = latent_hospital_admissions + + observed_hospital_admissions = hospital_admission_obs_rv( + mu=mu_obs_hosp, + obs=data_observed_hospital_admissions, + ) # wastewater component if self.include_ww: @@ -529,7 +526,6 @@ def batch_colvolve_fn(m): ).log_cdf(self.ww_log_lod[self.ww_censored]) numpyro.factor("log_prob_censored", log_cdf_values.sum()) - # if forecast: site_ww_pred_log = numpyro.sample( "site_ww_pred_log", dist.Normal( @@ -569,6 +565,6 @@ def batch_colvolve_fn(m): numpyro.deterministic("state_rt", state_rt) return ( - pred_hospital_admissions if forecast else None, + observed_hospital_admissions, site_ww_pred_log if self.include_ww else None, ) From a20309e4c5fd537a79cabb4dc7f3a33d3c712db5 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 17 Oct 2024 11:35:33 -0400 Subject: [PATCH 43/50] revert .gitignore --- .gitignore | 2 -- 1 file changed, 2 deletions(-) diff --git a/.gitignore b/.gitignore index d4a78120..721ed3c1 100644 --- a/.gitignore +++ b/.gitignore @@ -122,7 +122,6 @@ target/ # Jupyter Notebook .ipynb_checkpoints -*.ipynb # IPython profile_default/ @@ -388,6 +387,5 @@ poetry.lock notebooks/*_files/ notebooks/*.md -notebooks/*.quarto_ipynb nssp_demo/private_data/* From e32c9a79d242e42f86b64e67d8af93dfdef62519 Mon Sep 17 00:00:00 2001 From: sbidari Date: Fri, 22 Nov 2024 00:06:58 -0500 Subject: [PATCH 44/50] Merge branch 'main' of https://github.com/CDCgov/pyrenew-hew into site_level_dynamics_ref_subpop --- .ContainerBuildRprofile | 19 + .Rprofile | 2 - .containerignore | 3 + .github/workflows/containers.yaml | 121 + .github/workflows/r-cmd-check.yaml | 29 + .gitignore | 14 +- .pre-commit-config.yaml | 8 +- Containerfile | 17 + Containerfile.dependencies | 11 + Makefile | 30 + .../data/fit/stan_data.json | 0 .../data/fit_hosp_only/stan_data.json | 0 .../hosp_only_ww_model.qmd | 8 +- .../hosp_only_ww_model}/model_comp.qmd | 0 .../site_level_dynamics_model_demo.qmd | 0 .../hosp_only_ww_model}/wwinference.Rmd | 0 hewr/.Rbuildignore | 1 + hewr/DESCRIPTION | 48 + hewr/LICENSE.md | 194 ++ hewr/NAMESPACE | 8 + hewr/R/directory_utils.R | 130 + hewr/R/hewr-package.R | 7 + hewr/R/to_epiweekly_quantile_table.R | 155 ++ hewr/man/disease_map_lower.Rd | 18 + hewr/man/get_all_model_batch_dirs.Rd | 24 + hewr/man/hewr-package.Rd | 15 + hewr/man/parse_model_batch_dir_path.Rd | 23 + hewr/man/parse_model_run_dir_path.Rd | 22 + hewr/man/to_epiweekly_quantile_table.Rd | 24 + hewr/man/to_epiweekly_quantiles.Rd | 30 + hewr/tests/testthat.R | 12 + hewr/tests/testthat/test_directory_utils.R | 177 ++ nssp_demo/README.md | 82 - nssp_demo/fit_model.py | 43 - nssp_demo/generate_predictive.py | 71 - nssp_demo/post_process.R | 159 -- nssp_demo/prep_data.R | 164 -- pipelines/batch/setup_eval_job.py | 224 ++ pipelines/batch/setup_pool.py | 81 + pipelines/batch/setup_prod_job.py | 228 ++ pipelines/batch/setup_test_prod_job.py | 98 + {nssp_demo => pipelines}/build_model.py | 52 +- pipelines/collate_plots.py | 319 +++ pipelines/collate_score_tables.R | 175 ++ pipelines/convert_inferencedata_to_parquet.R | 95 + pipelines/create_hubverse_table.R | 49 + pipelines/default_priors.py | 68 + pipelines/diagnostic_report/custom.scss | 8 + pipelines/diagnostic_report/render_website.R | 66 + pipelines/diagnostic_report/template.qmd | 79 + pipelines/fit_model.py | 103 + pipelines/forecast_state.py | 419 +++ pipelines/generate_predictive.py | 72 + .../iteration_helpers/loop_fit.sh | 2 +- .../loop_generate_predictive.sh | 2 +- .../iteration_helpers/loop_postprocess.sh | 18 + pipelines/iteration_helpers/loop_score.sh | 17 + pipelines/make_observed_data_table.py | 72 + pipelines/postprocess_scoring.R | 406 +++ pipelines/postprocess_state_forecast.R | 275 ++ pipelines/prep_data.py | 347 +++ pipelines/pull_state_timeseries.py | 146 + pipelines/save_eval_data.py | 47 + pipelines/score_forecast.R | 245 ++ pipelines/tests/README.md | 10 + .../model_runs/TD/data.csv | 181 ++ .../model_runs/TD/data_for_model_fit.json | 381 +++ .../model_runs/TD/eval_data.tsv | 57 + .../model_runs/TD}/priors.py | 33 +- pipelines/tests/test_run.sh | 36 + pipelines/timeseries_forecasts.R | 259 ++ pipelines/utils.py | 135 + pyproject.toml | 5 + pyrenew-hew.Rproj | 16 - pyrenew_hew/hosp_only_ww_model.py | 19 +- renv.lock | 2342 ----------------- renv/.gitignore | 7 - renv/activate.R | 1220 --------- renv/settings.json | 19 - 79 files changed, 5910 insertions(+), 4192 deletions(-) create mode 100644 .ContainerBuildRprofile delete mode 100644 .Rprofile create mode 100644 .containerignore create mode 100644 .github/workflows/containers.yaml create mode 100644 .github/workflows/r-cmd-check.yaml create mode 100644 Containerfile create mode 100644 Containerfile.dependencies create mode 100644 Makefile rename {notebooks => demos/hosp_only_ww_model}/data/fit/stan_data.json (100%) rename {notebooks => demos/hosp_only_ww_model}/data/fit_hosp_only/stan_data.json (100%) rename {notebooks => demos/hosp_only_ww_model}/hosp_only_ww_model.qmd (90%) rename {notebooks => demos/hosp_only_ww_model}/model_comp.qmd (100%) rename {notebooks => demos/hosp_only_ww_model}/site_level_dynamics_model_demo.qmd (100%) rename {notebooks => demos/hosp_only_ww_model}/wwinference.Rmd (100%) create mode 100644 hewr/.Rbuildignore create mode 100644 hewr/DESCRIPTION create mode 100644 hewr/LICENSE.md create mode 100644 hewr/NAMESPACE create mode 100644 hewr/R/directory_utils.R create mode 100644 hewr/R/hewr-package.R create mode 100644 hewr/R/to_epiweekly_quantile_table.R create mode 100644 hewr/man/disease_map_lower.Rd create mode 100644 hewr/man/get_all_model_batch_dirs.Rd create mode 100644 hewr/man/hewr-package.Rd create mode 100644 hewr/man/parse_model_batch_dir_path.Rd create mode 100644 hewr/man/parse_model_run_dir_path.Rd create mode 100644 hewr/man/to_epiweekly_quantile_table.Rd create mode 100644 hewr/man/to_epiweekly_quantiles.Rd create mode 100644 hewr/tests/testthat.R create mode 100644 hewr/tests/testthat/test_directory_utils.R delete mode 100644 nssp_demo/README.md delete mode 100644 nssp_demo/fit_model.py delete mode 100644 nssp_demo/generate_predictive.py delete mode 100644 nssp_demo/post_process.R delete mode 100644 nssp_demo/prep_data.R create mode 100644 pipelines/batch/setup_eval_job.py create mode 100644 pipelines/batch/setup_pool.py create mode 100644 pipelines/batch/setup_prod_job.py create mode 100644 pipelines/batch/setup_test_prod_job.py rename {nssp_demo => pipelines}/build_model.py (61%) create mode 100644 pipelines/collate_plots.py create mode 100644 pipelines/collate_score_tables.R create mode 100644 pipelines/convert_inferencedata_to_parquet.R create mode 100644 pipelines/create_hubverse_table.R create mode 100644 pipelines/default_priors.py create mode 100644 pipelines/diagnostic_report/custom.scss create mode 100644 pipelines/diagnostic_report/render_website.R create mode 100644 pipelines/diagnostic_report/template.qmd create mode 100644 pipelines/fit_model.py create mode 100644 pipelines/forecast_state.py create mode 100644 pipelines/generate_predictive.py rename nssp_demo/fit_all_models.sh => pipelines/iteration_helpers/loop_fit.sh (89%) rename nssp_demo/generate_all_predictive.sh => pipelines/iteration_helpers/loop_generate_predictive.sh (83%) create mode 100755 pipelines/iteration_helpers/loop_postprocess.sh create mode 100755 pipelines/iteration_helpers/loop_score.sh create mode 100644 pipelines/make_observed_data_table.py create mode 100644 pipelines/postprocess_scoring.R create mode 100644 pipelines/postprocess_state_forecast.R create mode 100644 pipelines/prep_data.py create mode 100644 pipelines/pull_state_timeseries.py create mode 100644 pipelines/save_eval_data.py create mode 100644 pipelines/score_forecast.R create mode 100644 pipelines/tests/README.md create mode 100644 pipelines/tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs/TD/data.csv create mode 100644 pipelines/tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs/TD/data_for_model_fit.json create mode 100644 pipelines/tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs/TD/eval_data.tsv rename {nssp_demo => pipelines/tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs/TD}/priors.py (64%) create mode 100644 pipelines/tests/test_run.sh create mode 100644 pipelines/timeseries_forecasts.R create mode 100644 pipelines/utils.py delete mode 100644 pyrenew-hew.Rproj delete mode 100644 renv.lock delete mode 100644 renv/.gitignore delete mode 100644 renv/activate.R delete mode 100644 renv/settings.json diff --git a/.ContainerBuildRprofile b/.ContainerBuildRprofile new file mode 100644 index 00000000..d922b091 --- /dev/null +++ b/.ContainerBuildRprofile @@ -0,0 +1,19 @@ +options( + HTTPUserAgent = sprintf( + "R/%s R (%s)", + getRversion(), + paste( + getRversion(), + R.version["platform"], R.version["arch"], + R.version["os"] + ) + ), + ## use Posit package manager to get + ## precompiled binaries where possible + repos = c( + RSPM = "https://packagemanager.posit.co/cran/__linux__/bookworm/latest" + ), + renv.config.pak.enabled = TRUE +) + +cat(".Rprofile for container loaded successfully\n") diff --git a/.Rprofile b/.Rprofile deleted file mode 100644 index d659a4f7..00000000 --- a/.Rprofile +++ /dev/null @@ -1,2 +0,0 @@ -source("renv/activate.R") -source("~/.Rprofile") diff --git a/.containerignore b/.containerignore new file mode 100644 index 00000000..ada296cf --- /dev/null +++ b/.containerignore @@ -0,0 +1,3 @@ +Containerfile +nssp_demo/private_data +notebooks diff --git a/.github/workflows/containers.yaml b/.github/workflows/containers.yaml new file mode 100644 index 00000000..0ba0d5ba --- /dev/null +++ b/.github/workflows/containers.yaml @@ -0,0 +1,121 @@ +name: Create Docker Image + +on: + push: + branches: [main] + pull_request: + workflow_dispatch: + +env: + REGISTRY: cfaprdbatchcr.azurecr.io + IMAGE_NAME: pyrenew-hew + PYRENEW_VERSION: v0.1.1 + +jobs: + + build-dependencies-image: + runs-on: cfa-cdcgov + name: Build dependencies image + + outputs: + tag: ${{ steps.image-tag.outputs.tag }} + commit-msg: ${{ steps.commit-message.outputs.message }} + + steps: + + ######################################################################### + # Retrieving the commit message + # We need to ensure we are checking out the commit sha that triggered the + # workflow, not the PR's head sha. This is because the PR's head sha may + # be a merge commit, which will not have the commit message we need. + ######################################################################### + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.sha }} + + - name: Getting the commit message + id: commit-message + run: echo "message=$(git log -1 --pretty=%s HEAD)" >> $GITHUB_OUTPUT + + - name: Checking out the latest (may be merge if PR) + uses: actions/checkout@v4 + + # From: https://stackoverflow.com/a/58035262/2097171 + - name: Extract branch name + shell: bash + run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT + id: branch-name + + ######################################################################### + # Getting the tag + # The tag will be used for both the docker image and the batch pool + ######################################################################### + - name: Figure out tag (either latest if it is main or the branch name) + id: image-tag + run: | + if [ "${{ steps.branch-name.outputs.branch }}" = "main" ]; then + echo "tag=latest" >> $GITHUB_OUTPUT + else + echo "tag=${{ steps.branch-name.outputs.branch }}" >> $GITHUB_OUTPUT + fi + + - name: Check cache for base image + uses: actions/cache@v4 + id: cache + with: + key: docker-dependencies-${{ runner.os }}-${{ hashFiles('./Containerfile.dependencies') }}-${{ steps.image-tag.outputs.tag }} + lookup-only: true + path: + ./Containerfile.dependencies + + - name: Login to the Container Registry + if: steps.cache.outputs.cache-hit != 'true' + uses: docker/login-action@v3 + with: + registry: "cfaprdbatchcr.azurecr.io" + username: "cfaprdbatchcr" + password: ${{ secrets.CFAPRDBATCHCR_REGISTRY_PASSWORD }} + + - name: Build and push + if: steps.cache.outputs.cache-hit != 'true' + uses: docker/build-push-action@v6 + with: + push: true + no-cache: true + tags: | + ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}-dependencies:${{ steps.image-tag.outputs.tag }} + file: ./Containerfile.dependencies + build-args: | + PYRENEW_VERSION=${{ env.PYRENEW_VERSION }} + + build-pipeline-image: + + name: Build pipeline image + + needs: build-dependencies-image + runs-on: cfa-cdcgov + + outputs: + tag: ${{ needs.build-dependencies-image.outputs.tag }} + commit-msg: ${{ needs.build-dependencies-image.outputs.commit-msg }} + + steps: + + - name: Login to the Container Registry + uses: docker/login-action@v3 + with: + registry: "cfaprdbatchcr.azurecr.io" + username: "cfaprdbatchcr" + password: ${{ secrets.CFAPRDBATCHCR_REGISTRY_PASSWORD }} + + - name: Build and push model pipeline image for Azure batch + id: build_and_push_model_image + uses: docker/build-push-action@v6 + with: + push: true # This can be toggled manually for tweaking. + tags: | + ${{ env.REGISTRY}}/${{ env.IMAGE_NAME }}:${{ needs.build-dependencies-image.outputs.tag }} + file: ./Containerfile + build-args: | + TAG=${{ needs.build-dependencies-image.outputs.tag }} diff --git a/.github/workflows/r-cmd-check.yaml b/.github/workflows/r-cmd-check.yaml new file mode 100644 index 00000000..338ed5bc --- /dev/null +++ b/.github/workflows/r-cmd-check.yaml @@ -0,0 +1,29 @@ +name: R CMD check hewr + +on: + pull_request: + push: + branches: [main] + +jobs: + check-hewr: + strategy: + matrix: + r-version: ["4.4.0", "release"] + os: [windows-latest, ubuntu-latest, macos-latest] + runs-on: ${{matrix.os}} + steps: + - uses: actions/checkout@v4 + - uses: r-lib/actions/setup-r@v2 + with: + r-version: ${{matrix.r-version}} + use-public-rspm: true + - name: "Set up dependencies for hewr" + uses: r-lib/actions/setup-r-dependencies@v2 + with: + working-directory: hewr + needs: check + - name: "Check hewr package" + uses: r-lib/actions/check-r-package@v2 + with: + working-directory: hewr diff --git a/.gitignore b/.gitignore index 721ed3c1..9f92efd7 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,9 @@ *.bin *.xls *.xlsx +*.rds +*.pickle +*.nc # Documents *.doc @@ -212,6 +215,8 @@ cython_debug/ # R # https://github.com/github/gitignore/blob/main/R.gitignore +.Rprofile + # History files .Rhistory .Rapp.history @@ -234,6 +239,7 @@ cython_debug/ # RStudio files .Rproj.user/ +*.Rproj # produced vignettes vignettes/*.html @@ -388,4 +394,10 @@ poetry.lock notebooks/*_files/ notebooks/*.md -nssp_demo/private_data/* +private_data/* +*_files/ +.vscode/settings.json + +# Test data exceptions to the general data exclusion +!pipelines/tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs/TD/data.csv +!pipelines/tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs/TD/eval_data.tsv diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8fa785f2..d4ab5a48 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ repos: ##### # Basic file cleanliness - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: check-added-large-files - id: check-yaml @@ -13,7 +13,7 @@ repos: ##### # Python - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.8 + rev: v0.7.1 hooks: # Sort imports - id: ruff @@ -26,13 +26,13 @@ repos: ##### # R - repo: https://github.com/lorenzwalthert/precommit - rev: v0.4.3 + rev: v0.4.3.9001 hooks: - id: style-files - id: lintr # Secrets - repo: https://github.com/Yelp/detect-secrets - rev: v1.4.0 + rev: v1.5.0 hooks: - id: detect-secrets args: ["--baseline", ".secrets.baseline"] diff --git a/Containerfile b/Containerfile new file mode 100644 index 00000000..f627bf32 --- /dev/null +++ b/Containerfile @@ -0,0 +1,17 @@ +ARG TAG=latest + +FROM cfaprdbatchcr.azurecr.io/pyrenew-hew-dependencies:${TAG} + +COPY ./hewr /pyrenew-hew/hewr + +WORKDIR /pyrenew-hew + +COPY .ContainerBuildRprofile .Rprofile + +RUN Rscript -e "install.packages('pak')" +RUN Rscript -e "pak::pkg_install('cmu-delphi/epiprocess@main')" +RUN Rscript -e "pak::pkg_install('cmu-delphi/epipredict@main')" +RUN Rscript -e "pak::local_install('hewr')" + +COPY . . +RUN pip install --root-user-action=ignore . diff --git a/Containerfile.dependencies b/Containerfile.dependencies new file mode 100644 index 00000000..f8ea9501 --- /dev/null +++ b/Containerfile.dependencies @@ -0,0 +1,11 @@ +FROM python:3.13 + +ARG PYRENEW_VERSION=v0.1.1 + +RUN apt-get update +RUN apt-get install -y r-base +RUN apt-get install -y cmake +RUN pip install --root-user-action=ignore -U pip +RUN pip install --root-user-action=ignore git+https://github.com/cdcgov/pyrenew.git@$PYRENEW_VERSION + +CMD ["bash"] diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..9c45f4ef --- /dev/null +++ b/Makefile @@ -0,0 +1,30 @@ +.PHONY: container_build container_tag acr_login container_push dep_container_build dep_container_tag + +ENGINE := docker +DEP_CONTAINER_NAME := pyrenew-hew-dependencies +DEP_CONTAINERFILE := Containerfile.dependencies +DEP_CONTAINER_REMOTE_NAME := $(ACR_TAG_PREFIX)$(DEP_CONTAINER_NAME):latest +CONTAINER_NAME := pyrenew-hew +CONTAINERFILE := Containerfile +CONTAINER_REMOTE_NAME := $(ACR_TAG_PREFIX)$(CONTAINER_NAME):latest + +dep_container_build: + $(ENGINE) build . -t $(DEP_CONTAINER_NAME) -f $(DEP_CONTAINERFILE) + +dep_container_tag: + $(ENGINE) tag $(DEP_CONTAINER_NAME) $(DEP_CONTAINER_REMOTE_NAME) + +container_build: acr_login + $(ENGINE) build . -t $(CONTAINER_NAME) -f $(CONTAINERFILE) + +container_tag: + $(ENGINE) tag $(CONTAINER_NAME) $(CONTAINER_REMOTE_NAME) + +acr_login: + az acr login -n $(AZURE_CONTAINER_REGISTRY_ACCOUNT) + +dep_container_push: dep_container_tag acr_login + $(ENGINE) push $(DEP_CONTAINER_NAME) + +container_push: container_tag acr_login + $(ENGINE) push $(CONTAINER_REMOTE_NAME) diff --git a/notebooks/data/fit/stan_data.json b/demos/hosp_only_ww_model/data/fit/stan_data.json similarity index 100% rename from notebooks/data/fit/stan_data.json rename to demos/hosp_only_ww_model/data/fit/stan_data.json diff --git a/notebooks/data/fit_hosp_only/stan_data.json b/demos/hosp_only_ww_model/data/fit_hosp_only/stan_data.json similarity index 100% rename from notebooks/data/fit_hosp_only/stan_data.json rename to demos/hosp_only_ww_model/data/fit_hosp_only/stan_data.json diff --git a/notebooks/hosp_only_ww_model.qmd b/demos/hosp_only_ww_model/hosp_only_ww_model.qmd similarity index 90% rename from notebooks/hosp_only_ww_model.qmd rename to demos/hosp_only_ww_model/hosp_only_ww_model.qmd index 33e42602..778aca0e 100644 --- a/notebooks/hosp_only_ww_model.qmd +++ b/demos/hosp_only_ww_model/hosp_only_ww_model.qmd @@ -35,7 +35,7 @@ We begin by loading the Stan data, converting it the correct inputs for our mode ```{python} # | label: create model -my_hosp_only_ww_model, data_observed_hospital_admissions = ( +my_hosp_only_ww_model, data_observed_disease_hospital_admissions = ( create_hosp_only_ww_model_from_stan_data( "data/fit_hosp_only/stan_data.json" ) @@ -50,7 +50,7 @@ We check that we can simulate from the prior predictive n_forecast_days = 35 prior_predictive = my_hosp_only_ww_model.prior_predictive( - n_datapoints=len(data_observed_hospital_admissions) + n_forecast_days, + n_datapoints=len(data_observed_disease_hospital_admissions) + n_forecast_days, numpyro_predictive_args={"num_samples": 200}, ) ``` @@ -64,7 +64,7 @@ my_hosp_only_ww_model.run( num_warmup=500, num_samples=500, rng_key=jax.random.key(200), - data_observed_hospital_admissions=data_observed_hospital_admissions, + data_observed_disease_hospital_admissions=data_observed_disease_hospital_admissions, mcmc_args=dict(num_chains=4, progress_bar=False), nuts_args=dict(find_heuristic_step_size=True), ) @@ -75,7 +75,7 @@ Create the posterior predictive and forecast: ```{python} # | label: posterior predictive posterior_predictive = my_hosp_only_ww_model.posterior_predictive( - n_datapoints=len(data_observed_hospital_admissions) + n_forecast_days + n_datapoints=len(data_observed_disease_hospital_admissions) + n_forecast_days ) ``` diff --git a/notebooks/model_comp.qmd b/demos/hosp_only_ww_model/model_comp.qmd similarity index 100% rename from notebooks/model_comp.qmd rename to demos/hosp_only_ww_model/model_comp.qmd diff --git a/notebooks/site_level_dynamics_model_demo.qmd b/demos/hosp_only_ww_model/site_level_dynamics_model_demo.qmd similarity index 100% rename from notebooks/site_level_dynamics_model_demo.qmd rename to demos/hosp_only_ww_model/site_level_dynamics_model_demo.qmd diff --git a/notebooks/wwinference.Rmd b/demos/hosp_only_ww_model/wwinference.Rmd similarity index 100% rename from notebooks/wwinference.Rmd rename to demos/hosp_only_ww_model/wwinference.Rmd diff --git a/hewr/.Rbuildignore b/hewr/.Rbuildignore new file mode 100644 index 00000000..5163d0b5 --- /dev/null +++ b/hewr/.Rbuildignore @@ -0,0 +1 @@ +^LICENSE\.md$ diff --git a/hewr/DESCRIPTION b/hewr/DESCRIPTION new file mode 100644 index 00000000..16c93e43 --- /dev/null +++ b/hewr/DESCRIPTION @@ -0,0 +1,48 @@ +Package: hewr +Title: What the Package Does (One Line, Title Case) +Version: 0.0.0.9000 +Authors@R: + person("First", "Last", , "first.last@example.com", role = c("aut", "cre"), + comment = c(ORCID = "YOUR-ORCID-ID")) +Description: What the package does (one paragraph). +License: Apache License (>= 2) +Encoding: UTF-8 +Roxygen: list(markdown = TRUE) +RoxygenNote: 7.3.2 +Imports: + argparser, + arrow, + cowplot, + dplyr, + DT, + fable, + feasts, + forcats, + forecasttools (>= 0.0.0.9000), + fs, + ggplot2, + glue, + here, + htmltools, + jsonlite, + lubridate, + purrr, + quarto, + readr, + reticulate, + rlang, + scales, + scoringutils (>= 2.0.0), + stringr, + tibble, + tidybayes, + tidyr, + urca +Remotes: + https://github.com/cdcgov/forecasttools +Suggests: + rcmdcheck, + testthat (>= 3.0.0), + withr +Config/testthat/edition: 3 +Config/Needs/check: rcmdcheck, testthat diff --git a/hewr/LICENSE.md b/hewr/LICENSE.md new file mode 100644 index 00000000..b62a9b5f --- /dev/null +++ b/hewr/LICENSE.md @@ -0,0 +1,194 @@ +Apache License +============== + +_Version 2.0, January 2004_ +_<>_ + +### Terms and Conditions for use, reproduction, and distribution + +#### 1. Definitions + +“License” shall mean the terms and conditions for use, reproduction, and +distribution as defined by Sections 1 through 9 of this document. + +“Licensor” shall mean the copyright owner or entity authorized by the copyright +owner that is granting the License. + +“Legal Entity” shall mean the union of the acting entity and all other entities +that control, are controlled by, or are under common control with that entity. +For the purposes of this definition, “control” means **(i)** the power, direct or +indirect, to cause the direction or management of such entity, whether by +contract or otherwise, or **(ii)** ownership of fifty percent (50%) or more of the +outstanding shares, or **(iii)** beneficial ownership of such entity. + +“You” (or “Your”) shall mean an individual or Legal Entity exercising +permissions granted by this License. + +“Source” form shall mean the preferred form for making modifications, including +but not limited to software source code, documentation source, and configuration +files. + +“Object” form shall mean any form resulting from mechanical transformation or +translation of a Source form, including but not limited to compiled object code, +generated documentation, and conversions to other media types. + +“Work” shall mean the work of authorship, whether in Source or Object form, made +available under the License, as indicated by a copyright notice that is included +in or attached to the work (an example is provided in the Appendix below). + +“Derivative Works” shall mean any work, whether in Source or Object form, that +is based on (or derived from) the Work and for which the editorial revisions, +annotations, elaborations, or other modifications represent, as a whole, an +original work of authorship. For the purposes of this License, Derivative Works +shall not include works that remain separable from, or merely link (or bind by +name) to the interfaces of, the Work and Derivative Works thereof. + +“Contribution” shall mean any work of authorship, including the original version +of the Work and any modifications or additions to that Work or Derivative Works +thereof, that is intentionally submitted to Licensor for inclusion in the Work +by the copyright owner or by an individual or Legal Entity authorized to submit +on behalf of the copyright owner. For the purposes of this definition, +“submitted” means any form of electronic, verbal, or written communication sent +to the Licensor or its representatives, including but not limited to +communication on electronic mailing lists, source code control systems, and +issue tracking systems that are managed by, or on behalf of, the Licensor for +the purpose of discussing and improving the Work, but excluding communication +that is conspicuously marked or otherwise designated in writing by the copyright +owner as “Not a Contribution.” + +“Contributor” shall mean Licensor and any individual or Legal Entity on behalf +of whom a Contribution has been received by Licensor and subsequently +incorporated within the Work. + +#### 2. Grant of Copyright License + +Subject to the terms and conditions of this License, each Contributor hereby +grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, +irrevocable copyright license to reproduce, prepare Derivative Works of, +publicly display, publicly perform, sublicense, and distribute the Work and such +Derivative Works in Source or Object form. + +#### 3. Grant of Patent License + +Subject to the terms and conditions of this License, each Contributor hereby +grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, +irrevocable (except as stated in this section) patent license to make, have +made, use, offer to sell, sell, import, and otherwise transfer the Work, where +such license applies only to those patent claims licensable by such Contributor +that are necessarily infringed by their Contribution(s) alone or by combination +of their Contribution(s) with the Work to which such Contribution(s) was +submitted. If You institute patent litigation against any entity (including a +cross-claim or counterclaim in a lawsuit) alleging that the Work or a +Contribution incorporated within the Work constitutes direct or contributory +patent infringement, then any patent licenses granted to You under this License +for that Work shall terminate as of the date such litigation is filed. + +#### 4. Redistribution + +You may reproduce and distribute copies of the Work or Derivative Works thereof +in any medium, with or without modifications, and in Source or Object form, +provided that You meet the following conditions: + +* **(a)** You must give any other recipients of the Work or Derivative Works a copy of +this License; and +* **(b)** You must cause any modified files to carry prominent notices stating that You +changed the files; and +* **(c)** You must retain, in the Source form of any Derivative Works that You distribute, +all copyright, patent, trademark, and attribution notices from the Source form +of the Work, excluding those notices that do not pertain to any part of the +Derivative Works; and +* **(d)** If the Work includes a “NOTICE” text file as part of its distribution, then any +Derivative Works that You distribute must include a readable copy of the +attribution notices contained within such NOTICE file, excluding those notices +that do not pertain to any part of the Derivative Works, in at least one of the +following places: within a NOTICE text file distributed as part of the +Derivative Works; within the Source form or documentation, if provided along +with the Derivative Works; or, within a display generated by the Derivative +Works, if and wherever such third-party notices normally appear. The contents of +the NOTICE file are for informational purposes only and do not modify the +License. You may add Your own attribution notices within Derivative Works that +You distribute, alongside or as an addendum to the NOTICE text from the Work, +provided that such additional attribution notices cannot be construed as +modifying the License. + +You may add Your own copyright statement to Your modifications and may provide +additional or different license terms and conditions for use, reproduction, or +distribution of Your modifications, or for any such Derivative Works as a whole, +provided Your use, reproduction, and distribution of the Work otherwise complies +with the conditions stated in this License. + +#### 5. Submission of Contributions + +Unless You explicitly state otherwise, any Contribution intentionally submitted +for inclusion in the Work by You to the Licensor shall be under the terms and +conditions of this License, without any additional terms or conditions. +Notwithstanding the above, nothing herein shall supersede or modify the terms of +any separate license agreement you may have executed with Licensor regarding +such Contributions. + +#### 6. Trademarks + +This License does not grant permission to use the trade names, trademarks, +service marks, or product names of the Licensor, except as required for +reasonable and customary use in describing the origin of the Work and +reproducing the content of the NOTICE file. + +#### 7. Disclaimer of Warranty + +Unless required by applicable law or agreed to in writing, Licensor provides the +Work (and each Contributor provides its Contributions) on an “AS IS” BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, +including, without limitation, any warranties or conditions of TITLE, +NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are +solely responsible for determining the appropriateness of using or +redistributing the Work and assume any risks associated with Your exercise of +permissions under this License. + +#### 8. Limitation of Liability + +In no event and under no legal theory, whether in tort (including negligence), +contract, or otherwise, unless required by applicable law (such as deliberate +and grossly negligent acts) or agreed to in writing, shall any Contributor be +liable to You for damages, including any direct, indirect, special, incidental, +or consequential damages of any character arising as a result of this License or +out of the use or inability to use the Work (including but not limited to +damages for loss of goodwill, work stoppage, computer failure or malfunction, or +any and all other commercial damages or losses), even if such Contributor has +been advised of the possibility of such damages. + +#### 9. Accepting Warranty or Additional Liability + +While redistributing the Work or Derivative Works thereof, You may choose to +offer, and charge a fee for, acceptance of support, warranty, indemnity, or +other liability obligations and/or rights consistent with this License. However, +in accepting such obligations, You may act only on Your own behalf and on Your +sole responsibility, not on behalf of any other Contributor, and only if You +agree to indemnify, defend, and hold each Contributor harmless for any liability +incurred by, or claims asserted against, such Contributor by reason of your +accepting any such warranty or additional liability. + +_END OF TERMS AND CONDITIONS_ + +### APPENDIX: How to apply the Apache License to your work + +To apply the Apache License to your work, attach the following boilerplate +notice, with the fields enclosed by brackets `[]` replaced with your own +identifying information. (Don't include the brackets!) The text should be +enclosed in the appropriate comment syntax for the file format. We also +recommend that a file or class name and description of purpose be included on +the same “printed page” as the copyright notice for easier identification within +third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/hewr/NAMESPACE b/hewr/NAMESPACE new file mode 100644 index 00000000..86a36e69 --- /dev/null +++ b/hewr/NAMESPACE @@ -0,0 +1,8 @@ +# Generated by roxygen2: do not edit by hand + +export(get_all_model_batch_dirs) +export(parse_model_batch_dir_path) +export(parse_model_run_dir_path) +export(to_epiweekly_quantile_table) +export(to_epiweekly_quantiles) +importFrom(rlang,.data) diff --git a/hewr/R/directory_utils.R b/hewr/R/directory_utils.R new file mode 100644 index 00000000..d31b8fa0 --- /dev/null +++ b/hewr/R/directory_utils.R @@ -0,0 +1,130 @@ +#' Utilities for handling and parsing directory names +#' based on pyrenew-hew pipeline conventions. + +disease_map_lower <- c( + "covid-19" = "COVID-19", + "influenza" = "Influenza" +) + +#' Parse model batch directory name. +#' +#' Parse the name of a model batch directory +#' (i.e. a directory representing a single +#' report date and disease pair, but potentially +#' with fits for multiple locations), returning +#' a named list of quantities of interest. +#' +#' @param model_batch_dir_path Path to the model batch +#' directory to parse. Will parse only the basename. +#' @return A list of quantities: `disease`, `report_date`, +#' `first_training_date`, and `last_training_date`. +#' @export +parse_model_batch_dir_path <- function(model_batch_dir_path) { + pattern <- "(.+)_r_(.+)_f_(.+)_t_(.+)" + model_batch_dir_name <- fs::path_file(model_batch_dir_path) + matches <- stringr::str_match( + model_batch_dir_name, + pattern + ) + + if (any(is.na(matches))) { + stop( + "Invalid format for model batch directory name; ", + "could not parse. Expected ", + "'_r__f__t_", + "'." + ) + } + + result <- list( + disease = disease_map_lower[matches[2]] |> unname(), + # disease_map_lower + # is a named vector + # but we want disease + # just to be a string + report_date = lubridate::ymd(matches[3], quiet = TRUE), + first_training_date = lubridate::ymd(matches[4], quiet = TRUE), + last_training_date = lubridate::ymd(matches[5], quiet = TRUE) + ) + + if (any(is.na(result))) { + stop( + "Could not parse extracted disease and/or date ", + "values expected 'disease' to be one of 'covid-19' ", + "and 'influenza' and all dates to be valid dates in ", + "YYYY-MM-DD format. Got: ", + glue::glue( + "disease: {matches[2]}, ", + "report_date: {matches[3]}, ", + "first_training_date: {matches[4]}, ", + "last_training_date: {matches[5]}." + ) + ) + } + + return(result) +} + +#' Parse model run directory path. +#' +#' Parse path to a model run directory +#' (i.e. a directory representing a run for a +#' particular location, disease, and reference +#' date, and extract key quantities of interest. +#' +#' @param model_run_dir_path Path to parse. +#' @return A list of parsed attributes: +#' `location`, `disease`, `report_date`, +#' `first_training_date`, and `last_training_date`. +#' +#' @export +parse_model_run_dir_path <- function(model_run_dir_path) { + batch_dir <- model_run_dir_path |> + fs::path_dir() |> + fs::path_dir() |> + fs::path_file() + + location <- fs::path_file(model_run_dir_path) + + return(c( + location = location, + parse_model_batch_dir_path(batch_dir) + )) +} + + +#' Get forecast directories. +#' +#' Get all the subdirectories within a parent directory +#' that match the pattern for a forecast run for a +#' given disease and optionally a given report date. +#' +#' @param dir_of_batch_dirs Directory in which to look for +#' "model batch" directories, each of which represents an +#' individual forecast date / pathogen / dataset combination. +#' @param diseases Names of the diseases to match, as a vector of strings, +#' or a single disease as a string. +#' @return A vector of paths to the forecast subdirectories. +#' @export +get_all_model_batch_dirs <- function(dir_of_batch_dirs, + diseases) { + # disease names are lowercase by convention + match_patterns <- stringr::str_c(tolower(diseases), + "_r", + collapse = "|" + ) + + dirs <- tibble::tibble( + dir_path = fs::dir_ls( + dir_of_batch_dirs, + type = "directory" + ) + ) |> + dplyr::filter(stringr::str_starts( + fs::path_file(.data$dir_path), + !!match_patterns + )) |> + dplyr::pull(.data$dir_path) + + return(dirs) +} diff --git a/hewr/R/hewr-package.R b/hewr/R/hewr-package.R new file mode 100644 index 00000000..52c0c07b --- /dev/null +++ b/hewr/R/hewr-package.R @@ -0,0 +1,7 @@ +#' @keywords internal +"_PACKAGE" + +## usethis namespace: start +#' @importFrom rlang .data +## usethis namespace: end +NULL diff --git a/hewr/R/to_epiweekly_quantile_table.R b/hewr/R/to_epiweekly_quantile_table.R new file mode 100644 index 00000000..202428ef --- /dev/null +++ b/hewr/R/to_epiweekly_quantile_table.R @@ -0,0 +1,155 @@ +#' Read in daily forecast draws from a model run directory +#' and output a set of epiweekly quantiles, as a +#' [`tibbble`][tibble::tibble()]. +#' +#' @param model_run_dir Path to a directory containing +#' forecast draws to process, whose basename is the forecasted +#' location. +#' @param report_date Report date for which to generate epiweekly quantiles. +#' @param max_lookback_days How many days before the report date +#' to look back when generating epiweekly quantiles (determines how +#' many negative epiweekly forecast horizons (i.e. nowcast/backcast) +#' quantiles will be generated. +#' @return A [`tibble`][tibble::tibble()] of quantiles. +#' @export +to_epiweekly_quantiles <- function(model_run_dir, + report_date, + max_lookback_days) { + message(glue::glue("Processing {model_run_dir}...")) + draws_path <- fs::path(model_run_dir, + "forecast_samples", + ext = "parquet" + ) + location <- fs::path_file(model_run_dir) + + draws <- arrow::read_parquet(draws_path) |> + dplyr::filter(.data$date >= lubridate::ymd(!!report_date) - + lubridate::days(!!max_lookback_days)) + + if (nrow(draws) < 1) { + return(NULL) + } + + epiweekly_disease_draws <- draws |> + dplyr::filter( + .data$disease == "Disease" + ) |> + forecasttools::daily_to_epiweekly( + date_col = "date", + value_col = ".value", + id_cols = ".draw", + weekly_value_name = "epiweekly_disease", + strict = TRUE + ) + + epiweekly_total_draws <- draws |> + dplyr::filter(.data$disease == "Other") |> + forecasttools::daily_to_epiweekly( + date_col = "date", + value_col = ".value", + id_cols = ".draw", + weekly_value_name = "epiweekly_total", + strict = TRUE + ) + + epiweekly_prop_draws <- dplyr::inner_join( + epiweekly_disease_draws, + epiweekly_total_draws, + by = c( + "epiweek", + "epiyear", + ".draw" + ) + ) |> + dplyr::mutate( + "epiweekly_proportion" = + .data$epiweekly_disease / .data$epiweekly_total + ) + + + epiweekly_quantiles <- epiweekly_prop_draws |> + forecasttools::trajectories_to_quantiles( + timepoint_cols = c("epiweek", "epiyear"), + value_col = "epiweekly_proportion" + ) |> + dplyr::mutate( + "location" = !!location + ) + + message(glue::glue("Done processing {model_run_dir}")) + return(epiweekly_quantiles) +} + +#' Create an epiweekly hubverse-format forecast quantile table +#' from a model batch directory containing forecasts +#' for multiple locations as daily MCMC draws. +#' +#' @param model_batch_dir Model batch directory containing +#' the individual location forecast directories +#' ("model run directories") to process. Name should be in the format +#' `{disease}_r_{reference_date}_f_{first_data_date}_t_{last_data_date}`. +#' @param exclude Locations to exclude, if any, as a list of strings. +#' Default `NULL` (exclude nothing). +#' +#' @export +to_epiweekly_quantile_table <- function(model_batch_dir, + exclude = NULL) { + locations_to_process <- fs::dir_ls(model_batch_dir, + type = "directory" + ) + + if (!is.null(exclude)) { + locations_to_process <- locations_to_process[ + !(fs::path_file(locations_to_process) %in% exclude) + ] + } + + batch_params <- hewr::parse_model_batch_dir_path( + model_batch_dir + ) + report_date <- batch_params$report_date + disease <- batch_params$disease + disease_abbr <- dplyr::case_when( + disease == "Influenza" ~ "flu", + disease == "COVID-19" ~ "covid", + TRUE ~ disease + ) + + report_epiweek <- lubridate::epiweek(report_date) + report_epiyear <- lubridate::epiyear(report_date) + report_epiweek_end <- forecasttools::epiweek_to_date( + report_epiweek, + report_epiyear, + day_of_week = 7 + ) + + hubverse_table <- purrr::map( + locations_to_process, + \(x) { + to_epiweekly_quantiles( + x, + report_date = report_date, + max_lookback_days = 8 + ) + } + ## ensures we get the full -1 horizon but do not + ## waste time quantilizing draws that will not be + ## included in the final table. + ) |> + dplyr::bind_rows() |> + forecasttools::get_hubverse_table( + report_epiweek_end, + target_name = + glue::glue("wk inc {disease_abbr} prop ed visits") + ) |> + dplyr::arrange( + .data$target, + .data$output_type, + .data$location, + .data$reference_date, + .data$horizon, + .data$output_type_id + ) + + return(hubverse_table) +} diff --git a/hewr/man/disease_map_lower.Rd b/hewr/man/disease_map_lower.Rd new file mode 100644 index 00000000..63137973 --- /dev/null +++ b/hewr/man/disease_map_lower.Rd @@ -0,0 +1,18 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/directory_utils.R +\docType{data} +\name{disease_map_lower} +\alias{disease_map_lower} +\title{Utilities for handling and parsing directory names +based on pyrenew-hew pipeline conventions.} +\format{ +An object of class \code{character} of length 2. +} +\usage{ +disease_map_lower +} +\description{ +Utilities for handling and parsing directory names +based on pyrenew-hew pipeline conventions. +} +\keyword{datasets} diff --git a/hewr/man/get_all_model_batch_dirs.Rd b/hewr/man/get_all_model_batch_dirs.Rd new file mode 100644 index 00000000..4997d907 --- /dev/null +++ b/hewr/man/get_all_model_batch_dirs.Rd @@ -0,0 +1,24 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/directory_utils.R +\name{get_all_model_batch_dirs} +\alias{get_all_model_batch_dirs} +\title{Get forecast directories.} +\usage{ +get_all_model_batch_dirs(dir_of_batch_dirs, diseases) +} +\arguments{ +\item{dir_of_batch_dirs}{Directory in which to look for +"model batch" directories, each of which represents an +individual forecast date / pathogen / dataset combination.} + +\item{diseases}{Names of the diseases to match, as a vector of strings, +or a single disease as a string.} +} +\value{ +A vector of paths to the forecast subdirectories. +} +\description{ +Get all the subdirectories within a parent directory +that match the pattern for a forecast run for a +given disease and optionally a given report date. +} diff --git a/hewr/man/hewr-package.Rd b/hewr/man/hewr-package.Rd new file mode 100644 index 00000000..3c6304f1 --- /dev/null +++ b/hewr/man/hewr-package.Rd @@ -0,0 +1,15 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/hewr-package.R +\docType{package} +\name{hewr-package} +\alias{hewr} +\alias{hewr-package} +\title{hewr: What the Package Does (One Line, Title Case)} +\description{ +What the package does (one paragraph). +} +\author{ +\strong{Maintainer}: First Last \email{first.last@example.com} (\href{https://orcid.org/YOUR-ORCID-ID}{ORCID}) + +} +\keyword{internal} diff --git a/hewr/man/parse_model_batch_dir_path.Rd b/hewr/man/parse_model_batch_dir_path.Rd new file mode 100644 index 00000000..a247b597 --- /dev/null +++ b/hewr/man/parse_model_batch_dir_path.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/directory_utils.R +\name{parse_model_batch_dir_path} +\alias{parse_model_batch_dir_path} +\title{Parse model batch directory name.} +\usage{ +parse_model_batch_dir_path(model_batch_dir_path) +} +\arguments{ +\item{model_batch_dir_path}{Path to the model batch +directory to parse. Will parse only the basename.} +} +\value{ +A list of quantities: \code{disease}, \code{report_date}, +\code{first_training_date}, and \code{last_training_date}. +} +\description{ +Parse the name of a model batch directory +(i.e. a directory representing a single +report date and disease pair, but potentially +with fits for multiple locations), returning +a named list of quantities of interest. +} diff --git a/hewr/man/parse_model_run_dir_path.Rd b/hewr/man/parse_model_run_dir_path.Rd new file mode 100644 index 00000000..4795c6a3 --- /dev/null +++ b/hewr/man/parse_model_run_dir_path.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/directory_utils.R +\name{parse_model_run_dir_path} +\alias{parse_model_run_dir_path} +\title{Parse model run directory path.} +\usage{ +parse_model_run_dir_path(model_run_dir_path) +} +\arguments{ +\item{model_run_dir_path}{Path to parse.} +} +\value{ +A list of parsed attributes: +\code{location}, \code{disease}, \code{report_date}, +\code{first_training_date}, and \code{last_training_date}. +} +\description{ +Parse path to a model run directory +(i.e. a directory representing a run for a +particular location, disease, and reference +date, and extract key quantities of interest. +} diff --git a/hewr/man/to_epiweekly_quantile_table.Rd b/hewr/man/to_epiweekly_quantile_table.Rd new file mode 100644 index 00000000..56d017c1 --- /dev/null +++ b/hewr/man/to_epiweekly_quantile_table.Rd @@ -0,0 +1,24 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/to_epiweekly_quantile_table.R +\name{to_epiweekly_quantile_table} +\alias{to_epiweekly_quantile_table} +\title{Create an epiweekly hubverse-format forecast quantile table +from a model batch directory containing forecasts +for multiple locations as daily MCMC draws.} +\usage{ +to_epiweekly_quantile_table(model_batch_dir, exclude = NULL) +} +\arguments{ +\item{model_batch_dir}{Model batch directory containing +the individual location forecast directories +("model run directories") to process. Name should be in the format +\verb{\{disease\}_r_\{reference_date\}_f_\{first_data_date\}_t_\{last_data_date\}}.} + +\item{exclude}{Locations to exclude, if any, as a list of strings. +Default \code{NULL} (exclude nothing).} +} +\description{ +Create an epiweekly hubverse-format forecast quantile table +from a model batch directory containing forecasts +for multiple locations as daily MCMC draws. +} diff --git a/hewr/man/to_epiweekly_quantiles.Rd b/hewr/man/to_epiweekly_quantiles.Rd new file mode 100644 index 00000000..47738ad5 --- /dev/null +++ b/hewr/man/to_epiweekly_quantiles.Rd @@ -0,0 +1,30 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/to_epiweekly_quantile_table.R +\name{to_epiweekly_quantiles} +\alias{to_epiweekly_quantiles} +\title{Read in daily forecast draws from a model run directory +and output a set of epiweekly quantiles, as a +\code{\link[tibble:tibble]{tibbble}}.} +\usage{ +to_epiweekly_quantiles(model_run_dir, report_date, max_lookback_days) +} +\arguments{ +\item{model_run_dir}{Path to a directory containing +forecast draws to process, whose basename is the forecasted +location.} + +\item{report_date}{Report date for which to generate epiweekly quantiles.} + +\item{max_lookback_days}{How many days before the report date +to look back when generating epiweekly quantiles (determines how +many negative epiweekly forecast horizons (i.e. nowcast/backcast) +quantiles will be generated.} +} +\value{ +A \code{\link[tibble:tibble]{tibble}} of quantiles. +} +\description{ +Read in daily forecast draws from a model run directory +and output a set of epiweekly quantiles, as a +\code{\link[tibble:tibble]{tibbble}}. +} diff --git a/hewr/tests/testthat.R b/hewr/tests/testthat.R new file mode 100644 index 00000000..588af7fd --- /dev/null +++ b/hewr/tests/testthat.R @@ -0,0 +1,12 @@ +# This file is part of the standard setup for testthat. +# It is recommended that you do not modify it. +# +# Where should you do additional test configuration? +# Learn more about the roles of various files in: +# * https://r-pkgs.org/testing-design.html#sec-tests-files-overview +# * https://testthat.r-lib.org/articles/special-files.html + +library(testthat) +library(hewr) + +test_check("hewr") diff --git a/hewr/tests/testthat/test_directory_utils.R b/hewr/tests/testthat/test_directory_utils.R new file mode 100644 index 00000000..5ae9c1fa --- /dev/null +++ b/hewr/tests/testthat/test_directory_utils.R @@ -0,0 +1,177 @@ +valid_model_batch_dirs <- list( + list( + dirname = "covid-19_r_2024-02-03_f_2021-04-01_t_2024-01-23", + expected = list( + disease = "COVID-19", + report_date = lubridate::ymd("2024-02-03"), + first_training_date = lubridate::ymd("2021-04-1"), + last_training_date = lubridate::ymd("2024-01-23") + ) + ), + list( + dirname = "influenza_r_2022-12-11_f_2021-02-05_t_2027-12-30", + expected = list( + disease = "Influenza", + report_date = lubridate::ymd("2022-12-11"), + first_training_date = lubridate::ymd("2021-02-5"), + last_training_date = lubridate::ymd("2027-12-30") + ) + ) +) + +invalid_model_batch_dirs <- c( + "qcovid-19_r_2024-02-03_f_2021-04-01_t_2024-01-23", + "influenza_r_2022-12-33_f_2021-02-05_t_2027-12-30" +) + +to_valid_run_dir <- function(valid_batch_dir_entry, location) { + x <- valid_batch_dir_entry + x$dirpath <- fs::path(x$dirname, "model_runs", location) + x$expected <- c( + location = location, + x$expected + ) + return(x) +} + +valid_model_run_dirs <- c( + lapply( + valid_model_batch_dirs, to_valid_run_dir, + location = "ME" + ), + lapply( + valid_model_batch_dirs, to_valid_run_dir, + location = "US" + ) +) + + +test_that("parse_model_batch_dir_path() works as expected.", { + for (valid_pair in valid_model_batch_dirs) { + ## should work with base dirnames that are valid + expect_equal( + parse_model_batch_dir_path(valid_pair$dirname), + valid_pair$expected + ) + + ## should work identically with a full path rather + ## than just base dir + also_valid <- fs::path("this", "is", "a", "test", valid_pair$dirname) + expect_equal( + parse_model_batch_dir_path(also_valid), + valid_pair$expected + ) + + ## should error if the terminal directory is not + ## what is to be parsed + not_valid <- fs::path(valid_pair$dirname, "test") + expect_error( + { + parse_model_batch_dir_path(not_valid) + }, + regex = "Invalid format for model batch directory name" + ) + } + + ## should error if entries cannot be parsed as what is expected + + for (invalid_entry in invalid_model_batch_dirs) { + expect_error( + { + parse_model_batch_dir_path(invalid_entry) + }, + regex = "Could not parse extracted disease and/or date values" + ) + } +}) + +test_that("parse_model_run_dir_path() works as expected.", { + for (valid_pair in valid_model_run_dirs) { + expect_equal( + parse_model_run_dir_path(valid_pair$dirpath), + valid_pair$expected + ) + + ## should work identically with a longer path + expect_equal( + parse_model_run_dir_path(fs::path( + "this", "is", "a", "test", + valid_pair$dirpath + )), + valid_pair$expected + ) + + ## should fail if there is additional terminal pathing + expect_error( + { + parse_model_run_dir_path(fs::path(valid_pair$dirpath, "test")) + }, + regex = "Invalid format for model batch directory name" + ) + } +}) + +test_that("get_all_model_batch_dirs() returns expected output.", { + withr::with_tempdir({ + ## create some directories + valid_covid <- c( + "covid-19_r_2024-02-01_f_2021-01-01_t_2024-01-31", + "covid-19_r" + ) + valid_flu <- c( + "influenza_r_2022-11-12_f_2022-11-01_t_2022_11_10", + "influenza_r" + ) + valid_dirs <- c(valid_flu, valid_covid) + + invalid_dirs <- c( + "this_is_not_valid", + "covid19_r", + "covid-19-r", + "influenza-r", + "influnza_r", + "covid-19", + "influenza" + ) + + invalid_files <- c( + "covid-19_r.txt", + "influenza_r.txt" + ) + fs::dir_create(c(valid_dirs, invalid_dirs)) + fs::file_create(invalid_files) + expected_all_files <- c( + valid_dirs, + invalid_dirs, + invalid_files + ) + + result_all <- fs::dir_ls(".") |> fs::path_file() + + result_valid <- get_all_model_batch_dirs( + ".", + c("COVID-19", "Influenza") + ) + + result_valid_alt <- get_all_model_batch_dirs( + ".", + c("Influenza", "COVID-19") + ) + + result_valid_covid <- get_all_model_batch_dirs( + ".", + "COVID-19" + ) + + result_valid_flu <- get_all_model_batch_dirs( + ".", + "Influenza" + ) + + expect_setequal(result_all, expected_all_files) + expect_setequal(result_valid, c(valid_flu, valid_covid)) + expect_setequal(result_valid_alt, c(valid_flu, valid_covid)) + expect_setequal(result_valid_covid, valid_covid) + expect_setequal(result_valid_flu, valid_flu) + }) +}) diff --git a/nssp_demo/README.md b/nssp_demo/README.md deleted file mode 100644 index c48742d6..00000000 --- a/nssp_demo/README.md +++ /dev/null @@ -1,82 +0,0 @@ -# NSSP Demo Workflow - -## 1. Prepare data - - -### Now - -`prep_data.R` reads in a `private_data/report_date.parquet` (nssp data) -and `private_data/prod.parquet` (parameter estimates) from disk. -It provides a function `prep_data` -that takes the arguments: `disease`, `report_date`, `min_reference_date`, -`max_reference_date`, `last_training_date`, `state_abb`. - -To create a dataframe (for plotting) and a `data_for_model_fit` list (data that is -read in the model fitting step). - -The function `prep_and_save_data` has the same arguments as `prep_data` and -saves the results in `private_data/{str_to_lower(disease)}_r_{report_date}_f_{actual_first_date}_l_{actual_last_date}_t_{last_training_date}/{state_abb}`. - -`disease`, `report_date`, `min_reference_date`, `max_reference_date`, and `last_training_date` are all specified in the script. - -The script uses `purrr::walk` to save data for each `state_abb`. - -### In the future - -`disease`, `report_date`, `min_reference_date`, `max_reference_date`, and -`last_training_date`, and `state_abb` should be specified as command line -arguments. - -The path to `report_date.parquet` and `prod.parquet` should be specified as -command line arguments. - -Eventually, `report_date.parquet` and `prod.parquet` should be read from azure -blob storage. - -## 2. Fitting the model - -### Now - -Models are fit by calling `python fit_model.py --model_dir MODEL_DIR` from the -command line, where MODEL_DIR is of the form `private_data/{str_to_lower(disease)}_r_{report_date}_f_{actual_first_date}_l_{actual_last_date}_t_{last_training_date}/{state_abb}` - -Results are saved as a pickle file in `private_data/{str_to_lower(disease)}_r_{report_date}_f_{actual_first_date}_l_{actual_last_date}_t_{last_training_date}/{state_abb}/posterior_samples.pickle` - - -## 2. Creating forecasts - -Forecasts are created by calling `python generate_predictive.py --model_dir MODEL_DIR` from the -command line, where MODEL_DIR is of the form `private_data/{str_to_lower(disease)}_r_{report_date}_f_{actual_first_date}_l_{actual_last_date}_t_{last_training_date}/{state_abb}` - -The results are saved as a csv in -`private_data/{str_to_lower(disease)}_r_{report_date}_f_{actual_first_date}_l_{actual_last_date}_t_{last_training_date}/{state_abb}/inference_data.csv` -Results are also saved as a netCDF file in -`private_data/{str_to_lower(disease)}_r_{report_date}_f_{actual_first_date}_l_{actual_last_date}_t_{last_training_date}/{state_abb}/inference_data.nc` - -### In the future - -Exported file formates may change (see github issues.) - -## 4. Post-processing - -### Now - -Non-converging chains should be pruned here. - -`post_process.R` contains a function `make_forecast_fig` that takes `model_dir` -as an argument. It creates a forecast plot. - -The script uses `purrr::map` and `purrr::pwalk` to create and save forecast plots for -every sub-directory in -`private_data/{str_to_lower(disease)}_r_{report_date}_f_{actual_first_date}_l_{actual_last_date}_t_{last_training_date}`. - -It then uses `pdfunnite` to save a combined pdf. - -### In the future - -More plots and diagnoistics should be added. -There should be some intermediate script that prepares model output for -plotting, which can be run on each model in parallel. - -Then there can be one final script to create the figures and other diagnostics, -which may involve combining data from multiple model fits. diff --git a/nssp_demo/fit_model.py b/nssp_demo/fit_model.py deleted file mode 100644 index 4650d2dc..00000000 --- a/nssp_demo/fit_model.py +++ /dev/null @@ -1,43 +0,0 @@ -import argparse -import pickle -from pathlib import Path - -import jax -import numpyro - -n_chains = 4 -numpyro.set_host_device_count(n_chains) -from build_model import build_model_from_dir # noqa: E402 - -parser = argparse.ArgumentParser( - description="Fit the hospital-only wastewater model." -) -parser.add_argument( - "--model_dir", - type=str, - required=True, - help="Path to the model directory containing the data.", -) -args = parser.parse_args() -model_dir = Path(args.model_dir) - -my_model, data_observed_hospital_admissions, right_truncation_offset = ( - build_model_from_dir(model_dir) -) -my_model.run( - num_warmup=500, - num_samples=500, - rng_key=jax.random.key(200), - data_observed_hospital_admissions=data_observed_hospital_admissions, - right_truncation_offset=right_truncation_offset, - mcmc_args=dict(num_chains=n_chains, progress_bar=True), - nuts_args=dict(find_heuristic_step_size=True), -) - -my_model.mcmc.sampler = None - -with open( - model_dir / "posterior_samples.pickle", - "wb", -) as file: - pickle.dump(my_model.mcmc, file) diff --git a/nssp_demo/generate_predictive.py b/nssp_demo/generate_predictive.py deleted file mode 100644 index 04464c9a..00000000 --- a/nssp_demo/generate_predictive.py +++ /dev/null @@ -1,71 +0,0 @@ -import argparse -import pickle -from pathlib import Path - -import arviz as az -import numpyro - -n_chains = 4 -numpyro.set_host_device_count(n_chains) -from build_model import build_model_from_dir # noqa: E402 - -parser = argparse.ArgumentParser( - description="Fit the hospital-only wastewater model." -) -parser.add_argument( - "--model_dir", - type=str, - required=True, - help="Path to the model directory containing the data.", -) -parser.add_argument( - "--n_forecast_points", - type=int, - default=0, - help="Number of time points to forecast", -) -args = parser.parse_args() -model_dir = Path(args.model_dir) -n_forecast_points = args.n_forecast_points -my_model, data_observed_hospital_admissions, right_truncation_offset = ( - build_model_from_dir(model_dir) -) - -my_model._init_model(1, 1) -fresh_sampler = my_model.mcmc.sampler - -with open( - model_dir / "posterior_samples.pickle", - "rb", -) as file: - my_model.mcmc = pickle.load(file) - -my_model.mcmc.sampler = fresh_sampler - -# prior_predictive = my_model.prior_predictive( -# numpyro_predictive_args={ -# "num_samples": my_model.mcmc.num_samples * my_model.mcmc.num_chains, -# "batch_ndims":1 -# }, -# n_datapoints=len(data_observed_hospital_admissions) + n_forecast_points, -# ) -# need to figure out a way to generate these as distinct chains, so that the result of the to_datarame method is more compact - -posterior_predictive = my_model.posterior_predictive( - n_datapoints=len(data_observed_hospital_admissions) + n_forecast_points -) - -idata = az.from_numpyro( - my_model.mcmc, - # prior=prior_predictive, - posterior_predictive=posterior_predictive, -) - -idata.to_dataframe().to_csv(model_dir / "inference_data.csv", index=False) - -# Save one netcdf for reloading -idata.to_netcdf(model_dir / "inference_data.nc") - -# R cannot read netcdf files with groups, so we split them into separate files. -for group in idata._groups_all: - idata[group].to_netcdf(model_dir / f"inference_data_{group}.nc") diff --git a/nssp_demo/post_process.R b/nssp_demo/post_process.R deleted file mode 100644 index 5cc8b814..00000000 --- a/nssp_demo/post_process.R +++ /dev/null @@ -1,159 +0,0 @@ -library(tidyverse) -library(tidybayes) -library(fs) -library(cowplot) -library(glue) -library(scales) -library(here) - -theme_set(theme_minimal_grid()) - -disease_name_formatter <- c("covid-19" = "COVID-19", "influenza" = "Flu") - -make_forecast_fig <- function(model_dir) { - disease_name_raw <- base_dir %>% - path_file() %>% - str_extract("^.+(?=_r_)") - - state_abb <- model_dir %>% - path_split() %>% - pluck(1) %>% - tail(1) - - - data_path <- path(model_dir, "data", ext = "csv") - inference_data_path <- path(model_dir, "inference_data", - ext = "csv" - ) - - - dat <- read_csv(data_path) %>% - arrange(date) %>% - mutate(time = row_number() - 1) %>% - rename(.value = COVID_ED_admissions) - - last_training_date <- dat %>% - filter(data_type == "train") %>% - pull(date) %>% - max() - - last_data_date <- dat %>% - pull(date) %>% - max() - - arviz_split <- function(x) { - x %>% - select(-distribution) %>% - split(f = as.factor(x$distribution)) - } - - pyrenew_samples <- - read_csv(inference_data_path) %>% - rename_with(\(varname) str_remove_all(varname, "\\(|\\)|\\'|(, \\d+)")) |> - rename( - .chain = chain, - .iteration = draw - ) |> - mutate(across(c(.chain, .iteration), \(x) as.integer(x + 1))) |> - mutate( - .draw = tidybayes:::draw_from_chain_and_iteration_(.chain, .iteration), - .after = .iteration - ) |> - pivot_longer(-starts_with("."), - names_sep = ", ", - names_to = c("distribution", "name") - ) |> - arviz_split() |> - map(\(x) pivot_wider(x, names_from = name) |> tidy_draws()) - - - hosp_ci <- - pyrenew_samples$posterior_predictive %>% - gather_draws(observed_hospital_admissions[time]) %>% - median_qi(.width = c(0.5, 0.8, 0.95)) %>% - mutate(date = min(dat$date) + time) - - - - forecast_plot <- - ggplot(mapping = aes(date, .value)) + - geom_lineribbon( - data = hosp_ci, - mapping = aes(ymin = .lower, ymax = .upper), - color = "#08519c", key_glyph = draw_key_rect, step = "mid" - ) + - geom_point(mapping = aes(shape = data_type), data = dat) + - scale_y_continuous("Emergency Department Admissions") + - scale_x_date("Date") + - scale_fill_brewer( - name = "Credible Interval Width", - labels = ~ percent(as.numeric(.)) - ) + - scale_shape_discrete("Data Type", labels = str_to_title) + - geom_vline(xintercept = last_training_date, linetype = "dashed") + - annotate( - geom = "text", - x = last_training_date, - y = -Inf, - label = "Fit Period ←\n", - hjust = "right", - vjust = "bottom" - ) + - annotate( - geom = "text", - x = last_training_date, - y = -Inf, label = "→ Forecast Period\n", - hjust = "left", - vjust = "bottom", - ) + - ggtitle( - glue( - "{disease_name_formatter[disease_name_raw]} ", - "NSSP-based forecast for {state_abb}" - ), - subtitle = glue("as of {last_data_date}") - ) + - theme(legend.position = "bottom") - - forecast_plot -} - - -base_dir <- path(here( - "nssp_demo", - "private_data", - "covid-19_r_2024-10-10_f_2024-04-12_l_2024-10-09_t_2024-10-05" -)) - - -forecast_fig_tbl <- - tibble(base_model_dir = dir_ls(base_dir)) %>% - filter( - path(base_model_dir, "inference_data", ext = "csv") %>% - file_exists() - ) %>% - mutate(forecast_fig = map(base_model_dir, make_forecast_fig)) %>% - mutate(figure_path = path(base_model_dir, "forecast_plot", ext = "pdf")) - -pwalk( - forecast_fig_tbl %>% select(forecast_fig, figure_path), - function(forecast_fig, figure_path) { - save_plot( - filename = figure_path, - plot = forecast_fig, - device = cairo_pdf, base_height = 6 - ) - } -) - -str_c(forecast_fig_tbl$figure_path, collapse = " ") %>% - str_c( - path(base_dir, - glue("{path_file(base_dir)}_all_forecasts"), - ext = "pdf" - ), - sep = " " - ) %>% - system2("pdfunite", args = .) - -setdiff(usa::state.abb, path_file(forecast_fig_tbl$base_model_dir)) diff --git a/nssp_demo/prep_data.R b/nssp_demo/prep_data.R deleted file mode 100644 index daf6d4bf..00000000 --- a/nssp_demo/prep_data.R +++ /dev/null @@ -1,164 +0,0 @@ -# Read commnad line args -library(tidyverse) -library(CFAEpiNow2Pipeline) -library(usa) -library(fs) -library(arrow) -library(here) -library(glue) -library(jsonlite) - - -prep_data <- function(disease = c("COVID-19", "Influenza", "test"), - report_date = today(), - min_reference_date = "2000-01-01", - max_reference_date = "3000-01-01", - last_training_date = max_reference_date, - state_abb = "US") { - prepped_data <- read_data( - data_path = here( - path("nssp_demo", "private_data", report_date, ext = "parquet") - ), - disease = disease, - state_abb = state_abb, - report_date = report_date, - max_reference_date = max_reference_date, - min_reference_date = min_reference_date - ) %>% - as_tibble() %>% - select(date = reference_date, COVID_ED_admissions = confirm) %>% - mutate(data_type = if_else(date <= last_training_date, "train", "test")) - - train_ed_admissions <- prepped_data %>% - filter(data_type == "train") %>% - pull(COVID_ED_admissions) - - test_ed_admissions <- prepped_data %>% - filter(data_type == "test") %>% - pull(COVID_ED_admissions) - - - state_pop <- - usa::facts %>% - left_join(select(usa::states, abb, name)) %>% - filter(abb == state_abb) %>% - pull(population) - - nnh_estimates <- read_parquet( - here(path("nssp_demo", - "private_data", - "prod", - ext = "parquet" - )) - ) - - generation_interval_pmf <- - nnh_estimates %>% - filter( - is.na(geo_value), - disease == !!disease, - parameter == "generation_interval" - ) %>% - pull(value) %>% - pluck(1) - - - delay_pmf <- - nnh_estimates %>% - filter( - is.na(geo_value), - disease == !!disease, - parameter == "delay" - ) %>% - pull(value) %>% - pluck(1) - - right_truncation_pmf <- - nnh_estimates %>% - filter( - geo_value == state_abb, - disease == !!disease, - parameter == "right_truncation" - ) %>% - pull(value) %>% - pluck(1) - - - list( - prepped_date = prepped_data, - data_for_model_fit = list( - inf_to_hosp_pmf = delay_pmf, - generation_interval_pmf = generation_interval_pmf, - right_truncation_pmf = right_truncation_pmf, - data_observed_hospital_admissions = train_ed_admissions, - test_ed_admissions = test_ed_admissions, - state_pop = state_pop - ) - ) -} - - -prep_and_save_data <- function(disease, - report_date, - min_reference_date, - max_reference_date, - last_training_date, - state_abb) { - # prep data - dat <- prep_data( - disease = disease, - report_date = report_date, - min_reference_date = min_reference_date, - max_reference_date = max_reference_date, - last_training_date = last_training_date, - state_abb = state_abb - ) - - actual_first_date <- min(dat$prepped_date$date) - actual_last_date <- max(dat$prepped_date$date) - dat$data_for_model_fit$right_truncation_offset <- as.integer( - as_date(report_date) - - as_date(last_training_date) - ) - # could be off by 1 - - - # Create folders - model_folder_name <- glue(paste0( - "{str_to_lower(disease)}_", - "r_{report_date}_", - "f_{actual_first_date}_", - "l_{actual_last_date}_", - "t_{last_training_date}" - )) - model_folder <- here("nssp_demo", "private_data", model_folder_name) - dir_create(model_folder) - - data_folder <- path(model_folder, state_abb) - dir_create(data_folder) - - - # save state_pop and ed_visits in a single json - write_json( - x = dat$data_for_model_fit, - path = path(data_folder, "data_for_model_fit", ext = "json"), - auto_unbox = TRUE - ) - - # save whole dataset with forecast indicators as a csv - write_csv(dat$prepped_date, file = path(data_folder, "data", ext = "csv")) -} - -walk( - setdiff(usa::state.abb, "PR"), - \(x) { - prep_and_save_data( - disease = "Influenza", - report_date = "2024-10-10", - min_reference_date = "2000-01-01", - max_reference_date = "3000-01-01", - last_training_date = "2024-10-05", - state_abb = x - ) - } -) diff --git a/pipelines/batch/setup_eval_job.py b/pipelines/batch/setup_eval_job.py new file mode 100644 index 00000000..e5446a19 --- /dev/null +++ b/pipelines/batch/setup_eval_job.py @@ -0,0 +1,224 @@ +""" +Set up a multi-location, multi-date, +potentially multi-disease end to end +retrospective evaluation run for pyrenew-hew +on Azure Batch. +""" + +import argparse +import datetime +import itertools + +import polars as pl +from azure.batch import models +from azuretools.auth import EnvCredentialHandler +from azuretools.client import get_batch_service_client +from azuretools.job import create_job_if_not_exists +from azuretools.task import get_container_settings, get_task_config + + +def main( + job_id: str, + pool_id: str, + diseases: str, + container_image_name: str = "pyrenew-hew", + container_image_version: str = "latest", + excluded_locations: list[str] = [ + "AS", + "GU", + "MO", + "MP", + "PR", + "UM", + "VI", + "WY", + ], +) -> None: + """ + job_id + Name for the Batch job. + + pool_id + Azure Batch pool on which to run the job. + + diseases + Name(s) of disease(s) to run as part of the job, + as a whitespace-separated string. Supported + values are 'COVID-19' and 'Influenza'. + + container_image_name: + Name of the container to use for the job. + This container should exist within the Azure + Container Registry account associated to + the job. Default 'pyrenew-hew'. + The container registry account name and endpoint + will be obtained from local environment variables + via a :class``azuretools.auth.EnvCredentialHandler`. + + container_image_version + Version of the container to use. Default 'latest'. + + excluded_locations + List of two letter USPS location abbreviations to + exclude from the job. Defaults to locations for which + we typically do not have available NSSP ED visit data: + ``["AS", "GU", "MO", "MP", "PR", "UM", "VI", "WY"]``. + + Returns + ------- + None + """ + supported_diseases = ["COVID-19", "Influenza"] + + disease_list = diseases.split() + invalid_diseases = set(disease_list) - set(supported_diseases) + if invalid_diseases: + raise ValueError( + f"Unsupported diseases: {', '.join(invalid_diseases)}; " + f"supported diseases are: {', '.join(supported_diseases)}" + ) + + creds = EnvCredentialHandler() + client = get_batch_service_client(creds) + job = models.JobAddParameter( + id=job_id, + pool_info=models.PoolInformation(pool_id=pool_id), + ) + create_job_if_not_exists(client, job, verbose=True) + + container_image = ( + f"{creds.azure_container_registry_account}." + f"{creds.azure_container_registry_domain}/" + f"{container_image_name}:{container_image_version}" + ) + container_settings = get_container_settings( + container_image, + working_directory="containerImageDefault", + mount_pairs=[ + { + "source": "nssp-etl", + "target": "/pyrenew-hew/nssp-etl", + }, + { + "source": "nssp-archival-vintages", + "target": "/pyrenew-hew/nssp-archival-vintages", + }, + { + "source": "prod-param-estimates", + "target": "/pyrenew-hew/params", + }, + { + "source": "pyrenew-hew-prod-output", + "target": "/pyrenew-hew/output", + }, + { + "source": "pyrenew-hew-config", + "target": "/pyrenew-hew/config", + }, + ], + ) + + base_call = ( + "/bin/bash -c '" + "python pipelines/forecast_state.py " + "--disease {disease} " + "--state {state} " + "--n-training-days 365 " + "--n-warmup 1000 " + "--n-samples 500 " + "--facility-level-nssp-data-dir nssp-etl/gold " + "--state-level-nssp-data-dir " + "nssp-archival-vintages/gold " + "--param-data-dir params " + "--output-data-dir output " + "--priors-path config/eval_priors.py " + "--report-date {report_date:%Y-%m-%d} " + "--exclude-last-n-days 2 " + "--score " + "--eval-data-path " + "nssp-archival-vintages/latest_comprehensive.parquet" + "'" + ) + + locations = pl.read_csv( + "https://www2.census.gov/geo/docs/reference/state.txt", separator="|" + ) + + all_locations = ( + locations.filter(~pl.col("STUSAB").is_in(excluded_locations)) + .get_column("STUSAB") + .to_list() + ) + ["US"] + + report_dates = [ + datetime.date(2023, 10, 11) + datetime.timedelta(weeks=x) + for x in range(30) + ] + + for disease, report_date, loc in itertools.product( + disease_list, report_dates, all_locations + ): + task = get_task_config( + f"{job_id}-{loc}-{disease}-{report_date}", + base_call=base_call.format( + state=loc, + disease=disease, + report_date=report_date, + ), + container_settings=container_settings, + ) + client.task.add(job_id, task) + + return None + + +parser = argparse.ArgumentParser() + +parser.add_argument("job_id", type=str, help="Name for the Azure batch job") +parser.add_argument( + "pool_id", + type=str, + help=("Name of the Azure batch pool on which to run the job"), +) +parser.add_argument( + "diseases", + type=str, + help=( + "Name(s) of disease(s) to run as part of the job, " + "as a whitespace-separated string. Supported " + "values are 'COVID-19' and 'Influenza'." + ), +) + +parser.add_argument( + "--container-image-name", + type=str, + help="Name of the container to use for the job.", + default="pyrenew-hew", +) + +parser.add_argument( + "--container-image-version", + type=str, + help="Version of the container to use for the job.", + default="latest", +) + +parser.add_argument( + "--excluded-locations", + type=str, + help=( + "Two-letter USPS location abbreviations to " + "exclude from the job, as a whitespace-separated " + "string. Defaults to a set of locations for which " + "we typically do not have available NSSP ED visit " + "data: 'AS GU MO MP PR UM VI WY'." + ), + default="AS GU MO MP PR UM VI WY", +) + + +if __name__ == "__main__": + args = parser.parse_args() + args.excluded_locations = args.excluded_locations.split() + main(**vars(args)) diff --git a/pipelines/batch/setup_pool.py b/pipelines/batch/setup_pool.py new file mode 100644 index 00000000..23d8a8c1 --- /dev/null +++ b/pipelines/batch/setup_pool.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 + +import argparse + +import azuretools.defaults as d +from azure.mgmt.batch import models +from azuretools import blob +from azuretools.auth import EnvCredentialHandler, get_compute_node_id_reference +from azuretools.client import get_batch_management_client + + +def main(pool_name: str) -> None: + """ + Set up a pool with a given name + and default configuration. + + Parameters + ---------- + pool_name + name for the pool + + Returns + ------- + None + """ + + creds = EnvCredentialHandler() + client = get_batch_management_client(creds) + node_id_ref = get_compute_node_id_reference() + pool_config = d.get_default_pool_config( + pool_name=pool_name, + subnet_id=creds.azure_subnet_id, + user_assigned_identity=creds.azure_user_assigned_identity, + ) + + pool_config.mount_configuration = blob.get_node_mount_config( + storage_containers=[ + "nssp-etl", + "nssp-archival-vintages", + "prod-param-estimates", + "pyrenew-hew-prod-output", + "pyrenew-hew-config", + "pyrenew-test-output", + ], + account_names=creds.azure_blob_storage_account, + identity_references=node_id_ref, + ) + + ( + pool_config.deployment_configuration.virtual_machine_configuration.container_configuration + ) = models.ContainerConfiguration( + type="dockerCompatible", + container_image_names=[ + "https://cfaprdbatchcr.azurecr.io/pyrenew-hew:latest" + ], + container_registries=[creds.azure_container_registry], + ) + + client.pool.create( + resource_group_name=creds.azure_resource_group_name, + account_name=creds.azure_batch_account, + pool_name=pool_name, + parameters=pool_config, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=( + "Set up an Azure batch pool using the azuretools defaults" + ) + ) + parser.add_argument( + "pool_name", + type=str, + help="A name for the pool", + ) + + parsed = vars(parser.parse_args()) + + main(parsed["pool_name"]) diff --git a/pipelines/batch/setup_prod_job.py b/pipelines/batch/setup_prod_job.py new file mode 100644 index 00000000..c06b9cec --- /dev/null +++ b/pipelines/batch/setup_prod_job.py @@ -0,0 +1,228 @@ +""" +Set up a multi-location, multi-date, +potentially multi-disease end to end +retrospective evaluation run for pyrenew-hew +on Azure Batch. +""" + +import argparse +import itertools + +import polars as pl +from azure.batch import models +from azuretools.auth import EnvCredentialHandler +from azuretools.client import get_batch_service_client +from azuretools.job import create_job_if_not_exists +from azuretools.task import get_container_settings, get_task_config + + +def main( + job_id: str, + pool_id: str, + diseases: str | list[str], + container_image_name: str = "pyrenew-hew", + container_image_version: str = "latest", + excluded_locations: list[str] = [ + "AS", + "GU", + "MO", + "MP", + "PR", + "UM", + "VI", + "WY", + ], + test: bool = False, +) -> None: + """ + job_id + Name for the Batch job. + + pool_id + Azure Batch pool on which to run the job. + + diseases + Name(s) of disease(s) to run as part of the job, + as a single string (one disease) or a list of strings. + Supported values are 'COVID-19' and 'Influenza'. + + container_image_name: + Name of the container to use for the job. + This container should exist within the Azure + Container Registry account associated to + the job. Default 'pyrenew-hew'. + The container registry account name and enpoint + will be obtained from local environm variables + via a :class``azuretools.auth.EnvCredentialHandler`. + + container_image_version + Version of the container to use. Default 'latest'. + + excluded_locations + List of two letter USPS location abbreviations to + exclude from the job. Defaults to locations for which + we typically do not have available NSSP ED visit data: + ``["AS", "GU", "MO", "MP", "PR", "UM", "VI", "WY"]``. + + Returns + ------- + None + """ + supported_diseases = ["COVID-19", "Influenza"] + + disease_list = diseases + + invalid_diseases = set(disease_list) - set(supported_diseases) + if invalid_diseases: + raise ValueError( + f"Unsupported diseases: {', '.join(invalid_diseases)}; " + f"supported diseases are: {', '.join(supported_diseases)}" + ) + + pyrenew_hew_output_container = ( + "pyrenew-test-output" if test else "pyrenew-hew-prod-output" + ) + n_warmup = 200 if test else 1000 + n_samples = 200 if test else 500 + + creds = EnvCredentialHandler() + client = get_batch_service_client(creds) + job = models.JobAddParameter( + id=job_id, + pool_info=models.PoolInformation(pool_id=pool_id), + ) + create_job_if_not_exists(client, job, verbose=True) + + container_image = ( + f"{creds.azure_container_registry_account}." + f"{creds.azure_container_registry_domain}/" + f"{container_image_name}:{container_image_version}" + ) + container_settings = get_container_settings( + container_image, + working_directory="containerImageDefault", + mount_pairs=[ + { + "source": "nssp-etl", + "target": "/pyrenew-hew/nssp-etl", + }, + { + "source": "nssp-archival-vintages", + "target": "/pyrenew-hew/nssp-archival-vintages", + }, + { + "source": "prod-param-estimates", + "target": "/pyrenew-hew/params", + }, + { + "source": pyrenew_hew_output_container, + "target": "/pyrenew-hew/output", + }, + { + "source": "pyrenew-hew-config", + "target": "/pyrenew-hew/config", + }, + ], + ) + + base_call = ( + "/bin/bash -c '" + "python pipelines/forecast_state.py " + "--disease {disease} " + "--state {state} " + "--n-training-days 90 " + "--n-warmup {n_warmup} " + "--n-samples {n_samples} " + "--facility-level-nssp-data-dir nssp-etl/gold " + "--state-level-nssp-data-dir " + "nssp-archival-vintages/gold " + "--param-data-dir params " + "--output-data-dir output " + "--priors-path config/prod_priors.py " + "--report-date {report_date} " + "--exclude-last-n-days 5 " + "--no-score " + "--eval-data-path " + "nssp-archival-vintages/latest_comprehensive.parquet" + "'" + ) + + # to be replaced by forecasttools-py table + locations = pl.read_csv( + "https://www2.census.gov/geo/docs/reference/state.txt", separator="|" + ) + + all_locations = [ + loc + for loc in locations.get_column("STUSAB").to_list() + ["US"] + if loc not in excluded_locations + ] + + for disease, state in itertools.product(disease_list, all_locations): + task = get_task_config( + f"{job_id}-{state}-{disease}-prod", + base_call=base_call.format( + state=state, + disease=disease, + report_date="latest", + n_warmup=n_warmup, + n_samples=n_samples, + ), + container_settings=container_settings, + ) + client.task.add(job_id, task) + + return None + + +parser = argparse.ArgumentParser() + +parser.add_argument("job_id", type=str, help="Name for the Azure batch job") +parser.add_argument( + "pool_id", + type=str, + help=("Name of the Azure batch pool on which to run the job"), +) +parser.add_argument( + "diseases", + type=str, + help=( + "Name(s) of disease(s) to run as part of the job, " + "as a whitespace-separated string. Supported " + "values are 'COVID-19' and 'Influenza'." + ), +) + +parser.add_argument( + "--container-image-name", + type=str, + help="Name of the container to use for the job.", + default="pyrenew-hew", +) + +parser.add_argument( + "--container-image-version", + type=str, + help="Version of the container to use for the job.", + default="latest", +) + +parser.add_argument( + "--excluded-locations", + type=str, + help=( + "Two-letter USPS location abbreviations to " + "exclude from the job, as a whitespace-separated " + "string. Defaults to a set of locations for which " + "we typically do not have available NSSP ED visit " + "data: 'AS GU MO MP PR UM VI WY'." + ), + default="AS GU MO MP PR UM VI WY", +) + + +if __name__ == "__main__": + args = parser.parse_args() + args.diseases = args.diseases.split() + args.excluded_locations = args.excluded_locations.split() + main(**vars(args)) diff --git a/pipelines/batch/setup_test_prod_job.py b/pipelines/batch/setup_test_prod_job.py new file mode 100644 index 00000000..987139e4 --- /dev/null +++ b/pipelines/batch/setup_test_prod_job.py @@ -0,0 +1,98 @@ +""" +Set up a multi-location, multi-date, +potentially multi-disease end to end +retrospective evaluation run for pyrenew-hew +on Azure Batch. +""" + +import argparse +import os +from datetime import datetime, timezone +from pathlib import Path + +from pygit2 import Repository +from setup_prod_job import main + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Test production pipeline on small subset of locations" + ) + parser.add_argument( + "--tag", + type=str, + help="The tag name to use for the container image version", + default=Path(Repository(os.getcwd()).head.name).stem, + ) + + args = parser.parse_args() + + tag = args.tag + print(f"Using tag {tag}") + current_datetime = datetime.now(timezone.utc).strftime("%Y%m%d%H%M%SZ") + tag = Path(Repository(os.getcwd()).head.name).stem + + locs_to_exclude = [ # keep CA, MN, SD, and US + "AS", + "GU", + "MO", + "MP", + "PR", + "UM", + "VI", + "WY", + "AK", + "AL", + "AR", + "AZ", + "CO", + "CT", + "DC", + "DE", + "FL", + "GA", + "HI", + "IA", + "ID", + "IL", + "IN", + "KS", + "KY", + "LA", + "MA", + "MD", + "ME", + "MI", + "MS", + "MT", + "NC", + "ND", + "NE", + "NH", + "NJ", + "NM", + "NV", + "NY", + "OH", + "OK", + "OR", + "PA", + "RI", + "SC", + "TN", + "TX", + "UT", + "VA", + "VT", + "WA", + "WI", + "WV", + ] + main( + job_id=f"pyrenew-hew-test-{current_datetime}", + pool_id="pyrenew-pool", + diseases=["COVID-19", "Influenza"], + container_image_name="pyrenew-hew", + container_image_version=tag, + excluded_locations=locs_to_exclude, + test=True, + ) diff --git a/nssp_demo/build_model.py b/pipelines/build_model.py similarity index 61% rename from nssp_demo/build_model.py rename to pipelines/build_model.py index 78099ed2..148bddbd 100644 --- a/nssp_demo/build_model.py +++ b/pipelines/build_model.py @@ -1,22 +1,7 @@ import json +import runpy import jax.numpy as jnp - -# load priors -# have to run this from the right directory -from priors import ( # noqa: E402 - autoreg_p_hosp_rv, - autoreg_rt_rv, - eta_sd_rv, - hosp_wday_effect_rv, - i0_first_obs_n_rv, - inf_feedback_strength_rv, - initialization_rate_rv, - log_r_mu_intercept_rv, - p_hosp_mean_rv, - p_hosp_w_sd_rv, - phi_rv, -) from pyrenew.deterministic import DeterministicVariable from pyrenew_hew.hosp_only_ww_model import hosp_only_ww_model @@ -24,6 +9,7 @@ def build_model_from_dir(model_dir): data_path = model_dir / "data_for_model_fit.json" + prior_path = model_dir / "priors.py" with open( data_path, @@ -45,8 +31,8 @@ def build_model_from_dir(model_dir): jnp.array(model_data["generation_interval_pmf"]), ) # check if off by 1 or reversed - data_observed_hospital_admissions = jnp.array( - model_data["data_observed_hospital_admissions"] + data_observed_disease_hospital_admissions = jnp.array( + model_data["data_observed_disease_hospital_admissions"] ) state_pop = jnp.array(model_data["state_pop"]) @@ -62,26 +48,32 @@ def build_model_from_dir(model_dir): - 1 ) + priors = runpy.run_path(str(prior_path)) + right_truncation_offset = model_data["right_truncation_offset"] my_model = hosp_only_ww_model( state_pop=state_pop, - i0_first_obs_n_rv=i0_first_obs_n_rv, - initialization_rate_rv=initialization_rate_rv, - log_r_mu_intercept_rv=log_r_mu_intercept_rv, - autoreg_rt_rv=autoreg_rt_rv, - eta_sd_rv=eta_sd_rv, # sd of random walk for ar process, + i0_first_obs_n_rv=priors["i0_first_obs_n_rv"], + initialization_rate_rv=priors["initialization_rate_rv"], + log_r_mu_intercept_rv=priors["log_r_mu_intercept_rv"], + autoreg_rt_rv=priors["autoreg_rt_rv"], + eta_sd_rv=priors["eta_sd_rv"], # sd of random walk for ar process, generation_interval_pmf_rv=generation_interval_pmf_rv, - infection_feedback_strength_rv=inf_feedback_strength_rv, + infection_feedback_strength_rv=priors["inf_feedback_strength_rv"], infection_feedback_pmf_rv=infection_feedback_pmf_rv, - p_hosp_mean_rv=p_hosp_mean_rv, - p_hosp_w_sd_rv=p_hosp_w_sd_rv, - autoreg_p_hosp_rv=autoreg_p_hosp_rv, - hosp_wday_effect_rv=hosp_wday_effect_rv, + p_hosp_mean_rv=priors["p_ed_visit_mean_rv"], + p_hosp_w_sd_rv=priors["p_ed_visit_w_sd_rv"], + autoreg_p_hosp_rv=priors["autoreg_p_ed_visit_rv"], + hosp_wday_effect_rv=priors["ed_visit_wday_effect_rv"], inf_to_hosp_rv=inf_to_hosp_rv, - phi_rv=phi_rv, + phi_rv=priors["phi_rv"], right_truncation_pmf_rv=right_truncation_pmf_rv, n_initialization_points=uot, ) - return my_model, data_observed_hospital_admissions, right_truncation_offset + return ( + my_model, + data_observed_disease_hospital_admissions, + right_truncation_offset, + ) diff --git a/pipelines/collate_plots.py b/pipelines/collate_plots.py new file mode 100644 index 00000000..8e4b39bf --- /dev/null +++ b/pipelines/collate_plots.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python3 + +import argparse +import logging +import os +from pathlib import Path + +from pypdf import PdfWriter +from utils import ensure_listlike, get_all_forecast_dirs + + +def merge_pdfs_and_save( + to_merge: list[Path | str], output_path: Path | str +) -> None: + """ + Merge an ordered list of PDF files and + write the result to a designated output path. + + Parameters + ---------- + to_merge + List of paths to the PDF files to merge. + + output_path + Where to write the merged PDF. + + Returns + ------- + None + """ + pdf_writer = PdfWriter() + for pdf_file in to_merge: + pdf_writer.append(pdf_file) + pdf_writer.write(output_path) + + return None + + +def merge_pdfs_from_subdirs( + base_dir: str | Path, + file_name: str, + save_dir: str | Path = None, + output_file_name: str = None, + subdirs_only: list[str] = None, + subdir_pattern="*", +) -> None: + """ + Find matching PDF files from a set of + subdirectories of a base directory + and merge them, writing the resulting + merged PDF file the base directory. + + Parameters + ---------- + base_dir + The base directory in which to save + the resultant merged PDF file. + + file_name + Name of the files to merge. Must be an + exact match. + + save_dir + Directory in which to save the merged PDF. + If ``None``, use a "figures" directory in the parent directory of ``base_dir``. + Default ``None``. + + output_file_name + Name for the merged PDF file, which will be + saved within ``base_dir``. If ``None``, + use ``file_name``. Default ``None``. + + subdirs_only + Explicit list of subdirs to process. If + provided, process only subdirs found + within the ``base_dir`` that are named + in this list (and match the ``subdir_pattern``). + If ``None``, process all subdirs (provided + they match the ``subdir_pattern``). + Default ``None``. + + subdir_pattern + Unix-shell style wildcard pattern that + subdirectories must match to be included. + Default ``'*'`` (match everything). + See documentation for :func:`fnmatch.fnmatch` + for details. + + Returns + ------- + None + """ + + if save_dir is None: + save_dir = Path(base_dir).parent / "figures" + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + subdirs = [ + f.name for f in Path(base_dir).glob(subdir_pattern) if f.is_dir() + ] + + if subdirs_only is not None: + subdirs = [s for s in subdirs if s in subdirs_only] + + to_merge = [ + Path(base_dir, subdir, file_name) + for subdir in subdirs + if os.path.exists(Path(base_dir, subdir, file_name)) + ] + + if output_file_name is None: + output_file_name = file_name + + if len(to_merge) > 0: + merge_pdfs_and_save(to_merge, Path(save_dir, output_file_name)) + + return None + + +def process_dir( + base_dir: Path | str, + target_filenames: str | list[str], + save_dir: Path | str = None, + file_prefix: str = "", + subdirs_only: list[str] = None, +) -> None: + """ + Merge groups of PDFs from the subdirectories of + a given base directory, saving the resulting + merged PDFs in the base directory. + + Parameters + ---------- + base_dir + Path to the base directory in which to look + + target_filenames + One or more PDFs filenames to look for in the + subdirectories and merge. + + save_dir + Directory in which to save the merged PDFs. + If ``None``, use a "figures" directory in the parent directory of ``base_dir``. Default ``None``. + + file_prefix + Prefix to append to the names in `target_filenames` + when naming the merged files. + + subdirs_only + Only look for files to merge in these specific + named subdirectories. If ``None``, look in all + subdirectories of ``base_dir``. Default ``None``. + """ + if save_dir is None: + save_dir = Path(base_dir).parent / "figures" + + for file_name in ensure_listlike(target_filenames): + merge_pdfs_from_subdirs( + base_dir, + file_name, + save_dir, + output_file_name=file_prefix + file_name, + subdirs_only=subdirs_only, + ) + + +def collate_from_all_subdirs( + model_base_dir: str | Path, + disease: str, + target_filenames: str | list[str], + save_dir: str | Path = None, +) -> None: + """ + Collate target plots for a given disease + from a given base directory. + + Parameters + ---------- + model_base_dir + Path to the base directory in whose subdirectories + the script will look for PDFs to merge. + + disease + Name of the target disease. Merged PDFs will be named + with the disease as a prefix. + + target_filenames + One or more PDFs filenames to look for in the + subdirectories and merge. + + save_dir + Directory in which to save the merged PDFs. + If ``None``, use a "figures" directory in the parent directory of ``model_base_dir``. Default ``None``. + + Returns + ------- + None + """ + if save_dir is None: + save_dir = Path(model_base_dir).parent / "figures" + + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + target_filenames = ensure_listlike(target_filenames) + + forecast_dirs = get_all_forecast_dirs(model_base_dir, diseases=disease) + + # first collate locations for a given date + logger.info( + "Collating plots across locations, by forecast date. " + f"{len(forecast_dirs)} dates to process." + ) + for f_dir in forecast_dirs: + logger.info(f"Collating plots from {f_dir}") + process_dir( + base_dir=Path(model_base_dir, f_dir), + target_filenames=target_filenames, + save_dir=save_dir, + ) + logger.info("Done collating across locations by date.") + + # then collate dates, adding the disease name + # as a prefix for disambiguation since the + # top-level directory may contain forecasts + # for multiple diseases. + logger.info("Collating plots from forecast date directories...") + process_dir( + base_dir=model_base_dir, + target_filenames=target_filenames, + save_dir=save_dir, + file_prefix=f"{disease}_", + subdirs_only=forecast_dirs, + ) + + logger.info("Done collating plots from forecast date directories.") + + return None + + +def main( + dir_of_forecast_dirs: str | Path, + single_forecast_dir: str | Path, + target_filenames: list[str], + disease: str = None, +) -> None: + if not ((dir_of_forecast_dirs is None) ^ (single_forecast_dir is None)): + raise ValueError( + "Must provide exactly one of " + "'dir_of_forecast_dirs' (to process multiple " + "groups of forecasts) or " + "'single_forecast_dir' " + "(to process a single set of forecasts" + ) + elif dir_of_forecast_dirs is not None: + if disease is None: + raise ValueError( + "'disease' must not be None when collating plots " + "from multiple forecast subdirectories" + ) + collate_from_all_subdirs( + dir_of_forecast_dirs, disease, target_filenames + ) + elif single_forecast_dir is not None: + process_dir(single_forecast_dir, target_filenames) + return None + + +parser = argparse.ArgumentParser( + description=("Collate forecast plots from subdirectories into single PDFs") +) + +parser.add_argument( + "--dir-of-forecast-dirs", + type=Path, + help=( + "Base directory containing subdirectories that represent " + "individual forecast dates, each of which in turn has " + "subdirectories that represent individual location forecasts." + ), + default=None, +) + +parser.add_argument( + "--single-forecast-dir", + type=Path, + help="Path to a single directory to process", + default=None, +) + +parser.add_argument( + "--disease", + type=str, + help="Name of the disease for which to collate plots.", + default=None, +) + +parser.add_argument( + "--target-filenames", + type=str, + default=( + "Disease_forecast_plot.pdf Other_forecast_plot.pdf " + "prop_disease_ed_visits_forecast_plot.pdf " + "Disease_forecast_plot_log.pdf Other_forecast_plot_log.pdf " + "prop_disease_ed_visits_forecast_plot_log.pdf" + ), + help=( + "Exact filenames of PDF files to find and merge, including " + "the file extension but without the directory path, as " + "a whitespace-separated string" + ), +) + +if __name__ == "__main__": + args = parser.parse_args() + args.target_filenames = args.target_filenames.split() + main(**vars(args)) diff --git a/pipelines/collate_score_tables.R b/pipelines/collate_score_tables.R new file mode 100644 index 00000000..9a6aaf3c --- /dev/null +++ b/pipelines/collate_score_tables.R @@ -0,0 +1,175 @@ +script_packages <- c( + "data.table", + "argparser", + "stringr" +) + +# load in packages without messages +purrr::walk(script_packages, \(pkg) { + suppressPackageStartupMessages( + library(pkg, character.only = TRUE) + ) +}) + + +process_loc_date_score_table <- function(model_run_dir) { + table_path <- fs::path(model_run_dir, + "score_table", + ext = "rds" + ) + parsed <- hewr::parse_model_run_dir_path(model_run_dir) + + if (!(fs::file_exists(table_path))) { + warning(glue::glue( + "No `score_table.rds` found for location ", + "{location} in directory {model_run_dir}" + )) + return(NULL) + } + score_table <- readr::read_rds(table_path) + + ## add parsed metadata to both quantile and sample + ## score tables + for (x in names(parsed)) { + score_table$quantile_scores[[x]] <- parsed[[x]] + score_table$sample_scores[[x]] <- parsed[[x]] + } + + return(score_table) +} + + + +bind_tables <- function(list_of_table_pairs) { + sample_metrics <- purrr::map( + list_of_table_pairs, + \(x) { + attr(x$sample_scores, "metrics") + } + ) |> + unlist() |> + unique() + quantile_metrics <- purrr::map( + list_of_table_pairs, + \(x) { + attr(x$quantile_scores, "metrics") + } + ) |> + unlist() |> + unique() + + sample_scores <- purrr::map( + list_of_table_pairs, "sample_scores" + ) |> + data.table::rbindlist(fill = TRUE) + + quantile_scores <- purrr::map( + list_of_table_pairs, "quantile_scores" + ) |> + data.table::rbindlist(fill = TRUE) + + attr(sample_scores, "metrics") <- sample_metrics + attr(quantile_scores, "metrics") <- quantile_metrics + + + return(list( + sample_scores = sample_scores, + quantile_scores = quantile_scores + )) +} + +collate_scores_for_date <- function(model_run_dir, + score_file_name = "score_table", + score_file_ext = "rds", + save = FALSE) { + message(glue::glue("Processing scores from {model_run_dir}...")) + locations_to_process <- fs::dir_ls(model_run_dir, "model_runs", + type = "directory" + ) + date_score_table <- purrr::map( + locations_to_process, + process_loc_date_score_table + ) |> + bind_tables() + + if (save) { + save_path <- fs::path(model_run_dir, + score_file_name, + ext = score_file_ext + ) + message(glue::glue("Saving score table to {save_path}...")) + readr::write_rds(date_score_table, save_path) + } + message(glue::glue("Done processing scores for {model_run_dir}.")) + return(date_score_table) +} + + +collate_all_score_tables <- function(model_base_dir, + disease, + score_file_save_path = NULL) { + date_dirs_to_process <- hewr::get_all_model_batch_dirs( + model_base_dir, + diseases = disease + ) + + # collate scores across locations for each date + date_score_table <- purrr::map( + date_dirs_to_process, + \(x) { + collate_scores_for_date( + x, + save = save + ) + } + ) + + # get all dates, annotate, and combine + message( + "Combining date-specific score tables ", + "to create a full score table..." + ) + + full_score_table <- bind_tables(date_tables) + + if (!is.null(score_file_save_path)) { + message(glue::glue(paste0( + "Saving full score table to ", + "{score_file_save_path}..." + ))) + readr::write_rds(full_score_table, save_path) + } + + message("Done creating full score table.") + + return(full_score_table) +} + + +p <- arg_parser( + "Collate tables of scores into a single table across locations and dates." +) |> + add_argument( + "model_base_dir", + help = paste0( + "Base directory containing subdirectories that represent ", + "individual forecast dates, each of which in turn has ", + "subdirectories that represent individual location forecasts." + ) + ) |> + add_argument( + "disease", + help = paste0( + "Name of the disease for which to collate scores." + ) + ) + +argv <- parse_args(p) + +collate_all_score_tables( + argv$model_base_dir, + argv$disease, + score_file_name = glue::glue("{argv$disease}_score_table"), + score_file_ext = "rds", + save = TRUE +) diff --git a/pipelines/convert_inferencedata_to_parquet.R b/pipelines/convert_inferencedata_to_parquet.R new file mode 100644 index 00000000..39f4054a --- /dev/null +++ b/pipelines/convert_inferencedata_to_parquet.R @@ -0,0 +1,95 @@ +script_packages <- c( + "argparser", + "arrow", + "dplyr", + "forecasttools", + "fs", + "ggplot2", + "lubridate", + "readr", + "scoringutils", + "stringr", + "tidyr" +) + +## load in packages without messages +purrr::walk(script_packages, \(pkg) { + suppressPackageStartupMessages( + library(pkg, character.only = TRUE) + ) +}) + + +tidy_and_save_mcmc <- function(model_run_dir, + file_name_prefix = "", + filter_bad_chains, + good_chain_tol) { + inference_data_path <- path(model_run_dir, "inference_data", ext = "csv") + + tidy_inference_data <- inference_data_path |> + read_csv(show_col_types = FALSE) |> + inferencedata_to_tidy_draws() + + if (filter_bad_chains) { + good_chains <- + tibble::deframe(tidy_inference_data)$log_likelihood |> + pivot_longer(-starts_with(".")) |> + group_by(.iteration, .chain) |> + summarize(value = sum(value), .groups = "drop") |> + group_by(.chain) |> + summarize(value = mean(value)) |> + filter(value >= max(value) - 2) |> + pull(.chain) + } else { + good_chains <- + tibble::deframe(tidy_inference_data)$log_likelihood$.chain |> + unique() + } + + tidy_inference_data <- tidy_inference_data |> + mutate( + data = purrr::map(data, \(x) filter(x, .chain %in% good_chains)) + ) + + save_dir <- path(model_run_dir, "mcmc_tidy") + dir_create(save_dir) + + purrr::pwalk(tidy_inference_data, .f = function(group_name, data) { + write_parquet(data, path(save_dir, + str_c(file_name_prefix, group_name), + ext = "parquet" + )) + }) +} + + +p <- arg_parser("Tidy InferenceData to Parquet files") |> + add_argument( + "model_run_dir", + help = "Directory containing the model data and output.", + ) |> + add_argument( + "--no-filter-bad-chains", + help = paste0( + "By default, tidy_and_save_mcmc.R filters ", + "any bad chains from the samples. Set this flag ", + "to retain them" + ), + flag = TRUE + ) |> + add_argument( + "--good-chain-tol", + help = "Tolerance level for determining good chains.", + default = 2L + ) + +argv <- parse_args(p) +model_run_dir <- path(argv$model_run_dir) +filter_bad_chains <- !argv$no_filter_bad_chains +good_chain_tol <- argv$good_chain_tol + +tidy_and_save_mcmc(model_run_dir, + file_name_prefix = "pyrenew_", + filter_bad_chains, + good_chain_tol +) diff --git a/pipelines/create_hubverse_table.R b/pipelines/create_hubverse_table.R new file mode 100644 index 00000000..db724c39 --- /dev/null +++ b/pipelines/create_hubverse_table.R @@ -0,0 +1,49 @@ +#!/usr/bin/env Rscript + + +#' Create a hubverse table from model output, using +#' utilities from `hewr`. +#' +#' @param model_batch_dir Model batch directory from which +#' to create a hubverse table +#' @param output_path path to save the table as a tsv +#' @param exclude Locations to exclude, as a vector of strings. +#' @return Nothing, saving the table as a side effect. +main <- function(model_batch_dir, + output_path, + exclude = NULL) { + hewr::to_epiweekly_quantile_table( + model_batch_dir, + exclude = exclude + ) |> + readr::write_tsv(output_path) +} + + +p <- argparser::arg_parser( + "Create a hubverse table from location specific forecast draws." +) |> + argparser::add_argument( + "model_batch_dir", + help = paste0( + "Directory containing subdirectories that represent ", + "individual forecast locations, with a directory name ", + "that indicates the target pathogen and reference date" + ) + ) |> + argparser::add_argument( + "output_path", + help = "path to which to save the table" + ) |> + argparser::add_argument( + "--exclude", + help = "locations to exclude, as a whitespace-separated string", + default = "" + ) +argv <- argparser::parse_args(p) + +main( + argv$model_batch_dir, + argv$output_path, + stringr::str_split_1(argv$exclude, " ") +) diff --git a/pipelines/default_priors.py b/pipelines/default_priors.py new file mode 100644 index 00000000..4f9d61ab --- /dev/null +++ b/pipelines/default_priors.py @@ -0,0 +1,68 @@ +import jax.numpy as jnp +import numpyro.distributions as dist +import pyrenew.transformation as transformation +from numpyro.infer.reparam import LocScaleReparam +from pyrenew.randomvariable import DistributionalVariable, TransformedVariable + +i0_first_obs_n_rv = DistributionalVariable( + "i0_first_obs_n_rv", + dist.Beta(1, 10), +) + +initialization_rate_rv = DistributionalVariable( + "rate", dist.Normal(0, 0.01), reparam=LocScaleReparam(0) +) + +r_logmean = jnp.log(1) +r_logsd = jnp.log(jnp.sqrt(2)) + +log_r_mu_intercept_rv = DistributionalVariable( + "log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd) +) + +eta_sd_rv = DistributionalVariable( + "eta_sd", dist.TruncatedNormal(0.04, 0.02, low=0) +) + +autoreg_rt_rv = DistributionalVariable("autoreg_rt", dist.Beta(2, 40)) + + +inf_feedback_strength_rv = TransformedVariable( + "inf_feedback", + DistributionalVariable( + "inf_feedback_raw", + dist.LogNormal(jnp.log(50), jnp.log(2)), + ), + transforms=transformation.AffineTransform(loc=0, scale=-1), +) +# Could be reparameterized? + +p_ed_visit_mean_rv = DistributionalVariable( + "p_ed_visit_mean", + dist.Normal( + transformation.SigmoidTransform().inv(0.005), + 0.3, + ), +) # logit scale + + +p_ed_visit_w_sd_rv = DistributionalVariable( + "p_ed_visit_w_sd_sd", dist.TruncatedNormal(0, 0.01, low=0) +) + + +autoreg_p_ed_visit_rv = DistributionalVariable( + "autoreg_p_ed_visit_rv", dist.Beta(1, 100) +) + +ed_visit_wday_effect_rv = TransformedVariable( + "ed_visit_wday_effect", + DistributionalVariable( + "ed_visit_wday_effect_raw", + dist.Dirichlet(jnp.array([5, 5, 5, 5, 5, 5, 5])), + ), + transformation.AffineTransform(loc=0, scale=7), +) + +# Based on looking at some historical posteriors. +phi_rv = DistributionalVariable("phi", dist.LogNormal(6, 1)) diff --git a/pipelines/diagnostic_report/custom.scss b/pipelines/diagnostic_report/custom.scss new file mode 100644 index 00000000..b5d73809 --- /dev/null +++ b/pipelines/diagnostic_report/custom.scss @@ -0,0 +1,8 @@ +/*-- scss:rules --*/ +aside#bslib-sidebar-1 { + font-size: 1.1em; +} + +aside#bslib-sidebar-1 * { + font-size: 1.1em; +} diff --git a/pipelines/diagnostic_report/render_website.R b/pipelines/diagnostic_report/render_website.R new file mode 100644 index 00000000..035c405e --- /dev/null +++ b/pipelines/diagnostic_report/render_website.R @@ -0,0 +1,66 @@ +library(tidyverse) +library(fs) +library(quarto) + +base_dir <- path( + "~/pyrenew-hew", "nssp_demo", + "private_data", + "pyrenew-test-output", + "influenza_r_2024-11-12_f_2024-08-29_t_2024-11-11" +) +# parse this from CLI + +# The site should be contained in a single directory for easy linking between +# pages and sharing html files +site_output_dir <- path(base_dir, "diagnostic_report") +template_dir <- dir <- path("nssp_demo", "diagnostic_report") +css_file_name <- path("custom", ext = "scss") + +template_css_path <- path(template_dir, css_file_name) |> path_real() +template_qmd_path <- path(template_dir, "template", ext = "qmd") + + +wd_css <- tryCatch( + path_real(css_file_name), + error = function(e) { + message("An error occurred: ", e$message) + FALSE + } +) + +# Temporarily create template css in working directory +# otherwise quarto_render won't be able to find it +if (template_css_path != wd_css) { + file_copy(template_css_path, css_file_name, overwrite = TRUE) +} + + +quarto_render_tbl <- + tibble(state_dir = dir_ls(base_dir, type = "directory")) |> + mutate(qmd_path = path(site_output_dir, path_file(state_dir), ext = "qmd")) + +dir_create(site_output_dir) + +# Copy template with new file names to output directory +walk(quarto_render_tbl$qmd_path, function(x) { + file_copy(template_qmd_path, x, overwrite = TRUE) +}) + +# Render all qmd's +pwalk( + quarto_render_tbl, + function(state_dir, qmd_path) { + quarto_render( + input = qmd_path, + execute_params = list(model_dir_raw = state_dir) + ) + } +) + +# Delete qmd's +file_delete(quarto_render_tbl$qmd_path) + +# Clean up css in working directory +if (template_css_path != wd_css) { + file_delete(css_file_name) +} diff --git a/pipelines/diagnostic_report/template.qmd b/pipelines/diagnostic_report/template.qmd new file mode 100644 index 00000000..327ef4c2 --- /dev/null +++ b/pipelines/diagnostic_report/template.qmd @@ -0,0 +1,79 @@ +--- +title: "PyRenew-HEW Model Diagnostics" +format: + dashboard: + theme: + - flatly + - custom.scss + embed-resources: false +params: + model_dir_raw: "/home/xum8/pyrenew-hew/nssp_demo/private_data/pyrenew-test-output/influenza_r_2024-11-06_f_2024-08-18_t_2024-10-31/CA/" # pragma: allowlist-secret +--- + + +```{r Parse Params} +library(reticulate) +library(tidyverse) +library(fs) + +model_dir <- path(params$model_dir_raw) +this_state <- path_file(model_dir) +available_states <- model_dir %>% + path_dir() %>% + dir_ls(type = "directory") %>% + path_file() +``` + +## {.sidebar} + +```{r Render Sidebar} +#| output: asis +library(htmltools) +formatted_available_states <- + available_states %>% + map(\(x) a(x, href = path(x, ext = "html"))) %>% + map(p) %>% + map_chr(as.character) %>% + str_c(collapse = "") + +cat(formatted_available_states) +``` + +# Run Info + +::: {.card title="Model Metadata"} +This state is `r this_state`. +::: + +# Forecasts + +```{r Example Forecast} +#| title: A Forecast +library(cowplot) +theme_set(cowplot::theme_cowplot()) +eval_dat <- read_tsv(path(model_dir, "eval_data", ext = "tsv")) + +ggplot(eval_dat, aes(date, ed_visits, color = data_type)) + + facet_wrap(~disease, scale = "free_y") + + geom_point() +``` + +# MCMC Diagnostics + +```{python Load InferenceData} +from pathlib import Path +import polars as pl +from itables import to_html_datatable +import arviz as az + +model_dir = Path(r.params["model_dir_raw"]) +idata = az.from_netcdf(Path(model_dir, "inference_data.nc")) +idata_summary = az.summary(idata) +``` + +```{r Render InferenceData} +#| title: MCMC Summary +library(DT) +datatable(py$idata_summary) +# highlight or exclusively show bad rhats and ess? +``` diff --git a/pipelines/fit_model.py b/pipelines/fit_model.py new file mode 100644 index 00000000..9ecefd2a --- /dev/null +++ b/pipelines/fit_model.py @@ -0,0 +1,103 @@ +import argparse +import pickle +from pathlib import Path + +import jax +import numpy as np +from build_model import build_model_from_dir + + +def fit_and_save_model( + model_run_dir: str, + n_warmup: int = 1000, + n_samples: int = 1000, + n_chains: int = 4, + rng_key: int = None, +) -> None: + if rng_key is None: + rng_key = np.random.randint(0, 10000) + if isinstance(rng_key, int): + rng_key = jax.random.key(rng_key) + else: + raise ValueError( + "rng_key must be an integer with which " + "to seed :func:`jax.random.key`" + ) + ( + my_model, + data_observed_disease_hospital_admissions, + right_truncation_offset, + ) = build_model_from_dir(model_run_dir) + my_model.run( + num_warmup=n_warmup, + num_samples=n_samples, + rng_key=rng_key, + data_observed_disease_hospital_admissions=( + data_observed_disease_hospital_admissions + ), + right_truncation_offset=right_truncation_offset, + mcmc_args=dict(num_chains=n_chains, progress_bar=True), + nuts_args=dict(find_heuristic_step_size=True), + ) + + my_model.mcmc.sampler = None + + with open( + model_run_dir / "posterior_samples.pickle", + "wb", + ) as file: + pickle.dump(my_model.mcmc, file) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Fit the hospital-only wastewater model." + ) + parser.add_argument( + "model_run_dir", + type=Path, + help=( + "Path to a directory containing model fitting data. " + "The completed fit will be saved here." + ), + ) + parser.add_argument( + "--n-warmup", + type=int, + default=1000, + help=( + "Number of warmup iterations for the No-U-Turn sampler " + "(Default: 1000)." + ), + ) + parser.add_argument( + "--n-samples", + type=int, + default=1000, + help=( + "Number of sampling iterations after warmup " + "for the No-U-Turn sampler " + "(Default: 1000)." + ), + ) + parser.add_argument( + "--n-chains", + type=int, + default=4, + help=("Number of duplicate MCMC chains to run " "(Default 4)."), + ) + parser.add_argument( + "--rng-key", + type=int, + default=None, + help=( + "Integer with which to seed the pseudorandom" + "number generator. If none is specified, a " + "pseudorandom seed will be drawn via " + "np.random.randint" + ), + ) + + args = parser.parse_args() + + fit_and_save_model(**vars(args)) diff --git a/pipelines/forecast_state.py b/pipelines/forecast_state.py new file mode 100644 index 00000000..5a4388f7 --- /dev/null +++ b/pipelines/forecast_state.py @@ -0,0 +1,419 @@ +import argparse +import logging +import os +import shutil +import subprocess +from datetime import datetime, timedelta +from pathlib import Path + +import numpyro +import polars as pl +from prep_data import process_and_save_state +from save_eval_data import save_eval_data + +numpyro.set_host_device_count(4) + +from fit_model import fit_and_save_model # noqa +from generate_predictive import generate_and_save_predictions # noqa + + +def baseline_forecasts( + model_run_dir: Path, n_forecast_days: int, n_samples: int +) -> None: + result = subprocess.run( + [ + "Rscript", + "pipelines/timeseries_forecasts.R", + f"{model_run_dir}", + "--n-forecast-days", + f"{n_forecast_days}", + "--n-samples", + f"{n_samples}", + ], + capture_output=True, + ) + if result.returncode != 0: + raise RuntimeError(f"baseline_forecasts: {result.stderr}") + return None + + +def convert_inferencedata_to_parquet(model_run_dir: Path) -> None: + result = subprocess.run( + [ + "Rscript", + "pipelines/convert_inferencedata_to_parquet.R", + f"{model_run_dir}", + ], + capture_output=True, + ) + if result.returncode != 0: + raise RuntimeError( + f"convert_inferencedata_to_parquet: {result.stderr}" + ) + return None + + +def postprocess_forecast(model_run_dir: Path) -> None: + result = subprocess.run( + [ + "Rscript", + "pipelines/postprocess_state_forecast.R", + f"{model_run_dir}", + ], + capture_output=True, + ) + if result.returncode != 0: + raise RuntimeError(f"postprocess_forecast: {result.stderr}") + return None + + +def score_forecast(model_run_dir: Path) -> None: + result = subprocess.run( + [ + "Rscript", + "pipelines/score_forecast.R", + f"{model_run_dir}", + ], + capture_output=True, + ) + if result.returncode != 0: + raise RuntimeError(f"score_forecast: {result.stderr}") + return None + + +def get_available_reports( + data_dir: str | Path, glob_pattern: str = "*.parquet" +): + return [ + datetime.strptime(f.stem, "%Y-%m-%d").date() + for f in Path(data_dir).glob(glob_pattern) + ] + + +def main( + disease: str, + report_date: str, + state: str, + facility_level_nssp_data_dir: Path | str, + state_level_nssp_data_dir: Path | str, + param_data_dir: Path | str, + priors_path: Path | str, + output_data_dir: Path | str, + n_training_days: int, + n_forecast_days: int, + n_chains: int, + n_warmup: int, + n_samples: int, + exclude_last_n_days: int = 0, + score: bool = False, + eval_data_path: Path = None, +): + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + available_facility_level_reports = get_available_reports( + facility_level_nssp_data_dir + ) + + available_state_level_reports = get_available_reports( + state_level_nssp_data_dir + ) + first_available_state_report = min(available_state_level_reports) + last_available_state_report = max(available_state_level_reports) + + if report_date == "latest": + report_date = max(available_facility_level_reports) + else: + report_date = datetime.strptime(report_date, "%Y-%m-%d").date() + + if report_date in available_state_level_reports: + state_report_date = report_date + elif report_date > last_available_state_report: + state_report_date = last_available_state_report + elif report_date > first_available_state_report: + raise ValueError( + "Dataset appear to be missing some state-level " + f"reports. First entry is {first_available_state_report}, " + f"last is {last_available_state_report}, but no entry " + f"for {report_date}" + ) + else: + raise ValueError( + "Requested report date is earlier than the first " + "state-level vintage. This is not currently supported" + ) + + logger.info(f"Report date: {report_date}") + if state_report_date is not None: + logger.info(f"Using state-level data as of: {state_report_date}") + + # + 1 because max date in dataset is report_date - 1 + last_training_date = report_date - timedelta(days=exclude_last_n_days + 1) + + if last_training_date >= report_date: + raise ValueError( + "Last training date must be before the report date. " + "Got a last training date of {last_training_date} " + "with a report date of {report_date}." + ) + + logger.info(f"last training date: {last_training_date}") + + first_training_date = last_training_date - timedelta( + days=n_training_days - 1 + ) + + logger.info(f"First training date {first_training_date}") + + facility_level_nssp_data, state_level_nssp_data = None, None + + if report_date in available_facility_level_reports: + logger.info( + "Facility level data available for " "the given report date" + ) + facility_datafile = f"{report_date}.parquet" + facility_level_nssp_data = pl.scan_parquet( + Path(facility_level_nssp_data_dir, facility_datafile) + ) + if state_report_date in available_state_level_reports: + logger.info("State-level data available for the given report " "date.") + state_datafile = f"{state_report_date}.parquet" + state_level_nssp_data = pl.scan_parquet( + Path(state_level_nssp_data_dir, state_datafile) + ) + if facility_level_nssp_data is None and state_level_nssp_data is None: + raise ValueError( + "No data available for the requested report date " f"{report_date}" + ) + + param_estimates = pl.scan_parquet(Path(param_data_dir, "prod.parquet")) + model_batch_dir_name = ( + f"{disease.lower()}_r_{report_date}_f_" + f"{first_training_date}_t_{last_training_date}" + ) + + model_batch_dir = Path(output_data_dir, model_batch_dir_name) + + model_run_dir = Path(model_batch_dir, "model_runs", state) + + os.makedirs(model_run_dir, exist_ok=True) + + logger.info(f"Using priors from {priors_path}...") + shutil.copyfile(priors_path, Path(model_run_dir, "priors.py")) + + logger.info(f"Processing {state}") + process_and_save_state( + state_abb=state, + disease=disease, + facility_level_nssp_data=facility_level_nssp_data, + state_level_nssp_data=state_level_nssp_data, + report_date=report_date, + state_level_report_date=state_report_date, + first_training_date=first_training_date, + last_training_date=last_training_date, + param_estimates=param_estimates, + model_run_dir=model_run_dir, + logger=logger, + ) + logger.info("Data preparation complete.") + + logger.info("Fitting model") + fit_and_save_model( + model_run_dir, + n_warmup=n_warmup, + n_samples=n_samples, + n_chains=n_chains, + ) + logger.info("Model fitting complete") + + logger.info("Performing posterior prediction / forecasting...") + + n_days_past_last_training = n_forecast_days + exclude_last_n_days + generate_and_save_predictions(model_run_dir, n_days_past_last_training) + + logger.info( + "Performing baseline forecasting and non-target pathogen " + "forecasting..." + ) + n_denominator_samples = n_samples * n_chains + baseline_forecasts( + model_run_dir, n_days_past_last_training, n_denominator_samples + ) + logger.info("Forecasting complete.") + logger.info("Getting eval data...") + if eval_data_path is None: + raise ValueError("No path to an evaluation dataset provided.") + save_eval_data( + state=state, + report_date=report_date, + disease=disease, + first_training_date=first_training_date, + last_training_date=last_training_date, + latest_comprehensive_path=eval_data_path, + output_data_dir=model_run_dir, + last_eval_date=report_date + timedelta(days=n_forecast_days), + ) + + logger.info("Converting inferencedata to parquet...") + convert_inferencedata_to_parquet(model_run_dir) + logger.info("Conversion complete.") + + logger.info("Postprocessing forecast...") + postprocess_forecast(model_run_dir) + logger.info("Postprocessing complete.") + + if score: + logger.info("Scoring forecast...") + score_forecast(model_run_dir) + + logger.info( + "Single state pipeline complete " + f"for state {state} with " + f"report date {report_date}." + ) + return None + + +parser = argparse.ArgumentParser( + description="Create fit data for disease modeling." +) +parser.add_argument( + "--disease", + type=str, + required=True, + help="Disease to model (e.g., COVID-19, Influenza, RSV).", +) + +parser.add_argument( + "--state", + type=str, + required=True, + help=( + "Two letter abbreviation for the state to fit" + "(e.g. 'AK', 'AL', 'AZ', etc.)." + ), +) + +parser.add_argument( + "--report-date", + type=str, + default="latest", + help="Report date in YYYY-MM-DD format or latest (default: latest).", +) + +parser.add_argument( + "--facility-level-nssp-data-dir", + type=Path, + default=Path("private_data", "nssp_etl_gold"), + help=( + "Directory in which to look for facility-level NSSP " "ED visit data" + ), +) + +parser.add_argument( + "--state-level-nssp-data-dir", + type=Path, + default=Path("private_data", "nssp_state_level_gold"), + help=("Directory in which to look for state-level NSSP " "ED visit data."), +) + +parser.add_argument( + "--param-data-dir", + type=Path, + default=Path("private_data", "prod_param_estimates"), + help=( + "Directory in which to look for parameter estimates" + "such as delay PMFs." + ), + required=True, +) + +parser.add_argument( + "--priors-path", + type=Path, + help=( + "Path to an executible python file defining random variables " + "that require priors as pyrenew RandomVariable objects." + ), + required=True, +) + + +parser.add_argument( + "--output-data-dir", + type=Path, + default="private_data", + help="Directory in which to save output data.", +) + +parser.add_argument( + "--n-training-days", + type=int, + default=180, + help="Number of training days (default: 180).", +) + +parser.add_argument( + "--n-forecast-days", + type=int, + default=28, + help=( + "Number of days ahead to forecast relative to the " + "report date (default: 28).", + ), +) + + +parser.add_argument( + "--n-chains", + type=int, + default=4, + help="Number of MCMC chains to run (default: 4).", +) + +parser.add_argument( + "--n-warmup", + type=int, + default=1000, + help=("Number of warmup iterations per chain for NUTS" "(default: 1000)."), +) + +parser.add_argument( + "--n-samples", + type=int, + default=1000, + help=( + "Number of posterior samples to draw per " + "chain using NUTS (default: 1000)." + ), +) + +parser.add_argument( + "--exclude-last-n-days", + type=int, + default=0, + help=( + "Optionally exclude the final n days of available training " + "data (Default: 0, i.e. exclude no available data" + ), +) + +parser.add_argument( + "--score", + type=bool, + action=argparse.BooleanOptionalAction, + help=("If this flag is provided, will attempt to score the forecast."), +) + + +parser.add_argument( + "--eval-data-path", + type=Path, + help=("Path to a parquet file containing compehensive truth data."), +) + + +if __name__ == "__main__": + args = parser.parse_args() + numpyro.set_host_device_count(args.n_chains) + main(**vars(args)) diff --git a/pipelines/generate_predictive.py b/pipelines/generate_predictive.py new file mode 100644 index 00000000..5a0d82d3 --- /dev/null +++ b/pipelines/generate_predictive.py @@ -0,0 +1,72 @@ +import argparse +import pickle +from pathlib import Path + +import arviz as az +from build_model import build_model_from_dir + + +def generate_and_save_predictions( + model_run_dir: str | Path, n_forecast_points: int +) -> None: + model_run_dir = Path(model_run_dir) + + ( + my_model, + data_observed_disease_hospital_admissions, + right_truncation_offset, + ) = build_model_from_dir(model_run_dir) + + my_model._init_model(1, 1) + fresh_sampler = my_model.mcmc.sampler + + with open( + model_run_dir / "posterior_samples.pickle", + "rb", + ) as file: + my_model.mcmc = pickle.load(file) + + my_model.mcmc.sampler = fresh_sampler + + posterior_predictive = my_model.posterior_predictive( + n_datapoints=len(data_observed_disease_hospital_admissions) + + n_forecast_points + ) + + idata = az.from_numpyro( + my_model.mcmc, + posterior_predictive=posterior_predictive, + ) + + idata.to_dataframe().to_csv( + model_run_dir / "inference_data.csv", index=False + ) + + # Save one netcdf for reloading + idata.to_netcdf(model_run_dir / "inference_data.nc") + + return None + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=("Do posterior prediction from a pyrenew-hew fit.") + ) + parser.add_argument( + "model_run_dir", + type=Path, + help=( + "Path to a directory containing the model fitting data " + "and the posterior chains. " + "The completed predictive samples will be saved here." + ), + ) + parser.add_argument( + "--n-forecast-points", + type=int, + default=0, + help="Number of time points to forecast (Default: 0).", + ) + args = parser.parse_args() + + generate_and_save_predictions(**vars(args)) diff --git a/nssp_demo/fit_all_models.sh b/pipelines/iteration_helpers/loop_fit.sh similarity index 89% rename from nssp_demo/fit_all_models.sh rename to pipelines/iteration_helpers/loop_fit.sh index cf622600..837d5cf0 100755 --- a/nssp_demo/fit_all_models.sh +++ b/pipelines/iteration_helpers/loop_fit.sh @@ -13,5 +13,5 @@ BASE_DIR="$1" for SUBDIR in "$BASE_DIR"/*/; do # Run the Python script with the current subdirectory as the model_dir argument echo "$SUBDIR" - python fit_model.py --model_dir "$SUBDIR" + python fit_model.py "$SUBDIR" done diff --git a/nssp_demo/generate_all_predictive.sh b/pipelines/iteration_helpers/loop_generate_predictive.sh similarity index 83% rename from nssp_demo/generate_all_predictive.sh rename to pipelines/iteration_helpers/loop_generate_predictive.sh index 7c0aac22..b2510d21 100755 --- a/nssp_demo/generate_all_predictive.sh +++ b/pipelines/iteration_helpers/loop_generate_predictive.sh @@ -14,5 +14,5 @@ BASE_DIR="$1" for SUBDIR in "$BASE_DIR"/*/; do # Run the Python script with the current subdirectory as the model_dir argument echo "$SUBDIR" - python generate_predictive.py --model_dir "$SUBDIR" --n_forecast_points 28 + python generate_predictive.py "$SUBDIR" --n-forecast-points 28 done diff --git a/pipelines/iteration_helpers/loop_postprocess.sh b/pipelines/iteration_helpers/loop_postprocess.sh new file mode 100755 index 00000000..22603452 --- /dev/null +++ b/pipelines/iteration_helpers/loop_postprocess.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +# Check if the base directory is provided as an argument +if [ -z "$1" ]; then + echo "Usage: $0 " + exit 1 +fi + +# Base directory containing subdirectories +BASE_DIR="$1" + +# Iterate over each subdirectory in the base directory +for SUBDIR in "$BASE_DIR"/*/; do + # Run the R script with the current subdirectory as the model_dir argument + echo "$SUBDIR" + Rscript convert_inferencedata_to_parquet.R "$SUBDIR" + Rscript postprocess_state_forecast.R "$SUBDIR" +done diff --git a/pipelines/iteration_helpers/loop_score.sh b/pipelines/iteration_helpers/loop_score.sh new file mode 100755 index 00000000..6ead92cb --- /dev/null +++ b/pipelines/iteration_helpers/loop_score.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +# Check if the base directory is provided as an argument +if [ -z "$1" ]; then + echo "Usage: $0 " + exit 1 +fi + +# Base directory containing subdirectories +BASE_DIR="$1" + +# Iterate over each subdirectory in the base directory +for SUBDIR in "$BASE_DIR"/*/; do + # Run the R script with the current subdirectory as the model_dir argument + echo "$SUBDIR" + Rscript score_forecast.R "$SUBDIR" +done diff --git a/pipelines/make_observed_data_table.py b/pipelines/make_observed_data_table.py new file mode 100644 index 00000000..6e30d91d --- /dev/null +++ b/pipelines/make_observed_data_table.py @@ -0,0 +1,72 @@ +import argparse +import datetime +from pathlib import Path + +import polars as pl +from prep_data import aggregate_facility_level_nssp_to_state, get_state_pop_df + + +def save_observed_data_table( + path_to_latest_data: str | Path, output_path: str | Path +): + data = pl.scan_parquet(path_to_latest_data) + + state_pop = get_state_pop_df() + + visits_by_disease = [ + pl.concat( + map( + lambda abb: aggregate_facility_level_nssp_to_state( + facility_level_nssp_data=data, + state_abb=abb, + disease=disease, + first_training_date=datetime.date(2023, 1, 1), + state_pop_df=state_pop, + ), + ["US"] + [x for x in state_pop["abb"]], + ) + ).filter(pl.col("disease") == disease) + for disease in ["COVID-19", "Influenza", "RSV", "Total"] + ] + + full_table = ( + pl.concat(visits_by_disease) + .pivot(on="disease", values="ed_visits") + .select( + date=pl.col("date"), + location=pl.col("geo_value"), + count_covid=pl.col("COVID-19"), + frac_covid=pl.col("COVID-19") / pl.col("Total"), + pct_covid=100 * pl.col("COVID-19") / pl.col("Total"), + count_influenza=pl.col("Influenza"), + frac_influenza=pl.col("Influenza") / pl.col("Total"), + pct_influenza=100 * pl.col("Influenza") / pl.col("Total"), + count_rsv=pl.col("RSV"), + frac_rsv=pl.col("RSV") / pl.col("Total"), + pct_rsv=100 * pl.col("RSV") / pl.col("Total"), + count_total=pl.col("Total"), + ) + .sort(["location", "date"]) + ) + + full_table.write_csv(output_path, separator="\t") + + +parser = argparse.ArgumentParser() + +parser.add_argument( + "path_to_latest_data", + type=Path, + help=( + "Path to a parquet file containing the latest " + "ED visit observations." + ), +) + +parser.add_argument( + "output_path", type=Path, help="Save the output tsv file to this path." +) + +if __name__ == "__main__": + args = parser.parse_args() + save_observed_data_table(**vars(args)) diff --git a/pipelines/postprocess_scoring.R b/pipelines/postprocess_scoring.R new file mode 100644 index 00000000..ae2407f2 --- /dev/null +++ b/pipelines/postprocess_scoring.R @@ -0,0 +1,406 @@ +script_packages <- c( + "dplyr", + "scoringutils", + "lubridate", + "ggplot2", + "argparser" +) + +## load in packages without messages +purrr::walk(script_packages, \(pkg) { + suppressPackageStartupMessages( + library(pkg, character.only = TRUE) + ) +}) + + +#' Summarise Scoring Table using quantile scores +#' +#' This function takes a scoring table and summarises it by calculating both +#' relative and absolute Weighted Interval Scores (WIS) for each model. +#' The relative WIS is computed by comparing each model to a baseline model +#' ("cdc_baseline"), while the absolute WIS, median absolute error (MAE), +#' and interval coverages (50% and 90%) are directly summarised from the +#' scoring table. +#' +#' @param quantile_scores A scoring object containing the scoring table with +#' quantile scores. +#' @param scale A string specifying the scale to filter the quantile +#' scores. Default is "natural". +#' +#' @return A data frame with summarised scores for each model, including +#' relative WIS, absolute WIS, MAE, and interval coverages (50% and 90%). +#' +#' @examples +#' # Assuming `quantile_scores` is a data frame with the necessary structure: +#' summarised_scores <- summarised_scoring_table( +#' quantile_scores, +#' scale = "natural" +#' ) +#' print(summarised_scores) +summarised_scoring_table <- function(quantile_scores, + scale = "natural") { + rel_wis <- quantile_scores |> + filter(scale == !!scale) |> + get_pairwise_comparisons(baseline = "cdc_baseline") |> + group_by(model) |> + summarise(rel_wis = mean(wis_scaled_relative_skill)) + + abs_wis <- quantile_scores |> + filter(scale == !!scale) |> + summarise_scores(by = "model") |> + select(model, + abs_wis = wis, + mae = ae_median, + interval_coverage_50, + interval_coverage_90 + ) + + summarised_scores <- left_join(abs_wis, rel_wis, by = "model") + return(summarised_scores) +} + + +location_summary_table <- function(quantile_scores, + scale = "natural") { + rel_wis <- quantile_scores |> + filter(scale == !!scale) |> + get_pairwise_comparisons( + baseline = "cdc_baseline", + by = "location" + ) |> + group_by(model, location) |> + summarise(rel_wis = mean(wis_scaled_relative_skill)) + + abs_wis <- quantile_scores |> + filter(scale == !!scale) |> + summarise_scores(by = c("model", "location")) |> + select(model, + location, + abs_wis = wis, + mae = ae_median, + interval_coverage_50, + interval_coverage_90 + ) + + summarised_scores <- left_join(abs_wis, rel_wis, + by = c("model", "location") + ) + return(summarised_scores) +} + +#' @title Epiweekly Scoring Plot +#' This function generates a line plot of the weighted interval scores (WIS) for +#' different models over epiweeks. +#' @param quantile_scores A scoringutils output table +#' containing the unsummarised quantile scores. +#' @param scale A character string specifying the scale +#' to filter the quantile +#' scores. Default is "natural". +#' @return A ggplot object representing the epiweekly scoring plot. +epiweekly_scoring_plot <- function(quantile_scores, scale = "natural") { + epiweekly_score_fig <- quantile_scores |> + filter(scale == !!scale) |> + mutate(epiweek = epiweek(date), epiyear = epiyear(date)) |> + get_pairwise_comparisons( + by = c("epiweek", "epiyear"), + baseline = "cdc_baseline" + ) |> + mutate(epidate = forecasttools::epiweek_to_date( + epiweek, + epiyear, + day_of_week = 1 + )) |> + group_by(model, epidate) |> + summarise( + wis = mean(wis_scaled_relative_skill), + .groups = "drop" + ) |> + as_tibble() |> + ggplot(aes(x = epidate, y = wis, color = model)) + + geom_line() + + geom_point() + # Add points to the line plot + labs( + title = "Epiweekly Scoring by Model", + x = "Epiweek start dates", + y = "Relative Weighted Interval Score (WIS)" + ) + + scale_y_continuous(trans = "log10") + + theme_minimal() + + return(epiweekly_score_fig) +} + +location_rel_wis_plot <- function(location, quantile_scores, ...) { + return(epiweekly_scoring_plot( + quantile_scores |> + dplyr::filter(location == !!location), + ... + ) + ggtitle( + glue::glue("Relative WIS over time for {location}") + )) +} + +location_score_table <- function(location, quantile_scores, ...) { + return +} + +#' Save a list of plots as a PDF, with a +#' grid of `nrow` by `ncol` plots per page +#' +#' @param list_of_plots list of plots to save to PDF +#' @param save_path path to which to save the plots +#' @param nrow Number of rows of plots per page +#' (passed to [gridExtra::marrangeGrob()]) +#' Default `1`. +#' @param ncol Number of columns of plots per page +#' (passed to [gridExtra::marrangeGrob()]). +#' Default `1`. +#' @param width page width in device units (passed to +#' [ggplot2::ggsave()]). Default `8.5`. +#' @param height page height in device units (passed to +#' [ggplot2::ggsave()]). Default `11`. +#' @return `TRUE` on success. +#' @export +plots_to_pdf <- function(list_of_plots, + save_path, + nrow = 1, + ncol = 1, + width = 8.5, + height = 11) { + if (!(tolower(fs::path_ext(save_path)) == ".pdf")) { + cli::cli_abort("Filepath must end with `.pdf`") + } + cli::cli_inform("Saving plots to {save_path}") + ggplot2::ggsave( + filename = save_path, + plot = gridExtra::marrangeGrob(list_of_plots, + nrow = nrow, + ncol = ncol + ), + width = width, + height = height + ) + return(TRUE) +} + +relative_wis_by_location <- function(scores, + baseline_model = "cdc_baseline") { + scoring_data <- scores |> + get_pairwise_comparisons( + by = c("date", "location"), + baseline = baseline_model + ) |> + group_by(model, location) |> + summarise( + relative_wis = mean(wis_scaled_relative_skill), + .groups = "drop" + ) |> + filter(model == "pyrenew-hew") + + min_wis <- min(scoring_data$relative_wis) + max_wis <- max(scoring_data$relative_wis) + max_overall <- max(1 / min_wis, max_wis) + theme_minimal() + + + fig <- scoring_data |> + arrange(relative_wis) |> + mutate(location = factor(location, + ordered = TRUE, + levels = location + )) |> + ggplot( + aes( + y = location, + x = relative_wis, + group = model + ) + ) + + geom_point( + shape = 21, + size = 3, + fill = "darkblue", + color = "black" + ) + + geom_vline( + xintercept = 1, + linetype = "dashed" + ) + + scale_x_continuous(trans = "log10") + + coord_cartesian(xlim = c(1 / max_overall, max_overall)) + + theme_minimal() + + return(fig) +} + + +main <- function(path_to_scores, + output_directory, + output_prefix = "") { + get_save_path <- function(filename, ext = "pdf") { + fs::path(output_directory, + glue::glue("{output_prefix}{filename}"), + ext = ext + ) + } + + scores <- readRDS(path_to_scores) + + quantile_scores <- scores$quantile_scores + + locations <- unique(quantile_scores$location) |> + purrr::set_names() + + message("Plotting relative WIS by forecast date across locations...") + + rel_wis_by_date <- epiweekly_scoring_plot( + quantile_scores, + scale = "log" + ) + + rel_wis_by_date_save_path <- get_save_path("relative_wis_by_date") + + message(glue::glue("Saving figure to {rel_wis_by_date_save_path}...")) + ggsave(rel_wis_by_date_save_path, + rel_wis_by_date, + width = 8, + height = 4 + ) + + + message("Plotting relative WIS by forecast date and location...") + rel_wis_by_date_and_location <- purrr::map(locations, + location_rel_wis_plot, + quantile_scores = + quantile_scores, + scale = "log" + ) + + rel_wis_by_date_loc_save_path <- get_save_path( + "relative_wis_by_date_and_location" + ) + + message( + glue::glue("Saving figure to {rel_wis_by_date_loc_save_path}...") + ) + + plots_to_pdf(rel_wis_by_date_and_location, + rel_wis_by_date_loc_save_path, + width = 8, + height = 4 + ) + + message("Plotting WIS components by location for pyrenew-hew...") + wis_components_by_location <- + scoringutils::plot_wis( + quantile_scores |> + filter(model == "pyrenew-hew"), + x = "location" + ) + wis_comp_by_loc_save_path <- get_save_path( + "wis_components_by_location", + ext = "png" + ) + ## scoringutils wis component plots do not save well + ## as vector images on some devices, so we rasterize + ## them to PNGs + ggsave( + wis_comp_by_loc_save_path, + wis_components_by_location + ) + + wis_components_by_model <- + scoringutils::plot_wis(quantile_scores, + x = "model" + ) + + wis_comp_by_model_save_path <- get_save_path( + "wis_components_by_model", + ext = "png" + ) + ggsave( + wis_comp_by_model_save_path, + wis_components_by_model + ) + + message("Plotting relative WIS across dates by location") + rel_wis_by_location <- relative_wis_by_location( + quantile_scores + ) + + rel_wis_by_location_save_path <- get_save_path( + "relative_wis_by_location" + ) + + ggsave(rel_wis_by_location_save_path, + rel_wis_by_location, + height = 10, + width = 4 + ) + + message("Making tables...") + table_all <- summarised_scoring_table( + quantile_scores, + scale = "log" + ) + + table_all_save_path <- get_save_path( + "overall_scores", + ext = "tsv" + ) + readr::write_tsv(table_all, table_all_save_path) + + + table_locs <- location_summary_table(quantile_scores, + scale = "log" + ) + + table_locs_save_path <- get_save_path( + "scores_by_location", + ext = "tsv" + ) + readr::write_tsv(table_locs, table_locs_save_path) + + message("Done with score postprocessing.") +} + +p <- arg_parser(paste0( + "Postprocess a raw score table, creating summary plots ", + "and tables." +)) |> + add_argument( + "path_to_scores", + help = paste0( + "Path to a file holding all scores, as an .rds ", + "file, in the list of scoringutils objects output ", + "format of collate_score_tables.R" + ) + ) |> + add_argument("--output-directory", + help = paste0( + "Output directory in which to save the ", + "generated score plots and tables. ", + "Default '.', i.e. the current working ", + "directory" + ), + default = "." + ) |> + add_argument("--output-prefix", + help = paste0( + "Prefix to append to output file names, e.g. ", + "the name(s) of the target disease(s) ", + "and/or epidemiological signal(s). ", + "Default '' (no prefix)" + ), + default = "" + ) + + +argv <- parse_args(p) + +main( + argv$path_to_scores, + argv$output_directory, + argv$output_prefix +) diff --git a/pipelines/postprocess_state_forecast.R b/pipelines/postprocess_state_forecast.R new file mode 100644 index 00000000..2e023322 --- /dev/null +++ b/pipelines/postprocess_state_forecast.R @@ -0,0 +1,275 @@ +script_packages <- c( + "dplyr", + "stringr", + "purrr", + "ggplot2", + "tidybayes", + "fs", + "cowplot", + "glue", + "scales", + "argparser", + "arrow", + "tidyr", + "readr", + "here", + "forcats", + "hewr" +) + +## load in packages without messages +purrr::walk(script_packages, \(pkg) { + suppressPackageStartupMessages( + library(pkg, character.only = TRUE) + ) +}) + + +make_one_forecast_fig <- function(target_disease, + combined_dat, + last_training_date, + data_vintage_date, + posterior_predictive_ci, + state_abb, + y_transform = "identity") { + y_scale <- if (str_starts(target_disease, "prop")) { + scale_y_continuous("Proportion of Emergency Department Visits", + labels = percent, + transform = y_transform + ) + } else { + scale_y_continuous("Emergency Department Visits", + labels = comma, + transform = y_transform + ) + } + + title <- if (target_disease == "Other") { + glue("Other ED Visits in {state_abb}") + } else { + glue("{disease_name_pretty} ED Visits in {state_abb}") + } + + ggplot(mapping = aes(date, .value)) + + geom_lineribbon( + data = posterior_predictive_ci |> filter(disease == target_disease), + mapping = aes(ymin = .lower, ymax = .upper), + color = "#08519c", + key_glyph = draw_key_rect, + step = "mid" + ) + + scale_fill_brewer( + name = "Credible Interval Width", + labels = ~ percent(as.numeric(.)) + ) + + geom_point( + mapping = aes(color = data_type), size = 1.5, + data = combined_dat |> + filter( + disease == target_disease, + date <= max(posterior_predictive_ci$date) + ) |> + mutate(data_type = fct_rev(data_type)) |> + arrange(desc(data_type)) + ) + + scale_color_manual( + name = "Data Type", + values = c("olivedrab1", "deeppink"), + labels = str_to_title + ) + + geom_vline(xintercept = last_training_date, linetype = "dashed") + + annotate( + geom = "text", + x = last_training_date, + y = -Inf, + label = "Fit Period ←\n", + hjust = "right", + vjust = "bottom" + ) + + annotate( + geom = "text", + x = last_training_date, + y = -Inf, label = "→ Forecast Period\n", + hjust = "left", + vjust = "bottom", + ) + + ggtitle(title, subtitle = glue("as of {data_vintage_date}")) + + y_scale + + scale_x_date("Date") + + theme(legend.position = "bottom") +} + + +postprocess_state_forecast <- function(model_run_dir) { + state_abb <- model_run_dir |> + path_split() |> + pluck(1) |> + tail(1) + + train_data_path <- path(model_run_dir, "data", ext = "csv") + eval_data_path <- path(model_run_dir, "eval_data", ext = "tsv") + posterior_predictive_path <- path(model_run_dir, "mcmc_tidy", + "pyrenew_posterior_predictive", + ext = "parquet" + ) + other_ed_visits_path <- path( + model_run_dir, + "other_ed_visits_forecast", + ext = "parquet" + ) + + train_dat <- read_csv(train_data_path, show_col_types = FALSE) + + data_vintage_date <- max(train_dat$date) + 1 + # this should be stored as metadata somewhere else, instead of being + # computed like this + + eval_dat <- read_tsv(eval_data_path, show_col_types = FALSE) |> + mutate(data_type = "eval") + + combined_dat <- + bind_rows( + train_dat |> + filter(data_type == "train"), + eval_dat + ) |> + mutate( + disease = if_else( + disease == disease_name_nssp, + "Disease", # assign a common name for + # use in plotting functions + disease + ) + ) |> + pivot_wider(names_from = disease, values_from = ed_visits) |> + mutate( + Other = Total - Disease, + prop_disease_ed_visits = Disease / Total + ) |> + select(-Total) |> + mutate(time = dense_rank(date)) |> + pivot_longer(c(Disease, Other, prop_disease_ed_visits), + names_to = "disease", + values_to = ".value" + ) + + + last_training_date <- combined_dat |> + filter(data_type == "train") |> + pull(date) |> + max() + + posterior_predictive <- read_parquet(posterior_predictive_path) + + other_ed_visits_forecast <- + read_parquet(other_ed_visits_path) |> + rename(Other = other_ed_visits) + + other_ed_visits_samples <- + bind_rows( + combined_dat |> + filter( + data_type == "train", + disease == "Other", + date <= last_training_date + ) |> + select(date, Other = .value) |> + expand_grid(.draw = 1:max(other_ed_visits_forecast$.draw)), + other_ed_visits_forecast + ) + + posterior_predictive_samples <- + posterior_predictive |> + gather_draws(observed_hospital_admissions[time]) |> + pivot_wider(names_from = .variable, values_from = .value) |> + rename(Disease = observed_hospital_admissions) |> + ungroup() |> + mutate(date = min(combined_dat$date) + time) |> + left_join(other_ed_visits_samples, + by = c(".draw", "date") + ) |> + mutate(prop_disease_ed_visits = Disease / (Disease + Other)) |> + pivot_longer(c(Other, Disease, prop_disease_ed_visits), + names_to = "disease", + values_to = ".value" + ) + + arrow::write_parquet( + posterior_predictive_samples, + path(model_run_dir, "forecast_samples", + ext = "parquet" + ) + ) + + posterior_predictive_ci <- + posterior_predictive_samples |> + select(date, disease, .value) |> + group_by(date, disease) |> + median_qi(.width = c(0.5, 0.8, 0.95)) + + + arrow::write_parquet( + posterior_predictive_ci, + path(model_run_dir, "forecast_ci", + ext = "parquet" + ) + ) + + + all_forecast_plots <- map( + set_names(unique(combined_dat$disease)), + ~ make_one_forecast_fig( + .x, + combined_dat, + last_training_date, + data_vintage_date, + posterior_predictive_ci, + state_abb, + ) + ) + + all_forecast_plots_log <- map( + set_names(unique(combined_dat$disease)), + ~ make_one_forecast_fig( + .x, + combined_dat, + last_training_date, + data_vintage_date, + posterior_predictive_ci, + state_abb, + y_transform = "log10" + ) + ) + + iwalk(all_forecast_plots, ~ save_plot( + filename = path(model_run_dir, glue("{.y}_forecast_plot"), ext = "pdf"), + plot = .x, + device = cairo_pdf, base_height = 6 + )) + iwalk(all_forecast_plots_log, ~ save_plot( + filename = path(model_run_dir, glue("{.y}_forecast_plot_log"), ext = "pdf"), + plot = .x, + device = cairo_pdf, base_height = 6 + )) +} + + +theme_set(theme_minimal_grid()) + +# Create a parser +p <- arg_parser("Generate forecast figures") |> + add_argument( + "model_run_dir", + help = "Directory containing the model data and output.", + ) + +argv <- parse_args(p) +model_run_dir <- path(argv$model_run_dir) + + +disease_name_nssp <- parse_model_run_dir_path(model_run_dir)$disease + +disease_name_formatter <- c("COVID-19" = "COVID-19", "Influenza" = "Flu") +disease_name_pretty <- unname(disease_name_formatter[disease_name_nssp]) + +postprocess_state_forecast(model_run_dir) diff --git a/pipelines/prep_data.py b/pipelines/prep_data.py new file mode 100644 index 00000000..ed815856 --- /dev/null +++ b/pipelines/prep_data.py @@ -0,0 +1,347 @@ +import datetime +import json +import logging +import os +from logging import Logger +from pathlib import Path + +import polars as pl + +_disease_map = { + "COVID-19": "COVID-19/Omicron", +} + +_inverse_disease_map = {v: k for k, v in _disease_map.items()} + + +def aggregate_to_national( + data: pl.LazyFrame, + geo_values_to_include, + first_date_to_include: datetime.date, + national_geo_value="US", +): + assert national_geo_value not in geo_values_to_include + return ( + data.filter( + pl.col("geo_value").is_in(geo_values_to_include), + pl.col("reference_date") >= first_date_to_include, + ) + .group_by(["disease", "metric", "geo_type", "reference_date"]) + .agg(geo_value=pl.lit(national_geo_value), value=pl.col("value").sum()) + ) + + +def process_state_level_data( + state_level_nssp_data: pl.LazyFrame, + state_abb: str, + disease: str, + first_training_date: datetime.date, + state_pop_df: pl.DataFrame, +) -> pl.DataFrame: + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + if state_level_nssp_data is None: + return pl.DataFrame( + schema={ + "date": pl.Date, + "geo_value": pl.Utf8, + "disease": pl.Utf8, + "ed_visits": pl.Float64, + } + ) + + disease_key = _disease_map.get(disease, disease) + + if state_abb == "US": + logger.info("Aggregating state-level data to national") + state_level_nssp_data = aggregate_to_national( + state_level_nssp_data, + state_pop_df["abb"].unique(), + first_training_date, + national_geo_value="US", + ) + + return ( + state_level_nssp_data.filter( + pl.col("disease").is_in([disease_key, "Total"]), + pl.col("metric") == "count_ed_visits", + pl.col("geo_value") == state_abb, + pl.col("geo_type") == "state", + pl.col("reference_date") >= first_training_date, + ) + .select( + [ + pl.col("reference_date").alias("date"), + pl.col("geo_value").cast(pl.Utf8), + pl.col("disease").cast(pl.Utf8), + pl.col("value").alias("ed_visits"), + ] + ) + .with_columns( + disease=pl.col("disease") + .cast(pl.Utf8) + .replace(_inverse_disease_map), + ) + .sort(["date", "disease"]) + .collect() + ) + + +def aggregate_facility_level_nssp_to_state( + facility_level_nssp_data: pl.LazyFrame, + state_abb: str, + disease: str, + first_training_date: str, + state_pop_df: pl.DataFrame, +) -> pl.DataFrame: + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + if facility_level_nssp_data is None: + return pl.DataFrame( + schema={ + "date": pl.Date, + "geo_value": pl.Utf8, + "disease": pl.Utf8, + "ed_visits": pl.Float64, + } + ) + + disease_key = _disease_map.get(disease, disease) + + if state_abb == "US": + logger.info("Aggregating facility-level data to national") + facility_level_nssp_data = aggregate_to_national( + facility_level_nssp_data, + state_pop_df["abb"].unique(), + first_training_date, + national_geo_value="US", + ) + + return ( + facility_level_nssp_data.filter( + pl.col("disease").is_in([disease_key, "Total"]), + pl.col("metric") == "count_ed_visits", + pl.col("geo_value") == state_abb, + pl.col("reference_date") >= first_training_date, + ) + .group_by(["reference_date", "disease"]) + .agg(pl.col("value").sum().alias("ed_visits")) + .with_columns( + disease=pl.col("disease") + .cast(pl.Utf8) + .replace(_inverse_disease_map), + geo_value=pl.lit(state_abb).cast(pl.Utf8), + ) + .rename({"reference_date": "date"}) + .sort(["date", "disease"]) + .select(["date", "geo_value", "disease", "ed_visits"]) + .collect(streaming=True) + # setting streaming = True explicitly + # avoids an `Option::unwrap()` on a `None` value + # error. Cause of error not known but presumably + # related to how parquets are processed. + ) + + +def verify_no_date_gaps(df: pl.DataFrame): + expected_length = df.select( + dur=((pl.col("date").max() - pl.col("date").min()).dt.total_days() + 1) + ).to_numpy()[0] + if not df.height == 2 * expected_length: + raise ValueError("Data frame appears to have date gaps") + + +def get_state_pop_df(): + facts = pl.read_csv( + "https://raw.githubusercontent.com/k5cents/usa/" + "refs/heads/master/data-raw/facts.csv" + ) + states = pl.read_csv( + "https://raw.githubusercontent.com/k5cents/usa/" + "refs/heads/master/data-raw/states.csv" + ) + + state_pop_df = facts.join(states, on="name").select( + ["abb", "name", "population"] + ) + + return state_pop_df + + +def get_pmfs(param_estimates: pl.LazyFrame, state_abb: str, disease: str): + generation_interval_pmf = ( + param_estimates.filter( + (pl.col("geo_value").is_null()) + & (pl.col("disease") == disease) + & (pl.col("parameter") == "generation_interval") + & (pl.col("end_date").is_null()) # most recent estimate + ) + .collect() + .get_column("value") + .to_list()[0] + ) + + delay_pmf = ( + param_estimates.filter( + (pl.col("geo_value").is_null()) + & (pl.col("disease") == disease) + & (pl.col("parameter") == "delay") + & (pl.col("end_date").is_null()) # most recent estimate + ) + .collect() + .get_column("value") + .to_list()[0] + ) + + right_truncation_pmf = ( + param_estimates.filter( + (pl.col("geo_value") == state_abb) + & (pl.col("disease") == disease) + & (pl.col("parameter") == "right_truncation") + & (pl.col("end_date").is_null()) + ) + .filter(pl.col("reference_date") == pl.col("reference_date").max()) + .collect() + .get_column("value") + .to_list()[0] + ) + + return (generation_interval_pmf, delay_pmf, right_truncation_pmf) + + +def process_and_save_state( + state_abb: str, + disease: str, + report_date: datetime.date, + state_level_report_date: datetime.date, + first_training_date: datetime.date, + last_training_date: datetime.date, + param_estimates: pl.LazyFrame, + model_run_dir: Path, + logger: Logger = None, + facility_level_nssp_data: pl.LazyFrame = None, + state_level_nssp_data: pl.LazyFrame = None, +) -> None: + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + if facility_level_nssp_data is None and state_level_nssp_data is None: + raise ValueError( + "Must provide at least one " + "of facility-level and state-level" + "NSSP data" + ) + + state_pop_df = get_state_pop_df() + + if state_abb == "US": + state_pop = state_pop_df["population"].sum() + else: + state_pop = ( + state_pop_df.filter(pl.col("abb") == state_abb) + .get_column("population") + .to_list()[0] + ) + + (generation_interval_pmf, delay_pmf, right_truncation_pmf) = get_pmfs( + param_estimates=param_estimates, state_abb=state_abb, disease=disease + ) + + right_truncation_offset = (report_date - last_training_date).days + + aggregated_facility_data = aggregate_facility_level_nssp_to_state( + facility_level_nssp_data=facility_level_nssp_data, + state_abb=state_abb, + disease=disease, + first_training_date=first_training_date, + state_pop_df=state_pop_df, + ) + + state_level_data = process_state_level_data( + state_level_nssp_data=state_level_nssp_data, + state_abb=state_abb, + disease=disease, + first_training_date=first_training_date, + state_pop_df=state_pop_df, + ) + + if aggregated_facility_data.height > 0: + first_facility_level_data_date = aggregated_facility_data.get_column( + "date" + ).min() + state_level_data = state_level_data.filter( + pl.col("date") < first_facility_level_data_date + ) + + data_to_save = ( + pl.concat([state_level_data, aggregated_facility_data]) + .with_columns( + pl.when(pl.col("date") <= last_training_date) + .then(pl.lit("train")) + .otherwise(pl.lit("test")) + .alias("data_type"), + ) + .sort(["date", "disease"]) + ) + + verify_no_date_gaps(data_to_save) + + train_disease_ed_visits = ( + data_to_save.filter( + pl.col("data_type") == "train", pl.col("disease") == disease + ) + .get_column("ed_visits") + .to_list() + ) + + test_disease_ed_visits = ( + data_to_save.filter( + pl.col("data_type") == "test", + pl.col("disease") == disease, + ) + .get_column("ed_visits") + .to_list() + ) + + train_total_ed_visits = ( + data_to_save.filter( + pl.col("data_type") == "train", pl.col("disease") == "Total" + ) + .get_column("ed_visits") + .to_list() + ) + + test_total_ed_visits = ( + data_to_save.filter( + pl.col("data_type") == "test", pl.col("disease") == "Total" + ) + .get_column("ed_visits") + .to_list() + ) + + data_for_model_fit = { + "inf_to_hosp_pmf": delay_pmf, + "generation_interval_pmf": generation_interval_pmf, + "right_truncation_pmf": right_truncation_pmf, + "data_observed_disease_hospital_admissions": train_disease_ed_visits, + "data_observed_disease_hospital_admissions_test": test_disease_ed_visits, + "data_observed_total_hospital_admissions": train_total_ed_visits, + "data_observed_total_hospital_admissions_test": test_total_ed_visits, + "state_pop": state_pop, + "right_truncation_offset": right_truncation_offset, + } + + os.makedirs(model_run_dir, exist_ok=True) + + if logger is not None: + logger.info(f"Saving {state_abb} to {model_run_dir}") + data_to_save.write_csv(Path(model_run_dir, "data.csv")) + + with open( + Path(model_run_dir, "data_for_model_fit.json"), "w" + ) as json_file: + json.dump(data_for_model_fit, json_file) + + return None diff --git a/pipelines/pull_state_timeseries.py b/pipelines/pull_state_timeseries.py new file mode 100644 index 00000000..ffba0e6c --- /dev/null +++ b/pipelines/pull_state_timeseries.py @@ -0,0 +1,146 @@ +import argparse +import datetime +import logging +from pathlib import Path + +import polars as pl + + +def main( + nssp_data_dir, + output_path, + report_date: str | datetime.date, + first_date_to_pull: str | datetime.date = None, + separator="\t", + diseases=["covid", "influenza", "rsv"], +): + diseases_to_column_names = dict( + covid="COVID-19/Omicron", + influenza="Influenza", + rsv="RSV", + total="Total", + ) + + diseases_to_pull = [ + diseases_to_column_names.get(disease) for disease in diseases + ] + + col_names_to_pull = diseases_to_pull + ["Total"] + + if isinstance(report_date, str): + if report_date == "latest": + report_date = max( + f.stem for f in Path(nssp_data_dir).glob("*.parquet") + ) + report_date = datetime.datetime.strptime( + report_date, "%Y-%m-%d" + ).date() + elif not isinstance(report_date, datetime.date): + raise ValueError( + "`report_date` must be either be a " + "a `datetime.date` object, or a string " + "giving a date in IS08601 format." + ) + + if first_date_to_pull is None: + first_date_to_pull = pl.col("reference_date").min() + elif isinstance(first_date_to_pull, str): + first_date_to_pull = datetime.datetime.strptime( + first_date_to_pull, "%Y-%m-%d" + ).date() + elif not isinstance(first_date_to_pull, datetime.date): + raise ValueError( + "`first_date_to_pull` must be `None` " + "in which case all available dates are pulled, ", + "a `datetime.date` object, or a string " + "giving a date in IS08601 format.", + ) + + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info(f"Report date: {report_date}") + + datafile = f"{report_date}.parquet" + nssp_data = pl.scan_parquet(Path(nssp_data_dir, datafile)) + + data = ( + nssp_data.filter( + pl.col("disease").is_in(col_names_to_pull), + pl.col("metric") == "count_ed_visits", + pl.col("reference_date") > first_date_to_pull, + pl.col("report_date") == report_date, + ) + .select(["reference_date", "geo_value", "disease", "value"]) + .group_by(["reference_date", "geo_value", "disease"]) + .agg(value=pl.col("value").sum()) + .sort(["reference_date", "geo_value"]) + .collect() + .pivot(on="disease", index=["reference_date", "geo_value"]) + .rename( + { + v: f"count_{k}" + for k, v in diseases_to_column_names.items() + if v in col_names_to_pull + } + ) + .with_columns( + **{ + f"frac_{x}": (pl.col(f"count_{x}") / pl.col("count_total")) + for x in diseases + } + ) + .with_columns( + **{f"pct_{x}": (100.0 * pl.col(f"frac_{x}")) for x in diseases} + ) + .select( + [ + pl.col("reference_date").alias("date"), + pl.col("geo_value").alias("location"), + ] + + [ + item + for x in diseases + for item in [f"count_{x}", f"frac_{x}", f"pct_{x}"] + ] + + ["count_total"] + ) + ) + + print(data) + + logger.info(f"Saving data to {output_path}.") + + data.write_csv(file=output_path, separator=separator) + + logger.info("Data preparation complete.") + + +parser = argparse.ArgumentParser( + description="Pull NSSP data across pathogens." +) +parser.add_argument( + "nssp_data_dir", + type=Path, + help=( + "Directory in which to look for NSSP data gold table " + ".parquet files." + ), +) +parser.add_argument( + "output_path", + type=Path, + help="Path to which to save the output file, as a tsv.", +) + +parser.add_argument( + "--report-date", + type=str, + default="latest", + help="Report date in YYYY-MM-DD format or latest (default: latest)", +) + + +if __name__ == "__main__": + args = parser.parse_args() + main(**vars(args)) diff --git a/pipelines/save_eval_data.py b/pipelines/save_eval_data.py new file mode 100644 index 00000000..6a85389b --- /dev/null +++ b/pipelines/save_eval_data.py @@ -0,0 +1,47 @@ +import datetime +import logging +from pathlib import Path + +import polars as pl +from prep_data import get_state_pop_df, process_state_level_data + + +def save_eval_data( + state: str, + disease: str, + report_date: datetime.date, + first_training_date, + last_training_date, + latest_comprehensive_path: Path | str, + output_data_dir: Path | str, + last_eval_date: datetime.date = None, + output_file_name: str = "eval_data.tsv", +): + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + logger.info("Reading in truth data...") + state_level_nssp_data = pl.scan_parquet(latest_comprehensive_path) + + if last_eval_date is not None: + state_level_nssp_data = state_level_nssp_data.filter( + pl.col("reference_date") <= last_eval_date + ) + + state_level_data = ( + process_state_level_data( + state_level_nssp_data=state_level_nssp_data, + state_abb=state, + disease=disease, + first_training_date=first_training_date, + state_pop_df=get_state_pop_df(), + ) + .with_columns(data_type=pl.lit("eval")) + .sort(["date", "disease"]) + ) + + state_level_data.write_csv( + Path(output_data_dir, output_file_name), separator="\t" + ) + + return None diff --git a/pipelines/score_forecast.R b/pipelines/score_forecast.R new file mode 100644 index 00000000..62afd9a9 --- /dev/null +++ b/pipelines/score_forecast.R @@ -0,0 +1,245 @@ +script_packages <- c( + "dplyr", + "scoringutils", + "arrow", + "argparser" +) + +## load in packages without messages +purrr::walk(script_packages, \(pkg) { + suppressPackageStartupMessages( + library(pkg, character.only = TRUE) + ) +}) + + +#' Score Forecasts +#' +#' This function scores forecast data using the `scoringutils` package. It takes +#' in scorable data, that is data which has a joined truth data and forecast +#' data, and scores it. +#' +#' This function aims at scoring _sampled_ forecasts. Care must be taken to +#' select the appropriate columns for the observed and predicted values, as well +#' as the forecast unit. The expected `sample_id` column is `.draw` due to +#' expecting input from a tidybayes format. +#' +#' NB: this function assumes that _log-scale_ scoring is the default. If you +#' want to vary this behaviour, you can splat additional arguments to +#' `scoringutils::transform_forecasts` such as the identity transformation e.g. +#' `fun = identity` with `label = "identity"`. +#' +#' If more than one model is present in the data, in the column `model_col` the +#' function will add relative skill metrics to the output. +#' +#' @param scorable_data A data frame containing the data to be scored. +#' @param forecast_unit A string specifying the forecast unit. +#' @param observed A string specifying the column name for observed +#' values. +#' @param predicted A string specifying the column name for predicted +#' values. +#' @param sample_id A string specifying the column name for sample +#' IDs. Default is ".draw". +#' @param model_col A string specifying the column name for models. +#' @param ... Additional arguments passed to +#' `scoringutils::transform_forecasts`. +#' +#' @return A data frame with scored forecasts and relative skill metrics. +#' @export +score_single_run <- function( + scorable_data, quantile_only_data, forecast_unit, observed, predicted, + sample_id = ".draw", model_col = "model", ...) { + forecast_sample_df <- scorable_data |> + scoringutils::as_forecast_sample( + forecast_unit = forecast_unit, + observed = observed, + predicted = predicted, + sample_id = sample_id + ) + + quantile_only_df <- quantile_only_data |> + scoringutils::as_forecast_quantile( + forecast_unit = forecast_unit, + observed = observed, + predicted = predicted + ) + + quants <- unique(quantile_only_df$quantile_level) + quantiles_from_samples_df <- forecast_sample_df |> + scoringutils::as_forecast_quantile(probs = quants) + + + + forecast_quantile_df <- + dplyr::bind_rows( + quantile_only_df, + quantiles_from_samples_df + ) |> + scoringutils::as_forecast_quantile() + + sample_scores <- forecast_sample_df |> + scoringutils::transform_forecasts(...) |> + scoringutils::score() + + quantile_scores <- forecast_quantile_df |> + scoringutils::transform_forecasts(...) |> + scoringutils::score() + # Add relative skill if more than one model is present + if (n_distinct(scorable_data[[model_col]]) > 1) { + sample_scores <- scoringutils::add_relative_skill(sample_scores) + quantile_scores <- scoringutils::add_relative_skill(quantile_scores) + } + return(list( + sample_scores = sample_scores, + quantile_scores = quantile_scores + )) +} + + +prep_truth_data <- function(truth_data_path) { + dat <- readr::read_tsv(truth_data_path, + show_col_types = FALSE + ) |> + filter(data_type == "eval") |> + rename(true_value = ed_visits) + + truth_data_valid <- ( + dplyr::n_distinct(dat$disease) == 2 & + "Total" %in% dat$disease & + xor( + "COVID-19" %in% dat$disease, + "Influenza" %in% dat$disease + )) + + if (!truth_data_valid) { + err_dis <- paste(unique(dat$disease), collapse = "', ") + stop( + "Evaluation data 'disease' column must ", + "have exactly two uniques entries: 'Total' ", + "and exactly one of 'COVID-19', 'Influenza'. ", + glue::glue("Got: '{err_dis}") + ) + } + + prepped_dat <- dat |> + mutate(disease = ifelse(disease %in% c("COVID-19", "Influenza"), + "Disease", + disease + )) |> + tidyr::pivot_wider( + names_from = "disease", + values_from = "true_value" + ) |> + mutate(prop_disease_ed_visits = Disease / Total) |> + tidyr::pivot_longer( + c(Disease, Total, prop_disease_ed_visits), + names_to = "disease", + values_to = "true_value" + ) + + return(prepped_dat) +} + +read_and_score_location <- function(model_run_dir, + eval_data_filename = "eval_data", + eval_data_file_ext = "tsv", + parquet_file_ext = "parquet", + rds_file_ext = "rds") { + message(glue::glue("Scoring {model_run_dir}...")) + forecast_path <- fs::path( + model_run_dir, + "forecast_samples", + ext = parquet_file_ext + ) + ts_baseline_path <- fs::path( + model_run_dir, + "baseline_ts_prop_ed_visits_forecast", + ext = parquet_file_ext + ) + cdc_baseline_path <- fs::path( + model_run_dir, + "baseline_cdc_prop_ed_visits_forecast", + ext = parquet_file_ext + ) + + truth_path <- fs::path(model_run_dir, + eval_data_filename, + ext = eval_data_file_ext + ) + + actual_data <- prep_truth_data(truth_path) + + pyrenew <- arrow::read_parquet(forecast_path) |> + mutate(model = "pyrenew-hew") |> + select(date, .draw, disease, model, .value) + + ts_baseline <- arrow::read_parquet(ts_baseline_path) |> + mutate( + model = "ts_baseline", + disease = "prop_disease_ed_visits" + ) |> + select(date, + .draw, + disease, + model, + .value = prop_disease_ed_visits + ) + + cdc_baseline <- arrow::read_parquet(cdc_baseline_path) |> + mutate( + model = "cdc_baseline", + disease = "prop_disease_ed_visits" + ) |> + select(date, + disease, + quantile_level, + .value = baseline_ed_visit_prop_forecast, + model + ) + + quantile_forecasts_to_score <- inner_join( + cdc_baseline, + actual_data, + by = c("disease", "date") + ) + + sample_forecasts_to_score <- bind_rows( + pyrenew, + ts_baseline + ) |> + inner_join(actual_data, + by = c("disease", "date") + ) |> + filter(disease == "prop_disease_ed_visits") + + max_visits <- actual_data |> + filter(disease == "Total") |> + pull(true_value) |> + max() + + scored <- score_single_run( + sample_forecasts_to_score, + quantile_forecasts_to_score, + forecast_unit = c("date", "model"), + observed = "true_value", + sample_id = ".draw", + predicted = ".value", + offset = 1 / max_visits + ) + + readr::write_rds(scored, fs::path(model_run_dir, + "score_table", + ext = rds_file_ext + )) +} + +# Create a parser +p <- arg_parser("Score a single location forecast") |> + add_argument( + "model_run_dir", + help = "Directory containing the model data and output." + ) + +argv <- parse_args(p) + +read_and_score_location(argv$model_run_dir) diff --git a/pipelines/tests/README.md b/pipelines/tests/README.md new file mode 100644 index 00000000..aae5fdfa --- /dev/null +++ b/pipelines/tests/README.md @@ -0,0 +1,10 @@ +# Test data folder + +This folder is aimed at running test-mode scripts for validating the inference +pipeline on the test data. The test data is stored in subdirectories. + +To run the test scripts, execute the following command from the `pipelines` directory: + +```bash +% bash ./tests/test_run.sh ./tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs 1000 28 +``` diff --git a/pipelines/tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs/TD/data.csv b/pipelines/tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs/TD/data.csv new file mode 100644 index 00000000..b92e5e21 --- /dev/null +++ b/pipelines/tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs/TD/data.csv @@ -0,0 +1,181 @@ +date,disease,ed_visits,data_type +2023-11-01,COVID-19,6,train +2023-11-02,COVID-19,6,train +2023-11-03,COVID-19,7,train +2023-11-04,COVID-19,10,train +2023-11-05,COVID-19,10,train +2023-11-06,COVID-19,12,train +2023-11-07,COVID-19,12,train +2023-11-08,COVID-19,10,train +2023-11-09,COVID-19,8,train +2023-11-10,COVID-19,15,train +2023-11-11,COVID-19,8,train +2023-11-12,COVID-19,9,train +2023-11-13,COVID-19,9,train +2023-11-14,COVID-19,13,train +2023-11-15,COVID-19,17,train +2023-11-16,COVID-19,7,train +2023-11-17,COVID-19,12,train +2023-11-18,COVID-19,10,train +2023-11-19,COVID-19,13,train +2023-11-20,COVID-19,10,train +2023-11-21,COVID-19,12,train +2023-11-22,COVID-19,15,train +2023-11-23,COVID-19,19,train +2023-11-24,COVID-19,19,train +2023-11-25,COVID-19,22,train +2023-11-26,COVID-19,17,train +2023-11-27,COVID-19,19,train +2023-11-28,COVID-19,14,train +2023-11-29,COVID-19,17,train +2023-11-30,COVID-19,19,train +2023-12-01,COVID-19,18,train +2023-12-02,COVID-19,13,train +2023-12-03,COVID-19,24,train +2023-12-04,COVID-19,21,train +2023-12-05,COVID-19,35,train +2023-12-06,COVID-19,26,train +2023-12-07,COVID-19,25,train +2023-12-08,COVID-19,30,train +2023-12-09,COVID-19,26,train +2023-12-10,COVID-19,20,train +2023-12-11,COVID-19,29,train +2023-12-12,COVID-19,38,train +2023-12-13,COVID-19,35,train +2023-12-14,COVID-19,41,train +2023-12-15,COVID-19,30,train +2023-12-16,COVID-19,37,train +2023-12-17,COVID-19,35,train +2023-12-18,COVID-19,46,train +2023-12-19,COVID-19,38,train +2023-12-20,COVID-19,23,train +2023-12-21,COVID-19,38,train +2023-12-22,COVID-19,22,train +2023-12-23,COVID-19,28,train +2023-12-24,COVID-19,23,train +2023-12-25,COVID-19,31,train +2023-12-26,COVID-19,19,train +2023-12-27,COVID-19,23,train +2023-12-28,COVID-19,17,train +2023-12-29,COVID-19,23,train +2023-12-30,COVID-19,26,train +2023-12-31,COVID-19,17,train +2024-01-01,COVID-19,17,train +2024-01-02,COVID-19,12,train +2024-01-03,COVID-19,13,train +2024-01-04,COVID-19,9,train +2024-01-05,COVID-19,22,train +2024-01-06,COVID-19,12,train +2024-01-07,COVID-19,13,train +2024-01-08,COVID-19,17,train +2024-01-09,COVID-19,14,train +2024-01-10,COVID-19,12,train +2024-01-11,COVID-19,6,train +2024-01-12,COVID-19,10,train +2024-01-13,COVID-19,10,train +2024-01-14,COVID-19,4,train +2024-01-15,COVID-19,12,train +2024-01-16,COVID-19,9,train +2024-01-17,COVID-19,8,train +2024-01-18,COVID-19,9,train +2024-01-19,COVID-19,8,train +2024-01-20,COVID-19,6,train +2024-01-21,COVID-19,13,train +2024-01-22,COVID-19,7,train +2024-01-23,COVID-19,8,train +2024-01-24,COVID-19,13,train +2024-01-25,COVID-19,9,train +2024-01-26,COVID-19,9,train +2024-01-27,COVID-19,17,train +2024-01-28,COVID-19,7,train +2024-01-29,COVID-19,10,train +2023-11-01,Total,105,train +2023-11-02,Total,105,train +2023-11-03,Total,104,train +2023-11-04,Total,109,train +2023-11-05,Total,105,train +2023-11-06,Total,126,train +2023-11-07,Total,118,train +2023-11-08,Total,99,train +2023-11-09,Total,119,train +2023-11-10,Total,115,train +2023-11-11,Total,106,train +2023-11-12,Total,123,train +2023-11-13,Total,104,train +2023-11-14,Total,124,train +2023-11-15,Total,102,train +2023-11-16,Total,102,train +2023-11-17,Total,130,train +2023-11-18,Total,126,train +2023-11-19,Total,112,train +2023-11-20,Total,97,train +2023-11-21,Total,109,train +2023-11-22,Total,107,train +2023-11-23,Total,102,train +2023-11-24,Total,120,train +2023-11-25,Total,125,train +2023-11-26,Total,109,train +2023-11-27,Total,110,train +2023-11-28,Total,100,train +2023-11-29,Total,118,train +2023-11-30,Total,128,train +2023-12-01,Total,123,train +2023-12-02,Total,113,train +2023-12-03,Total,122,train +2023-12-04,Total,121,train +2023-12-05,Total,151,train +2023-12-06,Total,136,train +2023-12-07,Total,140,train +2023-12-08,Total,120,train +2023-12-09,Total,143,train +2023-12-10,Total,123,train +2023-12-11,Total,121,train +2023-12-12,Total,152,train +2023-12-13,Total,125,train +2023-12-14,Total,138,train +2023-12-15,Total,153,train +2023-12-16,Total,134,train +2023-12-17,Total,134,train +2023-12-18,Total,137,train +2023-12-19,Total,149,train +2023-12-20,Total,111,train +2023-12-21,Total,135,train +2023-12-22,Total,121,train +2023-12-23,Total,129,train +2023-12-24,Total,117,train +2023-12-25,Total,147,train +2023-12-26,Total,118,train +2023-12-27,Total,128,train +2023-12-28,Total,118,train +2023-12-29,Total,140,train +2023-12-30,Total,119,train +2023-12-31,Total,110,train +2024-01-01,Total,105,train +2024-01-02,Total,108,train +2024-01-03,Total,112,train +2024-01-04,Total,94,train +2024-01-05,Total,103,train +2024-01-06,Total,111,train +2024-01-07,Total,120,train +2024-01-08,Total,126,train +2024-01-09,Total,120,train +2024-01-10,Total,124,train +2024-01-11,Total,101,train +2024-01-12,Total,128,train +2024-01-13,Total,114,train +2024-01-14,Total,102,train +2024-01-15,Total,97,train +2024-01-16,Total,89,train +2024-01-17,Total,112,train +2024-01-18,Total,116,train +2024-01-19,Total,109,train +2024-01-20,Total,97,train +2024-01-21,Total,115,train +2024-01-22,Total,118,train +2024-01-23,Total,117,train +2024-01-24,Total,106,train +2024-01-25,Total,102,train +2024-01-26,Total,100,train +2024-01-27,Total,126,train +2024-01-28,Total,116,train +2024-01-29,Total,116,train diff --git a/pipelines/tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs/TD/data_for_model_fit.json b/pipelines/tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs/TD/data_for_model_fit.json new file mode 100644 index 00000000..8a1c17ad --- /dev/null +++ b/pipelines/tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs/TD/data_for_model_fit.json @@ -0,0 +1,381 @@ +{ + "gt_max": 15, + "hosp_delay_max": 55, + "inf_to_hosp_pmf": [0.0, 0.00469384736487552, 0.0145200073436112, 0.0278627741704387, 0.0423656492135518, 0.0558071445014868, 0.0665713169684116, 0.0737925805176124, 0.0772854627892072, 0.0773666390616176, 0.0746515449009948, 0.0698761436052596, 0.0637663813017696, 0.0569581929821651, 0.0499600186601535, 0.0431457477049282, 0.0367662806214046, 0.0309702535668237, 0.0258273785539499, 0.0213504646948306, 0.0175141661880584, 0.0142698211023571, 0.0115565159519833, 0.00930888979824423, 0.00746229206759214, 0.00595605679409682, 0.00473519993107751, 0.00375117728281842, 0.00296198928038098, 0.00233187862772459, 0.00183079868293457, 0.00143377454057296, 0.00107076258525208, 0.000773006742366448, 0.000539573690886396, 0.000364177599116743, 0.000237727628685579, 0.000150157714457011, 9.18283319498657e-05, 5.44079947589853e-05, 3.12548818921464e-05, 1.74202619730274e-05, 9.42698047424712e-06, 4.95614149002087e-06, 2.53275674485913e-06, 1.25854819834554e-06, 6.08116579596933e-07, 2.85572858589747e-07, 1.30129404249734e-07, 5.73280599448306e-08, 2.4219376577964e-08, 9.6316861194457e-09, 3.43804936850951e-09, 9.34806280366887e-10, 0.0], + "mwpd": 227000.0, + "ot": 90, + "n_subpops": 5, + "n_ww_sites": 4, + "n_ww_lab_sites": 5, + "owt": 88, + "oht": 90, + "n_censored": 0, + "n_uncensored": 88, + "uot": 50, + "ht": 35, + "n_weeks": 18, + "ind_m": [ + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] + ], + "tot_weeks": 25, + "p_hosp_m": [ + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] + ], + "generation_interval_pmf": [0.161701189933765, 0.320525743089203, 0.242198071982593, 0.134825252524032, 0.0689141939998525, 0.0346219683116734, 0.017497710736154, 0.00908172017279556, 0.00483656086299504, 0.00260732346885217, 0.00143298046642562, 0.00082002579123121, 0.0004729600977183, 0.000284420637980485, 0.000179877924728358], + "ts": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + "state_pop": 3000000.0, + "subpop_size": [400000.0, 200000.0, 100000.0, 50000.0, 2250000.0], + "norm_pop": 3000000.0, + "ww_sampled_times": [2, 5, 6, 6, 8, 9, 11, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 18, 18, 18, 19, 20, 21, 22, 23, 23, 25, 26, 27, 29, 29, 29, 31, 32, 32, 33, 33, 34, 36, 36, 37, 37, 37, 39, 42, 42, 42, 43, 45, 45, 46, 47, 48, 51, 53, 58, 58, 59, 59, 63, 63, 64, 65, 65, 67, 70, 70, 73, 73, 74, 75, 76, 76, 76, 78, 80, 81, 82, 83, 83, 84, 87, 89, 91, 92, 93, 93, 95], + "hosp_times": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90], + "ww_sampled_lab_sites": [1, 5, 4, 5, 3, 2, 4, 1, 3, 2, 3, 5, 1, 3, 5, 3, 5, 2, 4, 5, 2, 3, 1, 2, 3, 5, 4, 3, 3, 2, 3, 4, 5, 1, 3, 3, 5, 5, 3, 4, 2, 3, 4, 2, 1, 3, 4, 4, 1, 5, 3, 2, 1, 4, 4, 2, 4, 1, 2, 1, 2, 4, 1, 3, 1, 1, 3, 4, 5, 1, 3, 2, 3, 4, 5, 4, 5, 5, 2, 4, 4, 4, 4, 3, 2, 3, 5, 1], + "ww_log_lod": [5.09434727489065, 4.9806950154474, 4.73771588167502, 4.9806950154474, 5.2940513994166, 4.76390186342314, 4.73771588167502, 5.09434727489065, 5.2940513994166, 4.76390186342314, 5.2940513994166, 4.9806950154474, 5.09434727489065, 5.2940513994166, 4.9806950154474, 5.2940513994166, 4.9806950154474, 4.76390186342314, 4.73771588167502, 4.9806950154474, 4.76390186342314, 5.2940513994166, 5.09434727489065, 4.76390186342314, 5.2940513994166, 4.9806950154474, 4.73771588167502, 5.2940513994166, 5.2940513994166, 4.76390186342314, 5.2940513994166, 4.73771588167502, 4.9806950154474, 5.09434727489065, 5.2940513994166, 5.2940513994166, 4.9806950154474, 4.9806950154474, 5.2940513994166, 4.73771588167502, 4.76390186342314, 5.2940513994166, 4.73771588167502, 4.76390186342314, 5.09434727489065, 5.2940513994166, 4.73771588167502, 4.73771588167502, 5.09434727489065, 4.9806950154474, 5.2940513994166, 4.76390186342314, 5.09434727489065, 4.73771588167502, 4.73771588167502, 4.76390186342314, 4.73771588167502, 5.09434727489065, 4.76390186342314, 5.09434727489065, 4.76390186342314, 4.73771588167502, 5.09434727489065, 5.2940513994166, 5.09434727489065, 5.09434727489065, 5.2940513994166, 4.73771588167502, 4.9806950154474, 5.09434727489065, 5.2940513994166, 4.76390186342314, 5.2940513994166, 4.73771588167502, 4.9806950154474, 4.73771588167502, 4.9806950154474, 4.9806950154474, 4.76390186342314, 4.73771588167502, 4.73771588167502, 4.73771588167502, 4.73771588167502, 5.2940513994166, 4.76390186342314, 5.2940513994166, 4.9806950154474, 5.09434727489065], + "ww_censored": [], + "ww_uncensored": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88], + "data_observed_disease_hospital_admissions": [6, 6, 7, 10, 10, 12, 12, 10, 8, 15, 8, 9, 9, 13, 17, 7, 12, 10, 13, 10, 12, 15, 19, 19, 22, 17, 19, 14, 17, 19, 18, 13, 24, 21, 35, 26, 25, 30, 26, 20, 29, 38, 35, 41, 30, 37, 35, 46, 38, 23, 38, 22, 28, 23, 31, 19, 23, 17, 23, 26, 17, 17, 12, 13, 9, 22, 12, 13, 17, 14, 12, 6, 10, 10, 4, 12, 9, 8, 9, 8, 6, 13, 7, 8, 13, 9, 9, 17, 7, 10], + "day_of_week": [5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3], + "log_conc": [7.84933755868547, 8.75878317201939, 7.94731683160414, 7.84271554481825, 6.84314706058328, 8.49580667918879, 8.70688802330958, 8.3080964891881, 7.51476199898784, 8.73263730516131, 7.67705894528627, 9.16138870419256, 8.2291748710282, 7.57842684442178, 8.7658468204119, 7.59564541757592, 8.6485267724375, 8.73847850883661, 8.53496849868392, 7.81189950945825, 8.72816004875312, 7.77632957957527, 8.35751634986049, 8.61237112259723, 7.14724348782869, 8.98381918138397, 9.71929374032385, 7.73223113044046, 7.64136587504144, 8.93727178027946, 7.69319599090821, 9.3539099812084, 9.77803780431265, 8.83676433943342, 7.80095273126195, 7.88428742699397, 10.6861051151009, 10.7177204634667, 7.86033075211836, 9.44031060901259, 9.19820314756664, 8.10945148438765, 9.47368587127031, 9.30021922706583, 8.71406234479695, 7.82242543443078, 8.73443519589195, 9.246306907259, 8.61154444152258, 10.8332932698813, 7.39925321859574, 9.06066397101092, 8.61830748102478, 8.86917291106784, 9.02943162748827, 8.42736799456162, 8.11764762377314, 8.03266723037298, 8.41720674557318, 7.98228105459503, 8.42950370265189, 8.47947015286844, 7.81419539933152, 6.95672231924945, 7.74818768373832, 7.66575220799484, 6.86777849754671, 7.76297112083073, 6.95229085580734, 7.56030545976231, 6.64452132423733, 7.65706249921245, 6.80982040079249, 8.28375427477922, 10.3920632090468, 8.17535318491482, 8.62625413957733, 7.21373104269779, 8.15256350454411, 8.38397288981646, 8.40418626057656, 8.23933140611506, 7.91321568315416, 6.89107975822161, 8.02468424809018, 6.87101220865236, 8.34645251673327, 7.62822976068377], + "compute_likelihood": 1, + "include_ww": 1, + "include_hosp": 1, + "if_l": 15, + "infection_feedback_pmf": [0.161701189933765, 0.320525743089203, 0.242198071982593, 0.134825252524032, 0.0689141939998525, 0.0346219683116734, 0.017497710736154, 0.00908172017279556, 0.00483656086299504, 0.00260732346885217, 0.00143298046642562, 0.00082002579123121, 0.0004729600977183, 0.000284420637980485, 0.000179877924728358], + "viral_shedding_pars": [5.0, 1.0, 5.1, 0.5, 17.0, 3.0], + "autoreg_rt_a": 2, + "autoreg_rt_b": 40, + "autoreg_rt_site_a": 1, + "autoreg_rt_site_b": 4, + "autoreg_p_hosp_a": 1, + "autoreg_p_hosp_b": 100, + "inv_sqrt_phi_prior_mean": 0.1, + "inv_sqrt_phi_prior_sd": 0.1414214, + "r_prior_mean": 1, + "r_prior_sd": 1, + "log10_g_prior_mean": 12, + "log10_g_prior_sd": 2, + "i_first_obs_over_n_prior_a": 1.0015, + "i_first_obs_over_n_prior_b": 5.9985, + "hosp_wday_effect_prior_alpha": [5, 5, 5, 5, 5, 5, 5], + "mean_initial_exp_growth_rate_prior_mean": 0, + "mean_initial_exp_growth_rate_prior_sd": 0.01, + "sigma_initial_exp_growth_rate_prior_mode": 0, + "sigma_initial_exp_growth_rate_prior_sd": 0.05, + "mode_sigma_ww_site_prior_mode": 1, + "mode_sigma_ww_site_prior_sd": 1, + "sd_log_sigma_ww_site_prior_mode": 0, + "sd_log_sigma_ww_site_prior_sd": 0.693, + "eta_sd_sd": 0.01, + "sigma_i_first_obs_prior_mode": 0, + "sigma_i_first_obs_prior_sd": 0.5, + "p_hosp_prior_mean": 0.01, + "p_hosp_sd_logit": 0.3, + "p_hosp_w_sd_sd": 0.01, + "ww_site_mod_sd_sd": 0.25, + "inf_feedback_prior_logmean": 6.37408, + "inf_feedback_prior_logsd": 0.4, + "sigma_rt_prior": 0.1, + "log_phi_g_prior_mean": -2.302585, + "log_phi_g_prior_sd": 5, + "ww_sampled_sites": [1, 4, 3, 4, 2, 1, 3, 1, 2, 1, 2, 4, 1, 2, 4, 2, 4, 1, 3, 4, 1, 2, 1, 1, 2, 4, 3, 2, 2, 1, 2, 3, 4, 1, 2, 2, 4, 4, 2, 3, 1, 2, 3, 1, 1, 2, 3, 3, 1, 4, 2, 1, 1, 3, 3, 1, 3, 1, 1, 1, 1, 3, 1, 2, 1, 1, 2, 3, 4, 1, 2, 1, 2, 3, 4, 3, 4, 4, 1, 3, 3, 3, 3, 2, 1, 2, 4, 1], + "lab_site_to_site_map": [1, 1, 2, 3, 4], + "right_truncation_pmf": [1], + "right_truncation_offset": 0 +} diff --git a/pipelines/tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs/TD/eval_data.tsv b/pipelines/tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs/TD/eval_data.tsv new file mode 100644 index 00000000..2b4d5846 --- /dev/null +++ b/pipelines/tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs/TD/eval_data.tsv @@ -0,0 +1,57 @@ +date disease ed_visits data_type +2024-01-29 COVID-19 17 eval +2024-01-30 COVID-19 10 eval +2024-01-31 COVID-19 13 eval +2024-02-01 COVID-19 17 eval +2024-02-02 COVID-19 18 eval +2024-02-03 COVID-19 9 eval +2024-02-04 COVID-19 12 eval +2024-02-05 COVID-19 21 eval +2024-02-06 COVID-19 17 eval +2024-02-07 COVID-19 16 eval +2024-02-08 COVID-19 18 eval +2024-02-09 COVID-19 16 eval +2024-02-10 COVID-19 18 eval +2024-02-11 COVID-19 27 eval +2024-02-12 COVID-19 22 eval +2024-02-13 COVID-19 19 eval +2024-02-14 COVID-19 23 eval +2024-02-15 COVID-19 22 eval +2024-02-16 COVID-19 21 eval +2024-02-17 COVID-19 35 eval +2024-02-18 COVID-19 25 eval +2024-02-19 COVID-19 32 eval +2024-02-20 COVID-19 40 eval +2024-02-21 COVID-19 25 eval +2024-02-22 COVID-19 41 eval +2024-02-23 COVID-19 48 eval +2024-02-24 COVID-19 34 eval +2024-02-25 COVID-19 37 eval +2024-01-29 Total 119 eval +2024-01-30 Total 128 eval +2024-01-31 Total 119 eval +2024-02-01 Total 116 eval +2024-02-02 Total 112 eval +2024-02-03 Total 108 eval +2024-02-04 Total 128 eval +2024-02-05 Total 117 eval +2024-02-06 Total 108 eval +2024-02-07 Total 121 eval +2024-02-08 Total 112 eval +2024-02-09 Total 123 eval +2024-02-10 Total 120 eval +2024-02-11 Total 115 eval +2024-02-12 Total 117 eval +2024-02-13 Total 130 eval +2024-02-14 Total 126 eval +2024-02-15 Total 136 eval +2024-02-16 Total 122 eval +2024-02-17 Total 144 eval +2024-02-18 Total 118 eval +2024-02-19 Total 131 eval +2024-02-20 Total 130 eval +2024-02-21 Total 124 eval +2024-02-22 Total 131 eval +2024-02-23 Total 150 eval +2024-02-24 Total 134 eval +2024-02-25 Total 141 eval diff --git a/nssp_demo/priors.py b/pipelines/tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs/TD/priors.py similarity index 64% rename from nssp_demo/priors.py rename to pipelines/tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs/TD/priors.py index 07637134..4f9d61ab 100644 --- a/nssp_demo/priors.py +++ b/pipelines/tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs/TD/priors.py @@ -4,11 +4,6 @@ from numpyro.infer.reparam import LocScaleReparam from pyrenew.randomvariable import DistributionalVariable, TransformedVariable -from pyrenew_hew.utils import convert_to_logmean_log_sd - -# many of these should probably be different depending on if we are modeling flu -# or covid - i0_first_obs_n_rv = DistributionalVariable( "i0_first_obs_n_rv", dist.Beta(1, 10), @@ -18,13 +13,15 @@ "rate", dist.Normal(0, 0.01), reparam=LocScaleReparam(0) ) -r_logmean, r_logsd = convert_to_logmean_log_sd(1, 1) +r_logmean = jnp.log(1) +r_logsd = jnp.log(jnp.sqrt(2)) + log_r_mu_intercept_rv = DistributionalVariable( "log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd) ) eta_sd_rv = DistributionalVariable( - "eta_sd", dist.TruncatedNormal(0, 0.01, low=0) + "eta_sd", dist.TruncatedNormal(0.04, 0.02, low=0) ) autoreg_rt_rv = DistributionalVariable("autoreg_rt", dist.Beta(2, 40)) @@ -34,16 +31,14 @@ "inf_feedback", DistributionalVariable( "inf_feedback_raw", - dist.LogNormal(6.37408, 0.4), + dist.LogNormal(jnp.log(50), jnp.log(2)), ), transforms=transformation.AffineTransform(loc=0, scale=-1), ) # Could be reparameterized? -# Note: multiplied by 1/2 from hosp model -# this actually represents ed admissions -p_hosp_mean_rv = DistributionalVariable( - "p_hosp_mean", +p_ed_visit_mean_rv = DistributionalVariable( + "p_ed_visit_mean", dist.Normal( transformation.SigmoidTransform().inv(0.005), 0.3, @@ -51,17 +46,19 @@ ) # logit scale -p_hosp_w_sd_rv = DistributionalVariable( - "p_hosp_w_sd_sd", dist.TruncatedNormal(0, 0.01, low=0) +p_ed_visit_w_sd_rv = DistributionalVariable( + "p_ed_visit_w_sd_sd", dist.TruncatedNormal(0, 0.01, low=0) ) -autoreg_p_hosp_rv = DistributionalVariable("autoreg_p_hosp", dist.Beta(1, 100)) +autoreg_p_ed_visit_rv = DistributionalVariable( + "autoreg_p_ed_visit_rv", dist.Beta(1, 100) +) -hosp_wday_effect_rv = TransformedVariable( - "hosp_wday_effect", +ed_visit_wday_effect_rv = TransformedVariable( + "ed_visit_wday_effect", DistributionalVariable( - "hosp_wday_effect_raw", + "ed_visit_wday_effect_raw", dist.Dirichlet(jnp.array([5, 5, 5, 5, 5, 5, 5])), ), transformation.AffineTransform(loc=0, scale=7), diff --git a/pipelines/tests/test_run.sh b/pipelines/tests/test_run.sh new file mode 100644 index 00000000..8042cf41 --- /dev/null +++ b/pipelines/tests/test_run.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +# Check if the base directory is provided as an argument +if [ -z "$1" ]; then + echo "Usage: $0 " + exit 1 +fi + +# Base directory containing subdirectories +BASE_DIR="$1" +N_SAMPLES=$2 +N_AHEAD=$3 + +# Iterate over each subdirectory in the base directory +echo "TEST-MODE: Running loop over subdirectories in $BASE_DIR" +echo "For $N_SAMPLES samples on 1 chain, and $N_AHEAD forecast points" +for SUBDIR in "$BASE_DIR"/*/; do + echo "TEST-MODE: Inference for $SUBDIR" + python fit_model.py "$SUBDIR" --n-chains 1 --n-samples $N_SAMPLES + echo "TEST-MODE: Finished inference" + echo "TEST-MODE: Generating posterior predictions for $SUBDIR" + python generate_predictive.py "$SUBDIR" --n-forecast-points $N_AHEAD + echo "TEST-MODE: Finished generating posterior predictions" + echo "TEST-MODE: Converting inferencedata to parquet for $SUBDIR" + Rscript convert_inferencedata_to_parquet.R "$SUBDIR" + echo "TEST-MODE: Finished converting inferencedata to parquet" + echo "TEST-MODE: Forecasting baseline models for $SUBDIR" + Rscript timeseries_forecasts.R "$SUBDIR" --n-forecast-days $N_AHEAD --n-samples $N_SAMPLES + echo "TEST-MODE: Finished forecasting baseline models" + echo "TEST-MODE: Postprocessing state forecast for $SUBDIR" + Rscript postprocess_state_forecast.R "$SUBDIR" + echo "TEST-MODE: Finished postprocessing state forecast" + echo "TEST-MODE: Scoring forecast for $SUBDIR" + Rscript score_forecast.R "$SUBDIR" + echo "TEST-MODE: Finished scoring forecast" +done diff --git a/pipelines/timeseries_forecasts.R b/pipelines/timeseries_forecasts.R new file mode 100644 index 00000000..1251029a --- /dev/null +++ b/pipelines/timeseries_forecasts.R @@ -0,0 +1,259 @@ +script_packages <- c( + "dplyr", + "tidyr", + "tibble", + "readr", + "stringr", + "fs", + "fable", + "jsonlite", + "argparser", + "arrow", + "glue", + "epipredict", + "epiprocess", + "purrr", + "rlang", + "glue", + "hewr" +) + +## load in packages without messages +purrr::walk(script_packages, \(pkg) { + suppressPackageStartupMessages( + library(pkg, character.only = TRUE) + ) +}) + + +to_prop_forecast <- function(forecast_disease_count, + forecast_other_count, + disease_count_col = + "baseline_ed_visit_count_forecast", + other_count_col = + "other_ed_visits", + output_col = "prop_disease_ed_visits") { + result <- inner_join( + forecast_disease_count, + forecast_other_count, + by = c(".draw", "date") + ) |> + mutate( + !!output_col := + .data[[disease_count_col]] / + (.data[[disease_count_col]] + + .data[[other_count_col]]) + ) + + return(result) +} + + +#' Fit and Forecast Time Series Data +#' +#' This function fits a combination ensemble model to the training data and +#' generates forecast samples for a specified number of days. +#' +#' @param data A data frame containing the time series data. It should have a +#' column named `data_type` to distinguish between training and other data. +#' @param n_forecast_days An integer specifying the number of days to forecast. +#' Default is 28. +#' @param n_samples An integer specifying the number of forecast samples to +#' generate. Default is 2000. +#' @param target_col A string specifying the name of the target column in the +#' data. Default is "ed_visits". +#' @param output_col A string specifying the name of the output column for the +#' forecasted values. Default is "other_ed_visits". +# +#' @return A tibble containing the forecast samples with columns for date, +#' draw number, and forecasted values. +fit_and_forecast <- function(data, + n_forecast_days = 28, + n_samples = 2000, + target_col = "ed_visits", + output_col = "other_ed_visits") { + forecast_horizon <- glue("{n_forecast_days} days") + target_sym <- sym(target_col) + output_sym <- sym(output_col) + + max_visits <- data |> + pull(!!target_sym) |> + max(na.rm = TRUE) + offset <- 1 / max_visits + + fit <- + data |> + as_tsibble(index = date) |> + filter(data_type == "train") |> + model( + comb_model = combination_ensemble( + ETS(log(!!target_sym + offset) ~ trend(method = c("N", "M", "A"))), + ARIMA(log(!!target_sym + offset)) + ) + ) + + forecast_samples <- fit |> + generate(h = forecast_horizon, times = n_samples) |> + as_tibble() |> + mutate("{output_col}" := .sim, .draw = as.integer(.rep)) |> # nolint + select(date, .draw, !!output_sym) + + forecast_samples +} + +#' Generate CDC Flat Forecast +#' +#' This function generates a CDC flat forecast for the given data and returns +#' a data frame containing the forecasted values with columns for quantile +#' levels, reference dates, and target end dates suitable for use with +#' `scoringutils`. +#' +#' @param data A data frame containing the input data. +#' @param target_col A string specifying the column name of the target variable +#' in the data. Default is "ed_visits". +#' @param output_col A string specifying the column name for the output variable +#' in the forecast. Default is "other_ed_visits". +#' @param ... Additional arguments passed to the +#' `epipredict::cdc_baseline_args_list` function. +#' @return A data frame containing the forecasted values with columns for +#' quantile levels, (forecast) dates, and target values +cdc_flat_forecast <- function(data, + target_col = "ed_visits_target", + output_col = "cdc_flat_ed_visits", + ...) { + opts <- cdc_baseline_args_list(...) + # coerce data to epiprocess::epi_df format + epi_data <- data |> + filter(data_type == "train") |> + mutate(geo_value = "us", time_value = date) |> + as_epi_df() + # fit the model + cdc_flat_fit <- cdc_baseline_forecaster(epi_data, target_col, opts) + # generate forecast + cdc_flat_forecast <- cdc_flat_fit$predictions |> + pivot_quantiles_longer(.pred_distn) |> + mutate(!!output_col := values) |> + rename( + quantile_level = quantile_levels, report_date = forecast_date, + date = target_date + ) |> + select(date, quantile_level, all_of(output_col)) + + return(cdc_flat_forecast) +} + +main <- function(model_run_dir, n_forecast_days = 28, n_samples = 2000) { + # to do: do this with json data that has dates + data_path <- path(model_run_dir, "data", ext = "csv") + + target_and_other_data <- read_csv( + data_path, + col_types = cols( + disease = col_character(), + data_type = col_character(), + ed_visits = col_double(), + date = col_date() + ) + ) |> + mutate(disease = if_else( + disease == disease_name_nssp, + "Disease", disease + )) |> + pivot_wider(names_from = disease, values_from = ed_visits) |> + mutate(Other = Total - Disease) |> + select(date, + ed_visits_target = Disease, ed_visits_other = Other, + data_type + ) + ## Time series forecasting + ## Fit and forecast other (non-target-disease) ED visits using a combination + ## ensemble model + forecast_other <- fit_and_forecast( + target_and_other_data, + n_forecast_days, + n_samples, + target_col = "ed_visits_other", + output_col = "other_ed_visits" + ) + baseline_ts_count <- fit_and_forecast( + target_and_other_data, + n_forecast_days, + n_samples, + target_col = "ed_visits_target", + output_col = "baseline_ed_visit_count_forecast" + ) + ## Generate CDC flat forecast for the target disease number of ED visits + baseline_cdc_count <- cdc_flat_forecast( + target_and_other_data, + target_col = "ed_visits_target", + output_col = "baseline_ed_visit_count_forecast", + data_frequency = "1 day", + aheads = 1:n_forecast_days + ) + + baseline_ts_prop <- baseline_ts_count |> + to_prop_forecast(forecast_other) + + baseline_cdc_prop <- cdc_flat_forecast( + target_and_other_data |> + mutate(ed_visits_prop = ed_visits_target / + (ed_visits_target + ed_visits_other)), + target_col = "ed_visits_prop", + output_col = "baseline_ed_visit_prop_forecast", + data_frequency = "1 day", + aheads = 1:n_forecast_days + ) + + to_save <- tribble( + ~basename, ~value, + "other_ed_visits_forecast", forecast_other, + "baseline_ts_count_ed_visits_forecast", baseline_ts_count, + "baseline_ts_prop_ed_visits_forecast", baseline_ts_prop, + "baseline_cdc_count_ed_visits_forecast", baseline_cdc_count, + "baseline_cdc_prop_ed_visits_forecast", baseline_cdc_prop + ) |> + mutate(save_path = path( + !!model_run_dir, basename, + ext = "parquet" + )) + + + walk2( + to_save$value, + to_save$save_path, + write_parquet + ) +} + + +p <- arg_parser( + "Forecast other (non-target-disease) ED visits for a given location." +) |> + add_argument( + "model_run_dir", + help = "Directory containing the model data and output.", + ) |> + add_argument( + "--n-forecast-days", + help = "Number of days to forecast.", + default = 28L + ) |> + add_argument( + "--n-samples", + help = "Number of samples to generate.", + default = 2000L + ) + +argv <- parse_args(p) +model_run_dir <- path(argv$model_run_dir) +n_forecast_days <- argv$n_forecast_days +n_samples <- argv$n_samples + +disease_name_nssp_map <- c( + "covid-19" = "COVID-19", + "influenza" = "Influenza" +) + +disease_name_nssp <- parse_model_run_dir_path(model_run_dir)$disease + +main(model_run_dir, n_forecast_days, n_samples) diff --git a/pipelines/utils.py b/pipelines/utils.py new file mode 100644 index 00000000..c18215cc --- /dev/null +++ b/pipelines/utils.py @@ -0,0 +1,135 @@ +""" +Python utilities for the NSSP ED visit forecasting +pipeline. +""" + +import datetime +import os +import re +from collections.abc import MutableSequence +from pathlib import Path + +disease_map_lower_ = {"influenza": "Influenza", "covid-19": "COVID-19"} + + +def ensure_listlike(x): + """ + Ensure that an object either behaves like a + :class:`MutableSequence` and if not return a + one-item :class:`list` containing the object. + + Useful for handling list-of-strings inputs + alongside single strings. + + Based on this _`StackOverflow approach + `. + + Parameters + ---------- + x + The item to ensure is :class:`list`-like. + + Returns + ------- + MutableSequence + ``x`` if ``x`` is a :class:`MutableSequence` + otherwise ``[x]`` (i.e. a one-item list containing + ``x``. + """ + return x if isinstance(x, MutableSequence) else [x] + + +def parse_model_batch_dir_name(model_batch_dir_name): + """ + Parse the name of a model batch directory, + returning a dictionary of parsed values. + + Parameters + ---------- + model_batch_dir_name + Model batch directory name to parse. + + Returns + ------- + dict + A dictionary with keys 'disease', 'report_date', + 'first_training_date', and 'last_training_date'. + """ + regex_match = re.match(r"(.+)_r_(.+)_f_(.+)_t_(.+)", model_batch_dir_name) + if regex_match: + disease, report_date, first_training_date, last_training_date = ( + regex_match.groups() + ) + else: + raise ValueError( + "Invalid model batch directory name format: " + f"{model_batch_dir_name}" + ) + return dict( + disease=disease_map_lower_[disease], + report_date=datetime.strptime(report_date, "%Y-%m-%d").date(), + first_training_date=datetime.strptime( + first_training_date, "%Y-%m-%d" + ).date(), + last_training_date=datetime.strptime( + last_training_date, "%Y-%m-%d" + ).date(), + ) + + +def get_all_forecast_dirs( + parent_dir: Path | str, + diseases: str | list[str], + report_date: str | datetime.date = None, +) -> list[str]: + """ + Get all the subdirectories within a parent directory + that match the pattern for a forecast run for a + given disease and optionally a given report date. + + Parameters + ---------- + parent_dir + Directory in which to look for forecast subdirectories. + + diseases + Name of the diseases to match, as a list of strings, + or a single disease as a string. + + Returns + ------- + list[str] + Matching directories, if any, otherwise an empty + list. + + Raises + ------ + ValueError + Given an invalid ``report_date``. + """ + diseases = ensure_listlike(diseases) + + if report_date is None: + report_date_str = "" + elif isinstance(report_date, str): + report_date_str = report_date + elif isinstance(report_date, datetime.date): + report_date_str = f"{report_date:%Y-%m-%d}" + else: + raise ValueError( + "report_date must be one of None, " + "a string in the format YYYY-MM-DD " + "or a datetime.date instance. " + f"Got {type(report_date)}." + ) + valid_starts = tuple( + [f"{disease.lower()}_r_{report_date_str}" for disease in diseases] + ) + # by convention, disease names are + # lowercase in directory patterns + + return [ + f.name + for f in os.scandir(parent_dir) + if f.is_dir() and f.name.startswith(valid_starts) + ] diff --git a/pyproject.toml b/pyproject.toml index bad4d0c7..db52edc5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,12 @@ pyyaml = "^6.0.2" jupyter = "^1.0.0" ipykernel = "^6.29.5" polars = "^1.5.0" +pypdf = "^5.1.0" +pyarrow = "^18.0.0" +pygit2 = "^1.16.0" +[tool.poetry.group.azurebatch.dependencies] +azuretools = {git = "https://github.com/cdcent/cfa-stf-azuretools"} [build-system] requires = ["poetry-core"] diff --git a/pyrenew-hew.Rproj b/pyrenew-hew.Rproj deleted file mode 100644 index e83436a3..00000000 --- a/pyrenew-hew.Rproj +++ /dev/null @@ -1,16 +0,0 @@ -Version: 1.0 - -RestoreWorkspace: Default -SaveWorkspace: Default -AlwaysSaveHistory: Default - -EnableCodeIndexing: Yes -UseSpacesForTab: Yes -NumSpacesForTab: 2 -Encoding: UTF-8 - -RnwWeave: Sweave -LaTeX: pdfLaTeX - -AutoAppendNewline: Yes -StripTrailingWhitespace: Yes diff --git a/pyrenew_hew/hosp_only_ww_model.py b/pyrenew_hew/hosp_only_ww_model.py index 77f4a2e7..6ef5dd2f 100644 --- a/pyrenew_hew/hosp_only_ww_model.py +++ b/pyrenew_hew/hosp_only_ww_model.py @@ -85,23 +85,26 @@ def validate(self): # numpydoc ignore=GL08 def sample( self, n_datapoints=None, - data_observed_hospital_admissions=None, + data_observed_disease_hospital_admissions=None, right_truncation_offset=None, ): # numpydoc ignore=GL08 - if n_datapoints is None and data_observed_hospital_admissions is None: + if ( + n_datapoints is None + and data_observed_disease_hospital_admissions is None + ): raise ValueError( "Either n_datapoints or data_observed_hosp_admissions " "must be passed." ) elif ( n_datapoints is not None - and data_observed_hospital_admissions is not None + and data_observed_disease_hospital_admissions is not None ): raise ValueError( - "Cannot pass both n_datapoints and data_observed_hospital_admissions." + "Cannot pass both n_datapoints and data_observed_disease_hospital_admissions." ) elif n_datapoints is None: - n_datapoints = len(data_observed_hospital_admissions) + n_datapoints = len(data_observed_disease_hospital_admissions) else: n_datapoints = n_datapoints @@ -225,7 +228,7 @@ def sample( observed_hospital_admissions = hospital_admission_obs_rv( mu=latent_hospital_admissions_now, - obs=data_observed_hospital_admissions, + obs=data_observed_disease_hospital_admissions, ) return observed_hospital_admissions @@ -360,7 +363,7 @@ def create_hosp_only_ww_model_from_stan_data(stan_data_file): uot = len(jnp.array(stan_data["inf_to_hosp"])) state_pop = stan_data["state_pop"] - data_observed_hospital_admissions = jnp.array(stan_data["hosp"]) + data_observed_disease_hospital_admissions = jnp.array(stan_data["hosp"]) right_truncation_pmf_rv = DeterministicVariable( "right_truncation_pmf", jnp.array(1) ) @@ -384,4 +387,4 @@ def create_hosp_only_ww_model_from_stan_data(stan_data_file): n_initialization_points=uot, ) - return my_model, data_observed_hospital_admissions + return my_model, data_observed_disease_hospital_admissions diff --git a/renv.lock b/renv.lock deleted file mode 100644 index 1bf4c236..00000000 --- a/renv.lock +++ /dev/null @@ -1,2342 +0,0 @@ -{ - "R": { - "Version": "4.4.0", - "Repositories": [ - { - "Name": "CRAN", - "URL": "https://packagemanager.posit.co/cran/latest" - } - ] - }, - "Packages": { - "AzureAuth": { - "Package": "AzureAuth", - "Version": "1.3.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "R6", - "httr", - "jose", - "jsonlite", - "openssl", - "rappdirs", - "utils" - ], - "Hash": "3ce531ce76e84cf7f86c4deb01f225ce" - }, - "AzureGraph": { - "Package": "AzureGraph", - "Version": "1.3.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "AzureAuth", - "R", - "R6", - "curl", - "httr", - "jsonlite", - "openssl", - "utils" - ], - "Hash": "9e5449c6a5d0a7d3a03097667ff63850" - }, - "AzureRMR": { - "Package": "AzureRMR", - "Version": "2.4.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "AzureAuth", - "AzureGraph", - "R", - "R6", - "httr", - "jsonlite", - "parallel", - "utils", - "uuid" - ], - "Hash": "5a24a1da5e363cb279807ccc67af511e" - }, - "AzureStor": { - "Package": "AzureStor", - "Version": "3.7.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "AzureRMR", - "R", - "R6", - "httr", - "mime", - "openssl", - "utils", - "vctrs", - "xml2" - ], - "Hash": "ddb01acbe698467420adec8580198916" - }, - "BH": { - "Package": "BH", - "Version": "1.84.0-0", - "Source": "Repository", - "Repository": "CRAN", - "Hash": "a8235afbcd6316e6e91433ea47661013" - }, - "CFAEpiNow2Pipeline": { - "Package": "CFAEpiNow2Pipeline", - "Version": "0.0.0.9000", - "Source": "GitHub", - "RemoteType": "github", - "RemoteHost": "api.github.com", - "RemoteUsername": "CDCgov", - "RemoteRepo": "cfa-epinow2-pipeline", - "RemoteRef": "main", - "RemoteSha": "c04c8f857e49530ffad946ea582f404a42098735", - "Remotes": "github::epiforecasts/EpiNow2@bcf297cf36a93cc56123bc3c9e8cebfb1421a962", - "Requirements": [ - "AzureRMR", - "AzureStor", - "DBI", - "EpiNow2", - "R", - "cli", - "duckdb", - "rlang" - ], - "Hash": "8e012d5b5d7114f9ff707973a74afd79" - }, - "DBI": { - "Package": "DBI", - "Version": "1.2.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "methods" - ], - "Hash": "065ae649b05f1ff66bb0c793107508f5" - }, - "EpiNow2": { - "Package": "EpiNow2", - "Version": "1.4.9000", - "Source": "GitHub", - "RemoteType": "github", - "RemoteHost": "api.github.com", - "RemoteUsername": "epiforecasts", - "RemoteRepo": "EpiNow2", - "RemoteRef": "bcf297cf36a93cc56123bc3c9e8cebfb1421a962", - "RemoteSha": "bcf297cf36a93cc56123bc3c9e8cebfb1421a962", - "Requirements": [ - "BH", - "R", - "R.utils", - "Rcpp", - "RcppEigen", - "RcppParallel", - "StanHeaders", - "checkmate", - "data.table", - "futile.logger", - "future", - "future.apply", - "ggplot2", - "lifecycle", - "lubridate", - "methods", - "patchwork", - "progressr", - "purrr", - "rlang", - "rstan", - "rstantools", - "runner", - "scales", - "stats", - "truncnorm", - "utils" - ], - "Hash": "6f97be142a54e5ab72037e51b693f17b" - }, - "MASS": { - "Package": "MASS", - "Version": "7.3-60.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "grDevices", - "graphics", - "methods", - "stats", - "utils" - ], - "Hash": "2f342c46163b0b54d7b64d1f798e2c78" - }, - "Matrix": { - "Package": "Matrix", - "Version": "1.7-0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "grDevices", - "graphics", - "grid", - "lattice", - "methods", - "stats", - "utils" - ], - "Hash": "1920b2f11133b12350024297d8a4ff4a" - }, - "QuickJSR": { - "Package": "QuickJSR", - "Version": "1.3.1", - "Source": "Repository", - "Repository": "CRAN", - "Hash": "46cd5be5141c4d4eedb1a6f89acf7c29" - }, - "R.methodsS3": { - "Package": "R.methodsS3", - "Version": "1.8.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "utils" - ], - "Hash": "278c286fd6e9e75d0c2e8f731ea445c8" - }, - "R.oo": { - "Package": "R.oo", - "Version": "1.26.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "R.methodsS3", - "methods", - "utils" - ], - "Hash": "4fed809e53ddb5407b3da3d0f572e591" - }, - "R.utils": { - "Package": "R.utils", - "Version": "2.12.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "R.methodsS3", - "R.oo", - "methods", - "tools", - "utils" - ], - "Hash": "3dc2829b790254bfba21e60965787651" - }, - "R6": { - "Package": "R6", - "Version": "2.5.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "470851b6d5d0ac559e9d01bb352b4021" - }, - "RColorBrewer": { - "Package": "RColorBrewer", - "Version": "1.1-3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "45f0398006e83a5b10b72a90663d8d8c" - }, - "Rcpp": { - "Package": "Rcpp", - "Version": "1.0.13", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "methods", - "utils" - ], - "Hash": "f27411eb6d9c3dada5edd444b8416675" - }, - "RcppEigen": { - "Package": "RcppEigen", - "Version": "0.3.4.0.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "Rcpp", - "stats", - "utils" - ], - "Hash": "4ac8e423216b8b70cb9653d1b3f71eb9" - }, - "RcppParallel": { - "Package": "RcppParallel", - "Version": "5.1.9", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "f38a72a419b91faac0ce5d9eee04c120" - }, - "RcppTOML": { - "Package": "RcppTOML", - "Version": "0.2.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "Rcpp" - ], - "Hash": "c232938949fcd8126034419cc529333a" - }, - "StanHeaders": { - "Package": "StanHeaders", - "Version": "2.32.10", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "RcppEigen", - "RcppParallel" - ], - "Hash": "c35dc5b81d7ffb1018aa090dff364ecb" - }, - "abind": { - "Package": "abind", - "Version": "1.4-5", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "methods", - "utils" - ], - "Hash": "4f57884290cc75ab22f4af9e9d4ca862" - }, - "arrayhelpers": { - "Package": "arrayhelpers", - "Version": "1.1-0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "methods", - "svUnit", - "utils" - ], - "Hash": "3d4e52d458784c335af3846f2de64f75" - }, - "askpass": { - "Package": "askpass", - "Version": "1.2.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "sys" - ], - "Hash": "cad6cf7f1d5f6e906700b9d3e718c796" - }, - "backports": { - "Package": "backports", - "Version": "1.5.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "e1e1b9d75c37401117b636b7ae50827a" - }, - "base64enc": { - "Package": "base64enc", - "Version": "0.1-3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "543776ae6848fde2f48ff3816d0628bc" - }, - "bit": { - "Package": "bit", - "Version": "4.0.5", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "d242abec29412ce988848d0294b208fd" - }, - "bit64": { - "Package": "bit64", - "Version": "4.0.5", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "bit", - "methods", - "stats", - "utils" - ], - "Hash": "9fe98599ca456d6552421db0d6772d8f" - }, - "blob": { - "Package": "blob", - "Version": "1.2.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "methods", - "rlang", - "vctrs" - ], - "Hash": "40415719b5a479b87949f3aa0aee737c" - }, - "bookdown": { - "Package": "bookdown", - "Version": "0.40", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "htmltools", - "jquerylib", - "knitr", - "rmarkdown", - "tinytex", - "xfun", - "yaml" - ], - "Hash": "896a79478a50c78fb035a37148638f4e" - }, - "broom": { - "Package": "broom", - "Version": "1.0.6", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "backports", - "dplyr", - "generics", - "glue", - "lifecycle", - "purrr", - "rlang", - "stringr", - "tibble", - "tidyr" - ], - "Hash": "a4652c36d1f8abfc3ddf4774f768c934" - }, - "bslib": { - "Package": "bslib", - "Version": "0.8.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "base64enc", - "cachem", - "fastmap", - "grDevices", - "htmltools", - "jquerylib", - "jsonlite", - "lifecycle", - "memoise", - "mime", - "rlang", - "sass" - ], - "Hash": "b299c6741ca9746fb227debcb0f9fb6c" - }, - "cachem": { - "Package": "cachem", - "Version": "1.1.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "fastmap", - "rlang" - ], - "Hash": "cd9a672193789068eb5a2aad65a0dedf" - }, - "callr": { - "Package": "callr", - "Version": "3.7.6", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "R6", - "processx", - "utils" - ], - "Hash": "d7e13f49c19103ece9e58ad2d83a7354" - }, - "cellranger": { - "Package": "cellranger", - "Version": "1.1.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "rematch", - "tibble" - ], - "Hash": "f61dbaec772ccd2e17705c1e872e9e7c" - }, - "checkmate": { - "Package": "checkmate", - "Version": "2.3.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "backports", - "utils" - ], - "Hash": "0e14e01ce07e7c88fd25de6d4260d26b" - }, - "cli": { - "Package": "cli", - "Version": "3.6.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "utils" - ], - "Hash": "b21916dd77a27642b447374a5d30ecf3" - }, - "clipr": { - "Package": "clipr", - "Version": "0.8.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "utils" - ], - "Hash": "3f038e5ac7f41d4ac41ce658c85e3042" - }, - "cmdstanr": { - "Package": "cmdstanr", - "Version": "0.8.1.9000", - "Source": "GitHub", - "RemoteType": "github", - "RemoteHost": "api.github.com", - "RemoteUsername": "stan-dev", - "RemoteRepo": "cmdstanr", - "RemoteRef": "master", - "RemoteSha": "f2e152b88fde5c2cde01ff078d5715b3b6248628", - "Requirements": [ - "R", - "R6", - "checkmate", - "data.table", - "jsonlite", - "posterior", - "processx", - "rlang", - "withr" - ], - "Hash": "1b4ced11b3b6c23f0e90a6becee4d983" - }, - "coda": { - "Package": "coda", - "Version": "0.19-4.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "lattice" - ], - "Hash": "af436915c590afc6fffc3ce3a5be1569" - }, - "codetools": { - "Package": "codetools", - "Version": "0.2-19", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "c089a619a7fae175d149d89164f8c7d8" - }, - "colorspace": { - "Package": "colorspace", - "Version": "2.1-1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "grDevices", - "graphics", - "methods", - "stats" - ], - "Hash": "d954cb1c57e8d8b756165d7ba18aa55a" - }, - "conflicted": { - "Package": "conflicted", - "Version": "1.2.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cli", - "memoise", - "rlang" - ], - "Hash": "bb097fccb22d156624fd07cd2894ddb6" - }, - "cowplot": { - "Package": "cowplot", - "Version": "1.1.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "ggplot2", - "grDevices", - "grid", - "gtable", - "methods", - "rlang", - "scales" - ], - "Hash": "8ef2084dd7d28847b374e55440e4f8cb" - }, - "cpp11": { - "Package": "cpp11", - "Version": "0.4.7", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "5a295d7d963cc5035284dcdbaf334f4e" - }, - "crayon": { - "Package": "crayon", - "Version": "1.5.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "grDevices", - "methods", - "utils" - ], - "Hash": "859d96e65ef198fd43e82b9628d593ef" - }, - "curl": { - "Package": "curl", - "Version": "5.2.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "8f27335f2bcff4d6035edcc82d7d46de" - }, - "data.table": { - "Package": "data.table", - "Version": "1.15.4", - "Source": "Repository", - "Repository": "RSPM", - "Requirements": [ - "R", - "methods" - ], - "Hash": "8ee9ac56ef633d0c7cab8b2ca87d683e" - }, - "dbplyr": { - "Package": "dbplyr", - "Version": "2.5.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "DBI", - "R", - "R6", - "blob", - "cli", - "dplyr", - "glue", - "lifecycle", - "magrittr", - "methods", - "pillar", - "purrr", - "rlang", - "tibble", - "tidyr", - "tidyselect", - "utils", - "vctrs", - "withr" - ], - "Hash": "39b2e002522bfd258039ee4e889e0fd1" - }, - "desc": { - "Package": "desc", - "Version": "1.4.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "R6", - "cli", - "utils" - ], - "Hash": "99b79fcbd6c4d1ce087f5c5c758b384f" - }, - "digest": { - "Package": "digest", - "Version": "0.6.37", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "utils" - ], - "Hash": "33698c4b3127fc9f506654607fb73676" - }, - "distributional": { - "Package": "distributional", - "Version": "0.4.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "generics", - "lifecycle", - "numDeriv", - "rlang", - "stats", - "utils", - "vctrs" - ], - "Hash": "3bad76869f2257ea4fd00a3c08c2bcce" - }, - "dplyr": { - "Package": "dplyr", - "Version": "1.1.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "R6", - "cli", - "generics", - "glue", - "lifecycle", - "magrittr", - "methods", - "pillar", - "rlang", - "tibble", - "tidyselect", - "utils", - "vctrs" - ], - "Hash": "fedd9d00c2944ff00a0e2696ccf048ec" - }, - "dtplyr": { - "Package": "dtplyr", - "Version": "1.3.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cli", - "data.table", - "dplyr", - "glue", - "lifecycle", - "rlang", - "tibble", - "tidyselect", - "vctrs" - ], - "Hash": "54ed3ea01b11e81a86544faaecfef8e2" - }, - "duckdb": { - "Package": "duckdb", - "Version": "1.0.0-2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "DBI", - "R", - "methods", - "utils" - ], - "Hash": "c68785a280aa69dbe449b3cf98fa3dd1" - }, - "evaluate": { - "Package": "evaluate", - "Version": "0.24.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "methods" - ], - "Hash": "a1066cbc05caee9a4bf6d90f194ff4da" - }, - "fansi": { - "Package": "fansi", - "Version": "1.0.6", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "grDevices", - "utils" - ], - "Hash": "962174cf2aeb5b9eea581522286a911f" - }, - "farver": { - "Package": "farver", - "Version": "2.1.2", - "Source": "Repository", - "Repository": "CRAN", - "Hash": "680887028577f3fa2a81e410ed0d6e42" - }, - "fastmap": { - "Package": "fastmap", - "Version": "1.2.0", - "Source": "Repository", - "Repository": "CRAN", - "Hash": "aa5e1cd11c2d15497494c5292d7ffcc8" - }, - "fontawesome": { - "Package": "fontawesome", - "Version": "0.5.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "htmltools", - "rlang" - ], - "Hash": "c2efdd5f0bcd1ea861c2d4e2a883a67d" - }, - "forcats": { - "Package": "forcats", - "Version": "1.0.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cli", - "glue", - "lifecycle", - "magrittr", - "rlang", - "tibble" - ], - "Hash": "1a0a9a3d5083d0d573c4214576f1e690" - }, - "formatR": { - "Package": "formatR", - "Version": "1.14", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "63cb26d12517c7863f5abb006c5e0f25" - }, - "fs": { - "Package": "fs", - "Version": "1.6.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "methods" - ], - "Hash": "15aeb8c27f5ea5161f9f6a641fafd93a" - }, - "futile.logger": { - "Package": "futile.logger", - "Version": "1.4.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "futile.options", - "lambda.r", - "utils" - ], - "Hash": "99f0ace8c05ec7d3683d27083c4f1e7e" - }, - "futile.options": { - "Package": "futile.options", - "Version": "1.0.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "0d9bf02413ddc2bbe8da9ce369dcdd2b" - }, - "future": { - "Package": "future", - "Version": "1.34.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "digest", - "globals", - "listenv", - "parallel", - "parallelly", - "utils" - ], - "Hash": "475771e3edb711591476be387c9a8c2e" - }, - "future.apply": { - "Package": "future.apply", - "Version": "1.11.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "future", - "globals", - "parallel", - "utils" - ], - "Hash": "afe1507511629f44572e6c53b9baeb7c" - }, - "gargle": { - "Package": "gargle", - "Version": "1.5.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cli", - "fs", - "glue", - "httr", - "jsonlite", - "lifecycle", - "openssl", - "rappdirs", - "rlang", - "stats", - "utils", - "withr" - ], - "Hash": "fc0b272e5847c58cd5da9b20eedbd026" - }, - "generics": { - "Package": "generics", - "Version": "0.1.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "methods" - ], - "Hash": "15e9634c0fcd294799e9b2e929ed1b86" - }, - "ggdist": { - "Package": "ggdist", - "Version": "3.3.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "Rcpp", - "cli", - "distributional", - "ggplot2", - "glue", - "grid", - "gtable", - "numDeriv", - "quadprog", - "rlang", - "scales", - "tibble", - "vctrs", - "withr" - ], - "Hash": "86ebb3543cdad6520be9bf8863167a9a" - }, - "ggplot2": { - "Package": "ggplot2", - "Version": "3.5.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "MASS", - "R", - "cli", - "glue", - "grDevices", - "grid", - "gtable", - "isoband", - "lifecycle", - "mgcv", - "rlang", - "scales", - "stats", - "tibble", - "vctrs", - "withr" - ], - "Hash": "44c6a2f8202d5b7e878ea274b1092426" - }, - "globals": { - "Package": "globals", - "Version": "0.16.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "codetools" - ], - "Hash": "2580567908cafd4f187c1e5a91e98b7f" - }, - "glue": { - "Package": "glue", - "Version": "1.7.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "methods" - ], - "Hash": "e0b3a53876554bd45879e596cdb10a52" - }, - "googledrive": { - "Package": "googledrive", - "Version": "2.1.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cli", - "gargle", - "glue", - "httr", - "jsonlite", - "lifecycle", - "magrittr", - "pillar", - "purrr", - "rlang", - "tibble", - "utils", - "uuid", - "vctrs", - "withr" - ], - "Hash": "e99641edef03e2a5e87f0a0b1fcc97f4" - }, - "googlesheets4": { - "Package": "googlesheets4", - "Version": "1.1.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cellranger", - "cli", - "curl", - "gargle", - "glue", - "googledrive", - "httr", - "ids", - "lifecycle", - "magrittr", - "methods", - "purrr", - "rematch2", - "rlang", - "tibble", - "utils", - "vctrs", - "withr" - ], - "Hash": "d6db1667059d027da730decdc214b959" - }, - "gridExtra": { - "Package": "gridExtra", - "Version": "2.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "grDevices", - "graphics", - "grid", - "gtable", - "utils" - ], - "Hash": "7d7f283939f563670a697165b2cf5560" - }, - "gtable": { - "Package": "gtable", - "Version": "0.3.5", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cli", - "glue", - "grid", - "lifecycle", - "rlang" - ], - "Hash": "e18861963cbc65a27736e02b3cd3c4a0" - }, - "haven": { - "Package": "haven", - "Version": "2.5.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cli", - "cpp11", - "forcats", - "hms", - "lifecycle", - "methods", - "readr", - "rlang", - "tibble", - "tidyselect", - "vctrs" - ], - "Hash": "9171f898db9d9c4c1b2c745adc2c1ef1" - }, - "here": { - "Package": "here", - "Version": "1.0.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "rprojroot" - ], - "Hash": "24b224366f9c2e7534d2344d10d59211" - }, - "highr": { - "Package": "highr", - "Version": "0.11", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "xfun" - ], - "Hash": "d65ba49117ca223614f71b60d85b8ab7" - }, - "hms": { - "Package": "hms", - "Version": "1.1.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "lifecycle", - "methods", - "pkgconfig", - "rlang", - "vctrs" - ], - "Hash": "b59377caa7ed00fa41808342002138f9" - }, - "htmltools": { - "Package": "htmltools", - "Version": "0.5.8.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "base64enc", - "digest", - "fastmap", - "grDevices", - "rlang", - "utils" - ], - "Hash": "81d371a9cc60640e74e4ab6ac46dcedc" - }, - "httr": { - "Package": "httr", - "Version": "1.4.7", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "R6", - "curl", - "jsonlite", - "mime", - "openssl" - ], - "Hash": "ac107251d9d9fd72f0ca8049988f1d7f" - }, - "ids": { - "Package": "ids", - "Version": "1.0.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "openssl", - "uuid" - ], - "Hash": "99df65cfef20e525ed38c3d2577f7190" - }, - "inline": { - "Package": "inline", - "Version": "0.3.19", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "methods" - ], - "Hash": "1deaf1de3eac7e1d3377954b3a283652" - }, - "isoband": { - "Package": "isoband", - "Version": "0.2.7", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "grid", - "utils" - ], - "Hash": "0080607b4a1a7b28979aecef976d8bc2" - }, - "jose": { - "Package": "jose", - "Version": "1.2.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "jsonlite", - "openssl" - ], - "Hash": "acf397c005e2d96a4b7616bf7dfc3112" - }, - "jquerylib": { - "Package": "jquerylib", - "Version": "0.1.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "htmltools" - ], - "Hash": "5aab57a3bd297eee1c1d862735972182" - }, - "jsonlite": { - "Package": "jsonlite", - "Version": "1.8.8", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "methods" - ], - "Hash": "e1b9c55281c5adc4dd113652d9e26768" - }, - "knitr": { - "Package": "knitr", - "Version": "1.48", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "evaluate", - "highr", - "methods", - "tools", - "xfun", - "yaml" - ], - "Hash": "acf380f300c721da9fde7df115a5f86f" - }, - "labeling": { - "Package": "labeling", - "Version": "0.4.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "graphics", - "stats" - ], - "Hash": "b64ec208ac5bc1852b285f665d6368b3" - }, - "lambda.r": { - "Package": "lambda.r", - "Version": "1.2.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "formatR" - ], - "Hash": "b1e925c4b9ffeb901bacf812cbe9a6ad" - }, - "lattice": { - "Package": "lattice", - "Version": "0.22-6", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "grDevices", - "graphics", - "grid", - "stats", - "utils" - ], - "Hash": "cc5ac1ba4c238c7ca9fa6a87ca11a7e2" - }, - "lifecycle": { - "Package": "lifecycle", - "Version": "1.0.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cli", - "glue", - "rlang" - ], - "Hash": "b8552d117e1b808b09a832f589b79035" - }, - "listenv": { - "Package": "listenv", - "Version": "0.9.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "e2fca3e12e4db979dccc6e519b10a7ee" - }, - "loo": { - "Package": "loo", - "Version": "2.8.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "checkmate", - "matrixStats", - "parallel", - "posterior", - "stats" - ], - "Hash": "b0fe731e5bd801dda962ac5057a548f6" - }, - "lubridate": { - "Package": "lubridate", - "Version": "1.9.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "generics", - "methods", - "timechange" - ], - "Hash": "680ad542fbcf801442c83a6ac5a2126c" - }, - "magrittr": { - "Package": "magrittr", - "Version": "2.0.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "7ce2733a9826b3aeb1775d56fd305472" - }, - "matrixStats": { - "Package": "matrixStats", - "Version": "1.3.0", - "Source": "Repository", - "Repository": "RSPM", - "Requirements": [ - "R" - ], - "Hash": "4b3ea27a19d669c0405b38134d89a9d1" - }, - "memoise": { - "Package": "memoise", - "Version": "2.0.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "cachem", - "rlang" - ], - "Hash": "e2817ccf4a065c5d9d7f2cfbe7c1d78c" - }, - "mgcv": { - "Package": "mgcv", - "Version": "1.9-1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "Matrix", - "R", - "graphics", - "methods", - "nlme", - "splines", - "stats", - "utils" - ], - "Hash": "110ee9d83b496279960e162ac97764ce" - }, - "mime": { - "Package": "mime", - "Version": "0.12", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "tools" - ], - "Hash": "18e9c28c1d3ca1560ce30658b22ce104" - }, - "modelr": { - "Package": "modelr", - "Version": "0.1.11", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "broom", - "magrittr", - "purrr", - "rlang", - "tibble", - "tidyr", - "tidyselect", - "vctrs" - ], - "Hash": "4f50122dc256b1b6996a4703fecea821" - }, - "munsell": { - "Package": "munsell", - "Version": "0.5.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "colorspace", - "methods" - ], - "Hash": "4fd8900853b746af55b81fda99da7695" - }, - "nlme": { - "Package": "nlme", - "Version": "3.1-164", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "graphics", - "lattice", - "stats", - "utils" - ], - "Hash": "a623a2239e642806158bc4dc3f51565d" - }, - "numDeriv": { - "Package": "numDeriv", - "Version": "2016.8-1.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "df58958f293b166e4ab885ebcad90e02" - }, - "openssl": { - "Package": "openssl", - "Version": "2.2.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "askpass" - ], - "Hash": "c62edf62de70cadf40553e10c739049d" - }, - "parallelly": { - "Package": "parallelly", - "Version": "1.38.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "parallel", - "tools", - "utils" - ], - "Hash": "6e8b139c1904f5e9e14c69db64453bbe" - }, - "patchwork": { - "Package": "patchwork", - "Version": "1.2.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "cli", - "ggplot2", - "grDevices", - "graphics", - "grid", - "gtable", - "rlang", - "stats", - "utils" - ], - "Hash": "9c8ab14c00ac07e9e04d1664c0b74486" - }, - "pillar": { - "Package": "pillar", - "Version": "1.9.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "cli", - "fansi", - "glue", - "lifecycle", - "rlang", - "utf8", - "utils", - "vctrs" - ], - "Hash": "15da5a8412f317beeee6175fbc76f4bb" - }, - "pkgbuild": { - "Package": "pkgbuild", - "Version": "1.4.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "R6", - "callr", - "cli", - "desc", - "processx" - ], - "Hash": "a29e8e134a460a01e0ca67a4763c595b" - }, - "pkgconfig": { - "Package": "pkgconfig", - "Version": "2.0.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "utils" - ], - "Hash": "01f28d4278f15c76cddbea05899c5d6f" - }, - "posterior": { - "Package": "posterior", - "Version": "1.6.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "abind", - "checkmate", - "distributional", - "matrixStats", - "methods", - "parallel", - "pillar", - "rlang", - "stats", - "tensorA", - "tibble", - "vctrs" - ], - "Hash": "fc1213566f2ed9f0b15bef656ed1000b" - }, - "prettyunits": { - "Package": "prettyunits", - "Version": "1.2.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "6b01fc98b1e86c4f705ce9dcfd2f57c7" - }, - "processx": { - "Package": "processx", - "Version": "3.8.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "R6", - "ps", - "utils" - ], - "Hash": "0c90a7d71988856bad2a2a45dd871bb9" - }, - "progress": { - "Package": "progress", - "Version": "1.2.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "R6", - "crayon", - "hms", - "prettyunits" - ], - "Hash": "f4625e061cb2865f111b47ff163a5ca6" - }, - "progressr": { - "Package": "progressr", - "Version": "0.14.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "digest", - "utils" - ], - "Hash": "ac50c4ffa8f6a46580dd4d7813add3c4" - }, - "ps": { - "Package": "ps", - "Version": "1.7.7", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "utils" - ], - "Hash": "878b467580097e9c383acbb16adab57a" - }, - "purrr": { - "Package": "purrr", - "Version": "1.0.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cli", - "lifecycle", - "magrittr", - "rlang", - "vctrs" - ], - "Hash": "1cba04a4e9414bdefc9dcaa99649a8dc" - }, - "quadprog": { - "Package": "quadprog", - "Version": "1.5-8", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "5f919ae5e7f83a6f91dcf2288943370d" - }, - "ragg": { - "Package": "ragg", - "Version": "1.3.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "systemfonts", - "textshaping" - ], - "Hash": "e3087db406e079a8a2fd87f413918ed3" - }, - "rappdirs": { - "Package": "rappdirs", - "Version": "0.3.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "5e3c5dc0b071b21fa128676560dbe94d" - }, - "readr": { - "Package": "readr", - "Version": "2.1.5", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "R6", - "cli", - "clipr", - "cpp11", - "crayon", - "hms", - "lifecycle", - "methods", - "rlang", - "tibble", - "tzdb", - "utils", - "vroom" - ], - "Hash": "9de96463d2117f6ac49980577939dfb3" - }, - "readxl": { - "Package": "readxl", - "Version": "1.4.3", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cellranger", - "cpp11", - "progress", - "tibble", - "utils" - ], - "Hash": "8cf9c239b96df1bbb133b74aef77ad0a" - }, - "rematch": { - "Package": "rematch", - "Version": "2.0.0", - "Source": "Repository", - "Repository": "CRAN", - "Hash": "cbff1b666c6fa6d21202f07e2318d4f1" - }, - "rematch2": { - "Package": "rematch2", - "Version": "2.1.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "tibble" - ], - "Hash": "76c9e04c712a05848ae7a23d2f170a40" - }, - "renv": { - "Package": "renv", - "Version": "1.0.7", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "utils" - ], - "Hash": "397b7b2a265bc5a7a06852524dabae20" - }, - "reprex": { - "Package": "reprex", - "Version": "2.1.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "callr", - "cli", - "clipr", - "fs", - "glue", - "knitr", - "lifecycle", - "rlang", - "rmarkdown", - "rstudioapi", - "utils", - "withr" - ], - "Hash": "97b1d5361a24d9fb588db7afe3e5bcbf" - }, - "rlang": { - "Package": "rlang", - "Version": "1.1.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "utils" - ], - "Hash": "3eec01f8b1dee337674b2e34ab1f9bc1" - }, - "rmarkdown": { - "Package": "rmarkdown", - "Version": "2.28", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "bslib", - "evaluate", - "fontawesome", - "htmltools", - "jquerylib", - "jsonlite", - "knitr", - "methods", - "tinytex", - "tools", - "utils", - "xfun", - "yaml" - ], - "Hash": "062470668513dcda416927085ee9bdc7" - }, - "rprojroot": { - "Package": "rprojroot", - "Version": "2.0.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "4c8415e0ec1e29f3f4f6fc108bef0144" - }, - "rstan": { - "Package": "rstan", - "Version": "2.32.6", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "BH", - "QuickJSR", - "R", - "Rcpp", - "RcppEigen", - "RcppParallel", - "StanHeaders", - "ggplot2", - "gridExtra", - "inline", - "loo", - "methods", - "pkgbuild", - "stats4" - ], - "Hash": "8a5b5978f888a3477c116e0395d006f8" - }, - "rstantools": { - "Package": "rstantools", - "Version": "2.4.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "Rcpp", - "RcppParallel", - "desc", - "stats", - "utils" - ], - "Hash": "23813e635fcd210c33e154aa46d0a21a" - }, - "rstudioapi": { - "Package": "rstudioapi", - "Version": "0.16.0", - "Source": "Repository", - "Repository": "CRAN", - "Hash": "96710351d642b70e8f02ddeb237c46a7" - }, - "runner": { - "Package": "runner", - "Version": "0.4.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "Rcpp", - "methods", - "parallel" - ], - "Hash": "28a8c14ad9f77d5b275938c65128c7e7" - }, - "rvest": { - "Package": "rvest", - "Version": "1.0.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cli", - "glue", - "httr", - "lifecycle", - "magrittr", - "rlang", - "selectr", - "tibble", - "xml2" - ], - "Hash": "0bcf0c6f274e90ea314b812a6d19a519" - }, - "sass": { - "Package": "sass", - "Version": "0.4.9", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R6", - "fs", - "htmltools", - "rappdirs", - "rlang" - ], - "Hash": "d53dbfddf695303ea4ad66f86e99b95d" - }, - "scales": { - "Package": "scales", - "Version": "1.3.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "R6", - "RColorBrewer", - "cli", - "farver", - "glue", - "labeling", - "lifecycle", - "munsell", - "rlang", - "viridisLite" - ], - "Hash": "c19df082ba346b0ffa6f833e92de34d1" - }, - "selectr": { - "Package": "selectr", - "Version": "0.4-2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "R6", - "methods", - "stringr" - ], - "Hash": "3838071b66e0c566d55cc26bd6e27bf4" - }, - "stringi": { - "Package": "stringi", - "Version": "1.8.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "stats", - "tools", - "utils" - ], - "Hash": "39e1144fd75428983dc3f63aa53dfa91" - }, - "stringr": { - "Package": "stringr", - "Version": "1.5.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cli", - "glue", - "lifecycle", - "magrittr", - "rlang", - "stringi", - "vctrs" - ], - "Hash": "960e2ae9e09656611e0b8214ad543207" - }, - "svUnit": { - "Package": "svUnit", - "Version": "1.0.6", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "utils" - ], - "Hash": "894d0aff9585344af4ab78e2b4d60ab7" - }, - "sys": { - "Package": "sys", - "Version": "3.4.2", - "Source": "Repository", - "Repository": "CRAN", - "Hash": "3a1be13d68d47a8cd0bfd74739ca1555" - }, - "systemfonts": { - "Package": "systemfonts", - "Version": "1.1.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cpp11", - "lifecycle" - ], - "Hash": "213b6b8ed5afbf934843e6c3b090d418" - }, - "tensorA": { - "Package": "tensorA", - "Version": "0.36.2.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "stats" - ], - "Hash": "0d587599172f2ffda2c09cb6b854e0e5" - }, - "textshaping": { - "Package": "textshaping", - "Version": "0.4.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cpp11", - "lifecycle", - "systemfonts" - ], - "Hash": "5142f8bc78ed3d819d26461b641627ce" - }, - "tibble": { - "Package": "tibble", - "Version": "3.2.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "fansi", - "lifecycle", - "magrittr", - "methods", - "pillar", - "pkgconfig", - "rlang", - "utils", - "vctrs" - ], - "Hash": "a84e2cc86d07289b3b6f5069df7a004c" - }, - "tidybayes": { - "Package": "tidybayes", - "Version": "3.0.6", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "arrayhelpers", - "cli", - "coda", - "dplyr", - "ggdist", - "ggplot2", - "magrittr", - "methods", - "posterior", - "rlang", - "tibble", - "tidyr", - "tidyselect", - "vctrs", - "withr" - ], - "Hash": "3853596458d8b9c81f7b39ee920c5319" - }, - "tidyr": { - "Package": "tidyr", - "Version": "1.3.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cli", - "cpp11", - "dplyr", - "glue", - "lifecycle", - "magrittr", - "purrr", - "rlang", - "stringr", - "tibble", - "tidyselect", - "utils", - "vctrs" - ], - "Hash": "915fb7ce036c22a6a33b5a8adb712eb1" - }, - "tidyselect": { - "Package": "tidyselect", - "Version": "1.2.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cli", - "glue", - "lifecycle", - "rlang", - "vctrs", - "withr" - ], - "Hash": "829f27b9c4919c16b593794a6344d6c0" - }, - "tidyverse": { - "Package": "tidyverse", - "Version": "2.0.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "broom", - "cli", - "conflicted", - "dbplyr", - "dplyr", - "dtplyr", - "forcats", - "ggplot2", - "googledrive", - "googlesheets4", - "haven", - "hms", - "httr", - "jsonlite", - "lubridate", - "magrittr", - "modelr", - "pillar", - "purrr", - "ragg", - "readr", - "readxl", - "reprex", - "rlang", - "rstudioapi", - "rvest", - "stringr", - "tibble", - "tidyr", - "xml2" - ], - "Hash": "c328568cd14ea89a83bd4ca7f54ae07e" - }, - "timechange": { - "Package": "timechange", - "Version": "0.3.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cpp11" - ], - "Hash": "c5f3c201b931cd6474d17d8700ccb1c8" - }, - "tinytex": { - "Package": "tinytex", - "Version": "0.52", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "xfun" - ], - "Hash": "cfbad971a71f0e27cec22e544a08bc3b" - }, - "truncnorm": { - "Package": "truncnorm", - "Version": "1.0-9", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "ef5b32c5194351ff409dfb37ca9468f1" - }, - "tzdb": { - "Package": "tzdb", - "Version": "0.4.0", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cpp11" - ], - "Hash": "f561504ec2897f4d46f0c7657e488ae1" - }, - "utf8": { - "Package": "utf8", - "Version": "1.2.4", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "62b65c52671e6665f803ff02954446e9" - }, - "uuid": { - "Package": "uuid", - "Version": "1.2-1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "34e965e62a41fcafb1ca60e9b142085b" - }, - "vctrs": { - "Package": "vctrs", - "Version": "0.6.5", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cli", - "glue", - "lifecycle", - "rlang" - ], - "Hash": "c03fa420630029418f7e6da3667aac4a" - }, - "viridisLite": { - "Package": "viridisLite", - "Version": "0.4.2", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R" - ], - "Hash": "c826c7c4241b6fc89ff55aaea3fa7491" - }, - "vroom": { - "Package": "vroom", - "Version": "1.6.5", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "bit64", - "cli", - "cpp11", - "crayon", - "glue", - "hms", - "lifecycle", - "methods", - "progress", - "rlang", - "stats", - "tibble", - "tidyselect", - "tzdb", - "vctrs", - "withr" - ], - "Hash": "390f9315bc0025be03012054103d227c" - }, - "withr": { - "Package": "withr", - "Version": "3.0.1", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "grDevices", - "graphics" - ], - "Hash": "07909200e8bbe90426fbfeb73e1e27aa" - }, - "wwinference": { - "Package": "wwinference", - "Version": "0.0.0.9000", - "Source": "GitHub", - "RemoteType": "github", - "RemoteHost": "api.github.com", - "RemoteUsername": "CDCgov", - "RemoteRepo": "ww-inference-model", - "RemoteRef": "main", - "RemoteSha": "fd76331687daf3ba727eae1aa5d999124289518e", - "Remotes": "stan-dev/cmdstanr", - "Requirements": [ - "R", - "RcppTOML", - "checkmate", - "cli", - "cmdstanr", - "dplyr", - "fs", - "ggplot2", - "glue", - "lubridate", - "posterior", - "purrr", - "rlang", - "scales", - "tibble", - "tidybayes", - "tidyr", - "withr" - ], - "Hash": "43bd8fe6dc5294bc60ab514345e25b8c" - }, - "xfun": { - "Package": "xfun", - "Version": "0.47", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "grDevices", - "stats", - "tools" - ], - "Hash": "36ab21660e2d095fef0d83f689e0477c" - }, - "xml2": { - "Package": "xml2", - "Version": "1.3.6", - "Source": "Repository", - "Repository": "CRAN", - "Requirements": [ - "R", - "cli", - "methods", - "rlang" - ], - "Hash": "1d0336142f4cd25d8d23cd3ba7a8fb61" - }, - "yaml": { - "Package": "yaml", - "Version": "2.3.10", - "Source": "Repository", - "Repository": "CRAN", - "Hash": "51dab85c6c98e50a18d7551e9d49f76c" - } - } -} diff --git a/renv/.gitignore b/renv/.gitignore deleted file mode 100644 index 0ec0cbba..00000000 --- a/renv/.gitignore +++ /dev/null @@ -1,7 +0,0 @@ -library/ -local/ -cellar/ -lock/ -python/ -sandbox/ -staging/ diff --git a/renv/activate.R b/renv/activate.R deleted file mode 100644 index 29143208..00000000 --- a/renv/activate.R +++ /dev/null @@ -1,1220 +0,0 @@ - -local({ - - # the requested version of renv - version <- "1.0.7" - attr(version, "sha") <- NULL - - # the project directory - project <- Sys.getenv("RENV_PROJECT") - if (!nzchar(project)) - project <- getwd() - - # use start-up diagnostics if enabled - diagnostics <- Sys.getenv("RENV_STARTUP_DIAGNOSTICS", unset = "FALSE") - if (diagnostics) { - start <- Sys.time() - profile <- tempfile("renv-startup-", fileext = ".Rprof") - utils::Rprof(profile) - on.exit({ - utils::Rprof(NULL) - elapsed <- signif(difftime(Sys.time(), start, units = "auto"), digits = 2L) - writeLines(sprintf("- renv took %s to run the autoloader.", format(elapsed))) - writeLines(sprintf("- Profile: %s", profile)) - print(utils::summaryRprof(profile)) - }, add = TRUE) - } - - # figure out whether the autoloader is enabled - enabled <- local({ - - # first, check config option - override <- getOption("renv.config.autoloader.enabled") - if (!is.null(override)) - return(override) - - # if we're being run in a context where R_LIBS is already set, - # don't load -- presumably we're being run as a sub-process and - # the parent process has already set up library paths for us - rcmd <- Sys.getenv("R_CMD", unset = NA) - rlibs <- Sys.getenv("R_LIBS", unset = NA) - if (!is.na(rlibs) && !is.na(rcmd)) - return(FALSE) - - # next, check environment variables - # TODO: prefer using the configuration one in the future - envvars <- c( - "RENV_CONFIG_AUTOLOADER_ENABLED", - "RENV_AUTOLOADER_ENABLED", - "RENV_ACTIVATE_PROJECT" - ) - - for (envvar in envvars) { - envval <- Sys.getenv(envvar, unset = NA) - if (!is.na(envval)) - return(tolower(envval) %in% c("true", "t", "1")) - } - - # enable by default - TRUE - - }) - - # bail if we're not enabled - if (!enabled) { - - # if we're not enabled, we might still need to manually load - # the user profile here - profile <- Sys.getenv("R_PROFILE_USER", unset = "~/.Rprofile") - if (file.exists(profile)) { - cfg <- Sys.getenv("RENV_CONFIG_USER_PROFILE", unset = "TRUE") - if (tolower(cfg) %in% c("true", "t", "1")) - sys.source(profile, envir = globalenv()) - } - - return(FALSE) - - } - - # avoid recursion - if (identical(getOption("renv.autoloader.running"), TRUE)) { - warning("ignoring recursive attempt to run renv autoloader") - return(invisible(TRUE)) - } - - # signal that we're loading renv during R startup - options(renv.autoloader.running = TRUE) - on.exit(options(renv.autoloader.running = NULL), add = TRUE) - - # signal that we've consented to use renv - options(renv.consent = TRUE) - - # load the 'utils' package eagerly -- this ensures that renv shims, which - # mask 'utils' packages, will come first on the search path - library(utils, lib.loc = .Library) - - # unload renv if it's already been loaded - if ("renv" %in% loadedNamespaces()) - unloadNamespace("renv") - - # load bootstrap tools - `%||%` <- function(x, y) { - if (is.null(x)) y else x - } - - catf <- function(fmt, ..., appendLF = TRUE) { - - quiet <- getOption("renv.bootstrap.quiet", default = FALSE) - if (quiet) - return(invisible()) - - msg <- sprintf(fmt, ...) - cat(msg, file = stdout(), sep = if (appendLF) "\n" else "") - - invisible(msg) - - } - - header <- function(label, - ..., - prefix = "#", - suffix = "-", - n = min(getOption("width"), 78)) - { - label <- sprintf(label, ...) - n <- max(n - nchar(label) - nchar(prefix) - 2L, 8L) - if (n <= 0) - return(paste(prefix, label)) - - tail <- paste(rep.int(suffix, n), collapse = "") - paste0(prefix, " ", label, " ", tail) - - } - - heredoc <- function(text, leave = 0) { - - # remove leading, trailing whitespace - trimmed <- gsub("^\\s*\\n|\\n\\s*$", "", text) - - # split into lines - lines <- strsplit(trimmed, "\n", fixed = TRUE)[[1L]] - - # compute common indent - indent <- regexpr("[^[:space:]]", lines) - common <- min(setdiff(indent, -1L)) - leave - paste(substring(lines, common), collapse = "\n") - - } - - startswith <- function(string, prefix) { - substring(string, 1, nchar(prefix)) == prefix - } - - bootstrap <- function(version, library) { - - friendly <- renv_bootstrap_version_friendly(version) - section <- header(sprintf("Bootstrapping renv %s", friendly)) - catf(section) - - # attempt to download renv - catf("- Downloading renv ... ", appendLF = FALSE) - withCallingHandlers( - tarball <- renv_bootstrap_download(version), - error = function(err) { - catf("FAILED") - stop("failed to download:\n", conditionMessage(err)) - } - ) - catf("OK") - on.exit(unlink(tarball), add = TRUE) - - # now attempt to install - catf("- Installing renv ... ", appendLF = FALSE) - withCallingHandlers( - status <- renv_bootstrap_install(version, tarball, library), - error = function(err) { - catf("FAILED") - stop("failed to install:\n", conditionMessage(err)) - } - ) - catf("OK") - - # add empty line to break up bootstrapping from normal output - catf("") - - return(invisible()) - } - - renv_bootstrap_tests_running <- function() { - getOption("renv.tests.running", default = FALSE) - } - - renv_bootstrap_repos <- function() { - - # get CRAN repository - cran <- getOption("renv.repos.cran", "https://cloud.r-project.org") - - # check for repos override - repos <- Sys.getenv("RENV_CONFIG_REPOS_OVERRIDE", unset = NA) - if (!is.na(repos)) { - - # check for RSPM; if set, use a fallback repository for renv - rspm <- Sys.getenv("RSPM", unset = NA) - if (identical(rspm, repos)) - repos <- c(RSPM = rspm, CRAN = cran) - - return(repos) - - } - - # check for lockfile repositories - repos <- tryCatch(renv_bootstrap_repos_lockfile(), error = identity) - if (!inherits(repos, "error") && length(repos)) - return(repos) - - # retrieve current repos - repos <- getOption("repos") - - # ensure @CRAN@ entries are resolved - repos[repos == "@CRAN@"] <- cran - - # add in renv.bootstrap.repos if set - default <- c(FALLBACK = "https://cloud.r-project.org") - extra <- getOption("renv.bootstrap.repos", default = default) - repos <- c(repos, extra) - - # remove duplicates that might've snuck in - dupes <- duplicated(repos) | duplicated(names(repos)) - repos[!dupes] - - } - - renv_bootstrap_repos_lockfile <- function() { - - lockpath <- Sys.getenv("RENV_PATHS_LOCKFILE", unset = "renv.lock") - if (!file.exists(lockpath)) - return(NULL) - - lockfile <- tryCatch(renv_json_read(lockpath), error = identity) - if (inherits(lockfile, "error")) { - warning(lockfile) - return(NULL) - } - - repos <- lockfile$R$Repositories - if (length(repos) == 0) - return(NULL) - - keys <- vapply(repos, `[[`, "Name", FUN.VALUE = character(1)) - vals <- vapply(repos, `[[`, "URL", FUN.VALUE = character(1)) - names(vals) <- keys - - return(vals) - - } - - renv_bootstrap_download <- function(version) { - - sha <- attr(version, "sha", exact = TRUE) - - methods <- if (!is.null(sha)) { - - # attempting to bootstrap a development version of renv - c( - function() renv_bootstrap_download_tarball(sha), - function() renv_bootstrap_download_github(sha) - ) - - } else { - - # attempting to bootstrap a release version of renv - c( - function() renv_bootstrap_download_tarball(version), - function() renv_bootstrap_download_cran_latest(version), - function() renv_bootstrap_download_cran_archive(version) - ) - - } - - for (method in methods) { - path <- tryCatch(method(), error = identity) - if (is.character(path) && file.exists(path)) - return(path) - } - - stop("All download methods failed") - - } - - renv_bootstrap_download_impl <- function(url, destfile) { - - mode <- "wb" - - # https://bugs.r-project.org/bugzilla/show_bug.cgi?id=17715 - fixup <- - Sys.info()[["sysname"]] == "Windows" && - substring(url, 1L, 5L) == "file:" - - if (fixup) - mode <- "w+b" - - args <- list( - url = url, - destfile = destfile, - mode = mode, - quiet = TRUE - ) - - if ("headers" %in% names(formals(utils::download.file))) - args$headers <- renv_bootstrap_download_custom_headers(url) - - do.call(utils::download.file, args) - - } - - renv_bootstrap_download_custom_headers <- function(url) { - - headers <- getOption("renv.download.headers") - if (is.null(headers)) - return(character()) - - if (!is.function(headers)) - stopf("'renv.download.headers' is not a function") - - headers <- headers(url) - if (length(headers) == 0L) - return(character()) - - if (is.list(headers)) - headers <- unlist(headers, recursive = FALSE, use.names = TRUE) - - ok <- - is.character(headers) && - is.character(names(headers)) && - all(nzchar(names(headers))) - - if (!ok) - stop("invocation of 'renv.download.headers' did not return a named character vector") - - headers - - } - - renv_bootstrap_download_cran_latest <- function(version) { - - spec <- renv_bootstrap_download_cran_latest_find(version) - type <- spec$type - repos <- spec$repos - - baseurl <- utils::contrib.url(repos = repos, type = type) - ext <- if (identical(type, "source")) - ".tar.gz" - else if (Sys.info()[["sysname"]] == "Windows") - ".zip" - else - ".tgz" - name <- sprintf("renv_%s%s", version, ext) - url <- paste(baseurl, name, sep = "/") - - destfile <- file.path(tempdir(), name) - status <- tryCatch( - renv_bootstrap_download_impl(url, destfile), - condition = identity - ) - - if (inherits(status, "condition")) - return(FALSE) - - # report success and return - destfile - - } - - renv_bootstrap_download_cran_latest_find <- function(version) { - - # check whether binaries are supported on this system - binary <- - getOption("renv.bootstrap.binary", default = TRUE) && - !identical(.Platform$pkgType, "source") && - !identical(getOption("pkgType"), "source") && - Sys.info()[["sysname"]] %in% c("Darwin", "Windows") - - types <- c(if (binary) "binary", "source") - - # iterate over types + repositories - for (type in types) { - for (repos in renv_bootstrap_repos()) { - - # retrieve package database - db <- tryCatch( - as.data.frame( - utils::available.packages(type = type, repos = repos), - stringsAsFactors = FALSE - ), - error = identity - ) - - if (inherits(db, "error")) - next - - # check for compatible entry - entry <- db[db$Package %in% "renv" & db$Version %in% version, ] - if (nrow(entry) == 0) - next - - # found it; return spec to caller - spec <- list(entry = entry, type = type, repos = repos) - return(spec) - - } - } - - # if we got here, we failed to find renv - fmt <- "renv %s is not available from your declared package repositories" - stop(sprintf(fmt, version)) - - } - - renv_bootstrap_download_cran_archive <- function(version) { - - name <- sprintf("renv_%s.tar.gz", version) - repos <- renv_bootstrap_repos() - urls <- file.path(repos, "src/contrib/Archive/renv", name) - destfile <- file.path(tempdir(), name) - - for (url in urls) { - - status <- tryCatch( - renv_bootstrap_download_impl(url, destfile), - condition = identity - ) - - if (identical(status, 0L)) - return(destfile) - - } - - return(FALSE) - - } - - renv_bootstrap_download_tarball <- function(version) { - - # if the user has provided the path to a tarball via - # an environment variable, then use it - tarball <- Sys.getenv("RENV_BOOTSTRAP_TARBALL", unset = NA) - if (is.na(tarball)) - return() - - # allow directories - if (dir.exists(tarball)) { - name <- sprintf("renv_%s.tar.gz", version) - tarball <- file.path(tarball, name) - } - - # bail if it doesn't exist - if (!file.exists(tarball)) { - - # let the user know we weren't able to honour their request - fmt <- "- RENV_BOOTSTRAP_TARBALL is set (%s) but does not exist." - msg <- sprintf(fmt, tarball) - warning(msg) - - # bail - return() - - } - - catf("- Using local tarball '%s'.", tarball) - tarball - - } - - renv_bootstrap_download_github <- function(version) { - - enabled <- Sys.getenv("RENV_BOOTSTRAP_FROM_GITHUB", unset = "TRUE") - if (!identical(enabled, "TRUE")) - return(FALSE) - - # prepare download options - pat <- Sys.getenv("GITHUB_PAT") - if (nzchar(Sys.which("curl")) && nzchar(pat)) { - fmt <- "--location --fail --header \"Authorization: token %s\"" - extra <- sprintf(fmt, pat) - saved <- options("download.file.method", "download.file.extra") - options(download.file.method = "curl", download.file.extra = extra) - on.exit(do.call(base::options, saved), add = TRUE) - } else if (nzchar(Sys.which("wget")) && nzchar(pat)) { - fmt <- "--header=\"Authorization: token %s\"" - extra <- sprintf(fmt, pat) - saved <- options("download.file.method", "download.file.extra") - options(download.file.method = "wget", download.file.extra = extra) - on.exit(do.call(base::options, saved), add = TRUE) - } - - url <- file.path("https://api.github.com/repos/rstudio/renv/tarball", version) - name <- sprintf("renv_%s.tar.gz", version) - destfile <- file.path(tempdir(), name) - - status <- tryCatch( - renv_bootstrap_download_impl(url, destfile), - condition = identity - ) - - if (!identical(status, 0L)) - return(FALSE) - - renv_bootstrap_download_augment(destfile) - - return(destfile) - - } - - # Add Sha to DESCRIPTION. This is stop gap until #890, after which we - # can use renv::install() to fully capture metadata. - renv_bootstrap_download_augment <- function(destfile) { - sha <- renv_bootstrap_git_extract_sha1_tar(destfile) - if (is.null(sha)) { - return() - } - - # Untar - tempdir <- tempfile("renv-github-") - on.exit(unlink(tempdir, recursive = TRUE), add = TRUE) - untar(destfile, exdir = tempdir) - pkgdir <- dir(tempdir, full.names = TRUE)[[1]] - - # Modify description - desc_path <- file.path(pkgdir, "DESCRIPTION") - desc_lines <- readLines(desc_path) - remotes_fields <- c( - "RemoteType: github", - "RemoteHost: api.github.com", - "RemoteRepo: renv", - "RemoteUsername: rstudio", - "RemotePkgRef: rstudio/renv", - paste("RemoteRef: ", sha), - paste("RemoteSha: ", sha) - ) - writeLines(c(desc_lines[desc_lines != ""], remotes_fields), con = desc_path) - - # Re-tar - local({ - old <- setwd(tempdir) - on.exit(setwd(old), add = TRUE) - - tar(destfile, compression = "gzip") - }) - invisible() - } - - # Extract the commit hash from a git archive. Git archives include the SHA1 - # hash as the comment field of the tarball pax extended header - # (see https://www.kernel.org/pub/software/scm/git/docs/git-archive.html) - # For GitHub archives this should be the first header after the default one - # (512 byte) header. - renv_bootstrap_git_extract_sha1_tar <- function(bundle) { - - # open the bundle for reading - # We use gzcon for everything because (from ?gzcon) - # > Reading from a connection which does not supply a 'gzip' magic - # > header is equivalent to reading from the original connection - conn <- gzcon(file(bundle, open = "rb", raw = TRUE)) - on.exit(close(conn)) - - # The default pax header is 512 bytes long and the first pax extended header - # with the comment should be 51 bytes long - # `52 comment=` (11 chars) + 40 byte SHA1 hash - len <- 0x200 + 0x33 - res <- rawToChar(readBin(conn, "raw", n = len)[0x201:len]) - - if (grepl("^52 comment=", res)) { - sub("52 comment=", "", res) - } else { - NULL - } - } - - renv_bootstrap_install <- function(version, tarball, library) { - - # attempt to install it into project library - dir.create(library, showWarnings = FALSE, recursive = TRUE) - output <- renv_bootstrap_install_impl(library, tarball) - - # check for successful install - status <- attr(output, "status") - if (is.null(status) || identical(status, 0L)) - return(status) - - # an error occurred; report it - header <- "installation of renv failed" - lines <- paste(rep.int("=", nchar(header)), collapse = "") - text <- paste(c(header, lines, output), collapse = "\n") - stop(text) - - } - - renv_bootstrap_install_impl <- function(library, tarball) { - - # invoke using system2 so we can capture and report output - bin <- R.home("bin") - exe <- if (Sys.info()[["sysname"]] == "Windows") "R.exe" else "R" - R <- file.path(bin, exe) - - args <- c( - "--vanilla", "CMD", "INSTALL", "--no-multiarch", - "-l", shQuote(path.expand(library)), - shQuote(path.expand(tarball)) - ) - - system2(R, args, stdout = TRUE, stderr = TRUE) - - } - - renv_bootstrap_platform_prefix <- function() { - - # construct version prefix - version <- paste(R.version$major, R.version$minor, sep = ".") - prefix <- paste("R", numeric_version(version)[1, 1:2], sep = "-") - - # include SVN revision for development versions of R - # (to avoid sharing platform-specific artefacts with released versions of R) - devel <- - identical(R.version[["status"]], "Under development (unstable)") || - identical(R.version[["nickname"]], "Unsuffered Consequences") - - if (devel) - prefix <- paste(prefix, R.version[["svn rev"]], sep = "-r") - - # build list of path components - components <- c(prefix, R.version$platform) - - # include prefix if provided by user - prefix <- renv_bootstrap_platform_prefix_impl() - if (!is.na(prefix) && nzchar(prefix)) - components <- c(prefix, components) - - # build prefix - paste(components, collapse = "/") - - } - - renv_bootstrap_platform_prefix_impl <- function() { - - # if an explicit prefix has been supplied, use it - prefix <- Sys.getenv("RENV_PATHS_PREFIX", unset = NA) - if (!is.na(prefix)) - return(prefix) - - # if the user has requested an automatic prefix, generate it - auto <- Sys.getenv("RENV_PATHS_PREFIX_AUTO", unset = NA) - if (is.na(auto) && getRversion() >= "4.4.0") - auto <- "TRUE" - - if (auto %in% c("TRUE", "True", "true", "1")) - return(renv_bootstrap_platform_prefix_auto()) - - # empty string on failure - "" - - } - - renv_bootstrap_platform_prefix_auto <- function() { - - prefix <- tryCatch(renv_bootstrap_platform_os(), error = identity) - if (inherits(prefix, "error") || prefix %in% "unknown") { - - msg <- paste( - "failed to infer current operating system", - "please file a bug report at https://github.com/rstudio/renv/issues", - sep = "; " - ) - - warning(msg) - - } - - prefix - - } - - renv_bootstrap_platform_os <- function() { - - sysinfo <- Sys.info() - sysname <- sysinfo[["sysname"]] - - # handle Windows + macOS up front - if (sysname == "Windows") - return("windows") - else if (sysname == "Darwin") - return("macos") - - # check for os-release files - for (file in c("/etc/os-release", "/usr/lib/os-release")) - if (file.exists(file)) - return(renv_bootstrap_platform_os_via_os_release(file, sysinfo)) - - # check for redhat-release files - if (file.exists("/etc/redhat-release")) - return(renv_bootstrap_platform_os_via_redhat_release()) - - "unknown" - - } - - renv_bootstrap_platform_os_via_os_release <- function(file, sysinfo) { - - # read /etc/os-release - release <- utils::read.table( - file = file, - sep = "=", - quote = c("\"", "'"), - col.names = c("Key", "Value"), - comment.char = "#", - stringsAsFactors = FALSE - ) - - vars <- as.list(release$Value) - names(vars) <- release$Key - - # get os name - os <- tolower(sysinfo[["sysname"]]) - - # read id - id <- "unknown" - for (field in c("ID", "ID_LIKE")) { - if (field %in% names(vars) && nzchar(vars[[field]])) { - id <- vars[[field]] - break - } - } - - # read version - version <- "unknown" - for (field in c("UBUNTU_CODENAME", "VERSION_CODENAME", "VERSION_ID", "BUILD_ID")) { - if (field %in% names(vars) && nzchar(vars[[field]])) { - version <- vars[[field]] - break - } - } - - # join together - paste(c(os, id, version), collapse = "-") - - } - - renv_bootstrap_platform_os_via_redhat_release <- function() { - - # read /etc/redhat-release - contents <- readLines("/etc/redhat-release", warn = FALSE) - - # infer id - id <- if (grepl("centos", contents, ignore.case = TRUE)) - "centos" - else if (grepl("redhat", contents, ignore.case = TRUE)) - "redhat" - else - "unknown" - - # try to find a version component (very hacky) - version <- "unknown" - - parts <- strsplit(contents, "[[:space:]]")[[1L]] - for (part in parts) { - - nv <- tryCatch(numeric_version(part), error = identity) - if (inherits(nv, "error")) - next - - version <- nv[1, 1] - break - - } - - paste(c("linux", id, version), collapse = "-") - - } - - renv_bootstrap_library_root_name <- function(project) { - - # use project name as-is if requested - asis <- Sys.getenv("RENV_PATHS_LIBRARY_ROOT_ASIS", unset = "FALSE") - if (asis) - return(basename(project)) - - # otherwise, disambiguate based on project's path - id <- substring(renv_bootstrap_hash_text(project), 1L, 8L) - paste(basename(project), id, sep = "-") - - } - - renv_bootstrap_library_root <- function(project) { - - prefix <- renv_bootstrap_profile_prefix() - - path <- Sys.getenv("RENV_PATHS_LIBRARY", unset = NA) - if (!is.na(path)) - return(paste(c(path, prefix), collapse = "/")) - - path <- renv_bootstrap_library_root_impl(project) - if (!is.null(path)) { - name <- renv_bootstrap_library_root_name(project) - return(paste(c(path, prefix, name), collapse = "/")) - } - - renv_bootstrap_paths_renv("library", project = project) - - } - - renv_bootstrap_library_root_impl <- function(project) { - - root <- Sys.getenv("RENV_PATHS_LIBRARY_ROOT", unset = NA) - if (!is.na(root)) - return(root) - - type <- renv_bootstrap_project_type(project) - if (identical(type, "package")) { - userdir <- renv_bootstrap_user_dir() - return(file.path(userdir, "library")) - } - - } - - renv_bootstrap_validate_version <- function(version, description = NULL) { - - # resolve description file - # - # avoid passing lib.loc to `packageDescription()` below, since R will - # use the loaded version of the package by default anyhow. note that - # this function should only be called after 'renv' is loaded - # https://github.com/rstudio/renv/issues/1625 - description <- description %||% packageDescription("renv") - - # check whether requested version 'version' matches loaded version of renv - sha <- attr(version, "sha", exact = TRUE) - valid <- if (!is.null(sha)) - renv_bootstrap_validate_version_dev(sha, description) - else - renv_bootstrap_validate_version_release(version, description) - - if (valid) - return(TRUE) - - # the loaded version of renv doesn't match the requested version; - # give the user instructions on how to proceed - dev <- identical(description[["RemoteType"]], "github") - remote <- if (dev) - paste("rstudio/renv", description[["RemoteSha"]], sep = "@") - else - paste("renv", description[["Version"]], sep = "@") - - # display both loaded version + sha if available - friendly <- renv_bootstrap_version_friendly( - version = description[["Version"]], - sha = if (dev) description[["RemoteSha"]] - ) - - fmt <- heredoc(" - renv %1$s was loaded from project library, but this project is configured to use renv %2$s. - - Use `renv::record(\"%3$s\")` to record renv %1$s in the lockfile. - - Use `renv::restore(packages = \"renv\")` to install renv %2$s into the project library. - ") - catf(fmt, friendly, renv_bootstrap_version_friendly(version), remote) - - FALSE - - } - - renv_bootstrap_validate_version_dev <- function(version, description) { - expected <- description[["RemoteSha"]] - is.character(expected) && startswith(expected, version) - } - - renv_bootstrap_validate_version_release <- function(version, description) { - expected <- description[["Version"]] - is.character(expected) && identical(expected, version) - } - - renv_bootstrap_hash_text <- function(text) { - - hashfile <- tempfile("renv-hash-") - on.exit(unlink(hashfile), add = TRUE) - - writeLines(text, con = hashfile) - tools::md5sum(hashfile) - - } - - renv_bootstrap_load <- function(project, libpath, version) { - - # try to load renv from the project library - if (!requireNamespace("renv", lib.loc = libpath, quietly = TRUE)) - return(FALSE) - - # warn if the version of renv loaded does not match - renv_bootstrap_validate_version(version) - - # execute renv load hooks, if any - hooks <- getHook("renv::autoload") - for (hook in hooks) - if (is.function(hook)) - tryCatch(hook(), error = warnify) - - # load the project - renv::load(project) - - TRUE - - } - - renv_bootstrap_profile_load <- function(project) { - - # if RENV_PROFILE is already set, just use that - profile <- Sys.getenv("RENV_PROFILE", unset = NA) - if (!is.na(profile) && nzchar(profile)) - return(profile) - - # check for a profile file (nothing to do if it doesn't exist) - path <- renv_bootstrap_paths_renv("profile", profile = FALSE, project = project) - if (!file.exists(path)) - return(NULL) - - # read the profile, and set it if it exists - contents <- readLines(path, warn = FALSE) - if (length(contents) == 0L) - return(NULL) - - # set RENV_PROFILE - profile <- contents[[1L]] - if (!profile %in% c("", "default")) - Sys.setenv(RENV_PROFILE = profile) - - profile - - } - - renv_bootstrap_profile_prefix <- function() { - profile <- renv_bootstrap_profile_get() - if (!is.null(profile)) - return(file.path("profiles", profile, "renv")) - } - - renv_bootstrap_profile_get <- function() { - profile <- Sys.getenv("RENV_PROFILE", unset = "") - renv_bootstrap_profile_normalize(profile) - } - - renv_bootstrap_profile_set <- function(profile) { - profile <- renv_bootstrap_profile_normalize(profile) - if (is.null(profile)) - Sys.unsetenv("RENV_PROFILE") - else - Sys.setenv(RENV_PROFILE = profile) - } - - renv_bootstrap_profile_normalize <- function(profile) { - - if (is.null(profile) || profile %in% c("", "default")) - return(NULL) - - profile - - } - - renv_bootstrap_path_absolute <- function(path) { - - substr(path, 1L, 1L) %in% c("~", "/", "\\") || ( - substr(path, 1L, 1L) %in% c(letters, LETTERS) && - substr(path, 2L, 3L) %in% c(":/", ":\\") - ) - - } - - renv_bootstrap_paths_renv <- function(..., profile = TRUE, project = NULL) { - renv <- Sys.getenv("RENV_PATHS_RENV", unset = "renv") - root <- if (renv_bootstrap_path_absolute(renv)) NULL else project - prefix <- if (profile) renv_bootstrap_profile_prefix() - components <- c(root, renv, prefix, ...) - paste(components, collapse = "/") - } - - renv_bootstrap_project_type <- function(path) { - - descpath <- file.path(path, "DESCRIPTION") - if (!file.exists(descpath)) - return("unknown") - - desc <- tryCatch( - read.dcf(descpath, all = TRUE), - error = identity - ) - - if (inherits(desc, "error")) - return("unknown") - - type <- desc$Type - if (!is.null(type)) - return(tolower(type)) - - package <- desc$Package - if (!is.null(package)) - return("package") - - "unknown" - - } - - renv_bootstrap_user_dir <- function() { - dir <- renv_bootstrap_user_dir_impl() - path.expand(chartr("\\", "/", dir)) - } - - renv_bootstrap_user_dir_impl <- function() { - - # use local override if set - override <- getOption("renv.userdir.override") - if (!is.null(override)) - return(override) - - # use R_user_dir if available - tools <- asNamespace("tools") - if (is.function(tools$R_user_dir)) - return(tools$R_user_dir("renv", "cache")) - - # try using our own backfill for older versions of R - envvars <- c("R_USER_CACHE_DIR", "XDG_CACHE_HOME") - for (envvar in envvars) { - root <- Sys.getenv(envvar, unset = NA) - if (!is.na(root)) - return(file.path(root, "R/renv")) - } - - # use platform-specific default fallbacks - if (Sys.info()[["sysname"]] == "Windows") - file.path(Sys.getenv("LOCALAPPDATA"), "R/cache/R/renv") - else if (Sys.info()[["sysname"]] == "Darwin") - "~/Library/Caches/org.R-project.R/R/renv" - else - "~/.cache/R/renv" - - } - - renv_bootstrap_version_friendly <- function(version, shafmt = NULL, sha = NULL) { - sha <- sha %||% attr(version, "sha", exact = TRUE) - parts <- c(version, sprintf(shafmt %||% " [sha: %s]", substring(sha, 1L, 7L))) - paste(parts, collapse = "") - } - - renv_bootstrap_exec <- function(project, libpath, version) { - if (!renv_bootstrap_load(project, libpath, version)) - renv_bootstrap_run(version, libpath) - } - - renv_bootstrap_run <- function(version, libpath) { - - # perform bootstrap - bootstrap(version, libpath) - - # exit early if we're just testing bootstrap - if (!is.na(Sys.getenv("RENV_BOOTSTRAP_INSTALL_ONLY", unset = NA))) - return(TRUE) - - # try again to load - if (requireNamespace("renv", lib.loc = libpath, quietly = TRUE)) { - return(renv::load(project = getwd())) - } - - # failed to download or load renv; warn the user - msg <- c( - "Failed to find an renv installation: the project will not be loaded.", - "Use `renv::activate()` to re-initialize the project." - ) - - warning(paste(msg, collapse = "\n"), call. = FALSE) - - } - - renv_json_read <- function(file = NULL, text = NULL) { - - jlerr <- NULL - - # if jsonlite is loaded, use that instead - if ("jsonlite" %in% loadedNamespaces()) { - - json <- tryCatch(renv_json_read_jsonlite(file, text), error = identity) - if (!inherits(json, "error")) - return(json) - - jlerr <- json - - } - - # otherwise, fall back to the default JSON reader - json <- tryCatch(renv_json_read_default(file, text), error = identity) - if (!inherits(json, "error")) - return(json) - - # report an error - if (!is.null(jlerr)) - stop(jlerr) - else - stop(json) - - } - - renv_json_read_jsonlite <- function(file = NULL, text = NULL) { - text <- paste(text %||% readLines(file, warn = FALSE), collapse = "\n") - jsonlite::fromJSON(txt = text, simplifyVector = FALSE) - } - - renv_json_read_default <- function(file = NULL, text = NULL) { - - # find strings in the JSON - text <- paste(text %||% readLines(file, warn = FALSE), collapse = "\n") - pattern <- '["](?:(?:\\\\.)|(?:[^"\\\\]))*?["]' - locs <- gregexpr(pattern, text, perl = TRUE)[[1]] - - # if any are found, replace them with placeholders - replaced <- text - strings <- character() - replacements <- character() - - if (!identical(c(locs), -1L)) { - - # get the string values - starts <- locs - ends <- locs + attr(locs, "match.length") - 1L - strings <- substring(text, starts, ends) - - # only keep those requiring escaping - strings <- grep("[[\\]{}:]", strings, perl = TRUE, value = TRUE) - - # compute replacements - replacements <- sprintf('"\032%i\032"', seq_along(strings)) - - # replace the strings - mapply(function(string, replacement) { - replaced <<- sub(string, replacement, replaced, fixed = TRUE) - }, strings, replacements) - - } - - # transform the JSON into something the R parser understands - transformed <- replaced - transformed <- gsub("{}", "`names<-`(list(), character())", transformed, fixed = TRUE) - transformed <- gsub("[[{]", "list(", transformed, perl = TRUE) - transformed <- gsub("[]}]", ")", transformed, perl = TRUE) - transformed <- gsub(":", "=", transformed, fixed = TRUE) - text <- paste(transformed, collapse = "\n") - - # parse it - json <- parse(text = text, keep.source = FALSE, srcfile = NULL)[[1L]] - - # construct map between source strings, replaced strings - map <- as.character(parse(text = strings)) - names(map) <- as.character(parse(text = replacements)) - - # convert to list - map <- as.list(map) - - # remap strings in object - remapped <- renv_json_read_remap(json, map) - - # evaluate - eval(remapped, envir = baseenv()) - - } - - renv_json_read_remap <- function(json, map) { - - # fix names - if (!is.null(names(json))) { - lhs <- match(names(json), names(map), nomatch = 0L) - rhs <- match(names(map), names(json), nomatch = 0L) - names(json)[rhs] <- map[lhs] - } - - # fix values - if (is.character(json)) - return(map[[json]] %||% json) - - # handle true, false, null - if (is.name(json)) { - text <- as.character(json) - if (text == "true") - return(TRUE) - else if (text == "false") - return(FALSE) - else if (text == "null") - return(NULL) - } - - # recurse - if (is.recursive(json)) { - for (i in seq_along(json)) { - json[i] <- list(renv_json_read_remap(json[[i]], map)) - } - } - - json - - } - - # load the renv profile, if any - renv_bootstrap_profile_load(project) - - # construct path to library root - root <- renv_bootstrap_library_root(project) - - # construct library prefix for platform - prefix <- renv_bootstrap_platform_prefix() - - # construct full libpath - libpath <- file.path(root, prefix) - - # run bootstrap code - renv_bootstrap_exec(project, libpath, version) - - invisible() - -}) diff --git a/renv/settings.json b/renv/settings.json deleted file mode 100644 index dec23cf4..00000000 --- a/renv/settings.json +++ /dev/null @@ -1,19 +0,0 @@ -{ - "bioconductor.version": null, - "external.libraries": [], - "ignored.packages": [], - "package.dependency.fields": [ - "Imports", - "Depends", - "LinkingTo" - ], - "ppm.enabled": true, - "ppm.ignored.urls": [], - "r.version": null, - "snapshot.type": "implicit", - "use.cache": true, - "vcs.ignore.cellar": true, - "vcs.ignore.library": true, - "vcs.ignore.local": true, - "vcs.manage.ignores": true -} From 193806fdb3ea5ddefc4a4bba2b5a228e15327f34 Mon Sep 17 00:00:00 2001 From: sbidari Date: Fri, 22 Nov 2024 16:18:20 -0500 Subject: [PATCH 45/50] reorganize files within demos --- .gitignore | 4 ++-- demos/{hosp_only_ww_model => }/data/fit/stan_data.json | 0 .../data/fit_hosp_only/stan_data.json | 0 demos/{hosp_only_ww_model => }/hosp_only_ww_model.qmd | 0 demos/{hosp_only_ww_model => }/model_comp.qmd | 0 .../site_level_dynamics_model_demo.qmd => ww_model_demo.qmd} | 2 +- demos/{hosp_only_ww_model => }/wwinference.Rmd | 0 ...evel_dynamics_model.py => ww_site_level_dynamics_model.py} | 0 8 files changed, 3 insertions(+), 3 deletions(-) rename demos/{hosp_only_ww_model => }/data/fit/stan_data.json (100%) rename demos/{hosp_only_ww_model => }/data/fit_hosp_only/stan_data.json (100%) rename demos/{hosp_only_ww_model => }/hosp_only_ww_model.qmd (100%) rename demos/{hosp_only_ww_model => }/model_comp.qmd (100%) rename demos/{hosp_only_ww_model/site_level_dynamics_model_demo.qmd => ww_model_demo.qmd} (99%) rename demos/{hosp_only_ww_model => }/wwinference.Rmd (100%) rename pyrenew_hew/{site_level_dynamics_model.py => ww_site_level_dynamics_model.py} (100%) diff --git a/.gitignore b/.gitignore index 9f92efd7..dff68a0b 100644 --- a/.gitignore +++ b/.gitignore @@ -391,8 +391,8 @@ docs/site/ poetry.lock -notebooks/*_files/ -notebooks/*.md +demos/*_files/ +demos/*.md private_data/* *_files/ diff --git a/demos/hosp_only_ww_model/data/fit/stan_data.json b/demos/data/fit/stan_data.json similarity index 100% rename from demos/hosp_only_ww_model/data/fit/stan_data.json rename to demos/data/fit/stan_data.json diff --git a/demos/hosp_only_ww_model/data/fit_hosp_only/stan_data.json b/demos/data/fit_hosp_only/stan_data.json similarity index 100% rename from demos/hosp_only_ww_model/data/fit_hosp_only/stan_data.json rename to demos/data/fit_hosp_only/stan_data.json diff --git a/demos/hosp_only_ww_model/hosp_only_ww_model.qmd b/demos/hosp_only_ww_model.qmd similarity index 100% rename from demos/hosp_only_ww_model/hosp_only_ww_model.qmd rename to demos/hosp_only_ww_model.qmd diff --git a/demos/hosp_only_ww_model/model_comp.qmd b/demos/model_comp.qmd similarity index 100% rename from demos/hosp_only_ww_model/model_comp.qmd rename to demos/model_comp.qmd diff --git a/demos/hosp_only_ww_model/site_level_dynamics_model_demo.qmd b/demos/ww_model_demo.qmd similarity index 99% rename from demos/hosp_only_ww_model/site_level_dynamics_model_demo.qmd rename to demos/ww_model_demo.qmd index 9f2fe6f4..15db096f 100644 --- a/demos/hosp_only_ww_model/site_level_dynamics_model_demo.qmd +++ b/demos/ww_model_demo.qmd @@ -15,7 +15,7 @@ import numpyro.distributions.transforms as transforms from pyrenew.deterministic import DeterministicVariable, DeterministicPMF from pyrenew.randomvariable import DistributionalVariable, TransformedVariable -from pyrenew_hew.site_level_dynamics_model import ww_site_level_dynamics_model +from pyrenew_hew.ww_site_level_dynamics_model import ww_site_level_dynamics_model from pyrenew_hew.utils import convert_to_logmean_log_sd import pyrenew_hew.plotting as plotting diff --git a/demos/hosp_only_ww_model/wwinference.Rmd b/demos/wwinference.Rmd similarity index 100% rename from demos/hosp_only_ww_model/wwinference.Rmd rename to demos/wwinference.Rmd diff --git a/pyrenew_hew/site_level_dynamics_model.py b/pyrenew_hew/ww_site_level_dynamics_model.py similarity index 100% rename from pyrenew_hew/site_level_dynamics_model.py rename to pyrenew_hew/ww_site_level_dynamics_model.py From 22cb04990f8f3bafeb9101ff5367126f2587625b Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 9 Dec 2024 16:57:57 -0500 Subject: [PATCH 46/50] update stan_data for ww_model --- demos/hosp_only_ww_model/wwinference.Rmd | 7 ++----- demos/ww_model/data/fit/stan_data.json | 26 ++++++++++++------------ 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/demos/hosp_only_ww_model/wwinference.Rmd b/demos/hosp_only_ww_model/wwinference.Rmd index 0355f349..fced9fae 100644 --- a/demos/hosp_only_ww_model/wwinference.Rmd +++ b/demos/hosp_only_ww_model/wwinference.Rmd @@ -614,12 +614,9 @@ plot(draws_hosp_only, what = "global_rt", forecast_date = forecast_date) ```{r copy results} library(fs) -base_dir <- path("data") -fit_dir <- path(base_dir, "fit") -fit_hosp_only_dir <- path(base_dir, "fit_hosp_only") +fit_dir <- path("demos", "ww_model", "data", "fit") +fit_hosp_only_dir <- path("demos", "hosp_only_ww_model", "data", "fit_hosp_only") dir_create(fit_dir) -dir_create(fit_hosp_only_dir) - file_copy( path = ww_fit$fit$result$data_file(), diff --git a/demos/ww_model/data/fit/stan_data.json b/demos/ww_model/data/fit/stan_data.json index 9f6b833d..27c6ee66 100644 --- a/demos/ww_model/data/fit/stan_data.json +++ b/demos/ww_model/data/fit/stan_data.json @@ -7,10 +7,10 @@ "n_subpops": 5, "n_ww_sites": 4, "n_ww_lab_sites": 5, - "owt": 98, + "owt": 94, "oht": 90, - "n_censored": 1, - "n_uncensored": 97, + "n_censored": 0, + "n_uncensored": 94, "uot": 50, "ht": 35, "n_weeks": 18, @@ -324,17 +324,17 @@ "state_pop": 3000000.0, "subpop_size": [2250000.0, 400000.0, 200000.0, 100000.0, 50000.0], "norm_pop": 3000000.0, - "ww_sampled_times": [2, 2, 2, 5, 6, 6, 8, 9, 11, 12, 12, 12, 13, 14, 14, 14, 15, 15, 16, 18, 18, 18, 19, 20, 21, 22, 22, 23, 23, 25, 26, 26, 27, 29, 29, 29, 31, 32, 32, 33, 33, 34, 36, 36, 37, 37, 39, 40, 42, 42, 42, 43, 45, 46, 46, 47, 48, 51, 52, 53, 56, 57, 58, 58, 59, 59, 63, 63, 64, 65, 65, 67, 70, 70, 73, 74, 75, 76, 76, 77, 78, 80, 81, 83, 83, 84, 86, 87, 89, 89, 90, 90, 92, 92, 92, 94, 94, 94], + "ww_sampled_times": [1, 1, 2, 4, 5, 8, 9, 11, 11, 12, 14, 14, 14, 15, 17, 18, 19, 21, 22, 23, 24, 26, 26, 26, 27, 27, 27, 30, 30, 31, 31, 31, 32, 32, 35, 35, 36, 36, 36, 38, 38, 39, 40, 41, 42, 44, 44, 46, 46, 49, 49, 51, 51, 53, 54, 55, 56, 57, 57, 58, 58, 59, 59, 59, 61, 62, 62, 63, 65, 66, 67, 68, 69, 69, 70, 71, 72, 72, 74, 75, 78, 79, 80, 82, 84, 86, 86, 87, 87, 89, 90, 91, 93, 93], "hosp_times": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90], - "ww_sampled_subpops": [2, 4, 5, 4, 3, 4, 2, 2, 3, 2, 2, 5, 4, 2, 2, 4, 2, 4, 5, 2, 3, 4, 2, 2, 2, 2, 5, 2, 4, 3, 2, 5, 3, 2, 2, 3, 4, 2, 3, 3, 4, 4, 2, 3, 2, 3, 2, 5, 2, 2, 3, 3, 4, 2, 5, 2, 2, 4, 5, 4, 5, 5, 2, 4, 2, 5, 2, 5, 3, 2, 2, 2, 2, 2, 4, 2, 2, 2, 4, 4, 4, 3, 4, 2, 3, 3, 2, 3, 2, 4, 4, 5, 2, 4, 5, 2, 4, 5], + "ww_sampled_subpops": [2, 3, 5, 3, 3, 2, 3, 2, 5, 3, 2, 2, 5, 4, 2, 2, 3, 2, 5, 3, 3, 2, 4, 5, 2, 3, 5, 2, 3, 2, 2, 5, 2, 2, 3, 5, 2, 2, 4, 2, 2, 3, 4, 5, 5, 3, 4, 4, 5, 3, 5, 2, 5, 3, 2, 5, 5, 2, 5, 4, 5, 3, 4, 5, 2, 3, 4, 3, 3, 3, 3, 5, 2, 4, 3, 2, 2, 4, 5, 2, 5, 2, 5, 3, 2, 2, 2, 4, 5, 5, 5, 5, 2, 4], "lab_site_to_subpop_map": [2, 2, 3, 4, 5], - "ww_sampled_lab_sites": [1, 4, 5, 4, 3, 4, 2, 2, 3, 1, 2, 5, 4, 1, 2, 4, 2, 4, 5, 2, 3, 4, 1, 2, 1, 2, 5, 2, 4, 3, 2, 5, 3, 1, 2, 3, 4, 1, 3, 3, 4, 4, 2, 3, 2, 3, 1, 5, 1, 2, 3, 3, 4, 2, 5, 2, 1, 4, 5, 4, 5, 5, 1, 4, 1, 5, 1, 5, 3, 1, 2, 1, 1, 2, 4, 1, 2, 2, 4, 4, 4, 3, 4, 2, 3, 3, 1, 3, 1, 4, 4, 5, 2, 4, 5, 2, 4, 5], - "ww_log_lod": [4.81786386350134, 4.93524984908241, 4.78269939060126, 4.93524984908241, 5.01370230665543, 4.93524984908241, 5.14825526105204, 5.14825526105204, 5.01370230665543, 4.81786386350134, 5.14825526105204, 4.78269939060126, 4.93524984908241, 4.81786386350134, 5.14825526105204, 4.93524984908241, 5.14825526105204, 4.93524984908241, 4.78269939060126, 5.14825526105204, 5.01370230665543, 4.93524984908241, 4.81786386350134, 5.14825526105204, 4.81786386350134, 5.14825526105204, 4.78269939060126, 5.14825526105204, 4.93524984908241, 5.01370230665543, 5.14825526105204, 4.78269939060126, 5.01370230665543, 4.81786386350134, 5.14825526105204, 5.01370230665543, 4.93524984908241, 4.81786386350134, 5.01370230665543, 5.01370230665543, 4.93524984908241, 4.93524984908241, 5.14825526105204, 5.01370230665543, 5.14825526105204, 5.01370230665543, 4.81786386350134, 4.78269939060126, 4.81786386350134, 5.14825526105204, 5.01370230665543, 5.01370230665543, 4.93524984908241, 5.14825526105204, 4.78269939060126, 5.14825526105204, 4.81786386350134, 4.93524984908241, 4.78269939060126, 4.93524984908241, 4.78269939060126, 4.78269939060126, 4.81786386350134, 4.93524984908241, 4.81786386350134, 4.78269939060126, 4.81786386350134, 4.78269939060126, 5.01370230665543, 4.81786386350134, 5.14825526105204, 4.81786386350134, 4.81786386350134, 5.14825526105204, 4.93524984908241, 4.81786386350134, 5.14825526105204, 5.14825526105204, 4.93524984908241, 4.93524984908241, 4.93524984908241, 5.01370230665543, 4.93524984908241, 5.14825526105204, 5.01370230665543, 5.01370230665543, 4.81786386350134, 5.01370230665543, 4.81786386350134, 4.93524984908241, 4.93524984908241, 4.78269939060126, 5.14825526105204, 4.93524984908241, 4.78269939060126, 5.14825526105204, 4.93524984908241, 4.78269939060126], - "ww_censored": [66], - "ww_uncensored": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98], - "hosp": [14, 15, 11, 22, 13, 17, 26, 19, 21, 21, 18, 32, 22, 29, 15, 38, 30, 26, 26, 36, 37, 41, 38, 37, 50, 57, 56, 46, 55, 49, 53, 58, 58, 59, 63, 68, 66, 90, 71, 74, 90, 79, 66, 76, 91, 79, 94, 78, 93, 82, 100, 92, 61, 93, 60, 71, 62, 72, 53, 55, 49, 59, 68, 46, 50, 45, 36, 32, 56, 37, 41, 41, 35, 45, 37, 33, 27, 33, 34, 34, 38, 27, 43, 30, 33, 34, 35, 33, 35, 40], + "ww_sampled_lab_sites": [1, 3, 5, 3, 3, 2, 3, 1, 5, 3, 1, 2, 5, 4, 2, 1, 3, 1, 5, 3, 3, 1, 4, 5, 1, 3, 5, 1, 3, 1, 2, 5, 1, 2, 3, 5, 1, 2, 4, 1, 2, 3, 4, 5, 5, 3, 4, 4, 5, 3, 5, 2, 5, 3, 1, 5, 5, 1, 5, 4, 5, 3, 4, 5, 1, 3, 4, 3, 3, 3, 3, 5, 2, 4, 3, 2, 2, 4, 5, 1, 5, 2, 5, 3, 1, 1, 2, 4, 5, 5, 5, 5, 1, 4], + "ww_log_lod": [5.00549506716936, 5.21075017260572, 5.06712344199376, 5.21075017260572, 5.21075017260572, 4.66396345552081, 5.21075017260572, 5.00549506716936, 5.06712344199376, 5.21075017260572, 5.00549506716936, 4.66396345552081, 5.06712344199376, 4.77608017908556, 4.66396345552081, 5.00549506716936, 5.21075017260572, 5.00549506716936, 5.06712344199376, 5.21075017260572, 5.21075017260572, 5.00549506716936, 4.77608017908556, 5.06712344199376, 5.00549506716936, 5.21075017260572, 5.06712344199376, 5.00549506716936, 5.21075017260572, 5.00549506716936, 4.66396345552081, 5.06712344199376, 5.00549506716936, 4.66396345552081, 5.21075017260572, 5.06712344199376, 5.00549506716936, 4.66396345552081, 4.77608017908556, 5.00549506716936, 4.66396345552081, 5.21075017260572, 4.77608017908556, 5.06712344199376, 5.06712344199376, 5.21075017260572, 4.77608017908556, 4.77608017908556, 5.06712344199376, 5.21075017260572, 5.06712344199376, 4.66396345552081, 5.06712344199376, 5.21075017260572, 5.00549506716936, 5.06712344199376, 5.06712344199376, 5.00549506716936, 5.06712344199376, 4.77608017908556, 5.06712344199376, 5.21075017260572, 4.77608017908556, 5.06712344199376, 5.00549506716936, 5.21075017260572, 4.77608017908556, 5.21075017260572, 5.21075017260572, 5.21075017260572, 5.21075017260572, 5.06712344199376, 4.66396345552081, 4.77608017908556, 5.21075017260572, 4.66396345552081, 4.66396345552081, 4.77608017908556, 5.06712344199376, 5.00549506716936, 5.06712344199376, 4.66396345552081, 5.06712344199376, 5.21075017260572, 5.00549506716936, 5.00549506716936, 4.66396345552081, 4.77608017908556, 5.06712344199376, 5.06712344199376, 5.06712344199376, 5.06712344199376, 5.00549506716936, 4.77608017908556], + "ww_censored": [], + "ww_uncensored": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94], + "hosp": [25, 17, 25, 24, 26, 25, 27, 37, 32, 19, 18, 33, 22, 31, 34, 47, 45, 28, 38, 43, 33, 47, 38, 48, 43, 39, 52, 40, 49, 63, 47, 64, 58, 48, 92, 54, 53, 81, 56, 77, 84, 74, 62, 67, 74, 75, 89, 100, 65, 83, 96, 74, 59, 57, 60, 74, 70, 69, 60, 50, 75, 60, 53, 54, 50, 56, 48, 55, 41, 37, 50, 50, 39, 30, 31, 23, 35, 34, 33, 16, 23, 16, 21, 28, 29, 26, 30, 30, 27, 23], "day_of_week": [5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3], - "log_conc": [7.66802724479397, 8.21650318956816, 6.4512477164068, 7.75785732441969, 9.04472643003972, 7.8876093738573, 8.85695271080419, 8.77448654283081, 9.25474422044025, 8.48197724221017, 9.18487615016851, 7.73586744786872, 9.15190271402333, 8.57063681753772, 9.22494958414535, 8.8811814139552, 8.95303168179479, 8.87479731488632, 7.04631257099227, 9.17451866263841, 9.81185744840712, 8.34692842110394, 8.56700168572782, 9.35764330055156, 8.64805057574179, 9.45079742597308, 7.30867303574344, 9.41472892758601, 9.12723564522376, 9.63427957851573, 9.4877830585957, 9.02899614473423, 9.86293199017148, 8.81668715131234, 9.35443062233028, 9.86957090190018, 8.95533551342666, 9.10840763554723, 9.89336595991182, 10.2035349739338, 8.51721998551749, 8.50763772111075, 9.68146528671227, 10.1267878546749, 9.82586519456221, 10.4562946689248, 9.39359538798059, 9.06912617169831, 9.13017806227903, 9.87195987699652, 10.1617885270518, 9.92070098259861, 9.04498782443653, 9.99736049469332, 7.8122356539209, 9.90761719216578, 9.19954542861679, 9.20908242423643, 7.53800674845928, 9.23810876627582, 8.74766338564374, 5.68405424571368, 8.94327722754565, 8.13815511398509, 8.72901851354369, 2.39134969530063, 8.55874357420902, 7.99413714656807, 8.70447685985134, 8.7648016157833, 9.28255076496634, 8.71171479812706, 8.80114398496407, 9.28818292020013, 7.6967588326828, 8.58440395540122, 8.89885558938604, 9.12890628468619, 7.639038061726, 7.70375332709297, 8.03272855574736, 8.59715668661873, 7.85031287265023, 8.95008137695999, 9.11122924030913, 9.12994052023167, 8.37562782560992, 8.46419130893256, 8.46685643123332, 8.15758459191837, 7.81253430699302, 6.6648633811922, 9.23344255196987, 8.06975580003644, 6.09213112399348, 8.97077401961493, 7.26054982586712, 8.02898834978338], + "log_conc": [7.81764945377784, 7.6933600870002, 8.06724360667125, 7.91499647953297, 7.89413202757162, 8.65370848464167, 8.23703266230228, 8.27379492054683, 7.92724156512979, 7.94840425601809, 8.64481518407572, 8.73244075027552, 9.19316034661646, 8.60407057002823, 8.6722245014186, 8.87458785451969, 8.37923688766207, 8.91291868471269, 9.3578500019116, 8.33166286089283, 8.53454734151903, 9.16766777808109, 8.96778038239455, 8.57719573084511, 9.20437837140793, 8.95410274800007, 8.74059860130401, 9.10087573858902, 8.90131876883448, 9.08543023375679, 9.34542713278413, 9.92717583346023, 9.04312218630707, 9.57234097328664, 9.22198691752427, 9.49685074212238, 9.43672149942452, 9.61238931697528, 9.45242352426935, 9.36553178108906, 9.74293017201537, 9.0701963960744, 8.94662370930357, 8.95638676515249, 9.74171912141298, 9.20842822682359, 9.75413779255807, 9.15460750061054, 10.4986679283212, 9.01927790953137, 10.4227816651643, 9.48588325481423, 9.48290465379698, 9.06478545633172, 9.08506701681588, 10.3033927317299, 9.85477070600969, 8.95081048553881, 8.73677822709366, 8.84052991146567, 10.1298699894474, 8.87813562213935, 9.54864918729216, 10.0110817497961, 8.76636294806532, 9.03971685283071, 8.65351607279411, 8.41286808299863, 8.3719146541868, 8.36653561742599, 8.3261383206485, 10.7363688007168, 8.76415179845973, 9.14624029412538, 8.52527063701149, 8.73740006526049, 8.48666982792886, 8.18740210951593, 8.87078952071086, 8.12123299690221, 7.28523592880074, 8.71867664086082, 8.38027830353204, 7.89336545678093, 8.38380845068392, 8.4538151926047, 8.48281378055811, 8.30548667238886, 7.46444236066205, 9.45727778541334, 6.86310169078759, 7.72880012484138, 8.44521537189331, 8.56755001360401], "compute_likelihood": 1, "include_ww": 1, "include_hosp": 1, @@ -353,8 +353,8 @@ "r_prior_sd": 1, "log10_g_prior_mean": 12, "log10_g_prior_sd": 2, - "i_first_obs_over_n_prior_a": 1.00280952380952, - "i_first_obs_over_n_prior_b": 5.99719047619048, + "i_first_obs_over_n_prior_a": 1.00402380952381, + "i_first_obs_over_n_prior_b": 5.99597619047619, "hosp_wday_effect_prior_alpha": [5, 5, 5, 5, 5, 5, 5], "mean_initial_exp_growth_rate_prior_mean": 0, "mean_initial_exp_growth_rate_prior_sd": 0.01, From 44436bb0013723a4bc13599762567eeb42c50741 Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 9 Dec 2024 17:04:40 -0500 Subject: [PATCH 47/50] pre-commit changes --- demos/hosp_only_ww_model/wwinference.Rmd | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/demos/hosp_only_ww_model/wwinference.Rmd b/demos/hosp_only_ww_model/wwinference.Rmd index fced9fae..93b6c66e 100644 --- a/demos/hosp_only_ww_model/wwinference.Rmd +++ b/demos/hosp_only_ww_model/wwinference.Rmd @@ -615,8 +615,11 @@ plot(draws_hosp_only, what = "global_rt", forecast_date = forecast_date) library(fs) fit_dir <- path("demos", "ww_model", "data", "fit") -fit_hosp_only_dir <- path("demos", "hosp_only_ww_model", "data", "fit_hosp_only") +fit_hosp_only_dir <- path( + "demos", "hosp_only_ww_model", "data", "fit_hosp_only" +) dir_create(fit_dir) +dir_create(fit_hosp_only_dir) file_copy( path = ww_fit$fit$result$data_file(), From 19ddc079ed217187c9e74324c982006b2cbcfbe7 Mon Sep 17 00:00:00 2001 From: sbidari Date: Tue, 10 Dec 2024 16:50:29 -0500 Subject: [PATCH 48/50] add model comparison demo --- demos/ww_model/model_comp.qmd | 157 ++++++++++++++++++++ demos/ww_model/ww_model_demo.qmd | 13 +- pyrenew_hew/ww_site_level_dynamics_model.py | 82 ++++------ 3 files changed, 191 insertions(+), 61 deletions(-) create mode 100644 demos/ww_model/model_comp.qmd diff --git a/demos/ww_model/model_comp.qmd b/demos/ww_model/model_comp.qmd new file mode 100644 index 00000000..7f2ff749 --- /dev/null +++ b/demos/ww_model/model_comp.qmd @@ -0,0 +1,157 @@ +--- +title: "PyRenew and wwinference Fit and Forecast Comparison" +format: gfm +editor: visual +--- + +This document shows graphical comparisons for key variables in the PyRenew model fit to example data (notebooks/pyrenew_hew_model.qmd) and Stan model fit to example data (notebooks/wwinference.Rmd). In order to render this document, those notebooks must be rendered first. + +```{r} +#| output: false +library(tidyverse) +library(tidybayes) +library(fs) +library(cmdstanr) +library(posterior) +library(jsonlite) +library(scales) +library(here) +library(forecasttools) +ci_width <- c(0.5, 0.8, 0.95) +fit_dir <- here(path("demos/ww_model/data/fit")) +``` + +## Load Data + +```{r} +hosp_data <- tibble(.value = path(fit_dir, "stan_data", ext = "json") |> + jsonlite::read_json() |> + pluck("hosp") |> + unlist()) |> + mutate(time = row_number()) + +stan_files <- + dir_ls(fit_dir, + glob = "*wwinference*" + ) |> + enframe(name = NULL, value = "file_path") |> + mutate(file_details = path_ext_remove(path_file(file_path))) |> + separate_wider_delim(file_details, + delim = "-", + names = c("model", "date", "chain", "hash") + ) |> + mutate(date = ymd_hm(date)) |> + filter(date == max(date)) |> + pull(file_path) + + +stan_tidy_draws <- read_cmdstan_csv(stan_files)$post_warmup_draws |> + tidy_draws() + +pyrenew_tidy_draws <- + path(fit_dir, "inference_data", ext = "csv") |> + read_csv() |> + forecasttools::inferencedata_to_tidy_draws() +``` + +## Calculate Credible Intervals for Plotting + +```{r} +combined_ci_for_plotting <- + bind_rows( + deframe(pyrenew_tidy_draws)$posterior_predictive |> + gather_draws(observed_hospital_admissions[time], state_rt[time], ihr[time], r_subpop_t[time,group]) |> + median_qi(.width = ci_width) |> + mutate(model = "pyrenew"), + stan_tidy_draws |> + gather_draws(pred_hosp[time], rt[time], p_hosp[time],r_subpop_t[group,time]) |> + mutate(.variable = case_when( + .variable == "pred_hosp" ~ "observed_hospital_admissions", + .variable == "p_hosp" ~ "ihr", + .variable == "rt" ~ "state_rt", + TRUE ~ .variable + )) |> + median_qi(.width = ci_width) |> + mutate(model = "stan") + ) +``` + +## Hospital Admission Comparison + +```{r} +combined_ci_for_plotting |> + filter(.variable == "observed_hospital_admissions") |> + ggplot(aes(time, .value)) + + facet_wrap(~model) + + geom_lineribbon(aes(ymin = .lower, ymax = .upper), color = "#08519c") + + scale_fill_brewer( + name = "Credible Interval Width", + labels = ~ percent(as.numeric(.)) + ) + + geom_point(data = hosp_data) + + cowplot::theme_cowplot() + + ggtitle("Vignette Data Model Comparison") + + scale_y_continuous("Hospital Admissions") + + scale_x_continuous("Time") + + theme(legend.position = "bottom") +``` + +## Rt Comparions + +```{r} +combined_ci_for_plotting |> + filter(.variable == "state_rt") |> + ggplot(aes(time, .value)) + + facet_wrap(~model) + + geom_lineribbon(aes(ymin = .lower, ymax = .upper), color = "#08519c") + + scale_fill_brewer( + name = "Credible Interval Width", + labels = ~ percent(as.numeric(.)) + ) + + cowplot::theme_cowplot() + + ggtitle("Vignette Data Model Comparison") + + scale_y_log10("State Rt", breaks = scales::log_breaks(n = 6)) + + scale_x_continuous("Time") + + theme(legend.position = "bottom") + + geom_hline(yintercept = 1, linetype = "dashed") +``` + +## Subpopulation Rt Comparions +```{r} +combined_ci_for_plotting |> + filter(.variable == "r_subpop_t") |> + mutate(group = if_else(model == "pyrenew", group + 1, group)) |> #adjust for index python starting from 1 + ggplot(aes(time, .value)) + + facet_grid(rows = vars(group), cols = vars(model)) + + geom_lineribbon(aes(ymin = .lower, ymax = .upper), color = "#08519c") + + scale_fill_brewer( + name = "Credible Interval Width", + labels = ~ percent(as.numeric(.)) + ) + + cowplot::theme_cowplot() + + ggtitle("Vignette Data Model Comparison") + + scale_y_log10("Subpopulation Rt") + + scale_x_continuous("Time") + + theme(legend.position = "bottom") +``` + +## IHR Comparison + +```{r} +combined_ci_for_plotting |> + filter(.variable == "ihr") |> + ggplot(aes(time, .value)) + + facet_wrap(~model) + + geom_lineribbon(aes(ymin = .lower, ymax = .upper), color = "#08519c") + + scale_fill_brewer( + name = "Credible Interval Width", + labels = ~ percent(as.numeric(.)) + ) + + cowplot::theme_cowplot() + + ggtitle("Vignette Data Model Comparison") + + scale_y_log10("IHR (p_hosp)", breaks = scales::log_breaks(n = 6)) + + scale_x_continuous("Time") + + theme(legend.position = "bottom") +``` + +IHR lengths are different (Stan model generates an unnecessarily long version, see https://github.com/CDCgov/ww-inference-model/issues/43#issuecomment-2330269879) diff --git a/demos/ww_model/ww_model_demo.qmd b/demos/ww_model/ww_model_demo.qmd index 15db096f..adcdff85 100644 --- a/demos/ww_model/ww_model_demo.qmd +++ b/demos/ww_model/ww_model_demo.qmd @@ -336,8 +336,8 @@ Fit the model to observed hospital admissions and wastewater data ```{python} # | label: model fit my_model.run( - num_warmup=100, - num_samples=100, + num_warmup=750, + num_samples=500, rng_key=jax.random.key(223), data_observed_hospital_admissions=data_observed_hospital_admissions, data_observed_log_conc=data_observed_log_conc, @@ -434,8 +434,8 @@ my_model_hosp_only_fit = ww_site_level_dynamics_model( ```{python} # | label: hosp only fit my_model_hosp_only_fit.run( - num_warmup=100, - num_samples=100, + num_warmup=750, + num_samples=500, rng_key=jax.random.key(223), data_observed_hospital_admissions=data_observed_hospital_admissions, mcmc_args=dict(num_chains=4) @@ -461,3 +461,8 @@ plotting.plot_predictive(idata_hosp_only,name='observed_hospital_admissions') ```{python} plotting.plot_predictive(idata_hosp_only,'r_subpop_t') ``` + + +```{python} +idata.to_dataframe().to_csv("data/fit/inference_data.csv", index=False) +``` \ No newline at end of file diff --git a/pyrenew_hew/ww_site_level_dynamics_model.py b/pyrenew_hew/ww_site_level_dynamics_model.py index 50fbe7d7..dacd551f 100644 --- a/pyrenew_hew/ww_site_level_dynamics_model.py +++ b/pyrenew_hew/ww_site_level_dynamics_model.py @@ -87,9 +87,7 @@ def __init__( self.sigma_rt_rv = sigma_rt_rv self.i_first_obs_over_n_rv = i_first_obs_over_n_rv self.sigma_i_first_obs_rv = sigma_i_first_obs_rv - self.sigma_initial_exp_growth_rate_rv = ( - sigma_initial_exp_growth_rate_rv - ) + self.sigma_initial_exp_growth_rate_rv = sigma_initial_exp_growth_rate_rv self.mean_initial_exp_growth_rate_rv = mean_initial_exp_growth_rate_rv self.offset_ref_logit_i_first_obs_rv = offset_ref_logit_i_first_obs_rv self.offset_ref_initial_exp_growth_rate_rv = ( @@ -141,9 +139,7 @@ def sample( data_observed_log_conc=None, is_predictive=False, ): # numpydoc ignore=GL08 - if ( - n_datapoints is None - ): # calculate model calibration period based on data + if n_datapoints is None: # calculate model calibration period based on data if ( data_observed_hospital_admissions is None and data_observed_log_conc is None @@ -218,11 +214,8 @@ def sample( transforms.logit(i_first_obs_over_n) + jnp.where(self.n_subpops > 1, offset_ref_logit_i_first_obs, 0) ) - initial_exp_growth_rate_ref_subpop = ( - mean_initial_exp_growth_rate - + jnp.where( - self.n_subpops > 1, offset_ref_initial_exp_growth_rate, 0 - ) + initial_exp_growth_rate_ref_subpop = mean_initial_exp_growth_rate + jnp.where( + self.n_subpops > 1, offset_ref_initial_exp_growth_rate, 0 ) offset_ref_log_r_t = self.offset_ref_log_r_t_rv() @@ -247,9 +240,7 @@ def sample( ), transforms=transforms.SigmoidTransform(), ) - sigma_initial_exp_growth_rate = ( - self.sigma_initial_exp_growth_rate_rv() - ) + sigma_initial_exp_growth_rate = self.sigma_initial_exp_growth_rate_rv() initial_exp_growth_rate_non_ref_subpop_rv = DistributionalVariable( "initial_exp_growth_rate_non_ref_subpop_raw", dist.Normal( @@ -311,9 +302,7 @@ def sample( axis=1, ) - numpyro.deterministic( - "i_first_obs_over_n_subpop", i_first_obs_over_n_subpop - ) + numpyro.deterministic("i_first_obs_over_n_subpop", i_first_obs_over_n_subpop) numpyro.deterministic( "initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop ) @@ -333,9 +322,7 @@ def sample( ) numpyro.deterministic("rtu_subpop", rtu_subpop) - i0_subpop_rv = DeterministicVariable( - "i0_subpop", jnp.exp(log_i0_subpop) - ) + i0_subpop_rv = DeterministicVariable("i0_subpop", jnp.exp(log_i0_subpop)) initial_exp_growth_rate_subpop_rv = DeterministicVariable( "initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop ) @@ -374,9 +361,7 @@ def sample( r_subpop_t = inf_with_feedback_proc_sample.rt numpyro.deterministic("r_subpop_t", r_subpop_t) - state_inf_per_capita = jnp.sum( - self.pop_fraction * new_i_subpop, axis=1 - ) + state_inf_per_capita = jnp.sum(self.pop_fraction * new_i_subpop, axis=1) numpyro.deterministic("state_inf_per_capita", state_inf_per_capita) # Hospital admission component @@ -413,13 +398,11 @@ def sample( hosp_wday_effect = tile_until_n(hosp_wday_effect_raw, n_datapoints) - potential_latent_hospital_admissions = ( - compute_delay_ascertained_incidence( - p_observed_given_incident=1, - latent_incidence=state_inf_per_capita, - delay_incidence_to_observation_pmf=inf_to_hosp, - )[-n_datapoints:] - ) + potential_latent_hospital_admissions = compute_delay_ascertained_incidence( + p_observed_given_incident=1, + latent_incidence=state_inf_per_capita, + delay_incidence_to_observation_pmf=inf_to_hosp, + )[-n_datapoints:] latent_hospital_admissions = ( potential_latent_hospital_admissions @@ -427,9 +410,7 @@ def sample( * hosp_wday_effect * self.state_pop ) - numpyro.deterministic( - "latent_hospital_admissions", latent_hospital_admissions - ) + numpyro.deterministic("latent_hospital_admissions", latent_hospital_admissions) hospital_admission_obs_rv = NegativeBinomialObservation( "observed_hospital_admissions", concentration_rv=self.phi_rv @@ -481,9 +462,7 @@ def batch_colvolve_fn(m): "sigma_ww_site", DistributionalVariable( "log_sigma_ww_site", - dist.Normal( - jnp.log(mode_sigma_ww_site), sd_log_sigma_ww_site - ), + dist.Normal(jnp.log(mode_sigma_ww_site), sd_log_sigma_ww_site), reparam=LocScaleReparam(0), ), transforms=transforms.ExpTransform(), @@ -499,17 +478,13 @@ def batch_colvolve_fn(m): ] # multiply the expected observed genomes by the site-specific multiplier at that sampling time - exp_obs_log_v = ( - exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] - ) + exp_obs_log_v = exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] numpyro.sample( "log_conc_obs", dist.Normal( loc=exp_obs_log_v[self.ww_uncensored], - scale=sigma_ww_site[ - self.ww_sampled_lab_sites[self.ww_uncensored] - ], + scale=sigma_ww_site[self.ww_sampled_lab_sites[self.ww_uncensored]], ), obs=( data_observed_log_conc[self.ww_uncensored] @@ -520,24 +495,21 @@ def batch_colvolve_fn(m): if self.ww_censored.shape[0] != 0: log_cdf_values = dist.Normal( loc=exp_obs_log_v[self.ww_censored], - scale=sigma_ww_site[ - self.ww_sampled_lab_sites[self.ww_censored] - ], + scale=sigma_ww_site[self.ww_sampled_lab_sites[self.ww_censored]], ).log_cdf(self.ww_log_lod[self.ww_censored]) - numpyro.factor("log_prob_censored", log_cdf_values.sum()) + numpyro.factor("log_prob_censored", log_cdf_values.sum()) site_ww_pred_log = numpyro.sample( "site_ww_pred_log", dist.Normal( - loc=model_log_v_ot[:, self.lab_site_to_subpop_map] - + ww_site_mod, + loc=model_log_v_ot[:, self.lab_site_to_subpop_map] + ww_site_mod, scale=sigma_ww_site, ), ) - state_model_net_i = jnp.convolve( - state_inf_per_capita, s, mode="valid" - )[-n_datapoints:] + state_model_net_i = jnp.convolve(state_inf_per_capita, s, mode="valid")[ + -n_datapoints: + ] numpyro.deterministic("state_model_net_i", state_model_net_i) state_log_c = ( @@ -548,17 +520,13 @@ def batch_colvolve_fn(m): numpyro.deterministic("state_log_c", state_log_c) expected_state_ww_conc = jnp.exp(state_log_c) - numpyro.deterministic( - "expected_state_ww_conc", expected_state_ww_conc - ) + numpyro.deterministic("expected_state_ww_conc", expected_state_ww_conc) state_rt = ( state_inf_per_capita[-n_datapoints:] / jnp.convolve( state_inf_per_capita, - jnp.hstack( - (jnp.array([0]), jnp.array(generation_interval_pmf)) - ), + jnp.hstack((jnp.array([0]), jnp.array(generation_interval_pmf))), mode="valid", )[-n_datapoints:] ) From 949f231a290459bef0b23e46c49b924672ea7550 Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Wed, 11 Dec 2024 11:14:23 -0600 Subject: [PATCH 49/50] make pre-commit happy --- demos/ww_model/model_comp.qmd | 6 +- demos/ww_model/ww_model_demo.qmd | 2 +- pyrenew_hew/ww_site_level_dynamics_model.py | 80 ++++++++++++++------- 3 files changed, 60 insertions(+), 28 deletions(-) diff --git a/demos/ww_model/model_comp.qmd b/demos/ww_model/model_comp.qmd index 7f2ff749..deb24f8a 100644 --- a/demos/ww_model/model_comp.qmd +++ b/demos/ww_model/model_comp.qmd @@ -60,11 +60,11 @@ pyrenew_tidy_draws <- combined_ci_for_plotting <- bind_rows( deframe(pyrenew_tidy_draws)$posterior_predictive |> - gather_draws(observed_hospital_admissions[time], state_rt[time], ihr[time], r_subpop_t[time,group]) |> + gather_draws(observed_hospital_admissions[time], state_rt[time], ihr[time], r_subpop_t[time, group]) |> median_qi(.width = ci_width) |> mutate(model = "pyrenew"), stan_tidy_draws |> - gather_draws(pred_hosp[time], rt[time], p_hosp[time],r_subpop_t[group,time]) |> + gather_draws(pred_hosp[time], rt[time], p_hosp[time], r_subpop_t[group, time]) |> mutate(.variable = case_when( .variable == "pred_hosp" ~ "observed_hospital_admissions", .variable == "p_hosp" ~ "ihr", @@ -120,7 +120,7 @@ combined_ci_for_plotting |> ```{r} combined_ci_for_plotting |> filter(.variable == "r_subpop_t") |> - mutate(group = if_else(model == "pyrenew", group + 1, group)) |> #adjust for index python starting from 1 + mutate(group = if_else(model == "pyrenew", group + 1, group)) |> # adjust for index python starting from 1 ggplot(aes(time, .value)) + facet_grid(rows = vars(group), cols = vars(model)) + geom_lineribbon(aes(ymin = .lower, ymax = .upper), color = "#08519c") + diff --git a/demos/ww_model/ww_model_demo.qmd b/demos/ww_model/ww_model_demo.qmd index adcdff85..9d7d2932 100644 --- a/demos/ww_model/ww_model_demo.qmd +++ b/demos/ww_model/ww_model_demo.qmd @@ -465,4 +465,4 @@ plotting.plot_predictive(idata_hosp_only,'r_subpop_t') ```{python} idata.to_dataframe().to_csv("data/fit/inference_data.csv", index=False) -``` \ No newline at end of file +``` diff --git a/pyrenew_hew/ww_site_level_dynamics_model.py b/pyrenew_hew/ww_site_level_dynamics_model.py index dacd551f..fca6d313 100644 --- a/pyrenew_hew/ww_site_level_dynamics_model.py +++ b/pyrenew_hew/ww_site_level_dynamics_model.py @@ -87,7 +87,9 @@ def __init__( self.sigma_rt_rv = sigma_rt_rv self.i_first_obs_over_n_rv = i_first_obs_over_n_rv self.sigma_i_first_obs_rv = sigma_i_first_obs_rv - self.sigma_initial_exp_growth_rate_rv = sigma_initial_exp_growth_rate_rv + self.sigma_initial_exp_growth_rate_rv = ( + sigma_initial_exp_growth_rate_rv + ) self.mean_initial_exp_growth_rate_rv = mean_initial_exp_growth_rate_rv self.offset_ref_logit_i_first_obs_rv = offset_ref_logit_i_first_obs_rv self.offset_ref_initial_exp_growth_rate_rv = ( @@ -139,7 +141,9 @@ def sample( data_observed_log_conc=None, is_predictive=False, ): # numpydoc ignore=GL08 - if n_datapoints is None: # calculate model calibration period based on data + if ( + n_datapoints is None + ): # calculate model calibration period based on data if ( data_observed_hospital_admissions is None and data_observed_log_conc is None @@ -214,8 +218,11 @@ def sample( transforms.logit(i_first_obs_over_n) + jnp.where(self.n_subpops > 1, offset_ref_logit_i_first_obs, 0) ) - initial_exp_growth_rate_ref_subpop = mean_initial_exp_growth_rate + jnp.where( - self.n_subpops > 1, offset_ref_initial_exp_growth_rate, 0 + initial_exp_growth_rate_ref_subpop = ( + mean_initial_exp_growth_rate + + jnp.where( + self.n_subpops > 1, offset_ref_initial_exp_growth_rate, 0 + ) ) offset_ref_log_r_t = self.offset_ref_log_r_t_rv() @@ -240,7 +247,9 @@ def sample( ), transforms=transforms.SigmoidTransform(), ) - sigma_initial_exp_growth_rate = self.sigma_initial_exp_growth_rate_rv() + sigma_initial_exp_growth_rate = ( + self.sigma_initial_exp_growth_rate_rv() + ) initial_exp_growth_rate_non_ref_subpop_rv = DistributionalVariable( "initial_exp_growth_rate_non_ref_subpop_raw", dist.Normal( @@ -302,7 +311,9 @@ def sample( axis=1, ) - numpyro.deterministic("i_first_obs_over_n_subpop", i_first_obs_over_n_subpop) + numpyro.deterministic( + "i_first_obs_over_n_subpop", i_first_obs_over_n_subpop + ) numpyro.deterministic( "initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop ) @@ -322,7 +333,9 @@ def sample( ) numpyro.deterministic("rtu_subpop", rtu_subpop) - i0_subpop_rv = DeterministicVariable("i0_subpop", jnp.exp(log_i0_subpop)) + i0_subpop_rv = DeterministicVariable( + "i0_subpop", jnp.exp(log_i0_subpop) + ) initial_exp_growth_rate_subpop_rv = DeterministicVariable( "initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop ) @@ -361,7 +374,9 @@ def sample( r_subpop_t = inf_with_feedback_proc_sample.rt numpyro.deterministic("r_subpop_t", r_subpop_t) - state_inf_per_capita = jnp.sum(self.pop_fraction * new_i_subpop, axis=1) + state_inf_per_capita = jnp.sum( + self.pop_fraction * new_i_subpop, axis=1 + ) numpyro.deterministic("state_inf_per_capita", state_inf_per_capita) # Hospital admission component @@ -398,11 +413,13 @@ def sample( hosp_wday_effect = tile_until_n(hosp_wday_effect_raw, n_datapoints) - potential_latent_hospital_admissions = compute_delay_ascertained_incidence( - p_observed_given_incident=1, - latent_incidence=state_inf_per_capita, - delay_incidence_to_observation_pmf=inf_to_hosp, - )[-n_datapoints:] + potential_latent_hospital_admissions = ( + compute_delay_ascertained_incidence( + p_observed_given_incident=1, + latent_incidence=state_inf_per_capita, + delay_incidence_to_observation_pmf=inf_to_hosp, + )[-n_datapoints:] + ) latent_hospital_admissions = ( potential_latent_hospital_admissions @@ -410,7 +427,9 @@ def sample( * hosp_wday_effect * self.state_pop ) - numpyro.deterministic("latent_hospital_admissions", latent_hospital_admissions) + numpyro.deterministic( + "latent_hospital_admissions", latent_hospital_admissions + ) hospital_admission_obs_rv = NegativeBinomialObservation( "observed_hospital_admissions", concentration_rv=self.phi_rv @@ -462,7 +481,9 @@ def batch_colvolve_fn(m): "sigma_ww_site", DistributionalVariable( "log_sigma_ww_site", - dist.Normal(jnp.log(mode_sigma_ww_site), sd_log_sigma_ww_site), + dist.Normal( + jnp.log(mode_sigma_ww_site), sd_log_sigma_ww_site + ), reparam=LocScaleReparam(0), ), transforms=transforms.ExpTransform(), @@ -478,13 +499,17 @@ def batch_colvolve_fn(m): ] # multiply the expected observed genomes by the site-specific multiplier at that sampling time - exp_obs_log_v = exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] + exp_obs_log_v = ( + exp_obs_log_v_true + ww_site_mod[self.ww_sampled_lab_sites] + ) numpyro.sample( "log_conc_obs", dist.Normal( loc=exp_obs_log_v[self.ww_uncensored], - scale=sigma_ww_site[self.ww_sampled_lab_sites[self.ww_uncensored]], + scale=sigma_ww_site[ + self.ww_sampled_lab_sites[self.ww_uncensored] + ], ), obs=( data_observed_log_conc[self.ww_uncensored] @@ -495,21 +520,24 @@ def batch_colvolve_fn(m): if self.ww_censored.shape[0] != 0: log_cdf_values = dist.Normal( loc=exp_obs_log_v[self.ww_censored], - scale=sigma_ww_site[self.ww_sampled_lab_sites[self.ww_censored]], + scale=sigma_ww_site[ + self.ww_sampled_lab_sites[self.ww_censored] + ], ).log_cdf(self.ww_log_lod[self.ww_censored]) numpyro.factor("log_prob_censored", log_cdf_values.sum()) site_ww_pred_log = numpyro.sample( "site_ww_pred_log", dist.Normal( - loc=model_log_v_ot[:, self.lab_site_to_subpop_map] + ww_site_mod, + loc=model_log_v_ot[:, self.lab_site_to_subpop_map] + + ww_site_mod, scale=sigma_ww_site, ), ) - state_model_net_i = jnp.convolve(state_inf_per_capita, s, mode="valid")[ - -n_datapoints: - ] + state_model_net_i = jnp.convolve( + state_inf_per_capita, s, mode="valid" + )[-n_datapoints:] numpyro.deterministic("state_model_net_i", state_model_net_i) state_log_c = ( @@ -520,13 +548,17 @@ def batch_colvolve_fn(m): numpyro.deterministic("state_log_c", state_log_c) expected_state_ww_conc = jnp.exp(state_log_c) - numpyro.deterministic("expected_state_ww_conc", expected_state_ww_conc) + numpyro.deterministic( + "expected_state_ww_conc", expected_state_ww_conc + ) state_rt = ( state_inf_per_capita[-n_datapoints:] / jnp.convolve( state_inf_per_capita, - jnp.hstack((jnp.array([0]), jnp.array(generation_interval_pmf))), + jnp.hstack( + (jnp.array([0]), jnp.array(generation_interval_pmf)) + ), mode="valid", )[-n_datapoints:] ) From 9e6e8c5664e2df0d35d4e6a638d85e675e0ebd7c Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Wed, 11 Dec 2024 11:18:21 -0600 Subject: [PATCH 50/50] more pre-commit --- demos/ww_model/model_comp.qmd | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/demos/ww_model/model_comp.qmd b/demos/ww_model/model_comp.qmd index deb24f8a..5c75cd9d 100644 --- a/demos/ww_model/model_comp.qmd +++ b/demos/ww_model/model_comp.qmd @@ -60,11 +60,17 @@ pyrenew_tidy_draws <- combined_ci_for_plotting <- bind_rows( deframe(pyrenew_tidy_draws)$posterior_predictive |> - gather_draws(observed_hospital_admissions[time], state_rt[time], ihr[time], r_subpop_t[time, group]) |> + gather_draws( + observed_hospital_admissions[time], state_rt[time], + ihr[time], r_subpop_t[time, group] + ) |> median_qi(.width = ci_width) |> mutate(model = "pyrenew"), stan_tidy_draws |> - gather_draws(pred_hosp[time], rt[time], p_hosp[time], r_subpop_t[group, time]) |> + gather_draws( + pred_hosp[time], rt[time], p_hosp[time], + r_subpop_t[group, time] + ) |> mutate(.variable = case_when( .variable == "pred_hosp" ~ "observed_hospital_admissions", .variable == "p_hosp" ~ "ihr", @@ -117,10 +123,12 @@ combined_ci_for_plotting |> ``` ## Subpopulation Rt Comparions + ```{r} combined_ci_for_plotting |> filter(.variable == "r_subpop_t") |> - mutate(group = if_else(model == "pyrenew", group + 1, group)) |> # adjust for index python starting from 1 + mutate(group = if_else(model == "pyrenew", group + 1, group)) |> + # adjust for index python starting from 1 ggplot(aes(time, .value)) + facet_grid(rows = vars(group), cols = vars(model)) + geom_lineribbon(aes(ymin = .lower, ymax = .upper), color = "#08519c") +