Skip to content

Commit

Permalink
new index
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Nov 28, 2023
1 parent 948801a commit a11a211
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 14 deletions.
7 changes: 7 additions & 0 deletions ai_models/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,13 @@ def _main():
action="store_true",
)

parser.add_argument(
"--deterministic",
help="Fail if GPU is not available",
action="store_true",
)


# TODO: deprecate that option
parser.add_argument(
"--model-version",
Expand Down
11 changes: 0 additions & 11 deletions ai_models/inputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,6 @@ def fields_ml(self):

@cached_property
def all_fields(self):
print(
"all_fields",
"sfc",
len(self.fields_sfc),
"pl",
len(self.fields_pl),
"ml",
len(self.fields_ml),
"total",
len(self.fields_sfc) + len(self.fields_pl) + len(self.fields_ml),
)
return self.fields_sfc + self.fields_pl + self.fields_ml


Expand Down
12 changes: 12 additions & 0 deletions ai_models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,16 @@ def device(self):

return device

def torch_deterministic_mode(self):
import torch
LOG.info("Setting deterministic mode for PyTorch")
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)


@cached_property
def providers(self):
import platform
Expand Down Expand Up @@ -403,6 +413,8 @@ def forcing_and_constants(self, date, param):
param=param,
)

assert len(ds) == len(param), (len(ds), len(param), date)

return ds.to_numpy(dtype=np.float32)

@cached_property
Expand Down
7 changes: 5 additions & 2 deletions ai_models/outputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging

import climetlab as cml
import numpy as np

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -36,8 +37,10 @@ def __init__(self, owner, path, metadata, **kwargs):
**self.grib_keys,
)

def write(self, *args, **kwargs):
return self.output.write(*args, **kwargs)
def write(self, data, *args, **kwargs):
assert not np.isnan(data).any(), (args, kwargs)
return self.output.write(data, *args, **kwargs)



class HindcastReLabel:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def read(fname):
include_package_data=True,
install_requires=[
"entrypoints",
"climetlab>=0.17.1, <= 0.18.6",
"climetlab>=0.19.0",
"multiurl",
"ecmwflibs>=0.5.3",
"gputil",
Expand Down

0 comments on commit a11a211

Please sign in to comment.