You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Another tensor subclass to hold quantized weight. If AQT has basic support for backward, maybe we can use AQT directly. Otherwise, need to have another subclass.
The text was updated successfully, but these errors were encountered:
Had this idea and discussed briefly with @andrewor14.
Conceptually the current QAT + FSDP looks like this
However, we can do low-bit all-gather, since the weight can be quantized before all-gather
In terms of perf, basically we are comparing between (ignoring potential fusion surrounding this)
This might be a small perf win, especially when distributed comm is bottleneck. Might be useful for QAT recipes in torchtune.
This is probably a low priority, so just leave it here if anyone is interested to implement. Need to quantify the speedup, if any.
In terms of implementation, we can follow float8 design (https://github.com/pytorch/ao/blob/000a49026459dd1dadf5ca34322d98e7b1680250/torchao/float8/fsdp_utils.py)
The text was updated successfully, but these errors were encountered: