Skip to content

Commit

Permalink
Have ._step returns the full Context
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 699444877
  • Loading branch information
Conchylicultor authored and The kauldron Authors committed Nov 29, 2024
1 parent b867e0d commit dbfc5d8
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 78 deletions.
21 changes: 21 additions & 0 deletions kauldron/train/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class Context:
opt_state: The state of the optimizer prior to the update. (available after
the backward pass, e.g. for metrics). The old state is chosen to be
consistent with parameters which are also pre-update.
metric_states: The states of the metrics (after the backward pass)
summary_states: The states of the summaries (after the backward pass)
"""

# These are always available:
Expand All @@ -80,6 +82,9 @@ class Context:
grads: Any = None
updates: Any = None
opt_state: Any = None
# Become available after the metrics computation
metric_states: Any = None
summary_states: Any = None

replace = dataclasses.replace

Expand All @@ -100,3 +105,19 @@ def from_state_and_batch(

def flatten(self) -> dict[str, Any]:
return kontext.flatten_with_path(self)

def get_aux(
self,
*,
return_losses: bool = False,
return_metrics: bool = False,
return_summaries: bool = False,
) -> train_step.AuxiliariesState:
"""Returns the auxiliaries for the step."""
from kauldron.train import train_step # pylint: disable=g-import-not-at-top

return train_step.AuxiliariesState(
loss_states=self.loss_states if return_losses else None,
metric_states=self.metric_states if return_metrics else None,
summary_states=self.summary_states if return_summaries else None,
)
133 changes: 55 additions & 78 deletions kauldron/train/train_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,54 +230,37 @@ class Auxiliaries(config_util.UpdateFromRootCfg):
)

@jax.named_call
def get_state(
self,
context: context_lib.Context,
*,
# TODO(epot): Better signature
return_losses: bool = False,
return_metrics: bool = False,
return_summaries: bool = False,
) -> AuxiliariesState:
def update_context(self, context: context_lib.Context) -> context_lib.Context:
"""Get auxilaries."""
aux = AuxiliariesState()
if return_losses:
# TODO(epot): Cleanup loss-states:
# * Re-compute the states here if `context.loss_states` is None (e.g.
# if in eval)
# * Split `kd/losses/base:compute_losses` into `get_state` and
# `compute_losses(loss_states) -> float`
# * Unify all the `m.get_state_from_context` patterns for metrics,
# summaries, and losses.
aux = aux.replace(loss_states=context.loss_states)

if return_metrics:
aux = aux.replace(
metric_states=jax.tree.map(
lambda m: m.get_state_from_context(context), self.metrics
)
)

if return_summaries:
# TODO(klausg): remove legacy summaries protocol once all are migrated
# legacy summaries protocol:
aux = aux.replace(
summary_kwargs={
k: _gather_kwargs_with_reraise(k, summary, context)
for k, summary in self.summaries.items()
}
)
# new summaries as metrics protocol:
def _get_summary_state(summary):
if isinstance(summary, kd_metrics.Metric):
return summary.get_state_from_context(context)
else:
return kd_metrics.EmptyState()

aux = aux.replace(
summary_states=jax.tree.map(_get_summary_state, self.summaries)
# TODO(epot): Cleanup loss-states:
# * Split `kd/losses/base:compute_losses` into `get_state` and
# `compute_losses(loss_states) -> float`
# * Unify all the `m.get_state_from_context` patterns for metrics,
# summaries, and losses.

# Compute the loss states here if missing (e.g. in eval or when
# `kd.train.forward` is called rather than `kd.train.forward_with_loss`)
if context.loss_states is None:
loss_states = jax.tree.map(
lambda m: m.get_state_from_context(context), self.losses
)
return aux
else:
loss_states = context.loss_states

metric_states = jax.tree.map(
lambda m: m.get_state_from_context(context), self.metrics
)
summary_states = jax.tree.map(
lambda m: m.get_state_from_context(context), self.summaries
)

return dataclasses.replace(
context,
loss_states=loss_states,
metric_states=metric_states,
summary_states=summary_states,
)


@dataclasses.dataclass(kw_only=True, eq=True, frozen=True)
Expand Down Expand Up @@ -411,38 +394,39 @@ def step(
] = frozenset(),
) -> tuple[TrainState, AuxiliariesState]:
"""Training step: forward, losses, gradients, update, and metrics."""
# This function is just a small wrapper around `_step` for:
# * Checkify errors handling
# * Select which auxiliaries metrics to return.
# * Sharding
# If reading the code, you can likely skip this function and go directly
# to `_step`.

if checkify_error_categories:
step_fn = checkify.checkify(self._step, errors=checkify_error_categories)
error, (state, aux) = step_fn(
state,
batch,
return_losses=return_losses,
return_metrics=return_metrics,
return_summaries=return_summaries,
)
aux = aux.replace(error=error)
error, (state, ctx) = step_fn(state, batch)
else:
state, aux = self._step(
state,
batch,
return_losses=return_losses,
return_metrics=return_metrics,
return_summaries=return_summaries,
)
error = None
state, ctx = self._step(state, batch)

return state, aux
# TODO(epot): More flexible way to select the subset of context to return.
# Should also have a way to return the full context.
aux_state = ctx.get_aux_state(
return_losses=return_losses,
return_metrics=return_metrics,
return_summaries=return_summaries,
)
aux_state = aux_state.replace(error=error)
return sharding_lib.with_sharding_constraint(
(state, aux_state),
(self.sharding.state, self.sharding.aux),
)

def _step(
self,
state: TrainState,
batch: PyTree[Any],
*,
return_losses: bool = False,
return_metrics: bool = False,
return_summaries: bool = False
) -> tuple[TrainState, AuxiliariesState]:
) -> tuple[TrainState, context_lib.Context]:
"""Training step to be wrapped by checkify and called by `step`."""
# TODO(epot): Should `jax.named_call` be moved downstream directly in optax?
# NOTE: ensure that evaluation metrics are computed from the OLD model state
# *before* backprop gradients are applied.
grad_fn = jax.grad(
Expand All @@ -451,6 +435,7 @@ def _step(
has_aux=True,
allow_int=True,
)
# TODO(epot): Should `jax.named_call` be moved downstream directly in optax?
grad_fn = jax.named_call(grad_fn, name="grad_fn")

context = context_lib.Context.from_state_and_batch(state=state, batch=batch)
Expand Down Expand Up @@ -482,17 +467,9 @@ def _step(
opt_state=state.opt_state,
)

aux_state = self.aux.get_state(
context,
return_losses=return_losses,
return_metrics=return_metrics,
return_summaries=return_summaries,
)
context = self.aux.update_context(context)

return sharding_lib.with_sharding_constraint(
(next_state, aux_state),
(self.sharding.state, self.sharding.aux),
)
return next_state, context


def forward(
Expand Down Expand Up @@ -597,4 +574,4 @@ def forward(self, context, **kwargs):
)

def get_aux(self, context, **kwargs):
return self.get_state(context, **kwargs)
return self.update_context(context).get_aux_state(**kwargs)

0 comments on commit dbfc5d8

Please sign in to comment.