-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A faster and more memory-efficient implementation of zero_to_fp32
#6658
base: master
Are you sure you want to change the base?
Conversation
deepspeed/utils/zero_to_fp32.py
Outdated
@@ -305,6 +303,7 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero | |||
if debug: | |||
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") | |||
state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) | |||
# tensor = GatheredTensor(fp32_flat_groups, offset, partitioned_numel, shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove commented code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
deepspeed/utils/zero_to_fp32.py
Outdated
tensor_slice.append(flat_tensor[start_offset:end_offset]) | ||
|
||
pad_flat_param_chunks.append(torch.concat(tensor_slice, 0)) | ||
pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0).to(torch.float16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this cast to fp16? The checkpoint could be a different dtype like bf16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to fp16 has been removed
deepspeed/utils/zero_to_fp32.py
Outdated
return param | ||
|
||
# The following part make it compatible with `huggingface_hub.split_torch_state_dict_into_shards` | ||
# https://github.com/huggingface/huggingface_hub/blob/src/huggingface_hub/serialization/_torch.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you clarify why this HF_Hub APIs are needed since pseudo tensor is never exported into the output checkpoint file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DeepSpeed/deepspeed/utils/zero_to_fp32.py
Lines 565 to 566 in 6e6563d
state_dict_split = split_torch_state_dict_into_shards(state_dict, | |
filename_pattern=filename_pattern, |
HF_Hub APIs is used to split wights to shards.
Our pseudo tensor should be compatible with get_torch_storage_id
and get_torch_storage_size
in
https://github.com/huggingface/huggingface_hub/blob/v0.26.1/src/huggingface_hub/serialization/_torch.py#L335
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My question is why is this compatibility a requirement? To my understanding, the output checkpoint file will contain torch.Tensor
and not pseudo tensor. Am I wrong?
@xylian86, FYI |
@xu-song, just to clarify, we greatly appreciate this PR. The memory and speed benefits are very useful. My only concern are the HF_Hub related changes, so hopefully those can be clarified. Can you please add the observed speed and memory benefits of this optimizations? Such details are generally useful for readers to better appreciate the value. Thanks! |
@tjruwase Is there any alternative approach to sharding torch state_dict? If any, the compatible feature to |
Sorry, but I am a bit confused about the objective of this PR. The goal of zero_to_fp32 is to create a consolidated checkpoint state from the sharded checkpoints of ZeRO-* training, so I don't understand why state_dict sharding is a consideration here. It seems that there are two parts of this PR.
Am I correct? |
DeepSpeed/deepspeed/utils/zero_to_fp32.py Lines 565 to 567 in 54903e0
To save memory, the tensors in state_dict is pesudo tensor instead of torch tensor. |
1 similar comment
DeepSpeed/deepspeed/utils/zero_to_fp32.py Lines 565 to 567 in 54903e0
To save memory, the tensors in state_dict is pesudo tensor instead of torch tensor. |
It is a faster and more memory-efficient implementation of
zero_to_fp32
.The previous version double the memory usage, which cause cpu OOM for very large models (e.g. llama 405B).
DeepSpeed/deepspeed/utils/zero_to_fp32.py
Lines 438 to 441 in b647fb2
How does it work?
mmap=True
, thus the weights are mmaped rather than loading all the storages into memory.GatheredTensor
contains the mmaped weights and tensor offset. It is a memory-efficient pseudo tensor. Only when.contiguous()
is called, it starts to load related weights to memory and merge into a single tensor.Throughout the process, only one shard of tensors are keeped in memory.
How much benefit in speed and memory ?
memory benefit:
2M
to(1/n)M
, whereM
is the memory cost of the full weights,n
is num_shards. For llama3.1-405B, n=191llama3.1-405B
from zero3 to fp32 got OOM; After optimization, the memory cost is about 200GB-300GB.speed benefit: