From 7bfb131b5d610f03a3b4c1acc1d8ab29edb1135c Mon Sep 17 00:00:00 2001 From: Etienne Pot Date: Fri, 29 Nov 2024 01:41:04 -0800 Subject: [PATCH] Rename `kd.train.Auxiliaries` -> `kd.train.AuxiliariesState` PiperOrigin-RevId: 701217653 --- kauldron/evals/eval_impl.py | 4 ++-- kauldron/evals/evaluators.py | 4 ++-- kauldron/train/__init__.py | 2 +- kauldron/train/metric_writer.py | 4 ++-- kauldron/train/train_lib.py | 2 +- kauldron/train/train_step.py | 18 +++++++++--------- kauldron/train/trainer_lib.py | 4 ++-- 7 files changed, 19 insertions(+), 19 deletions(-) diff --git a/kauldron/evals/eval_impl.py b/kauldron/evals/eval_impl.py index 25d91963..929c6555 100644 --- a/kauldron/evals/eval_impl.py +++ b/kauldron/evals/eval_impl.py @@ -42,7 +42,7 @@ def continuous_eval( trainer: trainer_lib.Trainer, eval_names: list[str], -) -> dict[str, train_step.Auxiliaries]: +) -> dict[str, train_step.AuxiliariesState]: """Continuous evaluation. Trigger an evaluation everytime a new checkpoint is detected. @@ -75,7 +75,7 @@ def continuous_eval( trainer.evals[name].discard_opt for name in eval_names ), ) - aux = {eval_name: train_step.Auxiliaries() for eval_name in eval_names} + aux = {eval_name: train_step.AuxiliariesState() for eval_name in eval_names} # If preempted, the last checkpoint might be re-computed. There could be # some race condition where the metrics are written twice for one step, but diff --git a/kauldron/evals/evaluators.py b/kauldron/evals/evaluators.py index fe865e6b..af2c8574 100644 --- a/kauldron/evals/evaluators.py +++ b/kauldron/evals/evaluators.py @@ -212,7 +212,7 @@ def ds_iter(self) -> data.IterableDataset: def evaluate( self, state: train_step.TrainState, step: int - ) -> train_step.Auxiliaries: + ) -> train_step.AuxiliariesState: """Run one full evaluation.""" self._assert_root_cfg_resolved() if self.discard_opt: @@ -286,7 +286,7 @@ def basic_eval_step( state: train_step.TrainState, batch, sharding: sharding_lib.ShardingStrategy, -) -> train_step.Auxiliaries: +) -> train_step.AuxiliariesState: """Call the model (pmap version).""" # Note that step is train step (from train state), NOT `eval_step` ctx = context_lib.Context.from_state_and_batch(state=state, batch=batch) diff --git a/kauldron/train/__init__.py b/kauldron/train/__init__.py index 80eecf4e..5402e2ac 100644 --- a/kauldron/train/__init__.py +++ b/kauldron/train/__init__.py @@ -20,7 +20,7 @@ from kauldron.train.rngs_lib import RngStreams from kauldron.train.setup_utils import Setup from kauldron.train.setup_utils import TqdmInfo -from kauldron.train.train_step import Auxiliaries +from kauldron.train.train_step import AuxiliariesState from kauldron.train.train_step import ModelWithAux from kauldron.train.train_step import TrainState from kauldron.train.train_step import TrainStep diff --git a/kauldron/train/metric_writer.py b/kauldron/train/metric_writer.py index a48c4190..1b06524c 100644 --- a/kauldron/train/metric_writer.py +++ b/kauldron/train/metric_writer.py @@ -164,7 +164,7 @@ def write_step_metrics( self, *, step: int, - aux: train_step.Auxiliaries, + aux: train_step.AuxiliariesState, schedules: Mapping[str, optax.Schedule], log_summaries: bool, timer: Optional[chrono_utils.Chrono] = None, @@ -576,7 +576,7 @@ def write_step_metrics( self, *, step: int, - aux: train_step.Auxiliaries, + aux: train_step.AuxiliariesState, schedules: Mapping[str, optax.Schedule], log_summaries: bool, timer: Optional[chrono_utils.Chrono] = None, diff --git a/kauldron/train/train_lib.py b/kauldron/train/train_lib.py index f7ac85db..6dae38c4 100644 --- a/kauldron/train/train_lib.py +++ b/kauldron/train/train_lib.py @@ -42,7 +42,7 @@ def train_impl( trainer: trainer_lib.Trainer, -) -> tuple[train_step.TrainState, Optional[train_step.Auxiliaries]]: +) -> tuple[train_step.TrainState, Optional[train_step.AuxiliariesState]]: """Implements of `Trainer.train`.""" setup = trainer.setup setup.log_status("Configuring ...") diff --git a/kauldron/train/train_step.py b/kauldron/train/train_step.py index b042a8f7..93d0177c 100644 --- a/kauldron/train/train_step.py +++ b/kauldron/train/train_step.py @@ -77,7 +77,7 @@ def replace(self, **changes: Any) -> TrainState: @flax.struct.dataclass -class Auxiliaries: +class AuxiliariesState: """Auxiliaries (intermediate states to be accumulated).""" loss_states: Mapping[str, kd_metrics.State] = dataclasses.field( @@ -97,10 +97,10 @@ class Auxiliaries: _pred={}, _code={}, _metadata={}, _payload={} ) - def replace(self, **changes: Any) -> Auxiliaries: + def replace(self, **changes: Any) -> AuxiliariesState: return dataclasses.replace(self, **changes) - def merge(self, other: Optional[Auxiliaries]) -> Auxiliaries: + def merge(self, other: Optional[AuxiliariesState]) -> AuxiliariesState: """Accumulate auxiliary.""" if other is None: return self @@ -112,13 +112,13 @@ def merge(self, other: Optional[Auxiliaries]) -> Auxiliaries: ), ) - def __or__(self, other: Auxiliaries | None) -> Auxiliaries: + def __or__(self, other: AuxiliariesState | None) -> AuxiliariesState: """Alias for `.merge()`: `aux = aux1 | aux2`.""" if other is None: return self return self.merge(other) - def __ror__(self, other: Auxiliaries | None) -> Auxiliaries: + def __ror__(self, other: AuxiliariesState | None) -> AuxiliariesState: """Alias for `.merge()`: `aux = aux1 | aux2`.""" if other is None: return self @@ -305,9 +305,9 @@ def get_aux( return_losses: bool = False, return_metrics: bool = False, return_summaries: bool = False, - ) -> Auxiliaries: + ) -> AuxiliariesState: """Get auxilaries.""" - aux = Auxiliaries() + aux = AuxiliariesState() if return_losses: aux = aux.replace(loss_states=context.loss_states) @@ -456,7 +456,7 @@ def step( checkify_error_categories: frozenset[ trainer_lib.CheckifyErrorCategory ] = frozenset(), - ) -> tuple[TrainState, Auxiliaries]: + ) -> tuple[TrainState, AuxiliariesState]: """Training step: forward, losses, gradients, update, and metrics.""" if checkify_error_categories: step_fn = checkify.checkify(self._step, errors=checkify_error_categories) @@ -487,7 +487,7 @@ def _step( return_losses: bool = False, return_metrics: bool = False, return_summaries: bool = False - ) -> tuple[TrainState, Auxiliaries]: + ) -> tuple[TrainState, AuxiliariesState]: """Training step to be wrapped by checkify and called by `step`.""" # TODO(epot): Should `jax.named_call` be moved downstream directly in optax? # NOTE: ensure that evaluation metrics are computed from the OLD model state diff --git a/kauldron/train/trainer_lib.py b/kauldron/train/trainer_lib.py index 342d17f4..9728562c 100644 --- a/kauldron/train/trainer_lib.py +++ b/kauldron/train/trainer_lib.py @@ -340,7 +340,7 @@ def init_state( skip_optimizer=skip_optimizer, ) - def train(self) -> tuple[train_step.TrainState, train_step.Auxiliaries]: + def train(self) -> tuple[train_step.TrainState, train_step.AuxiliariesState]: """Main method that train/evaluate the object. Similar to: @@ -361,7 +361,7 @@ def train(self) -> tuple[train_step.TrainState, train_step.Auxiliaries]: def continuous_eval( self, names: str | list[str], - ) -> dict[str, train_step.Auxiliaries]: + ) -> dict[str, train_step.AuxiliariesState]: """Main method that perform auxiliary tasks (evaluation, rendering,...). Trigger an evaluation everytime a new checkpoint is detected.