Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Is there gradient accumulation support for training? #2332

Open
liuslnlp opened this issue Aug 22, 2024 · 4 comments
Open

[Question] Is there gradient accumulation support for training? #2332

liuslnlp opened this issue Aug 22, 2024 · 4 comments

Comments

@liuslnlp
Copy link

I am tuning hyper-parameters on two different compute clusters. Since the number of GPUs on these clusters varies, I need to use gradient accumulation (GA) to ensure that the total batch size is equal. Does torchrec support GA?

@JacoCheung
Copy link

Although this is a feature which I'm looking for as well, conisdering the embedding lookup backend is FBGEMM which combines optimizer update with backward at each single step, I would expect there is no GA supported.

@gouchangjiang
Copy link

Although this is a feature which I'm looking for as well, conisdering the embedding lookup backend is FBGEMM which combines optimizer update with backward at each single step, I would expect there is no GA supported.

Hi Jaco. According to your experience, how hard it is to add this GA functionality into the FGGEMM CPU/CUDA kernel?

@JacoCheung
Copy link

Hi @gouchangjiang I'm not a fbgemm expert, but I think it's not a trivial workload. Though it's feasible it may violate the design principle of fbgemm.

The principle of FBGEMM is to eliminate wgrad write back and so users can not access the wgrad. You can of course allocate a buffer and pass it into the backward kernels and remove the update and optimizer state related code(the original fbgemm kernel codes are optimizer templated & partial-instantiated) . But you have to pay:

  1. Extra memory footprint and time. Typically the wgrad is a sparse tensor (You may not want to have a dense tensor), and thus the shape is dynamic.
  2. Sparse tensor accumulation and exposure of update. GA means that you have to explictly trigger an update method. If the wgrad is a sparse tensor, you have to implement your own accumulation operations and optimizer.
  3. Adapter from fbgemm to torchrec EBC/EC. TorchRec has a deep calling stack, even you manage to expose the wgrad from fbgemm, you still need changes in torchrec codebase.

@gouchangjiang
Copy link

Thank you @JacoCheung . That's quite a lot of work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants