Skip to content

Commit

Permalink
Clean
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Jan 16, 2025
1 parent be210db commit fbeb5a7
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 64 deletions.
35 changes: 20 additions & 15 deletions benchmarks/fp8/torchao/non_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@
This particular script verifies this for single GPU training.
"""

from functools import partial

import evaluate
import torch
from functools import partial
from datasets import load_dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torchao.float8 import convert_to_float8_training
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup

from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import AORecipeKwargs, set_seed


Expand Down Expand Up @@ -169,8 +171,10 @@ def train_baseline():

def train_integration():
set_seed(42)
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[AORecipeKwargs()])
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
MODEL_NAME, accelerator=accelerator
)
model = accelerator.prepare(model)
base_model_results = evaluate_model(model, eval_dataloader, METRIC)
model.train()
Expand All @@ -196,17 +200,18 @@ def train_integration():


if __name__ == "__main__":
# baseline_not_trained, baseline_trained = train_baseline()
baseline_not_trained, baseline_trained = train_baseline()
AcceleratorState._reset_state(True)
accelerator_not_trained, accelerator_trained = train_integration()
# assert (
# baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"]
# ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}'
# assert (
# baseline_not_trained["f1"] == accelerator_not_trained["f1"]
# ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}'
# assert (
# baseline_trained["accuracy"] == accelerator_trained["accuracy"]
# ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}'
# assert (
# baseline_trained["f1"] == accelerator_trained["f1"]
# ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}'
assert (
baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"]
), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}'
assert (
baseline_not_trained["f1"] == accelerator_not_trained["f1"]
), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}'
assert (
baseline_trained["accuracy"] == accelerator_trained["accuracy"]
), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}'
assert (
baseline_trained["f1"] == accelerator_trained["f1"]
), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}'
24 changes: 16 additions & 8 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@
from types import MethodType
from typing import Any, Callable, Union

from accelerate.utils.imports import is_torchao_available
import torch
import torch.utils.hooks as hooks
from huggingface_hub import split_torch_state_dict_into_shards

from accelerate.utils.imports import is_torchao_available

from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches
from .logging import get_logger
Expand All @@ -49,10 +50,8 @@
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
WEIGHTS_PATTERN_NAME,
AutocastKwargs,
AORecipeKwargs,
TERecipeKwargs,
MSAMPRecipeKwargs,
AutocastKwargs,
DataLoaderConfiguration,
DeepSpeedPlugin,
DistributedDataParallelKwargs,
Expand All @@ -66,18 +65,20 @@
KwargsHandler,
LoggerType,
MegatronLMPlugin,
MSAMPRecipeKwargs,
PrecisionType,
ProfileKwargs,
ProjectConfiguration,
RNGType,
TERecipeKwargs,
TorchDynamoPlugin,
apply_fp8_autowrap,
check_os_kernel,
clean_state_dict_for_safetensors,
compare_versions,
convert_model,
convert_to_float8_training,
convert_outputs_to_fp32,
convert_to_float8_training,
ensure_weights_retied,
extract_model_from_parallel,
gather,
Expand Down Expand Up @@ -442,7 +443,9 @@ def __init__(
elif is_msamp_available():
self.msamp_recipe_handler = MSAMPRecipeKwargs()
else:
raise ImportError("Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed.")
raise ImportError(
"Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed."
)

self.delayed_fp8_autocast = False
if self.has_fp8_handler:
Expand Down Expand Up @@ -1627,8 +1630,13 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
def _prepare_ao(self, *args):
if not is_torchao_available():
raise ImportError("`torchao` was not found on your system. Please ensure that `torchao` is installed")
for model in self._models:
convert_to_float8_training(model, config=self.ao_recipe_handler.config, module_filter_func=self.ao_recipe_handler.module_filter_func)
for arg in args:
if isinstance(arg, torch.nn.Module):
convert_to_float8_training(
arg,
config=self.ao_recipe_handler.config,
module_filter_func=self.ao_recipe_handler.module_filter_func,
)
return args

def _prepare_te(self, *args):
Expand Down
8 changes: 4 additions & 4 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION,
)
from .dataclasses import (
AutocastKwargs,
AORecipeKwargs,
AutocastKwargs,
BnbQuantizationConfig,
ComputeEnvironment,
CustomDtype,
Expand All @@ -52,15 +52,15 @@
KwargsHandler,
LoggerType,
MegatronLMPlugin,
MSAMPRecipeKwargs,
PrecisionType,
ProfileKwargs,
ProjectConfiguration,
RNGType,
SageMakerDistributedType,
TensorInformation,
TorchDynamoPlugin,
TERecipeKwargs,
MSAMPRecipeKwargs,
TorchDynamoPlugin,
add_model_config_to_megatron_parser,
)
from .environment import (
Expand All @@ -81,7 +81,6 @@
)
from .imports import (
deepspeed_required,
torchao_required,
get_ccl_version,
is_4bit_bnb_available,
is_8bit_bnb_available,
Expand Down Expand Up @@ -129,6 +128,7 @@
is_wandb_available,
is_weights_only_available,
is_xpu_available,
torchao_required,
)
from .modeling import (
align_module_device,
Expand Down
28 changes: 12 additions & 16 deletions src/accelerate/utils/ao.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ def find_first_last_linear_layers(model: torch.nn.Module):
"""
Finds the first and last linear layer names in a model.
This is needed during FP8 to avoid issues with
instability by keeping the first and last layers
unquantized.
This is needed during FP8 to avoid issues with instability by keeping the first and last layers unquantized.
Ref: https://x.com/xariusrke/status/1826669142604141052
"""
Expand Down Expand Up @@ -72,31 +70,29 @@ def filter_linear_layers(module, layer_name, first_layer_name, last_layer_name)

@torchao_required
def convert_to_float8_training(
model: torch.nn.Module,
config=None,
module_filter_func=None,
):
model: torch.nn.Module,
config=None,
module_filter_func=None,
):
"""
Converts all `nn.Linear` layers in the model (except the first and last)
to torchao's `Float8Linear` layer inplace.
Converts all `nn.Linear` layers in the model (except the first and last) to torchao's `Float8Linear` layer inplace.
Args:
model (`torch.nn.Module`):
The model to convert.
config (`torchao.float8.Float8LinearConfig`, *optional*):
The configuration for the FP8 training. Recommended to utilize
`torchao.float8.recipe_name_to_linear_config` to generate this.
In general, the default config should be sufficient.
`torchao.float8.recipe_name_to_linear_config` to generate this. In general, the default config should be
sufficient.
module_filter_func (`Callable`, *optional*):
Optional function that must take in a module and layer name,
and returns a boolean indicating whether the module should be
converted to FP8. Defaults to `filter_linear_layers`. See
it for an example.
Optional function that must take in a module and layer name, and returns a boolean indicating whether the
module should be converted to FP8. Defaults to `filter_linear_layers`. See it for an example.
Example:
```python
from accelerate.utils.ao import convert_to_float8_training
model = MyModel()
model.to("cuda")
convert_to_float8_training(model)
Expand All @@ -109,4 +105,4 @@ def convert_to_float8_training(
first_linear, last_linear = find_first_last_linear_layers(model)
if module_filter_func is None:
module_filter_func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear)
convert_to_float8_training(model, config, module_filter_func)
convert_to_float8_training(model, module_filter_fn=module_filter_func, config=config)
34 changes: 14 additions & 20 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,16 @@
import copy
import enum
import functools
import logging
import os
import warnings
import logging
from contextlib import contextmanager
from dataclasses import dataclass, field
from datetime import timedelta
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union, get_args, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union, get_args

import torch

from .ao import filter_linear_layers
from .constants import (
FSDP_AUTO_WRAP_POLICY,
FSDP_BACKWARD_PREFETCH,
Expand All @@ -49,6 +48,7 @@
)
from .versions import compare_versions, is_torch_version


if TYPE_CHECKING:
# Mock imports for type checking
from torchao.float8 import Float8LinearConfig
Expand Down Expand Up @@ -296,25 +296,20 @@ class AORecipeKwargs(KwargsHandler):
Args:
recipe_name (`str`, *optional*, default to `None`):
The name of the recipe to use for FP8 training. Should
be compatible with `torchao.float8.recipe_name_to_linear_config`.
The name of the recipe to use for FP8 training. Should be compatible with
`torchao.float8.recipe_name_to_linear_config`.
config (`torchao.float8.Float8LinearConfig`, *optional*, default to `None`):
The configuration for the FP8 training. In general, the default config
should be sufficient.
The configuration for the FP8 training. In general, the default config should be sufficient.
module_filter_func (`Callable`, *optional*, default to `None`):
Optional function that must take in a module and layer name,
and returns a boolean indicating whether the module should be
converted to FP8. Defaults to `accelerate.utils.ao.filter_linear_layers`. See
it for an example.
Optional function that must take in a module and layer name, and returns a boolean indicating whether the
module should be converted to FP8. Defaults to `accelerate.utils.ao.filter_linear_layers`. See it for an
example.
"""

recipe_name: str = None
config: "Float8LinearConfig" = None
module_filter_func: Callable = None

def __post_init__(self):
if self.module_filter_func is None:
self.module_filter_func = filter_linear_layers


@dataclass
class TERecipeKwargs(KwargsHandler):
Expand Down Expand Up @@ -354,6 +349,7 @@ class TERecipeKwargs(KwargsHandler):
override_linear_precision (`tuple` of three `bool`, *optional*, default to `(False, False, False)`):
Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision.
"""

use_autocast_during_eval: bool = None
margin: int = None
interval: int = None
Expand All @@ -365,9 +361,7 @@ class TERecipeKwargs(KwargsHandler):
def __post_init__(self):
env_prefix = "ACCELERATE_FP8_"
if not is_transformer_engine_available():
raise ImportError(
"TransformerEngine is not available. Please install it or use a different backend."
)
raise ImportError("TransformerEngine is not available. Please install it or use a different backend.")
if self.use_autocast_during_eval is None:
self.use_autocast_during_eval = parse_flag_from_env(env_prefix + "USE_AUTOCAST_DURING_EVAL")
if self.margin is None:
Expand Down Expand Up @@ -399,6 +393,7 @@ class MSAMPRecipeKwargs(KwargsHandler):
Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision
training with `ms-amp`.
"""

opt_level: OptLevel = None

def __post_init__(self):
Expand All @@ -412,8 +407,7 @@ def __post_init__(self):
@dataclass
class FP8RecipeKwargs(TERecipeKwargs, MSAMPRecipeKwargs):
"""
Deprecated. Please use one of the proper FP8 recipe
kwargs classes such as `TERecipeKwargs` or `MSAMPRecipeKwargs`
Deprecated. Please use one of the proper FP8 recipe kwargs classes such as `TERecipeKwargs` or `MSAMPRecipeKwargs`
instead.
"""

Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def is_lomo_available():


def is_fp8_available():
return is_msamp_available() or is_transformer_engine_available()
return is_msamp_available() or is_transformer_engine_available() or is_torchao_available()


def is_cuda_available():
Expand Down

0 comments on commit fbeb5a7

Please sign in to comment.