diff --git a/notebooks/Data_Drive_PINN.ipynb b/notebooks/Data_Drive_PINN.ipynb new file mode 100644 index 0000000..fc27e51 --- /dev/null +++ b/notebooks/Data_Drive_PINN.ipynb @@ -0,0 +1,283 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import xarray as xr\n", + "import matplotlib.pyplot as plt\n", + "import cartopy.crs as ccrs\n", + "import cartopy.feature as cfeature\n", + "\n", + "# Load the saved Zarr data\n", + "zarr_path = '~/shared-public/mind_the_chl_gap/U-Net_with_CHL_pred.zarr'\n", + "zarr_ds = xr.open_zarr(zarr_path)['gapfree_pred']\n", + "\n", + "# Select the date you want to plot\n", + "date_to_plot = '2022-01-01' # Replace with the desired date\n", + "zarr_date = zarr_ds.sel(time=date_to_plot)\n", + "\n", + "# Load the Level 3 CHL data\n", + "level3_path = '~/shared-public/mind_the_chl_gap/IO.zarr'\n", + "level3_ds = xr.open_zarr(level3_path)\n", + "level3_chl = level3_ds['CHL_cmes-level3'].sel(time=date_to_plot)\n", + "sst = level3_ds['sst'].sel(time=date_to_plot)\n", + "u_wind = level3_ds['u_wind'].sel(time=date_to_plot)\n", + "v_wind = level3_ds['v_wind'].sel(time=date_to_plot)\n", + "air_temp = level3_ds['air_temp'].sel(time=date_to_plot)\n", + "ug_curr = level3_ds['ug_curr'].sel(time=date_to_plot)\n", + "# Plot the data\n", + "fig, axes = plt.subplots(nrows=7, ncols=1, figsize=(12, 6), subplot_kw={'projection': ccrs.PlateCarree()})\n", + "\n", + "# Plot the log-scaled Level 3 CHL data\n", + "ax = axes[0]\n", + "level3_chl_log = np.log(level3_chl.where(~np.isnan(level3_chl), np.nan))\n", + "im = ax.imshow(level3_chl_log, vmin=np.nanmin(level3_chl_log), vmax=np.nanmax(level3_chl_log), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n", + "ax.set_title('Log-scaled Level 3 CHL')\n", + "ax.add_feature(cfeature.COASTLINE)\n", + "ax.set_xlabel('Longitude')\n", + "ax.set_ylabel('Latitude')\n", + "ax = axes[1]\n", + "im = ax.imshow(sst, vmin=np.nanmin(sst), vmax=np.nanmax(sst), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n", + "ax.set_title('SST')\n", + "ax.add_feature(cfeature.COASTLINE)\n", + "ax.set_xlabel('Longitude')\n", + "ax.set_ylabel('Latitude')\n", + "ax = axes[2]\n", + "im = ax.imshow(u_wind, vmin=np.nanmin(u_wind), vmax=np.nanmax(u_wind), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n", + "ax.set_title('u_wind')\n", + "ax.add_feature(cfeature.COASTLINE)\n", + "ax.set_xlabel('Longitude')\n", + "ax.set_ylabel('Latitude')\n", + "ax = axes[3]\n", + "im = ax.imshow(v_wind, vmin=np.nanmin(v_wind), vmax=np.nanmax(v_wind), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n", + "ax.set_title('v_wind')\n", + "ax.add_feature(cfeature.COASTLINE)\n", + "ax.set_xlabel('Longitude')\n", + "ax.set_ylabel('Latitude')\n", + "ax = axes[4]\n", + "im = ax.imshow(air_temp, vmin=np.nanmin(air_temp), vmax=np.nanmax(air_temp), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n", + "ax.set_title('air_temp')\n", + "ax.add_feature(cfeature.COASTLINE)\n", + "ax.set_xlabel('Longitude')\n", + "ax.set_ylabel('Latitude')\n", + "ax = axes[5]\n", + "im = ax.imshow(ug_curr, vmin=np.nanmin(ug_curr), vmax=np.nanmax(ug_curr), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n", + "ax.set_title('ug_curr')\n", + "ax.add_feature(cfeature.COASTLINE)\n", + "ax.set_xlabel('Longitude')\n", + "ax.set_ylabel('Latitude')\n", + "ax = axes[6]\n", + "gapfill_chl_log = zarr_date\n", + "im = ax.imshow(gapfill_chl_log, vmin=np.nanmin(gapfill_chl_log), vmax=np.nanmax(gapfill_chl_log), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n", + "ax.set_title('Log-scaled U-Net Gapfilled CHL Prediction')\n", + "ax.add_feature(cfeature.COASTLINE)\n", + "ax.set_xlabel('Longitude')\n", + "ax.set_ylabel('Latitude')\n", + "\n", + "fig.colorbar(im, ax=axes.ravel().tolist(), location='right', shrink=0.9)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: deepxde in /srv/conda/envs/notebook/lib/python3.11/site-packages (1.12.0)\n", + "Requirement already satisfied: matplotlib in /srv/conda/envs/notebook/lib/python3.11/site-packages (from deepxde) (3.8.0)\n", + "Requirement already satisfied: numpy in /srv/conda/envs/notebook/lib/python3.11/site-packages (from deepxde) (1.24.4)\n", + "Requirement already satisfied: scikit-learn in /srv/conda/envs/notebook/lib/python3.11/site-packages (from deepxde) (1.3.0)\n", + "Requirement already satisfied: scikit-optimize>=0.9.0 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from deepxde) (0.10.2)\n", + "Requirement already satisfied: scipy in /srv/conda/envs/notebook/lib/python3.11/site-packages (from deepxde) (1.11.2)\n", + "Requirement already satisfied: joblib>=0.11 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from scikit-optimize>=0.9.0->deepxde) (1.3.2)\n", + "Requirement already satisfied: pyaml>=16.9 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from scikit-optimize>=0.9.0->deepxde) (24.7.0)\n", + "Requirement already satisfied: packaging>=21.3 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from scikit-optimize>=0.9.0->deepxde) (23.1)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from scikit-learn->deepxde) (3.2.0)\n", + "Requirement already satisfied: contourpy>=1.0.1 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from matplotlib->deepxde) (1.1.1)\n", + "Requirement already satisfied: cycler>=0.10 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from matplotlib->deepxde) (0.11.0)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from matplotlib->deepxde) (4.42.1)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from matplotlib->deepxde) (1.4.5)\n", + "Requirement already satisfied: pillow>=6.2.0 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from matplotlib->deepxde) (9.5.0)\n", + "Requirement already satisfied: pyparsing>=2.3.1 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from matplotlib->deepxde) (3.1.1)\n", + "Requirement already satisfied: python-dateutil>=2.7 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from matplotlib->deepxde) (2.8.2)\n", + "Requirement already satisfied: PyYAML in /srv/conda/envs/notebook/lib/python3.11/site-packages (from pyaml>=16.9->scikit-optimize>=0.9.0->deepxde) (6.0.1)\n", + "Requirement already satisfied: six>=1.5 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from python-dateutil>=2.7->matplotlib->deepxde) (1.16.0)\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "pip install deepxde" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using backend: pytorch\n", + "Other supported backends: tensorflow.compat.v1, tensorflow, jax, paddle.\n", + "paddle supports more examples now and is recommended.\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import xarray as xr\n", + "import os\n", + "os.environ[\"DDEBACKEND\"] = \"pytorch\"\n", + "import deepxde as dde\n", + "import matplotlib.pyplot as plt\n", + "import cartopy.crs as ccrs\n", + "import cartopy.feature as cfeature\n", + "from tqdm import tqdm\n", + "import numpy as np\n", + "import xarray as xr\n", + "import deepxde as dde\n", + "\n", + "# Load the data\n", + "zarr_path = '~/shared-public/mind_the_chl_gap/U-Net_with_CHL_pred.zarr'\n", + "zarr_ds = xr.open_zarr(zarr_path)['gapfree_pred']\n", + "\n", + "level3_path = '~/shared-public/mind_the_chl_gap/IO.zarr'\n", + "level3_ds = xr.open_zarr(level3_path)\n", + "\n", + "# Prepare the input data (v)\n", + "variables = ['CHL_cmes-level3', 'sst', 'u_wind', 'v_wind', 'air_temp', 'ug_curr']\n", + "input_data = []\n", + "\n", + "for var in variables:\n", + " data = level3_ds[var].values\n", + " data = np.log(data) if var == 'CHL_cmes-level3' else data\n", + " input_data.append(data)\n", + "\n", + "v = np.stack(input_data, axis=-1)\n", + "\n", + "# Prepare the output data (u)\n", + "u = zarr_ds.values\n", + "\n", + "# Reshape the data\n", + "v = v.reshape(-1, v.shape[-1]) # (num_points, num_variables)\n", + "u = u.reshape(-1) # (num_points,)\n", + "\n", + "# Split the data into training and testing sets\n", + "n_train = int(0.8 * len(u))\n", + "X_train, y_train = (v[:n_train], np.zeros((n_train, 1))), u[:n_train]\n", + "X_test, y_test = (v[n_train:], np.zeros((len(u) - n_train, 1))), u[n_train:]\n", + "# Set up the data\n", + "data = dde.data.TripleCartesianProd(\n", + " X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test\n", + ")\n", + "\n", + "# Define the DeepONet architecture\n", + "m = v.shape[1] # number of input variables\n", + "dim_x = 1 # dimension of spatial input (in this case, just a placeholder)\n", + "\n", + "net = dde.nn.DeepONetCartesianProd(\n", + " [m, 64, 64], # branch net\n", + " [dim_x, 64, 64], # trunk net\n", + " \"relu\",\n", + " \"Glorot normal\",\n", + ")\n", + "\n", + "# Create the model\n", + "model = dde.Model(data, net)\n", + "\n", + "# Compile the model\n", + "model.compile(\"adam\", lr=0.001, metrics=[\"mean l2 relative error\"])\n", + "\n", + "# Create a custom loss function\n", + "def custom_loss(inputs, outputs, targets):\n", + " return torch.mean((outputs - targets)**2)\n", + "\n", + "# Create the model\n", + "model = dde.Model(data, net)\n", + "print(v.dtype)\n", + "print(x.dtype)\n", + "print(u.dtype)\n", + "\n", + "# Compile the model\n", + "model.compile(\"adam\", lr=0.001, loss=custom_loss, metrics=[\"mean l2 relative error\"])\n", + "\n", + "print(\"Training the model...\")\n", + "\n", + "# Train the model\n", + "losshistory, train_state = model.train(iterations=10000, batch_size=32) \n", + "\n", + "print(\"Making predictions...\")\n", + "\n", + "# Make predictions\n", + "y_pred = model.predict(X_test)\n", + "\n", + "# Reshape the predictions back to the original shape\n", + "y_pred = y_pred.reshape(zarr_ds.shape[1:])\n", + "\n", + "print(\"Visualizing results...\")\n", + "\n", + "# Visualize the results\n", + "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10), subplot_kw={'projection': ccrs.PlateCarree()})\n", + "\n", + "# True gap-filled CHL\n", + "im1 = ax1.imshow(zarr_ds.isel(time=0), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n", + "ax1.set_title('True Gap-filled CHL')\n", + "ax1.add_feature(cfeature.COASTLINE)\n", + "ax1.set_xlabel('Longitude')\n", + "ax1.set_ylabel('Latitude')\n", + "fig.colorbar(im1, ax=ax1)\n", + "\n", + "# Predicted gap-filled CHL\n", + "im2 = ax2.imshow(y_pred, extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n", + "ax2.set_title('Predicted Gap-filled CHL')\n", + "ax2.add_feature(cfeature.COASTLINE)\n", + "ax2.set_xlabel('Longitude')\n", + "ax2.set_ylabel('Latitude')\n", + "fig.colorbar(im2, ax=ax2)\n", + "\n", + "plt.show()\n", + "\n", + "print(\"Process completed.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}