Skip to content

Commit

Permalink
Fixed formatting and added types
Browse files Browse the repository at this point in the history
  • Loading branch information
royagrace committed Oct 2, 2024
1 parent ca7aeb8 commit 1b8efe6
Showing 1 changed file with 63 additions and 53 deletions.
116 changes: 63 additions & 53 deletions tools/RAiDER/models/weatherModel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import datetime as dt
import os
from abc import ABC, abstractmethod
from typing import Optional
from pathlib import Path
from typing import Optional, Union

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -37,7 +38,7 @@ class WeatherModel(ABC):
_dataset: Optional[str]

def __init__(self) -> None:
# Initialize model-specific constants/parameters
"""Initialize model-specific constants/parameters."""
self._k1 = None
self._k2 = None
self._k3 = None
Expand All @@ -56,7 +57,7 @@ def __init__(self) -> None:

self._classname = None
self._dataset = None
self._Name = None
self._Name = ''
self._wmLoc = None

self._model_level_type = 'ml'
Expand Down Expand Up @@ -99,6 +100,7 @@ def __init__(self) -> None:
self._hydrostatic_ztd = None

def __str__(self) -> str:
"""Prints out the weather model information."""
string = '\n'
string += '======Weather Model class object=====\n'
string += f'Weather model time: {self._time}\n'
Expand Down Expand Up @@ -128,25 +130,27 @@ def __str__(self) -> str:
string += '=====================================\n'
return string

def Model(self):
def Model(self) -> str:
"""Returns the name of the weather model."""
return self._Name

def dtime(self):
def dtime(self) -> int:
"""Returns the availability of the weather model in hours."""
return self._time_res

def getLLRes(self):
def getLLRes(self) -> float:
"""Returns the grid spacing."""
return np.max([self._lat_res, self._lon_res])

def fetch(self, out, time) -> None:
def fetch(self, out: Path, time: dt.datetime) -> None:
"""
Checks the input datetime against the valid date range for the model and then
calls the model _fetch routine.
Args:
----------
out -
ll_bounds - 4 x 1 array, SNWE
time = UTC datetime
out: Path
time: dt.datetime, = UTC datetime
"""
self.checkTime(time)
self.setTime(time)
Expand All @@ -159,14 +163,15 @@ def fetch(self, out, time) -> None:
raise

@abstractmethod
def _fetch(self, out):
def _fetch(self, out: Path): # noqa: ANN202
"""Placeholder method. Should be implemented in each weather model type class."""
pass

def getTime(self):
def getTime(self) -> dt.datetime:
"""Returns the time of the weather model."""
return self._time

def setTime(self, time, fmt='%Y-%m-%dT%H:%M:%S') -> None:
def setTime(self, time: dt.datetime, fmt: str='%Y-%m-%dT%H:%M:%S') -> None:
"""Set the time for a weather model."""
if isinstance(time, str):
self._time = dt.datetime.strptime(time, fmt)
Expand All @@ -177,10 +182,11 @@ def setTime(self, time, fmt='%Y-%m-%dT%H:%M:%S') -> None:
if self._time.tzinfo is None:
self._time = self._time.replace(tzinfo=dt.timezone(offset=dt.timedelta()))

def get_latlon_bounds(self):
def get_latlon_bounds(self) -> Union[list, np.ndarray]:
"""Returns the bounds of the weather model."""
return self._ll_bounds

def set_latlon_bounds(self, ll_bounds, Nextra=2, output_spacing=None) -> None:
def set_latlon_bounds(self, ll_bounds: Union[list, np.ndarray], Nextra: int=2, output_spacing: float=None) -> None:
"""
Need to correct lat/lon bounds because not all of the weather models have valid
data exactly bounded by -90/90 (lats) and -180/180 (lons); for GMAO and MERRA2,
Expand Down Expand Up @@ -212,10 +218,10 @@ def set_latlon_bounds(self, ll_bounds, Nextra=2, output_spacing=None) -> None:

self._ll_bounds = np.array([S, N, W, E])

def get_wmLoc(self):
def get_wmLoc(self) -> Path:
"""Get the path to the direct with the weather model files."""
if self._wmLoc is None:
wmLoc = os.path.join(os.getcwd(), 'weather_files')
wmLoc = Path.join(Path.getcwd(), 'weather_files')
else:
wmLoc = self._wmLoc
return wmLoc
Expand All @@ -224,7 +230,7 @@ def set_wmLoc(self, weather_model_directory: str) -> None:
"""Set the path to the directory with the weather model files."""
self._wmLoc = weather_model_directory

def load(self, *args, _zlevels=None, **kwargs):
def load(self, *args: tuple, _zlevels: Union[np.ndarray, list]=None, **kwargs: dict) -> None:
"""
Calls the load_weather method. Each model class should define a load_weather
method appropriate for that class. 'args' should be one or more filenames.
Expand All @@ -234,7 +240,7 @@ def load(self, *args, _zlevels=None, **kwargs):
path_wm_raw = make_raw_weather_data_filename(outLoc, self.Model(), self.getTime())
self._out_name = self.out_file(outLoc)

if os.path.exists(self._out_name):
if Path.exists(self._out_name):
return self._out_name
else:
# Load the weather just for the query points
Expand All @@ -253,11 +259,11 @@ def load(self, *args, _zlevels=None, **kwargs):
return None

@abstractmethod
def load_weather(self, *args, **kwargs):
def load_weather(self, *args: tuple, **kwargs: dict) -> None:
"""Placeholder method. Should be implemented in each weather model type class."""
pass

def plot(self, plotType='pqt', savefig=True):
def plot(self, plotType: str='pqt', savefig: bool=True) -> str:
"""Plotting method. Valid plot types are 'pqt'."""
if plotType == 'pqt':
plot = plots.plot_pqt(self, savefig)
Expand All @@ -267,7 +273,7 @@ def plot(self, plotType='pqt', savefig=True):
raise RuntimeError(f'WeatherModel.plot: No plotType named {plotType}')
return plot

def checkTime(self, time) -> None:
def checkTime(self, time: dt.datetime) -> None:
"""
Checks the time against the lag time and valid date range for the given model type.
Expand Down Expand Up @@ -299,7 +305,7 @@ def checkTime(self, time) -> None:
if time > dt.datetime.now(dt.timezone.utc) - self._lag_time:
raise DatetimeOutsideRange(self.Model(), time)

def setLevelType(self, levelType) -> None:
def setLevelType(self, levelType: str) -> None:
"""Set the level type to model levels or pressure levels."""
if levelType in 'ml pl nat prs'.split():
self._model_level_type = levelType
Expand All @@ -311,11 +317,11 @@ def setLevelType(self, levelType) -> None:
else:
self.__pressure_levels__()

def _convertmb2Pa(self, pres):
def _convertmb2Pa(self, pres: Union[float, int, np.ndarray]) -> Union[float, int, np.ndarray]:
"""Convert pressure in millibars to Pascals."""
return 100 * pres

def _get_heights(self, lats, geo_hgt, geo_ht_fill=np.nan) -> None:
def _get_heights(self, lats: np.ndarray, geo_hgt: np.ndarray, geo_ht_fill: np.ndarray=np.nan) -> None:
"""Transform geo heights to WGS84 ellipsoidal heights."""
geo_ht_fix = np.where(geo_hgt != geo_ht_fill, geo_hgt, np.nan)
lats_full = np.broadcast_to(lats[..., np.newaxis], geo_ht_fix.shape)
Expand Down Expand Up @@ -352,16 +358,17 @@ def _get_hydro_refractivity(self) -> None:
"""Calculate the hydrostatic delay from pressure and temperature."""
self._hydrostatic_refractivity = self._k1 * self._p / self._t

def getWetRefractivity(self):
def getWetRefractivity(self) -> np.ndarray:
"""Returns the data cube of refractivity."""
return self._wet_refractivity

def getHydroRefractivity(self):
def getHydroRefractivity(self) -> np.ndarray:
"""Returns the data cube of hydrostatic refractivity."""
return self._hydrostatic_refractivity

def _adjust_grid(self, ll_bounds=None) -> None:
def _adjust_grid(self, ll_bounds: Union[list, tuple, np.ndarray]=None) -> None:
"""This function pads the weather grid with a level at self._zmin, if it does not already go that low."""
"""
This function pads the weather grid with a level at self._zmin, if
it does not already go that low.
<<The functionality below has been removed.>>
<<It also removes levels that are above self._zmax, since they are not needed.>>
"""
Expand Down Expand Up @@ -393,7 +400,7 @@ def _getZTD(self) -> None:
self._hydrostatic_ztd = hydro_total
self._wet_ztd = wet_total

def _getExtent(self, lats, lons):
def _getExtent(self, lats: np.ndarray, lons: np.ndarray) -> np.ndarray:
"""Get the bounding box around a set of lats/lons."""
if (lats.size == 1) & (lons.size == 1):
return [lats - self._lat_res, lats + self._lat_res, lons - self._lon_res, lons + self._lon_res]
Expand All @@ -407,7 +414,7 @@ def _getExtent(self, lats, lons):
raise RuntimeError('Not a valid lat/lon shape')

@property
def bbox(self) -> list:
def bbox(self) -> Union[list, tuple, np.ndarray]:
"""
Obtains the bounding box of the weather model in lat/lon CRS.
Expand All @@ -423,7 +430,7 @@ def bbox(self) -> list:
"""
if self._bbox is None:
path_weather_model = self.out_file(self.get_wmLoc())
if not os.path.exists(path_weather_model):
if not Path.exists(path_weather_model):
raise ValueError('Need to save cropped weather model as netcdf')

with xr.load_dataset(path_weather_model) as ds:
Expand Down Expand Up @@ -459,8 +466,8 @@ def checkValidBounds(
if not box(W, S, E, N).intersects(self._valid_bounds):
raise ValueError(f'The requested location is unavailable for {self._Name}')

def checkContainment(self, ll_bounds, buffer_deg: float = 1e-5) -> bool:
""" "
def checkContainment(self, ll_bounds: Union[list, tuple, np.ndarray], buffer_deg: float = 1e-5) -> bool:
"""
Checks containment of weather model bbox of outLats and outLons
provided.
Expand Down Expand Up @@ -510,7 +517,7 @@ def checkContainment(self, ll_bounds, buffer_deg: float = 1e-5) -> bool:

return weather_model_box.contains(input_box)

def _isOutside(self, extent1, extent2) -> bool:
def _isOutside(self, extent1: list, extent2: list) -> bool:
"""
Determine whether any of extent1 lies outside extent2.
extent1/2 should be a list containing [lower_lat, upper_lat, left_lon, right_lon].
Expand All @@ -521,7 +528,7 @@ def _isOutside(self, extent1, extent2) -> bool:
t4 = extent1[3] > extent2[3]
return np.any([t1, t2, t3, t4])

def _trimExtent(self, extent) -> None:
def _trimExtent(self, extent: list) -> None:
"""Get the bounding box around a set of lats/lons."""
lat = self._lats.copy()
lon = self._lons.copy()
Expand Down Expand Up @@ -553,11 +560,12 @@ def _trimExtent(self, extent) -> None:
self._wet_refractivity = self._wet_refractivity[index1:index2, index3:index4, ...]
self._hydrostatic_refractivity = self._hydrostatic_refractivity[index1:index2, index3:index4, :]

def _calculategeoh(self, z, lnsp):
def _calculategeoh(self, z: np.ndarray, lnsp: np.ndarray) -> Union[list, tuple, np.ndarray]:
"""
Function to calculate pressure, geopotential, and geopotential height
from the surface pressure and model levels provided by a weather model.
The model levels are numbered from the highest eleveation to the lowest.
Inputs:
self - weather model object with parameters a, b defined
z - 3-D array of surface heights for the location(s) of interest
Expand All @@ -570,14 +578,15 @@ def _calculategeoh(self, z, lnsp):
"""
return calcgeoh(lnsp, self._t, self._q, z, self._a, self._b, self._R_d, self._levels)

def getProjection(self):
def getProjection(self) -> CRS:
"""Returns: the native weather projection, which should be a pyproj object."""
return self._proj

def getPoints(self):
def getPoints(self) -> Union[list, tuple, np.ndarray]:
"""Returns something."""
return self._xs.copy(), self._ys.copy(), self._zs.copy()

def _uniform_in_z(self, _zlevels=None) -> None:
def _uniform_in_z(self, _zlevels: Union[np.ndarray, list]=None) -> None:
"""Interpolate all variables to a regular grid in z."""
nx, ny = self._p.shape[:2]

Expand Down Expand Up @@ -605,17 +614,18 @@ def _checkForNans(self) -> None:
self._t = fillna3D(self._t, fill_value=1e16) # to avoid division by zero later on
self._e = fillna3D(self._e)

def out_file(self, outLoc):
def out_file(self, outLoc: str) -> Path:
"""Returns outloc."""
f = make_weather_model_filename(
self._Name,
self._time,
self._ll_bounds,
)
return os.path.join(outLoc, f)
return Path.join(outLoc, f)

def filename(self, time=None, outLoc='weather_files'):
def filename(self, time: dt.datetime=None, outLoc: str='weather_files') -> str:
"""Create a filename to store the weather model."""
os.makedirs(outLoc, exist_ok=True)
Path.mkdir(outLoc, exist_ok=True)

if time is None:
if self._time is None:
Expand All @@ -632,7 +642,7 @@ def filename(self, time=None, outLoc='weather_files'):
self.files = [f]
return f

def write(self):
def write(self) -> str:
"""
By calling the abstract/modular netcdf writer
(RAiDER.utilFcns.write2NETCDF4core), write the weather model data
Expand Down Expand Up @@ -700,7 +710,8 @@ def write(self):
return f


def make_weather_model_filename(name, time, ll_bounds) -> str:
def make_weather_model_filename(name: str, time: dt.datetime, ll_bounds: Union[list, tuple, np.ndarray]) -> str:
"""Creates the filename for the weather model."""
s = np.floor(ll_bounds[0])
S = f'{np.abs(s):.0f}S' if s < 0 else f'{s:.0f}N'

Expand All @@ -715,14 +726,14 @@ def make_weather_model_filename(name, time, ll_bounds) -> str:
return f'{name}_{time.strftime("%Y_%m_%d_T%H_%M_%S")}_{S}_{N}_{W}_{E}.nc'


def make_raw_weather_data_filename(outLoc, name, time):
def make_raw_weather_data_filename(outLoc: str, name: str, time: dt.datetime) -> str:
"""Filename generator for the raw downloaded weather model data."""
date_string = dt.datetime.strftime(time, '%Y_%m_%d_T%H_%M_%S')
f = os.path.join(outLoc, f'{name}_{date_string}.nc')
return f


def find_svp(t):
def find_svp(t: np.ndarray) -> np.ndarray:
"""Calculate standard vapor presure. Should be model-specific."""
# From TRAIN:
# Could not find the wrf used equation as they appear to be
Expand Down Expand Up @@ -754,8 +765,7 @@ def find_svp(t):
svp = svp * 100
return svp.astype(np.float32)


def get_mapping(proj):
def get_mapping(proj: CRS) -> CRS:
"""Get CF-complient projection information from a proj."""
# In case of WGS-84 lat/lon, keep it simple
if proj.to_epsg() == 4326:
Expand All @@ -764,8 +774,8 @@ def get_mapping(proj):
return proj.to_wkt()


def checkContainment_raw(path_wm_raw, ll_bounds, buffer_deg: float = 1e-5) -> bool:
""" "
def checkContainment_raw(path_wm_raw: Path, ll_bounds: Union[list, tuple, np.ndarray], buffer_deg: float = 1e-5) -> bool:
"""
Checks if existing raw weather model contains
requested ll_bounds.
Expand Down

0 comments on commit 1b8efe6

Please sign in to comment.