diff --git a/rslp/forest_loss_driver/inference/config/forest_loss_driver_predict_pipeline_config.yaml b/rslp/forest_loss_driver/inference/config/forest_loss_driver_predict_pipeline_config.yaml index d51ac65..5eb9165 100644 --- a/rslp/forest_loss_driver/inference/config/forest_loss_driver_predict_pipeline_config.yaml +++ b/rslp/forest_loss_driver/inference/config/forest_loss_driver_predict_pipeline_config.yaml @@ -10,7 +10,7 @@ materialize_pipeline_args: materialize_args: apply_windows_args: workers: 112 -select_best_images_args: +select_least_cloudy_images_args: workers: 112 model_predict_args: model_cfg_fname: "data/forest_loss_driver/config_satlaspretrain_flip_oldmodel_unfreeze.yaml" # should be path from the top of the repo IF NOT ABSOLUTE PATH diff --git a/tests/integration/forest_loss_driver/test_predict_pipeline.py b/tests/integration/forest_loss_driver/test_predict_pipeline.py index 81f944d..533405b 100644 --- a/tests/integration/forest_loss_driver/test_predict_pipeline.py +++ b/tests/integration/forest_loss_driver/test_predict_pipeline.py @@ -7,6 +7,7 @@ import uuid from collections.abc import Generator from datetime import datetime, timezone +from unittest.mock import patch import pytest from google.cloud import storage @@ -19,6 +20,7 @@ ) from rslp.forest_loss_driver.predict_pipeline import ForestLossDriverPredictionPipeline from rslp.log_utils import get_logger +from rslp.main import main TEST_ID = str(uuid.uuid4()) logger = get_logger(__name__) @@ -38,6 +40,12 @@ def test_bucket() -> Generator[storage.Bucket, None, None]: yield bucket +@pytest.fixture +def predict_pipeline_config_path() -> str: + """The path to the config file used for inference.""" + return "rslp/forest_loss_driver/inference/config/forest_loss_driver_predict_pipeline_config.yaml" + + @pytest.fixture def predict_pipeline_config( inference_dataset_config_path: str, @@ -158,3 +166,28 @@ def test_predict_pipeline( assert ( abs(actual - expected) < tol ), f"Probability difference {abs(actual - expected)} exceeds threshold {tol}" + + +def test_forest_loss_driver_predict_cli_config_load( + predict_pipeline_config_path: str, +) -> None: + def assert_config(pred_pipeline_config: PredictPipelineConfig) -> bool: + # Verify the config is the correct type + logger.info(f"Pred pipeline config: {pred_pipeline_config}") + assert isinstance(pred_pipeline_config, PredictPipelineConfig) + return True + + with ( + patch( + "sys.argv", + [ + "rslp", + "forest_loss_driver", + "predict", + "--pred_pipeline_config", + predict_pipeline_config_path, + ], + ), + patch("rslp.forest_loss_driver.workflows", {"predict": assert_config}), + ): + main()