Skip to content

Commit

Permalink
Merge from main
Browse files Browse the repository at this point in the history
  • Loading branch information
sandorkertesz committed Jan 18, 2024
2 parents 6458321 + c81b56c commit 927fe4b
Show file tree
Hide file tree
Showing 10 changed files with 275 additions and 236 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'
python-version: '3.10'

- name: Check that tag version matches code version
run: |
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,5 @@ bar
*.nc
*.npz
*.json
*.req
dev/
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

The `ai-models` command is used to run AI-based weather forecasting models. These models need to be installed independently.

## Usage

Although the source code `ai-models` and its plugins are available under open sources licences, some model weights may be available under a different licence. For example some models make their weights available under the CC-BY-NC-SA 4.0 license, which does not allow commercial use. For more informations, please check the license associated with each model on their main home page, that we link from each of the corresponding plugins.

## Prerequisites

Before using the `ai-models` command, ensure you have the following prerequisites:
Expand Down
2 changes: 1 addition & 1 deletion ai_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

__version__ = "0.2.5"
__version__ = "0.3.2"
23 changes: 23 additions & 0 deletions ai_models/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,29 @@ def _main():
help="Length of forecast in hours.",
)

parser.add_argument(
"--hindcast-reference-year",
help="For encoding hincast-like outputs",
)

parser.add_argument(
"--staging-dates",
help="For encoding hincast-like outputs",
)

parser.add_argument(
"--only-gpu",
help="Fail if GPU is not available",
action="store_true",
)

parser.add_argument(
"--deterministic",
help="Fail if GPU is not available",
action="store_true",
)

# TODO: deprecate that option
parser.add_argument(
"--model-version",
default="latest",
Expand Down
8 changes: 7 additions & 1 deletion ai_models/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import logging
import os
import pickle
import zipfile
from typing import Any

LOG = logging.getLogger(__name__)


class FakeStorage:
def __init__(self):
Expand Down Expand Up @@ -43,7 +46,7 @@ def tidy(x):
if isinstance(x, (int, float, str, bool)):
return x

return str(type(x))
return x


def peek(path):
Expand All @@ -57,6 +60,9 @@ def peek(path):
)
data_pkl = b

LOG.info(f"Found data.pkl at {data_pkl}")

with zipfile.ZipFile(path, "r") as f:
unpickler = UnpicklerWrapper(f.open(data_pkl, "r"))
x = tidy(unpickler.load())
return tidy(x)
53 changes: 48 additions & 5 deletions ai_models/inputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def _patch(self, **kargs):

@cached_property
def fields_sfc(self):
param = self.owner.param_sfc
if not param:
return cml.load_source("empty")

LOG.info(f"Loading surface fields from {self.WHERE}")

return ekd.from_source(
Expand All @@ -33,7 +37,7 @@ def fields_sfc(self):
**self._patch(
date=date,
time=time,
param=self.owner.param_sfc,
param=param,
grid=self.owner.grid,
area=self.owner.area,
**self.owner.retrieve,
Expand All @@ -45,8 +49,11 @@ def fields_sfc(self):

@cached_property
def fields_pl(self):
LOG.info(f"Loading pressure fields from {self.WHERE}")
param, level = self.owner.param_level_pl
if not (param and level):
return cml.load_source("empty")

LOG.info(f"Loading pressure fields from {self.WHERE}")
return ekd.from_source(
"multi",
[
Expand All @@ -64,9 +71,33 @@ def fields_pl(self):
],
)

@cached_property
def fields_ml(self):
param, level = self.owner.param_level_ml
if not (param and level):
return cml.load_source("empty")

LOG.info(f"Loading model fields from {self.WHERE}")
return ekd.from_source(
"multi",
[
self.ml_load_source(
**self._patch(
date=date,
time=time,
param=param,
level=level,
grid=self.owner.grid,
area=self.owner.area,
)
)
for date, time in self.owner.datetimes()
],
)

@cached_property
def all_fields(self):
return self.fields_sfc + self.fields_pl
return self.fields_sfc + self.fields_pl + self.fields_ml


class MarsInput(RequestBasedInput):
Expand All @@ -85,6 +116,11 @@ def sfc_load_source(self, **kwargs):
logging.debug("load source mars %s", kwargs)
return ekd.from_source("mars", kwargs)

def ml_load_source(self, **kwargs):
kwargs["levtype"] = "ml"
logging.debug("load source mars %s", kwargs)
return ekd.from_source("mars", kwargs)


class CdsInput(RequestBasedInput):
WHERE = "CDS"
Expand All @@ -97,6 +133,9 @@ def sfc_load_source(self, **kwargs):
kwargs["product_type"] = "reanalysis"
return ekd.from_source("cds", "reanalysis-era5-single-levels", kwargs)

def ml_load_source(self, **kwargs):
raise NotImplementedError("CDS does not support model levels")


class FileInput:
def __init__(self, owner, file, **kwargs):
Expand All @@ -105,11 +144,15 @@ def __init__(self, owner, file, **kwargs):

@cached_property
def fields_sfc(self):
return ekd.from_source("file", self.file).sel(levtype="sfc")
return self.all_fields.sel(levtype="sfc")

@cached_property
def fields_pl(self):
return ekd.from_source("file", self.file).sel(levtype="pl")
return self.all_fields.sel(levtype="pl")

@cached_property
def fields_ml(self):
return self.all_fields.sel(levtype="ml")

@cached_property
def all_fields(self):
Expand Down
Loading

0 comments on commit 927fe4b

Please sign in to comment.