Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENHANCEMENT: Autograd to jax #319

Open
wants to merge 7 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions examples/tutorial_1_WaveBot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"metadata": {},
"outputs": [],
"source": [
"import autograd.numpy as np\n",
"import jax.numpy as np\n",
"import capytaine as cpy\n",
"import matplotlib.pyplot as plt\n",
"from scipy.optimize import brute\n",
Expand Down Expand Up @@ -353,6 +353,10 @@
"outputs": [],
"source": [
"nsubsteps = 5\n",
"print(\"Value of wec:\", wec)\n",
"print(\"Value of results[0]:\", results[0])\n",
"print(\"Value of waves.sel(realization=0):\", waves.sel(realization=0))\n",
"print(\"Value of nsubsteps:\", nsubsteps)\n",
"pto_fdom, pto_tdom = pto.post_process(wec, results[0], waves.sel(realization=0), nsubsteps=nsubsteps)\n",
"wec_fdom, wec_tdom = wec.post_process(results[0], waves.sel(realization=0), nsubsteps=nsubsteps)"
]
Expand Down Expand Up @@ -851,7 +855,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "wot_dev",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -865,7 +869,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.10.13"
},
"vscode": {
"interpreter": {
Expand All @@ -874,5 +878,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
17 changes: 8 additions & 9 deletions examples/tutorial_2_AquaHarmonics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
"outputs": [],
"source": [
"import capytaine as cpy\n",
"import autograd.numpy as np\n",
"import jax.numpy as np\n",
"import numpy as onp\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib import cm\n",
"from scipy.optimize import brute\n",
Expand Down Expand Up @@ -292,7 +293,7 @@
"y = np.arange(-1*torque_max, 1.0*torque_max, 5)\n",
"X, Y = np.meshgrid(x, y)\n",
"Z = power_loss(X, Y).copy()/1e3\n",
"Z[np.abs(X*Y) > power_max] = np.NaN # cut off area outside of power limit\n",
"Z = np.where(jax.numpy.abs(X*Y) > power_max, onp.NaN, Z) # cut off area outside of power limit\n",
"\n",
"fig = plt.figure(figsize=plt.figaspect(0.4))\n",
"ax = [fig.add_subplot(1, 2, 1, projection=\"3d\"),\n",
Expand Down Expand Up @@ -477,9 +478,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [],
"source": [
"wec = wot.WEC.from_bem(\n",
Expand Down Expand Up @@ -551,7 +550,7 @@
"scale_x_opt = 50e-2\n",
"scale_obj = 1e-3\n",
"\n",
"options = {'maxiter': 200, 'ftol': 1e-8}\n",
"options = {'maxiter': 200, 'tol': 1e-8}\n",
"\n",
"results = wec.solve(\n",
" waves,\n",
Expand Down Expand Up @@ -901,7 +900,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "wot_dev",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -915,7 +914,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.10.13"
},
"vscode": {
"interpreter": {
Expand All @@ -924,5 +923,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
93 changes: 53 additions & 40 deletions examples/tutorial_3_LUPA.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
"import os\n",
"import gmsh, pygmsh\n",
"import capytaine as cpy\n",
"import autograd.numpy as np\n",
"import jax.numpy as np\n",
"import numpy as onp\n",
"import matplotlib.pyplot as plt\n",
"import xarray as xr\n",
"from scipy.optimize import brute\n",
Expand Down Expand Up @@ -276,7 +277,7 @@
" pitch_inertia_float + mass_float*d_float**2 + \n",
" pitch_inertia_spar + mass_spar*d_spar**2\n",
")\n",
"inertia = np.diag([mass_float, mass_spar, lupa_fb.disp_mass(), pitch_inertia])\n",
"inertia = np.diag(np.array([mass_float, mass_spar, lupa_fb.disp_mass(), pitch_inertia]))\n",
"\n",
"# additional DOFs\n",
"lupa_fb.add_translation_dof(name='Surge')\n",
Expand Down Expand Up @@ -403,14 +404,20 @@
" for j, idof in enumerate(influenced_dofs):\n",
" sp_idx += 1\n",
" if i == 0:\n",
" np.abs(bem_data.diffraction_force.sel(influenced_dof=idof)).plot(\n",
" ax=ax_ex[j], linestyle='dashed', label='Diffraction force')\n",
" np.abs(bem_data.Froude_Krylov_force.sel(influenced_dof=idof)).plot(\n",
" ax=ax_ex[j], linestyle='dashdot', label='Froude-Krylov force')\n",
" ex_handles, ex_labels = ax_ex[j].get_legend_handles_labels()\n",
" abs_diffraction_force = np.abs(np.array(bem_data.diffraction_force.sel(influenced_dof=idof)))\n",
" abs_Froude_Krylov_force = np.abs(np.array(bem_data.Froude_Krylov_force.sel(influenced_dof=idof)))\n",
" \n",
" # Plot the numpy arrays on the axes object ax_ex[j]\n",
" ax_ex[j].plot(abs_diffraction_force, linestyle='dashed', label='Diffraction force')\n",
" ax_ex[j].plot(abs_Froude_Krylov_force, linestyle='dashdot', label='Froude-Krylov force')\n",
" \n",
" # Set the title, xlabel, and ylabel of the axes object\n",
" ax_ex[j].set_title(f'{idof}')\n",
" ax_ex[j].set_xlabel('')\n",
" ax_ex[j].set_ylabel('')\n",
" \n",
" # Get the legend handles and labels\n",
" ex_handles, ex_labels = ax_ex[j].get_legend_handles_labels()\n",
" if j <= i:\n",
" bem_data.added_mass.sel(\n",
" radiating_dof=rdof, influenced_dof=idof).plot(ax=ax_am[i, j])\n",
Expand Down Expand Up @@ -837,16 +844,16 @@
" k_tt = nlines * (\n",
" pretension * fair_r / linelen * (fair_r + linelen*np.cos(theta)))\n",
" mat = np.zeros([7, 7])\n",
" mat[1, 1] = k_vv\n",
" mat[2, 2] = k_hh\n",
" mat[3, 3] = k_hh\n",
" mat[4, 4] = k_rr\n",
" mat[5, 5] = k_rr\n",
" mat[6, 6] = k_tt\n",
" mat[2, 5] = -k_rh\n",
" mat[5, 2] = -k_rh\n",
" mat[4, 3] = k_rh\n",
" mat[3, 4] = k_rh\n",
" mat = mat.at[1, 1].set(k_vv)\n",
" mat = mat.at[2, 2].set(k_hh)\n",
" mat = mat.at[3, 3].set(k_hh)\n",
" mat = mat.at[4, 4].set(k_rr)\n",
" mat = mat.at[5, 5].set(k_rr)\n",
" mat = mat.at[6, 6].set(k_tt)\n",
" mat = mat.at[2, 5].set(-k_rh)\n",
" mat = mat.at[5, 2].set(-k_rh)\n",
" mat = mat.at[4, 3].set(k_rh)\n",
" mat = mat.at[3, 4].set(k_rh)\n",
"\n",
" return mat"
]
Expand Down Expand Up @@ -936,7 +943,7 @@
"}\n",
"\n",
"# small amount of friction to avoid small/negative terms\n",
"friction = np.diag([2.0, 2.0, 2.0, 0])\n",
"friction = np.diag(np.array([2.0, 2.0, 2.0, 0]))\n",
"\n",
"# WEC\n",
"wec = wot.WEC.from_bem(bem_data,\n",
Expand Down Expand Up @@ -1020,26 +1027,32 @@
"outputs": [],
"source": [
"fig, ax = plt.subplots()\n",
"plt1 = np.abs(waves['south_max_90'].sel(realization=0)).plot(\n",
" ax=ax, color='C0', linestyle='solid', label='PW South, 90th percentile')\n",
"plt2 = np.abs(waves['south_max_annual'].sel(realization=0)).plot(\n",
" ax=ax, color='C0', linestyle='dotted', label='PW South, Max Annual')\n",
"plt3 = np.abs(waves['south_max_occurrence'].sel(realization=0)).plot(\n",
" ax=ax, color='C0', linestyle='dashed', label='PW South, Max Occurrence')\n",
"plt4 = np.abs(waves['south_min_10'].sel(realization=0)).plot(\n",
" ax=ax, color='C0', linestyle='dashdot', label='PW South, 10th percentile')\n",
"plt5 = np.abs(waves['north_max_90'].sel(realization=0)).plot(\n",
" ax=ax, color='C1', linestyle='solid', label='PW North, 90th percentile')\n",
"plt6 = np.abs(waves['north_max_annual'].sel(realization=0)).plot(\n",
" ax=ax, color='C1', linestyle='dotted', label='PW North, Max Annual')\n",
"plt7 = np.abs(waves['north_max_occurrence'].sel(realization=0)).plot(\n",
" ax=ax, color='C1', linestyle='dashed', label='PW North, Max Occurrence')\n",
"plt8 = np.abs(waves['north_min_10'].sel(realization=0)).plot(\n",
" ax=ax, color='C1', linestyle='dashdot', label='PW North, 10th percentile')\n",
"\n",
"# Convert JAX arrays to numpy arrays\n",
"south_max_90 = np.abs(np.array(waves['south_max_90'].sel(realization=0)))\n",
"south_max_annual = np.abs(np.array(waves['south_max_annual'].sel(realization=0)))\n",
"south_max_occurrence = np.abs(np.array(waves['south_max_occurrence'].sel(realization=0)))\n",
"south_min_10 = np.abs(np.array(waves['south_min_10'].sel(realization=0)))\n",
"north_max_90 = np.abs(np.array(waves['north_max_90'].sel(realization=0)))\n",
"north_max_annual = np.abs(np.array(waves['north_max_annual'].sel(realization=0)))\n",
"north_max_occurrence = np.abs(np.array(waves['north_max_occurrence'].sel(realization=0)))\n",
"north_min_10 = np.abs(np.array(waves['north_min_10'].sel(realization=0)))\n",
"\n",
"# Plot the numpy arrays on the axes object ax\n",
"ax.plot(south_max_90, color='C0', linestyle='solid', label='PW South, 90th percentile')\n",
"ax.plot(south_max_annual, color='C0', linestyle='dotted', label='PW South, Max Annual')\n",
"ax.plot(south_max_occurrence, color='C0', linestyle='dashed', label='PW South, Max Occurrence')\n",
"ax.plot(south_min_10, color='C0', linestyle='dashdot', label='PW South, 10th percentile')\n",
"ax.plot(north_max_90, color='C1', linestyle='solid', label='PW North, 90th percentile')\n",
"ax.plot(north_max_annual, color='C1', linestyle='dotted', label='PW North, Max Annual')\n",
"ax.plot(north_max_occurrence, color='C1', linestyle='dashed', label='PW North, Max Occurrence')\n",
"ax.plot(north_min_10, color='C1', linestyle='dashdot', label='PW North, 10th percentile')\n",
"\n",
"# Set the title of the axes object\n",
"ax.set_title('PacWave wave spectra, LWF scale', fontweight='bold')\n",
"plts = plt1 + plt2 + plt3 + plt4 + plt5 + plt6 + plt7 + plt8\n",
"ax.legend(plts, [pl.get_label() for pl in plts], ncols=1, frameon=False)"
"\n",
"# Get the legend handles and labels\n",
"handles, labels = ax.get_legend_handles_labels()\n",
"ax.legend(handles, labels, ncols=1, frameon=False)"
]
},
{
Expand Down Expand Up @@ -1471,7 +1484,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "wot_dev",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -1485,7 +1498,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.13"
},
"vscode": {
"interpreter": {
Expand All @@ -1494,5 +1507,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
39 changes: 21 additions & 18 deletions examples/tutorial_4_Pioneer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"outputs": [],
"source": [
"import capytaine as cpy\n",
"import autograd.numpy as np\n",
"import jax.numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from scipy.linalg import block_diag\n",
"import xarray as xr\n",
Expand Down Expand Up @@ -109,10 +109,20 @@
"source": [
"#TODO: highlight the harmonics if wave freq and Tp with other markers+colors\n",
"fig, ax = plt.subplots()\n",
"np.abs(waves_regular).plot(marker = 'x', label=\"regular\")\n",
"np.abs(waves_irregular.sel(realization=0)).plot(marker = '*', label=\"irregular\")\n",
"# Convert JAX arrays to numpy arrays\n",
"waves_regular_np = np.abs(np.array(waves_regular)).ravel()\n",
"waves_irregular_np = np.abs(np.array(waves_irregular.sel(realization=0))).ravel()\n",
"\n",
"# Plot the numpy arrays on the axes object ax\n",
"ax.plot(waves_regular_np, marker='x', label='regular')\n",
"ax.plot(waves_irregular_np, marker='*', label='irregular')\n",
"\n",
"# Set the title of the axes object\n",
"ax.set_title('Wave elevation spectrum', fontweight='bold')\n",
"plt.legend()"
"\n",
"# Get the legend handles and labels\n",
"handles, labels = ax.get_legend_handles_labels()\n",
"ax.legend(handles, labels)"
]
},
{
Expand Down Expand Up @@ -268,10 +278,10 @@
"fig_ex, ax_ex = plt.subplots(tight_layout=True, sharex=True)\n",
"\n",
"# Excitation\n",
"np.abs(bem_data.diffraction_force.sel(influenced_dof='Pitch')).plot(\n",
" ax=ax_ex, linestyle='dashed', label='Diffraction force')\n",
"np.abs(bem_data.Froude_Krylov_force.sel(influenced_dof='Pitch')).plot(\n",
" ax=ax_ex, linestyle='dashdot', label='Froude-Krylov force')\n",
"ax_ex.plot(np.abs(np.array(bem_data.diffraction_force.sel(influenced_dof='Pitch')).ravel()), \n",
" linestyle='dashed', label='Diffraction force')\n",
"ax_ex.plot(np.abs(np.array(bem_data.Froude_Krylov_force.sel(influenced_dof='Pitch')).ravel()), \n",
" linestyle='dashdot', label='Froude-Krylov force')\n",
"ex_handles, ex_labels = ax_ex.get_legend_handles_labels()\n",
"ax_ex.set_xlabel(f'$\\omega$', fontsize=10)\n",
"ax_ex.set_title('Wave Excitation Coefficients', fontweight='bold')\n",
Expand Down Expand Up @@ -958,18 +968,11 @@
" axi.label_outer()\n",
" axi.autoscale(axis='x', tight=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "wot_dev",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -983,9 +986,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
Loading