Skip to content

Commit

Permalink
Move xr_flexsel to its own module; functionize bits of it.
Browse files Browse the repository at this point in the history
  • Loading branch information
samsrabin committed Feb 10, 2024
1 parent 3dae192 commit 538ab01
Show file tree
Hide file tree
Showing 3 changed files with 267 additions and 202 deletions.
201 changes: 1 addition & 200 deletions python/ctsm/crop_calendars/cropcal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import xarray as xr
from ctsm.crop_calendars.xr_flexsel import xr_flexsel


def define_pftlist():
Expand Down Expand Up @@ -256,206 +257,6 @@ def vegtype_str2int(vegtype_str, vegtype_mainlist=None):
return indices


def xr_flexsel(xr_object, patches1d_itype_veg=None, warn_about_seltype_interp=True, **kwargs):
"""
Flexibly subset time(s) and/or vegetation type(s) from an xarray Dataset or DataArray.
- Keyword arguments like dimension=selection.
- Selections can be individual values or slice()s.
- Optimize memory usage by beginning keyword argument list with the selections that will result
in the largest reduction of object size.
- Use dimension "vegtype" to extract patches of designated vegetation type (can be string or
integer).
- Can also do dimension=function---e.g., time=np.mean will take the mean over the time
dimension.
"""
# Setup
havewarned = False
delimiter = "__"

for key, selection in kwargs.items():
if callable(selection):
# It would have been really nice to do selection(xr_object, axis=key), but numpy methods and xarray methods disagree on "axis" vs. "dimension." So instead, just do this manually.
if selection == np.mean:
try:
xr_object = xr_object.mean(dim=key)
except:
raise ValueError(
f"Failed to take mean of dimension {key}. Try doing so outside of"
" xr_flexsel()."
)
else:
raise ValueError(f"xr_flexsel() doesn't recognize function {selection}")

elif key == "vegtype":
# Convert to list, if needed
if not isinstance(selection, list):
selection = [selection]

# Convert to indices, if needed
if isinstance(selection[0], str):
selection = vegtype_str2int(selection)

# Get list of boolean(s)
if isinstance(selection[0], int):
if isinstance(patches1d_itype_veg, type(None)):
patches1d_itype_veg = xr_object.patches1d_itype_veg.values
elif isinstance(patches1d_itype_veg, xr.core.dataarray.DataArray):
patches1d_itype_veg = patches1d_itype_veg.values
is_vegtype = is_each_vegtype(patches1d_itype_veg, selection, "ok_exact")
elif isinstance(selection[0], bool):
if len(selection) != len(xr_object.patch):
raise ValueError(
"If providing boolean 'vegtype' argument to xr_flexsel(), it must be the"
f" same length as xr_object.patch ({len(selection)} vs."
f" {len(xr_object.patch)})"
)
is_vegtype = selection
else:
raise TypeError(f"Not sure how to handle 'vegtype' of type {type(selection[0])}")
xr_object = xr_object.isel(patch=[i for i, x in enumerate(is_vegtype) if x])
if "ivt" in xr_object:
xr_object = xr_object.isel(
ivt=is_each_vegtype(xr_object.ivt.values, selection, "ok_exact")
)

else:
# Parse selection type, if provided
if delimiter in key:
key, selection_type = key.split(delimiter)

# Check type of selection
else:
is_inefficient = False
if isinstance(selection, slice):
slice_members = []
if selection == slice(0):
raise ValueError("slice(0) will be empty")
if selection.start is not None:
slice_members = slice_members + [selection.start]
if selection.stop is not None:
slice_members = slice_members + [selection.stop]
if selection.step is not None:
slice_members = slice_members + [selection.step]
if slice_members == []:
raise TypeError("slice is all None?")
this_type = int
for member in slice_members:
if member < 0 or not isinstance(member, int):
this_type = "values"
break
elif isinstance(selection, np.ndarray):
if selection.dtype.kind in np.typecodes["AllInteger"]:
this_type = int
else:
is_inefficient = True
this_type = None
for member in selection:
if member < 0 or member % 1 > 0:
if isinstance(member, int):
this_type = "values"
else:
this_type = type(member)
break
if this_type is None:
this_type = int
selection = selection.astype(int)
else:
this_type = type(selection)

warn_about_this_seltype_interp = warn_about_seltype_interp
if this_type == list and isinstance(selection[0], str):
selection_type = "values"
warn_about_this_seltype_interp = False
elif this_type == int:
selection_type = "indices"
else:
selection_type = "values"

if warn_about_this_seltype_interp:
# Suggest suppressing selection type interpretation warnings
if not havewarned:
print(
"xr_flexsel(): Suppress all 'selection type interpretation' messages by"
" specifying warn_about_seltype_interp=False"
)
havewarned = True
if is_inefficient:
extra = " This will also improve efficiency for large selections."
else:
extra = ""
print(
f"xr_flexsel(): Selecting {key} as {selection_type} because selection was"
f" interpreted as {this_type}. If not correct, specify selection type"
" ('indices' or 'values') in keyword like"
f" '{key}{delimiter}SELECTIONTYPE=...' instead of '{key}=...'.{extra}"
)

# Trim along relevant 1d axes
if isinstance(xr_object, xr.Dataset) and key in ["lat", "lon"]:
if selection_type == "indices":
incl_coords = xr_object[key].values[selection]
elif selection_type == "values":
if isinstance(selection, slice):
incl_coords = xr_object.sel({key: selection}, drop=False)[key].values
else:
incl_coords = selection
else:
raise TypeError(f"selection_type {selection_type} not recognized")
if key == "lat":
this_xy = "jxy"
elif key == "lon":
this_xy = "ixy"
else:
raise KeyError(
f"Key '{key}' not recognized: What 1d_ suffix should I use for variable"
" name?"
)
pattern = re.compile(f"1d_{this_xy}")
matches = [x for x in list(xr_object.keys()) if pattern.search(x) is not None]
for var in matches:
if len(xr_object[var].dims) != 1:
raise RuntimeError(
f"Expected {var} to have 1 dimension, but it has"
f" {len(xr_object[var].dims)}: {xr_object[var].dims}"
)
dim = xr_object[var].dims[0]
# print(f"Variable {var} has dimension {dim}")
coords = xr_object[key].values[xr_object[var].values.astype(int) - 1]
# print(f"{dim} size before: {xr_object.sizes[dim]}")
ok_ind = []
new_1d_this_xy = []
for i, member in enumerate(coords):
if member in incl_coords:
ok_ind = ok_ind + [i]
new_1d_this_xy = new_1d_this_xy + [
(incl_coords == member).nonzero()[0] + 1
]
xr_object = xr_object.isel({dim: ok_ind})
new_1d_this_xy = np.array(new_1d_this_xy).squeeze()
xr_object[var].values = new_1d_this_xy
# print(f"{dim} size after: {xr_object.sizes[dim]}")

# Perform selection
if selection_type == "indices":
# Have to select like this instead of with index directly because otherwise assign_coords() will throw an error. Not sure why.
if isinstance(selection, int):
# Single integer? Turn it into a slice.
selection = slice(selection, selection + 1)
elif (
isinstance(selection, np.ndarray)
and not selection.dtype.kind in np.typecodes["AllInteger"]
):
selection = selection.astype(int)
xr_object = xr_object.isel({key: selection})
elif selection_type == "values":
xr_object = xr_object.sel({key: selection})
else:
raise TypeError(f"selection_type {selection_type} not recognized")

return xr_object


def get_patch_ivts(this_ds, this_pftlist):
"""
Get PFT of each patch, in both integer and string forms.
Expand Down
5 changes: 3 additions & 2 deletions python/ctsm/crop_calendars/generate_gdds_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
sys.path.insert(1, _CTSM_PYTHON)
import ctsm.crop_calendars.cropcal_utils as utils # pylint: disable=wrong-import-position
import ctsm.crop_calendars.cropcal_module as cc # pylint: disable=wrong-import-position
from ctsm.crop_calendars.xr_flexsel import xr_flexsel # pylint: disable=wrong-import-position

CAN_PLOT = True
try:
Expand Down Expand Up @@ -573,10 +574,10 @@ def import_and_process_1yr(
continue

vegtype_int = utils.vegtype_str2int(vegtype_str)[0]
this_crop_full_patchlist = list(utils.xr_flexsel(h2_ds, vegtype=vegtype_str).patch.values)
this_crop_full_patchlist = list(xr_flexsel(h2_ds, vegtype=vegtype_str).patch.values)

# Get time series for each patch of this type
this_crop_ds = utils.xr_flexsel(h2_incl_ds, vegtype=vegtype_str)
this_crop_ds = xr_flexsel(h2_incl_ds, vegtype=vegtype_str)
this_crop_gddaccum_da = this_crop_ds[clm_gdd_var]
if save_figs:
this_crop_gddharv_da = this_crop_ds["GDDHARV"]
Expand Down
Loading

0 comments on commit 538ab01

Please sign in to comment.