Skip to content

Commit

Permalink
Bye bye torch <1
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Jan 8, 2025
1 parent 54370d4 commit d18b853
Show file tree
Hide file tree
Showing 8 changed files with 12 additions and 26 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
"packaging>=20.0",
"psutil",
"pyyaml",
"torch>=1.10.0",
"torch>=2.0.0",
"huggingface_hub>=0.21.0",
"safetensors>=0.4.3",
],
Expand Down
2 changes: 0 additions & 2 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1608,8 +1608,6 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
model = apply_fp8_autowrap(model, self.fp8_recipe_handler)
# torch.compile should be called last and only if the model isn't already compiled.
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
if not is_torch_version(">=", "2.0"):
raise ValueError("Using `torch.compile` requires PyTorch 2.0 or higher.")
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
return model

Expand Down
4 changes: 1 addition & 3 deletions src/accelerate/big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
is_mlu_available,
is_musa_available,
is_npu_available,
is_torch_version,
is_xpu_available,
load_checkpoint_in_model,
offload_state_dict,
Expand Down Expand Up @@ -114,8 +113,7 @@ def init_on_device(device: torch.device, include_buffers: bool = None):
if include_buffers is None:
include_buffers = parse_flag_from_env("ACCELERATE_INIT_INCLUDE_BUFFERS", False)

# TODO(shingjan): remove the torch version check once older versions are deprecated
if is_torch_version(">=", "2.0") and include_buffers:
if include_buffers:
with device:
yield
return
Expand Down
6 changes: 1 addition & 5 deletions src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,12 @@
_filter_args,
check_cuda_p2p_ib_support,
convert_dict_to_env_variables,
is_bf16_available,
is_deepspeed_available,
is_mlu_available,
is_musa_available,
is_npu_available,
is_rich_available,
is_sagemaker_available,
is_torch_version,
is_torch_xla_available,
is_xpu_available,
patch_environment,
Expand Down Expand Up @@ -1055,9 +1053,7 @@ def _validate_launch_command(args):
mp_from_config_flag = True
else:
if args.use_cpu or (args.use_xpu and torch.xpu.is_available()):
native_amp = is_torch_version(">=", "1.10")
else:
native_amp = is_bf16_available(True)
native_amp = True
if (
args.mixed_precision == "bf16"
and not native_amp
Expand Down
9 changes: 4 additions & 5 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,12 +718,11 @@ def __init__(
**kwargs,
):
shuffle = False
if is_torch_version(">=", "1.11.0"):
from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe
from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe

# We need to save the shuffling state of the DataPipe
if isinstance(dataset, ShufflerIterDataPipe):
shuffle = dataset._shuffle_enabled
# We need to save the shuffling state of the DataPipe
if isinstance(dataset, ShufflerIterDataPipe):
shuffle = dataset._shuffle_enabled
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
self.split_batches = split_batches
if shuffle:
Expand Down
6 changes: 1 addition & 5 deletions src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
is_torch_xla_available,
is_xpu_available,
)
from .versions import is_torch_version


if is_torch_xla_available():
Expand Down Expand Up @@ -320,10 +319,7 @@ def _tpu_gather_one(tensor):

def _gpu_gather(tensor):
state = PartialState()
if is_torch_version(">=", "1.13"):
gather_op = torch.distributed.all_gather_into_tensor
else:
gather_op = torch.distributed._all_gather_base
gather_op = torch.distributed.all_gather_into_tensor

def _gpu_gather_one(tensor):
if tensor.ndim == 0:
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def is_compiled_module(module):
"""
Check whether the module was compiled with torch.compile()
"""
if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"):
if not hasattr(torch, "_dynamo"):
return False
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)

Expand Down
7 changes: 3 additions & 4 deletions tests/test_big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
slow,
torch_device,
)
from accelerate.utils import is_torch_version, offload_state_dict
from accelerate.utils import offload_state_dict


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -166,9 +166,8 @@ def test_init_empty_weights(self):
with init_empty_weights(include_buffers=True):
module = nn.BatchNorm1d(4)
# nn.Module.register_parameter/buffer shouldn't be changed with torch >= 2.0
if is_torch_version(">=", "2.0"):
assert register_parameter_func == nn.Module.register_parameter
assert register_buffer_func == nn.Module.register_buffer
assert register_parameter_func == nn.Module.register_parameter
assert register_buffer_func == nn.Module.register_buffer
assert module.weight.device == torch.device("meta")
assert module.running_mean.device == torch.device("meta")

Expand Down

0 comments on commit d18b853

Please sign in to comment.