From 3e3980b326d3fa26a3e1a5c84f95214b01de140c Mon Sep 17 00:00:00 2001 From: The kauldron Authors Date: Fri, 29 Nov 2024 11:07:55 -0800 Subject: [PATCH] Fix bug which led to crash in kd-test PiperOrigin-RevId: 701322169 --- kauldron/train/trainer_lib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kauldron/train/trainer_lib.py b/kauldron/train/trainer_lib.py index 5b809e19..5413da25 100644 --- a/kauldron/train/trainer_lib.py +++ b/kauldron/train/trainer_lib.py @@ -399,7 +399,7 @@ def context_specs(self) -> context_lib.Context: # Instead just creating the `spec` should be enough. m_batch = data_utils.mock_batch_from_elem_spec(elem_spec, elem_sharding) - context = jax.eval_shape( + _, context = jax.eval_shape( self.trainstep._step, # pylint: disable=protected-access self.state_specs, m_batch,