diff --git a/rs_tools/_src/geoprocessing/msg/geoprocessor_msg.py b/rs_tools/_src/geoprocessing/msg/geoprocessor_msg.py index 5ad19d7..6c70981 100644 --- a/rs_tools/_src/geoprocessing/msg/geoprocessor_msg.py +++ b/rs_tools/_src/geoprocessing/msg/geoprocessor_msg.py @@ -313,8 +313,8 @@ def preprocess_files(self): logger.error(f"Skipping {itime} due to error loading") continue - # remove crs from dataset - ds = ds.drop_vars('msg_seviri_fes_3km') + # remove crs from dataset + # ds = ds.drop_vars('msg_seviri_fes_3km') # NOTE: Uncommented to keep coordinate reference system # remove attrs that cause netcdf error for var in ds.data_vars: diff --git a/rs_tools/_src/preprocessing/prepatcher.py b/rs_tools/_src/preprocessing/prepatcher.py index d044682..af5cca5 100644 --- a/rs_tools/_src/preprocessing/prepatcher.py +++ b/rs_tools/_src/preprocessing/prepatcher.py @@ -1,26 +1,29 @@ -import autoroot -import numpy as np -from xrpatcher._src.base import XRDAPatcher -import rioxarray +from __future__ import annotations + +import gc import os -from pathlib import Path from dataclasses import dataclass -from typing import Optional, List, Union, Tuple -from tqdm import tqdm -from rs_tools._src.utils.io import get_list_filenames +from pathlib import Path + +import numpy as np import typer -from loguru import logger import xarray as xr +from loguru import logger +from rs_tools._src.utils.io import get_list_filenames +from tqdm import tqdm +from xrpatcher._src.base import XRDAPatcher + def _check_filetype(file_type: str) -> bool: """checks instrument for GOES data.""" - if file_type in ["nc", "np"]: + if file_type in ["nc", "np", "tif"]: return True else: msg = "Unrecognized file type" - msg += f"\nNeeds to be 'nc' or 'np'. Others are not yet tested" + msg += f"\nNeeds to be 'nc', 'np', or 'tif'. Others are not yet tested" raise ValueError(msg) - + + def _check_nan_count(arr: np.array, nan_cutoff: float) -> bool: """ Check if the number of NaN values in the given array is below a specified cutoff. @@ -37,11 +40,15 @@ def _check_nan_count(arr: np.array, nan_cutoff: float) -> bool: # get total pixel count total_count = int(arr.size) # check if nan_count is within allowed cutoff - if nan_count/total_count <= nan_cutoff: + + pct_nan = nan_count / total_count + + if pct_nan <= nan_cutoff: return True else: return False + @dataclass(frozen=True) class PrePatcher: """ @@ -53,7 +60,7 @@ class PrePatcher: patch_size (int): The size of each patch. stride_size (int): The stride size for generating patches. nan_cutoff (float): The cutoff value for allowed NaN count in a patch. - save_filetype (str): The file type to save patches as. Options are [nc, np]. + save_filetype (str): The file type to save patches as. Options are [nc, np, tif]. Methods: nc_files(self) -> List[str]: Returns a list of all NetCDF filenames in the read_path directory. @@ -61,14 +68,14 @@ class PrePatcher: """ read_path: str - save_path: str + save_path: str patch_size: int - stride_size: int + stride_size: int nan_cutoff: float save_filetype: str @property - def nc_files(self) -> List[str]: + def nc_files(self) -> list[str]: """ Returns a list of all NetCDF filenames in the read_path directory. @@ -91,6 +98,25 @@ def save_patches(self): pbar.set_description(f"Processing: {itime}") # open dataset ds = xr.open_dataset(ifile, engine="netcdf4") + + if self.save_filetype == "tif": + # concatenate variables + ds_temp = xr.concat( + [ds.cloud_mask, ds.latitude, ds.longitude], dim="band" + ) + # name data variables "Rad" + ds_temp = ds_temp.to_dataset(name="Rad") + ds_temp = ds_temp.drop(["cloud_mask", "latitude", "longitude"]) + ds_temp = ds_temp.assign_coords( + band=["cloud_mask", "latitude", "longitude"] + ) + # merge with original dataset + ds = xr.merge([ds_temp.Rad, ds.Rad]) + # store band names to be attached to da later + band_names = [str(i) for i in ds.band.values] + del ds_temp + gc.collect() + # extract radiance data array da = ds.Rad # define patch parameters @@ -104,36 +130,82 @@ def save_patches(self): os.makedirs(self.save_path) for i, ipatch in tqdm(enumerate(patcher), total=len(patcher)): - data = ipatch.data # extract data patch + data = ipatch.data # extract data + # logger.info(f'stride size {self.stride_size} ') if _check_nan_count(data, self.nan_cutoff): if self.save_filetype == "nc": # reconvert to dataset to attach band_wavelength and time - ipatch = ipatch.to_dataset(name='Rad') - ipatch = ipatch.assign_coords({'time': ds.time.values}) - ipatch = ipatch.assign_coords({'band_wavelength': ds.band_wavelength.values}) + ipatch = ipatch.to_dataset(name="Rad") + ipatch = ipatch.assign_coords({"time": ds.time.values}) + ipatch = ipatch.assign_coords( + {"band_wavelength": ds.band_wavelength.values} + ) # compile filename - file_path = Path(self.save_path).joinpath(f"{itime}_patch_{i}.nc") + file_path = Path(self.save_path).joinpath( + f"{itime}_patch_{i}.nc" + ) # remove file if it already exists if os.path.exists(file_path): os.remove(file_path) - # save patch to netcdf - ipatch.to_netcdf(Path(self.save_path).joinpath(f"{itime}_patch_{i}.nc"), engine="netcdf4") + # save patch to netcdf + ipatch.to_netcdf( + Path(self.save_path).joinpath(f"{itime}_patch_{i}.nc"), + engine="netcdf4", + ) + elif self.save_filetype == "tif": + # reconvert to dataset to attach band_wavelength and time + # ds.attrs['band_names'] = [str(i) for i in ds.band.values] + # compile filename + file_path = Path(self.save_path).joinpath( + f"{itime}_patch_{i}.nc" + ) + # remove file if it already exists + if os.path.exists(file_path): + os.remove(file_path) + # add band names as attribute + ipatch.attrs["band_names"] = band_names + # save patch to tiff + ipatch.rio.to_raster( + Path(self.save_path).joinpath(f"{itime}_patch_{i}.tif") + ) elif self.save_filetype == "np": # save as numpy files - np.save(Path(self.save_path).joinpath(f"{itime}_radiance_patch_{i}"), data) - np.save(Path(self.save_path).joinpath(f"{itime}_latitude_patch_{i}"), ipatch.latitude.values) - np.save(Path(self.save_path).joinpath(f"{itime}_longitude_patch_{i}"), ipatch.longitude.values) - np.save(Path(self.save_path).joinpath(f"{itime}_cloudmask_patch_{i}"), ipatch.cloud_mask.values) + np.save( + Path(self.save_path).joinpath( + f"{itime}_radiance_patch_{i}" + ), + data, + ) + np.save( + Path(self.save_path).joinpath( + f"{itime}_latitude_patch_{i}" + ), + ipatch.latitude.values, + ) + np.save( + Path(self.save_path).joinpath( + f"{itime}_longitude_patch_{i}" + ), + ipatch.longitude.values, + ) + np.save( + Path(self.save_path).joinpath( + f"{itime}_cloudmask_patch_{i}" + ), + ipatch.cloud_mask.values, + ) else: - logger.info(f'NaN count exceeded for patch {i} of timestamp {itime}.') + pass + # logger.info(f'NaN count exceeded for patch {i} of timestamp {itime}.') + def prepatch( - read_path: str = "./", - save_path: str = "./", - patch_size: int = 256, - stride_size: int = 256, - nan_cutoff: float = 0.5, - save_filetype: str = 'nc' + read_path: str = "./", + save_path: str = "./", + patch_size: int = 256, + stride_size: int = 256, + nan_cutoff: float = 0.5, + save_filetype: str = "nc", ): """ Patches satellite data into smaller patches for training. @@ -151,21 +223,23 @@ def prepatch( _check_filetype(file_type=save_filetype) # Initialize Prepatcher + logger.info(f"Patching Files...: {read_path}") logger.info(f"Initializing Prepatcher...") prepatcher = PrePatcher( - read_path=read_path, + read_path=read_path, save_path=save_path, patch_size=patch_size, stride_size=stride_size, nan_cutoff=nan_cutoff, - save_filetype=save_filetype - ) + save_filetype=save_filetype, + ) logger.info(f"Patching Files...: {save_path}") prepatcher.save_patches() logger.info(f"Finished Prepatching Script...!") -if __name__ == '__main__': + +if __name__ == "__main__": """ python scripts/pipeline/prepatch.py --read-path "/path/to/netcdf/file" --save-path /path/to/save/patches """