Skip to content

Commit

Permalink
Handle model-only checkpoints with the trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Jan 17, 2025
1 parent 9818232 commit 0c096e2
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added new LR schedulers: `LinearWithWarmup`, `InvSqrtWithWarmup`, `ConstantWithWarmup`, `SequentialScheduler`.
- Added option to pre-download checkpoint files from remote storage before trying to load a checkpoint.
- Added a callback for sending Slack notifications.
- The trainer can load model-only checkpoints now.

### Changed

Expand Down
60 changes: 46 additions & 14 deletions src/olmo_core/train/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,17 @@
from ..config import Config
from ..distributed.checkpoint import (
async_save_model_and_optim_state,
get_checkpoint_metadata,
load_model_and_optim_state,
save_model_and_optim_state,
)
from ..distributed.utils import barrier, get_fs_local_rank, get_rank, is_distributed
from ..distributed.utils import (
barrier,
get_fs_local_rank,
get_rank,
is_distributed,
scatter_object,
)
from ..exceptions import OLMoConfigurationError
from ..io import (
clear_directory,
Expand Down Expand Up @@ -146,8 +153,8 @@ def load(
model: nn.Module,
optim: Optimizer,
*,
load_optimizer_state: bool = True,
load_trainer_state: bool = True,
load_optimizer_state: Optional[bool] = None,
load_trainer_state: Optional[bool] = None,
key_mapping: Optional[Dict[str, str]] = None,
) -> Optional[Dict[str, Any]]:
"""
Expand All @@ -158,21 +165,44 @@ def load(

# Maybe load trainer state.
trainer_state: Optional[Dict[str, Any]] = None
if load_trainer_state:
if load_trainer_state is not False:
# Try loading the given rank's state first, then fall back to rank 0 train state if it
# doesn't exist, which can happen when we're restoring a checkpoint with a different world size.
for path in (f"{dir}/train/rank{get_rank()}.pt", f"{dir}/train/rank0.pt"):
try:
trainer_state = torch.load(cached_path(path, quiet=True), weights_only=False)
except FileNotFoundError:
pass

if load_trainer_state is True and trainer_state is None:
raise FileNotFoundError(f"Missing trainer state in checkpoint dir '{dir}'")

# Load model and optimizer state.
model_and_optim_dir: str = f"{dir}/model_and_optim"
if get_rank(self.process_group) == 0:
try:
trainer_state = torch.load(
cached_path(f"{dir}/train/rank{get_rank()}.pt", quiet=True), weights_only=False
)
metadata = get_checkpoint_metadata(model_and_optim_dir)
except FileNotFoundError:
# Fall back to rank 0 train state.
# This can happen when we're restoring a checkpoint with a different world size.
trainer_state = torch.load(
cached_path(f"{dir}/train/rank0.pt", quiet=True), weights_only=False
)
# Try base directory, which could be the case if user is trying to load model weights
# (possibly with optimizer state), and not an actual train checkpoint.
if trainer_state is None:
metadata = get_checkpoint_metadata(dir)
model_and_optim_dir = dir
else:
raise
if load_optimizer_state is None:
for key in metadata.state_dict_metadata.keys():
if key.startswith("optim."):
load_optimizer_state = True
break
else:
load_optimizer_state = False

model_and_optim_dir = scatter_object(model_and_optim_dir, group=self.process_group)
load_optimizer_state = scatter_object(load_optimizer_state, group=self.process_group)

# Load model and optimizer state.
load_model_and_optim_state(
f"{dir}/model_and_optim",
model_and_optim_dir,
model,
optim if load_optimizer_state else None,
process_group=self.process_group,
Expand Down Expand Up @@ -233,6 +263,8 @@ def dir_is_checkpoint(cls, dir: PathOrStr) -> bool:
Check if a directory is a checkpoint directory.
"""
dir = normalize_path(dir)
if file_exists(f"{dir}/.metadata"): # just model (and maybe optim state), no trainer state
return True
paths_to_check = [
f"{dir}/train/rank0.pt",
f"{dir}/model_and_optim/.metadata",
Expand Down
15 changes: 11 additions & 4 deletions src/olmo_core/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,11 @@ def load_state_dict(self, state_dict: TrainerStateDict):
)

def load_checkpoint(
self, dir: PathOrStr, *, load_optimizer_state: bool = True, load_trainer_state: bool = True
self,
dir: PathOrStr,
*,
load_optimizer_state: Optional[bool] = None,
load_trainer_state: Optional[bool] = None,
):
"""
Load a checkpoint.
Expand Down Expand Up @@ -698,8 +702,7 @@ def load_checkpoint(
load_trainer_state=load_trainer_state,
key_mapping=self.load_key_mapping,
)
if load_trainer_state:
assert trainer_state is not None
if trainer_state is not None:
self.load_state_dict(cast(TrainerStateDict, trainer_state))

for callback in self.callbacks.values():
Expand All @@ -709,7 +712,11 @@ def load_checkpoint(
log.info("Checkpoint successfully loaded")

def maybe_load_checkpoint(
self, dir: PathOrStr, *, load_optimizer_state: bool = True, load_trainer_state: bool = True
self,
dir: PathOrStr,
*,
load_optimizer_state: Optional[bool] = None,
load_trainer_state: Optional[bool] = None,
) -> bool:
"""
Like :meth:`load_checkpoint()` but is a no-op if there is no checkpoint in the ``dir`` provided.
Expand Down

0 comments on commit 0c096e2

Please sign in to comment.