Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A faster and more memory-efficient implementation of zero_to_fp32 #6658

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 89 additions & 15 deletions deepspeed/utils/zero_to_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import math
import os
import re
import gc
import json
import numpy as np
from tqdm import tqdm
from collections import OrderedDict
from dataclasses import dataclass
Expand Down Expand Up @@ -146,8 +148,8 @@ def parse_model_states(files):
def parse_optim_states(files, ds_checkpoint_dir):
total_files = len(files)
state_dicts = []
for f in files:
state_dict = torch.load(f, map_location=device)
for f in tqdm(files, desc='Loading checkpoint shards'):
state_dict = torch.load(f, map_location=device, mmap=True)
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
# and also handle the case where it was already removed by another helper script
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
Expand Down Expand Up @@ -184,12 +186,8 @@ def parse_optim_states(files, ds_checkpoint_dir):
elif zero_stage == 3:
# if there is more than one param group, there will be multiple flattened tensors - one
# flattened tensor per group - for simplicity merge them into a single tensor
#
# XXX: could make the script more memory efficient for when there are multiple groups - it
# will require matching the sub-lists of param_shapes for each param group flattened tensor

fp32_flat_groups = [
torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))
]

return zero_stage, world_size, fp32_flat_groups
Expand Down Expand Up @@ -398,9 +396,79 @@ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")


class GatheredTensor:
"""
A pseudo tensor that collects partitioned weights.
It is more memory efficient when there are multiple groups.
"""
def __init__(self, flat_groups, offset, partitioned_numel, shape):
self.flat_groups = flat_groups
self.offset = offset
self.partitioned_numel = partitioned_numel
self.shape = shape

def contiguous(self):
"""
Merge partitioned weights from flat_groups into a single tensor.
"""
end_idx = self.offset + self.partitioned_numel
world_size = len(self.flat_groups)
pad_flat_param_chunks = []
for rank in range(world_size):
flat_group = self.flat_groups[rank]
tensor_slice = []
start_part = None
end_part = None
sum_idx = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in flat_group]))
for i in range(len(sum_idx)):
if sum_idx[i] <= self.offset < sum_idx[i + 1]:
start_part = i
if sum_idx[i] < end_idx <= sum_idx[i + 1]:
end_part = i
break
for i in range(start_part, end_part + 1):
flat_tensor = flat_group[i]
start_offset = self.offset - sum_idx[i]
end_offset = min(end_idx, sum_idx[i + 1]) - sum_idx[i]
tensor_slice.append(flat_tensor[start_offset:end_offset])

pad_flat_param_chunks.append(torch.concat(tensor_slice, 0))
pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
return param

# The following part make it compatible with `huggingface_hub.split_torch_state_dict_into_shards`
# https://github.com/huggingface/huggingface_hub/blob/v0.26.1/src/huggingface_hub/serialization/_torch.py#L335
# 1. get_torch_storage_id
# tensor.storage().data_ptr() is part of torch_storage_id
# refer to https://github.com/huggingface/huggingface_hub/blob/v0.26.1/src/huggingface_hub/serialization/_torch.py#L385
# 2. get_torch_storage_size
# tensor.zize() is called by get_torch_storage_size
def storage(self):
return self

def data_ptr(self):
"""
a naive implemtation of storage id, can be optimized.
"""
return self.offset * 100000 + self.partitioned_numel
tjruwase marked this conversation as resolved.
Show resolved Hide resolved

def size(self):
return self.shape.numel()

@property
def dtype(self):
return self.flat_groups[0][0].dtype

@property
def device(self):
return self.flat_groups[0][0].device


def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
param_shapes = zero_model_states[0].param_shapes
avail_numel = fp32_flat_groups[0].numel() * world_size
avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size

# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
# param, re-consolidating each param, while dealing with padding if any

Expand All @@ -424,7 +492,7 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
offset = 0
total_numel = 0
total_params = 0
for name, shape in tqdm(param_shapes.items(), desc='Gathering Sharded Weights'):
for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
total_params += 1
Expand All @@ -435,10 +503,9 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
)

# XXX: memory usage doubles here
state_dict[name] = torch.cat(
tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
0).narrow(0, 0, unpartitioned_numel).view(shape)
# memory efficient tensor
tensor = GatheredTensor(fp32_flat_groups, offset, partitioned_numel, shape)
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
state_dict[name] = tensor
offset += partitioned_numel

offset *= world_size
Expand Down Expand Up @@ -541,6 +608,7 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
- ``exclude_frozen_parameters``: exclude frozen parameters
"""

# Dependency pre-check
if safe_serialization:
try:
Expand Down Expand Up @@ -571,15 +639,21 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
state_dict_split = StateDictSplit(is_sharded=False,
filename_to_tensors={weights_name: list(state_dict.keys())})

# Save the model
# Save the model by shard
os.makedirs(output_dir, exist_ok=True)
filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
shard = {tensor_name: state_dict[tensor_name].contiguous() for tensor_name in tensors}
output_path = os.path.join(output_dir, shard_file)
if safe_serialization:
save_file(shard, output_path, metadata={"format": "pt"})
else:
torch.save(shard, output_path)
# release the memory of current shard
for tensor_name in shard:
del state_dict[tensor_name]
gc.collect()


# Save index if sharded
if state_dict_split.is_sharded:
Expand Down