Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DDP token averaging for equivalent non-parallel training similar to #34191 #34242

Open
sbwww opened this issue Oct 18, 2024 · 2 comments
Open
Labels
Discussion Discussion on a topic (keep it focused or open a new issue though) Feature request Request for a new feature

Comments

@sbwww
Copy link

sbwww commented Oct 18, 2024

Feature request

Token averaging in gradient accumulation was fixed in #34191 . But token averaging in DDP seems to have the same issue.


Expected behaivor

With all the tokens contributing to loss in each step (in each GPU, gradient accumulation step, and microbatch), the equation becomes:

$$ntokens=\sum\limits_{GPUs} \sum\limits_{gas} \sum\limits_{microb} (label\neq-100)$$

I believe we should average the above tokens at the same time for equivalent non-parallel training.


Current issue

Prior to #34191, the loss/gradients were averaged on $\sum\limits_{GPUs}$, $\sum\limits_{gas}$, and $\sum\limits_{microb}$ separately. And, the introduction of num_items_in_batch in #34191 refers to:

$$ntokens=\sum\limits_{gas} \sum\limits_{microb} (label\neq-100)$$

So, the loss/gradients are now averaged on $\sum\limits_{GPUs}$ and $\left(\sum\limits_{gas}\sum\limits_{microb}\right)$ separately. However, this still does not seem equivalent to non-parallel training.

Can we also incorporate $\sum\limits_{GPUs}$ when determining num_items_in_batch? Something like all_reduce(num_items_in_batch)?

Motivation

DDP seems not fully equivalent to non-parallel training.

related comments: #34191 (comment)

Your contribution

Found some fairseq implementation of this feature

https://github.com/facebookresearch/fairseq/blob/018621f3cca02ca9de945dc082c3fb1a7f9f2deb/fairseq/trainer.py#L932-L949

@sbwww sbwww added the Feature request Request for a new feature label Oct 18, 2024
@muellerzr
Copy link
Contributor

muellerzr commented Oct 18, 2024

I observed this as well when I was running some experiments (things were close postfix, but not exact). Would you like to take a stab at a PR? :)

@LysandreJik LysandreJik added the Discussion Discussion on a topic (keep it focused or open a new issue though) label Oct 21, 2024
@techkang
Copy link

A simple implemention may be:

  1. add all_reduce(num_items_in_batch, op=SUM) after: https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L2416
  2. add loss *= get_world_size() after: https://github.com/huggingface/transformers/blob/main/src/transformers/loss/loss_utils.py#L26

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Discussion Discussion on a topic (keep it focused or open a new issue though) Feature request Request for a new feature
Projects
None yet
Development

No branches or pull requests

4 participants