From fbeb5a7bc27c020a0c8e30ebad44fc6fe48571b9 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 16 Jan 2025 10:55:29 -0500 Subject: [PATCH] Clean --- benchmarks/fp8/torchao/non_distributed.py | 35 +++++++++++++---------- src/accelerate/accelerator.py | 24 ++++++++++------ src/accelerate/utils/__init__.py | 8 +++--- src/accelerate/utils/ao.py | 28 ++++++++---------- src/accelerate/utils/dataclasses.py | 34 +++++++++------------- src/accelerate/utils/imports.py | 2 +- 6 files changed, 67 insertions(+), 64 deletions(-) diff --git a/benchmarks/fp8/torchao/non_distributed.py b/benchmarks/fp8/torchao/non_distributed.py index 81eb0d2bc73..e2426f162d5 100644 --- a/benchmarks/fp8/torchao/non_distributed.py +++ b/benchmarks/fp8/torchao/non_distributed.py @@ -18,9 +18,10 @@ This particular script verifies this for single GPU training. """ +from functools import partial + import evaluate import torch -from functools import partial from datasets import load_dataset from torch.optim import AdamW from torch.utils.data import DataLoader @@ -28,6 +29,7 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup from accelerate import Accelerator +from accelerate.state import AcceleratorState from accelerate.utils import AORecipeKwargs, set_seed @@ -169,8 +171,10 @@ def train_baseline(): def train_integration(): set_seed(42) - model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[AORecipeKwargs()]) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( + MODEL_NAME, accelerator=accelerator + ) model = accelerator.prepare(model) base_model_results = evaluate_model(model, eval_dataloader, METRIC) model.train() @@ -196,17 +200,18 @@ def train_integration(): if __name__ == "__main__": - # baseline_not_trained, baseline_trained = train_baseline() + baseline_not_trained, baseline_trained = train_baseline() + AcceleratorState._reset_state(True) accelerator_not_trained, accelerator_trained = train_integration() - # assert ( - # baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] - # ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' - # assert ( - # baseline_not_trained["f1"] == accelerator_not_trained["f1"] - # ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' - # assert ( - # baseline_trained["accuracy"] == accelerator_trained["accuracy"] - # ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' - # assert ( - # baseline_trained["f1"] == accelerator_trained["f1"] - # ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 923f2693764..d3ba6f4f74a 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -29,11 +29,12 @@ from types import MethodType from typing import Any, Callable, Union -from accelerate.utils.imports import is_torchao_available import torch import torch.utils.hooks as hooks from huggingface_hub import split_torch_state_dict_into_shards +from accelerate.utils.imports import is_torchao_available + from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches from .logging import get_logger @@ -49,10 +50,8 @@ WEIGHTS_INDEX_NAME, WEIGHTS_NAME, WEIGHTS_PATTERN_NAME, - AutocastKwargs, AORecipeKwargs, - TERecipeKwargs, - MSAMPRecipeKwargs, + AutocastKwargs, DataLoaderConfiguration, DeepSpeedPlugin, DistributedDataParallelKwargs, @@ -66,18 +65,20 @@ KwargsHandler, LoggerType, MegatronLMPlugin, + MSAMPRecipeKwargs, PrecisionType, ProfileKwargs, ProjectConfiguration, RNGType, + TERecipeKwargs, TorchDynamoPlugin, apply_fp8_autowrap, check_os_kernel, clean_state_dict_for_safetensors, compare_versions, convert_model, - convert_to_float8_training, convert_outputs_to_fp32, + convert_to_float8_training, ensure_weights_retied, extract_model_from_parallel, gather, @@ -442,7 +443,9 @@ def __init__( elif is_msamp_available(): self.msamp_recipe_handler = MSAMPRecipeKwargs() else: - raise ImportError("Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed.") + raise ImportError( + "Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed." + ) self.delayed_fp8_autocast = False if self.has_fp8_handler: @@ -1627,8 +1630,13 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e def _prepare_ao(self, *args): if not is_torchao_available(): raise ImportError("`torchao` was not found on your system. Please ensure that `torchao` is installed") - for model in self._models: - convert_to_float8_training(model, config=self.ao_recipe_handler.config, module_filter_func=self.ao_recipe_handler.module_filter_func) + for arg in args: + if isinstance(arg, torch.nn.Module): + convert_to_float8_training( + arg, + config=self.ao_recipe_handler.config, + module_filter_func=self.ao_recipe_handler.module_filter_func, + ) return args def _prepare_te(self, *args): diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 6219adbfee0..502c9c04b88 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -33,8 +33,8 @@ XPU_PROFILING_AVAILABLE_PYTORCH_VERSION, ) from .dataclasses import ( - AutocastKwargs, AORecipeKwargs, + AutocastKwargs, BnbQuantizationConfig, ComputeEnvironment, CustomDtype, @@ -52,15 +52,15 @@ KwargsHandler, LoggerType, MegatronLMPlugin, + MSAMPRecipeKwargs, PrecisionType, ProfileKwargs, ProjectConfiguration, RNGType, SageMakerDistributedType, TensorInformation, - TorchDynamoPlugin, TERecipeKwargs, - MSAMPRecipeKwargs, + TorchDynamoPlugin, add_model_config_to_megatron_parser, ) from .environment import ( @@ -81,7 +81,6 @@ ) from .imports import ( deepspeed_required, - torchao_required, get_ccl_version, is_4bit_bnb_available, is_8bit_bnb_available, @@ -129,6 +128,7 @@ is_wandb_available, is_weights_only_available, is_xpu_available, + torchao_required, ) from .modeling import ( align_module_device, diff --git a/src/accelerate/utils/ao.py b/src/accelerate/utils/ao.py index 1d21738c495..e0a2cf93d73 100644 --- a/src/accelerate/utils/ao.py +++ b/src/accelerate/utils/ao.py @@ -27,9 +27,7 @@ def find_first_last_linear_layers(model: torch.nn.Module): """ Finds the first and last linear layer names in a model. - This is needed during FP8 to avoid issues with - instability by keeping the first and last layers - unquantized. + This is needed during FP8 to avoid issues with instability by keeping the first and last layers unquantized. Ref: https://x.com/xariusrke/status/1826669142604141052 """ @@ -72,31 +70,29 @@ def filter_linear_layers(module, layer_name, first_layer_name, last_layer_name) @torchao_required def convert_to_float8_training( - model: torch.nn.Module, - config=None, - module_filter_func=None, - ): + model: torch.nn.Module, + config=None, + module_filter_func=None, +): """ - Converts all `nn.Linear` layers in the model (except the first and last) - to torchao's `Float8Linear` layer inplace. + Converts all `nn.Linear` layers in the model (except the first and last) to torchao's `Float8Linear` layer inplace. Args: model (`torch.nn.Module`): The model to convert. config (`torchao.float8.Float8LinearConfig`, *optional*): The configuration for the FP8 training. Recommended to utilize - `torchao.float8.recipe_name_to_linear_config` to generate this. - In general, the default config should be sufficient. + `torchao.float8.recipe_name_to_linear_config` to generate this. In general, the default config should be + sufficient. module_filter_func (`Callable`, *optional*): - Optional function that must take in a module and layer name, - and returns a boolean indicating whether the module should be - converted to FP8. Defaults to `filter_linear_layers`. See - it for an example. + Optional function that must take in a module and layer name, and returns a boolean indicating whether the + module should be converted to FP8. Defaults to `filter_linear_layers`. See it for an example. Example: ```python from accelerate.utils.ao import convert_to_float8_training + model = MyModel() model.to("cuda") convert_to_float8_training(model) @@ -109,4 +105,4 @@ def convert_to_float8_training( first_linear, last_linear = find_first_last_linear_layers(model) if module_filter_func is None: module_filter_func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear) - convert_to_float8_training(model, config, module_filter_func) + convert_to_float8_training(model, module_filter_fn=module_filter_func, config=config) diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index f1240b7aeb8..9aaf35308ff 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -20,17 +20,16 @@ import copy import enum import functools +import logging import os import warnings -import logging from contextlib import contextmanager from dataclasses import dataclass, field from datetime import timedelta -from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union, get_args, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union, get_args import torch -from .ao import filter_linear_layers from .constants import ( FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, @@ -49,6 +48,7 @@ ) from .versions import compare_versions, is_torch_version + if TYPE_CHECKING: # Mock imports for type checking from torchao.float8 import Float8LinearConfig @@ -296,25 +296,20 @@ class AORecipeKwargs(KwargsHandler): Args: recipe_name (`str`, *optional*, default to `None`): - The name of the recipe to use for FP8 training. Should - be compatible with `torchao.float8.recipe_name_to_linear_config`. + The name of the recipe to use for FP8 training. Should be compatible with + `torchao.float8.recipe_name_to_linear_config`. config (`torchao.float8.Float8LinearConfig`, *optional*, default to `None`): - The configuration for the FP8 training. In general, the default config - should be sufficient. + The configuration for the FP8 training. In general, the default config should be sufficient. module_filter_func (`Callable`, *optional*, default to `None`): - Optional function that must take in a module and layer name, - and returns a boolean indicating whether the module should be - converted to FP8. Defaults to `accelerate.utils.ao.filter_linear_layers`. See - it for an example. + Optional function that must take in a module and layer name, and returns a boolean indicating whether the + module should be converted to FP8. Defaults to `accelerate.utils.ao.filter_linear_layers`. See it for an + example. """ + recipe_name: str = None config: "Float8LinearConfig" = None module_filter_func: Callable = None - def __post_init__(self): - if self.module_filter_func is None: - self.module_filter_func = filter_linear_layers - @dataclass class TERecipeKwargs(KwargsHandler): @@ -354,6 +349,7 @@ class TERecipeKwargs(KwargsHandler): override_linear_precision (`tuple` of three `bool`, *optional*, default to `(False, False, False)`): Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision. """ + use_autocast_during_eval: bool = None margin: int = None interval: int = None @@ -365,9 +361,7 @@ class TERecipeKwargs(KwargsHandler): def __post_init__(self): env_prefix = "ACCELERATE_FP8_" if not is_transformer_engine_available(): - raise ImportError( - "TransformerEngine is not available. Please install it or use a different backend." - ) + raise ImportError("TransformerEngine is not available. Please install it or use a different backend.") if self.use_autocast_during_eval is None: self.use_autocast_during_eval = parse_flag_from_env(env_prefix + "USE_AUTOCAST_DURING_EVAL") if self.margin is None: @@ -399,6 +393,7 @@ class MSAMPRecipeKwargs(KwargsHandler): Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision training with `ms-amp`. """ + opt_level: OptLevel = None def __post_init__(self): @@ -412,8 +407,7 @@ def __post_init__(self): @dataclass class FP8RecipeKwargs(TERecipeKwargs, MSAMPRecipeKwargs): """ - Deprecated. Please use one of the proper FP8 recipe - kwargs classes such as `TERecipeKwargs` or `MSAMPRecipeKwargs` + Deprecated. Please use one of the proper FP8 recipe kwargs classes such as `TERecipeKwargs` or `MSAMPRecipeKwargs` instead. """ diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 3ba98691902..7653f36e60d 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -110,7 +110,7 @@ def is_lomo_available(): def is_fp8_available(): - return is_msamp_available() or is_transformer_engine_available() + return is_msamp_available() or is_transformer_engine_available() or is_torchao_available() def is_cuda_available():