diff --git a/ai_models/__main__.py b/ai_models/__main__.py index 32e2cf4..fa23868 100644 --- a/ai_models/__main__.py +++ b/ai_models/__main__.py @@ -191,6 +191,13 @@ def _main(): action="store_true", ) + parser.add_argument( + "--deterministic", + help="Fail if GPU is not available", + action="store_true", + ) + + # TODO: deprecate that option parser.add_argument( "--model-version", diff --git a/ai_models/inputs/__init__.py b/ai_models/inputs/__init__.py index b4d0196..d66553b 100644 --- a/ai_models/inputs/__init__.py +++ b/ai_models/inputs/__init__.py @@ -96,17 +96,6 @@ def fields_ml(self): @cached_property def all_fields(self): - print( - "all_fields", - "sfc", - len(self.fields_sfc), - "pl", - len(self.fields_pl), - "ml", - len(self.fields_ml), - "total", - len(self.fields_sfc) + len(self.fields_pl) + len(self.fields_ml), - ) return self.fields_sfc + self.fields_pl + self.fields_ml diff --git a/ai_models/model.py b/ai_models/model.py index 31dd8d3..7f2df8c 100644 --- a/ai_models/model.py +++ b/ai_models/model.py @@ -179,6 +179,16 @@ def device(self): return device + def torch_deterministic_mode(self): + import torch + LOG.info("Setting deterministic mode for PyTorch") + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.use_deterministic_algorithms(True) + + @cached_property def providers(self): import platform @@ -403,6 +413,8 @@ def forcing_and_constants(self, date, param): param=param, ) + assert len(ds) == len(param), (len(ds), len(param), date) + return ds.to_numpy(dtype=np.float32) @cached_property diff --git a/ai_models/outputs/__init__.py b/ai_models/outputs/__init__.py index 619aca5..bad612e 100644 --- a/ai_models/outputs/__init__.py +++ b/ai_models/outputs/__init__.py @@ -9,6 +9,7 @@ import logging import climetlab as cml +import numpy as np LOG = logging.getLogger(__name__) @@ -36,8 +37,10 @@ def __init__(self, owner, path, metadata, **kwargs): **self.grib_keys, ) - def write(self, *args, **kwargs): - return self.output.write(*args, **kwargs) + def write(self, data, *args, **kwargs): + assert not np.isnan(data).any(), (args, kwargs) + return self.output.write(data, *args, **kwargs) + class HindcastReLabel: diff --git a/setup.py b/setup.py index cd0bb03..7d14fe3 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ def read(fname): include_package_data=True, install_requires=[ "entrypoints", - "climetlab>=0.17.1, <= 0.18.6", + "climetlab>=0.19.0", "multiurl", "ecmwflibs>=0.5.3", "gputil",