From d6dd2c50f28ce08f2f954062e5dcd0b916888598 Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Thu, 21 Nov 2024 18:24:21 -0600 Subject: [PATCH] Small quality of life improvements to `pipelines/postprocess_state_forecast.R` (#154) --- hewr/NAMESPACE | 2 + hewr/R/make_forecast_figure.R | 102 +++++++++ hewr/R/process_state_forecast.R | 129 +++++++++++ hewr/man/make_forecast_figure.Rd | 36 ++++ hewr/man/process_state_forecast.Rd | 20 ++ pipelines/forecast_state.py | 276 ++++++++++++------------ pipelines/postprocess_state_forecast.R | 288 ++++--------------------- 7 files changed, 466 insertions(+), 387 deletions(-) create mode 100644 hewr/R/make_forecast_figure.R create mode 100644 hewr/R/process_state_forecast.R create mode 100644 hewr/man/make_forecast_figure.Rd create mode 100644 hewr/man/process_state_forecast.Rd diff --git a/hewr/NAMESPACE b/hewr/NAMESPACE index 86a36e69..3966cdab 100644 --- a/hewr/NAMESPACE +++ b/hewr/NAMESPACE @@ -1,8 +1,10 @@ # Generated by roxygen2: do not edit by hand export(get_all_model_batch_dirs) +export(make_forecast_figure) export(parse_model_batch_dir_path) export(parse_model_run_dir_path) +export(process_state_forecast) export(to_epiweekly_quantile_table) export(to_epiweekly_quantiles) importFrom(rlang,.data) diff --git a/hewr/R/make_forecast_figure.R b/hewr/R/make_forecast_figure.R new file mode 100644 index 00000000..c581c42b --- /dev/null +++ b/hewr/R/make_forecast_figure.R @@ -0,0 +1,102 @@ +#' Make Forecast Figure +#' +#' @param target_disease a disease matching the disease columns in combined_dat +#' and forecast_ci +#' @param combined_dat `combined_dat` from the result of process_state_forecast +#' @param forecast_ci `forecast_ci` from the result of process_state_forecast +#' @param disease_name "COVID-19" or "Influenza" +#' @param data_vintage_date date that the data was collected +#' @param y_transform a character passed as the transform argument to +#' ggplot2::scale_y_continuous() +#' +#' @return a ggplot object +#' @export + +make_forecast_figure <- function(target_disease, + combined_dat, + forecast_ci, + disease_name = c("COVID-19", "Influenza"), + data_vintage_date, + y_transform = "identity") { + disease_name <- rlang::arg_match(disease_name) + disease_name_pretty <- c( + "COVID-19" = "COVID-19", + "Influenza" = "Flu" + )[disease_name] + state_abb <- unique(combined_dat$geo_value)[1] + + y_scale <- if (stringr::str_starts(target_disease, "prop")) { + ggplot2::scale_y_continuous("Proportion of Emergency Department Visits", + labels = scales::label_percent(), + transform = y_transform + ) + } else { + ggplot2::scale_y_continuous("Emergency Department Visits", + labels = scales::label_comma(), + transform = y_transform + ) + } + + + title <- if (target_disease == "Other") { + glue::glue("Other ED Visits in {state_abb}") + } else { + glue::glue("{disease_name_pretty} ED Visits in {state_abb}") + } + + last_training_date <- combined_dat |> + dplyr::filter(data_type == "train") |> + dplyr::pull(date) |> + max() + + ggplot2::ggplot(mapping = ggplot2::aes(date, .value)) + + ggdist::geom_lineribbon( + data = forecast_ci |> dplyr::filter(disease == target_disease), + mapping = ggplot2::aes(ymin = .lower, ymax = .upper), + color = "#08519c", + key_glyph = ggplot2::draw_key_rect, + step = "mid" + ) + + ggplot2::scale_fill_brewer( + name = "Credible Interval Width", + labels = ~ scales::label_percent()(as.numeric(.)) + ) + + ggplot2::geom_point( + mapping = ggplot2::aes(color = data_type), size = 1.5, + data = combined_dat |> + dplyr::filter( + disease == target_disease, + date <= max(forecast_ci$date) + ) |> + dplyr::mutate(data_type = forcats::fct_rev(data_type)) |> + dplyr::arrange(dplyr::desc(data_type)) + ) + + ggplot2::scale_color_manual( + name = "Data Type", + values = c("olivedrab1", "deeppink"), + labels = stringr::str_to_title + ) + + ggplot2::geom_vline(xintercept = last_training_date, linetype = "dashed") + + ggplot2::annotate( + geom = "text", + x = last_training_date, + y = -Inf, + label = "Fit Period \u2190\n", + hjust = "right", + vjust = "bottom" + ) + + ggplot2::annotate( + geom = "text", + x = last_training_date, + y = -Inf, label = "\u2192 Forecast Period\n", + hjust = "left", + vjust = "bottom", + ) + + ggplot2::ggtitle(title, + subtitle = glue::glue("as of {data_vintage_date}") + ) + + y_scale + + ggplot2::scale_x_date("Date") + + cowplot::theme_minimal_grid() + + ggplot2::theme(legend.position = "bottom") +} diff --git a/hewr/R/process_state_forecast.R b/hewr/R/process_state_forecast.R new file mode 100644 index 00000000..af3729f9 --- /dev/null +++ b/hewr/R/process_state_forecast.R @@ -0,0 +1,129 @@ +#' Process state forecast +#' +#' @param model_run_dir Model run directory +#' @param save Logical indicating whether or not to save +#' +#' @return a list with three tibbles: combined_dat, forecast_samples, +#' and forecast_ci +#' @export +process_state_forecast <- function(model_run_dir, save = TRUE) { + disease_name_nssp <- parse_model_run_dir_path(model_run_dir)$disease + + train_data_path <- fs::path(model_run_dir, "data", ext = "csv") + train_dat <- readr::read_csv(train_data_path, show_col_types = FALSE) + + eval_data_path <- fs::path(model_run_dir, "eval_data", ext = "tsv") + eval_dat <- readr::read_tsv(eval_data_path, show_col_types = FALSE) |> + dplyr::mutate(data_type = "eval") + + posterior_predictive_path <- fs::path(model_run_dir, "mcmc_tidy", + "pyrenew_posterior_predictive", + ext = "parquet" + ) + posterior_predictive <- arrow::read_parquet(posterior_predictive_path) + + + other_ed_visits_path <- fs::path(model_run_dir, "other_ed_visits_forecast", + ext = "parquet" + ) + other_ed_visits_forecast <- arrow::read_parquet(other_ed_visits_path) |> + dplyr::rename(Other = other_ed_visits) + + combined_dat <- + dplyr::bind_rows( + train_dat |> + dplyr::filter(data_type == "train"), + eval_dat + ) |> + dplyr::mutate( + disease = dplyr::if_else( + disease == disease_name_nssp, + "Disease", # assign a common name for + # use in plotting functions + disease + ) + ) |> + tidyr::pivot_wider(names_from = disease, values_from = ed_visits) |> + dplyr::mutate( + Other = Total - Disease, + prop_disease_ed_visits = Disease / Total + ) |> + dplyr::select(-Total) |> + dplyr::mutate(time = dplyr::dense_rank(date)) |> + tidyr::pivot_longer(c(Disease, Other, prop_disease_ed_visits), + names_to = "disease", + values_to = ".value" + ) + + last_training_date <- combined_dat |> + dplyr::filter(data_type == "train") |> + dplyr::pull(date) |> + max() + + other_ed_visits_samples <- + dplyr::bind_rows( + combined_dat |> + dplyr::filter( + data_type == "train", + disease == "Other", + date <= last_training_date + ) |> + dplyr::select(date, Other = .value) |> + tidyr::expand_grid(.draw = 1:max(other_ed_visits_forecast$.draw)), + other_ed_visits_forecast + ) + + + forecast_samples <- + posterior_predictive |> + tidybayes::gather_draws(observed_hospital_admissions[time]) |> + tidyr::pivot_wider(names_from = .variable, values_from = .value) |> + dplyr::rename(Disease = observed_hospital_admissions) |> + dplyr::ungroup() |> + dplyr::mutate(date = min(combined_dat$date) + time) |> + dplyr::left_join(other_ed_visits_samples, + by = c(".draw", "date") + ) |> + dplyr::mutate(prop_disease_ed_visits = Disease / (Disease + Other)) |> + tidyr::pivot_longer(c(Other, Disease, prop_disease_ed_visits), + names_to = "disease", + values_to = ".value" + ) + + + forecast_ci <- + forecast_samples |> + dplyr::select(date, disease, .value) |> + dplyr::group_by(date, disease) |> + ggdist::median_qi(.width = c(0.5, 0.8, 0.95)) + + # Save data + if (save) { + arrow::write_parquet( + combined_dat, + fs::path(model_run_dir, + "combined_training_eval_data", + ext = "parquet" + ) + ) + + arrow::write_parquet( + forecast_samples, + fs::path(model_run_dir, "forecast_samples", + ext = "parquet" + ) + ) + + arrow::write_parquet( + forecast_ci, + fs::path(model_run_dir, "forecast_ci", + ext = "parquet" + ) + ) + } + return(list( + combined_dat = combined_dat, + forecast_samples = forecast_samples, + forecast_ci = forecast_ci + )) +} diff --git a/hewr/man/make_forecast_figure.Rd b/hewr/man/make_forecast_figure.Rd new file mode 100644 index 00000000..043813d7 --- /dev/null +++ b/hewr/man/make_forecast_figure.Rd @@ -0,0 +1,36 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/make_forecast_figure.R +\name{make_forecast_figure} +\alias{make_forecast_figure} +\title{Make Forecast Figure} +\usage{ +make_forecast_figure( + target_disease, + combined_dat, + forecast_ci, + disease_name = c("COVID-19", "Influenza"), + data_vintage_date, + y_transform = "identity" +) +} +\arguments{ +\item{target_disease}{a disease matching the disease columns in combined_dat +and forecast_ci} + +\item{combined_dat}{\code{combined_dat} from the result of process_state_forecast} + +\item{forecast_ci}{\code{forecast_ci} from the result of process_state_forecast} + +\item{disease_name}{"COVID-19" or "Influenza"} + +\item{data_vintage_date}{date that the data was collected} + +\item{y_transform}{a character passed as the transform argument to +ggplot2::scale_y_continuous()} +} +\value{ +a ggplot object +} +\description{ +Make Forecast Figure +} diff --git a/hewr/man/process_state_forecast.Rd b/hewr/man/process_state_forecast.Rd new file mode 100644 index 00000000..89f01cf3 --- /dev/null +++ b/hewr/man/process_state_forecast.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/process_state_forecast.R +\name{process_state_forecast} +\alias{process_state_forecast} +\title{Process state forecast} +\usage{ +process_state_forecast(model_run_dir, save = TRUE) +} +\arguments{ +\item{model_run_dir}{Model run directory} + +\item{save}{Logical indicating whether or not to save} +} +\value{ +a list with three tibbles: combined_dat, forecast_samples, +and forecast_ci +} +\description{ +Process state forecast +} diff --git a/pipelines/forecast_state.py b/pipelines/forecast_state.py index 5a4388f7..12b70afc 100644 --- a/pipelines/forecast_state.py +++ b/pipelines/forecast_state.py @@ -274,146 +274,146 @@ def main( return None -parser = argparse.ArgumentParser( - description="Create fit data for disease modeling." -) -parser.add_argument( - "--disease", - type=str, - required=True, - help="Disease to model (e.g., COVID-19, Influenza, RSV).", -) - -parser.add_argument( - "--state", - type=str, - required=True, - help=( - "Two letter abbreviation for the state to fit" - "(e.g. 'AK', 'AL', 'AZ', etc.)." - ), -) - -parser.add_argument( - "--report-date", - type=str, - default="latest", - help="Report date in YYYY-MM-DD format or latest (default: latest).", -) - -parser.add_argument( - "--facility-level-nssp-data-dir", - type=Path, - default=Path("private_data", "nssp_etl_gold"), - help=( - "Directory in which to look for facility-level NSSP " "ED visit data" - ), -) - -parser.add_argument( - "--state-level-nssp-data-dir", - type=Path, - default=Path("private_data", "nssp_state_level_gold"), - help=("Directory in which to look for state-level NSSP " "ED visit data."), -) - -parser.add_argument( - "--param-data-dir", - type=Path, - default=Path("private_data", "prod_param_estimates"), - help=( - "Directory in which to look for parameter estimates" - "such as delay PMFs." - ), - required=True, -) - -parser.add_argument( - "--priors-path", - type=Path, - help=( - "Path to an executible python file defining random variables " - "that require priors as pyrenew RandomVariable objects." - ), - required=True, -) - - -parser.add_argument( - "--output-data-dir", - type=Path, - default="private_data", - help="Directory in which to save output data.", -) - -parser.add_argument( - "--n-training-days", - type=int, - default=180, - help="Number of training days (default: 180).", -) - -parser.add_argument( - "--n-forecast-days", - type=int, - default=28, - help=( - "Number of days ahead to forecast relative to the " - "report date (default: 28).", - ), -) - - -parser.add_argument( - "--n-chains", - type=int, - default=4, - help="Number of MCMC chains to run (default: 4).", -) - -parser.add_argument( - "--n-warmup", - type=int, - default=1000, - help=("Number of warmup iterations per chain for NUTS" "(default: 1000)."), -) - -parser.add_argument( - "--n-samples", - type=int, - default=1000, - help=( - "Number of posterior samples to draw per " - "chain using NUTS (default: 1000)." - ), -) - -parser.add_argument( - "--exclude-last-n-days", - type=int, - default=0, - help=( - "Optionally exclude the final n days of available training " - "data (Default: 0, i.e. exclude no available data" - ), -) - -parser.add_argument( - "--score", - type=bool, - action=argparse.BooleanOptionalAction, - help=("If this flag is provided, will attempt to score the forecast."), -) - - -parser.add_argument( - "--eval-data-path", - type=Path, - help=("Path to a parquet file containing compehensive truth data."), -) +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Create fit data for disease modeling." + ) + parser.add_argument( + "--disease", + type=str, + required=True, + help="Disease to model (e.g., COVID-19, Influenza, RSV).", + ) + + parser.add_argument( + "--state", + type=str, + required=True, + help=( + "Two letter abbreviation for the state to fit" + "(e.g. 'AK', 'AL', 'AZ', etc.)." + ), + ) + parser.add_argument( + "--report-date", + type=str, + default="latest", + help="Report date in YYYY-MM-DD format or latest (default: latest).", + ) -if __name__ == "__main__": + parser.add_argument( + "--facility-level-nssp-data-dir", + type=Path, + default=Path("private_data", "nssp_etl_gold"), + help=( + "Directory in which to look for facility-level NSSP " + "ED visit data" + ), + ) + + parser.add_argument( + "--state-level-nssp-data-dir", + type=Path, + default=Path("private_data", "nssp_state_level_gold"), + help=( + "Directory in which to look for state-level NSSP " "ED visit data." + ), + ) + + parser.add_argument( + "--param-data-dir", + type=Path, + default=Path("private_data", "prod_param_estimates"), + help=( + "Directory in which to look for parameter estimates" + "such as delay PMFs." + ), + required=True, + ) + + parser.add_argument( + "--priors-path", + type=Path, + help=( + "Path to an executible python file defining random variables " + "that require priors as pyrenew RandomVariable objects." + ), + required=True, + ) + + parser.add_argument( + "--output-data-dir", + type=Path, + default="private_data", + help="Directory in which to save output data.", + ) + + parser.add_argument( + "--n-training-days", + type=int, + default=180, + help="Number of training days (default: 180).", + ) + + parser.add_argument( + "--n-forecast-days", + type=int, + default=28, + help=( + "Number of days ahead to forecast relative to the " + "report date (default: 28).", + ), + ) + + parser.add_argument( + "--n-chains", + type=int, + default=4, + help="Number of MCMC chains to run (default: 4).", + ) + + parser.add_argument( + "--n-warmup", + type=int, + default=1000, + help=( + "Number of warmup iterations per chain for NUTS" "(default: 1000)." + ), + ) + + parser.add_argument( + "--n-samples", + type=int, + default=1000, + help=( + "Number of posterior samples to draw per " + "chain using NUTS (default: 1000)." + ), + ) + + parser.add_argument( + "--exclude-last-n-days", + type=int, + default=0, + help=( + "Optionally exclude the final n days of available training " + "data (Default: 0, i.e. exclude no available data" + ), + ) + + parser.add_argument( + "--score", + type=bool, + action=argparse.BooleanOptionalAction, + help=("If this flag is provided, will attempt to score the forecast."), + ) + + parser.add_argument( + "--eval-data-path", + type=Path, + help=("Path to a parquet file containing compehensive truth data."), + ) args = parser.parse_args() numpyro.set_host_device_count(args.n_chains) main(**vars(args)) diff --git a/pipelines/postprocess_state_forecast.R b/pipelines/postprocess_state_forecast.R index 2e023322..661fe72a 100644 --- a/pipelines/postprocess_state_forecast.R +++ b/pipelines/postprocess_state_forecast.R @@ -1,20 +1,6 @@ script_packages <- c( - "dplyr", - "stringr", - "purrr", - "ggplot2", - "tidybayes", - "fs", - "cowplot", - "glue", - "scales", - "argparser", - "arrow", - "tidyr", - "readr", - "here", - "forcats", - "hewr" + "argparser", "cowplot", "dplyr", "fs", "glue", "hewr", "purrr", + "tidyr" ) ## load in packages without messages @@ -25,238 +11,48 @@ purrr::walk(script_packages, \(pkg) { }) -make_one_forecast_fig <- function(target_disease, - combined_dat, - last_training_date, - data_vintage_date, - posterior_predictive_ci, - state_abb, - y_transform = "identity") { - y_scale <- if (str_starts(target_disease, "prop")) { - scale_y_continuous("Proportion of Emergency Department Visits", - labels = percent, - transform = y_transform - ) - } else { - scale_y_continuous("Emergency Department Visits", - labels = comma, - transform = y_transform - ) - } +save_forecast_figures <- function(model_run_dir) { + parsed_model_run_dir <- parse_model_run_dir_path(model_run_dir) + processed_forecast <- process_state_forecast(model_run_dir) - title <- if (target_disease == "Other") { - glue("Other ED Visits in {state_abb}") - } else { - glue("{disease_name_pretty} ED Visits in {state_abb}") - } - - ggplot(mapping = aes(date, .value)) + - geom_lineribbon( - data = posterior_predictive_ci |> filter(disease == target_disease), - mapping = aes(ymin = .lower, ymax = .upper), - color = "#08519c", - key_glyph = draw_key_rect, - step = "mid" - ) + - scale_fill_brewer( - name = "Credible Interval Width", - labels = ~ percent(as.numeric(.)) - ) + - geom_point( - mapping = aes(color = data_type), size = 1.5, - data = combined_dat |> - filter( - disease == target_disease, - date <= max(posterior_predictive_ci$date) - ) |> - mutate(data_type = fct_rev(data_type)) |> - arrange(desc(data_type)) - ) + - scale_color_manual( - name = "Data Type", - values = c("olivedrab1", "deeppink"), - labels = str_to_title - ) + - geom_vline(xintercept = last_training_date, linetype = "dashed") + - annotate( - geom = "text", - x = last_training_date, - y = -Inf, - label = "Fit Period ←\n", - hjust = "right", - vjust = "bottom" - ) + - annotate( - geom = "text", - x = last_training_date, - y = -Inf, label = "→ Forecast Period\n", - hjust = "left", - vjust = "bottom", - ) + - ggtitle(title, subtitle = glue("as of {data_vintage_date}")) + - y_scale + - scale_x_date("Date") + - theme(legend.position = "bottom") -} - - -postprocess_state_forecast <- function(model_run_dir) { - state_abb <- model_run_dir |> - path_split() |> - pluck(1) |> - tail(1) - - train_data_path <- path(model_run_dir, "data", ext = "csv") - eval_data_path <- path(model_run_dir, "eval_data", ext = "tsv") - posterior_predictive_path <- path(model_run_dir, "mcmc_tidy", - "pyrenew_posterior_predictive", - ext = "parquet" - ) - other_ed_visits_path <- path( - model_run_dir, - "other_ed_visits_forecast", - ext = "parquet" - ) - - train_dat <- read_csv(train_data_path, show_col_types = FALSE) - - data_vintage_date <- max(train_dat$date) + 1 - # this should be stored as metadata somewhere else, instead of being - # computed like this - - eval_dat <- read_tsv(eval_data_path, show_col_types = FALSE) |> - mutate(data_type = "eval") - - combined_dat <- - bind_rows( - train_dat |> - filter(data_type == "train"), - eval_dat + figure_save_tbl <- + expand_grid( + target_disease = unique(processed_forecast$combined_dat$disease), + y_transform = c("identity", "log10") ) |> - mutate( - disease = if_else( - disease == disease_name_nssp, - "Disease", # assign a common name for - # use in plotting functions - disease + mutate(path_suffix = c("identity" = "", "log10" = "_log")[y_transform]) |> + mutate(figure_path = path(model_run_dir, + glue("{target_disease}_forecast_plot{path_suffix}"), + ext = "pdf" + )) |> + mutate(figure = map2( + target_disease, y_transform, + \(target_disease, y_transform) { + make_forecast_figure( + target_disease = target_disease, + combined_dat = processed_forecast$combined_dat, + forecast_ci = processed_forecast$forecast_ci, + disease_name = parsed_model_run_dir$disease, + data_vintage_date = parsed_model_run_dir$report_date, + y_transform = y_transform + ) + } + )) + + + walk2( + figure_save_tbl$figure, figure_save_tbl$figure_path, + \(figure, figure_path) { + save_plot( + filename = figure_path, + plot = figure, + device = cairo_pdf, base_height = 6 ) - ) |> - pivot_wider(names_from = disease, values_from = ed_visits) |> - mutate( - Other = Total - Disease, - prop_disease_ed_visits = Disease / Total - ) |> - select(-Total) |> - mutate(time = dense_rank(date)) |> - pivot_longer(c(Disease, Other, prop_disease_ed_visits), - names_to = "disease", - values_to = ".value" - ) - - - last_training_date <- combined_dat |> - filter(data_type == "train") |> - pull(date) |> - max() - - posterior_predictive <- read_parquet(posterior_predictive_path) - - other_ed_visits_forecast <- - read_parquet(other_ed_visits_path) |> - rename(Other = other_ed_visits) - - other_ed_visits_samples <- - bind_rows( - combined_dat |> - filter( - data_type == "train", - disease == "Other", - date <= last_training_date - ) |> - select(date, Other = .value) |> - expand_grid(.draw = 1:max(other_ed_visits_forecast$.draw)), - other_ed_visits_forecast - ) - - posterior_predictive_samples <- - posterior_predictive |> - gather_draws(observed_hospital_admissions[time]) |> - pivot_wider(names_from = .variable, values_from = .value) |> - rename(Disease = observed_hospital_admissions) |> - ungroup() |> - mutate(date = min(combined_dat$date) + time) |> - left_join(other_ed_visits_samples, - by = c(".draw", "date") - ) |> - mutate(prop_disease_ed_visits = Disease / (Disease + Other)) |> - pivot_longer(c(Other, Disease, prop_disease_ed_visits), - names_to = "disease", - values_to = ".value" - ) - - arrow::write_parquet( - posterior_predictive_samples, - path(model_run_dir, "forecast_samples", - ext = "parquet" - ) - ) - - posterior_predictive_ci <- - posterior_predictive_samples |> - select(date, disease, .value) |> - group_by(date, disease) |> - median_qi(.width = c(0.5, 0.8, 0.95)) - - - arrow::write_parquet( - posterior_predictive_ci, - path(model_run_dir, "forecast_ci", - ext = "parquet" - ) - ) - - - all_forecast_plots <- map( - set_names(unique(combined_dat$disease)), - ~ make_one_forecast_fig( - .x, - combined_dat, - last_training_date, - data_vintage_date, - posterior_predictive_ci, - state_abb, - ) + } ) - - all_forecast_plots_log <- map( - set_names(unique(combined_dat$disease)), - ~ make_one_forecast_fig( - .x, - combined_dat, - last_training_date, - data_vintage_date, - posterior_predictive_ci, - state_abb, - y_transform = "log10" - ) - ) - - iwalk(all_forecast_plots, ~ save_plot( - filename = path(model_run_dir, glue("{.y}_forecast_plot"), ext = "pdf"), - plot = .x, - device = cairo_pdf, base_height = 6 - )) - iwalk(all_forecast_plots_log, ~ save_plot( - filename = path(model_run_dir, glue("{.y}_forecast_plot_log"), ext = "pdf"), - plot = .x, - device = cairo_pdf, base_height = 6 - )) } -theme_set(theme_minimal_grid()) - -# Create a parser p <- arg_parser("Generate forecast figures") |> add_argument( "model_run_dir", @@ -264,12 +60,6 @@ p <- arg_parser("Generate forecast figures") |> ) argv <- parse_args(p) -model_run_dir <- path(argv$model_run_dir) - -disease_name_nssp <- parse_model_run_dir_path(model_run_dir)$disease - -disease_name_formatter <- c("COVID-19" = "COVID-19", "Influenza" = "Flu") -disease_name_pretty <- unname(disease_name_formatter[disease_name_nssp]) - -postprocess_state_forecast(model_run_dir) +model_run_dir <- path(argv$model_run_dir) +save_forecast_figures(model_run_dir)