-
Notifications
You must be signed in to change notification settings - Fork 160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cannot run FSDP2 with low bit optim from AO #1189
Comments
@gau-nernst can you please take a look?: is_pinned is not implemented hence it causes issues when saving optimizer states |
Do you have a small reproduction? Yea I think we don't exactly test for saving/loading optimizers in FSDP2. Will add tests for it. |
unfortunately i cannot share the orignal code, but this ability to save optimizer states properly with FSDP2 and other parallelism implemented with Dtensor is crucial to make use of these low bit optimizers. @gau-nernst |
It looks like you are using pytorch lightning, and it calls I tested that the normal @awgu What are the benefits of using Implementing from torch.distributed.checkpoint import state_dict_saver, state_dict_loader
fsdp_model = ...
fsdp_optim = AdamW8bit(fsdp_model.parameters())
# do some training, so optimizer states are initialized
rank = torch.distributed.get_rank()
state_dict_saver.save(fsdp_optim.state_dict(), checkpoint_id=f"state_dict_rank{rank}.ckpt")
# new sharded optim. optimizer states are not initialized
new_fsdp_optim = AdamW8bit(fsdp_model.parameters())
state_dict = new_fsdp_optim.state_dict()
# this requires aten.detach. and it doesn't seem to load optim state when the new optim state is empty (i.e. not initialized)
state_dict_saver.load(state_dict, checkpoint_id=f"state_dict_rank{rank}.ckpt")
new_fsdp_optim.load(state_dict) |
I agree that for simplicity, using |
@gau-nernst @awgu hi guys, i don't think just using torch.save is enough for cases with TP+ FSDP2 or even higher degree of parallelisms. my reference is this repo:https://github.com/pytorch/torchtitan . they have also use distributed checkpointing |
@gau-nernst @nighting0le01 If there is no resharding (no parallelism change and no world size changes), then |
@fegin thanks but i think in general resharding would be required in most cases since. switching from training to inference and to a different no of gpus. and might be worthwhile |
@nighting0le01 Can you produce a minimal snippet showing how you save and load optimizer state dict? Without it, I don't have much to investigate and make sure it works for you. |
hi @gau-nernst i got the loading to work by implmenting the detach, is_pinned methods but can you advise how to go about slice in the case of this optimizer? i run into this issue while loading
|
What are the exact commands that you use to run it? Again, if you can minify the code to a small snippet, that would be great. |
@gau-nernst i just modify here to run to the 8 bit varaint https://github.com/pytorch/torchtitan/blob/d7cabfb6cc987f9310a4bfa5b87b1dbb8974b10c/torchtitan/optimizer.py#L35 this is the train script:https://github.com/pytorch/torchtitan/blob/main/run_llama_train.sh |
I have implemented Regarding stability/divergence, you can try reducing the block size, but I think it won't help much since 128 is already pretty small. There could be some bugs in certain training configurations. Again, I can't help much without a reproduction. |
@nighting0le01 Can you try my branch over at #1217 to see if it works for you? From my testing, it works with the default config (only data parallel, no tensor/pipeline parallel). Using torch==2.6.0.dev20241102+cu124. I didn't need to implement I need to add the following lines so that from torchao.prototype.low_bit_optim.subclass_4bit import OptimState4bit
from torchao.prototype.low_bit_optim.subclass_8bit import OptimState8bit
from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8
subclasses = (OptimState4bit, OptimState8bit, OptimStateFp8)
torch.serialization.add_safe_globals(subclasses) as well as patch the follow code pytorch/pytorch#139575 |
hi @gau-nernst so the issue with aten.slice came only while loading and not during saving. and that too with TP + FSDP. i added the above add_safe_globals to enable saving as well and the patch you mention before sharing the above comment on
|
@gau-nernst i'm 100% sure we need is_pinned detach and aten.slice. my implementation is slightly modified (based on torchtitan) and is implemented using model_parallel strategy from lightening. could you please guide me on the implementaiton of slice for this optimizer? since i haven't gone through the entire paper yet. i have something here:#1189 (comment)
|
Cannot run FSDP2 with low bit optim from AO
The text was updated successfully, but these errors were encountered: