You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In the 2nd train_batch_iter: even when we do half_param.grad = None in _accumulate_grad, for some reason half grads don't get released until the end of the forward (specifically until after column_linear in the last cast_to_fp32 in nanotron/models/llama.py:906:forward_with_hidden_states)
Possible culprits:
Something in column_linear keeps reference of our half_grads
torch is not reliable to release grads before a new forward
update: Found another case (with batch_accum=3) where grads don't get cleared until end of forward -> culprit isn't cast_to_fp32?
Adding torch.empty_cache in PipelineEngine.forward before the forward solves the issue
but I think it's expensive to call empty_cache before every forward.
The text was updated successfully, but these errors were encountered:
Here's a plot summarizing the issue
In the 2nd train_batch_iter: even when we do
half_param.grad = None
in_accumulate_grad
, for some reason half grads don't get released until the end of the forward (specifically until aftercolumn_linear
in the lastcast_to_fp32
innanotron/models/llama.py:906:forward_with_hidden_states
)Possible culprits:
column_linear
keeps reference of our half_gradsupdate: Found another case (with batch_accum=3) where grads don't get cleared until end of forward -> culprit isn't
cast_to_fp32
?Adding
torch.empty_cache
in PipelineEngine.forward before the forward solves the issuebut I think it's expensive to call
empty_cache
before every forward.The text was updated successfully, but these errors were encountered: