Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/async callbacks #102

Open
wants to merge 108 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
108 commits
Select commit Hold shift + click to select a range
b12fac8
Refactor Callbacks
HCookie Sep 24, 2024
29a8477
Update changelog
HCookie Sep 24, 2024
15824be
Fix TypeError
HCookie Sep 24, 2024
4077bf4
Move to hydra.instantiate
HCookie Sep 25, 2024
494d39d
Merge remote-tracking branch 'origin/develop' into fix/refactor_callb…
HCookie Sep 25, 2024
fe37c02
Add __all__
HCookie Sep 25, 2024
2d8275c
Add to base config
HCookie Sep 25, 2024
230eb0e
Fix nested list
HCookie Sep 25, 2024
5547b20
Fix nested get issue
HCookie Sep 26, 2024
1d80cfb
Fix type checking
HCookie Sep 27, 2024
e79dfc7
Merge branch 'develop' into fxi/refactor_callbacks
HCookie Oct 1, 2024
96ab74c
feat: edge plot in callbacks
JPXKQX Oct 1, 2024
4aeb1a5
feat: set default extra callbacks
JPXKQX Oct 1, 2024
816b3af
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 2, 2024
644038f
fix: typing & refactoring
JPXKQX Oct 2, 2024
8356cd4
fix: remove list comprehension
JPXKQX Oct 2, 2024
930e4d2
Refactor according to PR
HCookie Oct 2, 2024
52ea91f
Update deprecation warning
HCookie Oct 4, 2024
0dd81b7
Merge branch 'fxi/refactor_callbacks' into feature/graph-features-cal…
JPXKQX Oct 4, 2024
332f746
Merge pull request #71 from ecmwf/feature/graph-features-callback
HCookie Oct 4, 2024
bb8b9bb
Refactor: Remove backwards compatability,
HCookie Oct 10, 2024
0349be2
Fix tests
HCookie Oct 10, 2024
1e97ff1
PR Fixes
HCookie Oct 15, 2024
d7f713e
Merge branch 'develop' into fix/refactor_callbacks
HCookie Oct 18, 2024
ebfaf90
Merge remote-tracking branch 'origin/develop' into fix/refactor_callb…
HCookie Oct 18, 2024
460c8ba
Update Changelog
HCookie Oct 18, 2024
5671c7e
Merge branch 'develop' into fix/refactor_callbacks
HCookie Oct 21, 2024
21c05de
Refactor rollout (#87)
HCookie Oct 21, 2024
3c5e144
Remove batch frequency from LongRolloutPlots
HCookie Oct 21, 2024
5742754
Merge remote-tracking branch 'origin/develop' into fix/refactor_callb…
HCookie Oct 21, 2024
8671543
Merge branch 'develop' into fxi/refactor_callbacks
HCookie Oct 22, 2024
382728c
Remove TP reference
HCookie Oct 22, 2024
6fa66cc
Remove missing config reference
HCookie Oct 23, 2024
752a94b
Authentication support for mlflow sync (#51)
gmertes Oct 11, 2024
11af912
New mlflow authentication API (#78)
gmertes Oct 11, 2024
a21bea6
Update changelog
HCookie Sep 24, 2024
4453557
rebase
mc4117 Oct 23, 2024
a3dd9c5
Update deprecation warning
HCookie Oct 4, 2024
046275b
Refactor: Remove backwards compatability,
HCookie Oct 10, 2024
ee2fcc3
add scatter plot
mc4117 Oct 22, 2024
799f314
adding async
mc4117 Oct 23, 2024
30a26b3
fix
mc4117 Oct 23, 2024
73ce9e5
tests
mc4117 Oct 23, 2024
a1a9019
fix failing tests
mc4117 Oct 23, 2024
82d7b63
rm change to ds valid
mc4117 Oct 23, 2024
3f49079
precommit hooks
mc4117 Oct 23, 2024
474e271
fix linting
mc4117 Oct 23, 2024
72af5f8
rebase
mc4117 Oct 23, 2024
22a96cd
Update deprecation warning
HCookie Oct 4, 2024
02ac1c6
Refactor: Remove backwards compatability,
HCookie Oct 10, 2024
ae39622
add scatter plot
mc4117 Oct 22, 2024
51557f6
adding async
mc4117 Oct 23, 2024
4a97f2b
fix
mc4117 Oct 23, 2024
bb6d3e9
tests
mc4117 Oct 23, 2024
d53bf50
fix failing tests
mc4117 Oct 23, 2024
c9cd810
rm change to ds valid
mc4117 Oct 23, 2024
64a1144
precommit hooks
mc4117 Oct 23, 2024
07efe0b
fix linting
mc4117 Oct 23, 2024
a3122ec
Merge branch 'fix/refactor_callbacks' into fxi/async_callbacks
HCookie Oct 23, 2024
da4a22b
revert unnecessary config changes
mc4117 Oct 23, 2024
ff8f8c4
fix merge conflict
mc4117 Oct 23, 2024
443bbef
change config files
mc4117 Oct 23, 2024
110fb64
Swapped histogram and spectrum
HCookie Oct 23, 2024
23cc785
Update copyright notice
HCookie Oct 23, 2024
bfe76f3
Merge branch 'develop' into fxi/refactor_callbacks
HCookie Oct 23, 2024
6b759fe
Merge branch 'fxi/refactor_callbacks' into fxi/async_callbacks
HCookie Oct 23, 2024
5a6880e
Merge branch 'develop' into fxi/refactor_callbacks
HCookie Oct 24, 2024
51a455d
Fix issues with split of PlotAdditionalMetrics
HCookie Oct 24, 2024
3318675
Merge branch 'fxi/refactor_callbacks' of github.com:ecmwf/anemoi-trai…
HCookie Oct 24, 2024
77bd65d
Merge remote-tracking branch 'origin/develop' into fix/refactor_callb…
HCookie Oct 24, 2024
3c6e1af
Fix CHANGELOG
HCookie Oct 25, 2024
86059a9
Fix documentation for callbacks
HCookie Oct 25, 2024
0bce490
Add all callback submodules to docs
HCookie Oct 25, 2024
f5057c6
Merge branch 'develop' into fxi/refactor_callbacks
HCookie Oct 25, 2024
d6e1d9c
Apply suggestions from code review
HCookie Oct 25, 2024
4c145bb
Merge branch 'fxi/refactor_callbacks' into fxi/async_callbacks
mc4117 Oct 25, 2024
6073d84
Fix init args issue in RolloutPlots
HCookie Oct 25, 2024
f1d883f
Add rollout_eval config
HCookie Oct 25, 2024
66bd306
Add training mode to rollout step
HCookie Oct 28, 2024
8dfe25d
Force LongRolloutPlots to plot in serial
HCookie Oct 28, 2024
942e06f
Add warning to LongRolloutPlots when async
HCookie Oct 28, 2024
8e6ab30
Merge branch 'develop' into fxi/refactor_callbacks
HCookie Oct 28, 2024
84072a6
Fix asserrt calculation
HCookie Oct 28, 2024
42b59e5
Merge branch 'develop' into fxi/refactor_callbacks
HCookie Oct 28, 2024
30dfd45
Apply post_processors before plotting in LongRolloutPlots
HCookie Oct 28, 2024
eebaf16
Fix reference to batch
HCookie Oct 28, 2024
b31da0e
Merge branch 'develop' into fxi/refactor_callbacks
HCookie Oct 28, 2024
8b2a30e
Fix debug config
HCookie Oct 28, 2024
f0b4ad7
Merge branch 'fxi/refactor_callbacks' into fxi/async_callbacks
HCookie Oct 29, 2024
7fe2c05
Merge remote-tracking branch 'origin/develop' into fix/refactor_callb…
HCookie Oct 29, 2024
44efa8c
brinding plot for mean wave direction and fixing type hinting
anaprietonem Oct 29, 2024
288c2a1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2024
5392363
Merge branch 'fxi/refactor_callbacks' into fxi/async_callbacks
anaprietonem Oct 29, 2024
426a4e3
Merge remote-tracking branch 'origin/develop' into fxi/async_callbacks
HCookie Oct 29, 2024
f42dbaa
Merge remote-tracking branch 'refs/remotes/origin/fxi/async_callbacks…
anaprietonem Oct 29, 2024
ca6974c
Merge branch 'fxi/async_callbacks' of github.com:ecmwf/anemoi-trainin…
anaprietonem Oct 29, 2024
15c634d
Merge branch 'develop' into fxi/async_callbacks
anaprietonem Oct 29, 2024
5918a58
add changelog entry
anaprietonem Oct 29, 2024
40141f6
fixes for async plots to work
anaprietonem Oct 29, 2024
c8e7138
fix pre-commit styling
anaprietonem Oct 29, 2024
31e5e82
improved loop closing and readability
anaprietonem Nov 1, 2024
848e76b
fixing for pre-commit hooks
anaprietonem Nov 1, 2024
ab2919f
Merge branch 'develop' into fxi/async_callbacks
anaprietonem Nov 1, 2024
806b663
remove commented block
anaprietonem Nov 4, 2024
d2b5b17
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 4, 2024
606194b
address sugestion for args and kwargs and missing type hints
anaprietonem Nov 6, 2024
2b2a4fa
update flag to datashader rather than scatter
anaprietonem Nov 6, 2024
97d7cbf
Merge branch 'fxi/async_callbacks' of github.com:ecmwf/anemoi-trainin…
anaprietonem Nov 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Keep it human-readable, your future self will thank you!
- Enable longer validation rollout than training
### Added
- Included more loss functions and allowed configuration [#70](https://github.com/ecmwf/anemoi-training/pull/70)
- Include option to use datashader and optimised asyncronohous callbacks [#102](https://github.com/ecmwf/anemoi-training/pull/102)
- Fix that applies the metric_ranges in the post-processed variable space [#116](https://github.com/ecmwf/anemoi-training/pull/116)
- Sub-hour datasets [#63](https://github.com/ecmwf/anemoi-training/pull/63)
- Add synchronisation workflow [#92](https://github.com/ecmwf/anemoi-training/pull/92)
Expand Down Expand Up @@ -95,6 +96,7 @@ Keep it human-readable, your future self will thank you!

- Updated configuration examples in documentation and corrected links - [#46](https://github.com/ecmwf/anemoi-training/pull/46)
- Remove credential prompt from mlflow login, replace with seed refresh token via web - [#78](https://github.com/ecmwf/anemoi-training/pull/78)
- Modified training configuration to support max_steps and tied lr iterations to max_steps by default [#67](https://github.com/ecmwf/anemoi-training/pull/67)
- Update CODEOWNERS

## [0.1.0 - Anemoi training - First release](https://github.com/ecmwf/anemoi-training/releases/tag/0.1.0) - 2024-08-16
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ dependencies = [
"anemoi-graphs",
"anemoi-models>=0.3",
"anemoi-utils[provenance]>=0.3.10",
"datashader>=0.16.3",
"einops>=0.6.1",
"hydra-core>=1.3",
"matplotlib>=3.7.1",
Expand Down
1 change: 1 addition & 0 deletions src/anemoi/training/config/diagnostics/plot/detailed.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
asynchronous: True # Whether to plot asynchronously
scatter: False # Choose which technique to use for plotting
frequency: # Frequency of the plotting
batch: 750
epoch: 5
Expand Down
1 change: 1 addition & 0 deletions src/anemoi/training/config/diagnostics/plot/simple.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
asynchronous: True # Whether to plot asynchronously
scatter: False # Choose which technique to use for plotting
frequency: # Frequency of the plotting
batch: 750
epoch: 10
Expand Down
3 changes: 1 addition & 2 deletions src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ def ds_train(self) -> NativeGridDataset:

@cached_property
def ds_valid(self) -> NativeGridDataset:
r = self.rollout
r = max(r, self.config.dataloader.get("validation_rollout", 1))
r = max(self.rollout, self.config.dataloader.get("validation_rollout", 1))

assert self.config.dataloader.training.end < self.config.dataloader.validation.start, (
f"Training end date {self.config.dataloader.training.end} is not before"
Expand Down
1 change: 1 addition & 0 deletions src/anemoi/training/diagnostics/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import TYPE_CHECKING
from typing import Callable
from typing import Iterable
from typing import Optional

from hydra.utils import instantiate
from omegaconf import DictConfig
Expand Down
159 changes: 81 additions & 78 deletions src/anemoi/training/diagnostics/callbacks/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

# ruff: noqa: ANN001

from __future__ import annotations

import asyncio
import copy
import logging
import sys
import threading
import time
import traceback
from abc import ABC
Expand All @@ -23,8 +23,6 @@
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
Expand All @@ -43,33 +41,14 @@
from anemoi.training.losses.weightedloss import BaseWeightedLoss

if TYPE_CHECKING:
from typing import Any

import pytorch_lightning as pl
from omegaconf import OmegaConf

LOGGER = logging.getLogger(__name__)


class ParallelExecutor(ThreadPoolExecutor):
"""Wraps parallel execution and provides accurate information about errors.

Extends ThreadPoolExecutor to preserve the original traceback and line number.

Reference: https://stackoverflow.com/questions/19309514/getting-original-line-
number-for-exception-in-concurrent-futures/24457608#24457608
"""

def submit(self, fn: Any, *args, **kwargs) -> Callable:
"""Submits the wrapped function instead of `fn`."""
return super().submit(self._function_wrapper, fn, *args, **kwargs)

def _function_wrapper(self, fn: Any, *args: list, **kwargs: dict) -> Callable:
"""Wraps `fn` in order to preserve the traceback of any kind of."""
try:
return fn(*args, **kwargs)
except Exception as exc:
raise sys.exc_info()[0](traceback.format_exc()) from exc


class BasePlotCallback(Callback, ABC):
"""Factory for creating a callback that plots data to Experiment Logging."""

Expand All @@ -93,11 +72,21 @@ def __init__(self, config: OmegaConf) -> None:

self.plot = self._plot
self._executor = None
self._error: BaseException = None
self.datashader_plotting = config.diagnostics.plot.datashader

if self.config.diagnostics.plot.asynchronous:
self._executor = ParallelExecutor(max_workers=1)
self._error: BaseException | None = None
LOGGER.info("Setting up asynchronous plotting ...")
self.plot = self._async_plot
self._executor = ThreadPoolExecutor(max_workers=1)
self.loop_thread = threading.Thread(target=self.start_event_loop, daemon=True)
self.loop_thread.start()

def start_event_loop(self) -> None:
"""Start the event loop in a separate thread."""
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.loop.run_forever()

@rank_zero_only
def _output_figure(
Expand All @@ -113,27 +102,48 @@ def _output_figure(
save_path = Path(
self.save_basedir,
"plots",
f"{tag}_epoch{epoch:03d}.png",
f"{tag}_epoch{epoch:03d}.jpg",
)

save_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(save_path, dpi=100, bbox_inches="tight")
fig.canvas.draw()
image_array = np.array(fig.canvas.renderer.buffer_rgba())
plt.imsave(save_path, image_array, dpi=100)
if self.config.diagnostics.log.wandb.enabled:
import wandb

logger.experiment.log({exp_log_tag: wandb.Image(fig)})

if self.config.diagnostics.log.mlflow.enabled:
run_id = logger.run_id
logger.experiment.log_artifact(run_id, str(save_path))

plt.close(fig) # cleanup

@rank_zero_only
def _plot_with_error_catching(self, trainer: pl.Trainer, args: Any, kwargs: Any) -> None:
"""To execute the plot function but ensuring we catch any errors."""
try:
self._plot(trainer, *args, **kwargs)
except BaseException:
import os

LOGGER.exception(traceback.format_exc())
os._exit(1) # to force exit when sanity val steps are used

def teardown(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None:
"""Method is called to close the threads."""
"""Teardown the callback."""
del trainer, pl_module, stage # unused
LOGGER.info("Teardown of the Plot Callback ...")

if self._executor is not None:
self._executor.shutdown(wait=True)
LOGGER.info("waiting and shutting down the executor ...")
self._executor.shutdown(wait=False, cancel_futures=True)

self.loop.call_soon_threadsafe(self.loop.stop)
self.loop_thread.join()
# Step 3: Close the asyncio event loop
self.loop_thread._stop()
self.loop_thread._delete()

def apply_output_mask(self, pl_module: pl.LightningModule, data: torch.Tensor) -> torch.Tensor:
if hasattr(pl_module, "output_mask") and pl_module.output_mask is not None:
Expand All @@ -147,31 +157,39 @@ def _plot(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
*args,
**kwargs,
*args: Any,
**kwargs: Any,
) -> None:
"""Plotting function to be implemented by subclasses."""

# Async function to run the plot function in the background thread
async def submit_plot(self, trainer: pl.Trainer, *args: Any, **kwargs: Any) -> None:
"""Async function or coroutine to schedule the plot function."""
loop = asyncio.get_running_loop()
# run_in_executor doesn't support keyword arguments,
await loop.run_in_executor(
self._executor,
self._plot_with_error_catching,
trainer,
args,
kwargs,
) # One because loop.run_in_executor expects positional arguments, not keyword arguments

@rank_zero_only
def _async_plot(
self,
trainer: pl.Trainer,
*args: list,
**kwargs: dict,
*args: Any,
**kwargs: Any,
) -> None:
"""To execute the plot function but ensuring we catch any errors."""
future = self._executor.submit(
self._plot,
trainer,
*args,
**kwargs,
)
# otherwise the error won't be thrown till the validation epoch is finished
try:
future.result()
except Exception:
LOGGER.exception("Critical error occurred in asynchronous plots.")
sys.exit(1)
"""Run the plot function asynchronously.

This is the function that is called by the callback. It schedules the plot
function to run in the background thread. Since we have an event loop running in
the background thread, we need to schedule the plot function to run in that
loop.
"""
asyncio.run_coroutine_threadsafe(self.submit_plot(trainer, *args, **kwargs), self.loop)


class BasePerBatchPlotCallback(BasePlotCallback):
Expand All @@ -192,26 +210,12 @@ def __init__(self, config: OmegaConf, batch_frequency: int | None = None):
super().__init__(config)
self.batch_frequency = batch_frequency or self.config.diagnostics.plot.frequency.batch

@abstractmethod
@rank_zero_only
def _plot(
def on_validation_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: list[torch.Tensor],
batch: torch.Tensor,
batch_idx: int,
epoch: int,
**kwargs,
) -> None:
"""Plotting function to be implemented by subclasses."""

@rank_zero_only
def on_validation_batch_end(
self,
trainer,
pl_module,
output,
output: list[torch.Tensor],
batch: torch.Tensor,
batch_idx: int,
**kwargs,
Expand Down Expand Up @@ -310,12 +314,12 @@ def __init__(
@rank_zero_only
def _plot(
self,
trainer,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
output: list[torch.Tensor],
batch: torch.Tensor,
batch_idx,
epoch,
batch_idx: int,
epoch: int,
) -> None:
_ = output

Expand Down Expand Up @@ -406,9 +410,9 @@ def _plot(
@rank_zero_only
def on_validation_batch_end(
self,
trainer,
pl_module,
output,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
output: list[torch.Tensor],
batch: torch.Tensor,
batch_idx: int,
) -> None:
Expand Down Expand Up @@ -451,18 +455,17 @@ def _plot(
pl_module: pl.LightningModule,
epoch: int,
) -> None:
_ = epoch
model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model

fig = plot_graph_node_features(model)
fig = plot_graph_node_features(model, datashader=self.datashader_plotting)

tag = "node_trainable_params"
exp_log_tag = "node_trainable_params"

self._output_figure(
trainer.logger,
fig,
epoch=trainer.current_epoch,
epoch=epoch,
tag=tag,
exp_log_tag=exp_log_tag,
)
Expand Down Expand Up @@ -493,7 +496,6 @@ def _plot(
pl_module: pl.LightningModule,
epoch: int,
) -> None:
_ = epoch

model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model
fig = plot_graph_edge_features(model)
Expand All @@ -504,7 +506,7 @@ def _plot(
self._output_figure(
trainer.logger,
fig,
epoch=trainer.current_epoch,
epoch=epoch,
tag=tag,
exp_log_tag=exp_log_tag,
)
Expand Down Expand Up @@ -785,6 +787,7 @@ def _plot(
data[0, ...].squeeze(),
data[rollout_step + 1, ...].squeeze(),
output_tensor[rollout_step, ...],
datashader=self.datashader_plotting,
precip_and_related_fields=self.precip_and_related_fields,
)

Expand Down Expand Up @@ -874,7 +877,7 @@ def _plot(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: list,
outputs: list[torch.Tensor],
batch: torch.Tensor,
batch_idx: int,
epoch: int,
Expand Down Expand Up @@ -956,7 +959,7 @@ def _plot(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: list,
outputs: list[torch.Tensor],
batch: torch.Tensor,
batch_idx: int,
epoch: int,
Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/training/diagnostics/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self) -> None:
def __call__(self, lon: np.ndarray, lat: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
lon_rad = np.radians(lon)
lat_rad = np.radians(lat)
x = [v - 2 * np.pi if v > np.pi else v for v in lon_rad]
x = np.array([v - 2 * np.pi if v > np.pi else v for v in lon_rad], dtype=lon_rad.dtype)
y = lat_rad
return x, y

Expand Down
Loading
Loading