Skip to content

Commit

Permalink
add quiet option (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
malmans2 authored Apr 29, 2024
1 parent 2a155e2 commit e2f2012
Showing 1 changed file with 30 additions and 5 deletions.
35 changes: 30 additions & 5 deletions c3s_eqc_automatic_quality_control/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
# limitations under the License.

import calendar
import contextlib
import fnmatch
import functools
import itertools
import os
import pathlib
from collections.abc import Callable
from collections.abc import Callable, Iterator
from typing import Any

import cacholote
Expand All @@ -48,6 +50,19 @@
_SORTED_REQUEST_PARAMETERS = ("area", "grid")


@contextlib.contextmanager
def _set_env(**kwargs: Any) -> Iterator[None]:
old_environ = dict(os.environ)
try:
os.environ.update(
{k.upper(): str(v) for k, v in kwargs.items() if v is not None}
)
yield
finally:
os.environ.clear()
os.environ.update(old_environ)


def compute_stop_date(switch_month_day: int | None = None) -> pd.Period:
today = pd.Timestamp.today()
if switch_month_day is None:
Expand Down Expand Up @@ -306,7 +321,8 @@ def get_sources(
) -> list[str]:
source: set[str] = set()

for request in tqdm.tqdm(request_list, disable=len(request_list) <= 1):
disable = os.getenv("TQDM_DISABLE", "False") == "True"
for request in tqdm.tqdm(request_list, disable=disable):
data = _cached_retrieve(collection_id, request)
if content := getattr(data, "_content", None):
source.update(map(str, content))
Expand Down Expand Up @@ -479,6 +495,7 @@ def download_and_transform(
n_jobs: int | None = None,
invalidate_cache: bool | None = None,
cached_open_mfdataset_kwargs: bool | dict[str, Any] = {},
quiet: bool = False,
**open_mfdataset_kwargs: Any,
) -> xr.Dataset:
"""
Expand Down Expand Up @@ -513,13 +530,17 @@ def download_and_transform(
cached_open_mfdataset_kwargs: bool | dict
Kwargs to be passed on to xr.open_mfdataset for cached files.
If True, use open_mfdataset_kwargs used for raw files.
quiet: bool
Whether to disable progress bars.
**open_mfdataset_kwargs:
Kwargs to be passed on to xr.open_mfdataset for raw files.
Returns
-------
xr.Dataset
"""
assert isinstance(quiet, bool)

if n_jobs is None:
n_jobs = N_JOBS

Expand Down Expand Up @@ -557,12 +578,15 @@ def download_and_transform(
if use_cache and transform_chunks:
# Cache each chunk transformed
sources = []
for request in tqdm.tqdm(request_list):
for request in tqdm.tqdm(request_list, disable=quiet):
if invalidate_cache:
cacholote.delete(
func.func, *func.args, request_list=[request], **func.keywords
)
with cacholote.config.set(return_cache_entry=True):
with (
cacholote.config.set(return_cache_entry=True),
_set_env(tqdm_disable=True),
):
sources.append(func(request_list=[request]).result["args"][0]["href"])
ds = xr.open_mfdataset(sources, **cached_open_mfdataset_kwargs)
else:
Expand All @@ -571,7 +595,8 @@ def download_and_transform(
cacholote.delete(
func.func, *func.args, request_list=request_list, **func.keywords
)
ds = func(request_list=request_list)
with _set_env(tqdm_disable=quiet):
ds = func(request_list=request_list)

ds.attrs.pop("coordinates", None) # Previously added to guarantee roundtrip
return ds

0 comments on commit e2f2012

Please sign in to comment.