Skip to content

Commit

Permalink
Add remote parameter lookup
Browse files Browse the repository at this point in the history
  • Loading branch information
gmertes committed Mar 12, 2024
1 parent ec969b6 commit 415355b
Showing 1 changed file with 62 additions and 14 deletions.
76 changes: 62 additions & 14 deletions ai_models/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import tempfile
import time
from functools import cached_property
from functools import cache, cached_property
from urllib.parse import urljoin

import climetlab as cml
Expand All @@ -17,23 +17,24 @@

class RemoteModel(Model):
def __init__(self, **kwargs):
kwargs["download_assets"] = False

super().__init__(**kwargs)

self.cfg = kwargs
self.client = RemoteClient()
self.cfg["download_assets"] = False
self.cfg["assets_extra_dir"] = None
self._param = {}
self.api = RemoteClient()

super().__init__(**self.cfg)

def run(self):
with tempfile.TemporaryDirectory() as tmpdirname:
input_file = os.path.join(tmpdirname, "input.grib")
output_file = os.path.join(tmpdirname, "output.grib")
self.all_fields.save(input_file)

self.client.input_file = input_file
self.client.output_file = output_file
self.api.input_file = input_file
self.api.output_file = output_file

self.client.run(self.cfg)
self.api.run(self.cfg)

ds = cml.load_source("file", output_file)
for field in ds:
Expand All @@ -42,6 +43,41 @@ def run(self):
def parse_model_args(self, args):
return None

def __getattr__(self, name):
return self.get_param(name)

@cache
def get_param(self, name):
return self.api.get_param(self.cfg["model"], name).get(name, None)

@cached_property
def param_level_ml(self):
return self.get_param("param_level_ml") or ([], [])

@cached_property
def param_level_pl(self):
return self.get_param("param_level_pl") or ([], [])

@cached_property
def param_sfc(self):
return self.get_param("param_sfc") or []

@cached_property
def lagged(self):
return self.get_param("lagged") or False

@cached_property
def version(self):
return self.get_param("version") or 1

@cached_property
def grib_extra_metadata(self):
return self.get_param("grib_extra_metadata") or {}

@cached_property
def retrieve(self):
return self.get_param("retrieve") or {}


class BearerAuth(requests.auth.AuthBase):
def __init__(self, token):
Expand Down Expand Up @@ -133,17 +169,29 @@ def run(self, cfg: dict):

LOG.debug("Result written to %s", self.output_file)

def _request(self, type, href, data=None, json=None, auth=None):
r = robust(type, retry_after=self._timeout)(
def get_param(self, model, param):
if isinstance(param, str):
return self._request(
requests.get, f"metadata/{model}/{param}", with_status=False
)
else:
return self._request(
requests.post, f"metadata/{model}", json=param, with_status=False
)

def _request(self, type, href, data=None, json=None, auth=None, with_status=True):
r = robust(type, retry_after=30)(
urljoin(self.url, href),
json=json,
data=data,
auth=self.auth,
timeout=self._timeout,
)

status, href = self._update_state(r)
return status, href
if with_status:
status, href = self._update_state(r)
return status, href
else:
return r.json()

def _update_state(self, response: requests.Response):
if response.status_code == 401:
Expand Down

0 comments on commit 415355b

Please sign in to comment.