Skip to content

How to add dask support

Bill Ladwig edited this page Apr 3, 2019 · 7 revisions

See the steps below for how dask support can be added to WRF-Python in v2.0.

Let's choose a simple example, the 'tv' diagnostic from getvar, in particular the 'getter' function found in g_temp.py. The code is below (note: ignoring the metadata and units stuff here):

def get_tv(wrfin, timeidx=0, method="cat", squeeze=True,
           cache=None, meta=True, _key=None,
           units="K"):
    varnames = ("T", "P", "PB", "QVAPOR")
    ncvars = extract_vars(wrfin, timeidx, varnames, method, squeeze, cache,
                          meta=False, _key=_key)

    t = ncvars["T"]
    p = ncvars["P"]
    pb = ncvars["PB"]
    qv = ncvars["QVAPOR"]

    full_t = t + Constants.T_BASE
    full_p = p + pb
    tk = _tk(full_p, full_t)

    tv = _tv(tk, qv)

    return tv

In the code above, we extract a few variables from a WRF file, compute pressure and potential temperature, then tk (temperature in kelvin) and tv (virtual temperature). Below, let's show how this could be rewritten using xarray and dask, which can be easily wrapped in to an xarray extension.

  1. Create thin wrappers

    The _tv and _tk code in extension.py calls a Fortran routine and performs several common operations via wrapt decorators. Unfortunately, wrapt decorators don't serialize, so we need to create a thin wrappers around them.

    Also, in order to create the dask tasks, we need functions to pass to the dask map_blocks routine, so we'll also need wrappers around the "base + perturbation" operations above.

    Let's start with the _tv and _tk wrappers. For these functions, since OpenMP is already supported at the Fortran level, let's take an additional argument to set the number of OpenMP threads to use. Note that dask can do what OpenMP does, so this is entirely optional, but if you wanted to use dask tasks with OpenMP for the low level computation, this is one way to do it.

    def tk_wrap(pressure, theta, omp_threads=1):
        from wrf.extension import _tk, omp_set_num_threads
    
        omp_set_num_threads(omp_threads)
        result = _tk(pressure, theta)
    
        return result
    def tv_wrap(temp_k, qvapor, omp_threads=1):
        from wrf.extension import _tv, omp_set_num_threads
    
        omp_set_num_threads(omp_threads)
        result = _tv(temp_k, qvapor)
    
        return result

    Next, let's make a thin wrapper for the "perturbation + base" operation:

    def pert_add(base, perturbation):
        return base + perturbation

    In the above wrappers, we're assuming we want each operation to run as a separate dask task. You could put all of the above in to one wrapper so they run as one dask task. It is left as an exercise to the WRF-Python 2.x developer to determine which performs better (suspect that one task will perform better to minimize I/O and and message passing in a distributed environment).

  2. Create a 'getter' function for the 'tv' diagnostic

    The original getter method above takes several common arguments (wrfin, timeidx, method, squeeze, etc.). Since much of WRF-Python 1.x implements things xarray does, a lot of this can be gutted in WRF-Python 2.x in favor of xarray. So, this getter function only needs to take an xarray Dataset argument and the OpenMP number of threads argument (only if you want users to control OpenMP threads).

    For this example, we're going to assume dask is installed, but in a real implementation, the xarray.DataArray.data attribute might return a numpy array so it should be prepared for that. Here, we're only trying to show how to make dask work.

    Here is the code:

    from wrf import Constants
    from dask.array import map_blocks
    
    def tv_getter(ds, omp_threads=1):
    
        t = ds["T"].data
        p = ds["P"].data
        pb = ds["PB"].data
        qv = ds["QVAPOR"].data
    
        full_t = map_blocks(pert_add, Constants.T_BASE, t, omp_threads, dtype=t.dtype)
        
        full_p = map_blocks(pert_add, pb, p, omp_threads, dtype=p.dtype)
    
        tk = map_blocks(tk_wrap, full_p, full_t, omp_threads, dtype=p.dtype)
    
        tv = map_blocks(tv_wrap, tk, qv)
    
        return tv
        

    Note that the above is returning a dask array with no metadata, so nothing has actually happened other than building the task graph for dask. If you want to actually compute something, you need to call compute() on the returned dask object. As for adding metadata, that exercise is beyond the scope of this tutorial and is left as an exercise to whoever implements WRF-Python 2.x.

    Let's test the getter function below (assuming your getter function has already been imported):

    import xarray
    
    # Setting 
    ds = xarray.open_mfdataset("/path/to/wrf_vortex_multi/moving_nest/wrfout_d02*", parallel=True)
    
    tv = tv_getter(ds, omp_threads=1)
    
    # Now actually compute tv (note: result is a numpy array)
    tv_result = tv.compute()
  3. Create an xarray extension if you want an object oriented API

    If you want your API to work on the Dataset object itself, rather than in a separate function, then xarray has an easy way to adding your extensions, which can be found here: http://xarray.pydata.org/en/stable/internals.html#extending-xarray

    Since most of the WRF routines work on Datasets, we're going to create a Dataset extension. Again, the result of this will be a dask array, so handling when the actual computation is performed and metadata applied will be left as an exercise to whoever implements WRF-Python 2.x, but this should illustrate the basic concepts.

    First, let's create the xarray extension class, which will add a new 'wrf' attribute to the Dataset.

    import xarray 
    
    _FUNC_MAP = {'tv' : tv_getter}
    
    @xarray.register_dataset_accessor('wrf')
    class WRFDatasetExtension(object):
        def __init__(self, xarray_obj):
            self._obj = xarray_obj
    
        def getvar(self, product, omp_threads=1, **kwargs):
            return _FUNC_MAP[product](self._obj, omp_threads, **kwargs)

    Now if you want to use this:

    ds = xarray.open_mfdataset("/path/to/wrf_vortex_multi/moving_nest/wrfout_d02*", parallel=True)
    
    tv = ds.wrf.getvar("tv", omp_threads=1)
    
    # Compute the result
    result = tv.compute()
Clone this wiki locally