Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot run FSDP2 with low bit optim from AO #1189

Open
nighting0le01 opened this issue Oct 29, 2024 · 16 comments · May be fixed by #1217
Open

Cannot run FSDP2 with low bit optim from AO #1189

nighting0le01 opened this issue Oct 29, 2024 · 16 comments · May be fixed by #1217

Comments

@nighting0le01
Copy link

Cannot run FSDP2 with low bit optim from AO

[rank7]:   File "<frozen runpy>", line 198, in _run_module_as_main
[rank7]:   File "<frozen runpy>", line 88, in _run_code
[rank7]:   File "/nfs/asahni/multi_parallel/oct_28/training/scripts/train.py", line 226, in <module>
[rank7]:     main()
[rank7]:   File "/nfs/asahni/multi_parallel/oct_28/training/scripts/train.py", line 218, in main
[rank7]:     trainer.fit(
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
[rank7]:     call._call_and_handle_interrupt(
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
[rank7]:     return trainer_fn(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
[rank7]:     self._run(model, ckpt_path=ckpt_path)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
[rank7]:     results = self._run_stage()
[rank7]:               ^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1025, in _run_stage
[rank7]:     self.fit_loop.run()
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
[rank7]:     self.advance()
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
[rank7]:     self.epoch_loop.run(self._data_fetcher)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
[rank7]:     self.advance(data_fetcher)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 269, in advance
[rank7]:     call._call_callback_hooks(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 218, in _call_callback_hooks
[rank7]:     fn(trainer, trainer.lightning_module, *args, **kwargs)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 316, in on_train_batch_end
[rank7]:     self._save_topk_checkpoint(trainer, monitor_candidates)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 387, in _save_topk_checkpoint
[rank7]:     self._save_none_monitor_checkpoint(trainer, monitor_candidates)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 715, in _save_none_monitor_checkpoint
[rank7]:     self._save_checkpoint(trainer, filepath)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 390, in _save_checkpoint
[rank7]:     trainer.save_checkpoint(filepath, self.save_weights_only)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1365, in save_checkpoint
[rank7]:     self.strategy.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/strategies/model_parallel.py", line 321, in save_checkpoint
[rank7]:     _distributed_checkpoint_save(converted_state, path)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/fabric/strategies/fsdp.py", line 867, in _distributed_checkpoint_save
[rank7]:     save(converted_state, checkpoint_id=path)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/logger.py", line 83, in wrapper
[rank7]:     result = func(*args, **kwargs)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 429, in inner_func
[rank7]:     return func(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 152, in save
[rank7]:     return _save_state_dict(
[rank7]:            ^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 316, in _save_state_dict
[rank7]:     central_plan: SavePlan = distW.reduce_scatter("plan", local_step, global_step)
[rank7]:                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 191, in reduce_scatter
[rank7]:     raise result
[rank7]: torch.distributed.checkpoint.api.CheckpointException: CheckpointException ranks:dict_keys([0, 1, 2, 3, 4, 5, 6, 7])
[rank7]: Traceback (most recent call last): (RANK 0)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 164, in reduce_scatter
[rank7]:     local_data = map_fun()
[rank7]:                  ^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/logger.py", line 83, in wrapper
[rank7]:     result = func(*args, **kwargs)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 303, in local_step
[rank7]:     local_plan = planner.create_local_plan()
[rank7]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 101, in create_local_plan
[rank7]:     plan = create_default_local_save_plan(self.state_dict, self.is_coordinator)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 399, in create_default_local_save_plan
[rank7]:     requests += _create_write_items(fqn, obj)
[rank7]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/planner_helpers.py", line 222, in _create_write_items
[rank7]:     return object.__create_write_items__(fqn, object)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 598, in __create_write_items__
[rank7]:     return [_create_write_items_for_dtensor(fqn, object)]
[rank7]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/planner_helpers.py", line 86, in _create_write_items_for_dtensor
[rank7]:     properties=TensorProperties.create_from_tensor(tensor.to_local()),
[rank7]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/metadata.py", line 108, in create_from_tensor
[rank7]:     pin_memory=tensor.is_pinned(),
[rank7]:                ^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torchao/utils.py", line 377, in _dispatch__torch_function__
[rank7]:     return func(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torchao/utils.py", line 393, in _dispatch__torch_dispatch__
[rank7]:     raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func=}, {types=}, {arg_types=}, {kwarg_types=}")
[rank7]: NotImplementedError: OptimState8bit dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.prototype.low_bit_optim.subclass_8bit.OptimState8bit'>,), arg_types=(<class 'torchao.prototype.low_bit_optim.subclass_8bit.OptimState8bit'>,), kwarg_types={}
@nighting0le01
Copy link
Author

@gau-nernst can you please take a look?: is_pinned is not implemented hence it causes issues when saving optimizer states

@gau-nernst
Copy link
Collaborator

Do you have a small reproduction? Yea I think we don't exactly test for saving/loading optimizers in FSDP2. Will add tests for it.

@nighting0le01
Copy link
Author

unfortunately i cannot share the orignal code, but this ability to save optimizer states properly with FSDP2 and other parallelism implemented with Dtensor is crucial to make use of these low bit optimizers. @gau-nernst

@gau-nernst
Copy link
Collaborator

gau-nernst commented Oct 30, 2024

It looks like you are using pytorch lightning, and it calls torch.distributed.checkpoint.state_dict_saver.save(). This function requires the tensor subclass to implement aten.is_pinned.

I tested that the normal torch.save(optim.state_dict(), "state_dict.ckpt") works fine. @nighting0le01 In the mean time, is it possible for you to switch to the plain torch.save() for checkpointing?

@awgu What are the benefits of using torch.distributed.checkpoint.state_dict_saver.save() over the plain torch.save()? From my understanding of https://pytorch.org/docs/stable/distributed.checkpoint.html, it seems like the former will handle some kind of resharding when loading? Is the saving the same?

Implementing aten.is_pinned op is simple, but since I'm not too familiar with torch.distributed.checkpoint, what is the recommended/correct way to save and load (sharded) optim state dict with it? (so that I can add the correct tests) Is it something like this

from torch.distributed.checkpoint import state_dict_saver, state_dict_loader

fsdp_model = ...
fsdp_optim = AdamW8bit(fsdp_model.parameters())

# do some training, so optimizer states are initialized

rank = torch.distributed.get_rank()
state_dict_saver.save(fsdp_optim.state_dict(), checkpoint_id=f"state_dict_rank{rank}.ckpt")

# new sharded optim. optimizer states are not initialized
new_fsdp_optim = AdamW8bit(fsdp_model.parameters())
state_dict = new_fsdp_optim.state_dict()

# this requires aten.detach. and it doesn't seem to load optim state when the new optim state is empty (i.e. not initialized)
state_dict_saver.load(state_dict, checkpoint_id=f"state_dict_rank{rank}.ckpt")
new_fsdp_optim.load(state_dict)

@awgu
Copy link
Contributor

awgu commented Oct 30, 2024

I agree that for simplicity, using torch.save directly would unblock the use case.

@nighting0le01
Copy link
Author

@gau-nernst @awgu hi guys, i don't think just using torch.save is enough for cases with TP+ FSDP2 or even higher degree of parallelisms. my reference is this repo:https://github.com/pytorch/torchtitan . they have also use distributed checkpointing

@fegin
Copy link

fegin commented Oct 30, 2024

@gau-nernst @nighting0le01 If there is no resharding (no parallelism change and no world size changes), then torch.save is enough. Post-processing of the checkpoint can be harder with torch.save because it is another form of resharding.

@nighting0le01
Copy link
Author

@fegin thanks but i think in general resharding would be required in most cases since. switching from training to inference and to a different no of gpus. and might be worthwhile
@gau-nernst @awgu

@gau-nernst
Copy link
Collaborator

@nighting0le01 Can you produce a minimal snippet showing how you save and load optimizer state dict? Without it, I don't have much to investigate and make sure it works for you.

@nighting0le01
Copy link
Author

nighting0le01 commented Nov 1, 2024

hi @gau-nernst i got the loading to work by implmenting the detach, is_pinned methods but can you advise how to go about slice in the case of this optimizer?
i'm just working with the repo here:https://github.com/pytorch/torchtitan/tree/main/torchtitan

i run into this issue while loading

[rank0]: Traceback (most recent call last): (RANK 7)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 248, in all_gather
[rank0]:     result = map_fun()
[rank0]:              ^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/logger.py", line 83, in wrapper
[rank0]:     result = func(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 228, in read_data
[rank0]:     all_reads = storage_reader.read_data(final_local_plan, planner)
[rank0]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/filesystem.py", line 661, in read_data
[rank0]:     tensor = narrow_tensor_by_index(
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/_shard/_utils.py", line 24, in narrow_tensor_by_index
[rank0]:     narrowed_tensor = narrowed_tensor.narrow(idx, offset, size)
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torchao/utils.py", line 377, in _dispatch__torch_function__
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torchao/utils.py", line 393, in _dispatch__torch_dispatch__
[rank0]:     raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func=}, {types=}, {arg_types=}, {kwarg_types=}")
[rank0]: NotImplementedError: OptimState8bit dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.slice', overload='Tensor')>, types=(<class 'torchao.prototype.low_bit_optim.subclass_8bit.OptimState8bit'>,), arg_types=(<class 'torchao.prototype.low_bit_optim.subclass_8bit.OptimState8bit'>, <class 'int'>, <class 'int'>, <class 'int'>), kwarg_types={}
@OptimState8bit.implements(aten.is_pinned.default) 
def _(func, types, args, kwargs):
    x = args[0] 
    # The OptimState8bit instance 
    # # Check if the underlying tensors are pinned 
    is_pinned = x.codes.is_pinned() and x.scale.is_pinned() and x.qmap.is_pinned() 
    return is_pinned


@OptimState8bit.implements(aten.detach.default)
def _(func, types, args, kwargs):
    x = args[0]
    codes = x.codes.detach()
    scale = x.scale.detach()
    qmap = x.qmap.detach()
    return OptimState8bit(codes, scale, qmap, x.signed)


@OptimState8bit.implements(aten.slice.Tensor)
def _(func, types, args, kwargs):
    x, dim, start, end, step = args
    print(x)
    if step != 1:
        raise NotImplementedError("Step sizes other than 1 are not supported for OptimState8bit slicing.")
    if start is None:
        start = 0
    if end is None:
        end = x.shape[dim]
    # Slice the codes tensor
    codes_sliced = x.codes.slice(dim, start, end, step)
    # Adjust the scale tensor based on the sliced codes
    total_elements = x.codes.numel()
    # Calculate the number of blocks in the original tensor
    num_blocks = total_elements // x.block_size
    codes_shape = x.codes.shape
    # We need to find out which blocks are affected by the slice
    # For simplicity, flatten the codes tensor
    codes_flat = x.codes.flatten()
    # Compute the start and end indices in the flattened tensor
    flat_start = start * (x.codes.stride()[dim])
    flat_end = end * (x.codes.stride()[dim])
    # Compute the affected blocks
    block_start = flat_start // x.block_size
    block_end = (flat_end + x.block_size - 1) // x.block_size  # Ceiling division
    # Slice the scale tensor accordingly
    scale_sliced = x.scale[block_start:block_end]
    # qmap remains the same
    qmap = x.qmap
    return OptimState8bit(codes_sliced, scale_sliced, qmap, x.signed)

@gau-nernst
Copy link
Collaborator

What are the exact commands that you use to run it? Again, if you can minify the code to a small snippet, that would be great.

@nighting0le01
Copy link
Author

@gau-nernst
Copy link
Collaborator

gau-nernst commented Nov 3, 2024

I have implemented aten.is_pinned and aten.detach in #1217. It works with the default torchtitan config from my testing. Again, I need to full command/config (PP/TP if any) that you used to reproduce the issue with aten.slice (also your multi-GPU setup, it might be relevant), because I don't know how the slice op is being used here. I can't help you without the ability to reproduce the problem from my side.

Regarding stability/divergence, you can try reducing the block size, but I think it won't help much since 128 is already pretty small. There could be some bugs in certain training configurations. Again, I can't help much without a reproduction.

@gau-nernst
Copy link
Collaborator

@nighting0le01 Can you try my branch over at #1217 to see if it works for you? From my testing, it works with the default config (only data parallel, no tensor/pipeline parallel). Using torch==2.6.0.dev20241102+cu124. I didn't need to implement aten.slice.

I need to add the following lines so that dcp.load() can load tensor subclasses

from torchao.prototype.low_bit_optim.subclass_4bit import OptimState4bit
from torchao.prototype.low_bit_optim.subclass_8bit import OptimState8bit
from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8

subclasses = (OptimState4bit, OptimState8bit, OptimStateFp8)
torch.serialization.add_safe_globals(subclasses)

as well as patch the follow code pytorch/pytorch#139575

@nighting0le01
Copy link
Author

nighting0le01 commented Nov 3, 2024

hi @gau-nernst so the issue with aten.slice came only while loading and not during saving. and that too with TP + FSDP. i added the above add_safe_globals to enable saving as well and the patch you mention before sharing the above comment on

[rank0]: Traceback (most recent call last): (RANK 7)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 248, in all_gather
[rank0]:     result = map_fun()
[rank0]:              ^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/logger.py", line 83, in wrapper
[rank0]:     result = func(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 228, in read_data
[rank0]:     all_reads = storage_reader.read_data(final_local_plan, planner)
[rank0]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/filesystem.py", line 661, in read_data
[rank0]:     tensor = narrow_tensor_by_index(
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/_shard/_utils.py", line 24, in narrow_tensor_by_index
[rank0]:     narrowed_tensor = narrowed_tensor.narrow(idx, offset, size)
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torchao/utils.py", line 377, in _dispatch__torch_function__
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torchao/utils.py", line 393, in _dispatch__torch_dispatch__
[rank0]:     raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func=}, {types=}, {arg_types=}, {kwarg_types=}")
[rank0]: NotImplementedError: OptimState8bit dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.slice', overload='Tensor')>, types=(<class 'torchao.prototype.low_bit_optim.subclass_8bit.OptimState8bit'>,), arg_types=(<class 'torchao.prototype.low_bit_optim.subclass_8bit.OptimState8bit'>, <class 'int'>, <class 'int'>, <class 'int'>), kwarg_types={}

@nighting0le01
Copy link
Author

nighting0le01 commented Nov 3, 2024

@gau-nernst i'm 100% sure we need is_pinned detach and aten.slice. my implementation is slightly modified (based on torchtitan) and is implemented using model_parallel strategy from lightening. could you please guide me on the implementaiton of slice for this optimizer? since i haven't gone through the entire paper yet.

i have something here:#1189 (comment)

  1. also what can i do for stability ? i see divergence with block size of 128 with 8 bit, shall i keep reducing?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants