Skip to content

Commit

Permalink
Experimental support for netcdf
Browse files Browse the repository at this point in the history
- Adds new netcdf output
- Checks for attrs to allow none grib data
  • Loading branch information
HCookie committed Oct 31, 2024
1 parent c5847cd commit 699f37c
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,4 @@ opendata = "ai_models.inputs.opendata:OpenDataInput"
[project.entry-points."ai_models.output"]
file = "ai_models.outputs:FileOutput"
none = "ai_models.outputs:NoneOutput"
netcdf = "ai_models.outputs:NetCDFOutput"
4 changes: 4 additions & 0 deletions src/ai_models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,10 @@ def write_input_fields(
if ignore is None:
ignore = []

if all(map(lambda x: x is None, fields.metadata("shortName", default = None))):
LOG.warning("Could not find 'shortName' in metadata. Are you using a grib input? Skipping writing input fields")
return

with self.timer("Writing step 0"):
for field in fields:
if field.metadata("shortName") in ignore:
Expand Down
35 changes: 35 additions & 0 deletions src/ai_models/outputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import warnings
from functools import cached_property
from collections import defaultdict

import earthkit.data as ekd
import entrypoints
Expand Down Expand Up @@ -110,6 +111,40 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
LOG.info("Writing results to %s", self.path)

class NetCDFOutput(Output):
def __init__(self, owner, path, metadata, **kwargs):
metadata.setdefault("stream", "oper")
metadata.setdefault("expver", owner.expver)
metadata.setdefault("class", "ml")

self.path = path
self.owner = owner
self.metadata = metadata

self._outputs = defaultdict(list)

def write(self, data, *args, check = False, **kwargs):
template = kwargs.pop("template")
step = kwargs.pop("step")
import xarray as xr

xarray_obj: xr.DataArray = template.to_xarray()
xarray_obj.data = data
xarray_obj = xarray_obj.assign_coords(step = step)
if 'pl' in xarray_obj.coords:
xarray_obj = xarray_obj.expand_dims('pl')
if 'ml' in xarray_obj.coords:
xarray_obj = xarray_obj.expand_dims('ml')

self._outputs[step].append(xarray_obj)

def flush(self, *args, **kwargs):
import xarray as xr

output = xr.concat(map(xr.merge, self._outputs.values()), dim = 'step')
output.attrs.update(self.metadata)
output.to_netcdf(self.path)


class NoneOutput(Output):
def __init__(self, *args, **kwargs):
Expand Down

0 comments on commit 699f37c

Please sign in to comment.