Skip to content

Commit

Permalink
fix yaml and add test to catch bad load
Browse files Browse the repository at this point in the history
  • Loading branch information
Hgherzog committed Dec 20, 2024
1 parent 0d8d6dd commit 990dd6b
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ materialize_pipeline_args:
materialize_args:
apply_windows_args:
workers: 112
select_best_images_args:
select_least_cloudy_images_args:
workers: 112
model_predict_args:
model_cfg_fname: "data/forest_loss_driver/config_satlaspretrain_flip_oldmodel_unfreeze.yaml" # should be path from the top of the repo IF NOT ABSOLUTE PATH
33 changes: 33 additions & 0 deletions tests/integration/forest_loss_driver/test_predict_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import uuid
from collections.abc import Generator
from datetime import datetime, timezone
from unittest.mock import patch

import pytest
from google.cloud import storage
Expand All @@ -19,6 +20,7 @@
)
from rslp.forest_loss_driver.predict_pipeline import ForestLossDriverPredictionPipeline
from rslp.log_utils import get_logger
from rslp.main import main

TEST_ID = str(uuid.uuid4())
logger = get_logger(__name__)
Expand All @@ -38,6 +40,12 @@ def test_bucket() -> Generator[storage.Bucket, None, None]:
yield bucket


@pytest.fixture
def predict_pipeline_config_path() -> str:
"""The path to the config file used for inference."""
return "rslp/forest_loss_driver/inference/config/forest_loss_driver_predict_pipeline_config.yaml"


@pytest.fixture
def predict_pipeline_config(
inference_dataset_config_path: str,
Expand Down Expand Up @@ -158,3 +166,28 @@ def test_predict_pipeline(
assert (
abs(actual - expected) < tol
), f"Probability difference {abs(actual - expected)} exceeds threshold {tol}"


def test_forest_loss_driver_predict_cli_config_load(
predict_pipeline_config_path: str,
) -> None:
def assert_config(pred_pipeline_config: PredictPipelineConfig) -> bool:
# Verify the config is the correct type
logger.info(f"Pred pipeline config: {pred_pipeline_config}")
assert isinstance(pred_pipeline_config, PredictPipelineConfig)
return True

with (
patch(
"sys.argv",
[
"rslp",
"forest_loss_driver",
"predict",
"--pred_pipeline_config",
predict_pipeline_config_path,
],
),
patch("rslp.forest_loss_driver.workflows", {"predict": assert_config}),
):
main()

0 comments on commit 990dd6b

Please sign in to comment.