From bcc3041ce45fc99d8d96458bfe293b9179570f2e Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 6 May 2024 19:09:19 +0000 Subject: [PATCH] filter requests --- ai_models/__main__.py | 18 +++++++++++ ai_models/model.py | 57 +++++++++++++++++++++++++++++++++-- ai_models/outputs/__init__.py | 32 +++++++++++++++++--- 3 files changed, 100 insertions(+), 7 deletions(-) diff --git a/ai_models/__main__.py b/ai_models/__main__.py index 7e3308c..1258dad 100644 --- a/ai_models/__main__.py +++ b/ai_models/__main__.py @@ -75,6 +75,19 @@ def _main(argv): help=("Dump the requests in JSON format."), ) + parser.add_argument( + "--retrieve-fields-type", + help="Type of field to retrieve. To use with --retrieve-requests.", + choices=["constants", "prognostics", "all"], + default="all", + ) + + parser.add_argument( + "--retrieve-only-one-date", + help="Only retrieve the last date/time. To use with --retrieve-requests.", + action="store_true", + ) + parser.add_argument( "--dump-provenance", metavar="FILE", @@ -188,6 +201,11 @@ def _main(argv): help="For encoding hincast-like outputs", ) + parser.add_argument( + "--hindcast-reference-date", + help="For encoding hincast-like outputs", + ) + parser.add_argument( "--staging-dates", help="For encoding hincast-like outputs", diff --git a/ai_models/model.py b/ai_models/model.py index 5593043..3bfb412 100644 --- a/ai_models/model.py +++ b/ai_models/model.py @@ -125,7 +125,7 @@ def write(self, *args, **kwargs): def collect_archive_requests(self, written): if self.archive_requests: handle, path = written - if self.hindcast_reference_year: + if self.hindcast_reference_year or self.hindcast_reference_date: # The clone is necessary because the handle # does not return always return recently set keys handle = handle.clone() @@ -333,7 +333,7 @@ def print_requests(self): for r in requests: self._print_request("retrieve", r) - def _requests(self): + def _requests_unfiltered(self): result = [] first = dict( @@ -375,6 +375,55 @@ def _requests(self): return result + def _requests(self): + + def filter_constant(request): + # We check for 'sfc' because param 'z' can be ambiguous + if request.get("levtype") == "sfc": + param = set(self.constant_fields) & set(request.get("param", [])) + if param: + request["param"] = list(param) + return True + + return False + + def filter_prognostic(request): + # TODO: We assume here that prognostic fields are + # the ones that are not constant. This may not always be true + if request.get("levtype") == "sfc": + param = set(request.get("param", [])) - set(self.constant_fields) + if param: + request["param"] = list(param) + return True + return False + + return True + + def filter_last_date(request): + date, time = max(self.datetimes()) + return request["date"] == date and request["time"] == time + + def noop(request): + return request + + filter_type = { + "constants": filter_constant, + "prognostics": filter_prognostic, + "all": noop, + }[self.retrieve_fields_type] + + filter_dates = { + True: filter_last_date, + False: noop, + }[self.retrieve_only_one_date] + + result = [] + for r in self._requests_unfiltered(): + if filter_type(r) and filter_dates(r): + result.append(r) + + return result + def patch_retrieve_request(self, request): # Overriden in subclasses if needed pass @@ -413,6 +462,10 @@ def gridpoints(self): def start_datetime(self): return self.all_fields.order_by(valid_datetime="ascending")[-1].datetime() + @property + def constant_fields(self): + raise NotImplementedError("constant_fields") + def write_input_fields( self, fields, diff --git a/ai_models/outputs/__init__.py b/ai_models/outputs/__init__.py index b58e97b..c8bb887 100644 --- a/ai_models/outputs/__init__.py +++ b/ai_models/outputs/__init__.py @@ -82,10 +82,21 @@ def write(self, data, *args, check=False, **kwargs): class HindcastReLabel: - def __init__(self, owner, output, hindcast_reference_year, **kwargs): + def __init__( + self, owner, output, hindcast_reference_year, hind_cast_reference_date, **kwargs + ): self.owner = owner self.output = output - self.hindcast_reference_year = int(hindcast_reference_year) + self.hindcast_reference_year = ( + int(hindcast_reference_year) if hindcast_reference_year else None + ) + self.hind_cast_reference_date = ( + int(hind_cast_reference_date) if hind_cast_reference_date else None + ) + assert ( + self.hindcast_reference_year is not None + or self.hind_cast_reference_date is not None + ) def write(self, *args, **kwargs): if "hdate" in kwargs: @@ -105,7 +116,11 @@ def write(self, *args, **kwargs): if hdate is not None: # Input was a hindcast - referenceDate = self.hindcast_reference_year * 10000 + date % 10000 + referenceDate = ( + self.hind_cast_reference_date + if self.hind_cast_reference_date is not None + else self.hindcast_reference_year * 10000 + date % 10000 + ) assert date == referenceDate, ( date, referenceDate, @@ -115,7 +130,11 @@ def write(self, *args, **kwargs): kwargs["referenceDate"] = referenceDate kwargs["hdate"] = hdate else: - referenceDate = self.hindcast_reference_year * 10000 + date % 10000 + referenceDate = ( + self.hind_cast_reference_date + if self.hind_cast_reference_date is not None + else self.hindcast_reference_year * 10000 + date % 10000 + ) kwargs["referenceDate"] = referenceDate kwargs["hdate"] = date @@ -134,7 +153,10 @@ def write(self, *args, **kwargs): def get_output(name, owner, *args, **kwargs): result = available_outputs()[name].load()(owner, *args, **kwargs) - if kwargs.get("hindcast_reference_year") is not None: + if ( + kwargs.get("hindcast_reference_year") is not None + or kwargs.get("hindcast_reference_date") is not None + ): result = HindcastReLabel(owner, result, **kwargs) return result