Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track git commit id when fitting models #242

Merged
merged 12 commits into from
Dec 16, 2024
5 changes: 4 additions & 1 deletion .github/workflows/containers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
outputs:
tag: ${{ steps.image-tag.outputs.tag }}
commit-msg: ${{ steps.commit-message.outputs.message }}
branch: ${{ steps.branch-name.outputs.branch }}

steps:

Expand Down Expand Up @@ -115,7 +116,9 @@ jobs:
with:
push: true # This can be toggled manually for tweaking.
tags: |
${{ env.REGISTRY}}/${{ env.IMAGE_NAME }}:${{ needs.build-dependencies-image.outputs.tag }}
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ needs.build-dependencies-image.outputs.tag }}
file: ./Containerfile
build-args: |
TAG=${{ needs.build-dependencies-image.outputs.tag }}
GIT_COMMIT_SHA=${{ github.event.pull_request.head.sha || github.sha }}
GIT_BRANCH_NAME=${{ needs.build-dependencies-image.outputs.branch }}
5 changes: 5 additions & 0 deletions Containerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
ARG TAG=latest

FROM cfaprdbatchcr.azurecr.io/pyrenew-hew-dependencies:${TAG}
ARG GIT_COMMIT_SHA
damonbayer marked this conversation as resolved.
Show resolved Hide resolved
ENV GIT_COMMIT_SHA=$GIT_COMMIT_SHA

ARG GIT_BRANCH_NAME
ENV GIT_BRANCH_NAME=$GIT_BRANCH_NAME

COPY ./hewr /pyrenew-hew/hewr

Expand Down
26 changes: 26 additions & 0 deletions pipelines/forecast_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

import numpyro
import polars as pl
import yaml

Check warning on line 11 in pipelines/forecast_state.py

View check run for this annotation

Codecov / codecov/patch

pipelines/forecast_state.py#L11

Added line #L11 was not covered by tests
from prep_data import process_and_save_state
from pygit2 import Repository

Check warning on line 13 in pipelines/forecast_state.py

View check run for this annotation

Codecov / codecov/patch

pipelines/forecast_state.py#L13

Added line #L13 was not covered by tests
from save_eval_data import save_eval_data

numpyro.set_host_device_count(4)
Expand All @@ -17,6 +19,27 @@
from generate_predictive import generate_and_save_predictions # noqa


def record_git_info(model_run_dir: Path):
metadata_file = Path(model_run_dir, "metadata.yaml")
try:
repo = Repository(os.getcwd())
branch_name = os.environ.get(

Check warning on line 26 in pipelines/forecast_state.py

View check run for this annotation

Codecov / codecov/patch

pipelines/forecast_state.py#L22-L26

Added lines #L22 - L26 were not covered by tests
"GIT_BRANCH_NAME", Path(repo.head.name).stem
)
commit_sha = os.environ.get("GIT_COMMIT_SHA", str(repo.head.target))
except:
branch_name = os.environ.get("GIT_BRANCH_NAME", "unknown")
commit_sha = os.environ.get("GIT_COMMIT_SHA", "unknown")

Check warning on line 32 in pipelines/forecast_state.py

View check run for this annotation

Codecov / codecov/patch

pipelines/forecast_state.py#L29-L32

Added lines #L29 - L32 were not covered by tests

metadata = {

Check warning on line 34 in pipelines/forecast_state.py

View check run for this annotation

Codecov / codecov/patch

pipelines/forecast_state.py#L34

Added line #L34 was not covered by tests
"branch_name": branch_name,
"commit_sha": commit_sha,
}

with open(metadata_file, "w") as file:
yaml.dump(metadata, file)

Check warning on line 40 in pipelines/forecast_state.py

View check run for this annotation

Codecov / codecov/patch

pipelines/forecast_state.py#L39-L40

Added lines #L39 - L40 were not covered by tests


def generate_epiweekly(model_run_dir: Path) -> None:
result = subprocess.run(
[
Expand Down Expand Up @@ -238,6 +261,9 @@

os.makedirs(model_run_dir, exist_ok=True)

logger.info("Recording git info...")
record_git_info(model_run_dir)

Check warning on line 265 in pipelines/forecast_state.py

View check run for this annotation

Codecov / codecov/patch

pipelines/forecast_state.py#L264-L265

Added lines #L264 - L265 were not covered by tests

logger.info(f"Using priors from {priors_path}...")
shutil.copyfile(priors_path, Path(model_run_dir, "priors.py"))

Expand Down
Loading