Skip to content

Commit

Permalink
Strict should always be TRUE, get order correct for epiweekly data ge…
Browse files Browse the repository at this point in the history
…neration in forecast_state
  • Loading branch information
dylanhmorris committed Nov 23, 2024
1 parent b8612dc commit 36e26a9
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 27 deletions.
2 changes: 1 addition & 1 deletion hewr/R/process_state_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ process_state_forecast <- function(model_run_dir, save = TRUE) {
dplyr::group_by(disease) |>
dplyr::group_modify(~ forecasttools::daily_to_epiweekly(.x,
value_col = ".value", weekly_value_name = ".value",
strict = strict
strict = TRUE
)) |>
dplyr::ungroup() |>
tidyr::pivot_wider(
Expand Down
8 changes: 4 additions & 4 deletions pipelines/batch/setup_eval_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def main(
"target": "/pyrenew-hew/params",
},
{
"source": "pyrenew-hew-prod-output",
"source": "pyrenew-test-output/eval2",
"target": "/pyrenew-hew/output",
},
{
Expand All @@ -123,7 +123,7 @@ def main(
"python pipelines/forecast_state.py "
"--disease {disease} "
"--state {state} "
"--n-training-days 365 "
"--n-training-days 90 "
"--n-warmup 1000 "
"--n-samples 500 "
"--facility-level-nssp-data-dir nssp-etl/gold "
Expand All @@ -133,7 +133,7 @@ def main(
"--output-data-dir output "
"--priors-path config/eval_priors.py "
"--report-date {report_date:%Y-%m-%d} "
"--exclude-last-n-days 2 "
"--exclude-last-n-days 5 "
"--score "
"--eval-data-path "
"nssp-archival-vintages/latest_comprehensive.parquet"
Expand All @@ -152,7 +152,7 @@ def main(

report_dates = [
datetime.date(2023, 10, 11) + datetime.timedelta(weeks=x)
for x in range(30)
for x in range(4, 30)
]

for disease, report_date, loc in itertools.product(
Expand Down
28 changes: 14 additions & 14 deletions pipelines/forecast_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,21 @@ def main(
model_run_dir=model_run_dir,
logger=logger,
)
logger.info("Getting eval data...")
if eval_data_path is None:
raise ValueError("No path to an evaluation dataset provided.")
save_eval_data(
state=state,
report_date=report_date,
disease=disease,
first_training_date=first_training_date,
last_training_date=last_training_date,
latest_comprehensive_path=eval_data_path,
output_data_dir=model_run_dir,
last_eval_date=report_date + timedelta(days=n_forecast_days),
)

logger.info("Generating epiweekly data...")
logger.info("Generating epiweekly datasets from daily datasets...")
generate_epiweekly(model_run_dir)

logger.info("Data preparation complete.")
Expand Down Expand Up @@ -258,19 +271,6 @@ def main(
model_run_dir, n_days_past_last_training, n_denominator_samples
)
logger.info("All forecasting complete.")
logger.info("Getting eval data...")
if eval_data_path is None:
raise ValueError("No path to an evaluation dataset provided.")
save_eval_data(
state=state,
report_date=report_date,
disease=disease,
first_training_date=first_training_date,
last_training_date=last_training_date,
latest_comprehensive_path=eval_data_path,
output_data_dir=model_run_dir,
last_eval_date=report_date + timedelta(days=n_forecast_days),
)

logger.info("Converting inferencedata to parquet...")
convert_inferencedata_to_parquet(model_run_dir)
Expand Down
13 changes: 5 additions & 8 deletions pipelines/generate_epiweekly.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ purrr::walk(script_packages, \(pkg) {
#' specified directory.
convert_daily_to_epiweekly <- function(
model_run_dir, strict = TRUE,
day_of_week = 1, dataname = "data", ext = "csv") {
day_of_week = 7, dataname = "data", ext = "csv") {
if (!ext %in% c("csv", "tsv")) {
stop("Invalid file extension. Only 'csv' and 'tsv' are allowed.")
}
Expand All @@ -53,12 +53,6 @@ convert_daily_to_epiweekly <- function(
) |>
mutate(.draw = 1)

the_data_type <- daily_data$data_type[1]
if (any(daily_data$data_type != the_data_type)) {
stop(glue::glue("The data_type column must contain only {the_data_type}
values."))
}

epiweekly_data <- daily_data |>
group_by(disease) |>
group_modify(~ forecasttools::daily_to_epiweekly(.x,
Expand All @@ -70,7 +64,10 @@ convert_daily_to_epiweekly <- function(
day_of_week = day_of_week
)) |>
select(date, disease, ed_visits) |>
mutate(data_type = the_data_type)
inner_join(daily_data |> select(date, disease, data_type),
by = c("date", "disease")
)
# epiweek end date determines data_type classification

output_file <- path(model_run_dir, glue::glue("epiweekly_{dataname}"),
ext = ext
Expand Down

0 comments on commit 36e26a9

Please sign in to comment.