Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 110: Add epiweekly fits and scoring #159

Merged
merged 38 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
3ed55f8
Create generate_epiweekly.R
SamuelBrand1 Nov 22, 2024
0964737
Update generate_epiweekly.R
SamuelBrand1 Nov 22, 2024
7492852
restyle
SamuelBrand1 Nov 22, 2024
128ad65
add convert data to epiweekly to test mode script
SamuelBrand1 Nov 22, 2024
8c2f63f
add step to fit baseline models to epiweekly
SamuelBrand1 Nov 22, 2024
27fe335
add epiweekly scoring to score forecasts
SamuelBrand1 Nov 22, 2024
e166ac3
add tidyr dep in score_forecast
SamuelBrand1 Nov 22, 2024
4f3549a
Merge branch 'main' into epiweekly-fits
dylanhmorris Nov 22, 2024
ef52f4a
Merge branch 'main' into epiweekly-fits
dylanhmorris Nov 22, 2024
6627518
Move epiweekly sample saving into hewr
dylanhmorris Nov 23, 2024
b8612dc
Add generate_epiweekly to forecast_state.py, rename baseline_forecast…
dylanhmorris Nov 23, 2024
36e26a9
Strict should always be TRUE, get order correct for epiweekly data ge…
dylanhmorris Nov 23, 2024
f63fd6d
Typo and namespace fixes
dylanhmorris Nov 23, 2024
2bfb324
Add 95 coverage to default quantile scoring
dylanhmorris Nov 23, 2024
dc409b0
Filter samples to after forecast date for ease of scoring
dylanhmorris Nov 23, 2024
3e9dcef
Correct variable name
dylanhmorris Nov 23, 2024
46342b1
Also filter quantile scores
dylanhmorris Nov 23, 2024
0178e33
Debug collate_score_tables
dylanhmorris Nov 23, 2024
0f1a9e8
File ext handling, variable name typo
dylanhmorris Nov 23, 2024
94843d1
Fix collate plots
dylanhmorris Nov 23, 2024
b0a2a8f
Tweak setup eval job
dylanhmorris Nov 23, 2024
067b6ac
Configurable output subdir
dylanhmorris Nov 23, 2024
30726fc
Typo fix
dylanhmorris Nov 23, 2024
e3cc2ac
Tweak eval job
dylanhmorris Nov 23, 2024
3560132
Update pipelines/score_forecast.R
SamuelBrand1 Nov 25, 2024
714c620
tweaks to epiweekly fits (#164)
dylanhmorris Nov 25, 2024
a43441f
Merge branch 'main' into epiweekly-fits
dylanhmorris Nov 26, 2024
a2fa9af
Update pipelines/batch/setup_prod_job.py
SamuelBrand1 Nov 26, 2024
0145294
change to if_else pattern
SamuelBrand1 Nov 27, 2024
57148a2
Add missing streaming=True to collect() calls in prep_data.py
dylanhmorris Nov 27, 2024
95838c5
Use more forecasttools in score_hubverse.R
dylanhmorris Nov 27, 2024
ab6b1f3
Updates to hubverse scoring
dylanhmorris Dec 2, 2024
cded78d
Update generate_epiweekly.R
SamuelBrand1 Dec 2, 2024
ee059ff
Merge branch 'epiweekly-fits' of https://github.com/CDCgov/pyrenew-he…
SamuelBrand1 Dec 2, 2024
34f47ba
restyle
SamuelBrand1 Dec 2, 2024
e70b0f8
Add option for epiweekly denominators to hubverse table
dylanhmorris Dec 2, 2024
c696aa6
Fix docs for other ed visit flag in to_epiweekly_quantile_table
dylanhmorris Dec 2, 2024
2f38c1f
Remove score_hubverse.R
dylanhmorris Dec 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions hewr/R/process_state_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,29 @@ process_state_forecast <- function(model_run_dir, save = TRUE) {
values_to = ".value"
)

epiweekly_forecast_samples <- forecast_samples |>
dplyr::filter(disease != "prop_disease_ed_visits") |>
dplyr::group_by(disease) |>
dplyr::group_modify(~ forecasttools::daily_to_epiweekly(.x,
value_col = ".value", weekly_value_name = ".value",
strict = TRUE
)) |>
dplyr::ungroup() |>
tidyr::pivot_wider(
names_from = disease,
values_from = .value
) |>
dplyr::mutate(prop_disease_ed_visits = Disease /
(Disease + Other)) |>
tidyr::pivot_longer(c(Disease, Other, prop_disease_ed_visits),
names_to = "disease",
values_to = ".value"
) |>
dplyr::mutate(date = forecasttools::epiweek_to_date(
epiweek,
epiyear,
day_of_week = 7
))

forecast_ci <-
forecast_samples |>
Expand All @@ -113,6 +136,12 @@ process_state_forecast <- function(model_run_dir, save = TRUE) {
ext = "parquet"
)
)
arrow::write_parquet(
epiweekly_forecast_samples,
fs::path(model_run_dir, "epiweekly_forecast_samples",
ext = "parquet"
)
)

arrow::write_parquet(
forecast_ci,
Expand All @@ -124,6 +153,7 @@ process_state_forecast <- function(model_run_dir, save = TRUE) {
return(list(
combined_dat = combined_dat,
forecast_samples = forecast_samples,
epiweekly_forecast_samples = epiweekly_forecast_samples,
forecast_ci = forecast_ci
))
}
4 changes: 3 additions & 1 deletion hewr/R/to_epiweekly_quantile_table.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ to_epiweekly_quantiles <- function(model_run_dir,
#' @export
to_epiweekly_quantile_table <- function(model_batch_dir,
exclude = NULL) {
locations_to_process <- fs::dir_ls(model_batch_dir,
model_runs_path <- fs::path(model_batch_dir, "model_runs")

locations_to_process <- fs::dir_ls(model_runs_path,
type = "directory"
)

Expand Down
31 changes: 27 additions & 4 deletions pipelines/batch/setup_eval_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import argparse
import datetime
import itertools
from pathlib import Path

import polars as pl
from azure.batch import models
Expand All @@ -21,6 +22,7 @@ def main(
job_id: str,
pool_id: str,
diseases: str,
output_subdir: str | Path = "./",
damonbayer marked this conversation as resolved.
Show resolved Hide resolved
container_image_name: str = "pyrenew-hew",
container_image_version: str = "latest",
excluded_locations: list[str] = [
Expand All @@ -46,6 +48,11 @@ def main(
as a whitespace-separated string. Supported
values are 'COVID-19' and 'Influenza'.

output_subdir
Subdirectory of the output blob storage container
in which to save results.


container_image_name:
Name of the container to use for the job.
This container should exist within the Azure
Expand Down Expand Up @@ -108,7 +115,7 @@ def main(
"target": "/pyrenew-hew/params",
},
{
"source": "pyrenew-hew-prod-output",
"source": "pyrenew-test-output",
"target": "/pyrenew-hew/output",
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved
},
{
Expand All @@ -123,17 +130,17 @@ def main(
"python pipelines/forecast_state.py "
"--disease {disease} "
"--state {state} "
"--n-training-days 365 "
"--n-training-days {n_training} "
"--n-warmup 1000 "
"--n-samples 500 "
"--facility-level-nssp-data-dir nssp-etl/gold "
"--state-level-nssp-data-dir "
"nssp-archival-vintages/gold "
"--param-data-dir params "
"--output-data-dir output "
"--output-dir {output_dir} "
"--priors-path config/eval_priors.py "
"--report-date {report_date:%Y-%m-%d} "
"--exclude-last-n-days 2 "
"--exclude-last-n-days {exclude_last_n} "
"--score "
"--eval-data-path "
"nssp-archival-vintages/latest_comprehensive.parquet"
Expand All @@ -158,12 +165,17 @@ def main(
for disease, report_date, loc in itertools.product(
disease_list, report_dates, all_locations
):
n_training = 90
exclude_last_n = 3
task = get_task_config(
f"{job_id}-{loc}-{disease}-{report_date}",
base_call=base_call.format(
state=loc,
disease=disease,
report_date=report_date,
n_training=n_training,
exclude_last_n=exclude_last_n,
output_dir=str(Path("output", output_subdir)),
damonbayer marked this conversation as resolved.
Show resolved Hide resolved
),
container_settings=container_settings,
)
Expand All @@ -190,6 +202,17 @@ def main(
),
)

parser.add_argument(
"--output-subdir",
type=str,
help=(
"Subdirectory of the output blob storage container "
"in which to save results."
),
default="./",
)


parser.add_argument(
"--container-image-name",
type=str,
Expand Down
240 changes: 240 additions & 0 deletions pipelines/batch/setup_parameter_inference_job.py
damonbayer marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
"""
Set up a multi-location, multi-disease parameter
inference run for pyrenew-hew on Azure Batch.
"""

import argparse
import itertools
from pathlib import Path

import polars as pl
from azure.batch import models
from azuretools.auth import EnvCredentialHandler
from azuretools.client import get_batch_service_client
from azuretools.job import create_job_if_not_exists
from azuretools.task import get_container_settings, get_task_config


def main(
job_id: str,
pool_id: str,
diseases: str | list[str],
output_subdir: str | Path = "./",
container_image_name: str = "pyrenew-hew",
container_image_version: str = "latest",
excluded_locations: list[str] = [
"AS",
"GU",
"MO",
"MP",
"PR",
"UM",
"VI",
"WY",
],
) -> None:
"""
job_id
Name for the Batch job.

pool_id
Azure Batch pool on which to run the job.

diseases
Name(s) of disease(s) to run as part of the job,
as a single string (one disease) or a list of strings.
Supported values are 'COVID-19' and 'Influenza'.

output_subdir
Subdirectory of the output blob storage container
in which to save results.

container_image_name:
Name of the container to use for the job.
This container should exist within the Azure
Container Registry account associated to
the job. Default 'pyrenew-hew'.
The container registry account name and enpoint
will be obtained from local environm variables
via a :class``azuretools.auth.EnvCredentialHandler`.

container_image_version
Version of the container to use. Default 'latest'.

excluded_locations
List of two letter USPS location abbreviations to
exclude from the job. Defaults to locations for which
we typically do not have available NSSP ED visit data:
``["AS", "GU", "MO", "MP", "PR", "UM", "VI", "WY"]``.

Returns
-------
None
"""
supported_diseases = ["COVID-19", "Influenza"]

disease_list = diseases

invalid_diseases = set(disease_list) - set(supported_diseases)
if invalid_diseases:
raise ValueError(
f"Unsupported diseases: {', '.join(invalid_diseases)}; "
f"supported diseases are: {', '.join(supported_diseases)}"
)

pyrenew_hew_output_container = "pyrenew-test-output"
n_warmup = 1000
n_samples = 500

creds = EnvCredentialHandler()
client = get_batch_service_client(creds)
job = models.JobAddParameter(
id=job_id,
pool_info=models.PoolInformation(pool_id=pool_id),
)
create_job_if_not_exists(client, job, verbose=True)

container_image = (
f"{creds.azure_container_registry_account}."
f"{creds.azure_container_registry_domain}/"
f"{container_image_name}:{container_image_version}"
)
container_settings = get_container_settings(
container_image,
working_directory="containerImageDefault",
mount_pairs=[
{
"source": "nssp-etl",
"target": "/pyrenew-hew/nssp-etl",
},
{
"source": "nssp-archival-vintages",
"target": "/pyrenew-hew/nssp-archival-vintages",
},
{
"source": "prod-param-estimates",
"target": "/pyrenew-hew/params",
},
{
"source": pyrenew_hew_output_container,
"target": "/pyrenew-hew/output",
},
{
"source": "pyrenew-hew-config",
"target": "/pyrenew-hew/config",
},
],
)

base_call = (
"/bin/bash -c '"
"python pipelines/forecast_state.py "
"--disease {disease} "
"--state {state} "
"--n-training-days 450 "
"--n-warmup {n_warmup} "
"--n-samples {n_samples} "
"--facility-level-nssp-data-dir nssp-etl/gold "
"--state-level-nssp-data-dir "
"nssp-archival-vintages/gold "
"--param-data-dir params "
"--output-dir {output_dir} "
"--priors-path config/parameter_inference_priors.py "
damonbayer marked this conversation as resolved.
Show resolved Hide resolved
"--report-date {report_date} "
"--exclude-last-n-days 5 "
"--no-score "
"--eval-data-path "
"nssp-archival-vintages/latest_comprehensive.parquet"
"'"
)

# to be replaced by forecasttools-py table
locations = pl.read_csv(
"https://www2.census.gov/geo/docs/reference/state.txt", separator="|"
)

all_locations = [
loc
for loc in ["US"] + locations.get_column("STUSAB").to_list()
if loc not in excluded_locations
]

for disease, state in itertools.product(disease_list, all_locations):
task = get_task_config(
f"{job_id}-{state}-{disease}-prod",
base_call=base_call.format(
state=state,
disease=disease,
report_date="2024-04-01",
n_warmup=n_warmup,
n_samples=n_samples,
output_dir=str(Path("output", output_subdir)),
),
container_settings=container_settings,
)
client.task.add(job_id, task)

return None


parser = argparse.ArgumentParser()

parser.add_argument("job_id", type=str, help="Name for the Azure batch job")
parser.add_argument(
"pool_id",
type=str,
help=("Name of the Azure batch pool on which to run the job"),
)
parser.add_argument(
"diseases",
type=str,
help=(
"Name(s) of disease(s) to run as part of the job, "
"as a whitespace-separated string. Supported "
"values are 'COVID-19' and 'Influenza'."
),
)

parser.add_argument(
"--output-subdir",
type=str,
help=(
"Subdirectory of the output blob storage container "
"in which to save results."
),
default="./",
)

parser.add_argument(
"--container-image-name",
type=str,
help="Name of the container to use for the job.",
default="pyrenew-hew",
)

parser.add_argument(
"--container-image-version",
type=str,
help="Version of the container to use for the job.",
default="latest",
)

parser.add_argument(
"--excluded-locations",
type=str,
help=(
"Two-letter USPS location abbreviations to "
"exclude from the job, as a whitespace-separated "
"string. Defaults to a set of locations for which "
"we typically do not have available NSSP ED visit "
"data: 'AS GU MO MP PR UM VI WY'."
),
default="AS GU MO MP PR UM VI WY",
)


if __name__ == "__main__":
args = parser.parse_args()
args.diseases = args.diseases.split()
args.excluded_locations = args.excluded_locations.split()
main(**vars(args))
Loading
Loading