Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A faster and more memory-efficient implementation of zero_to_fp32 #6658

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

xu-song
Copy link
Contributor

@xu-song xu-song commented Oct 23, 2024

It is a faster and more memory-efficient implementation of zero_to_fp32.

The previous version double the memory usage, which cause cpu OOM for very large models (e.g. llama 405B).

# XXX: memory usage doubles here
state_dict[name] = torch.cat(
tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
0).narrow(0, 0, unpartitioned_numel).view(shape)

How does it work?

  1. Lazy loading: Load checkpoint with mmap=True, thus the weights are mmaped rather than loading all the storages into memory.
  2. Lazy merge: GatheredTensor contains the mmaped weights and tensor offset. It is a memory-efficient pseudo tensor. Only when .contiguous() is called, it starts to load related weights to memory and merge into a single tensor.
  3. Free memory in time: Release the memory once a shard is saved.

Throughout the process, only one shard of tensors are keeped in memory.

How much benefit in speed and memory ?

memory benefit:

  • Theoretically, this PR reduces the memory cost from 2M to (1/n)M, where M is the memory cost of the full weights, n is num_shards. For llama3.1-405B, n=191
  • In my test with 1TB cpu memory: converting llama3.1-405B from zero3 to fp32 got OOM; After optimization, the memory cost is about 200GB-300GB.

speed benefit:

  • the speed gain mainly comes from avoiding extra tensor copying.

@@ -305,6 +303,7 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
if debug:
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
# tensor = GatheredTensor(fp32_flat_groups, offset, partitioned_numel, shape)
Copy link
Contributor

@tjruwase tjruwase Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove commented code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

tensor_slice.append(flat_tensor[start_offset:end_offset])

pad_flat_param_chunks.append(torch.concat(tensor_slice, 0))
pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0).to(torch.float16)
Copy link
Contributor

@tjruwase tjruwase Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this cast to fp16? The checkpoint could be a different dtype like bf16.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to fp16 has been removed

return param

# The following part make it compatible with `huggingface_hub.split_torch_state_dict_into_shards`
# https://github.com/huggingface/huggingface_hub/blob/src/huggingface_hub/serialization/_torch.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify why this HF_Hub APIs are needed since pseudo tensor is never exported into the output checkpoint file?

Copy link
Contributor Author

@xu-song xu-song Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

state_dict_split = split_torch_state_dict_into_shards(state_dict,
filename_pattern=filename_pattern,

HF_Hub APIs is used to split wights to shards.

Our pseudo tensor should be compatible with get_torch_storage_id and get_torch_storage_size in
https://github.com/huggingface/huggingface_hub/blob/v0.26.1/src/huggingface_hub/serialization/_torch.py#L335

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My question is why is this compatibility a requirement? To my understanding, the output checkpoint file will contain torch.Tensor and not pseudo tensor. Am I wrong?

@tjruwase tjruwase requested review from tohtana and removed request for awan-10 October 23, 2024 16:31
@tjruwase
Copy link
Contributor

@xylian86, FYI

@tjruwase
Copy link
Contributor

@xu-song, just to clarify, we greatly appreciate this PR. The memory and speed benefits are very useful. My only concern are the HF_Hub related changes, so hopefully those can be clarified.

Can you please add the observed speed and memory benefits of this optimizations? Such details are generally useful for readers to better appreciate the value. Thanks!

@xu-song
Copy link
Contributor Author

xu-song commented Oct 25, 2024

@xu-song, just to clarify, we greatly appreciate this PR. The memory and speed benefits are very useful. My only concern are the HF_Hub related changes, so hopefully those can be clarified.

Can you please add the observed speed and memory benefits of this optimizations? Such details are generally useful for readers to better appreciate the value. Thanks!

@tjruwase Is there any alternative approach to sharding torch state_dict?

If any, the compatible feature to huggingface_hub.split_torch_state_dict_into_shards can be discarded.

@tjruwase
Copy link
Contributor

@tjruwase Is there any alternative approach to sharding torch state_dict?

If any, the compatible feature to huggingface_hub.split_torch_state_dict_into_shards can be discarded.

Sorry, but I am a bit confused about the objective of this PR. The goal of zero_to_fp32 is to create a consolidated checkpoint state from the sharded checkpoints of ZeRO-* training, so I don't understand why state_dict sharding is a consideration here.

It seems that there are two parts of this PR.

  1. Speed and memory optimizations
  2. HF_hub compatibility involving state_dict sharding

Am I correct?

@xu-song
Copy link
Contributor Author

xu-song commented Oct 26, 2024

@tjruwase Is there any alternative approach to sharding torch state_dict?
If any, the compatible feature to huggingface_hub.split_torch_state_dict_into_shards can be discarded.

Sorry, but I am a bit confused about the objective of this PR. The goal of zero_to_fp32 is to create a consolidated checkpoint state from the sharded checkpoints of ZeRO-* training, so I don't understand why state_dict sharding is a consideration here.

It seems that there are two parts of this PR.

  1. Speed and memory optimizations
  2. HF_hub compatibility involving state_dict sharding

Am I correct?

state_dict_split = split_torch_state_dict_into_shards(state_dict,
filename_pattern=filename_pattern,
max_shard_size=max_shard_size)

To save memory, the tensors in state_dict is pesudo tensor instead of torch tensor.

1 similar comment
@xu-song
Copy link
Contributor Author

xu-song commented Oct 26, 2024

@tjruwase Is there any alternative approach to sharding torch state_dict?
If any, the compatible feature to huggingface_hub.split_torch_state_dict_into_shards can be discarded.

Sorry, but I am a bit confused about the objective of this PR. The goal of zero_to_fp32 is to create a consolidated checkpoint state from the sharded checkpoints of ZeRO-* training, so I don't understand why state_dict sharding is a consideration here.

It seems that there are two parts of this PR.

  1. Speed and memory optimizations
  2. HF_hub compatibility involving state_dict sharding

Am I correct?

state_dict_split = split_torch_state_dict_into_shards(state_dict,
filename_pattern=filename_pattern,
max_shard_size=max_shard_size)

To save memory, the tensors in state_dict is pesudo tensor instead of torch tensor.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants