-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Small quality of life improvements to `pipelines/postprocess_state_fo…
…recast.R` (#154)
- Loading branch information
1 parent
72f816b
commit d6dd2c5
Showing
7 changed files
with
466 additions
and
387 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,10 @@ | ||
# Generated by roxygen2: do not edit by hand | ||
|
||
export(get_all_model_batch_dirs) | ||
export(make_forecast_figure) | ||
export(parse_model_batch_dir_path) | ||
export(parse_model_run_dir_path) | ||
export(process_state_forecast) | ||
export(to_epiweekly_quantile_table) | ||
export(to_epiweekly_quantiles) | ||
importFrom(rlang,.data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
#' Make Forecast Figure | ||
#' | ||
#' @param target_disease a disease matching the disease columns in combined_dat | ||
#' and forecast_ci | ||
#' @param combined_dat `combined_dat` from the result of process_state_forecast | ||
#' @param forecast_ci `forecast_ci` from the result of process_state_forecast | ||
#' @param disease_name "COVID-19" or "Influenza" | ||
#' @param data_vintage_date date that the data was collected | ||
#' @param y_transform a character passed as the transform argument to | ||
#' ggplot2::scale_y_continuous() | ||
#' | ||
#' @return a ggplot object | ||
#' @export | ||
|
||
make_forecast_figure <- function(target_disease, | ||
combined_dat, | ||
forecast_ci, | ||
disease_name = c("COVID-19", "Influenza"), | ||
data_vintage_date, | ||
y_transform = "identity") { | ||
disease_name <- rlang::arg_match(disease_name) | ||
disease_name_pretty <- c( | ||
"COVID-19" = "COVID-19", | ||
"Influenza" = "Flu" | ||
)[disease_name] | ||
state_abb <- unique(combined_dat$geo_value)[1] | ||
|
||
y_scale <- if (stringr::str_starts(target_disease, "prop")) { | ||
ggplot2::scale_y_continuous("Proportion of Emergency Department Visits", | ||
labels = scales::label_percent(), | ||
transform = y_transform | ||
) | ||
} else { | ||
ggplot2::scale_y_continuous("Emergency Department Visits", | ||
labels = scales::label_comma(), | ||
transform = y_transform | ||
) | ||
} | ||
|
||
|
||
title <- if (target_disease == "Other") { | ||
glue::glue("Other ED Visits in {state_abb}") | ||
} else { | ||
glue::glue("{disease_name_pretty} ED Visits in {state_abb}") | ||
} | ||
|
||
last_training_date <- combined_dat |> | ||
dplyr::filter(data_type == "train") |> | ||
dplyr::pull(date) |> | ||
max() | ||
|
||
ggplot2::ggplot(mapping = ggplot2::aes(date, .value)) + | ||
ggdist::geom_lineribbon( | ||
data = forecast_ci |> dplyr::filter(disease == target_disease), | ||
mapping = ggplot2::aes(ymin = .lower, ymax = .upper), | ||
color = "#08519c", | ||
key_glyph = ggplot2::draw_key_rect, | ||
step = "mid" | ||
) + | ||
ggplot2::scale_fill_brewer( | ||
name = "Credible Interval Width", | ||
labels = ~ scales::label_percent()(as.numeric(.)) | ||
) + | ||
ggplot2::geom_point( | ||
mapping = ggplot2::aes(color = data_type), size = 1.5, | ||
data = combined_dat |> | ||
dplyr::filter( | ||
disease == target_disease, | ||
date <= max(forecast_ci$date) | ||
) |> | ||
dplyr::mutate(data_type = forcats::fct_rev(data_type)) |> | ||
dplyr::arrange(dplyr::desc(data_type)) | ||
) + | ||
ggplot2::scale_color_manual( | ||
name = "Data Type", | ||
values = c("olivedrab1", "deeppink"), | ||
labels = stringr::str_to_title | ||
) + | ||
ggplot2::geom_vline(xintercept = last_training_date, linetype = "dashed") + | ||
ggplot2::annotate( | ||
geom = "text", | ||
x = last_training_date, | ||
y = -Inf, | ||
label = "Fit Period \u2190\n", | ||
hjust = "right", | ||
vjust = "bottom" | ||
) + | ||
ggplot2::annotate( | ||
geom = "text", | ||
x = last_training_date, | ||
y = -Inf, label = "\u2192 Forecast Period\n", | ||
hjust = "left", | ||
vjust = "bottom", | ||
) + | ||
ggplot2::ggtitle(title, | ||
subtitle = glue::glue("as of {data_vintage_date}") | ||
) + | ||
y_scale + | ||
ggplot2::scale_x_date("Date") + | ||
cowplot::theme_minimal_grid() + | ||
ggplot2::theme(legend.position = "bottom") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
#' Process state forecast | ||
#' | ||
#' @param model_run_dir Model run directory | ||
#' @param save Logical indicating whether or not to save | ||
#' | ||
#' @return a list with three tibbles: combined_dat, forecast_samples, | ||
#' and forecast_ci | ||
#' @export | ||
process_state_forecast <- function(model_run_dir, save = TRUE) { | ||
disease_name_nssp <- parse_model_run_dir_path(model_run_dir)$disease | ||
|
||
train_data_path <- fs::path(model_run_dir, "data", ext = "csv") | ||
train_dat <- readr::read_csv(train_data_path, show_col_types = FALSE) | ||
|
||
eval_data_path <- fs::path(model_run_dir, "eval_data", ext = "tsv") | ||
eval_dat <- readr::read_tsv(eval_data_path, show_col_types = FALSE) |> | ||
dplyr::mutate(data_type = "eval") | ||
|
||
posterior_predictive_path <- fs::path(model_run_dir, "mcmc_tidy", | ||
"pyrenew_posterior_predictive", | ||
ext = "parquet" | ||
) | ||
posterior_predictive <- arrow::read_parquet(posterior_predictive_path) | ||
|
||
|
||
other_ed_visits_path <- fs::path(model_run_dir, "other_ed_visits_forecast", | ||
ext = "parquet" | ||
) | ||
other_ed_visits_forecast <- arrow::read_parquet(other_ed_visits_path) |> | ||
dplyr::rename(Other = other_ed_visits) | ||
|
||
combined_dat <- | ||
dplyr::bind_rows( | ||
train_dat |> | ||
dplyr::filter(data_type == "train"), | ||
eval_dat | ||
) |> | ||
dplyr::mutate( | ||
disease = dplyr::if_else( | ||
disease == disease_name_nssp, | ||
"Disease", # assign a common name for | ||
# use in plotting functions | ||
disease | ||
) | ||
) |> | ||
tidyr::pivot_wider(names_from = disease, values_from = ed_visits) |> | ||
dplyr::mutate( | ||
Other = Total - Disease, | ||
prop_disease_ed_visits = Disease / Total | ||
) |> | ||
dplyr::select(-Total) |> | ||
dplyr::mutate(time = dplyr::dense_rank(date)) |> | ||
tidyr::pivot_longer(c(Disease, Other, prop_disease_ed_visits), | ||
names_to = "disease", | ||
values_to = ".value" | ||
) | ||
|
||
last_training_date <- combined_dat |> | ||
dplyr::filter(data_type == "train") |> | ||
dplyr::pull(date) |> | ||
max() | ||
|
||
other_ed_visits_samples <- | ||
dplyr::bind_rows( | ||
combined_dat |> | ||
dplyr::filter( | ||
data_type == "train", | ||
disease == "Other", | ||
date <= last_training_date | ||
) |> | ||
dplyr::select(date, Other = .value) |> | ||
tidyr::expand_grid(.draw = 1:max(other_ed_visits_forecast$.draw)), | ||
other_ed_visits_forecast | ||
) | ||
|
||
|
||
forecast_samples <- | ||
posterior_predictive |> | ||
tidybayes::gather_draws(observed_hospital_admissions[time]) |> | ||
tidyr::pivot_wider(names_from = .variable, values_from = .value) |> | ||
dplyr::rename(Disease = observed_hospital_admissions) |> | ||
dplyr::ungroup() |> | ||
dplyr::mutate(date = min(combined_dat$date) + time) |> | ||
dplyr::left_join(other_ed_visits_samples, | ||
by = c(".draw", "date") | ||
) |> | ||
dplyr::mutate(prop_disease_ed_visits = Disease / (Disease + Other)) |> | ||
tidyr::pivot_longer(c(Other, Disease, prop_disease_ed_visits), | ||
names_to = "disease", | ||
values_to = ".value" | ||
) | ||
|
||
|
||
forecast_ci <- | ||
forecast_samples |> | ||
dplyr::select(date, disease, .value) |> | ||
dplyr::group_by(date, disease) |> | ||
ggdist::median_qi(.width = c(0.5, 0.8, 0.95)) | ||
|
||
# Save data | ||
if (save) { | ||
arrow::write_parquet( | ||
combined_dat, | ||
fs::path(model_run_dir, | ||
"combined_training_eval_data", | ||
ext = "parquet" | ||
) | ||
) | ||
|
||
arrow::write_parquet( | ||
forecast_samples, | ||
fs::path(model_run_dir, "forecast_samples", | ||
ext = "parquet" | ||
) | ||
) | ||
|
||
arrow::write_parquet( | ||
forecast_ci, | ||
fs::path(model_run_dir, "forecast_ci", | ||
ext = "parquet" | ||
) | ||
) | ||
} | ||
return(list( | ||
combined_dat = combined_dat, | ||
forecast_samples = forecast_samples, | ||
forecast_ci = forecast_ci | ||
)) | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Oops, something went wrong.