Skip to content

Commit

Permalink
Fix pp naming
Browse files Browse the repository at this point in the history
  • Loading branch information
TJ-Solergibert committed Sep 6, 2024
1 parent 1456446 commit 4d61489
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
1 change: 1 addition & 0 deletions src/nanotron/serialize/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def load(
load_optimizer(optimizer=optimizer, parallel_context=parallel_context, root_folder=root_folder)
load_lr_scheduler(
lr_scheduler=lr_scheduler,
parallel_context=parallel_context,
root_folder=root_folder,
)
return checkpoint_metadata
Expand Down
12 changes: 5 additions & 7 deletions src/nanotron/serialize/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def optimizer_filename(parallel_context: ParallelContext, is_zero: bool):
return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"


def lr_scheduler_filename():
def lr_scheduler_filename(parallel_context: ParallelContext):
"""The lr_scheduler is the same for all processes."""
return f"{ObjectType.LR_SCHEDULER.value}.pt"
return f"{ObjectType.LR_SCHEDULER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}.pt"


def save_optimizer(
Expand Down Expand Up @@ -109,17 +109,14 @@ def save_lr_scheduler(
root_folder: Path,
):
"""Saves lr scheduler states"""
if dist.get_rank(parallel_context.world_pg) > 0:
# Only WORLD-RANK 0 saves the lr scheduler state
return

root_folder = root_folder / "lr_scheduler"
root_folder.mkdir(exist_ok=True, parents=True)

# We dump the optimizer state using `torch.save`
torch.save(
lr_scheduler.state_dict(),
root_folder / lr_scheduler_filename(),
root_folder / lr_scheduler_filename(parallel_context),
)


Expand Down Expand Up @@ -313,9 +310,10 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -

def load_lr_scheduler(
lr_scheduler,
parallel_context: ParallelContext,
root_folder: Path,
):
root_folder = root_folder / "lr_scheduler"

state_dict = torch.load(root_folder / lr_scheduler_filename())
state_dict = torch.load(root_folder / lr_scheduler_filename(parallel_context))
lr_scheduler.load_state_dict(state_dict)
9 changes: 5 additions & 4 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def __init__(
if self.init_checkpoint_path is not None:
load_lr_scheduler(
lr_scheduler=self.lr_scheduler,
parallel_context=self.parallel_context,
root_folder=self.init_checkpoint_path,
)

Expand Down Expand Up @@ -442,10 +443,10 @@ def train(
self.save_checkpoint()

dist.barrier() # let's wait for everyone before leaving

if self.config.checkpoints.save_final_state:
self.save_checkpoint()

self.post_training()

def training_step(
Expand Down Expand Up @@ -864,8 +865,8 @@ def save_checkpoint(self) -> Path:
), # We only save the weights on DP==0
should_save_optimizer=True,
should_save_lr_scheduler=bool(
dist.get_rank(self.parallel_context.world_pg) == 0
), # We only save the lr_scheduler on world_rank==0
dist.get_rank(self.parallel_context.dp_pg) == 0 and dist.get_rank(self.parallel_context.tp_pg)
), # We only save the lr_scheduler on DP==0 && TP==0
should_save_config=bool(
dist.get_rank(self.parallel_context.world_pg) == 0
), # We only save the config on world_rank==0
Expand Down

0 comments on commit 4d61489

Please sign in to comment.