Skip to content

Commit

Permalink
filter requests
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed May 6, 2024
1 parent c8bed62 commit bcc3041
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 7 deletions.
18 changes: 18 additions & 0 deletions ai_models/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,19 @@ def _main(argv):
help=("Dump the requests in JSON format."),
)

parser.add_argument(
"--retrieve-fields-type",
help="Type of field to retrieve. To use with --retrieve-requests.",
choices=["constants", "prognostics", "all"],
default="all",
)

parser.add_argument(
"--retrieve-only-one-date",
help="Only retrieve the last date/time. To use with --retrieve-requests.",
action="store_true",
)

parser.add_argument(
"--dump-provenance",
metavar="FILE",
Expand Down Expand Up @@ -188,6 +201,11 @@ def _main(argv):
help="For encoding hincast-like outputs",
)

parser.add_argument(
"--hindcast-reference-date",
help="For encoding hincast-like outputs",
)

parser.add_argument(
"--staging-dates",
help="For encoding hincast-like outputs",
Expand Down
57 changes: 55 additions & 2 deletions ai_models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def write(self, *args, **kwargs):
def collect_archive_requests(self, written):
if self.archive_requests:
handle, path = written
if self.hindcast_reference_year:
if self.hindcast_reference_year or self.hindcast_reference_date:
# The clone is necessary because the handle
# does not return always return recently set keys
handle = handle.clone()
Expand Down Expand Up @@ -333,7 +333,7 @@ def print_requests(self):
for r in requests:
self._print_request("retrieve", r)

def _requests(self):
def _requests_unfiltered(self):
result = []

first = dict(
Expand Down Expand Up @@ -375,6 +375,55 @@ def _requests(self):

return result

def _requests(self):

def filter_constant(request):
# We check for 'sfc' because param 'z' can be ambiguous
if request.get("levtype") == "sfc":
param = set(self.constant_fields) & set(request.get("param", []))
if param:
request["param"] = list(param)
return True

return False

def filter_prognostic(request):
# TODO: We assume here that prognostic fields are
# the ones that are not constant. This may not always be true
if request.get("levtype") == "sfc":
param = set(request.get("param", [])) - set(self.constant_fields)
if param:
request["param"] = list(param)
return True
return False

return True

def filter_last_date(request):
date, time = max(self.datetimes())
return request["date"] == date and request["time"] == time

def noop(request):
return request

filter_type = {
"constants": filter_constant,
"prognostics": filter_prognostic,
"all": noop,
}[self.retrieve_fields_type]

filter_dates = {
True: filter_last_date,
False: noop,
}[self.retrieve_only_one_date]

result = []
for r in self._requests_unfiltered():
if filter_type(r) and filter_dates(r):
result.append(r)

return result

def patch_retrieve_request(self, request):
# Overriden in subclasses if needed
pass
Expand Down Expand Up @@ -413,6 +462,10 @@ def gridpoints(self):
def start_datetime(self):
return self.all_fields.order_by(valid_datetime="ascending")[-1].datetime()

@property
def constant_fields(self):
raise NotImplementedError("constant_fields")

def write_input_fields(
self,
fields,
Expand Down
32 changes: 27 additions & 5 deletions ai_models/outputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,21 @@ def write(self, data, *args, check=False, **kwargs):


class HindcastReLabel:
def __init__(self, owner, output, hindcast_reference_year, **kwargs):
def __init__(
self, owner, output, hindcast_reference_year, hind_cast_reference_date, **kwargs
):
self.owner = owner
self.output = output
self.hindcast_reference_year = int(hindcast_reference_year)
self.hindcast_reference_year = (
int(hindcast_reference_year) if hindcast_reference_year else None
)
self.hind_cast_reference_date = (
int(hind_cast_reference_date) if hind_cast_reference_date else None
)
assert (
self.hindcast_reference_year is not None
or self.hind_cast_reference_date is not None
)

def write(self, *args, **kwargs):
if "hdate" in kwargs:
Expand All @@ -105,7 +116,11 @@ def write(self, *args, **kwargs):

if hdate is not None:
# Input was a hindcast
referenceDate = self.hindcast_reference_year * 10000 + date % 10000
referenceDate = (
self.hind_cast_reference_date
if self.hind_cast_reference_date is not None
else self.hindcast_reference_year * 10000 + date % 10000
)
assert date == referenceDate, (
date,
referenceDate,
Expand All @@ -115,7 +130,11 @@ def write(self, *args, **kwargs):
kwargs["referenceDate"] = referenceDate
kwargs["hdate"] = hdate
else:
referenceDate = self.hindcast_reference_year * 10000 + date % 10000
referenceDate = (
self.hind_cast_reference_date
if self.hind_cast_reference_date is not None
else self.hindcast_reference_year * 10000 + date % 10000
)
kwargs["referenceDate"] = referenceDate
kwargs["hdate"] = date

Expand All @@ -134,7 +153,10 @@ def write(self, *args, **kwargs):

def get_output(name, owner, *args, **kwargs):
result = available_outputs()[name].load()(owner, *args, **kwargs)
if kwargs.get("hindcast_reference_year") is not None:
if (
kwargs.get("hindcast_reference_year") is not None
or kwargs.get("hindcast_reference_date") is not None
):
result = HindcastReLabel(owner, result, **kwargs)
return result

Expand Down

0 comments on commit bcc3041

Please sign in to comment.