diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index 286008ac..346ad573 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -236,6 +236,7 @@ def load( load_optimizer(optimizer=optimizer, parallel_context=parallel_context, root_folder=root_folder) load_lr_scheduler( lr_scheduler=lr_scheduler, + parallel_context=parallel_context, root_folder=root_folder, ) return checkpoint_metadata diff --git a/src/nanotron/serialize/optimizer.py b/src/nanotron/serialize/optimizer.py index 68a3b1a0..72ed7282 100644 --- a/src/nanotron/serialize/optimizer.py +++ b/src/nanotron/serialize/optimizer.py @@ -30,9 +30,9 @@ def optimizer_filename(parallel_context: ParallelContext, is_zero: bool): return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt" -def lr_scheduler_filename(): +def lr_scheduler_filename(parallel_context: ParallelContext): """The lr_scheduler is the same for all processes.""" - return f"{ObjectType.LR_SCHEDULER.value}.pt" + return f"{ObjectType.LR_SCHEDULER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}.pt" def save_optimizer( @@ -109,9 +109,6 @@ def save_lr_scheduler( root_folder: Path, ): """Saves lr scheduler states""" - if dist.get_rank(parallel_context.world_pg) > 0: - # Only WORLD-RANK 0 saves the lr scheduler state - return root_folder = root_folder / "lr_scheduler" root_folder.mkdir(exist_ok=True, parents=True) @@ -119,7 +116,7 @@ def save_lr_scheduler( # We dump the optimizer state using `torch.save` torch.save( lr_scheduler.state_dict(), - root_folder / lr_scheduler_filename(), + root_folder / lr_scheduler_filename(parallel_context), ) @@ -313,9 +310,10 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) - def load_lr_scheduler( lr_scheduler, + parallel_context: ParallelContext, root_folder: Path, ): root_folder = root_folder / "lr_scheduler" - state_dict = torch.load(root_folder / lr_scheduler_filename()) + state_dict = torch.load(root_folder / lr_scheduler_filename(parallel_context)) lr_scheduler.load_state_dict(state_dict) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index bef629c1..e5f6bde3 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -206,6 +206,7 @@ def __init__( if self.init_checkpoint_path is not None: load_lr_scheduler( lr_scheduler=self.lr_scheduler, + parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path, ) @@ -442,10 +443,10 @@ def train( self.save_checkpoint() dist.barrier() # let's wait for everyone before leaving - + if self.config.checkpoints.save_final_state: self.save_checkpoint() - + self.post_training() def training_step( @@ -864,8 +865,8 @@ def save_checkpoint(self) -> Path: ), # We only save the weights on DP==0 should_save_optimizer=True, should_save_lr_scheduler=bool( - dist.get_rank(self.parallel_context.world_pg) == 0 - ), # We only save the lr_scheduler on world_rank==0 + dist.get_rank(self.parallel_context.dp_pg) == 0 and dist.get_rank(self.parallel_context.tp_pg) + ), # We only save the lr_scheduler on DP==0 && TP==0 should_save_config=bool( dist.get_rank(self.parallel_context.world_pg) == 0 ), # We only save the config on world_rank==0