From 55c42ac52157b8b55127e615bced0a07570758eb Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Mon, 2 Sep 2024 07:47:54 +0000 Subject: [PATCH 01/14] feat: compute usable dates Co-authored-by: Magnus Sikora --- src/anemoi/training/data/dataset.py | 17 +++++--- src/anemoi/training/utils/usable_indices.py | 46 +++++++++++++++++++++ 2 files changed, 58 insertions(+), 5 deletions(-) create mode 100644 src/anemoi/training/utils/usable_indices.py diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index caaa986b..883d7593 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -19,6 +19,7 @@ from torch.utils.data import get_worker_info from anemoi.training.utils.seeding import get_base_seed +from anemoi.training.utils.usable_indices import get_usable_indices LOGGER = logging.getLogger(__name__) @@ -110,6 +111,15 @@ def resolution(self) -> dict: """Return dataset resolution.""" return self.data.resolution + @cached_property + def valid_dates(self) -> np.ndarray: + """Return valid dates. + + If there are no missing dates, total number of valid ICs is + dataset length minus rollout minus additional multistep inputs. + """ + return get_usable_indices(self.data.missing, len(self.data), self.rollout, self.multi_step, self.timeincrement) + def per_worker_init(self, n_workers: int, worker_id: int) -> None: """Called by worker_init_func on each copy of dataset. @@ -125,11 +135,8 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: """ self.worker_id = worker_id - # Total number of valid ICs is dataset length minus rollout minus additional multistep inputs - len_corrected = len(self.data) - (self.rollout + (self.multi_step - 1)) * self.timeincrement - # Divide this equally across shards (one shard per group!) - shard_size = len_corrected // self.model_comm_num_groups + shard_size = len(self.valid_dates) // self.model_comm_num_groups shard_start = self.model_comm_group_id * shard_size + (self.multi_step - 1) * self.timeincrement shard_end = min((self.model_comm_group_id + 1) * shard_size, len(self.data) - self.rollout * self.timeincrement) @@ -149,7 +156,7 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: high, ) - self.chunk_index_range = np.arange(low, high, dtype=np.uint32) + self.chunk_index_range = self.valid_dates[np.arange(low, high, dtype=np.uint32)] # each worker must have a different seed for its random number generator, # otherwise all the workers will output exactly the same data diff --git a/src/anemoi/training/utils/usable_indices.py b/src/anemoi/training/utils/usable_indices.py new file mode 100644 index 00000000..8989fd5f --- /dev/null +++ b/src/anemoi/training/utils/usable_indices.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import numpy as np + + +def get_usable_indices( + missing_indices: set[int] | None, + series_length: int, + rollout: int, + multistep: int, + timeincrement: int = 1, +) -> np.ndarray: + """Get the usable indices of a series whit missing indices. + Parameters + ---------- + missing_indices : set[int] + Dataset to be used. + series_length : int + Length of the series. + rollout : int + Number of steps to roll out. + multistep : int + Number of previous indices to include as predictors. + timeincrement : int + Time increment, by default 1. + Returns + ------- + usable_indices : np.array + Array of usable indices. + """ + prev_invalid_dates = (multistep - 1) * timeincrement + next_invalid_dates = rollout * timeincrement + + usable_indices = np.arange(series_length) # set of all indices + + # No missing indices + if missing_indices is None: + return usable_indices[prev_invalid_dates : series_length - next_invalid_dates] + + missing_indices |= {-1, len(missing_indices)} # to filter initial and final indices + + # Missing indices + for i in missing_indices: + usable_indices = usable_indices[(usable_indices < i - next_invalid_dates) + (usable_indices > i + prev_invalid_dates)] + + return usable_indices From 7ad7784c13ab8a1ccd24d1937b70223d518bdca8 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Mon, 2 Sep 2024 08:28:27 +0000 Subject: [PATCH 02/14] test: usable indices --- src/anemoi/training/data/dataset.py | 2 +- src/anemoi/training/utils/usable_indices.py | 15 ++++++- tests/utils/test_usable_indices.py | 45 +++++++++++++++++++++ 3 files changed, 59 insertions(+), 3 deletions(-) create mode 100644 tests/utils/test_usable_indices.py diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index 883d7593..bc3633ab 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -114,7 +114,7 @@ def resolution(self) -> dict: @cached_property def valid_dates(self) -> np.ndarray: """Return valid dates. - + If there are no missing dates, total number of valid ICs is dataset length minus rollout minus additional multistep inputs. """ diff --git a/src/anemoi/training/utils/usable_indices.py b/src/anemoi/training/utils/usable_indices.py index 8989fd5f..5a126d9a 100644 --- a/src/anemoi/training/utils/usable_indices.py +++ b/src/anemoi/training/utils/usable_indices.py @@ -1,3 +1,10 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + from __future__ import annotations import numpy as np @@ -11,6 +18,7 @@ def get_usable_indices( timeincrement: int = 1, ) -> np.ndarray: """Get the usable indices of a series whit missing indices. + Parameters ---------- missing_indices : set[int] @@ -23,6 +31,7 @@ def get_usable_indices( Number of previous indices to include as predictors. timeincrement : int Time increment, by default 1. + Returns ------- usable_indices : np.array @@ -37,10 +46,12 @@ def get_usable_indices( if missing_indices is None: return usable_indices[prev_invalid_dates : series_length - next_invalid_dates] - missing_indices |= {-1, len(missing_indices)} # to filter initial and final indices + missing_indices |= {-1, series_length} # to filter initial and final indices # Missing indices for i in missing_indices: - usable_indices = usable_indices[(usable_indices < i - next_invalid_dates) + (usable_indices > i + prev_invalid_dates)] + usable_indices = usable_indices[ + (usable_indices < i - next_invalid_dates) + (usable_indices > i + prev_invalid_dates) + ] return usable_indices diff --git a/tests/utils/test_usable_indices.py b/tests/utils/test_usable_indices.py new file mode 100644 index 00000000..6bc5c83f --- /dev/null +++ b/tests/utils/test_usable_indices.py @@ -0,0 +1,45 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import numpy as np + +from anemoi.training.utils.usable_indices import get_usable_indices + + +def test_get_usable_indices() -> None: + """Test get_usable_indices function.""" + # Test base case + valid_indices = get_usable_indices(missing_indices=None, series_length=10, rollout=1, multistep=1, timeincrement=1) + expected_values = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8]) + assert np.allclose(valid_indices, expected_values) + + # Test multiple steps inputs + valid_indices = get_usable_indices(missing_indices=None, series_length=10, rollout=1, multistep=2, timeincrement=1) + expected_values = np.array([1, 2, 3, 4, 5, 6, 7, 8]) + assert np.allclose(valid_indices, expected_values) + + # Test roll out + valid_indices = get_usable_indices(missing_indices=None, series_length=10, rollout=2, multistep=1, timeincrement=1) + expected_values = np.array([0, 1, 2, 3, 4, 5, 6, 7]) + assert np.allclose(valid_indices, expected_values) + + # Test longer time increments + valid_indices = get_usable_indices(missing_indices=None, series_length=10, rollout=1, multistep=2, timeincrement=2) + expected_values = np.array([2, 3, 4, 5, 6, 7]) + assert np.allclose(valid_indices, expected_values) + + # Test missing indices + missing_indices = {7, 5} + valid_indices = get_usable_indices( + missing_indices=missing_indices, + series_length=10, + rollout=1, + multistep=2, + timeincrement=1, + ) + expected_values = np.array([1, 2, 3]) + assert np.allclose(valid_indices, expected_values) From ef5da0e661d304c57244900d5d54582474e3333e Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz <48736305+JPXKQX@users.noreply.github.com> Date: Wed, 4 Sep 2024 11:25:16 +0100 Subject: [PATCH 03/14] Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index de7b1209..7ee3a9a4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ Keep it human-readable, your future self will thank you! - Enable the callback for plotting a histogram for variables containing NaNs - Enforce same binning for histograms comparing true data to predicted data +- Support training for datasets with missing time steps ## [0.1.0 - Anemoi training - First release](https://github.com/ecmwf/anemoi-training/compare/x.x.x...0.1.0) - 2024-08-16 From 09e5cba22ca77a9d407c963e68e1dfa6287062f2 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 5 Sep 2024 15:06:43 +0000 Subject: [PATCH 04/14] feat: add link to PR --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ee3a9a4..e1642f08 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,7 @@ Keep it human-readable, your future self will thank you! - Enable the callback for plotting a histogram for variables containing NaNs - Enforce same binning for histograms comparing true data to predicted data -- Support training for datasets with missing time steps +- Support training for datasets with missing time steps [#48](https://github.com/ecmwf/anemoi-training/pulls/48) ## [0.1.0 - Anemoi training - First release](https://github.com/ecmwf/anemoi-training/compare/x.x.x...0.1.0) - 2024-08-16 From 30913fe7136d367902d4eab3c89627a1b5b3348b Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 6 Sep 2024 14:18:43 +0000 Subject: [PATCH 05/14] refactor: get_usable_indices --- src/anemoi/training/utils/usable_indices.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/anemoi/training/utils/usable_indices.py b/src/anemoi/training/utils/usable_indices.py index 5a126d9a..7bdd5cbd 100644 --- a/src/anemoi/training/utils/usable_indices.py +++ b/src/anemoi/training/utils/usable_indices.py @@ -42,9 +42,8 @@ def get_usable_indices( usable_indices = np.arange(series_length) # set of all indices - # No missing indices if missing_indices is None: - return usable_indices[prev_invalid_dates : series_length - next_invalid_dates] + missing_indices = set() missing_indices |= {-1, series_length} # to filter initial and final indices From e17e3a4af1858e3296f6e16d849ce24685d81a18 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 6 Sep 2024 14:19:41 +0000 Subject: [PATCH 06/14] fix: shard_start/end Co-authored-by: Magnus Sikora --- src/anemoi/training/data/dataset.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index bc3633ab..e2aa12bd 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -112,11 +112,16 @@ def resolution(self) -> dict: return self.data.resolution @cached_property - def valid_dates(self) -> np.ndarray: - """Return valid dates. + def valid_date_indices(self) -> np.ndarray: + """Return valid date indices. + + A date t is valid if we can sample the sequence + (t - multistep + 1, ..., t + rollout) + without missing data (if time_increment is 1). If there are no missing dates, total number of valid ICs is - dataset length minus rollout minus additional multistep inputs. + dataset length minus rollout minus additional multistep inputs + (if time_increment is 1). """ return get_usable_indices(self.data.missing, len(self.data), self.rollout, self.multi_step, self.timeincrement) @@ -136,9 +141,9 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: self.worker_id = worker_id # Divide this equally across shards (one shard per group!) - shard_size = len(self.valid_dates) // self.model_comm_num_groups - shard_start = self.model_comm_group_id * shard_size + (self.multi_step - 1) * self.timeincrement - shard_end = min((self.model_comm_group_id + 1) * shard_size, len(self.data) - self.rollout * self.timeincrement) + shard_size = len(self.valid_date_indices) // self.model_comm_num_groups + shard_start = self.model_comm_group_id * shard_size + shard_end = (self.model_comm_group_id + 1) * shard_size shard_len = shard_end - shard_start self.n_samples_per_worker = shard_len // n_workers @@ -156,7 +161,7 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: high, ) - self.chunk_index_range = self.valid_dates[np.arange(low, high, dtype=np.uint32)] + self.chunk_index_range = self.valid_date_indices[np.arange(low, high, dtype=np.uint32)] # each worker must have a different seed for its random number generator, # otherwise all the workers will output exactly the same data From f79b0c48ea0a9ceaf0960e00e367ff34deb6f4fd Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Mon, 2 Sep 2024 07:47:54 +0000 Subject: [PATCH 07/14] feat: compute usable dates Co-authored-by: Magnus Sikora --- src/anemoi/training/data/dataset.py | 17 +++++--- src/anemoi/training/utils/usable_indices.py | 46 +++++++++++++++++++++ 2 files changed, 58 insertions(+), 5 deletions(-) create mode 100644 src/anemoi/training/utils/usable_indices.py diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index caaa986b..883d7593 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -19,6 +19,7 @@ from torch.utils.data import get_worker_info from anemoi.training.utils.seeding import get_base_seed +from anemoi.training.utils.usable_indices import get_usable_indices LOGGER = logging.getLogger(__name__) @@ -110,6 +111,15 @@ def resolution(self) -> dict: """Return dataset resolution.""" return self.data.resolution + @cached_property + def valid_dates(self) -> np.ndarray: + """Return valid dates. + + If there are no missing dates, total number of valid ICs is + dataset length minus rollout minus additional multistep inputs. + """ + return get_usable_indices(self.data.missing, len(self.data), self.rollout, self.multi_step, self.timeincrement) + def per_worker_init(self, n_workers: int, worker_id: int) -> None: """Called by worker_init_func on each copy of dataset. @@ -125,11 +135,8 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: """ self.worker_id = worker_id - # Total number of valid ICs is dataset length minus rollout minus additional multistep inputs - len_corrected = len(self.data) - (self.rollout + (self.multi_step - 1)) * self.timeincrement - # Divide this equally across shards (one shard per group!) - shard_size = len_corrected // self.model_comm_num_groups + shard_size = len(self.valid_dates) // self.model_comm_num_groups shard_start = self.model_comm_group_id * shard_size + (self.multi_step - 1) * self.timeincrement shard_end = min((self.model_comm_group_id + 1) * shard_size, len(self.data) - self.rollout * self.timeincrement) @@ -149,7 +156,7 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: high, ) - self.chunk_index_range = np.arange(low, high, dtype=np.uint32) + self.chunk_index_range = self.valid_dates[np.arange(low, high, dtype=np.uint32)] # each worker must have a different seed for its random number generator, # otherwise all the workers will output exactly the same data diff --git a/src/anemoi/training/utils/usable_indices.py b/src/anemoi/training/utils/usable_indices.py new file mode 100644 index 00000000..8989fd5f --- /dev/null +++ b/src/anemoi/training/utils/usable_indices.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import numpy as np + + +def get_usable_indices( + missing_indices: set[int] | None, + series_length: int, + rollout: int, + multistep: int, + timeincrement: int = 1, +) -> np.ndarray: + """Get the usable indices of a series whit missing indices. + Parameters + ---------- + missing_indices : set[int] + Dataset to be used. + series_length : int + Length of the series. + rollout : int + Number of steps to roll out. + multistep : int + Number of previous indices to include as predictors. + timeincrement : int + Time increment, by default 1. + Returns + ------- + usable_indices : np.array + Array of usable indices. + """ + prev_invalid_dates = (multistep - 1) * timeincrement + next_invalid_dates = rollout * timeincrement + + usable_indices = np.arange(series_length) # set of all indices + + # No missing indices + if missing_indices is None: + return usable_indices[prev_invalid_dates : series_length - next_invalid_dates] + + missing_indices |= {-1, len(missing_indices)} # to filter initial and final indices + + # Missing indices + for i in missing_indices: + usable_indices = usable_indices[(usable_indices < i - next_invalid_dates) + (usable_indices > i + prev_invalid_dates)] + + return usable_indices From 86adf33862eea1d86b283b417101137d75dfe152 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Mon, 2 Sep 2024 08:28:27 +0000 Subject: [PATCH 08/14] test: usable indices --- src/anemoi/training/data/dataset.py | 2 +- src/anemoi/training/utils/usable_indices.py | 15 ++++++- tests/utils/test_usable_indices.py | 45 +++++++++++++++++++++ 3 files changed, 59 insertions(+), 3 deletions(-) create mode 100644 tests/utils/test_usable_indices.py diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index 883d7593..bc3633ab 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -114,7 +114,7 @@ def resolution(self) -> dict: @cached_property def valid_dates(self) -> np.ndarray: """Return valid dates. - + If there are no missing dates, total number of valid ICs is dataset length minus rollout minus additional multistep inputs. """ diff --git a/src/anemoi/training/utils/usable_indices.py b/src/anemoi/training/utils/usable_indices.py index 8989fd5f..5a126d9a 100644 --- a/src/anemoi/training/utils/usable_indices.py +++ b/src/anemoi/training/utils/usable_indices.py @@ -1,3 +1,10 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + from __future__ import annotations import numpy as np @@ -11,6 +18,7 @@ def get_usable_indices( timeincrement: int = 1, ) -> np.ndarray: """Get the usable indices of a series whit missing indices. + Parameters ---------- missing_indices : set[int] @@ -23,6 +31,7 @@ def get_usable_indices( Number of previous indices to include as predictors. timeincrement : int Time increment, by default 1. + Returns ------- usable_indices : np.array @@ -37,10 +46,12 @@ def get_usable_indices( if missing_indices is None: return usable_indices[prev_invalid_dates : series_length - next_invalid_dates] - missing_indices |= {-1, len(missing_indices)} # to filter initial and final indices + missing_indices |= {-1, series_length} # to filter initial and final indices # Missing indices for i in missing_indices: - usable_indices = usable_indices[(usable_indices < i - next_invalid_dates) + (usable_indices > i + prev_invalid_dates)] + usable_indices = usable_indices[ + (usable_indices < i - next_invalid_dates) + (usable_indices > i + prev_invalid_dates) + ] return usable_indices diff --git a/tests/utils/test_usable_indices.py b/tests/utils/test_usable_indices.py new file mode 100644 index 00000000..6bc5c83f --- /dev/null +++ b/tests/utils/test_usable_indices.py @@ -0,0 +1,45 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import numpy as np + +from anemoi.training.utils.usable_indices import get_usable_indices + + +def test_get_usable_indices() -> None: + """Test get_usable_indices function.""" + # Test base case + valid_indices = get_usable_indices(missing_indices=None, series_length=10, rollout=1, multistep=1, timeincrement=1) + expected_values = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8]) + assert np.allclose(valid_indices, expected_values) + + # Test multiple steps inputs + valid_indices = get_usable_indices(missing_indices=None, series_length=10, rollout=1, multistep=2, timeincrement=1) + expected_values = np.array([1, 2, 3, 4, 5, 6, 7, 8]) + assert np.allclose(valid_indices, expected_values) + + # Test roll out + valid_indices = get_usable_indices(missing_indices=None, series_length=10, rollout=2, multistep=1, timeincrement=1) + expected_values = np.array([0, 1, 2, 3, 4, 5, 6, 7]) + assert np.allclose(valid_indices, expected_values) + + # Test longer time increments + valid_indices = get_usable_indices(missing_indices=None, series_length=10, rollout=1, multistep=2, timeincrement=2) + expected_values = np.array([2, 3, 4, 5, 6, 7]) + assert np.allclose(valid_indices, expected_values) + + # Test missing indices + missing_indices = {7, 5} + valid_indices = get_usable_indices( + missing_indices=missing_indices, + series_length=10, + rollout=1, + multistep=2, + timeincrement=1, + ) + expected_values = np.array([1, 2, 3]) + assert np.allclose(valid_indices, expected_values) From 3c59da4695f566801b4f07a820f3d1854263498f Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz <48736305+JPXKQX@users.noreply.github.com> Date: Wed, 4 Sep 2024 11:25:16 +0100 Subject: [PATCH 09/14] Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 63b1ae7e..0f3f2e6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ Keep it human-readable, your future self will thank you! - Enforce same binning for histograms comparing true data to predicted data - Fix: Inference checkpoints are now saved according the frequency settings defined in the config - Feature: Add configurable models [#50](https://github.com/ecmwf/anemoi-training/pulls/50) +- Feature: Support training for datasets with missing time steps ## [0.1.0 - Anemoi training - First release](https://github.com/ecmwf/anemoi-training/releases/tag/0.1.0) - 2024-08-16 From 0e665f02f45855f9e617934b16f9b9a1d631df12 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 5 Sep 2024 15:06:43 +0000 Subject: [PATCH 10/14] feat: add link to PR --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f3f2e6e..7637b7d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,7 @@ Keep it human-readable, your future self will thank you! - Enforce same binning for histograms comparing true data to predicted data - Fix: Inference checkpoints are now saved according the frequency settings defined in the config - Feature: Add configurable models [#50](https://github.com/ecmwf/anemoi-training/pulls/50) -- Feature: Support training for datasets with missing time steps +- Feature: Support training for datasets with missing time steps [#48](https://github.com/ecmwf/anemoi-training/pulls/48) ## [0.1.0 - Anemoi training - First release](https://github.com/ecmwf/anemoi-training/releases/tag/0.1.0) - 2024-08-16 From fe9011476d9bbac55c553e0c31b7e789316ddd9d Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 6 Sep 2024 14:18:43 +0000 Subject: [PATCH 11/14] refactor: get_usable_indices --- src/anemoi/training/utils/usable_indices.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/anemoi/training/utils/usable_indices.py b/src/anemoi/training/utils/usable_indices.py index 5a126d9a..7bdd5cbd 100644 --- a/src/anemoi/training/utils/usable_indices.py +++ b/src/anemoi/training/utils/usable_indices.py @@ -42,9 +42,8 @@ def get_usable_indices( usable_indices = np.arange(series_length) # set of all indices - # No missing indices if missing_indices is None: - return usable_indices[prev_invalid_dates : series_length - next_invalid_dates] + missing_indices = set() missing_indices |= {-1, series_length} # to filter initial and final indices From 89660e82c0f5c69313ac52d8004a408f631a102f Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 6 Sep 2024 14:19:41 +0000 Subject: [PATCH 12/14] fix: shard_start/end Co-authored-by: Magnus Sikora --- src/anemoi/training/data/dataset.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index bc3633ab..e2aa12bd 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -112,11 +112,16 @@ def resolution(self) -> dict: return self.data.resolution @cached_property - def valid_dates(self) -> np.ndarray: - """Return valid dates. + def valid_date_indices(self) -> np.ndarray: + """Return valid date indices. + + A date t is valid if we can sample the sequence + (t - multistep + 1, ..., t + rollout) + without missing data (if time_increment is 1). If there are no missing dates, total number of valid ICs is - dataset length minus rollout minus additional multistep inputs. + dataset length minus rollout minus additional multistep inputs + (if time_increment is 1). """ return get_usable_indices(self.data.missing, len(self.data), self.rollout, self.multi_step, self.timeincrement) @@ -136,9 +141,9 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: self.worker_id = worker_id # Divide this equally across shards (one shard per group!) - shard_size = len(self.valid_dates) // self.model_comm_num_groups - shard_start = self.model_comm_group_id * shard_size + (self.multi_step - 1) * self.timeincrement - shard_end = min((self.model_comm_group_id + 1) * shard_size, len(self.data) - self.rollout * self.timeincrement) + shard_size = len(self.valid_date_indices) // self.model_comm_num_groups + shard_start = self.model_comm_group_id * shard_size + shard_end = (self.model_comm_group_id + 1) * shard_size shard_len = shard_end - shard_start self.n_samples_per_worker = shard_len // n_workers @@ -156,7 +161,7 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: high, ) - self.chunk_index_range = self.valid_dates[np.arange(low, high, dtype=np.uint32)] + self.chunk_index_range = self.valid_date_indices[np.arange(low, high, dtype=np.uint32)] # each worker must have a different seed for its random number generator, # otherwise all the workers will output exactly the same data From 4144dc888e59aa0d5b28f2f43e8d263fa3462a21 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Wed, 11 Sep 2024 14:19:56 +0000 Subject: [PATCH 13/14] fix: update CHANGELOG.md --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2c01c7b2..7637b7d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,7 +20,6 @@ Keep it human-readable, your future self will thank you! - Enable the callback for plotting a histogram for variables containing NaNs - Enforce same binning for histograms comparing true data to predicted data -- Feature: Support training for datasets with missing time steps [#48](https://github.com/ecmwf/anemoi-training/pulls/48) - Fix: Inference checkpoints are now saved according the frequency settings defined in the config - Feature: Add configurable models [#50](https://github.com/ecmwf/anemoi-training/pulls/50) - Feature: Support training for datasets with missing time steps [#48](https://github.com/ecmwf/anemoi-training/pulls/48) From 9a68434d6eb7f414c7591e0ae1e72f4b75b52a47 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Wed, 11 Sep 2024 14:21:31 +0000 Subject: [PATCH 14/14] fix: add more PRs to changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7637b7d3..8ee266ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,7 +20,7 @@ Keep it human-readable, your future self will thank you! - Enable the callback for plotting a histogram for variables containing NaNs - Enforce same binning for histograms comparing true data to predicted data -- Fix: Inference checkpoints are now saved according the frequency settings defined in the config +- Fix: Inference checkpoints are now saved according the frequency settings defined in the config [#37](https://github.com/ecmwf/anemoi-training/pull/37) - Feature: Add configurable models [#50](https://github.com/ecmwf/anemoi-training/pulls/50) - Feature: Support training for datasets with missing time steps [#48](https://github.com/ecmwf/anemoi-training/pulls/48)