Skip to content

Commit

Permalink
Rename kd.train.Auxiliaries -> kd.train.AuxiliariesState
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 701217653
  • Loading branch information
Conchylicultor authored and The kauldron Authors committed Nov 29, 2024
1 parent ec238be commit 7bfb131
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 19 deletions.
4 changes: 2 additions & 2 deletions kauldron/evals/eval_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
def continuous_eval(
trainer: trainer_lib.Trainer,
eval_names: list[str],
) -> dict[str, train_step.Auxiliaries]:
) -> dict[str, train_step.AuxiliariesState]:
"""Continuous evaluation.
Trigger an evaluation everytime a new checkpoint is detected.
Expand Down Expand Up @@ -75,7 +75,7 @@ def continuous_eval(
trainer.evals[name].discard_opt for name in eval_names
),
)
aux = {eval_name: train_step.Auxiliaries() for eval_name in eval_names}
aux = {eval_name: train_step.AuxiliariesState() for eval_name in eval_names}

# If preempted, the last checkpoint might be re-computed. There could be
# some race condition where the metrics are written twice for one step, but
Expand Down
4 changes: 2 additions & 2 deletions kauldron/evals/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def ds_iter(self) -> data.IterableDataset:

def evaluate(
self, state: train_step.TrainState, step: int
) -> train_step.Auxiliaries:
) -> train_step.AuxiliariesState:
"""Run one full evaluation."""
self._assert_root_cfg_resolved()
if self.discard_opt:
Expand Down Expand Up @@ -286,7 +286,7 @@ def basic_eval_step(
state: train_step.TrainState,
batch,
sharding: sharding_lib.ShardingStrategy,
) -> train_step.Auxiliaries:
) -> train_step.AuxiliariesState:
"""Call the model (pmap version)."""
# Note that step is train step (from train state), NOT `eval_step`
ctx = context_lib.Context.from_state_and_batch(state=state, batch=batch)
Expand Down
2 changes: 1 addition & 1 deletion kauldron/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from kauldron.train.rngs_lib import RngStreams
from kauldron.train.setup_utils import Setup
from kauldron.train.setup_utils import TqdmInfo
from kauldron.train.train_step import Auxiliaries
from kauldron.train.train_step import AuxiliariesState
from kauldron.train.train_step import ModelWithAux
from kauldron.train.train_step import TrainState
from kauldron.train.train_step import TrainStep
Expand Down
4 changes: 2 additions & 2 deletions kauldron/train/metric_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def write_step_metrics(
self,
*,
step: int,
aux: train_step.Auxiliaries,
aux: train_step.AuxiliariesState,
schedules: Mapping[str, optax.Schedule],
log_summaries: bool,
timer: Optional[chrono_utils.Chrono] = None,
Expand Down Expand Up @@ -576,7 +576,7 @@ def write_step_metrics(
self,
*,
step: int,
aux: train_step.Auxiliaries,
aux: train_step.AuxiliariesState,
schedules: Mapping[str, optax.Schedule],
log_summaries: bool,
timer: Optional[chrono_utils.Chrono] = None,
Expand Down
2 changes: 1 addition & 1 deletion kauldron/train/train_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

def train_impl(
trainer: trainer_lib.Trainer,
) -> tuple[train_step.TrainState, Optional[train_step.Auxiliaries]]:
) -> tuple[train_step.TrainState, Optional[train_step.AuxiliariesState]]:
"""Implements of `Trainer.train`."""
setup = trainer.setup
setup.log_status("Configuring ...")
Expand Down
18 changes: 9 additions & 9 deletions kauldron/train/train_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def replace(self, **changes: Any) -> TrainState:


@flax.struct.dataclass
class Auxiliaries:
class AuxiliariesState:
"""Auxiliaries (intermediate states to be accumulated)."""

loss_states: Mapping[str, kd_metrics.State] = dataclasses.field(
Expand All @@ -97,10 +97,10 @@ class Auxiliaries:
_pred={}, _code={}, _metadata={}, _payload={}
)

def replace(self, **changes: Any) -> Auxiliaries:
def replace(self, **changes: Any) -> AuxiliariesState:
return dataclasses.replace(self, **changes)

def merge(self, other: Optional[Auxiliaries]) -> Auxiliaries:
def merge(self, other: Optional[AuxiliariesState]) -> AuxiliariesState:
"""Accumulate auxiliary."""
if other is None:
return self
Expand All @@ -112,13 +112,13 @@ def merge(self, other: Optional[Auxiliaries]) -> Auxiliaries:
),
)

def __or__(self, other: Auxiliaries | None) -> Auxiliaries:
def __or__(self, other: AuxiliariesState | None) -> AuxiliariesState:
"""Alias for `.merge()`: `aux = aux1 | aux2`."""
if other is None:
return self
return self.merge(other)

def __ror__(self, other: Auxiliaries | None) -> Auxiliaries:
def __ror__(self, other: AuxiliariesState | None) -> AuxiliariesState:
"""Alias for `.merge()`: `aux = aux1 | aux2`."""
if other is None:
return self
Expand Down Expand Up @@ -305,9 +305,9 @@ def get_aux(
return_losses: bool = False,
return_metrics: bool = False,
return_summaries: bool = False,
) -> Auxiliaries:
) -> AuxiliariesState:
"""Get auxilaries."""
aux = Auxiliaries()
aux = AuxiliariesState()
if return_losses:
aux = aux.replace(loss_states=context.loss_states)

Expand Down Expand Up @@ -456,7 +456,7 @@ def step(
checkify_error_categories: frozenset[
trainer_lib.CheckifyErrorCategory
] = frozenset(),
) -> tuple[TrainState, Auxiliaries]:
) -> tuple[TrainState, AuxiliariesState]:
"""Training step: forward, losses, gradients, update, and metrics."""
if checkify_error_categories:
step_fn = checkify.checkify(self._step, errors=checkify_error_categories)
Expand Down Expand Up @@ -487,7 +487,7 @@ def _step(
return_losses: bool = False,
return_metrics: bool = False,
return_summaries: bool = False
) -> tuple[TrainState, Auxiliaries]:
) -> tuple[TrainState, AuxiliariesState]:
"""Training step to be wrapped by checkify and called by `step`."""
# TODO(epot): Should `jax.named_call` be moved downstream directly in optax?
# NOTE: ensure that evaluation metrics are computed from the OLD model state
Expand Down
4 changes: 2 additions & 2 deletions kauldron/train/trainer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def init_state(
skip_optimizer=skip_optimizer,
)

def train(self) -> tuple[train_step.TrainState, train_step.Auxiliaries]:
def train(self) -> tuple[train_step.TrainState, train_step.AuxiliariesState]:
"""Main method that train/evaluate the object.
Similar to:
Expand All @@ -361,7 +361,7 @@ def train(self) -> tuple[train_step.TrainState, train_step.Auxiliaries]:
def continuous_eval(
self,
names: str | list[str],
) -> dict[str, train_step.Auxiliaries]:
) -> dict[str, train_step.AuxiliariesState]:
"""Main method that perform auxiliary tasks (evaluation, rendering,...).
Trigger an evaluation everytime a new checkpoint is detected.
Expand Down

0 comments on commit 7bfb131

Please sign in to comment.