From 5dad728f5f08dacd1e06ecce60b37fa822b9ad3e Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Mon, 15 Apr 2024 14:09:51 +0800 Subject: [PATCH] add grad scale for optim_manager --- .github/workflows/build.yml | 7 ++++--- bmtrain/optim/optim_manager.py | 6 +++++- bmtrain/synchronize.py | 8 ++++++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 5fe62ca0..969b5db8 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -6,14 +6,15 @@ on: branches: - 'dev' - 'main' + push: + branches: + - 'dev' jobs: build-archive-wheel: uses: OpenBMB/BMTrain/.github/workflows/build_whl.yml@main - secrets: - DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }} - DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} + secrets: inherit publish: needs: build-archive-wheel diff --git a/bmtrain/optim/optim_manager.py b/bmtrain/optim/optim_manager.py index 7aa1bb81..1a98ed92 100644 --- a/bmtrain/optim/optim_manager.py +++ b/bmtrain/optim/optim_manager.py @@ -52,6 +52,7 @@ def __init__(self, loss_scale_steps : int = 1024, min_loss_scale = 1, max_loss_scale = float("inf"), + grad_scale : Optional[int] = None, ): if loss_scale is not None: self.loss_scale = loss_scale @@ -64,6 +65,9 @@ def __init__(self, self.loss_scale_steps = loss_scale_steps self.min_loss_scale = min_loss_scale self.max_loss_scale = max_loss_scale + if grad_scale is None: + grad_scale = config['zero_size'] + self.grad_scale = grad_scale self.optimizers = [] self.lr_schedulers = [] @@ -85,7 +89,7 @@ def add_optimizer( def scale_loss(self, loss : torch.Tensor) -> torch.Tensor: - return loss * (self.loss_scale / (config['world_size']//(config['tp_size']*config['pipe_size']))) # loss scale + return loss * ( self.loss_scale / self.grad_scale ) # loss scale def backward(self, loss : torch.Tensor): """ diff --git a/bmtrain/synchronize.py b/bmtrain/synchronize.py index d562cc21..2587e0ff 100644 --- a/bmtrain/synchronize.py +++ b/bmtrain/synchronize.py @@ -2,6 +2,7 @@ from . import distributed, nccl from .global_var import config import warnings +from typing import Optional def synchronize(): """ @@ -24,14 +25,17 @@ def wait_loader(): config['calc_stream'].record_event(config['load_event']) -def sum_loss(loss : torch.Tensor): +def sum_loss(loss : torch.Tensor, comm: Optional[nccl.Communicator] = None): """ Sum the loss across all workers. This is a helper function to reduce the loss across all workers. """ + if comm is None: + comm = config['comm'] warnings.warn("bmtrain.sum_loss is deprecated and will be removed in later version. Use bmtrain.distributed.all_reduce instead.", DeprecationWarning) - return distributed.all_reduce(loss, "sum") / config['world_size'] + + return distributed.all_reduce(loss, "avg", comm) def gather_result(result: torch.Tensor): warnings.warn("bmtrain.gather_result is deprecated and will be removed in later version. Use bmtrain.distributed.all_gather instead.", DeprecationWarning)