Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit 39c003d
Author: Subekshya Bidari <[email protected]>
Date:   Wed Jan 29 16:20:53 2025 -0500

    wastewater observation process (#310)

    * add ww obs process

    * pre-commit

    * update build pyrenew model

    * fix output assignment in test latent infection process

    * include some rvs as placeholders

    * change prior

    * use priors from stan model

    * fix typo

    * move data args to sample method

    * pre-commit

    * update prod_priors.py

    * sync

    * update generate_predictive.py

    * pre-commit

    * move model constant to prior as per dhm suggestion

    * make get_viral_trajectory a class method

    * remove outdated call to utils

    * output assignment

    * state -> population

    * code review suggestions

    * code review suggestions

    * update distributional variable sample call

    * Apply suggestions from code review

    Co-authored-by: Dylan H. Morris <[email protected]>

    * pre-commit

    * Apply suggestions from code review

    Co-authored-by: Dylan H. Morris <[email protected]>

    ---------

    Co-authored-by: Dylan H. Morris <[email protected]>

commit df3648d
Author: Dylan H. Morris <[email protected]>
Date:   Tue Jan 28 17:20:34 2025 +0000

    Improvements and bug fixes for epiweekly other hubverse tables and plots (#312)
  • Loading branch information
damonbayer committed Jan 29, 2025
1 parent 9121ea5 commit f652b42
Show file tree
Hide file tree
Showing 11 changed files with 484 additions and 124 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ repos:
#####
# R
- repo: https://github.com/lorenzwalthert/precommit
rev: v0.4.3.9001
rev: v0.4.3
hooks:
- id: style-files
- id: lintr
Expand Down
9 changes: 8 additions & 1 deletion demos/ww_model/ww_model_demo.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ from pyrenew.deterministic import DeterministicVariable, DeterministicPMF
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable
from pyrenew_hew.ww_site_level_dynamics_model import ww_site_level_dynamics_model
from pyrenew_hew.utils import convert_to_logmean_log_sd
import pyrenew_hew.plotting as plotting
numpyro.set_host_device_count(4)
Expand All @@ -29,6 +28,14 @@ We will use the data used in the `wwinference` [vignette](https://github.com/CDC
with open("data/fit/stan_data.json","r") as file:
stan_data = json.load(file)
# define functions called later
def convert_to_logmean_log_sd(mean, sd):
logmean = jnp.log(
jnp.power(mean, 2) / jnp.sqrt(jnp.power(sd, 2) + jnp.power(mean, 2))
)
logsd = jnp.sqrt(jnp.log(1 + (jnp.power(sd, 2) / jnp.power(mean, 2))))
return logmean, logsd
```

```{python}
Expand Down
111 changes: 95 additions & 16 deletions hewr/tests/testthat/test_to_epiweekly_quantile_table.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,20 @@ test_that("to_epiweekly_quantiles works as expected", {
) |> suppressMessages()

expect_s3_class(result, "tbl_df")
expect_setequal(c(
"epiweek", "epiyear", "quantile_value", "quantile_level", "location"
), colnames(result))
checkmate::expect_names(
colnames(result),
identical.to = c(
"epiweek",
"epiyear",
"quantile_value",
"quantile_level",
"location",
"source_samples"
)
)

expect_equal(draws_file_name, unique(result$source_samples))

expect_gt(nrow(result), 0)
}

Expand Down Expand Up @@ -127,7 +138,11 @@ test_that("to_epiweekly_quantiles handles missing forecast files", {


# tests for `to_epiweekly_quantile_table`
test_that("to_epiweekly_quantile_table handles multiple locations", {
test_that(paste0(
"to_epiweekly_quantile_table ",
"handles multiple locations ",
"and multiple source files"
), {
batch_dir_name <- "covid-19_r_2024-12-14_f_2024-12-08_t_2024-12-14"
tempdir <- withr::local_tempdir()

Expand All @@ -142,6 +157,17 @@ test_that("to_epiweekly_quantile_table handles multiple locations", {
if (loc != "loc3") {
disease_cols <- c(disease_cols, "prop_disease_ed_visits")
}
create_tidy_forecast_data(
directory = loc_dir,
filename = "epiweekly_with_epiweekly_other_samples.parquet",
date_cols = seq(
lubridate::ymd("2024-12-08"), lubridate::ymd("2024-12-14"),
by = "week"
),
disease_cols = disease_cols,
n_draw = 25,
with_epiweek = TRUE
)

create_tidy_forecast_data(
directory = loc_dir,
Expand All @@ -157,7 +183,10 @@ test_that("to_epiweekly_quantile_table handles multiple locations", {
})

## should succeed despite loc3 not having valid draws with strict = FALSE
result_w_both_locations <- to_epiweekly_quantile_table(temp_batch_dir) |>
result_w_both_locations <-
to_epiweekly_quantile_table(temp_batch_dir,
epiweekly_other_locations = "loc1"
) |>
suppressMessages()

## should error if strict = TRUE because loc3 does not have
Expand All @@ -168,6 +197,44 @@ test_that("to_epiweekly_quantile_table handles multiple locations", {
"did not find valid draws"
)

## should succeed with strict = TRUE if loc3 is excluded
alt_result_w_both_locations <- (
to_epiweekly_quantile_table(temp_batch_dir,
strict = TRUE,
exclude = "loc3"
)) |>
suppressMessages()

## results should be equivalent for loc2,
## but not for loc1
expect_equal(
result_w_both_locations |>
dplyr::filter(location == "loc2"),
alt_result_w_both_locations |>
dplyr::filter(location == "loc2")
)

## check that one used epiweekly
## other for loc1 while other used
## default, resulting in different values
loc1_a <- result_w_both_locations |>
dplyr::filter(location == "loc1") |>
dplyr::pull(.data$value)
loc1_b <- alt_result_w_both_locations |>
dplyr::filter(location == "loc1") |>
dplyr::pull(.data$value)

## length checks ensure that the
## number of allowed equalities _could_
## be reached if the vectors were mostly
## or entirely identical
expect_gt(length(loc1_a), 10)
expect_gt(length(loc1_b), 10)
expect_lt(
sum(loc1_a == loc1_b),
5
)

expect_s3_class(result_w_both_locations, "tbl_df")
expect_gt(nrow(result_w_both_locations), 0)
checkmate::expect_names(
Expand All @@ -181,20 +248,32 @@ test_that("to_epiweekly_quantile_table handles multiple locations", {
"output_type",
"output_type_id",
"value",
"other_ed_visit_forecast"
"source_samples"
)
)
expect_setequal(
c("loc1", "loc2"),
result_w_both_locations$location
result_w_both_locations$location,
c("loc1", "loc2")
)
expect_setequal(
alt_result_w_both_locations$location,
c("loc1", "loc2")
)
expect_false("loc3" %in% result_w_both_locations$location)

result_w_one_location <- to_epiweekly_quantile_table(
model_batch_dir = temp_batch_dir,
exclude = "loc1"
) |>
suppressMessages()
expect_true("loc2" %in% result_w_one_location$location)
expect_false("loc1" %in% result_w_one_location$location)
expect_setequal(
result_w_both_locations$source_samples,
c(
"epiweekly_samples",
"epiweekly_with_epiweekly_other_samples"
)
)

expect_setequal(
alt_result_w_both_locations$source_samples,
"epiweekly_samples"
)


expect_false("loc3" %in% result_w_both_locations$location)
expect_false("loc3" %in% alt_result_w_both_locations$location)
})
12 changes: 10 additions & 2 deletions pipelines/build_pyrenew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,16 @@ def build_model_from_dir(
ihr_rv=priors["ihr_rv"],
)

# placeholder
my_wastewater_obs_model = WastewaterObservationProcess()
my_wastewater_obs_model = WastewaterObservationProcess(
t_peak_rv=priors["t_peak_rv"],
duration_shed_after_peak_rv=priors["duration_shed_after_peak_rv"],
log10_genome_per_inf_ind_rv=priors["log10_genome_per_inf_ind_rv"],
mode_sigma_ww_site_rv=priors["mode_sigma_ww_site_rv"],
sd_log_sigma_ww_site_rv=priors["sd_log_sigma_ww_site_rv"],
mode_sd_ww_site_rv=priors["mode_sd_ww_site_rv"],
max_shed_interval=priors["max_shed_interval"],
ww_ml_produced_per_day=priors["ww_ml_produced_per_day"],
)

my_model = PyrenewHEWModel(
population_size=population_size,
Expand Down
2 changes: 1 addition & 1 deletion pipelines/generate_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def generate_and_save_predictions(
data=forecast_data,
sample_ed_visits=True,
sample_hospital_admissions=True,
sample_wastewater=True,
sample_wastewater=False,
)

idata = az.from_numpyro(
Expand Down
27 changes: 27 additions & 0 deletions pipelines/priors/prod_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,30 @@
hosp_admit_neg_bin_concentration_rv = DistributionalVariable(
"hosp_admit_neg_bin_concentration", dist.LogNormal(4, 2)
)

t_peak_rv = DistributionalVariable("t_peak", dist.TruncatedNormal(5, 1, low=0))

duration_shed_after_peak_rv = DistributionalVariable(
"durtion_shed_after_peak", dist.TruncatedNormal(12, 3, low=0)
)

log10_genome_per_inf_ind_rv = DistributionalVariable(
"log10_genome_per_inf_ind", dist.Normal(12, 2)
)

mode_sigma_ww_site_rv = DistributionalVariable(
"mode_sigma_ww_site",
dist.TruncatedNormal(1, 1, low=0),
)

sd_log_sigma_ww_site_rv = DistributionalVariable(
"sd_log_sigma_ww_site", dist.TruncatedNormal(0, 0.693, low=0)
)

mode_sd_ww_site_rv = DistributionalVariable(
"mode_sd_ww_site", dist.TruncatedNormal(0, 0.25, low=0)
)

# model constants related to wastewater obs process
ww_ml_produced_per_day = 227000
max_shed_interval = 26
Loading

0 comments on commit f652b42

Please sign in to comment.