Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Nov 21, 2023
1 parent 29fda67 commit 948801a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 33 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,4 @@ bar
*.npz
*.json
*.req
dev/
35 changes: 3 additions & 32 deletions ai_models/inputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,34 +46,6 @@ def fields_sfc(self):
],
)

@cached_property
def fields_sfc_fc(self):
param, step = self.owner.param_sfc_fc
if not (param and step):
return cml.load_source("empty")

LOG.info(f"Loading FC surface fields from {self.WHERE}")
assert len(step) == 1, step # For now, we only support one step

return cml.load_source(
"multi",
[
self.sfc_load_source(
**self._patch(
type="fc",
date=date,
time=time,
param=param,
step=step,
grid=self.owner.grid,
area=self.owner.area,
**self.owner.retrieve,
)
)
for date, time in self.owner.datetimes()
],
)

@cached_property
def fields_pl(self):
param, level = self.owner.param_level_pl
Expand Down Expand Up @@ -128,15 +100,14 @@ def all_fields(self):
"all_fields",
"sfc",
len(self.fields_sfc),
"sfc_fc",
len(self.fields_sfc_fc),
"pl",
len(self.fields_pl),
"ml",
len(self.fields_ml),
'total', len(self.fields_sfc) + len(self.fields_sfc_fc) + len(self.fields_pl) + len(self.fields_ml)
"total",
len(self.fields_sfc) + len(self.fields_pl) + len(self.fields_ml),
)
return self.fields_sfc + self.fields_sfc_fc + self.fields_pl + self.fields_ml
return self.fields_sfc + self.fields_pl + self.fields_ml


class MarsInput(RequestBasedInput):
Expand Down
21 changes: 20 additions & 1 deletion ai_models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from collections import defaultdict
from functools import cached_property

import climetlab as cml
import entrypoints
import numpy as np
from climetlab.utils.humanize import seconds
from multiurl import download

Expand Down Expand Up @@ -66,7 +68,6 @@ class Model:
param_level_ml = ([], []) # param, level
param_level_pl = ([], []) # param, level
param_sfc = [] # param
param_sfc_fc = ([], []) # param, step

def __init__(self, input, output, download_assets, **kwargs):
self.input = get_input(input, self, **kwargs)
Expand Down Expand Up @@ -394,6 +395,24 @@ def provenance(self):

return gather_provenance_info(self.asset_files)

def forcing_and_constants(self, date, param):
ds = cml.load_source(
"constants",
self.all_fields,
date=date,
param=param,
)

return ds.to_numpy(dtype=np.float32)

@cached_property
def gridpoints(self):
return len(self.all_fields[0].grid_points()[0])

@cached_property
def start_datetime(self):
return self.all_fields.order_by(valid_datetime="ascending")[-1].datetime()


def load_model(name, **kwargs):
return available_models()[name].load()(**kwargs)
Expand Down

0 comments on commit 948801a

Please sign in to comment.