Skip to content

Commit

Permalink
use priors from stan model
Browse files Browse the repository at this point in the history
  • Loading branch information
sbidari committed Jan 23, 2025
1 parent c589679 commit 3607397
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
25 changes: 18 additions & 7 deletions pipelines/build_pyrenew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from pathlib import Path

import jax.numpy as jnp
import numpyro.distributions as dist

Check warning on line 7 in pipelines/build_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/build_pyrenew_model.py#L7

Added line #L7 was not covered by tests
from pyrenew.deterministic import DeterministicVariable
from pyrenew.randomvariable import DistributionalVariable

Check warning on line 9 in pipelines/build_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/build_pyrenew_model.py#L9

Added line #L9 was not covered by tests

from pyrenew_hew.pyrenew_hew_data import PyrenewHEWData
from pyrenew_hew.pyrenew_hew_model import (
Expand Down Expand Up @@ -111,16 +113,25 @@ def build_model_from_dir(

# placeholder
my_wastewater_obs_model = WastewaterObservationProcess(

Check warning on line 115 in pipelines/build_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/build_pyrenew_model.py#L115

Added line #L115 was not covered by tests
t_peak_rv=DeterministicVariable("t_peak", 1),
dur_shed_after_peak_rv=DeterministicVariable("dur_shed_after_peak", 1),
t_peak_rv=DistributionalVariable(
"t_peak", dist.TruncatedNormal(5, 1, low=0)
),
dur_shed_after_peak_rv=DistributionalVariable(
"dur_shed_after_peak", dist.TruncatedNormal(12, 3, low=0)
),
log10_genome_per_inf_ind_rv=DeterministicVariable(
"log10_genome_per_inf_ind", 0
"log10_genome_per_inf_ind", dist.Normal(12, 2)
),
mode_sigma_ww_site_rv=DistributionalVariable(
"mode_sigma_ww_site",
dist.TruncatedNormal(1, 1, low=0),
),
sd_log_sigma_ww_site_rv=DistributionalVariable(
"sd_log_sigma_ww_site", dist.TruncatedNormal(0, 0.693, low=0)
),
mode_sigma_ww_site_rv=DeterministicVariable("mode_sigma_ww_site", 0),
sd_log_sigma_ww_site_rv=DeterministicVariable(
"sd_log_sigma_ww_site", 0
mode_sd_ww_site_rv=DistributionalVariable(
"mode_sd_ww_site", dist.TruncatedNormal(0, 0.25, low=0)
),
mode_sd_ww_site_rv=DeterministicVariable("mode_sd_ww_site", 0),
ww_ml_produced_per_day=None,
ww_uncensored=None,
ww_censored=None,
Expand Down
6 changes: 3 additions & 3 deletions pyrenew_hew/pyrenew_hew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,12 +521,12 @@ def sample(
):
t_peak = self.t_peak_rv()
dur_shed = self.dur_shed_after_peak_rv()
viral_kinetics_trajectory = get_viral_trajectory(
viral_kinetics = get_viral_trajectory(

Check warning on line 524 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L522-L524

Added lines #L522 - L524 were not covered by tests
t_peak, dur_shed, self.max_shed_interval
)

def batch_colvolve_fn(m):
return jnp.convolve(m, viral_kinetics_trajectory, mode="valid")
return jnp.convolve(m, viral_kinetics, mode="valid")

Check warning on line 529 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L528-L529

Added lines #L528 - L529 were not covered by tests

model_net_inf_ind_shedding = jax.vmap(

Check warning on line 531 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L531

Added line #L531 was not covered by tests
batch_colvolve_fn, in_axes=1, out_axes=1
Expand Down Expand Up @@ -611,7 +611,7 @@ def batch_colvolve_fn(m):
)

state_net_inf_ind_shedding = jnp.convolve(

Check warning on line 613 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L613

Added line #L613 was not covered by tests
latent_infections, viral_kinetics_trajectory, mode="valid"
latent_infections, viral_kinetics, mode="valid"
)[-n_datapoints:]
numpyro.deterministic(

Check warning on line 616 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L616

Added line #L616 was not covered by tests
"state_net_inf_ind_shedding", state_net_inf_ind_shedding
Expand Down

0 comments on commit 3607397

Please sign in to comment.