Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Some documentation fixes and fixes in Jax point cost functions. #118

Merged
merged 6 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-package-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11"]
os: [macOS, ubuntu]
inlcude:
- os: macos-latest
Expand Down
2 changes: 1 addition & 1 deletion REQUIREMENTS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ pooch
cmweather
cdsapi
xarray
datatree
xarray-datatree
1 change: 1 addition & 0 deletions continuous_integration/environment-actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ dependencies:
- jaxopt
- tensorflow>=2.6
- tensorflow-probability
- xarray-datatree
1 change: 1 addition & 0 deletions doc/environment_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ dependencies:
- sphinx-gallery
- sphinx-copybutton
- sphinx-design
- xarray-datatree
1 change: 1 addition & 0 deletions doc/source/contributors_guide/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Examples of unacceptable behavior by participants include:
advances

Trolling, insulting/derogatory comments, and personal or political attacks

Public or private harassment

Publishing others' private information, such as a physical or electronic
Expand Down
8 changes: 3 additions & 5 deletions examples/README.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
PyDDA Example Gallery
====================

Different examples are given on how to retrieve winds using HRRR and radar data.

Example grid data files for Hurricane Florence are available at:

https://drive.google.com/drive/folders/1pcQxWRJV78xuJePTZnlXPPpMe1qut0ie
In this section, we show different examples on:
* How to use HRRR to initalize your wind retrieval
* How to adjust the variational retrieval parameters
5 changes: 2 additions & 3 deletions pydda/cost_functions/_cost_functions_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,8 @@ def calculate_point_cost(u, v, x, y, z, point_list, Cp=1e-3, roi=500.0):
),
jnp.abs(z - the_point["z"]) < roi,
)
J += jnp.sum(
((u[the_box] - the_point["u"]) ** 2 + (v[the_box] - the_point["v"]) ** 2)
)
the_box = jnp.where(the_box, 1.0, 0.0)
J += jnp.sum(((u - the_point["u"]) ** 2 + (v - the_point["v"]) ** 2) * the_box)

return J * Cp

Expand Down
4 changes: 0 additions & 4 deletions pydda/cost_functions/cost_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
TENSORFLOW_AVAILABLE = False

try:
from jax.config import config

config.update("jax_enable_x64", True)
import jax.numpy as jnp

JAX_AVAILABLE = True
Expand Down Expand Up @@ -858,7 +855,6 @@ def grad_jax(winds, parameters):
parameters.point_list,
Cp=parameters.Cpoint,
roi=parameters.roi,
upper_bc=parameters.upper_bc,
)
return grad

Expand Down
1 change: 1 addition & 0 deletions pydda/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

read_grid
read_from_pyart_grid
read_hpl
"""

from .read_grid import read_grid, read_from_pyart_grid
1 change: 0 additions & 1 deletion pydda/io/read_grid.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import xarray as xr
import xradar as xd
import numpy as np

from glob import glob
Expand Down
21 changes: 15 additions & 6 deletions pydda/retrieval/nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,23 @@ def get_dd_wind_field_nested(grid_tree: DataTree, **kwargs):
"""
Does a wind retrieval over nested grids. The nested grids are created using PyART's
:func:`pyart.map.grid_from_radars` function and then placed into a tree structure using
dictionaries. Each node of the tree has three parameters:
'input_grids': The list of PyART grids for the given level of the grid
'kwargs': The list of key word arguments for input to the get_dd_wind_field function for the set of grids.
If this is None, then the default keyword arguments are carried from the keyword arguments of this function.
'children': The list of trees that are the children of this node.
:func:`dataTree`s. Each node of the tree has three parameters:
.. list-table:: Title
:widths: 25 100
:header-rows: 1

* - Dictionary key
- Description
* - input_grids
- The list of PyART grids for the given level of the grid
* - kwargs
- The list of key word arguments for input to the :py:func:`pydda.retrieval.get_dd_wind_field` function for the set of grids.
* - children
- The list of trees that are the children of this node.

The function will output the same tree, with the list of output grids of each level output to the 'output_grids'
member of the tree structure.
member of the tree structure. If *kwargs* is set to None, then the input keyword arguments will be
used throughout the retrieval.
"""

# Look for radars in current level
Expand Down
12 changes: 6 additions & 6 deletions pydda/retrieval/wind_retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,7 +1326,7 @@ def get_dd_wind_field(
Using Tensorflow or Jax expands PyDDA's capabiability to take advantage of GPU-based systems.
In addition, these two implementations use automatic differentation to calculate the gradient
of the cost function in order to optimize the gradient calculation.
TensorFlow 2.6 and tensorflow-probability are required for the TensorFlow-basedengine.
TensorFlow 2.6 and tensorflow-probability are required for the TensorFlow-based engine.
The latest version of Jax is required for the Jax-based engine.
points: None or list of dicts
Point observations as returned by :func:`pydda.constraints.get_iem_obs`. Set
Expand Down Expand Up @@ -1413,9 +1413,9 @@ def get_dd_wind_field(
The list of fields in the first grid in Grids that contain the custom
data interpolated to the Grid's grid specification. Helper functions
to create such gridded fields for HRRR and NetCDF WRF data exist
in ::pydda.constraints::. PyDDA will look for fields named U_(model
field name), V_(model field name), and W_(model field name). For
example, if you have U_hrrr, V_hrrr, and W_hrrr, then specify ["hrrr"]
in :py:func:`pydda.constraints`. PyDDA will look for fields named *U_(model
field name)*, *V_(model field name)*, and *W_(model field name)*. For
example, if you have *U_hrrr*, *V_hrrr*, and *W_hrrr*, then specify *["hrrr"]*
into model_fields.
output_cost_functions: bool
Set to True to output the value of each cost function every
Expand All @@ -1429,9 +1429,9 @@ def get_dd_wind_field(
wind_tol: float
Stop iterations after maximum change in winds is less than this value.
tolerance: float
Tolerance for L2 norm of gradient before stopping.
Tolerance for :math:`L_{2}` norm of gradient before stopping.
max_wind_magnitude: float
Constrain the optimization to have :math:`|u|, :math:`|w|`, and :math:`|w| < x` m/s.
Constrain the optimization to have :math:`|u|`, :math:`|w|`, and :math:`|w| < x` m/s.

Returns
=======
Expand Down
4 changes: 0 additions & 4 deletions pydda/tests/test_cost_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,6 @@ def test_vert_vorticity_tf():
def test_point_cost():
u = 1 * np.ones((10, 10, 10))
v = 1 * np.ones((10, 10, 10))
0 * np.ones((10, 10, 10))

x = np.linspace(-10, 10, 10)
y = np.linspace(-10, 10, 10)
z = np.linspace(-10, 10, 10)
Expand Down Expand Up @@ -556,8 +554,6 @@ def test_point_cost():
def test_point_cost_jax():
u = 1 * np.ones((10, 10, 10))
v = 1 * np.ones((10, 10, 10))
0 * np.ones((10, 10, 10))

x = np.linspace(-10, 10, 10)
y = np.linspace(-10, 10, 10)
z = np.linspace(-10, 10, 10)
Expand Down
Loading