diff --git a/esmvaltool/diag_scripts/validation.py b/esmvaltool/diag_scripts/validation.py index d72cc5e1ad..52c8d1a561 100644 --- a/esmvaltool/diag_scripts/validation.py +++ b/esmvaltool/diag_scripts/validation.py @@ -29,27 +29,11 @@ def _get_provenance_record(cfg, plot_file, caption, loc): """Create a provenance record describing the diagnostic data and plot.""" - all_input_files = [ - k for k in cfg["input_data"].keys() if k.endswith(".nc") - ] - if "_vs_" in plot_file: - model_1 = plot_file.split("_vs_")[0].split("_")[-1] - if plot_file.endswith(".png"): - model_2 = plot_file.split("_vs_")[1].strip(".png") - elif plot_file.endswith(".nc"): - model_2 = plot_file.split("_vs_")[1].strip(".nc") - ancestor_1 = [ - k for k in all_input_files if model_1 in os.path.basename(k) - ][0] - ancestor_2 = [ - k for k in all_input_files if model_2 in os.path.basename(k) - ][0] - ancestor_files = [ancestor_1, ancestor_2] - else: - model = os.path.basename(plot_file).split("_")[0] - ancestor_files = [ - k for k in all_input_files if model in os.path.basename(k) - ] + ancestor_files = [] + for dataset in cfg['input_data'].values(): + if (dataset['alias'] in plot_file and + dataset['short_name'] in plot_file): + ancestor_files.append(dataset['filename']) record = { 'caption': caption, 'statistics': ['mean'], @@ -72,9 +56,9 @@ def _get_provenance_record(cfg, plot_file, caption, loc): def plot_contour(cube, cfg, plt_title, file_name): """Plot a contour with iris.quickplot (qplot).""" if len(cube.shape) == 2: - qplt.contourf(cube, cmap='RdYlBu_r', bbox_inches='tight') + qplt.contourf(cube, cmap='RdYlBu_r') else: - qplt.contourf(cube[0], cmap='RdYlBu_r', bbox_inches='tight') + qplt.contourf(cube[0], cmap='RdYlBu_r') plt.title(plt_title) plt.gca().coastlines() plt.tight_layout() @@ -138,7 +122,10 @@ def plot_latlon_cubes(cube_1, # plot each cube var = data_names.split('_')[0] if not obs_name: - cube_names = [data_names.split('_')[1], data_names.split('_')[3]] + cube_names = [ + data_names.replace(f'{var}_', '').split('_vs_')[i] for i in + range(2) + ] for cube, cube_name in zip(cubes, cube_names): if not season: plot_file_path = os.path.join( @@ -179,23 +166,40 @@ def plot_zonal_cubes(cube_1, cube_2, cfg, plot_data): # xcoordinate: latotude or longitude (str) data_names, xcoordinate, period = plot_data var = data_names.split('_')[0] - cube_names = [data_names.split('_')[1], data_names.split('_')[3]] + cube_names = data_names.replace(var + '_', '').split('_vs_') lat_points = cube_1.coord(xcoordinate).points plt.plot(lat_points, cube_1.data, label=cube_names[0]) plt.plot(lat_points, cube_2.data, label=cube_names[1]) + plt.title(f'Annual Climatology of {var}' if period == 'alltime' + else f'{period} of {var}') if xcoordinate == 'latitude': - plt.title(period + ' Zonal Mean for ' + var + ' ' + data_names) + axis = plt.gca() + axis.set_xticks([-60, -30, 0, 30, 60], + labels=['60\N{DEGREE SIGN} S', + '30\N{DEGREE SIGN} S', + '0\N{DEGREE SIGN}', + '30\N{DEGREE SIGN} N', + '60\N{DEGREE SIGN} N']) elif xcoordinate == 'longitude': - plt.title(period + ' Meridional Mean for ' + var + ' ' + data_names) + axis = plt.gca() + axis.set_xticks([0, 60, 120, 180, 240, 300, 360], + labels=['0\N{DEGREE SIGN} E', + '60\N{DEGREE SIGN} E', + '120\N{DEGREE SIGN} E', + '180\N{DEGREE SIGN} E', + '240\N{DEGREE SIGN} E', + '300\N{DEGREE SIGN} E', + '0\N{DEGREE SIGN} E']) plt.xlabel(xcoordinate + ' (deg)') - plt.ylabel(var) + plt.ylabel(f'{var} [{str(cube_1.units)}]') plt.tight_layout() plt.grid() plt.legend() + png_name = f'{xcoordinate}_{period}_{data_names}.png' if xcoordinate == 'latitude': - png_name = 'Zonal_Mean_' + xcoordinate + '_' + data_names + '.png' + png_name = 'Zonal_Mean_' + png_name elif xcoordinate == 'longitude': - png_name = 'Merid_Mean_' + xcoordinate + '_' + data_names + '.png' + png_name = 'Merid_Mean_' + png_name plot_file_path = os.path.join(cfg['plot_dir'], period, png_name) plt.savefig(plot_file_path) save_plotted_cubes( @@ -252,13 +256,13 @@ def coordinate_collapse(data_set, cfg): if 'mask_threshold' in cfg: thr = cfg['mask_threshold'] data_set.data = np.ma.masked_array(data_set.data, - mask=(mask_cube.data > thr)) + mask=mask_cube.data > thr) else: logger.warning('Could not find masking threshold') logger.warning('Please specify it if needed') logger.warning('Masking on 0-values = True (masked value)') data_set.data = np.ma.masked_array(data_set.data, - mask=(mask_cube.data == 0)) + mask=mask_cube.data == 0) # if zonal mean on LON if analysis_type == 'zonal_mean':