diff --git a/CHANGELOG.md b/CHANGELOG.md index c85f4245..a82f5f75 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added new LR schedulers: `LinearWithWarmup`, `InvSqrtWithWarmup`, `ConstantWithWarmup`, `SequentialScheduler`. - Added option to pre-download checkpoint files from remote storage before trying to load a checkpoint. - Added a callback for sending Slack notifications. +- The trainer can load model-only checkpoints now. ### Changed diff --git a/src/olmo_core/train/checkpoint.py b/src/olmo_core/train/checkpoint.py index b262fe81..9eb5a373 100644 --- a/src/olmo_core/train/checkpoint.py +++ b/src/olmo_core/train/checkpoint.py @@ -19,10 +19,17 @@ from ..config import Config from ..distributed.checkpoint import ( async_save_model_and_optim_state, + get_checkpoint_metadata, load_model_and_optim_state, save_model_and_optim_state, ) -from ..distributed.utils import barrier, get_fs_local_rank, get_rank, is_distributed +from ..distributed.utils import ( + barrier, + get_fs_local_rank, + get_rank, + is_distributed, + scatter_object, +) from ..exceptions import OLMoConfigurationError from ..io import ( clear_directory, @@ -146,8 +153,8 @@ def load( model: nn.Module, optim: Optimizer, *, - load_optimizer_state: bool = True, - load_trainer_state: bool = True, + load_optimizer_state: Optional[bool] = None, + load_trainer_state: Optional[bool] = None, key_mapping: Optional[Dict[str, str]] = None, ) -> Optional[Dict[str, Any]]: """ @@ -158,21 +165,44 @@ def load( # Maybe load trainer state. trainer_state: Optional[Dict[str, Any]] = None - if load_trainer_state: + if load_trainer_state is not False: + # Try loading the given rank's state first, then fall back to rank 0 train state if it + # doesn't exist, which can happen when we're restoring a checkpoint with a different world size. + for path in (f"{dir}/train/rank{get_rank()}.pt", f"{dir}/train/rank0.pt"): + try: + trainer_state = torch.load(cached_path(path, quiet=True), weights_only=False) + except FileNotFoundError: + pass + + if load_trainer_state is True and trainer_state is None: + raise FileNotFoundError(f"Missing trainer state in checkpoint dir '{dir}'") + + # Load model and optimizer state. + model_and_optim_dir: str = f"{dir}/model_and_optim" + if get_rank(self.process_group) == 0: try: - trainer_state = torch.load( - cached_path(f"{dir}/train/rank{get_rank()}.pt", quiet=True), weights_only=False - ) + metadata = get_checkpoint_metadata(model_and_optim_dir) except FileNotFoundError: - # Fall back to rank 0 train state. - # This can happen when we're restoring a checkpoint with a different world size. - trainer_state = torch.load( - cached_path(f"{dir}/train/rank0.pt", quiet=True), weights_only=False - ) + # Try base directory, which could be the case if user is trying to load model weights + # (possibly with optimizer state), and not an actual train checkpoint. + if trainer_state is None: + metadata = get_checkpoint_metadata(dir) + model_and_optim_dir = dir + else: + raise + if load_optimizer_state is None: + for key in metadata.state_dict_metadata.keys(): + if key.startswith("optim."): + load_optimizer_state = True + break + else: + load_optimizer_state = False + + model_and_optim_dir = scatter_object(model_and_optim_dir, group=self.process_group) + load_optimizer_state = scatter_object(load_optimizer_state, group=self.process_group) - # Load model and optimizer state. load_model_and_optim_state( - f"{dir}/model_and_optim", + model_and_optim_dir, model, optim if load_optimizer_state else None, process_group=self.process_group, @@ -233,6 +263,8 @@ def dir_is_checkpoint(cls, dir: PathOrStr) -> bool: Check if a directory is a checkpoint directory. """ dir = normalize_path(dir) + if file_exists(f"{dir}/.metadata"): # just model (and maybe optim state), no trainer state + return True paths_to_check = [ f"{dir}/train/rank0.pt", f"{dir}/model_and_optim/.metadata", diff --git a/src/olmo_core/train/trainer.py b/src/olmo_core/train/trainer.py index 0c8d17aa..213f3277 100644 --- a/src/olmo_core/train/trainer.py +++ b/src/olmo_core/train/trainer.py @@ -668,7 +668,11 @@ def load_state_dict(self, state_dict: TrainerStateDict): ) def load_checkpoint( - self, dir: PathOrStr, *, load_optimizer_state: bool = True, load_trainer_state: bool = True + self, + dir: PathOrStr, + *, + load_optimizer_state: Optional[bool] = None, + load_trainer_state: Optional[bool] = None, ): """ Load a checkpoint. @@ -698,8 +702,7 @@ def load_checkpoint( load_trainer_state=load_trainer_state, key_mapping=self.load_key_mapping, ) - if load_trainer_state: - assert trainer_state is not None + if trainer_state is not None: self.load_state_dict(cast(TrainerStateDict, trainer_state)) for callback in self.callbacks.values(): @@ -709,7 +712,11 @@ def load_checkpoint( log.info("Checkpoint successfully loaded") def maybe_load_checkpoint( - self, dir: PathOrStr, *, load_optimizer_state: bool = True, load_trainer_state: bool = True + self, + dir: PathOrStr, + *, + load_optimizer_state: Optional[bool] = None, + load_trainer_state: Optional[bool] = None, ) -> bool: """ Like :meth:`load_checkpoint()` but is a no-op if there is no checkpoint in the ``dir`` provided.