Skip to content

Commit

Permalink
Small quality of life improvements to `pipelines/postprocess_state_fo…
Browse files Browse the repository at this point in the history
…recast.R` (#154)
  • Loading branch information
damonbayer authored Nov 22, 2024
1 parent 72f816b commit d6dd2c5
Show file tree
Hide file tree
Showing 7 changed files with 466 additions and 387 deletions.
2 changes: 2 additions & 0 deletions hewr/NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Generated by roxygen2: do not edit by hand

export(get_all_model_batch_dirs)
export(make_forecast_figure)
export(parse_model_batch_dir_path)
export(parse_model_run_dir_path)
export(process_state_forecast)
export(to_epiweekly_quantile_table)
export(to_epiweekly_quantiles)
importFrom(rlang,.data)
102 changes: 102 additions & 0 deletions hewr/R/make_forecast_figure.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#' Make Forecast Figure
#'
#' @param target_disease a disease matching the disease columns in combined_dat
#' and forecast_ci
#' @param combined_dat `combined_dat` from the result of process_state_forecast
#' @param forecast_ci `forecast_ci` from the result of process_state_forecast
#' @param disease_name "COVID-19" or "Influenza"
#' @param data_vintage_date date that the data was collected
#' @param y_transform a character passed as the transform argument to
#' ggplot2::scale_y_continuous()
#'
#' @return a ggplot object
#' @export

make_forecast_figure <- function(target_disease,
combined_dat,
forecast_ci,
disease_name = c("COVID-19", "Influenza"),
data_vintage_date,
y_transform = "identity") {
disease_name <- rlang::arg_match(disease_name)
disease_name_pretty <- c(
"COVID-19" = "COVID-19",
"Influenza" = "Flu"
)[disease_name]
state_abb <- unique(combined_dat$geo_value)[1]

y_scale <- if (stringr::str_starts(target_disease, "prop")) {
ggplot2::scale_y_continuous("Proportion of Emergency Department Visits",
labels = scales::label_percent(),
transform = y_transform
)
} else {
ggplot2::scale_y_continuous("Emergency Department Visits",
labels = scales::label_comma(),
transform = y_transform
)
}


title <- if (target_disease == "Other") {
glue::glue("Other ED Visits in {state_abb}")
} else {
glue::glue("{disease_name_pretty} ED Visits in {state_abb}")
}

last_training_date <- combined_dat |>
dplyr::filter(data_type == "train") |>
dplyr::pull(date) |>
max()

ggplot2::ggplot(mapping = ggplot2::aes(date, .value)) +
ggdist::geom_lineribbon(
data = forecast_ci |> dplyr::filter(disease == target_disease),
mapping = ggplot2::aes(ymin = .lower, ymax = .upper),
color = "#08519c",
key_glyph = ggplot2::draw_key_rect,
step = "mid"
) +
ggplot2::scale_fill_brewer(
name = "Credible Interval Width",
labels = ~ scales::label_percent()(as.numeric(.))
) +
ggplot2::geom_point(
mapping = ggplot2::aes(color = data_type), size = 1.5,
data = combined_dat |>
dplyr::filter(
disease == target_disease,
date <= max(forecast_ci$date)
) |>
dplyr::mutate(data_type = forcats::fct_rev(data_type)) |>
dplyr::arrange(dplyr::desc(data_type))
) +
ggplot2::scale_color_manual(
name = "Data Type",
values = c("olivedrab1", "deeppink"),
labels = stringr::str_to_title
) +
ggplot2::geom_vline(xintercept = last_training_date, linetype = "dashed") +
ggplot2::annotate(
geom = "text",
x = last_training_date,
y = -Inf,
label = "Fit Period \u2190\n",
hjust = "right",
vjust = "bottom"
) +
ggplot2::annotate(
geom = "text",
x = last_training_date,
y = -Inf, label = "\u2192 Forecast Period\n",
hjust = "left",
vjust = "bottom",
) +
ggplot2::ggtitle(title,
subtitle = glue::glue("as of {data_vintage_date}")
) +
y_scale +
ggplot2::scale_x_date("Date") +
cowplot::theme_minimal_grid() +
ggplot2::theme(legend.position = "bottom")
}
129 changes: 129 additions & 0 deletions hewr/R/process_state_forecast.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#' Process state forecast
#'
#' @param model_run_dir Model run directory
#' @param save Logical indicating whether or not to save
#'
#' @return a list with three tibbles: combined_dat, forecast_samples,
#' and forecast_ci
#' @export
process_state_forecast <- function(model_run_dir, save = TRUE) {
disease_name_nssp <- parse_model_run_dir_path(model_run_dir)$disease

train_data_path <- fs::path(model_run_dir, "data", ext = "csv")
train_dat <- readr::read_csv(train_data_path, show_col_types = FALSE)

eval_data_path <- fs::path(model_run_dir, "eval_data", ext = "tsv")
eval_dat <- readr::read_tsv(eval_data_path, show_col_types = FALSE) |>
dplyr::mutate(data_type = "eval")

posterior_predictive_path <- fs::path(model_run_dir, "mcmc_tidy",
"pyrenew_posterior_predictive",
ext = "parquet"
)
posterior_predictive <- arrow::read_parquet(posterior_predictive_path)


other_ed_visits_path <- fs::path(model_run_dir, "other_ed_visits_forecast",
ext = "parquet"
)
other_ed_visits_forecast <- arrow::read_parquet(other_ed_visits_path) |>
dplyr::rename(Other = other_ed_visits)

combined_dat <-
dplyr::bind_rows(
train_dat |>
dplyr::filter(data_type == "train"),
eval_dat
) |>
dplyr::mutate(
disease = dplyr::if_else(
disease == disease_name_nssp,
"Disease", # assign a common name for
# use in plotting functions
disease
)
) |>
tidyr::pivot_wider(names_from = disease, values_from = ed_visits) |>
dplyr::mutate(
Other = Total - Disease,
prop_disease_ed_visits = Disease / Total
) |>
dplyr::select(-Total) |>
dplyr::mutate(time = dplyr::dense_rank(date)) |>
tidyr::pivot_longer(c(Disease, Other, prop_disease_ed_visits),
names_to = "disease",
values_to = ".value"
)

last_training_date <- combined_dat |>
dplyr::filter(data_type == "train") |>
dplyr::pull(date) |>
max()

other_ed_visits_samples <-
dplyr::bind_rows(
combined_dat |>
dplyr::filter(
data_type == "train",
disease == "Other",
date <= last_training_date
) |>
dplyr::select(date, Other = .value) |>
tidyr::expand_grid(.draw = 1:max(other_ed_visits_forecast$.draw)),
other_ed_visits_forecast
)


forecast_samples <-
posterior_predictive |>
tidybayes::gather_draws(observed_hospital_admissions[time]) |>
tidyr::pivot_wider(names_from = .variable, values_from = .value) |>
dplyr::rename(Disease = observed_hospital_admissions) |>
dplyr::ungroup() |>
dplyr::mutate(date = min(combined_dat$date) + time) |>
dplyr::left_join(other_ed_visits_samples,
by = c(".draw", "date")
) |>
dplyr::mutate(prop_disease_ed_visits = Disease / (Disease + Other)) |>
tidyr::pivot_longer(c(Other, Disease, prop_disease_ed_visits),
names_to = "disease",
values_to = ".value"
)


forecast_ci <-
forecast_samples |>
dplyr::select(date, disease, .value) |>
dplyr::group_by(date, disease) |>
ggdist::median_qi(.width = c(0.5, 0.8, 0.95))

# Save data
if (save) {
arrow::write_parquet(
combined_dat,
fs::path(model_run_dir,
"combined_training_eval_data",
ext = "parquet"
)
)

arrow::write_parquet(
forecast_samples,
fs::path(model_run_dir, "forecast_samples",
ext = "parquet"
)
)

arrow::write_parquet(
forecast_ci,
fs::path(model_run_dir, "forecast_ci",
ext = "parquet"
)
)
}
return(list(
combined_dat = combined_dat,
forecast_samples = forecast_samples,
forecast_ci = forecast_ci
))
}
36 changes: 36 additions & 0 deletions hewr/man/make_forecast_figure.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions hewr/man/process_state_forecast.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit d6dd2c5

Please sign in to comment.