From 19712a1c75bfc1da4a7f3ecca6915a86af671568 Mon Sep 17 00:00:00 2001 From: xusong28 Date: Wed, 23 Oct 2024 20:59:43 +0800 Subject: [PATCH 1/3] Faster and more memory-efficient impl of zero_to_fp32 --- deepspeed/utils/zero_to_fp32.py | 99 +++++++++++++++++++++++++++------ 1 file changed, 83 insertions(+), 16 deletions(-) diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index e69ecd9acb5a..e98f71fb73cc 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -21,7 +21,9 @@ import math import os import re +import gc import json +import numpy as np from tqdm import tqdm from collections import OrderedDict from dataclasses import dataclass @@ -146,8 +148,8 @@ def parse_model_states(files): def parse_optim_states(files, ds_checkpoint_dir): total_files = len(files) state_dicts = [] - for f in files: - state_dict = torch.load(f, map_location=device) + for f in tqdm(files, desc='Loading checkpoint shards'): + state_dict = torch.load(f, map_location=device, mmap=True) # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights # and also handle the case where it was already removed by another helper script state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) @@ -184,12 +186,8 @@ def parse_optim_states(files, ds_checkpoint_dir): elif zero_stage == 3: # if there is more than one param group, there will be multiple flattened tensors - one # flattened tensor per group - for simplicity merge them into a single tensor - # - # XXX: could make the script more memory efficient for when there are multiple groups - it - # will require matching the sub-lists of param_shapes for each param group flattened tensor - fp32_flat_groups = [ - torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts)) + state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts)) ] return zero_stage, world_size, fp32_flat_groups @@ -305,6 +303,7 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero if debug: print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) + # tensor = GatheredTensor(fp32_flat_groups, offset, partitioned_numel, shape) offset += unpartitioned_numel # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and @@ -398,9 +397,71 @@ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") +class GatheredTensor: + """ + A pseudo tensor that collects partitioned weights. + It is more memory efficient when there are multiple groups. + """ + def __init__(self, flat_groups, offset, partitioned_numel, shape): + self.flat_groups = flat_groups + self.offset = offset + self.partitioned_numel = partitioned_numel + self.shape = shape + + def contiguous(self): + """ + Merge partitioned weights from flat_groups into a single tensor. + """ + end_idx = self.offset + self.partitioned_numel + world_size = len(self.flat_groups) + pad_flat_param_chunks = [] + for rank in range(world_size): + flat_group = self.flat_groups[rank] + tensor_slice = [] + start_part = None + end_part = None + sum_idx = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in flat_group])) + for i in range(len(sum_idx)): + if sum_idx[i] <= self.offset < sum_idx[i + 1]: + start_part = i + if sum_idx[i] < end_idx <= sum_idx[i + 1]: + end_part = i + break + for i in range(start_part, end_part + 1): + flat_tensor = flat_group[i] + start_offset = self.offset - sum_idx[i] + end_offset = min(end_idx, sum_idx[i + 1]) - sum_idx[i] + tensor_slice.append(flat_tensor[start_offset:end_offset]) + + pad_flat_param_chunks.append(torch.concat(tensor_slice, 0)) + pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0).to(torch.float16) + param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous() + return param + + # The following part make it compatible with `huggingface_hub.split_torch_state_dict_into_shards` + # https://github.com/huggingface/huggingface_hub/blob/src/huggingface_hub/serialization/_torch.py + def storage(self): + return self + + def data_ptr(self): + return self.offset * 100000 + self.partitioned_numel + + def size(self): + return self.shape.numel() + + @property + def dtype(self): + return self.flat_groups[0][0].dtype + + @property + def device(self): + return self.flat_groups[0][0].device + + def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): param_shapes = zero_model_states[0].param_shapes - avail_numel = fp32_flat_groups[0].numel() * world_size + avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each # param, re-consolidating each param, while dealing with padding if any @@ -424,7 +485,7 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero offset = 0 total_numel = 0 total_params = 0 - for name, shape in tqdm(param_shapes.items(), desc='Gathering Sharded Weights'): + for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'): unpartitioned_numel = shape.numel() total_numel += unpartitioned_numel total_params += 1 @@ -435,10 +496,9 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" ) - # XXX: memory usage doubles here - state_dict[name] = torch.cat( - tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)), - 0).narrow(0, 0, unpartitioned_numel).view(shape) + # memory efficient tensor + tensor = GatheredTensor(fp32_flat_groups, offset, partitioned_numel, shape) + state_dict[name] = tensor offset += partitioned_numel offset *= world_size @@ -541,6 +601,7 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` - ``exclude_frozen_parameters``: exclude frozen parameters """ + # Dependency pre-check if safe_serialization: try: @@ -571,15 +632,21 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, state_dict_split = StateDictSplit(is_sharded=False, filename_to_tensors={weights_name: list(state_dict.keys())}) - # Save the model + # Save the model by shard + os.makedirs(output_dir, exist_ok=True) filename_to_tensors = state_dict_split.filename_to_tensors.items() for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"): - shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors} + shard = {tensor_name: state_dict[tensor_name].contiguous() for tensor_name in tensors} output_path = os.path.join(output_dir, shard_file) if safe_serialization: save_file(shard, output_path, metadata={"format": "pt"}) else: torch.save(shard, output_path) + # release the memory of current shard + for tensor_name in shard: + del state_dict[tensor_name] + gc.collect() + # Save index if sharded if state_dict_split.is_sharded: @@ -671,4 +738,4 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): max_shard_size=args.max_shard_size, safe_serialization=args.safe_serialization, tag=args.tag, - exclude_frozen_parameters=args.exclude_frozen_parameters) + exclude_frozen_parameters=args.exclude_frozen_parameters) \ No newline at end of file From 6d00a62acc7307788fecf004cb28278a91391446 Mon Sep 17 00:00:00 2001 From: Xu Song Date: Thu, 24 Oct 2024 10:41:28 +0800 Subject: [PATCH 2/3] Update zero_to_fp32.py --- deepspeed/utils/zero_to_fp32.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index e98f71fb73cc..f2eb6e88f45e 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -303,7 +303,6 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero if debug: print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) - # tensor = GatheredTensor(fp32_flat_groups, offset, partitioned_numel, shape) offset += unpartitioned_numel # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and @@ -434,16 +433,24 @@ def contiguous(self): tensor_slice.append(flat_tensor[start_offset:end_offset]) pad_flat_param_chunks.append(torch.concat(tensor_slice, 0)) - pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0).to(torch.float16) + pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0) param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous() return param # The following part make it compatible with `huggingface_hub.split_torch_state_dict_into_shards` - # https://github.com/huggingface/huggingface_hub/blob/src/huggingface_hub/serialization/_torch.py + # https://github.com/huggingface/huggingface_hub/blob/v0.26.1/src/huggingface_hub/serialization/_torch.py#L335 + # 1. get_torch_storage_id + # tensor.storage().data_ptr() is part of torch_storage_id + # refer to https://github.com/huggingface/huggingface_hub/blob/v0.26.1/src/huggingface_hub/serialization/_torch.py#L385 + # 2. get_torch_storage_size + # tensor.zize() is called by get_torch_storage_size def storage(self): return self def data_ptr(self): + """ + a naive implemtation of storage id, can be optimized. + """ return self.offset * 100000 + self.partitioned_numel def size(self): @@ -738,4 +745,4 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): max_shard_size=args.max_shard_size, safe_serialization=args.safe_serialization, tag=args.tag, - exclude_frozen_parameters=args.exclude_frozen_parameters) \ No newline at end of file + exclude_frozen_parameters=args.exclude_frozen_parameters) From c81dd8e39e2e1775e1c2fc714d7b555fd4cc7421 Mon Sep 17 00:00:00 2001 From: Xu Song Date: Sun, 27 Oct 2024 14:55:22 +0800 Subject: [PATCH 3/3] remove hf_hub compatible feature --- deepspeed/utils/zero_to_fp32.py | 65 ++++++++++++--------------------- 1 file changed, 23 insertions(+), 42 deletions(-) diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index f2eb6e88f45e..048ab9c373aa 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -406,6 +406,7 @@ def __init__(self, flat_groups, offset, partitioned_numel, shape): self.offset = offset self.partitioned_numel = partitioned_numel self.shape = shape + self.dtype = self.flat_groups[0][0].dtype def contiguous(self): """ @@ -414,55 +415,34 @@ def contiguous(self): end_idx = self.offset + self.partitioned_numel world_size = len(self.flat_groups) pad_flat_param_chunks = [] - for rank in range(world_size): - flat_group = self.flat_groups[rank] + for rank_i in range(world_size): + # for each rank, we need to collect weights from related group/groups + flat_groups_at_rank_i = self.flat_groups[rank_i] tensor_slice = [] - start_part = None - end_part = None - sum_idx = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in flat_group])) - for i in range(len(sum_idx)): - if sum_idx[i] <= self.offset < sum_idx[i + 1]: - start_part = i - if sum_idx[i] < end_idx <= sum_idx[i + 1]: - end_part = i + start_group_id = None + end_group_id = None + # get the start_group_id and end_group_id, can be moved outside if all ranks get the same offset + offset_of_groups = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in flat_groups_at_rank_i])) + for group_id in range(len(offset_of_groups)): + if offset_of_groups[group_id] <= self.offset < offset_of_groups[group_id + 1]: + start_group_id = group_id + if offset_of_groups[group_id] < end_idx <= offset_of_groups[group_id + 1]: + end_group_id = group_id break - for i in range(start_part, end_part + 1): - flat_tensor = flat_group[i] - start_offset = self.offset - sum_idx[i] - end_offset = min(end_idx, sum_idx[i + 1]) - sum_idx[i] + # collect weights from related group/groups + for group_id in range(start_group_id, end_group_id + 1): + flat_tensor = flat_groups_at_rank_i[group_id] + start_offset = self.offset - offset_of_groups[group_id] + end_offset = min(end_idx, offset_of_groups[group_id + 1]) - offset_of_groups[group_id] tensor_slice.append(flat_tensor[start_offset:end_offset]) - pad_flat_param_chunks.append(torch.concat(tensor_slice, 0)) + + # collect weights from all ranks pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0) param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous() return param - # The following part make it compatible with `huggingface_hub.split_torch_state_dict_into_shards` - # https://github.com/huggingface/huggingface_hub/blob/v0.26.1/src/huggingface_hub/serialization/_torch.py#L335 - # 1. get_torch_storage_id - # tensor.storage().data_ptr() is part of torch_storage_id - # refer to https://github.com/huggingface/huggingface_hub/blob/v0.26.1/src/huggingface_hub/serialization/_torch.py#L385 - # 2. get_torch_storage_size - # tensor.zize() is called by get_torch_storage_size - def storage(self): - return self - - def data_ptr(self): - """ - a naive implemtation of storage id, can be optimized. - """ - return self.offset * 100000 + self.partitioned_numel - - def size(self): - return self.shape.numel() - - @property - def dtype(self): - return self.flat_groups[0][0].dtype - - @property - def device(self): - return self.flat_groups[0][0].device + def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): @@ -630,7 +610,8 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin" if max_shard_size is not None: filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors") - state_dict_split = split_torch_state_dict_into_shards(state_dict, + mock_state_dict = {name: torch.empty(tensor.shape, dtype=tensor.dtype) for name, tensor in state_dict.items()} + state_dict_split = split_torch_state_dict_into_shards(mock_state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size) else: