Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP8 training can't works with deepspeed #3360

Open
2 of 4 tasks
XiaobingSuper opened this issue Jan 23, 2025 · 1 comment · May be fixed by #3361
Open
2 of 4 tasks

FP8 training can't works with deepspeed #3360

XiaobingSuper opened this issue Jan 23, 2025 · 1 comment · May be fixed by #3361

Comments

@XiaobingSuper
Copy link

System Info

- `Accelerate` version: 1.3.0
- Platform: Linux-5.4.250-4-velinux1u1-amd64-x86_64-with-glibc2.35
- `accelerate` bash location: /usr/local/bin/accelerate
- Python version: 3.10.12
- Numpy version: 1.24.4
- PyTorch version (GPU?): 2.5.1+cu124 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- PyTorch MUSA available: False
- System RAM: 1928.86 GB
- GPU type: NVIDIA H20
- `Accelerate` default config:
        Not found

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

Create deepspeed configure json:

{
    "train_batch_size": 16,
    "train_micro_batch_size_per_gpu": 16,
    "gradient_accumulation_steps": 1,
    "zero_optimization": {
        "stage": 3,
        "stage3_gather_16bit_weights_on_model_save": false
    },
    "gradient_clipping": 1.0,
    "bf16": {"enabled": true},
    "fp16": {"enabled": false},
    "zero_allow_untested_optimizer": true
}

and accelerate configure yaml:

distributed_type: DEEPSPEED
deepspeed_config:
  deepspeed_config_file: "config.json"
  zero3_init_flag: true
num_processes: 1

Then run the following script:

from accelerate import Accelerator
from accelerate.utils import has_transformer_engine_layers, FP8RecipeKwargs
from fp8_utils import  get_training_utilities

MODEL_NAME = "bert-base-cased"

FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"}
kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)]

accelerator = Accelerator(
    mixed_precision="fp8", kwargs_handlers=kwargs_handlers
)

model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
    MODEL_NAME, accelerator=accelerator
)

model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)

assert has_transformer_engine_layers(model), "Model should have Transformer Engine layers"

Expected behavior

There will be an assert error:Model should have Transformer Engine layers

@XiaobingSuper XiaobingSuper linked a pull request Jan 23, 2025 that will close this issue
5 tasks
@XiaobingSuper
Copy link
Author

you can also reproduce this issue by running https://github.com/huggingface/accelerate/blob/main/benchmarks/fp8/transformer_engine/distrib_deepspeed.py:

CUDA_VISIBLE_DEVICES=4,5 accelerate launch distrib_deepspeed.py

Image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant