Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QAT] Low-bit FSDP all-gather for QAT #1224

Open
gau-nernst opened this issue Nov 5, 2024 · 1 comment
Open

[QAT] Low-bit FSDP all-gather for QAT #1224

gau-nernst opened this issue Nov 5, 2024 · 1 comment
Labels

Comments

@gau-nernst
Copy link
Collaborator

gau-nernst commented Nov 5, 2024

Had this idea and discussed briefly with @andrewor14.

Conceptually the current QAT + FSDP looks like this

  • sharded FP32 weight -> all-gather in BF16 -> fake quantize

However, we can do low-bit all-gather, since the weight can be quantized before all-gather

  • sharded FP32 weight -> (real) quantize -> all-gather in low-bit -> dequantize

In terms of perf, basically we are comparing between (ignoring potential fusion surrounding this)

  1. BF16 all-gather + fake quantize
  2. (Real) quantize (1/NGPU) + Low-bit all-gather + Dequant

This might be a small perf win, especially when distributed comm is bottleneck. Might be useful for QAT recipes in torchtune.

This is probably a low priority, so just leave it here if anyone is interested to implement. Need to quantify the speedup, if any.

In terms of implementation, we can follow float8 design (https://github.com/pytorch/ao/blob/000a49026459dd1dadf5ca34322d98e7b1680250/torchao/float8/fsdp_utils.py)

@vkuzo
Copy link
Contributor

vkuzo commented Nov 5, 2024

This would chain nicely with also doing the matrix multiply in low precision.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants