Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[enhance] make input datatype ready for allgather #6162

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,12 +544,14 @@ def step(self, closure=None):
# and should not be updated
real_working_params = dict()
real_master_params = dict()
params_to_gather_buffer = dict()

for group_id in range(self.num_param_groups):
master_params = self._master_param_groups_of_current_rank[group_id]
working_params = self._working_param_groups[group_id]
real_working_params[group_id] = []
real_master_params[group_id] = []
params_to_gather_buffer[group_id] = []
working_grads = []
for working_param, master_param in zip(working_params, master_params):
# if a working param requires grad and has no grad
Expand Down Expand Up @@ -596,13 +598,20 @@ def step(self, closure=None):
pg: TensorBucket(self.pg_to_bucket_store[pg].reduce_bucket_size) for pg in self.pg_to_param_list
}

device = get_accelerator().get_current_device()
for group_id in range(self.num_param_groups):
master_working_param = self.optim.param_groups[group_id]["params"]
for idx, master_param in enumerate(master_working_param):
param_to_gather = master_param.to(device).to(self._dtype)
params_to_gather_buffer[group_id].append(param_to_gather)

# update working partition updated by the current rank
device = get_accelerator().get_current_device()
for group_id in range(self.num_param_groups):
master_working_param = self.optim.param_groups[group_id]["params"]
for idx, master_param in enumerate(master_working_param):
working_param = real_working_params[group_id][idx]
param_to_gather = master_param.to(device).to(self._dtype)
param_to_gather = params_to_gather_buffer[group_id][idx]
pg = self.param_to_pg[working_param]
padded_working_param = self._working_param_to_padded_working_param[working_param]
if self._overlap_allgather:
Expand Down Expand Up @@ -634,6 +643,8 @@ def step(self, closure=None):
if not tensor_bucket.is_empty():
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)

del params_to_gather_buffer

def _compute_grad_norm(
self, dp_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]], gradients: List[Tensor], norm_type: int = 2
) -> float:
Expand Down
Loading