diff --git a/.github/workflows/build_workflow.yml b/.github/workflows/build_workflow.yml index d19f29a5b..a183e6572 100644 --- a/.github/workflows/build_workflow.yml +++ b/.github/workflows/build_workflow.yml @@ -51,7 +51,7 @@ jobs: strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12"] - container: + container: image: ghcr.io/e3sm-project/containers-e3sm-diags-test-data:e3sm-diags-test-data-0.0.2 steps: - id: skip_check diff --git a/auxiliary_tools/cdat_regression_testing/892-bottleneck/debug_ref_u.py b/auxiliary_tools/cdat_regression_testing/892-bottleneck/debug_ref_u.py new file mode 100644 index 000000000..88787d4ef --- /dev/null +++ b/auxiliary_tools/cdat_regression_testing/892-bottleneck/debug_ref_u.py @@ -0,0 +1,85 @@ +""" +This script is used to debug the bottleneck issue in the reference u variable. +""" + +# %% +import timeit + +import xarray as xr + +# Perlmutter +# ---------- +# filepaths = [ +# "/global/cfs/cdirs/e3sm/diagnostics/observations/Atm/time-series/ERA5/ua_197901_201912.nc" +# ] + +# LCRC +# ----- +filepaths = [ + "/lcrc/group/e3sm/diagnostics/observations/Atm/time-series/ERA5/ua_197901_201912.nc" +] +time_slice = slice("1996-01-15", "1997-01-15", None) + +# %% +# Test case 1 - OPEN_MFDATASET() + "ua" dataset (76 GB) + subsetting + `.load()` +# Result: .load() hangs when using `open_mfdataset` +# ------------------------------------------------------------------------------ +ds_ua_omfd = xr.open_mfdataset( + filepaths, + decode_times=True, + use_cftime=True, + coords="minimal", + compat="override", +) +ds_ua_omfd_sub = ds_ua_omfd.sel(time=time_slice) + +# %% +start_time = timeit.default_timer() +ds_ua_omfd_sub.load() +elapsed = timeit.default_timer() - start_time +print(f"Time taken to load ds_xc_sub: {elapsed} seconds") + +# %% +# Test case 2 - OPEN_DATASET() + "ua" dataset (76 GB) + subsetting + `.load()` +# Result: load() works fine when using `open_dataset` +# ------------------------------------------------------------------------------ +ds_ua_od = xc.open_dataset( + filepaths[0], + add_bounds=["X", "Y", "T"], + decode_times=True, + use_cftime=True, + # coords="minimal", + # compat="override", +) +ds_ua_od_sub = ds_ua_od.sel(time=time_slice) + +# %% +start_time = timeit.default_timer() +ds_ua_od_sub.load() +elapsed = timeit.default_timer() - start_time +print(f"Time taken to load ds_xc_sub: {elapsed} seconds") + +# %% +# Test case 3 - OPEN_MFDATASET() + "pr" dataset (2 GB) + subsetting + `.load()` +# Result: ds.load() works fine with pr variable, but not with ua variable +# Notes: pr is 3D variable (time, lat, lon), ua is a 4D variable (time, lat, lon, plev). +# ------------------------------------------------------------------------------ +filepaths_pr = [ + "/global/cfs/cdirs/e3sm/diagnostics/observations/Atm/time-series/ERA5/pr_197901_201912.nc" +] +ds_pr = xc.open_mfdataset( + filepaths_pr, + add_bounds=["X", "Y", "T"], + decode_times=True, + use_cftime=True, + coords="minimal", + compat="override", +) + +# %% +# pr dataset is ~2 GB without subsetting. There is no need to subset. +start_time = timeit.default_timer() +ds_pr.load() +elapsed = timeit.default_timer() - start_time +print(f"Time taken to load ds_xc_sub_0: {elapsed} seconds") +# %% diff --git a/auxiliary_tools/cdat_regression_testing/892-bottleneck/run_script.cfg b/auxiliary_tools/cdat_regression_testing/892-bottleneck/run_script.cfg new file mode 100644 index 000000000..3d06adcde --- /dev/null +++ b/auxiliary_tools/cdat_regression_testing/892-bottleneck/run_script.cfg @@ -0,0 +1,13 @@ +[#] +sets = ["lat_lon"] +case_id = "ERA5" +variables = ["U"] +ref_name = "ERA5" +reference_name = "ERA5 Reanalysis" +seasons = ["ANN", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12", "DJF", "MAM", "JJA", "SON"] +plevs = [850.0] +test_colormap = "PiYG_r" +reference_colormap = "PiYG_r" +contour_levels = [-20, -15, -10, -8, -5, -3, -1, 1, 3, 5, 8, 10, 15, 20] +diff_levels = [-8, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4, 5, 6, 8] +regrid_method = "bilinear" \ No newline at end of file diff --git a/auxiliary_tools/cdat_regression_testing/892-bottleneck/run_script.py b/auxiliary_tools/cdat_regression_testing/892-bottleneck/run_script.py new file mode 100644 index 000000000..452901cd6 --- /dev/null +++ b/auxiliary_tools/cdat_regression_testing/892-bottleneck/run_script.py @@ -0,0 +1,39 @@ +import sys +import os +from e3sm_diags.parameter.core_parameter import CoreParameter +from e3sm_diags.run import runner + +param = CoreParameter() + + +param.reference_data_path = ( + "/global/cfs/cdirs/e3sm/diagnostics/observations/Atm/time-series" +) +param.test_data_path = "/global/cfs/cdirs/e3sm/chengzhu/eamxx/post/data/rgr" +param.test_name = "eamxx_decadal" +param.seasons = ["ANN"] +# param.save_netcdf = True + +param.ref_timeseries_input = True +# Years to slice the ref data, base this off the years in the filenames. +param.ref_start_yr = "1996" +param.ref_end_yr = "1996" + +prefix = "/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/892-bottleneck" +param.results_dir = os.path.join(prefix, "eamxx_decadal_1996_1107_edv3") + +cfg_path = "auxiliary_tools/cdat_regression_testing/892-bottleneck/run_script.cfg" +sys.argv.extend(["--diags", cfg_path]) + +runner.sets_to_run = [ + "lat_lon", + "zonal_mean_xy", + "zonal_mean_2d", + "zonal_mean_2d_stratosphere", + "polar", + "cosp_histogram", + "meridional_mean_2d", + "annual_cycle_zonal_mean", +] + +runner.run_diags([param]) diff --git a/auxiliary_tools/cdat_regression_testing/892-bottleneck/xr_mvce_e3sm_data.py b/auxiliary_tools/cdat_regression_testing/892-bottleneck/xr_mvce_e3sm_data.py new file mode 100644 index 000000000..119f869b8 --- /dev/null +++ b/auxiliary_tools/cdat_regression_testing/892-bottleneck/xr_mvce_e3sm_data.py @@ -0,0 +1,20 @@ +# %% +import timeit + +import xarray as xr + +filepaths = [ + "/lcrc/group/e3sm/diagnostics/observations/Atm/time-series/ERA5/ua_197901_201912.nc" +] + +ds = xr.open_mfdataset(filepaths) + +ds_sub = ds.sel(time=slice("1996-01-15", "1997-01-15", None)) + +# %% +start_time = timeit.default_timer() +ds_sub.ua.load() +elapsed = timeit.default_timer() - start_time +print(f"Time taken to load ds_xc_sub: {elapsed} seconds") + +# %% diff --git a/auxiliary_tools/cdat_regression_testing/892-bottleneck/xr_mvce_gh.py b/auxiliary_tools/cdat_regression_testing/892-bottleneck/xr_mvce_gh.py new file mode 100644 index 000000000..2baf2ef65 --- /dev/null +++ b/auxiliary_tools/cdat_regression_testing/892-bottleneck/xr_mvce_gh.py @@ -0,0 +1,48 @@ +# %% +import numpy as np +import pandas as pd +import xarray as xr +import timeit + +import dask.array as da + +# %% +# Define the dimensions +time = 12 +plev = 37 +lat = 721 +lon = 1440 + +# Create the data arrays using dask. +data = da.random.random(size=(time, plev, lat, lon), chunks=(12, 37, 721, 1440)).astype( + np.float32 +) + +# Create the coordinates. +times = pd.date_range("2000-01-01", periods=time) +plevs = np.linspace(100000, 10, plev) +lats = np.linspace(-90, 90, lat) +lons = np.linspace(0, 360, lon, endpoint=False) + +# Create the dataset and write out to a file. +ds = xr.Dataset( + {"data": (["time", "plev", "lat", "lon"], data)}, + coords={"time": times, "plev": plevs, "lat": lats, "lon": lons}, +) +# %% +ds.to_netcdf("dask_bottleneck.nc") + +# %% +# Open the dataset. +ds_open = xr.open_mfdataset("dask_bottleneck.nc") + +# %% +# Load the dataset into memory +start_time = timeit.default_timer() +ds.load() +end_time = timeit.default_timer() + +print(f"Time taken to load the dataset: {end_time - start_time} seconds") + + +# %% diff --git a/conda-env/ci.yml b/conda-env/ci.yml index c8039b294..0a684be41 100644 --- a/conda-env/ci.yml +++ b/conda-env/ci.yml @@ -15,7 +15,7 @@ dependencies: - cartopy >=0.17.0 - cartopy_offlinedata - cf-units - - dask + - dask <2024.12.0 - esmpy >=8.4.0 - lxml - mache >=0.15.0 diff --git a/conda-env/dev.yml b/conda-env/dev.yml index 578800e85..2b1c4e495 100644 --- a/conda-env/dev.yml +++ b/conda-env/dev.yml @@ -13,7 +13,7 @@ dependencies: - cartopy >=0.17.0 - cartopy_offlinedata - cf-units - - dask + - dask <2024.12.0 - esmpy >=8.4.0 - lxml - mache >=0.15.0 diff --git a/e3sm_diags/derivations/derivations.py b/e3sm_diags/derivations/derivations.py index 14eef3104..a176987f4 100644 --- a/e3sm_diags/derivations/derivations.py +++ b/e3sm_diags/derivations/derivations.py @@ -109,6 +109,10 @@ ("pr",): lambda pr: qflxconvert_units(rename(pr)), ("PRECC", "PRECL"): lambda precc, precl: prect(precc, precl), ("sat_gauge_precip",): rename, + ("precip_liq_surf_mass_flux", "precip_ice_surf_mass_flux"): prect, # EAMxx + ("precip_total_surf_mass_flux",): lambda pr: convert_units( + rename(pr), target_units="mm/day" + ), # EAMxx ("PrecipLiqSurfMassFlux", "PrecipIceSurfMassFlux"): prect, # EAMxx }, "PRECST": { @@ -142,12 +146,21 @@ lower_limit=0.9, ), ), + ( + ("surf_radiative_T", "ocnfrac"), + lambda ts, ocnfrac: _apply_land_sea_mask( + convert_units(ts, target_units="degC"), + ocnfrac, + lower_limit=0.9, + ), + ), (("SST",), lambda sst: convert_units(sst, target_units="degC")), ] ), "TMQ": OrderedDict( [ (("PREH2O",), rename), + (("VapWaterPath",), rename), # EAMxx ( ("prw",), lambda prw: convert_units(rename(prw), target_units="kg/m2"), @@ -159,10 +172,7 @@ ("ALBEDO",): rename, ("SOLIN", "FSNTOA"): lambda solin, fsntoa: albedo(solin, solin - fsntoa), ("rsdt", "rsut"): albedo, - ( - "SW_flux_up_at_model_top", - "SW_clrsky_flux_up_at_model_top", - ): swcf, # EAMxx + ("SW_flux_dn_at_model_top", "SW_flux_up_at_model_top"): albedo, # EAMxx }, "ALBEDOC": OrderedDict( [ @@ -172,6 +182,10 @@ lambda solin, fsntoac: albedoc(solin, solin - fsntoac), ), (("rsdt", "rsutcs"), lambda rsdt, rsutcs: albedoc(rsdt, rsutcs)), + ( + ("SW_flux_dn_at_model_top", "SW_clrsky_flux_up_at_model_top"), + lambda rsdt, rsutcs: albedoc(rsdt, rsutcs), + ), # EAMxx ] ), "ALBEDO_SRF": OrderedDict( @@ -182,6 +196,10 @@ ("FSDS", "FSNS"), lambda fsds, fsns: albedo_srf(fsds, fsds - fsns), ), + ( + ("SW_flux_dn_at_model_bot", "SW_flux_up_at_model_bot"), + lambda rsds, rsus: albedo_srf(rsds, rsus), + ), # EAMxx ] ), # Pay attention to the positive direction of SW and LW fluxes @@ -202,6 +220,7 @@ lambda fsntoa, fsntoac: swcf(fsntoa, fsntoac), ), (("rsut", "rsutcs"), lambda rsutcs, rsut: swcf(rsut, rsutcs)), + (("ShortwaveCloudForcing",), rename), # EAMxx ] ), "SWCFSRF": OrderedDict( @@ -217,6 +236,15 @@ ), (("sfc_cre_net_sw_mon",), rename), (("FSNS", "FSNSC"), lambda fsns, fsnsc: swcfsrf(fsns, fsnsc)), + ( + ( + "SW_flux_dn_at_model_bot", + "SW_flux_up_at_model_bot", + "SW_clrsky_flux_dn_at_model_bot", + "SW_clrsky_flux_up_at_model_bot", + ), + lambda fsds, fsus, fsdsc, fsusc: swcfsrf(fsds - fsus, fsdsc - fsusc), + ), # EAMxx ] ), "LWCF": OrderedDict( @@ -236,6 +264,7 @@ lambda flntoa, flntoac: lwcf(flntoa, flntoac), ), (("rlut", "rlutcs"), lambda rlutcs, rlut: lwcf(rlut, rlutcs)), + (("LongwaveCloudForcing",), rename), # EAMxx ] ), "LWCFSRF": OrderedDict( @@ -251,6 +280,13 @@ ), (("sfc_cre_net_lw_mon",), rename), (("FLNS", "FLNSC"), lambda flns, flnsc: lwcfsrf(flns, flnsc)), + ( + ( + "LW_flux_dn_at_model_bot", + "LW_clrsky_flux_dn_at_model_bot", + ), + lambda flds, fldsc: -lwcfsrf(flds, fldsc), + ), # EAMxx ] ), "NETCF": OrderedDict( @@ -282,6 +318,10 @@ lambda swcf, lwcf: netcf2(swcf, lwcf), ), (("SWCF", "LWCF"), lambda swcf, lwcf: netcf2(swcf, lwcf)), + ( + ("ShortwaveCloudForcing", "LongwaveCloudForcing"), + lambda swcf, lwcf: netcf2(swcf, lwcf), + ), # EAMxx ( ("FSNTOA", "FSNTOAC", "FLNTOA", "FLNTOAC"), lambda fsntoa, fsntoac, flntoa, flntoac: netcf4( @@ -322,6 +362,21 @@ ("FSNS", "FSNSC", "FLNSC", "FLNS"), lambda fsns, fsnsc, flnsc, flns: netcf4srf(fsns, fsnsc, flnsc, flns), ), + ( + ( + "SW_flux_dn_at_model_bot", + "SW_flux_up_at_model_bot", + "SW_clrsky_flux_dn_at_model_bot", + "SW_clrsky_flux_up_at_model_bot", + "LW_flux_up_at_model_bot", + "LW_clrsky_flux_dn_at_model_bot", + "LW_flux_up_at_model_bot", + "LW_flux_dn_at_model_bot", + ), + lambda fsds, fsus, fsdsc, fsusc, flusc, fldsc, flus, flds: netcf4srf( + fsds - fsus, fsdsc - fsusc, flusc - fldsc, flus - flds + ), + ), # EAMxx ] ), "FLNS": OrderedDict( @@ -331,6 +386,10 @@ lambda sfc_net_lw_all_mon: -sfc_net_lw_all_mon, ), (("rlds", "rlus"), lambda rlds, rlus: netlw(rlds, rlus)), + ( + ("LW_flux_dn_at_model_bot", "LW_flux_up_at_model_bot"), + lambda rlds, rlus: netlw(rlds, rlus), + ), ] ), "FLNSC": OrderedDict( @@ -343,18 +402,24 @@ ("sfc_net_lw_clr_t_mon",), lambda sfc_net_lw_clr_mon: -sfc_net_lw_clr_mon, ), + ( + ("LW_clrsky_flux_dn_at_model_bot", "LW_flux_up_at_model_bot"), + lambda rlds, rlus: netlw(rlds, rlus), + ), # EAMxx ] ), - "FLDS": OrderedDict([(("rlds",), rename)]), + "FLDS": OrderedDict([(("rlds",), rename), (("LW_flux_dn_at_model_bot",), rename)]), "FLUS": OrderedDict( [ (("rlus",), rename), + (("LW_flux_up_at_model_bot",), rename), # EAMxx (("FLDS", "FLNS"), lambda FLDS, FLNS: flus(FLDS, FLNS)), ] ), "FLDSC": OrderedDict( [ (("rldscs",), rename), + (("LW_clrsky_flux_dn_at_model_bot",), rename), # EAMxx (("TS", "FLNSC"), lambda ts, flnsc: fldsc(ts, flnsc)), ] ), @@ -362,23 +427,42 @@ [ (("sfc_net_sw_all_mon",), rename), (("rsds", "rsus"), lambda rsds, rsus: netsw(rsds, rsus)), + ( + ("SW_flux_dn_at_model_bot", "SW_flux_up_at_model_bot"), + lambda rsds, rsus: netsw(rsds, rsus), + ), # EAMxx ] ), "FSNSC": OrderedDict( [ (("sfc_net_sw_clr_mon",), rename), (("sfc_net_sw_clr_t_mon",), rename), + ( + ("SW_clrsky_flux_dn_at_model_bot", "SW_clrsky_flux_up_at_model_bot"), + lambda rsds, rsus: netsw(rsds, rsus), + ), # EAMxx ] ), - "FSDS": OrderedDict([(("rsds",), rename)]), + "FSDS": OrderedDict( + [(("rsds",), rename), (("SW_flux_dn_at_model_bot",), rename)], + ), "FSUS": OrderedDict( [ (("rsus",), rename), + (("SW_flux_up_at_model_bot",), rename), # EAMxx (("FSDS", "FSNS"), lambda FSDS, FSNS: fsus(FSDS, FSNS)), ] ), - "FSUSC": OrderedDict([(("rsuscs",), rename)]), - "FSDSC": OrderedDict([(("rsdscs",), rename), (("rsdsc",), rename)]), + "FSUSC": OrderedDict( + [(("rsuscs",), rename), (("SW_clrsky_flux_up_at_model_bot",), rename)] + ), + "FSDSC": OrderedDict( + [ + (("rsdscs",), rename), + (("rsdsc",), rename), + (("SW_clrsky_flux_dn_at_model_bot",), rename), + ] + ), # Net surface heat flux: W/(m^2) "NET_FLUX_SRF": OrderedDict( [ @@ -408,12 +492,25 @@ rsds, rsus, rlds, rlus, hfls, hfss ), ), + ( + ( + "SW_flux_dn_at_model_bot", + "SW_flux_up_at_model_bot", + "LW_flux_dn_at_model_bot", + "LW_flux_up_at_model_bot", + "surface_upward_latent_heat_flux", + "surf_sens_flux", + ), + lambda rsds, rsus, rlds, rlus, hfls, hfss: netflux6( + rsds, rsus, rlds, rlus, hfls, hfss + ), # EAMxx + ), ] ), "FLUT": {("rlut",): rename, ("LW_flux_up_at_model_top",): rename}, - "FSUTOA": {("rsut",): rename}, - "FSUTOAC": {("rsutcs",): rename}, - "FLNT": {("FLNT",): rename}, + "FSUTOA": {("rsut",): rename, ("SW_flux_up_at_model_top",): rename}, + "FSUTOAC": {("rsutcs",): rename, ("SW_clrsky_flux_up_at_model_top",): rename}, + "FLNT": {("FLNT",): rename, ("LW_flux_up_at_model_top",): rename}, "FLUTC": {("rlutcs",): rename, ("LW_clrsky_flux_up_at_model_top",): rename}, "FSNTOA": { ("FSNTOA",): rename, @@ -453,6 +550,12 @@ prect(precc, precl), landfrac, lower_limit=0.5 ), ), + ( + ("precip_liq_surf_mass_flux", "precip_ice_surf_mass_flux", "landfrac"), + lambda precc, precl, landfrac: _apply_land_sea_mask( + prect(precc, precl), landfrac, lower_limit=0.5 + ), # EAMxx + ), ] ), "Z3": OrderedDict( @@ -462,6 +565,10 @@ lambda zg: convert_units(rename(zg), target_units="hectometer"), ), (("Z3",), lambda z3: convert_units(z3, target_units="hectometer")), + ( + ("z_mid",), + lambda z3: convert_units(z3, target_units="hectometer"), + ), # EAMxx ] ), "PSL": { @@ -474,7 +581,7 @@ "T": { ("ta",): rename, ("T",): lambda t: convert_units(t, target_units="K"), - ("T_2m",): lambda t: convert_units(t, target_units="DegC"), # EAMxx + ("T_mid",): lambda t: convert_units(t, target_units="K"), # EAMxx }, "U": OrderedDict( [ @@ -496,6 +603,7 @@ lambda t: convert_units(t, target_units="DegC"), ), (("tas",), lambda t: convert_units(t, target_units="DegC")), + (("T_2m",), lambda t: convert_units(t, target_units="DegC")), # EAMxx ] ), # Surface water flux: kg/((m^2)*s) @@ -508,7 +616,7 @@ "LHFLX": { ("hfls",): rename, ("QFLX",): qflx_convert_to_lhflx_approxi, - ("surface_upward_latent_heat_flux",): rename, # EAMxx "s^-3 kg" + ("surface_upward_latent_heat_flux",): rename, # EAMxx }, "SHFLX": { ("hfss",): rename, @@ -528,6 +636,14 @@ lower_limit=0.65, ), ), + ( + ("LiqWaterPath", "ocnfrac"), + lambda tgcldlwp, ocnfrac: _apply_land_sea_mask( + convert_units(tgcldlwp, target_units="g/m^2"), + ocnfrac, + lower_limit=0.65, + ), # EAMxx + ), ] ), "PRECT_OCN": OrderedDict( @@ -544,6 +660,14 @@ lower_limit=0.65, ), ), + ( + ("precip_liq_surf_mass_flux", "precip_liq_surf_mass_flux", "ocnfrac"), + lambda a, b, ocnfrac: _apply_land_sea_mask( + aplusb(a, b, target_units="mm/day"), + ocnfrac, + lower_limit=0.65, + ), # EAMxx + ), ] ), "PREH2O_OCN": OrderedDict( @@ -582,6 +706,10 @@ ("CLOUD",), lambda cldtot: convert_units(cldtot, target_units="%"), ), + ( + ("cldfrac_tot_for_analysis",), + lambda cldtot: convert_units(cldtot, target_units="%"), + ), ] ), # below for COSP output @@ -693,14 +821,21 @@ "RELHUM": { ("hur",): lambda hur: convert_units(hur, target_units="%"), ("RELHUM",): lambda relhum: convert_units(relhum, target_units="%"), + ("RelativeHumidity",): lambda relhum: convert_units( + relhum, target_units="%" + ), # EAMxx }, "OMEGA": { ("wap",): lambda wap: convert_units(wap, target_units="mbar/day"), ("OMEGA",): lambda omega: convert_units(omega, target_units="mbar/day"), + ("omega",): lambda omega: convert_units( + omega, target_units="mbar/day" + ), # EAMxx }, "Q": { ("hus",): lambda q: convert_units(rename(q), target_units="g/kg"), ("Q",): lambda q: convert_units(rename(q), target_units="g/kg"), + ("qv",): lambda q: convert_units(rename(q), target_units="g/kg"), # EAMxx ("SHUM",): lambda shum: convert_units(shum, target_units="g/kg"), }, "H2OLNZ": { @@ -739,9 +874,14 @@ ("surf_radiative_T",): rename, # EAMxx }, "PS": {("ps",): rename}, - "U10": {("sfcWind",): rename}, + "U10": { + ("sfcWind",): rename, + ("wind_speed_10m",): rename, # EAMxx + ("si10",): rename, + }, "QREFHT": { ("QREFHT",): lambda q: convert_units(q, target_units="g/kg"), + ("qv_2m",): lambda q: convert_units(q, target_units="g/kg"), # EAMxx ("huss",): lambda q: convert_units(q, target_units="g/kg"), ("d2m", "sp"): qsat, }, @@ -754,9 +894,18 @@ ("surf_mom_flux_V",): lambda tauv: -tauv, # EAMxx }, "CLDICE": {("cli",): rename}, - "TGCLDIWP": {("clivi",): rename}, - "CLDLIQ": {("clw",): rename}, - "TGCLDCWP": {("clwvi",): rename}, + "TGCLDIWP": { + ("clivi",): rename, + ("IceWaterPath",): rename, # EAMxx + }, + "CLDLIQ": { + ("clw",): rename, + ("qc",): rename, # EAMxx + }, + "TGCLDCWP": { + ("clwvi",): rename, + ("LiqWaterPath",): rename, # EAMxx + }, "O3": {("o3",): rename}, "PminusE": { ("PminusE",): pminuse_convert_units, diff --git a/e3sm_diags/derivations/formulas.py b/e3sm_diags/derivations/formulas.py index 78db38f99..f6200986f 100644 --- a/e3sm_diags/derivations/formulas.py +++ b/e3sm_diags/derivations/formulas.py @@ -143,8 +143,8 @@ def so4_mass_sum(a1: xr.DataArray, a2: xr.DataArray): with xr.set_options(keep_attrs=True): var = (a1 + a2) * AIR_DENS * 1e9 var.name = "so4_mass" - var.units = "\u03bcg/m3" - var.long_name = "SO4 mass conc." + var.attrs["units"] = "\u03bcg/m3" + var.attrs["long_name"] = "SO4 mass conc." return var @@ -324,6 +324,7 @@ def rst(rsdt: xr.DataArray, rsut: xr.DataArray): var = rsdt - rsut var.name = "FSNTOA" + var.attrs["units"] = "W/m2" var.attrs["long_name"] = "TOA net shortwave flux" return var @@ -334,6 +335,7 @@ def rstcs(rsdt: xr.DataArray, rsutcs: xr.DataArray): var = rsdt - rsutcs var.name = "FSNTOAC" + var.attrs["units"] = "W/m2" var.attrs["long_name"] = "TOA net shortwave flux clear-sky" return var @@ -344,6 +346,7 @@ def swcfsrf(fsns: xr.DataArray, fsnsc: xr.DataArray): var = fsns - fsnsc var.name = "SCWFSRF" + var.attrs["units"] = "W/m2" var.attrs["long_name"] = "Surface shortwave cloud forcing" return var @@ -354,6 +357,7 @@ def lwcfsrf(flns: xr.DataArray, flnsc: xr.DataArray): var = -(flns - flnsc) var.name = "LCWFSRF" + var.attrs["units"] = "W/m2" var.attrs["long_name"] = "Surface longwave cloud forcing" return var @@ -364,6 +368,7 @@ def swcf(fsntoa: xr.DataArray, fsntoac: xr.DataArray): var = fsntoa - fsntoac var.name = "SWCF" + var.attrs["units"] = "W/m2" var.attrs["long_name"] = "TOA shortwave cloud forcing" return var @@ -374,6 +379,7 @@ def lwcf(flntoa: xr.DataArray, flntoac: xr.DataArray): var = flntoa - flntoac var.name = "LWCF" + var.attrs["units"] = "W/m2" var.attrs["long_name"] = "TOA longwave cloud forcing" return var @@ -384,6 +390,7 @@ def netcf2(swcf: xr.DataArray, lwcf: xr.DataArray): var = swcf + lwcf var.name = "NETCF" + var.attrs["units"] = "W/m2" var.attrs["long_name"] = "TOA net cloud forcing" return var @@ -399,6 +406,7 @@ def netcf4( var = fsntoa - fsntoac + flntoa - flntoac var.name = "NETCF" + var.attrs["units"] = "W/m2" var.attrs["long_name"] = "TOA net cloud forcing" return var @@ -409,6 +417,7 @@ def netcf2srf(swcf: xr.DataArray, lwcf: xr.DataArray): var = swcf + lwcf var.name = "NETCF_SRF" + var.attrs["units"] = "W/m2" var.attrs["long_name"] = "Surface net cloud forcing" return var @@ -424,6 +433,7 @@ def netcf4srf( var = fsntoa - fsntoac + flntoa - flntoac var.name = "NETCF4SRF" + var.attrs["units"] = "W/m2" var.attrs["long_name"] = "Surface net cloud forcing" return var @@ -445,6 +455,7 @@ def restom(fsnt: xr.DataArray, flnt: xr.DataArray): var = fsnt - flnt var.name = "RESTOM" + var.attrs["units"] = "W/m2" var.attrs["long_name"] = "TOM(top of model) Radiative flux" return var @@ -454,7 +465,8 @@ def restom3(swdn: xr.DataArray, swup: xr.DataArray, lwup: xr.DataArray): with xr.set_options(keep_attrs=True): var = swdn - swup - lwup - var.long_name = "TOM(top of model) Radiative flux" + var.attrs["units"] = "W/m2" + var.attrs["long_name"] = "TOM(top of model) Radiative flux" return var @@ -465,6 +477,7 @@ def restoa(fsnt: xr.DataArray, flnt: xr.DataArray): var = fsnt - flnt var.name = "RESTOA" + var.attrs["units"] = "W/m2" var.attrs["long_name"] = "TOA(top of atmosphere) Radiative flux" return var @@ -475,6 +488,7 @@ def flus(flds: xr.DataArray, flns: xr.DataArray): var = flns + flds var.name = "FLUS" + var.attrs["units"] = "W/m2" var.attrs["long_name"] = "Upwelling longwave flux at surface" return var @@ -485,6 +499,7 @@ def fsus(fsds: xr.DataArray, fsns: xr.DataArray): var = fsds - fsns var.name = "FSUS" + var.attrs["units"] = "W/m2" var.attrs["long_name"] = "Upwelling shortwave flux at surface" return var @@ -495,6 +510,7 @@ def netsw(rsds: xr.DataArray, rsus: xr.DataArray): var = rsds - rsus var.name = "FSNS" + var.attrs["units"] = "W/m2" var.attrs["long_name"] = "Surface SW Radiative flux" return var @@ -505,6 +521,7 @@ def netlw(rlds: xr.DataArray, rlus: xr.DataArray): var = -(rlds - rlus) var.name = "NET_FLUX_SRF" + var.attrs["units"] = "W/m2" var.attrs["long_name"] = "Surface LW Radiative flux" return var @@ -517,6 +534,7 @@ def netflux4( var = fsns - flns - lhflx - shflx var.name = "NET_FLUX_SRF" + var.attrs["units"] = "W/m2" var.attrs["long_name"] = "Surface Net flux" return var @@ -534,6 +552,7 @@ def netflux6( var = rsds - rsus + (rlds - rlus) - hfls - hfss var.name = "NET_FLUX_SRF" + var.attrs["units"] = "W/m2" var.attrs["long_name"] = "Surface Net flux" return var diff --git a/e3sm_diags/derivations/utils.py b/e3sm_diags/derivations/utils.py index 7c1fd7e92..b794f8b5c 100644 --- a/e3sm_diags/derivations/utils.py +++ b/e3sm_diags/derivations/utils.py @@ -82,14 +82,10 @@ def convert_units(var: xr.DataArray, target_units: str): # noqa: C901 elif var.attrs["units"] in ["gN/m^2/day", "gP/m^2/day", "gC/m^2/day"]: pass else: - temp = cf_units.Unit(var.attrs["units"]) - target = cf_units.Unit(target_units) - coeff, offset = temp.convert(1, target), temp.convert(0, target) - - # Keep all of the attributes except the units. - with xr.set_options(keep_attrs=True): - var = coeff * var + offset + original_udunit = cf_units.Unit(var.attrs["units"]) + target_udunit = cf_units.Unit(target_units) + var.values = original_udunit.convert(var.values, target_udunit) var.attrs["units"] = target_units return var diff --git a/e3sm_diags/driver/__init__.py b/e3sm_diags/driver/__init__.py index 7ceab7b9b..630d0c539 100644 --- a/e3sm_diags/driver/__init__.py +++ b/e3sm_diags/driver/__init__.py @@ -8,34 +8,4 @@ # The keys for the land and ocean fraction variables in the # `LAND_OCEAN_MASK_PATH` file. -LAND_FRAC_KEY = "LANDFRAC" -OCEAN_FRAC_KEY = "OCNFRAC" - - -def _get_region_mask_var_key(region: str): - """Get the region's mask variable key. - - This variable key can be used to map the the variable data in a sdataset. - Only land and ocean regions are supported. - - Parameters - ---------- - region : str - The region. - - Returns - ------- - str - The variable key, either "LANDFRAC" or "OCNFRAC". - - Raises - ------ - ValueError - If the region passed is not land or ocean. - """ - if "land" in region: - return LAND_FRAC_KEY - elif "ocean" in region: - return OCEAN_FRAC_KEY - - raise ValueError(f"Only land and ocean regions are supported, not '{region}'.") +FRAC_REGION_KEYS = {"land": ("LANDFRAC", "landfrac"), "ocean": ("OCNFRAC", "ocnfrac")} diff --git a/e3sm_diags/driver/aerosol_budget_driver.py b/e3sm_diags/driver/aerosol_budget_driver.py index 9c1de7d00..faf0c4005 100644 --- a/e3sm_diags/driver/aerosol_budget_driver.py +++ b/e3sm_diags/driver/aerosol_budget_driver.py @@ -3,6 +3,7 @@ script is integrated in e3sm_diags by Jill Zhang, with input from Kai Zhang, Taufiq Hassan, Xue Zheng, Ziming Ke, Susannah Burrows, and Naser Mahfouz. """ + from __future__ import annotations import csv diff --git a/e3sm_diags/driver/utils/climo_xr.py b/e3sm_diags/driver/utils/climo_xr.py index bb229048c..acbe73fa2 100644 --- a/e3sm_diags/driver/utils/climo_xr.py +++ b/e3sm_diags/driver/utils/climo_xr.py @@ -1,8 +1,8 @@ """This module stores climatology functions operating on Xarray objects. - This file will eventually be refactored to use xCDAT's climatology API. """ + from typing import Dict, List, Literal, get_args import numpy as np diff --git a/e3sm_diags/driver/utils/dataset_xr.py b/e3sm_diags/driver/utils/dataset_xr.py index 94b109e66..bc5fa34df 100644 --- a/e3sm_diags/driver/utils/dataset_xr.py +++ b/e3sm_diags/driver/utils/dataset_xr.py @@ -1,6 +1,5 @@ """This module stores the Dataset class, which is the primary class for I/O. - This Dataset class operates on `xr.Dataset` objects, which are created using netCDF files. These `xr.Dataset` contain either the reference or test variable. This variable can either be from a climatology file or a time series file. @@ -8,6 +7,7 @@ calculated. Reference and test variables can also be derived using other variables from dataset files. """ + from __future__ import annotations import collections @@ -28,7 +28,7 @@ DerivedVariableMap, DerivedVariablesMap, ) -from e3sm_diags.driver import LAND_FRAC_KEY, LAND_OCEAN_MASK_PATH, OCEAN_FRAC_KEY +from e3sm_diags.driver import FRAC_REGION_KEYS, LAND_OCEAN_MASK_PATH from e3sm_diags.driver.utils.climo_xr import CLIMO_FREQS, ClimoFreq, climo from e3sm_diags.driver.utils.regrid import HYBRID_SIGMA_KEYS from e3sm_diags.logger import custom_logger @@ -388,7 +388,9 @@ def get_climo_dataset(self, var: str, season: ClimoFreq) -> xr.Dataset: if self.is_time_series: ds = self.get_time_series_dataset(var) + ds_climo = climo(ds, self.var, season).to_dataset() + ds_climo = ds_climo.bounds.add_missing_bounds(axes=["X", "Y"]) return ds_climo @@ -710,12 +712,13 @@ def _get_dataset_with_derived_climo_var(self, ds: xr.Dataset) -> xr.Dataset: return ds_derived # None of the entries in the derived variables dictionary worked, - # so try to get the variable directly from he dataset. + # so try to get the variable directly from the dataset. if target_var in ds.data_vars.keys(): return ds raise IOError( - f"The dataset file has no matching source variables for {target_var}" + f"The dataset file has no matching source variables for {target_var} and " + f"could not be derived using {list(target_var_map.keys())}." ) def _get_matching_climo_src_vars( @@ -1133,14 +1136,14 @@ def _subset_time_series_dataset(self, ds: xr.Dataset, var: str) -> xr.Dataset: xr.Dataset The subsetted time series dataset. """ - ds_sub = self._subset_vars_and_load(ds, var) - - time_slice = self._get_time_slice(ds_sub) - ds_sub = ds_sub.sel(time=time_slice).squeeze() + time_slice = self._get_time_slice(ds) + ds_sub = ds.sel(time=time_slice).squeeze() if self.is_sub_monthly: ds_sub = self._exclude_sub_monthly_coord_spanning_year(ds_sub) + ds_sub = self._subset_vars_and_load(ds_sub, var) + return ds_sub def _get_time_slice(self, ds: xr.Dataset) -> slice: @@ -1428,18 +1431,60 @@ def _get_land_sea_mask(self, season: str) -> xr.Dataset: The xr.Dataset object containing the land sea mask variables "LANDFRAC" and "OCNFRAC". """ - try: - ds_land_frac = self.get_climo_dataset(LAND_FRAC_KEY, season) # type: ignore - ds_ocean_frac = self.get_climo_dataset(OCEAN_FRAC_KEY, season) # type: ignore - except IOError as e: - logger.info( - f"{e}. Using default land sea mask located at `{LAND_OCEAN_MASK_PATH}`." - ) + ds_mask = self._get_land_sea_mask_dataset(season) - ds_mask = xr.open_dataset(LAND_OCEAN_MASK_PATH) - ds_mask = squeeze_time_dim(ds_mask) - else: - ds_mask = xr.merge([ds_land_frac, ds_ocean_frac]) + if ds_mask is None: + logger.info("No land sea mask datasets were found for the given season.") + ds_mask = self._get_default_land_sea_mask_dataset() + + return ds_mask + + def _get_land_sea_mask_dataset(self, season: str) -> xr.Dataset | None: + """Get the land sea mask dataset for the given season. + + Parameters + ---------- + season : str + The season to subset on. + + Returns + ------- + xr.Dataset | None + The land sea mask dataset for the given season, or None if not + found. + """ + land_keys = FRAC_REGION_KEYS["land"] + ocn_keys = FRAC_REGION_KEYS["ocean"] + + datasets = [] + + for land_key, ocn_key in zip(land_keys, ocn_keys): + try: + ds_land = self.get_climo_dataset(land_key, season) # type: ignore + ds_ocn = self.get_climo_dataset(ocn_key, season) # type: ignore + except IOError: + pass + else: + datasets.append(ds_land) + datasets.append(ds_ocn) + + if len(datasets) == 2: + return xr.merge(datasets) + + return None + + def _get_default_land_sea_mask_dataset(self) -> xr.Dataset: + """Get the default land sea mask dataset. + + Returns + ------- + xr.Dataset + The default land sea mask dataset. + """ + logger.info(f"Using default land sea mask located at `{LAND_OCEAN_MASK_PATH}`.") + + ds_mask = xr.open_dataset(LAND_OCEAN_MASK_PATH) + ds_mask = squeeze_time_dim(ds_mask) return ds_mask @@ -1482,6 +1527,7 @@ def _subset_vars_and_load(self, ds: xr.Dataset, var: str) -> xr.Dataset: ] ds = ds[[var] + keep_vars] + # FIXME: `ds.load()` on `ds_ref` causes deadlock. ds.load(scheduler="sync") return ds diff --git a/e3sm_diags/driver/utils/regrid.py b/e3sm_diags/driver/utils/regrid.py index 98e30a619..d8c04561b 100644 --- a/e3sm_diags/driver/utils/regrid.py +++ b/e3sm_diags/driver/utils/regrid.py @@ -6,7 +6,7 @@ import xcdat as xc from e3sm_diags.derivations.default_regions_xr import REGION_SPECS -from e3sm_diags.driver import _get_region_mask_var_key +from e3sm_diags.driver import FRAC_REGION_KEYS from e3sm_diags.logger import custom_logger if TYPE_CHECKING: @@ -189,8 +189,7 @@ def _apply_land_sea_mask( ds: xr.Dataset The dataset containing the variable. ds_mask : xr.Dataset - The dataset containing the land sea region mask variables, "LANDFRAC" - and "OCEANFRAC". + The dataset containing the land sea region mask variable(s). var_key : str The key the variable region : Literal["land", "ocean"] @@ -243,7 +242,7 @@ def _apply_land_sea_mask( ds_new = ds.copy() ds_new = _drop_unused_ilev_axis(ds) output_grid = ds_new.regridder.grid - mask_var_key = _get_region_mask_var_key(region) + mask_var_key = _get_region_mask_var_key(ds_mask, region) ds_mask_new = _drop_unused_ilev_axis(ds_mask) ds_mask_regrid = ds_mask_new.regridder.horizontal( @@ -457,6 +456,41 @@ def _drop_unused_ilev_axis(ds: xr.Dataset) -> xr.Dataset: return ds_new +def _get_region_mask_var_key(ds_mask: xr.Dataset, region: str): + """Get the region's mask variable key. + + This variable key can be used to map the the variable data in a dataset. + Only land and ocean regions are supported. + + Parameters + ---------- + ds_mask : xr.Dataset + The dataset containing the land and ocean mask variables. + region : str + The region. + + Returns + ------- + Tuple[str, ...] + A tuple of valid keys for the land or ocean fraction variable. + + Raises + ------ + ValueError + If the region passed is not land or ocean. + """ + for region_prefix in ["land", "ocean"]: + if region_prefix in region: + region_keys = FRAC_REGION_KEYS.get(region_prefix) + + if region_keys is None: + raise ValueError(f"Only land and ocean regions are supported, not '{region}'.") + + for key in region_keys: + if key in ds_mask.data_vars: + return key + + def regrid_z_axis_to_plevs( dataset: xr.Dataset, var_key: str, diff --git a/e3sm_diags/metrics/metrics.py b/e3sm_diags/metrics/metrics.py index 333980643..d98fe519d 100644 --- a/e3sm_diags/metrics/metrics.py +++ b/e3sm_diags/metrics/metrics.py @@ -1,4 +1,5 @@ """This module stores functions to calculate metrics using Xarray objects.""" + from __future__ import annotations from typing import List, Literal diff --git a/examples/e3sm_diags_for_eamxx/README.md b/examples/e3sm_diags_for_eamxx/README.md new file mode 100644 index 000000000..2aa41d565 --- /dev/null +++ b/examples/e3sm_diags_for_eamxx/README.md @@ -0,0 +1,23 @@ +# Initial Instruction to Run E3SM Diags on EAMxx output (e.g. monthly ne30pg2 output) + +0. Secure an interactive compute node and to activate the E3SM-Unified enviroment: + +salloc --nodes 1 --qos interactive --time 02:00:00 --constraint cpu --account e3sm + +source /global/common/software/e3sm/anaconda_envs/load_latest_e3sm_unified_pm-cpu.sh + +(The version of E3SM Diags (v3) that has EAMxx variable support is available in E3SM-Unified v1.11 (Mid Feb 2025 release. ) + +1. To remap monthly ne30pg2 data to regular lat-lon data to prepare for E3SM Diags run. An example usage based on a EAMxx decadal run is provided in following script ``nco.sh``. To run the script: + +bash nco.sh + +2. Generate a python script for running E3SM Diags. Two example is provided here: + +python run_e3sm_diags_1996.py: to compare 1996 climatology from EAMxx to available 1990 obs climatology + +python run_e3sm_diags_climo.py: to compare 1996 climatology from EAMxx to pre-calculated obs climatology + + + + diff --git a/examples/e3sm_diags_for_eamxx/nco.sh b/examples/e3sm_diags_for_eamxx/nco.sh new file mode 100644 index 000000000..22be04ead --- /dev/null +++ b/examples/e3sm_diags_for_eamxx/nco.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +source /global/common/software/e3sm/anaconda_envs/load_latest_e3sm_unified_pm-cpu.sh + +drc_in=/global/cfs/cdirs/e3sm/chengzhu/eamxx/run +drc_out=/global/cfs/cdirs/e3sm/chengzhu/eamxx/post/data +caseid=output.scream.decadal.monthlyAVG_ne30pg2.AVERAGE.nmonths_x1 + +# spoofed climatology files with data from 1995-09 to 1996-08 + +# create climatology files +cd ${drc_in};ls ${caseid}*1996-0[1-8]*.nc ${caseid}*1995-09*.nc ${caseid}*1995-1[0-2]*.nc | ncclimo -P eamxx --fml_nm=eamxx_decadal --yr_srt=1996 --yr_end=1996 --drc_out=$drc_out + + +map=/global/cfs/projectdirs/e3sm/zender/maps/map_ne30pg2_to_cmip6_180x360_traave.20231201.nc +# remaping climo files to regular lat-lon +cd $drc_out;ls *.nc | ncremap -P eamxx --prm_opt=time,lwband,swband,ilev,lev,plev,cosp_tau,cosp_cth,cosp_prs,dim2,ncol --map=${map} --drc_out=${drc_out}/rgr + +exit + diff --git a/examples/e3sm_diags_for_eamxx/run_e3sm_diags_1996.py b/examples/e3sm_diags_for_eamxx/run_e3sm_diags_1996.py new file mode 100644 index 000000000..3a44ebd06 --- /dev/null +++ b/examples/e3sm_diags_for_eamxx/run_e3sm_diags_1996.py @@ -0,0 +1,34 @@ +import os +from e3sm_diags.parameter.core_parameter import CoreParameter +from e3sm_diags.run import runner + +param = CoreParameter() + +#param.reference_data_path = '/global/cfs/cdirs/e3sm/diagnostics/observations/Atm/climatology' +#param.test_data_path = '/global/cfs/cdirs/e3sm/zhang40/e3sm_diags_for_EAMxx/data/Cess' +#param.reference_data_path = '/global/cfs/cdirs/e3sm/diagnostics/observations/Atm/climatology' +param.reference_data_path = '/global/cfs/cdirs/e3sm/diagnostics/observations/Atm/time-series' +param.test_data_path = '/global/cfs/cdirs/e3sm/chengzhu/eamxx/post/data/rgr' +param.test_name = 'eamxx_decadal' +param.seasons = ["ANN"] +#param.save_netcdf = True + +param.ref_timeseries_input = True +# Years to slice the ref data, base this off the years in the filenames. +param.ref_start_yr = "1996" +param.ref_end_yr = "1996" + +prefix = '/global/cfs/cdirs/e3sm/www/zhang40/tests/eamxx' +param.results_dir = os.path.join(prefix, 'eamxx_decadal_1996_1212_edv3') + +runner.sets_to_run = ["lat_lon", + "zonal_mean_xy", + "zonal_mean_2d", + "zonal_mean_2d_stratosphere", + "polar", + "cosp_histogram", + "meridional_mean_2d", + "annual_cycle_zonal_mean",] + +runner.run_diags([param]) + diff --git a/examples/e3sm_diags_for_eamxx/run_e3sm_diags_climo.py b/examples/e3sm_diags_for_eamxx/run_e3sm_diags_climo.py new file mode 100644 index 000000000..f8a84fe1d --- /dev/null +++ b/examples/e3sm_diags_for_eamxx/run_e3sm_diags_climo.py @@ -0,0 +1,29 @@ +import os +from e3sm_diags.parameter.core_parameter import CoreParameter +from e3sm_diags.run import runner + +param = CoreParameter() + +#param.reference_data_path = '/global/cfs/cdirs/e3sm/diagnostics/observations/Atm/climatology' +#param.test_data_path = '/global/cfs/cdirs/e3sm/zhang40/e3sm_diags_for_EAMxx/data/Cess' +param.reference_data_path = '/global/cfs/cdirs/e3sm/diagnostics/observations/Atm/climatology' +param.test_data_path = '/global/cfs/cdirs/e3sm/chengzhu/eamxx/post/data/rgr' +param.test_name = 'eamxx_decadal' +param.seasons = ["ANN"] +#param.save_netcdf = True + +prefix = '/global/cfs/cdirs/e3sm/www/zhang40/tests/eamxx' +param.results_dir = os.path.join(prefix, 'eamxx_decadal_1212') + +runner.sets_to_run = ["lat_lon", + "zonal_mean_xy", + "zonal_mean_2d", + "zonal_mean_2d_stratosphere", + "polar", + "cosp_histogram", + "meridional_mean_2d", + "annual_cycle_zonal_mean", + ] + +runner.run_diags([param]) + diff --git a/pyproject.toml b/pyproject.toml index 26a0803df..427c5483d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ dependencies = [ # This package is not available on PyPI. # "cartopy_offlinedata", "cf-units", - "dask", + "dask <2024.12.0", "esmpy >=8.4.0", "lxml", "mache >=0.15.0", diff --git a/tests/e3sm_diags/driver/utils/test_dataset_xr.py b/tests/e3sm_diags/driver/utils/test_dataset_xr.py index 2c89c92cb..030f753bb 100644 --- a/tests/e3sm_diags/driver/utils/test_dataset_xr.py +++ b/tests/e3sm_diags/driver/utils/test_dataset_xr.py @@ -561,6 +561,63 @@ def test_returns_climo_dataset_using_test_file_variable_ref_name_and_season_nest @pytest.mark.xfail( reason="Need to figure out why to create dummy incorrect time scalar variable with Xarray." ) + def test_returns_climo_dataset_with_derived_variable(self): + # We will derive the "PRECT" variable using the "pr" variable. + ds_pr = xr.Dataset( + coords={ + **spatial_coords, + "time": xr.DataArray( + dims="time", + data=np.array( + [ + cftime.DatetimeGregorian( + 2000, 1, 16, 12, 0, 0, 0, has_year_zero=False + ), + ], + dtype=object, + ), + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ), + }, + data_vars={ + **spatial_bounds, + "pr": xr.DataArray( + xr.DataArray( + data=np.array( + [ + [[1.0, 1.0], [1.0, 1.0]], + ] + ), + dims=["time", "lat", "lon"], + attrs={"units": "mm/s"}, + ) + ), + }, + ) + + parameter = _create_parameter_object( + "ref", "climo", self.data_path, "2000", "2001" + ) + parameter.ref_file = "pr_200001_200112.nc" + ds_pr.to_netcdf(f"{self.data_path}/{parameter.ref_file}") + + ds = Dataset(parameter, data_type="ref") + + result = ds.get_climo_dataset("PRECT", season="ANN") + expected = ds_pr.copy() + expected = expected.squeeze(dim="time").drop_vars("time") + expected["PRECT"] = expected["pr"] * 3600 * 24 + expected["PRECT"].attrs["units"] = "mm/day" + expected = expected.drop_vars("pr") + + xr.testing.assert_identical(result, expected) + + @pytest.mark.xfail def test_returns_climo_dataset_using_derived_var_directly_from_dataset_and_replaces_scalar_time_var( self, ): @@ -750,6 +807,7 @@ def test_returns_climo_dataset_using_climo_of_time_series_files(self): # Set all of the correct attributes. expected = expected.assign(**spatial_coords) # type: ignore expected = expected.drop_dims("time") + expected = expected.bounds.add_missing_bounds(axes=["X", "Y"]) xr.testing.assert_identical(result, expected)