Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Gradient Accumulation fix across all models + trainer fully in forward() #34283

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1172,7 +1173,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1083,7 +1084,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/gemma/modular_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
```python
Expand Down Expand Up @@ -1002,7 +1003,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1064,7 +1065,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
```python
Expand Down Expand Up @@ -805,7 +806,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1071,7 +1072,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1477,6 +1477,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: Optional[Union[int, None]] = None,
**loss_kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1542,7 +1543,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

aux_loss = None
if output_router_logits:
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1326,7 +1327,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

aux_loss = None
if output_router_logits:
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/mllama/modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1887,6 +1887,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1949,7 +1950,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/nemotron/modeling_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1083,7 +1084,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1126,7 +1127,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/olmoe/modeling_olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1290,7 +1291,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

aux_loss = None
if output_router_logits:
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,6 +1192,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1250,7 +1251,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1300,7 +1301,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/phimoe/modeling_phimoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1467,7 +1468,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

aux_loss = None
if output_router_logits:
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1204,7 +1205,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,6 +1330,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1392,7 +1393,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

aux_loss = None
if output_router_logits:
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/rt_detr/modeling_rt_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2027,6 +2027,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**loss_kwargs,
) -> Union[Tuple[torch.FloatTensor], RTDetrObjectDetectionOutput]:
r"""
labels (`List[Dict]` of len `(batch_size,)`, *optional*):
Expand Down Expand Up @@ -2128,6 +2129,7 @@ def forward(
enc_topk_logits=enc_topk_logits,
enc_topk_bboxes=enc_topk_bboxes,
denoising_meta_values=denoising_meta_values,
**loss_kwargs,
)

if not return_dict:
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/zamba/modeling_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1477,7 +1478,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
9 changes: 7 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,8 @@ def __init__(
self.model_wrapped = model
self.model = model

self.model_accepts_loss_kwargs = "loss_kwargs" in inspect.signature(model.forward).parameters

self.neftune_noise_alpha = args.neftune_noise_alpha

self.compute_metrics = compute_metrics
Expand Down Expand Up @@ -3610,8 +3612,11 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
labels = inputs.pop("labels")
else:
labels = None
# if num_items_in_batch is not None:
# inputs["num_items_in_batch"] = num_items_in_batch
if self.model_accepts_loss_kwargs:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if condition doesn't seem to work for PeftModel class (it only has kwargs not loss_kwargs 🫠 )

I tried just changing that condition to if True and ran some tests, and the loss calculation worked perfectly for a LORA on a Llama 3 1B.

I'm wondering if there's a safe/non-breaking way to support peft models here as well?

loss_kwargs = {}
if num_items_in_batch is not None:
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
Expand Down
Loading