diff --git a/.github/workflows/containers.yaml b/.github/workflows/containers.yaml index 0ba0d5ba..bf91e07e 100644 --- a/.github/workflows/containers.yaml +++ b/.github/workflows/containers.yaml @@ -20,6 +20,7 @@ jobs: outputs: tag: ${{ steps.image-tag.outputs.tag }} commit-msg: ${{ steps.commit-message.outputs.message }} + branch: ${{ steps.branch-name.outputs.branch }} steps: @@ -115,7 +116,9 @@ jobs: with: push: true # This can be toggled manually for tweaking. tags: | - ${{ env.REGISTRY}}/${{ env.IMAGE_NAME }}:${{ needs.build-dependencies-image.outputs.tag }} + ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ needs.build-dependencies-image.outputs.tag }} file: ./Containerfile build-args: | TAG=${{ needs.build-dependencies-image.outputs.tag }} + GIT_COMMIT_SHA=${{ github.event.pull_request.head.sha || github.sha }} + GIT_BRANCH_NAME=${{ needs.build-dependencies-image.outputs.branch }} diff --git a/Containerfile b/Containerfile index f627bf32..20b98455 100644 --- a/Containerfile +++ b/Containerfile @@ -1,6 +1,11 @@ ARG TAG=latest FROM cfaprdbatchcr.azurecr.io/pyrenew-hew-dependencies:${TAG} +ARG GIT_COMMIT_SHA +ENV GIT_COMMIT_SHA=$GIT_COMMIT_SHA + +ARG GIT_BRANCH_NAME +ENV GIT_BRANCH_NAME=$GIT_BRANCH_NAME COPY ./hewr /pyrenew-hew/hewr diff --git a/pipelines/forecast_state.py b/pipelines/forecast_state.py index c5e8d445..c5c4442d 100644 --- a/pipelines/forecast_state.py +++ b/pipelines/forecast_state.py @@ -8,7 +8,9 @@ import numpyro import polars as pl +import yaml from prep_data import process_and_save_state +from pygit2 import Repository from save_eval_data import save_eval_data numpyro.set_host_device_count(4) @@ -17,6 +19,27 @@ from generate_predictive import generate_and_save_predictions # noqa +def record_git_info(model_run_dir: Path): + metadata_file = Path(model_run_dir, "metadata.yaml") + try: + repo = Repository(os.getcwd()) + branch_name = os.environ.get( + "GIT_BRANCH_NAME", Path(repo.head.name).stem + ) + commit_sha = os.environ.get("GIT_COMMIT_SHA", str(repo.head.target)) + except: + branch_name = os.environ.get("GIT_BRANCH_NAME", "unknown") + commit_sha = os.environ.get("GIT_COMMIT_SHA", "unknown") + + metadata = { + "branch_name": branch_name, + "commit_sha": commit_sha, + } + + with open(metadata_file, "w") as file: + yaml.dump(metadata, file) + + def generate_epiweekly(model_run_dir: Path) -> None: result = subprocess.run( [ @@ -238,6 +261,9 @@ def main( os.makedirs(model_run_dir, exist_ok=True) + logger.info("Recording git info...") + record_git_info(model_run_dir) + logger.info(f"Using priors from {priors_path}...") shutil.copyfile(priors_path, Path(model_run_dir, "priors.py"))