From 348520de5dff4d37ce362d770b205301e0cabc3f Mon Sep 17 00:00:00 2001 From: BurkeHulk Date: Tue, 17 Dec 2024 17:44:54 +0800 Subject: [PATCH] prepare allgather input in advance --- colossalai/zero/low_level/low_level_optim.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 8abaf8fc6b3f..6b69ab133b54 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -544,12 +544,14 @@ def step(self, closure=None): # and should not be updated real_working_params = dict() real_master_params = dict() + params_to_gather_buffer = dict() for group_id in range(self.num_param_groups): master_params = self._master_param_groups_of_current_rank[group_id] working_params = self._working_param_groups[group_id] real_working_params[group_id] = [] real_master_params[group_id] = [] + params_to_gather_buffer[group_id] = [] working_grads = [] for working_param, master_param in zip(working_params, master_params): # if a working param requires grad and has no grad @@ -596,13 +598,20 @@ def step(self, closure=None): pg: TensorBucket(self.pg_to_bucket_store[pg].reduce_bucket_size) for pg in self.pg_to_param_list } + device = get_accelerator().get_current_device() + for group_id in range(self.num_param_groups): + master_working_param = self.optim.param_groups[group_id]["params"] + for idx, master_param in enumerate(master_working_param): + param_to_gather = master_param.to(device).to(self._dtype) + params_to_gather_buffer[group_id].append(param_to_gather) + # update working partition updated by the current rank device = get_accelerator().get_current_device() for group_id in range(self.num_param_groups): master_working_param = self.optim.param_groups[group_id]["params"] for idx, master_param in enumerate(master_working_param): working_param = real_working_params[group_id][idx] - param_to_gather = master_param.to(device).to(self._dtype) + param_to_gather = params_to_gather_buffer[group_id][idx] pg = self.param_to_pg[working_param] padded_working_param = self._working_param_to_padded_working_param[working_param] if self._overlap_allgather: @@ -634,6 +643,8 @@ def step(self, closure=None): if not tensor_bucket.is_empty(): tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication) + del params_to_gather_buffer + def _compute_grad_norm( self, dp_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]], gradients: List[Tensor], norm_type: int = 2 ) -> float: