From eb8c535c1784d99a08c64451a78c1be995a0a815 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Tue, 24 Oct 2023 12:55:06 -0400 Subject: [PATCH] Fix (#2080) --- src/accelerate/checkpointing.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/src/accelerate/checkpointing.py b/src/accelerate/checkpointing.py index 12f9aee55ab..7928adce33e 100644 --- a/src/accelerate/checkpointing.py +++ b/src/accelerate/checkpointing.py @@ -20,7 +20,6 @@ import numpy as np import torch from torch.cuda.amp import GradScaler -from torch.utils.data import BatchSampler from .utils import ( MODEL_NAME, @@ -102,15 +101,13 @@ def save_accelerator_state( sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin" output_sampler_file = os.path.join(output_dir, sampler_name) # Only save if we have our custom sampler - sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) - if sampler_is_batch_sampler: + from .data_loader import IterableDatasetShard, SeedableRandomSampler + + if isinstance(dataloader.dataset, IterableDatasetShard): sampler = dataloader.sampler.sampler - else: - sampler = dataloader.batch_sampler.sampler - from .data_loader import SeedableRandomSampler - if isinstance(sampler, SeedableRandomSampler): - save(sampler, output_sampler_file, save_on_each_node=save_on_each_node) + if isinstance(sampler, SeedableRandomSampler): + save(sampler, output_sampler_file, save_on_each_node=save_on_each_node) logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}") # GradScaler state @@ -203,18 +200,13 @@ def load_accelerator_state( sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin" input_sampler_file = os.path.join(input_dir, sampler_name) # Only load if we have our custom sampler - sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) - if sampler_is_batch_sampler: + from .data_loader import IterableDatasetShard, SeedableRandomSampler + + if isinstance(dataloader.dataset, IterableDatasetShard): sampler = dataloader.sampler.sampler - else: - sampler = dataloader.batch_sampler.sampler - from .data_loader import SeedableRandomSampler - if isinstance(sampler, SeedableRandomSampler): - if sampler_is_batch_sampler: + if isinstance(sampler, SeedableRandomSampler): dataloader.sampler.sampler = torch.load(input_sampler_file) - else: - dataloader.batch_sampler.sampler = torch.load(input_sampler_file) logger.info("All dataloader sampler states loaded successfully") # GradScaler state