diff --git a/auxiliary_tools/template_cdat_regression_test.ipynb b/auxiliary_tools/template_cdat_regression_test.ipynb
deleted file mode 100644
index 8b4d00bd1..000000000
--- a/auxiliary_tools/template_cdat_regression_test.ipynb
+++ /dev/null
@@ -1,1333 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# CDAT Migration Regression Test (FY24)\n",
- "\n",
- "This notebook is used to perform regression testing between the development and\n",
- "production versions of a diagnostic set.\n",
- "\n",
- "## How it works\n",
- "\n",
- "It compares the relative differences (%) between two sets of `.json` files in two\n",
- "separate directories, one for the refactored code and the other for the `main` branch.\n",
- "\n",
- "It will display metrics values with relative differences >= 2%. Relative differences are used instead of absolute differences because:\n",
- "\n",
- "- Relative differences are in percentages, which shows the scale of the differences.\n",
- "- Absolute differences are just a raw number that doesn't factor in\n",
- " floating point size (e.g., 100.00 vs. 0.0001), which can be misleading.\n",
- "\n",
- "## How to use\n",
- "\n",
- "PREREQUISITE: The diagnostic set's metrics stored in `.json` files in two directories\n",
- "(dev and `main` branches).\n",
- "\n",
- "1. Make a copy of this notebook.\n",
- "2. Run `mamba create -n cdat_regression_test -y -c conda-forge \"python<3.12\" pandas matplotlib-base ipykernel`\n",
- "3. Run `mamba activate cdat_regression_test`\n",
- "4. Update `DEV_PATH` and `PROD_PATH` in the copy of your notebook.\n",
- "5. Run all cells IN ORDER.\n",
- "6. Review results for any outstanding differences (>= 2%).\n",
- " - Debug these differences (e.g., bug in metrics functions, incorrect variable references, etc.)\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Setup Code\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "import glob\n",
- "import math\n",
- "from typing import List\n",
- "\n",
- "import pandas as pd\n",
- "\n",
- "# TODO: Update DEV_RESULTS and PROD_RESULTS to your diagnostic sets.\n",
- "DEV_PATH = \"/global/cfs/cdirs/e3sm/www/vo13/examples_658/ex1_modTS_vs_modTS_3years/lat_lon/model_vs_model\"\n",
- "PROD_PATH = \"/global/cfs/cdirs/e3sm/www/vo13/examples/ex1_modTS_vs_modTS_3years/lat_lon/model_vs_model\"\n",
- "\n",
- "DEV_GLOB = sorted(glob.glob(DEV_PATH + \"/*.json\"))\n",
- "PROD_GLOB = sorted(glob.glob(PROD_PATH + \"/*.json\"))\n",
- "\n",
- "# The names of the columns that store percentage difference values.\n",
- "PERCENTAGE_COLUMNS = [\n",
- " \"test DIFF (%)\",\n",
- " \"ref DIFF (%)\",\n",
- " \"test_regrid DIFF (%)\",\n",
- " \"ref_regrid DIFF (%)\",\n",
- " \"diff DIFF (%)\",\n",
- " \"misc DIFF (%)\",\n",
- "]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Core Functions\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "def get_metrics(filepaths: List[str]) -> pd.DataFrame:\n",
- " \"\"\"Get the metrics using a glob of `.json` metric files in a directory.\n",
- "\n",
- " Parameters\n",
- " ----------\n",
- " filepaths : List[str]\n",
- " The filepaths for metrics `.json` files.\n",
- "\n",
- " Returns\n",
- " -------\n",
- " pd.DataFrame\n",
- " The DataFrame containing the metrics for all of the variables in\n",
- " the results directory.\n",
- " \"\"\"\n",
- " metrics = []\n",
- "\n",
- " for filepath in filepaths:\n",
- " df = pd.read_json(filepath)\n",
- "\n",
- " filename = filepath.split(\"/\")[-1]\n",
- " var_key = filename.split(\"-\")[1]\n",
- "\n",
- " # Add the variable key to the MultiIndex and update the index\n",
- " # before stacking to make the DataFrame easier to parse.\n",
- " multiindex = pd.MultiIndex.from_product([[var_key], [*df.index]])\n",
- " df = df.set_index(multiindex)\n",
- " df.stack()\n",
- "\n",
- " metrics.append(df)\n",
- "\n",
- " df_final = pd.concat(metrics)\n",
- "\n",
- " # Reorder columns and drop \"unit\" column (string dtype breaks Pandas\n",
- " # arithmetic).\n",
- " df_final = df_final[[\"test\", \"ref\", \"test_regrid\", \"ref_regrid\", \"diff\", \"misc\"]]\n",
- "\n",
- " return df_final\n",
- "\n",
- "\n",
- "def get_rel_diffs(df_actual: pd.DataFrame, df_reference: pd.DataFrame) -> pd.DataFrame:\n",
- " \"\"\"Get the relative differences between two DataFrames.\n",
- "\n",
- " Formula: abs(actual - reference) / abs(actual)\n",
- "\n",
- " Parameters\n",
- " ----------\n",
- " df_actual : pd.DataFrame\n",
- " The first DataFrame representing \"actual\" results (dev branch).\n",
- " df_reference : pd.DataFrame\n",
- " The second DataFrame representing \"reference\" results (main branch).\n",
- "\n",
- " Returns\n",
- " -------\n",
- " pd.DataFrame\n",
- " The DataFrame containing absolute and relative differences between\n",
- " the metrics DataFrames.\n",
- " \"\"\"\n",
- " df_diff = abs(df_actual - df_reference) / abs(df_actual)\n",
- " df_diff = df_diff.add_suffix(\" DIFF (%)\")\n",
- "\n",
- " return df_diff\n",
- "\n",
- "\n",
- "def sort_columns(df: pd.DataFrame) -> pd.DataFrame:\n",
- " \"\"\"Sorts the order of the columns for the final DataFrame output.\n",
- "\n",
- " Parameters\n",
- " ----------\n",
- " df : pd.DataFrame\n",
- " The final DataFrame output.\n",
- "\n",
- " Returns\n",
- " -------\n",
- " pd.DataFrame\n",
- " The final DataFrame output with sorted columns.\n",
- " \"\"\"\n",
- " columns = [\n",
- " \"test_dev\",\n",
- " \"test_prod\",\n",
- " \"test DIFF (%)\",\n",
- " \"ref_dev\",\n",
- " \"ref_prod\",\n",
- " \"ref DIFF (%)\",\n",
- " \"test_regrid_dev\",\n",
- " \"test_regrid_prod\",\n",
- " \"test_regrid DIFF (%)\",\n",
- " \"ref_regrid_dev\",\n",
- " \"ref_regrid_prod\",\n",
- " \"ref_regrid DIFF (%)\",\n",
- " \"diff_dev\",\n",
- " \"diff_prod\",\n",
- " \"diff DIFF (%)\",\n",
- " \"misc_dev\",\n",
- " \"misc_prod\",\n",
- " \"misc DIFF (%)\",\n",
- " ]\n",
- "\n",
- " df_new = df.copy()\n",
- " df_new = df_new[columns]\n",
- "\n",
- " return df_new\n",
- "\n",
- "\n",
- "def update_diffs_to_pct(df: pd.DataFrame):\n",
- " \"\"\"Update relative diff columns from float to string percentage.\n",
- "\n",
- " Parameters\n",
- " ----------\n",
- " df : pd.DataFrame\n",
- " The final DataFrame containing metrics and diffs (floats).\n",
- "\n",
- " Returns\n",
- " -------\n",
- " pd.DataFrame\n",
- " The final DataFrame containing metrics and diffs (str percentage).\n",
- " \"\"\"\n",
- " df_new = df.copy()\n",
- " df_new[PERCENTAGE_COLUMNS] = df_new[PERCENTAGE_COLUMNS].map(\n",
- " lambda x: \"{0:.2f}%\".format(x * 100) if not math.isnan(x) else x\n",
- " )\n",
- "\n",
- " return df_new"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 1. Get the DataFrame containing development and production metrics.\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [],
- "source": [
- "df_metrics_dev = get_metrics(DEV_GLOB)\n",
- "df_metrics_prod = get_metrics(PROD_GLOB)\n",
- "df_metrics_all = pd.concat(\n",
- " [df_metrics_dev.add_suffix(\"_dev\"), df_metrics_prod.add_suffix(\"_prod\")],\n",
- " axis=1,\n",
- " join=\"outer\",\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 2. Get DataFrame for differences >= 2%.\n",
- "\n",
- "- Get the relative differences for all metrics\n",
- "- Filter down metrics to those with differences >= 2%\n",
- " - If all cells in a row are NaN (< 2%), the entire row is dropped to make the results easier to parse.\n",
- " - Any remaining NaN cells are below < 2% difference and **should be ignored**.\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [],
- "source": [
- "df_metrics_diffs = get_rel_diffs(df_metrics_dev, df_metrics_prod)\n",
- "df_metrics_diffs_thres = df_metrics_diffs[df_metrics_diffs >= 0.02]\n",
- "df_metrics_diffs_thres = df_metrics_diffs_thres.dropna(\n",
- " axis=0, how=\"all\", ignore_index=False\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 3. Combine both DataFrames to get the final result.\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [],
- "source": [
- "df_final = df_metrics_diffs_thres.join(df_metrics_all)\n",
- "df_final = sort_columns(df_final)\n",
- "df_final = update_diffs_to_pct(df_final)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 4. Display final DataFrame and review results.\n",
- "\n",
- "- Red cells are differences >= 2%\n",
- "- `nan` cells are differences < 2% and **should be ignored**\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " var_key | \n",
- " metric | \n",
- " test_dev | \n",
- " test_prod | \n",
- " test DIFF (%) | \n",
- " ref_dev | \n",
- " ref_prod | \n",
- " ref DIFF (%) | \n",
- " test_regrid_dev | \n",
- " test_regrid_prod | \n",
- " test_regrid DIFF (%) | \n",
- " ref_regrid_dev | \n",
- " ref_regrid_prod | \n",
- " ref_regrid DIFF (%) | \n",
- " diff_dev | \n",
- " diff_prod | \n",
- " diff DIFF (%) | \n",
- " misc_dev | \n",
- " misc_prod | \n",
- " misc DIFF (%) | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " FLUT | \n",
- " max | \n",
- " 299.911864 | \n",
- " 299.355074 | \n",
- " nan | \n",
- " 300.162128 | \n",
- " 299.776167 | \n",
- " nan | \n",
- " 299.911864 | \n",
- " 299.355074 | \n",
- " nan | \n",
- " 300.162128 | \n",
- " 299.776167 | \n",
- " nan | \n",
- " 9.492359 | \n",
- " 9.788809 | \n",
- " 3.12% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " FLUT | \n",
- " min | \n",
- " 124.610884 | \n",
- " 125.987072 | \n",
- " nan | \n",
- " 122.878196 | \n",
- " 124.148986 | \n",
- " nan | \n",
- " 124.610884 | \n",
- " 125.987072 | \n",
- " nan | \n",
- " 122.878196 | \n",
- " 124.148986 | \n",
- " nan | \n",
- " -15.505809 | \n",
- " -17.032325 | \n",
- " 9.84% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " FSNS | \n",
- " max | \n",
- " 269.789702 | \n",
- " 269.798166 | \n",
- " nan | \n",
- " 272.722362 | \n",
- " 272.184917 | \n",
- " nan | \n",
- " 269.789702 | \n",
- " 269.798166 | \n",
- " nan | \n",
- " 272.722362 | \n",
- " 272.184917 | \n",
- " nan | \n",
- " 20.647929 | \n",
- " 24.859852 | \n",
- " 20.40% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " FSNS | \n",
- " min | \n",
- " 16.897423 | \n",
- " 17.760889 | \n",
- " 5.11% | \n",
- " 16.710134 | \n",
- " 16.237061 | \n",
- " 2.83% | \n",
- " 16.897423 | \n",
- " 17.760889 | \n",
- " 5.11% | \n",
- " 16.710134 | \n",
- " 16.237061 | \n",
- " 2.83% | \n",
- " -28.822277 | \n",
- " -28.324921 | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 4 | \n",
- " FSNTOA | \n",
- " max | \n",
- " 360.624327 | \n",
- " 360.209193 | \n",
- " nan | \n",
- " 362.188816 | \n",
- " 361.778529 | \n",
- " nan | \n",
- " 360.624327 | \n",
- " 360.209193 | \n",
- " nan | \n",
- " 362.188816 | \n",
- " 361.778529 | \n",
- " nan | \n",
- " 18.602276 | \n",
- " 22.624266 | \n",
- " 21.62% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 5 | \n",
- " FSNTOA | \n",
- " mean | \n",
- " 239.859777 | \n",
- " 240.001860 | \n",
- " nan | \n",
- " 241.439641 | \n",
- " 241.544384 | \n",
- " nan | \n",
- " 239.859777 | \n",
- " 240.001860 | \n",
- " nan | \n",
- " 241.439641 | \n",
- " 241.544384 | \n",
- " nan | \n",
- " -1.579864 | \n",
- " -1.542524 | \n",
- " 2.36% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " FSNTOA | \n",
- " min | \n",
- " 44.907041 | \n",
- " 48.256818 | \n",
- " 7.46% | \n",
- " 47.223502 | \n",
- " 50.339608 | \n",
- " 6.60% | \n",
- " 44.907041 | \n",
- " 48.256818 | \n",
- " 7.46% | \n",
- " 47.223502 | \n",
- " 50.339608 | \n",
- " 6.60% | \n",
- " -23.576184 | \n",
- " -23.171864 | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 7 | \n",
- " LHFLX | \n",
- " max | \n",
- " 282.280453 | \n",
- " 289.079940 | \n",
- " 2.41% | \n",
- " 275.792933 | \n",
- " 276.297281 | \n",
- " nan | \n",
- " 282.280453 | \n",
- " 289.079940 | \n",
- " 2.41% | \n",
- " 275.792933 | \n",
- " 276.297281 | \n",
- " nan | \n",
- " 47.535503 | \n",
- " 53.168924 | \n",
- " 11.85% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 8 | \n",
- " LHFLX | \n",
- " mean | \n",
- " 88.379609 | \n",
- " 88.470270 | \n",
- " nan | \n",
- " 88.969550 | \n",
- " 88.976266 | \n",
- " nan | \n",
- " 88.379609 | \n",
- " 88.470270 | \n",
- " nan | \n",
- " 88.969550 | \n",
- " 88.976266 | \n",
- " nan | \n",
- " -0.589942 | \n",
- " -0.505996 | \n",
- " 14.23% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " LHFLX | \n",
- " min | \n",
- " -0.878371 | \n",
- " -0.549248 | \n",
- " 37.47% | \n",
- " -1.176561 | \n",
- " -0.946110 | \n",
- " 19.59% | \n",
- " -0.878371 | \n",
- " -0.549248 | \n",
- " 37.47% | \n",
- " -1.176561 | \n",
- " -0.946110 | \n",
- " 19.59% | \n",
- " -34.375924 | \n",
- " -33.902769 | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 10 | \n",
- " LWCF | \n",
- " max | \n",
- " 78.493653 | \n",
- " 77.473220 | \n",
- " nan | \n",
- " 86.121959 | \n",
- " 84.993825 | \n",
- " nan | \n",
- " 78.493653 | \n",
- " 77.473220 | \n",
- " nan | \n",
- " 86.121959 | \n",
- " 84.993825 | \n",
- " nan | \n",
- " 9.616057 | \n",
- " 10.796104 | \n",
- " 12.27% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 11 | \n",
- " LWCF | \n",
- " mean | \n",
- " 24.373224 | \n",
- " 24.370539 | \n",
- " nan | \n",
- " 24.406697 | \n",
- " 24.391579 | \n",
- " nan | \n",
- " 24.373224 | \n",
- " 24.370539 | \n",
- " nan | \n",
- " 24.406697 | \n",
- " 24.391579 | \n",
- " nan | \n",
- " -0.033473 | \n",
- " -0.021040 | \n",
- " 37.14% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " LWCF | \n",
- " min | \n",
- " -0.667812 | \n",
- " -0.617107 | \n",
- " 7.59% | \n",
- " -1.360010 | \n",
- " -1.181787 | \n",
- " 13.10% | \n",
- " -0.667812 | \n",
- " -0.617107 | \n",
- " 7.59% | \n",
- " -1.360010 | \n",
- " -1.181787 | \n",
- " 13.10% | \n",
- " -10.574643 | \n",
- " -10.145188 | \n",
- " 4.06% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 13 | \n",
- " NETCF | \n",
- " max | \n",
- " 13.224604 | \n",
- " 12.621825 | \n",
- " 4.56% | \n",
- " 13.715438 | \n",
- " 13.232716 | \n",
- " 3.52% | \n",
- " 13.224604 | \n",
- " 12.621825 | \n",
- " 4.56% | \n",
- " 13.715438 | \n",
- " 13.232716 | \n",
- " 3.52% | \n",
- " 10.899344 | \n",
- " 10.284825 | \n",
- " 5.64% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 14 | \n",
- " NETCF | \n",
- " min | \n",
- " -66.633044 | \n",
- " -66.008633 | \n",
- " nan | \n",
- " -64.832041 | \n",
- " -67.398047 | \n",
- " 3.96% | \n",
- " -66.633044 | \n",
- " -66.008633 | \n",
- " nan | \n",
- " -64.832041 | \n",
- " -67.398047 | \n",
- " 3.96% | \n",
- " -17.923932 | \n",
- " -17.940099 | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 15 | \n",
- " NET_FLUX_SRF | \n",
- " max | \n",
- " 155.691338 | \n",
- " 156.424180 | \n",
- " nan | \n",
- " 166.556120 | \n",
- " 166.506173 | \n",
- " nan | \n",
- " 155.691338 | \n",
- " 156.424180 | \n",
- " nan | \n",
- " 166.556120 | \n",
- " 166.506173 | \n",
- " nan | \n",
- " 59.819449 | \n",
- " 61.672824 | \n",
- " 3.10% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 16 | \n",
- " NET_FLUX_SRF | \n",
- " mean | \n",
- " 0.394016 | \n",
- " 0.516330 | \n",
- " 31.04% | \n",
- " -0.068186 | \n",
- " 0.068584 | \n",
- " 200.58% | \n",
- " 0.394016 | \n",
- " 0.516330 | \n",
- " 31.04% | \n",
- " -0.068186 | \n",
- " 0.068584 | \n",
- " 200.58% | \n",
- " 0.462202 | \n",
- " 0.447746 | \n",
- " 3.13% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 17 | \n",
- " NET_FLUX_SRF | \n",
- " min | \n",
- " -284.505205 | \n",
- " -299.505024 | \n",
- " 5.27% | \n",
- " -280.893287 | \n",
- " -290.202934 | \n",
- " 3.31% | \n",
- " -284.505205 | \n",
- " -299.505024 | \n",
- " 5.27% | \n",
- " -280.893287 | \n",
- " -290.202934 | \n",
- " 3.31% | \n",
- " -75.857589 | \n",
- " -85.852089 | \n",
- " 13.18% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 18 | \n",
- " PRECT | \n",
- " max | \n",
- " 17.289951 | \n",
- " 17.071276 | \n",
- " nan | \n",
- " 20.264862 | \n",
- " 20.138274 | \n",
- " nan | \n",
- " 17.289951 | \n",
- " 17.071276 | \n",
- " nan | \n",
- " 20.264862 | \n",
- " 20.138274 | \n",
- " nan | \n",
- " 2.344111 | \n",
- " 2.406625 | \n",
- " 2.67% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 19 | \n",
- " PRECT | \n",
- " mean | \n",
- " 3.053802 | \n",
- " 3.056760 | \n",
- " nan | \n",
- " 3.074885 | \n",
- " 3.074978 | \n",
- " nan | \n",
- " 3.053802 | \n",
- " 3.056760 | \n",
- " nan | \n",
- " 3.074885 | \n",
- " 3.074978 | \n",
- " nan | \n",
- " -0.021083 | \n",
- " -0.018218 | \n",
- " 13.59% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 20 | \n",
- " PSL | \n",
- " min | \n",
- " 970.981710 | \n",
- " 971.390765 | \n",
- " nan | \n",
- " 973.198437 | \n",
- " 973.235326 | \n",
- " nan | \n",
- " 970.981710 | \n",
- " 971.390765 | \n",
- " nan | \n",
- " 973.198437 | \n",
- " 973.235326 | \n",
- " nan | \n",
- " -6.328677 | \n",
- " -6.104610 | \n",
- " 3.54% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 21 | \n",
- " PSL | \n",
- " rmse | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " 1.042884 | \n",
- " 0.979981 | \n",
- " 6.03% | \n",
- "
\n",
- " \n",
- " 22 | \n",
- " RESTOM | \n",
- " max | \n",
- " 84.295502 | \n",
- " 83.821906 | \n",
- " nan | \n",
- " 87.707944 | \n",
- " 87.451262 | \n",
- " nan | \n",
- " 84.295502 | \n",
- " 83.821906 | \n",
- " nan | \n",
- " 87.707944 | \n",
- " 87.451262 | \n",
- " nan | \n",
- " 17.396283 | \n",
- " 21.423616 | \n",
- " 23.15% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 23 | \n",
- " RESTOM | \n",
- " mean | \n",
- " 0.481549 | \n",
- " 0.656560 | \n",
- " 36.34% | \n",
- " 0.018041 | \n",
- " 0.162984 | \n",
- " 803.40% | \n",
- " 0.481549 | \n",
- " 0.656560 | \n",
- " 36.34% | \n",
- " 0.018041 | \n",
- " 0.162984 | \n",
- " 803.40% | \n",
- " 0.463508 | \n",
- " 0.493576 | \n",
- " 6.49% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 24 | \n",
- " RESTOM | \n",
- " min | \n",
- " -127.667181 | \n",
- " -129.014673 | \n",
- " nan | \n",
- " -127.417586 | \n",
- " -128.673508 | \n",
- " nan | \n",
- " -127.667181 | \n",
- " -129.014673 | \n",
- " nan | \n",
- " -127.417586 | \n",
- " -128.673508 | \n",
- " nan | \n",
- " -15.226249 | \n",
- " -14.869614 | \n",
- " 2.34% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 25 | \n",
- " SHFLX | \n",
- " max | \n",
- " 114.036895 | \n",
- " 112.859646 | \n",
- " nan | \n",
- " 116.870038 | \n",
- " 116.432591 | \n",
- " nan | \n",
- " 114.036895 | \n",
- " 112.859646 | \n",
- " nan | \n",
- " 116.870038 | \n",
- " 116.432591 | \n",
- " nan | \n",
- " 28.320656 | \n",
- " 27.556755 | \n",
- " 2.70% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 26 | \n",
- " SHFLX | \n",
- " min | \n",
- " -88.650312 | \n",
- " -88.386947 | \n",
- " nan | \n",
- " -85.809438 | \n",
- " -85.480377 | \n",
- " nan | \n",
- " -88.650312 | \n",
- " -88.386947 | \n",
- " nan | \n",
- " -85.809438 | \n",
- " -85.480377 | \n",
- " nan | \n",
- " -27.776625 | \n",
- " -28.363053 | \n",
- " 2.11% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 27 | \n",
- " SST | \n",
- " min | \n",
- " -1.788055 | \n",
- " -1.788055 | \n",
- " nan | \n",
- " -1.676941 | \n",
- " -1.676941 | \n",
- " nan | \n",
- " -1.788055 | \n",
- " -1.788055 | \n",
- " nan | \n",
- " -1.676941 | \n",
- " -1.676941 | \n",
- " nan | \n",
- " -4.513070 | \n",
- " -2.993272 | \n",
- " 33.68% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 28 | \n",
- " SWCF | \n",
- " max | \n",
- " -0.518025 | \n",
- " -0.536844 | \n",
- " 3.63% | \n",
- " -0.311639 | \n",
- " -0.331616 | \n",
- " 6.41% | \n",
- " -0.518025 | \n",
- " -0.536844 | \n",
- " 3.63% | \n",
- " -0.311639 | \n",
- " -0.331616 | \n",
- " 6.41% | \n",
- " 11.668939 | \n",
- " 12.087077 | \n",
- " 3.58% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 29 | \n",
- " SWCF | \n",
- " min | \n",
- " -123.625017 | \n",
- " -122.042043 | \n",
- " nan | \n",
- " -131.053537 | \n",
- " -130.430161 | \n",
- " nan | \n",
- " -123.625017 | \n",
- " -122.042043 | \n",
- " nan | \n",
- " -131.053537 | \n",
- " -130.430161 | \n",
- " nan | \n",
- " -21.415249 | \n",
- " -20.808973 | \n",
- " 2.83% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 30 | \n",
- " TREFHT | \n",
- " max | \n",
- " 31.141508 | \n",
- " 31.058424 | \n",
- " nan | \n",
- " 29.819210 | \n",
- " 29.721868 | \n",
- " nan | \n",
- " 31.141508 | \n",
- " 31.058424 | \n",
- " nan | \n",
- " 29.819210 | \n",
- " 29.721868 | \n",
- " nan | \n",
- " 4.981757 | \n",
- " 5.126185 | \n",
- " 2.90% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 31 | \n",
- " TREFHT | \n",
- " max | \n",
- " 31.141508 | \n",
- " 31.058424 | \n",
- " nan | \n",
- " 29.819210 | \n",
- " 29.721868 | \n",
- " nan | \n",
- " 31.141508 | \n",
- " 31.058424 | \n",
- " nan | \n",
- " 29.819210 | \n",
- " 29.721868 | \n",
- " nan | \n",
- " 4.867855 | \n",
- " 5.126185 | \n",
- " 2.90% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 32 | \n",
- " TREFHT | \n",
- " max | \n",
- " 31.141508 | \n",
- " 31.058424 | \n",
- " nan | \n",
- " 29.819210 | \n",
- " 29.721868 | \n",
- " nan | \n",
- " 31.141508 | \n",
- " 31.058424 | \n",
- " nan | \n",
- " 29.819210 | \n",
- " 29.721868 | \n",
- " nan | \n",
- " 4.981757 | \n",
- " 5.126185 | \n",
- " 5.31% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 33 | \n",
- " TREFHT | \n",
- " max | \n",
- " 31.141508 | \n",
- " 31.058424 | \n",
- " nan | \n",
- " 29.819210 | \n",
- " 29.721868 | \n",
- " nan | \n",
- " 31.141508 | \n",
- " 31.058424 | \n",
- " nan | \n",
- " 29.819210 | \n",
- " 29.721868 | \n",
- " nan | \n",
- " 4.867855 | \n",
- " 5.126185 | \n",
- " 5.31% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 34 | \n",
- " TREFHT | \n",
- " mean | \n",
- " 14.769946 | \n",
- " 14.741707 | \n",
- " nan | \n",
- " 13.842013 | \n",
- " 13.800258 | \n",
- " nan | \n",
- " 14.769946 | \n",
- " 14.741707 | \n",
- " nan | \n",
- " 13.842013 | \n",
- " 13.800258 | \n",
- " nan | \n",
- " 0.927933 | \n",
- " 0.941449 | \n",
- " 2.28% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 35 | \n",
- " TREFHT | \n",
- " mean | \n",
- " 9.214224 | \n",
- " 9.114572 | \n",
- " nan | \n",
- " 8.083349 | \n",
- " 7.957917 | \n",
- " nan | \n",
- " 9.214224 | \n",
- " 9.114572 | \n",
- " nan | \n",
- " 8.083349 | \n",
- " 7.957917 | \n",
- " nan | \n",
- " 1.130876 | \n",
- " 1.156655 | \n",
- " 2.28% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 36 | \n",
- " TREFHT | \n",
- " min | \n",
- " -56.266677 | \n",
- " -55.623001 | \n",
- " nan | \n",
- " -58.159250 | \n",
- " -57.542053 | \n",
- " nan | \n",
- " -56.266677 | \n",
- " -55.623001 | \n",
- " nan | \n",
- " -58.159250 | \n",
- " -57.542053 | \n",
- " nan | \n",
- " -0.681558 | \n",
- " -0.624371 | \n",
- " 8.39% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 37 | \n",
- " TREFHT | \n",
- " min | \n",
- " -56.266677 | \n",
- " -55.623001 | \n",
- " nan | \n",
- " -58.159250 | \n",
- " -57.542053 | \n",
- " nan | \n",
- " -56.266677 | \n",
- " -55.623001 | \n",
- " nan | \n",
- " -58.159250 | \n",
- " -57.542053 | \n",
- " nan | \n",
- " -0.681558 | \n",
- " -0.624371 | \n",
- " 8.39% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 38 | \n",
- " TREFHT | \n",
- " min | \n",
- " -56.266677 | \n",
- " -55.623001 | \n",
- " nan | \n",
- " -58.159250 | \n",
- " -57.542053 | \n",
- " nan | \n",
- " -56.266677 | \n",
- " -55.623001 | \n",
- " nan | \n",
- " -58.159250 | \n",
- " -57.542053 | \n",
- " nan | \n",
- " -0.681558 | \n",
- " -0.624371 | \n",
- " 8.39% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 39 | \n",
- " TREFHT | \n",
- " min | \n",
- " -56.266677 | \n",
- " -55.623001 | \n",
- " nan | \n",
- " -58.159250 | \n",
- " -57.542053 | \n",
- " nan | \n",
- " -56.266677 | \n",
- " -55.623001 | \n",
- " nan | \n",
- " -58.159250 | \n",
- " -57.542053 | \n",
- " nan | \n",
- " -0.681558 | \n",
- " -0.624371 | \n",
- " 8.39% | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- "
\n",
- " \n",
- " 40 | \n",
- " TREFHT | \n",
- " rmse | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " 1.160718 | \n",
- " 1.179995 | \n",
- " 2.68% | \n",
- "
\n",
- " \n",
- " 41 | \n",
- " TREFHT | \n",
- " rmse | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " nan | \n",
- " 1.343169 | \n",
- " 1.379141 | \n",
- " 2.68% | \n",
- "
\n",
- " \n",
- "
\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "df_final.reset_index(names=[\"var_key\", \"metric\"]).style.map(\n",
- " lambda x: \"background-color : red\" if isinstance(x, str) else \"\",\n",
- " subset=pd.IndexSlice[:, PERCENTAGE_COLUMNS],\n",
- ")"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "cdat_regression_test",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.11.6"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/e3sm_diags/driver/qbo_driver.py b/e3sm_diags/driver/qbo_driver.py
index 396456598..3379f4c46 100644
--- a/e3sm_diags/driver/qbo_driver.py
+++ b/e3sm_diags/driver/qbo_driver.py
@@ -8,10 +8,7 @@
import scipy.fftpack
import xarray as xr
import xcdat as xc
-<<<<<<< HEAD
from scipy.signal import detrend
-=======
->>>>>>> c7ef34e9 (CDAT Migration Phase 2: Refactor `qbo` set (#826))
from e3sm_diags.driver.utils.dataset_xr import Dataset
from e3sm_diags.driver.utils.io import _get_output_dir, _write_to_netcdf
@@ -29,12 +26,9 @@
# The region will always be 5S5N
REGION = "5S5N"
-<<<<<<< HEAD
# Target power spectral vertical level for the wavelet diagnostic.
POW_SPEC_LEV = 20.0
-=======
->>>>>>> c7ef34e9 (CDAT Migration Phase 2: Refactor `qbo` set (#826))
class MetricsDict(TypedDict):
qbo: xr.DataArray
@@ -43,11 +37,8 @@ class MetricsDict(TypedDict):
period_new: np.ndarray
psd_x_new: np.ndarray
amplitude_new: np.ndarray
-<<<<<<< HEAD
wave_period: np.ndarray
wavelet: np.ndarray
-=======
->>>>>>> c7ef34e9 (CDAT Migration Phase 2: Refactor `qbo` set (#826))
name: str
@@ -105,7 +96,6 @@ def run_diag(parameter: QboParameter) -> QboParameter:
x_ref, ref_dict["period_new"]
)
-<<<<<<< HEAD
# Diagnostic 4: calculate the Wavelet
test_dict["wave_period"], test_dict["wavelet"] = _calculate_wavelet(
test_dict["qbo"]
@@ -114,8 +104,6 @@ def run_diag(parameter: QboParameter) -> QboParameter:
ref_dict["qbo"]
)
-=======
->>>>>>> c7ef34e9 (CDAT Migration Phase 2: Refactor `qbo` set (#826))
parameter.var_id = var_key
parameter.output_file = "qbo_diags"
parameter.main_title = (
@@ -135,15 +123,7 @@ def run_diag(parameter: QboParameter) -> QboParameter:
# Write the metrics to .json files.
test_dict["name"] = test_ds._get_test_name()
-<<<<<<< HEAD
ref_dict["name"] = ref_ds._get_ref_name()
-=======
-
- try:
- ref_dict["name"] = ref_ds._get_ref_name()
- except AttributeError:
- ref_dict["name"] = parameter.ref_name
->>>>>>> c7ef34e9 (CDAT Migration Phase 2: Refactor `qbo` set (#826))
_save_metrics_to_json(parameter, test_dict, "test") # type: ignore
_save_metrics_to_json(parameter, ref_dict, "ref") # type: ignore
@@ -172,11 +152,7 @@ def _save_metrics_to_json(
metrics_dict[key] = metrics_dict[key].tolist() # type: ignore
with open(abs_path, "w") as outfile:
-<<<<<<< HEAD
json.dump(metrics_dict, outfile, default=str)
-=======
- json.dump(metrics_dict, outfile)
->>>>>>> c7ef34e9 (CDAT Migration Phase 2: Refactor `qbo` set (#826))
logger.info("Metrics saved in: {}".format(abs_path))
@@ -379,7 +355,6 @@ def deseason(xraw):
# i.e., get the difference between this month's value and it's "usual" value
x_deseasoned[month_index] = xraw[month_index] - xclim[month]
return x_deseasoned
-<<<<<<< HEAD
def _calculate_wavelet(var: xr.DataArray) -> Tuple[np.ndarray, np.ndarray]:
@@ -444,5 +419,3 @@ def _get_psd_from_wavelet(data: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
psd = np.mean(np.square(np.abs(cwtmatr)), axis=1)
return (period, psd)
-=======
->>>>>>> c7ef34e9 (CDAT Migration Phase 2: Refactor `qbo` set (#826))
diff --git a/e3sm_diags/driver/utils/dataset_xr.py b/e3sm_diags/driver/utils/dataset_xr.py
index 1b851970e..71d7480cf 100644
--- a/e3sm_diags/driver/utils/dataset_xr.py
+++ b/e3sm_diags/driver/utils/dataset_xr.py
@@ -316,15 +316,6 @@ def _get_ref_name(self, default_name: str | None = None) -> str:
return self.parameter.ref_name
-<<<<<<< HEAD
-=======
- raise AttributeError(
- "Either `parameter.short_ref_name`, `parameter.reference_name`, or "
- "`parameter.ref_name` must be set to get the name and years attribute for "
- "reference datasets."
- )
-
->>>>>>> c7ef34e9 (CDAT Migration Phase 2: Refactor `qbo` set (#826))
def _get_global_attr_from_climo_dataset(
self, attr: str, season: ClimoFreq
) -> str | None:
@@ -444,7 +435,6 @@ def _get_climo_dataset(self, season: str) -> xr.Dataset:
)
ds = squeeze_time_dim(ds)
-<<<<<<< HEAD
ds = self._subset_vars_and_load(ds, self.var)
return ds
@@ -475,9 +465,6 @@ def _add_cf_attrs_to_z_axes(self, ds: xr.Dataset) -> xr.Dataset:
if axis_attr is None:
ds[dim].attrs["axis"] = "Z"
-=======
- ds = self._subset_vars_and_load(ds)
->>>>>>> c7ef34e9 (CDAT Migration Phase 2: Refactor `qbo` set (#826))
return ds
@@ -815,45 +802,6 @@ def _get_matching_climo_src_vars(
return None
- def _subset_vars_and_load(self, ds: xr.Dataset) -> xr.Dataset:
- """Subset for variables needed for processing and load into memory.
-
- Subsetting the dataset reduces its memory footprint. Loading is
- necessary because there seems to be an issue with `open_mfdataset()`
- and using the multiprocessing scheduler defined in e3sm_diags,
- resulting in timeouts and resource locking. To avoid this, we load the
- multi-file dataset into memory before performing downstream operations.
-
- Source: https://github.com/pydata/xarray/issues/3781
-
- Parameters
- ----------
- ds : xr.Dataset
- The dataset.
-
- Returns
- -------
- xr.Dataset
- The dataset subsetted and loaded into memory.
- """
- # slat and slon are lat lon pair for staggered FV grid included in
- # remapped files.
- if "slat" in ds.dims:
- ds = ds.drop_dims(["slat", "slon"])
-
- all_vars_keys = list(ds.data_vars.keys())
- hybrid_var_keys = set(list(sum(HYBRID_SIGMA_KEYS.values(), ())))
- keep_vars = [
- var
- for var in all_vars_keys
- if "bnd" in var or "bounds" in var or var in hybrid_var_keys
- ]
- ds = ds[[self.var] + keep_vars]
-
- ds.load(scheduler="sync")
-
- return ds
-
# --------------------------------------------------------------------------
# Time series related methods
# --------------------------------------------------------------------------
diff --git a/e3sm_diags/parameter/core_parameter.py b/e3sm_diags/parameter/core_parameter.py
index b96a92f36..5351a9cb5 100644
--- a/e3sm_diags/parameter/core_parameter.py
+++ b/e3sm_diags/parameter/core_parameter.py
@@ -46,10 +46,6 @@
from e3sm_diags.driver.utils.dataset_xr import Dataset
-if TYPE_CHECKING:
- from e3sm_diags.driver.utils.dataset_xr import Dataset
-
-
class CoreParameter:
def __init__(self):
# File I/O
diff --git a/e3sm_diags/plot/cartopy/aerosol_aeronet_plot.py b/e3sm_diags/plot/cartopy/aerosol_aeronet_plot.py
deleted file mode 100644
index 765235095..000000000
--- a/e3sm_diags/plot/cartopy/aerosol_aeronet_plot.py
+++ /dev/null
@@ -1,132 +0,0 @@
-import os
-
-import cartopy.crs as ccrs
-import matplotlib
-import numpy as np
-
-from e3sm_diags.driver.utils.general import get_output_dir
-from e3sm_diags.logger import custom_logger
-from e3sm_diags.metrics import mean
-from e3sm_diags.plot.cartopy.deprecated_lat_lon_plot import plot_panel
-
-matplotlib.use("Agg")
-import matplotlib.pyplot as plt # isort:skip # noqa: E402
-
-logger = custom_logger(__name__)
-
-plotTitle = {"fontsize": 11.5}
-plotSideTitle = {"fontsize": 9.5}
-
-
-def plot(test, test_site, ref_site, parameter):
- # Plot scatter plot
- # Position and sizes of subplot axes in page coordinates (0 to 1)
- # (left, bottom, width, height) in page coordinates
- panel = [
- (0.09, 0.40, 0.72, 0.30),
- (0.19, 0.2, 0.62, 0.30),
- ]
- # Border padding relative to subplot axes for saving individual panels
- # (left, bottom, right, top) in page coordinates
- border = (-0.06, -0.03, 0.13, 0.03)
-
- fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi)
- fig.suptitle(parameter.var_id, x=0.5, y=0.97)
- proj = ccrs.PlateCarree()
- max1 = test.max()
- min1 = test.min()
- mean1 = mean(test)
- # TODO: Replace this function call with `e3sm_diags.plot.utils._add_colormap()`.
- plot_panel(
- 0,
- fig,
- proj,
- test,
- parameter.contour_levels,
- parameter.test_colormap,
- (parameter.test_name_yrs, None, None),
- parameter,
- stats=(max1, mean1, min1),
- )
-
- ax = fig.add_axes(panel[1])
- ax.set_title(f"{parameter.var_id} from AERONET sites")
-
- # define 1:1 line, and x y axis limits
-
- if parameter.var_id == "AODVIS":
- x1 = np.arange(0.01, 3.0, 0.1)
- y1 = np.arange(0.01, 3.0, 0.1)
- plt.xlim(0.03, 1)
- plt.ylim(0.03, 1)
- else:
- x1 = np.arange(0.0001, 1.0, 0.01)
- y1 = np.arange(0.0001, 1.0, 0.01)
- plt.xlim(0.001, 0.3)
- plt.ylim(0.001, 0.3)
-
- plt.loglog(x1, y1, "-k", linewidth=0.5)
- plt.loglog(x1, y1 * 0.5, "--k", linewidth=0.5)
- plt.loglog(x1 * 0.5, y1, "--k", linewidth=0.5)
-
- corr = np.corrcoef(ref_site, test_site)
- xmean = np.mean(ref_site)
- ymean = np.mean(test_site)
- ax.text(
- 0.3,
- 0.9,
- f"Mean (test): {ymean:.3f} \n Mean (ref): {xmean:.3f}\n Corr: {corr[0, 1]:.2f}",
- horizontalalignment="right",
- verticalalignment="top",
- transform=ax.transAxes,
- )
-
- # axis ticks
- plt.tick_params(axis="both", which="major")
- plt.tick_params(axis="both", which="minor")
-
- # axis labels
- plt.xlabel(f"ref: {parameter.ref_name_yrs}")
- plt.ylabel(f"test: {parameter.test_name_yrs}")
-
- plt.loglog(ref_site, test_site, "kx", markersize=3.0, mfc="none")
-
- # legend
- plt.legend(frameon=False, prop={"size": 5})
-
- # TODO: This section can be refactored to use `plot.utils._save_plot()`.
- for f in parameter.output_format:
- f = f.lower().split(".")[-1]
- fnm = os.path.join(
- get_output_dir(parameter.current_set, parameter),
- f"{parameter.output_file}" + "." + f,
- )
- plt.savefig(fnm)
- logger.info(f"Plot saved in: {fnm}")
-
- for f in parameter.output_format_subplot:
- fnm = os.path.join(
- get_output_dir(parameter.current_set, parameter),
- parameter.output_file,
- )
- page = fig.get_size_inches()
- i = 0
- for p in panel:
- # Extent of subplot
- subpage = np.array(p).reshape(2, 2)
- subpage[1, :] = subpage[0, :] + subpage[1, :]
- subpage = subpage + np.array(border).reshape(2, 2)
- subpage = list(((subpage) * page).flatten()) # type: ignore
- extent = matplotlib.transforms.Bbox.from_extents(*subpage)
- # Save subplot
- fname = fnm + ".%i." % (i) + f
- plt.savefig(fname, bbox_inches=extent)
-
- orig_fnm = os.path.join(
- get_output_dir(parameter.current_set, parameter),
- parameter.output_file,
- )
- fname = orig_fnm + ".%i." % (i) + f
- logger.info(f"Sub-plot saved in: {fname}")
-
- i += 1
diff --git a/e3sm_diags/plot/cartopy/deprecated_lat_lon_plot.py b/e3sm_diags/plot/cartopy/deprecated_lat_lon_plot.py
deleted file mode 100644
index 4eaebcf80..000000000
--- a/e3sm_diags/plot/cartopy/deprecated_lat_lon_plot.py
+++ /dev/null
@@ -1,360 +0,0 @@
-"""
-WARNING: This module has been deprecated and replaced by
-`e3sm_diags.plot.lat_lon_plot.py`. This file temporarily kept because
-`e3sm_diags.plot.cartopy.aerosol_aeronet_plot.plot` references the
-`plot_panel()` function. Once the aerosol_aeronet set is refactored, this
-file can be deleted.
-"""
-from __future__ import print_function
-
-import os
-
-import cartopy.crs as ccrs
-import cartopy.feature as cfeature
-import cdutil
-import matplotlib
-import numpy as np
-import numpy.ma as ma
-from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter
-
-from e3sm_diags.derivations.default_regions import regions_specs
-from e3sm_diags.driver.utils.general import get_output_dir
-from e3sm_diags.logger import custom_logger
-from e3sm_diags.plot import get_colormap
-
-matplotlib.use("Agg")
-import matplotlib.colors as colors # isort:skip # noqa: E402
-import matplotlib.pyplot as plt # isort:skip # noqa: E402
-
-logger = custom_logger(__name__)
-
-plotTitle = {"fontsize": 11.5}
-plotSideTitle = {"fontsize": 9.5}
-
-# Position and sizes of subplot axes in page coordinates (0 to 1)
-panel = [
- (0.1691, 0.6810, 0.6465, 0.2258),
- (0.1691, 0.3961, 0.6465, 0.2258),
- (0.1691, 0.1112, 0.6465, 0.2258),
-]
-
-# Border padding relative to subplot axes for saving individual panels
-# (left, bottom, right, top) in page coordinates
-border = (-0.06, -0.03, 0.13, 0.03)
-
-
-def add_cyclic(var):
- lon = var.getLongitude()
- return var(longitude=(lon[0], lon[0] + 360.0, "coe"))
-
-
-def get_ax_size(fig, ax):
- bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
- width, height = bbox.width, bbox.height
- width *= fig.dpi
- height *= fig.dpi
- return width, height
-
-
-def determine_tick_step(degrees_covered):
- if degrees_covered > 180:
- return 60
- if degrees_covered > 60:
- return 30
- elif degrees_covered > 30:
- return 10
- elif degrees_covered > 20:
- return 5
- else:
- return 1
-
-
-def plot_panel( # noqa: C901
- n, fig, proj, var, clevels, cmap, title, parameters, stats=None
-):
- var = add_cyclic(var)
- lon = var.getLongitude()
- lat = var.getLatitude()
- var = ma.squeeze(var.asma())
-
- # Contour levels
- levels = None
- norm = None
- if len(clevels) > 0:
- levels = [-1.0e8] + clevels + [1.0e8]
- norm = colors.BoundaryNorm(boundaries=levels, ncolors=256)
-
- # ax.set_global()
- region_str = parameters.regions[0]
- region = regions_specs[region_str]
- global_domain = True
- full_lon = True
- if "domain" in region.keys(): # type: ignore
- # Get domain to plot
- domain = region["domain"] # type: ignore
- global_domain = False
- else:
- # Assume global domain
- domain = cdutil.region.domain(latitude=(-90.0, 90, "ccb"))
- kargs = domain.components()[0].kargs
- lon_west, lon_east, lat_south, lat_north = (0, 360, -90, 90)
- if "longitude" in kargs:
- full_lon = False
- lon_west, lon_east, _ = kargs["longitude"]
- # Note cartopy Problem with gridlines across the dateline:https://github.com/SciTools/cartopy/issues/821. Region cross dateline is not supported yet.
- if lon_west > 180 and lon_east > 180:
- lon_west = lon_west - 360
- lon_east = lon_east - 360
-
- if "latitude" in kargs:
- lat_south, lat_north, _ = kargs["latitude"]
- lon_covered = lon_east - lon_west
- lon_step = determine_tick_step(lon_covered)
- xticks = np.arange(lon_west, lon_east, lon_step)
- # Subtract 0.50 to get 0 W to show up on the right side of the plot.
- # If less than 0.50 is subtracted, then 0 W will overlap 0 E on the left side of the plot.
- # If a number is added, then the value won't show up at all.
- if global_domain or full_lon:
- xticks = np.append(xticks, lon_east - 0.50)
- proj = ccrs.PlateCarree(central_longitude=180)
- else:
- xticks = np.append(xticks, lon_east)
- lat_covered = lat_north - lat_south
- lat_step = determine_tick_step(lat_covered)
- yticks = np.arange(lat_south, lat_north, lat_step)
- yticks = np.append(yticks, lat_north)
-
- # Contour plot
- ax = fig.add_axes(panel[n], projection=proj)
- ax.set_extent([lon_west, lon_east, lat_south, lat_north], crs=proj)
- cmap = get_colormap(cmap, parameters)
- p1 = ax.contourf(
- lon,
- lat,
- var,
- transform=ccrs.PlateCarree(),
- norm=norm,
- levels=levels,
- cmap=cmap,
- extend="both",
- )
-
- # ax.set_aspect('auto')
- # Full world would be aspect 360/(2*180) = 1
- ax.set_aspect((lon_east - lon_west) / (2 * (lat_north - lat_south)))
- ax.coastlines(lw=0.3)
- if not global_domain and "RRM" in region_str:
- ax.coastlines(resolution="50m", color="black", linewidth=1)
- state_borders = cfeature.NaturalEarthFeature(
- category="cultural",
- name="admin_1_states_provinces_lakes",
- scale="50m",
- facecolor="none",
- )
- ax.add_feature(state_borders, edgecolor="black")
- if title[0] is not None:
- ax.set_title(title[0], loc="left", fontdict=plotSideTitle)
- if title[1] is not None:
- ax.set_title(title[1], fontdict=plotTitle)
- if title[2] is not None:
- ax.set_title(title[2], loc="right", fontdict=plotSideTitle)
- ax.set_xticks(xticks, crs=ccrs.PlateCarree())
- ax.set_yticks(yticks, crs=ccrs.PlateCarree())
- lon_formatter = LongitudeFormatter(zero_direction_label=True, number_format=".0f")
- lat_formatter = LatitudeFormatter()
- ax.xaxis.set_major_formatter(lon_formatter)
- ax.yaxis.set_major_formatter(lat_formatter)
- ax.tick_params(labelsize=8.0, direction="out", width=1)
- ax.xaxis.set_ticks_position("bottom")
- ax.yaxis.set_ticks_position("left")
-
- # Color bar
- cbax = fig.add_axes((panel[n][0] + 0.6635, panel[n][1] + 0.0215, 0.0326, 0.1792))
- cbar = fig.colorbar(p1, cax=cbax)
- w, h = get_ax_size(fig, cbax)
-
- if levels is None:
- cbar.ax.tick_params(labelsize=9.0, length=0)
-
- else:
- maxval = np.amax(np.absolute(levels[1:-1]))
- if maxval < 0.2:
- fmt = "%5.3f"
- pad = 28
- elif maxval < 10.0:
- fmt = "%5.2f"
- pad = 25
- elif maxval < 100.0:
- fmt = "%5.1f"
- pad = 25
- elif maxval > 9999.0:
- fmt = "%.0f"
- pad = 40
- else:
- fmt = "%6.1f"
- pad = 30
-
- cbar.set_ticks(levels[1:-1])
- labels = [fmt % level for level in levels[1:-1]]
- cbar.ax.set_yticklabels(labels, ha="right")
- cbar.ax.tick_params(labelsize=9.0, pad=pad, length=0)
-
- # Min, Mean, Max
- fig.text(
- panel[n][0] + 0.6635,
- panel[n][1] + 0.2107,
- "Max\nMean\nMin",
- ha="left",
- fontdict=plotSideTitle,
- )
-
- fmt_m = []
- # printing in scientific notation if value greater than 10^5
- for i in range(len(stats[0:3])):
- fs = "1e" if stats[i] > 100000.0 else "2f"
- fmt_m.append(fs)
- fmt_metrics = f"%.{fmt_m[0]}\n%.{fmt_m[1]}\n%.{fmt_m[2]}"
-
- fig.text(
- panel[n][0] + 0.7635,
- panel[n][1] + 0.2107,
- # "%.2f\n%.2f\n%.2f" % stats[0:3],
- fmt_metrics % stats[0:3],
- ha="right",
- fontdict=plotSideTitle,
- )
-
- # RMSE, CORR
- if len(stats) == 5:
- fig.text(
- panel[n][0] + 0.6635,
- panel[n][1] - 0.0105,
- "RMSE\nCORR",
- ha="left",
- fontdict=plotSideTitle,
- )
- fig.text(
- panel[n][0] + 0.7635,
- panel[n][1] - 0.0105,
- "%.2f\n%.2f" % stats[3:5],
- ha="right",
- fontdict=plotSideTitle,
- )
-
- # grid resolution info:
- if n == 2 and "RRM" in region_str:
- dlat = lat[2] - lat[1]
- dlon = lon[2] - lon[1]
- fig.text(
- panel[n][0] + 0.4635,
- panel[n][1] - 0.04,
- "Resolution: {:.2f}x{:.2f}".format(dlat, dlon),
- ha="left",
- fontdict=plotSideTitle,
- )
-
-
-def plot(reference, test, diff, metrics_dict, parameter):
- # Create figure, projection
- fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi)
- proj = ccrs.PlateCarree()
-
- # Figure title
- fig.suptitle(parameter.main_title, x=0.5, y=0.96, fontsize=18)
-
- # First two panels
- min1 = metrics_dict["test"]["min"]
- mean1 = metrics_dict["test"]["mean"]
- max1 = metrics_dict["test"]["max"]
-
- plot_panel(
- 0,
- fig,
- proj,
- test,
- parameter.contour_levels,
- parameter.test_colormap,
- (parameter.test_name_yrs, parameter.test_title, test.units),
- parameter,
- stats=(max1, mean1, min1),
- )
-
- if not parameter.model_only:
- min2 = metrics_dict["ref"]["min"]
- mean2 = metrics_dict["ref"]["mean"]
- max2 = metrics_dict["ref"]["max"]
-
- plot_panel(
- 1,
- fig,
- proj,
- reference,
- parameter.contour_levels,
- parameter.reference_colormap,
- (parameter.ref_name_yrs, parameter.reference_title, reference.units),
- parameter,
- stats=(max2, mean2, min2),
- )
-
- # Third panel
- min3 = metrics_dict["diff"]["min"]
- mean3 = metrics_dict["diff"]["mean"]
- max3 = metrics_dict["diff"]["max"]
- r = metrics_dict["misc"]["rmse"]
- c = metrics_dict["misc"]["corr"]
- plot_panel(
- 2,
- fig,
- proj,
- diff,
- parameter.diff_levels,
- parameter.diff_colormap,
- (None, parameter.diff_title, test.units),
- parameter,
- stats=(max3, mean3, min3, r, c),
- )
-
- # Save figure
- for f in parameter.output_format:
- f = f.lower().split(".")[-1]
- fnm = os.path.join(
- get_output_dir(parameter.current_set, parameter),
- parameter.output_file + "." + f,
- )
- plt.savefig(fnm)
- logger.info(f"Plot saved in: {fnm}")
-
- # Save individual subplots
- if parameter.ref_name == "":
- panels = [panel[0]]
- else:
- panels = panel
-
- for f in parameter.output_format_subplot:
- fnm = os.path.join(
- get_output_dir(parameter.current_set, parameter),
- parameter.output_file,
- )
- page = fig.get_size_inches()
- i = 0
- for p in panels:
- # Extent of subplot
- subpage = np.array(p).reshape(2, 2)
- subpage[1, :] = subpage[0, :] + subpage[1, :]
- subpage = subpage + np.array(border).reshape(2, 2)
- subpage = list(((subpage) * page).flatten()) # type: ignore
- extent = matplotlib.transforms.Bbox.from_extents(*subpage)
- # Save subplot
- fname = fnm + ".%i." % (i) + f
- plt.savefig(fname, bbox_inches=extent)
-
- orig_fnm = os.path.join(
- get_output_dir(parameter.current_set, parameter),
- parameter.output_file,
- )
- fname = orig_fnm + ".%i." % (i) + f
- logger.info(f"Sub-plot saved in: {fname}")
-
- i += 1
-
- plt.close()
diff --git a/e3sm_diags/plot/cartopy/zonal_mean_2d_plot.py b/e3sm_diags/plot/cartopy/zonal_mean_2d_plot.py
deleted file mode 100644
index a72bf5dce..000000000
--- a/e3sm_diags/plot/cartopy/zonal_mean_2d_plot.py
+++ /dev/null
@@ -1,187 +0,0 @@
-from typing import List, Optional, Tuple
-
-import matplotlib
-import numpy as np
-import xarray as xr
-import xcdat as xc
-
-from e3sm_diags.driver.utils.type_annotations import MetricsDict
-from e3sm_diags.logger import custom_logger
-from e3sm_diags.parameter.core_parameter import CoreParameter
-from e3sm_diags.parameter.zonal_mean_2d_parameter import DEFAULT_PLEVS
-from e3sm_diags.plot.utils import (
- DEFAULT_PANEL_CFG,
- _add_colorbar,
- _add_contour_plot,
- _add_min_mean_max_text,
- _add_rmse_corr_text,
- _configure_titles,
- _configure_x_and_y_axes,
- _get_c_levels_and_norm,
- _save_plot,
-)
-
-matplotlib.use("Agg")
-import matplotlib.pyplot as plt # isort:skip # noqa: E402
-
-logger = custom_logger(__name__)
-
-
-# Configs for x axis ticks and x axis limits.
-X_TICKS = np.array([-90, -60, -30, 0, 30, 60, 90])
-X_LIM = -90, 90
-
-
-def plot(
- parameter: CoreParameter,
- da_test: xr.DataArray,
- da_ref: xr.DataArray,
- da_diff: xr.DataArray,
- metrics_dict: MetricsDict,
-):
- """Plot the variable's metrics generated by the zonal_mean_2d set.
-
- Parameters
- ----------
- parameter : CoreParameter
- The CoreParameter object containing plot configurations.
- da_test : xr.DataArray
- The test data.
- da_ref : xr.DataArray
- The reference data.
- da_diff : xr.DataArray
- The difference between `da_test` and `da_ref` (both are regridded to
- the lower resolution of the two beforehand).
- metrics_dict : Metrics
- The metrics.
- """
- fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi)
- fig.suptitle(parameter.main_title, x=0.5, y=0.96, fontsize=18)
-
- # The variable units.
- units = metrics_dict["units"]
-
- # Add the first subplot for test data.
- min1 = metrics_dict["test"]["min"] # type: ignore
- mean1 = metrics_dict["test"]["mean"] # type: ignore
- max1 = metrics_dict["test"]["max"] # type: ignore
-
- _add_colormap(
- 0,
- da_test,
- fig,
- parameter,
- parameter.test_colormap,
- parameter.contour_levels,
- title=(parameter.test_name_yrs, parameter.test_title, units), # type: ignore
- metrics=(max1, mean1, min1), # type: ignore
- )
-
- # Add the second and third subplots for ref data and the differences,
- # respectively.
- min2 = metrics_dict["ref"]["min"] # type: ignore
- mean2 = metrics_dict["ref"]["mean"] # type: ignore
- max2 = metrics_dict["ref"]["max"] # type: ignore
-
- _add_colormap(
- 1,
- da_ref,
- fig,
- parameter,
- parameter.reference_colormap,
- parameter.contour_levels,
- title=(parameter.ref_name_yrs, parameter.reference_title, units), # type: ignore
- metrics=(max2, mean2, min2), # type: ignore
- )
-
- min3 = metrics_dict["diff"]["min"] # type: ignore
- mean3 = metrics_dict["diff"]["mean"] # type: ignore
- max3 = metrics_dict["diff"]["max"] # type: ignore
- r = metrics_dict["misc"]["rmse"] # type: ignore
- c = metrics_dict["misc"]["corr"] # type: ignore
-
- _add_colormap(
- 2,
- da_diff,
- fig,
- parameter,
- parameter.diff_colormap,
- parameter.diff_levels,
- title=(None, parameter.diff_title, da_diff.attrs["units"]), #
- metrics=(max3, mean3, min3, r, c), # type: ignore
- )
-
- _save_plot(fig, parameter)
-
- plt.close()
-
-
-def _add_colormap(
- subplot_num: int,
- var: xr.DataArray,
- fig: plt.Figure,
- parameter: CoreParameter,
- color_map: str,
- contour_levels: List[float],
- title: Tuple[Optional[str], str, str],
- metrics: Tuple[float, ...],
-):
- lat = xc.get_dim_coords(var, axis="Y")
- plev = xc.get_dim_coords(var, axis="Z")
- var = var.squeeze()
-
- # Configure contour levels
- # --------------------------------------------------------------------------
- c_levels, norm = _get_c_levels_and_norm(contour_levels)
-
- # Add the contour plot
- # --------------------------------------------------------------------------
- ax = fig.add_axes(DEFAULT_PANEL_CFG[subplot_num], projection=None)
-
- contour_plot = _add_contour_plot(
- ax, parameter, var, lat, plev, color_map, None, norm, c_levels
- )
-
- # Configure the aspect ratio and plot titles.
- # --------------------------------------------------------------------------
- ax.set_aspect("auto")
- _configure_titles(ax, title)
-
- # Configure x and y axis.
- # --------------------------------------------------------------------------
- _configure_x_and_y_axes(ax, X_TICKS, None, None, parameter.current_set)
- ax.set_xlim(X_LIM)
-
- if parameter.plot_log_plevs:
- ax.set_yscale("log")
-
- if parameter.plot_plevs:
- plev_ticks = parameter.plevs
- plt.yticks(plev_ticks, plev_ticks)
-
- # For default plevs, specify the pressure axis and show the 50 mb tick
- # at the top.
- if (
- not parameter.plot_log_plevs
- and not parameter.plot_plevs
- and parameter.plevs == DEFAULT_PLEVS
- ):
- plev_ticks = parameter.plevs
- new_ticks = [plev_ticks[0]] + plev_ticks[1::2]
- new_ticks = [int(x) for x in new_ticks]
- plt.yticks(new_ticks, new_ticks)
-
- plt.ylabel("pressure (mb)")
- ax.invert_yaxis()
-
- # Add and configure the color bar.
- # --------------------------------------------------------------------------
- _add_colorbar(fig, subplot_num, DEFAULT_PANEL_CFG, contour_plot, c_levels)
-
- # Add metrics text.
- # --------------------------------------------------------------------------
- # Min, Mean, Max
- _add_min_mean_max_text(fig, subplot_num, DEFAULT_PANEL_CFG, metrics)
-
- if len(metrics) == 5:
- _add_rmse_corr_text(fig, subplot_num, DEFAULT_PANEL_CFG, metrics)
diff --git a/e3sm_diags/plot/cartopy/zonal_mean_2d_stratosphere_plot.py b/e3sm_diags/plot/cartopy/zonal_mean_2d_stratosphere_plot.py
deleted file mode 100644
index 004f3c93d..000000000
--- a/e3sm_diags/plot/cartopy/zonal_mean_2d_stratosphere_plot.py
+++ /dev/null
@@ -1,15 +0,0 @@
-import xarray as xr
-
-from e3sm_diags.driver.utils.type_annotations import MetricsDict
-from e3sm_diags.parameter.core_parameter import CoreParameter
-from e3sm_diags.plot.cartopy.zonal_mean_2d_plot import plot as base_plot
-
-
-def plot(
- parameter: CoreParameter,
- da_test: xr.DataArray,
- da_ref: xr.DataArray,
- da_diff: xr.DataArray,
- metrics_dict: MetricsDict,
-):
- return base_plot(parameter, da_test, da_ref, da_diff, metrics_dict)
diff --git a/e3sm_diags/plot/utils.py b/e3sm_diags/plot/utils.py
index 668a20981..37481d69b 100644
--- a/e3sm_diags/plot/utils.py
+++ b/e3sm_diags/plot/utils.py
@@ -134,7 +134,6 @@ def _add_grid_res_info(fig, subplot_num, region_key, lat, lon, panel_configs):
ha="left",
fontdict={"fontsize": SECONDARY_TITLE_FONTSIZE},
)
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def _make_lon_cyclic(var: xr.DataArray):
diff --git a/tests/e3sm_diags/driver/utils/test_dataset_xr.py b/tests/e3sm_diags/driver/utils/test_dataset_xr.py
index e81c122bc..2c89c92cb 100644
--- a/tests/e3sm_diags/driver/utils/test_dataset_xr.py
+++ b/tests/e3sm_diags/driver/utils/test_dataset_xr.py
@@ -558,69 +558,9 @@ def test_returns_climo_dataset_using_test_file_variable_ref_name_and_season_nest
xr.testing.assert_identical(result, expected)
-<<<<<<< HEAD
@pytest.mark.xfail(
reason="Need to figure out why to create dummy incorrect time scalar variable with Xarray."
)
-=======
- def test_returns_climo_dataset_with_derived_variable(self):
- # We will derive the "PRECT" variable using the "pr" variable.
- ds_pr = xr.Dataset(
- coords={
- **spatial_coords,
- "time": xr.DataArray(
- dims="time",
- data=np.array(
- [
- cftime.DatetimeGregorian(
- 2000, 1, 16, 12, 0, 0, 0, has_year_zero=False
- ),
- ],
- dtype=object,
- ),
- attrs={
- "axis": "T",
- "long_name": "time",
- "standard_name": "time",
- "bounds": "time_bnds",
- },
- ),
- },
- data_vars={
- **spatial_bounds,
- "pr": xr.DataArray(
- xr.DataArray(
- data=np.array(
- [
- [[1.0, 1.0], [1.0, 1.0]],
- ]
- ),
- dims=["time", "lat", "lon"],
- attrs={"units": "mm/s"},
- )
- ),
- },
- )
-
- parameter = _create_parameter_object(
- "ref", "climo", self.data_path, "2000", "2001"
- )
- parameter.ref_file = "pr_200001_200112.nc"
- ds_pr.to_netcdf(f"{self.data_path}/{parameter.ref_file}")
-
- ds = Dataset(parameter, data_type="ref")
-
- result = ds.get_climo_dataset("PRECT", season="ANN")
- expected = ds_pr.copy()
- expected = expected.squeeze(dim="time").drop_vars("time")
- expected["PRECT"] = expected["pr"] * 3600 * 24
- expected["PRECT"].attrs["units"] = "mm/day"
- expected = expected.drop_vars("pr")
-
- xr.testing.assert_identical(result, expected)
-
- @pytest.mark.xfail
->>>>>>> c7ef34e9 (CDAT Migration Phase 2: Refactor `qbo` set (#826))
def test_returns_climo_dataset_using_derived_var_directly_from_dataset_and_replaces_scalar_time_var(
self,
):