From 3607397985830ba1feab76d4107f72cc672038d7 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 23 Jan 2025 18:39:19 -0500 Subject: [PATCH] use priors from stan model --- pipelines/build_pyrenew_model.py | 25 ++++++++++++++++++------- pyrenew_hew/pyrenew_hew_model.py | 6 +++--- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/pipelines/build_pyrenew_model.py b/pipelines/build_pyrenew_model.py index b898f6c6..90401eb7 100644 --- a/pipelines/build_pyrenew_model.py +++ b/pipelines/build_pyrenew_model.py @@ -4,7 +4,9 @@ from pathlib import Path import jax.numpy as jnp +import numpyro.distributions as dist from pyrenew.deterministic import DeterministicVariable +from pyrenew.randomvariable import DistributionalVariable from pyrenew_hew.pyrenew_hew_data import PyrenewHEWData from pyrenew_hew.pyrenew_hew_model import ( @@ -111,16 +113,25 @@ def build_model_from_dir( # placeholder my_wastewater_obs_model = WastewaterObservationProcess( - t_peak_rv=DeterministicVariable("t_peak", 1), - dur_shed_after_peak_rv=DeterministicVariable("dur_shed_after_peak", 1), + t_peak_rv=DistributionalVariable( + "t_peak", dist.TruncatedNormal(5, 1, low=0) + ), + dur_shed_after_peak_rv=DistributionalVariable( + "dur_shed_after_peak", dist.TruncatedNormal(12, 3, low=0) + ), log10_genome_per_inf_ind_rv=DeterministicVariable( - "log10_genome_per_inf_ind", 0 + "log10_genome_per_inf_ind", dist.Normal(12, 2) + ), + mode_sigma_ww_site_rv=DistributionalVariable( + "mode_sigma_ww_site", + dist.TruncatedNormal(1, 1, low=0), + ), + sd_log_sigma_ww_site_rv=DistributionalVariable( + "sd_log_sigma_ww_site", dist.TruncatedNormal(0, 0.693, low=0) ), - mode_sigma_ww_site_rv=DeterministicVariable("mode_sigma_ww_site", 0), - sd_log_sigma_ww_site_rv=DeterministicVariable( - "sd_log_sigma_ww_site", 0 + mode_sd_ww_site_rv=DistributionalVariable( + "mode_sd_ww_site", dist.TruncatedNormal(0, 0.25, low=0) ), - mode_sd_ww_site_rv=DeterministicVariable("mode_sd_ww_site", 0), ww_ml_produced_per_day=None, ww_uncensored=None, ww_censored=None, diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 36001e84..0da469fd 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -521,12 +521,12 @@ def sample( ): t_peak = self.t_peak_rv() dur_shed = self.dur_shed_after_peak_rv() - viral_kinetics_trajectory = get_viral_trajectory( + viral_kinetics = get_viral_trajectory( t_peak, dur_shed, self.max_shed_interval ) def batch_colvolve_fn(m): - return jnp.convolve(m, viral_kinetics_trajectory, mode="valid") + return jnp.convolve(m, viral_kinetics, mode="valid") model_net_inf_ind_shedding = jax.vmap( batch_colvolve_fn, in_axes=1, out_axes=1 @@ -611,7 +611,7 @@ def batch_colvolve_fn(m): ) state_net_inf_ind_shedding = jnp.convolve( - latent_infections, viral_kinetics_trajectory, mode="valid" + latent_infections, viral_kinetics, mode="valid" )[-n_datapoints:] numpyro.deterministic( "state_net_inf_ind_shedding", state_net_inf_ind_shedding