diff --git a/src/transport_performance/utils/raster.py b/src/transport_performance/utils/raster.py index 05bf3fc2..fa801e59 100644 --- a/src/transport_performance/utils/raster.py +++ b/src/transport_performance/utils/raster.py @@ -13,6 +13,10 @@ from rioxarray.merge import merge_arrays from rasterio.warp import Resampling +from transport_performance.utils.defence import ( + _check_parent_dir_exists, + _is_expected_filetype, +) def merge_raster_files( @@ -146,8 +150,7 @@ def sum_resample_file( """ # defend against case where the provided input dir does not exist - if not os.path.exists(input_filepath): - raise FileNotFoundError(f"{input_filepath} can not be found") + _is_expected_filetype(input_filepath, "input_filepath", exp_ext=".tif") xds = rioxarray.open_rasterio(input_filepath, masked=True) @@ -161,7 +164,6 @@ def sum_resample_file( ) # make output_filepath's directory if it does not exist - if not os.path.exists(os.path.dirname(output_filepath)): - os.mkdir(output_filepath) + _check_parent_dir_exists(output_filepath, "output_filepath", create=True) xds_resampled.rio.to_raster(output_filepath) diff --git a/tests/utils/test_raster.py b/tests/utils/test_raster.py index 249ede0d..b1dcf37c 100644 --- a/tests/utils/test_raster.py +++ b/tests/utils/test_raster.py @@ -15,6 +15,9 @@ import xarray as xr import rioxarray # noqa: F401 - import required for xarray but not needed here +from typing import Type +from pytest_lazyfixture import lazy_fixture +from _pytest.python_api import RaisesContext from transport_performance.utils.raster import ( merge_raster_files, sum_resample_file, @@ -159,6 +162,31 @@ def resample_xarr_fpath( return out_filepath +@pytest.fixture +def save_empty_text_file(resample_xarr_fpath: str) -> str: + """Save an empty text file. + + Parameters + ---------- + resample_xarr_fpath : str + File path to dummy raster data, used to make sure file is in the same + directory. + + Returns + ------- + str + Dummy text file name. + + """ + # save an empty text file to the same directory + working_dir = os.path.dirname(resample_xarr_fpath) + test_file_name = "text.txt" + with open(os.path.join(working_dir, test_file_name), "w") as f: + f.write("") + + return test_file_name + + class TestUtilsRaster: """A class to test utils/raster functions.""" @@ -222,9 +250,13 @@ def test_sum_resample_file(self, resample_xarr_fpath: str) -> None: # useful when using -rP flag in pytest to see directory print(f"Temp filepath for resampling test: {resample_xarr_fpath}") - # resample to input and set the output location + # set the output location to sub dir in a different folder + # adding different sub dir to test resolution of issue 121 output_fpath = os.path.join( - os.path.dirname(resample_xarr_fpath), "output.tif" + os.path.dirname(os.path.dirname(resample_xarr_fpath)), + "resample_outputs", + "outputs", + "output.tif", ) sum_resample_file(resample_xarr_fpath, output_fpath) @@ -241,3 +273,40 @@ def test_sum_resample_file(self, resample_xarr_fpath: str) -> None: # assert correct resampling values (summing consitiuent grids) expected_result = np.array([[[14, 22], [46, 54]]]) assert np.array_equal(expected_result, xds_out.to_numpy()) + + @pytest.mark.parametrize( + "input_path, file_name, expected", + [ + # test file that does not exist + ( + lazy_fixture("resample_xarr_fpath"), + "test.tif", + pytest.raises(FileExistsError), + ), + # test file with an invalid file extension + ( + lazy_fixture("resample_xarr_fpath"), + lazy_fixture("save_empty_text_file"), + pytest.raises(ValueError), + ), + ], + ) + def test_sum_resample_on_fail( + self, input_path: str, file_name: str, expected: Type[RaisesContext] + ) -> None: + """Test sum_resample_file in failing cases. + + Parameters + ---------- + input_path : str + path to input dummy raster data + file_name : str + name of file to be tested + expected : Type[RaisesContext] + exception to test with + + """ + with expected: + input_folder = os.path.dirname(input_path) + fpath = os.path.join(input_folder, file_name) + sum_resample_file(fpath, "")