-
Notifications
You must be signed in to change notification settings - Fork 154
How to add dask support
See the steps below for how basic dask support can be added to WRF-Python in v2.0.
Let's choose a simple example -- the 'tv' diagnostic from getvar. In particular, we're interested in the 'getter' function found in g_temp.py. The code is below (note: we're ignoring the metadata and units stuff for this example):
def get_tv(wrfin, timeidx=0, method="cat", squeeze=True,
cache=None, meta=True, _key=None,
units="K"):
varnames = ("T", "P", "PB", "QVAPOR")
ncvars = extract_vars(wrfin, timeidx, varnames, method, squeeze, cache,
meta=False, _key=_key)
t = ncvars["T"]
p = ncvars["P"]
pb = ncvars["PB"]
qv = ncvars["QVAPOR"]
full_t = t + Constants.T_BASE
full_p = p + pb
tk = _tk(full_p, full_t)
tv = _tv(tk, qv)
return tv
In the code above, we extract a few variables from a WRF file, compute pressure and potential temperature, then compute tk (temperature in kelvin) and tv (virtual temperature). Below, let's show how this could be rewritten using xarray and dask.
-
Create thin wrappers
The _tv and _tk code in extension.py calls a Fortran routine and performs several common operations via wrapt decorators. Unfortunately, wrapt decorators don't serialize, so we need to create thin wrappers around wrapt decorated functions.
Also, in order to create the dask tasks, we need functions to pass to the dask map_blocks routine, so we'll also need wrappers around the "base + perturbation" operations above.
Let's start with the _tv and _tk wrappers. For these functions, since OpenMP is already supported at the Fortran level, let's take an additional argument to set the number of OpenMP threads to use. Note that dask can do what OpenMP does, so this is entirely optional, but if you wanted to use dask tasks with OpenMP for the low level computation, this is one way to do it (note: using the omp_threads argument was easier than trying to get environment variables to work with the multiprocessing scheduler).
def tk_wrap(pressure, theta, omp_threads=1): from wrf.extension import _tk, omp_set_num_threads omp_set_num_threads(omp_threads) result = _tk(pressure, theta) return result
def tv_wrap(temp_k, qvapor, omp_threads=1): from wrf.extension import _tv, omp_set_num_threads omp_set_num_threads(omp_threads) result = _tv(temp_k, qvapor) return result
Next, let's make a thin wrapper for the "base + perturbation" operation:
def pert_add(base, perturbation): return base + perturbation
In the above wrappers, we're assuming we want each operation to run as a separate dask task. Alternatively, you could put all of the above in to one wrapper so they run as one dask task. It is left as an exercise to the WRF-Python 2.x developer to determine which performs better (suspect that one task will perform better to minimize I/O and and message passing in a distributed environment).
-
Create a 'getter' function for the 'tv' diagnostic
The original getter method above takes several common arguments (wrfin, timeidx, method, squeeze, etc.). Since much of WRF-Python 1.x implements things xarray does, a lot of this can be gutted in WRF-Python 2.x in favor of xarray. So, this getter function only needs to take an xarray Dataset argument and the OpenMP number of threads argument (only if you want users to control OpenMP threads).
For this example, we're going to assume dask is installed, but in a real implementation, the xarray.DataArray.data attribute might return a numpy array so it should be prepared for that. Here, we're only trying to show how to make dask work.
Here is the code:
from wrf import Constants from dask.array import map_blocks def tv_getter(ds, omp_threads=1): t = ds["T"].data p = ds["P"].data pb = ds["PB"].data qv = ds["QVAPOR"].data full_t = map_blocks(pert_add, Constants.T_BASE, t, omp_threads, dtype=t.dtype) full_p = map_blocks(pert_add, pb, p, omp_threads, dtype=p.dtype) tk = map_blocks(tk_wrap, full_p, full_t, omp_threads, dtype=p.dtype) tv = map_blocks(tv_wrap, tk, qv) return tv
Note that the above is returning a dask array with no metadata, so nothing has actually happened other than building the task graph for dask. If you want to actually compute something, you need to call compute() on the returned dask object. As for adding metadata, that exercise is beyond the scope of this tutorial and is left as an exercise to whoever implements WRF-Python 2.x.
Let's test the getter function below (assuming your getter function has already been imported):
import xarray # Setting ds = xarray.open_mfdataset("/path/to/wrf_vortex_multi/moving_nest/wrfout_d02*", parallel=True) tv = tv_getter(ds, omp_threads=2) # Now actually compute tv (note: result is a numpy array) # Let's use 4 workers with 2 OpenMP threads for a total of 8 CPUs tv_result = tv.compute(num_workers=4)
-
Create an xarray extension if you want an object oriented API
If you want your API to work on the Dataset object itself, rather than in a separate function, xarray has an easy way to adding your own extensions which can be found here: http://xarray.pydata.org/en/stable/internals.html#extending-xarray
Since most of the WRF routines work on Datasets, we're going to create a Dataset extension. Again, the result of this will be a dask array, so handling when the actual computation is performed and metadata applied will be left as an exercise to whoever implements WRF-Python 2.x, but this should illustrate the basic concepts.
First, let's create the xarray extension class, which will add a new 'wrf' attribute to the Dataset API.
import xarray _FUNC_MAP = {'tv' : tv_getter} @xarray.register_dataset_accessor('wrf') class WRFDatasetExtension(object): def __init__(self, xarray_obj): self._obj = xarray_obj def getvar(self, product, omp_threads=1, **kwargs): return _FUNC_MAP[product](self._obj, omp_threads, **kwargs)
Now if you want to use this:
ds = xarray.open_mfdataset("/path/to/wrf_vortex_multi/moving_nest/wrfout_d02*", parallel=True) tv = ds.wrf.getvar("tv", omp_threads=2) # Compute the result result = tv.compute(num_workers=4)
-
Notes
-
Make sure you supply the dtype to map_blocks or it tries to send an array of 1's to your function to determine the type returned. For some reason this causeds a locked dask worker, possibly because the routine returns before the multiprocessing worker is ready for the result, or some other race condition. The threaded scheduler was fine. In any case, supplying dtype solves this problem.
-
The cape and wetbulb routines open a file in Fortran to read a lookup table, since this behavior was directly ported from NCL. There appears to be some threadsafety problems with this when testing the dask stuff with cape. In any case, reading of the lookup table should be moved out of Fortran and the arrays passed to the routine instead.
-
WRF-Python release the Global Interpreter Lock (GIL) before calling Fortran routines, so the threaded scheduler is fine.
-
Add OpenACC support to the WRF-Python routines and this thing should really fly....at the speed of I/O.
-
Needs a creative auto-chunker. The stuff above requires the user to chunk.
-