Skip to content

Commit

Permalink
sequence parallel for uneven heads (#6392)
Browse files Browse the repository at this point in the history
In sequence_parallel (Ulysses), the sequence parallel size is
constrained by the requirement to be divisible by the number of heads,
which prevents some models/workloads from setting a specific sequence
parallel size. This PR implements uneven all-to-all heads splitting.

- both support  batch first (b,s,...) and seq_len first(s,b..) layout.
- Added unit tests with numerical checks. Locally also tested with **7
heads with sp=4** and **20 heads with sp=8**, and it passed.

---------

Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Ma, Guokai <[email protected]>
Co-authored-by: Masahiro Tanaka <[email protected]>
  • Loading branch information
5 people authored Oct 25, 2024
1 parent 3d5cf73 commit 5fb71c0
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 2 deletions.
4 changes: 3 additions & 1 deletion deepspeed/module_inject/tp_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def set_n_embd(num):

def get_num_kv_heads():
global num_kv_heads
return num_kv_heads
if 'num_kv_heads' in globals():
return num_kv_heads
return None


def get_num_attention_heads():
Expand Down
126 changes: 126 additions & 0 deletions deepspeed/sequence/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.module_inject.tp_shard import get_shard_size_list, set_num_kv_heads, get_num_kv_heads
from deepspeed.utils import groups


def post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, seq_len, num_head, head_dim):
Expand Down Expand Up @@ -38,8 +40,132 @@ def post_func(input):
return post_func


def uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group):
seq_world_size = dist.get_world_size(group)
inp_shape = list(input.shape)
assert batch_dim_idx in [0, 1], "batch_dim_idx must be either 0 or 1"

if not (scatter_idx < 2):
input_splits = get_shard_size_list(inp_shape[scatter_idx], seq_world_size)
input = input.transpose(0, scatter_idx).contiguous()
local_heads = input_splits[groups._get_sequence_parallel_rank()]
output_splits = [local_heads] * seq_world_size

output_buffer_shape = [seq_world_size * local_heads] + list(input.shape[1:])
output = torch.empty(output_buffer_shape, device=input.device, dtype=input.dtype)
dist.all_to_all_single(output,input,output_split_sizes=output_splits,\
input_split_sizes=input_splits,group=group)
###[seq_ws*local_heads, ...] to [seq_ws, local_heads, ...]
output = output.view(seq_world_size, local_heads, *output.shape[1:])
###[seq_ws,local_heads,b,seq_len,...] to [seq_ws,seq_len,b,local_heads,...]

### batch_dim_idx=0 [seq_ws,local_heads,seq_len,b,...] to [b, seq_ws, seq_len, local_heads ...]
### batch_dim_idx=1 [seq_ws,local_heads,b,seq_len,...] to [seq_ws,seq_len,b,local_heads,...]
if batch_dim_idx == 0:
order = [3, 0, 2, 1] + list(range(4, len(output.shape)))
output = output.permute(order).contiguous()
###[b, seq_ws*local_seq_len, local_heads,...]
output = output.view(output.shape[0], inp_shape[gather_idx] * seq_world_size,
*output.shape[3:]).contiguous()
elif batch_dim_idx == 1:
output = output.transpose(1, 3).contiguous()
###[seq_ws*local_seq_len, b, local_heads,...]
output = output.view(inp_shape[gather_idx] * seq_world_size, *output.shape[2:]).contiguous()
else:
# The compatibility handling of 4D and 3D tensors, standardizing to 3D.
input = input.reshape(input.shape[0], input.shape[1], -1)

if batch_dim_idx == 0: #b,s,h
input = input.permute(1, 2, 0).contiguous() #s,h,b
elif batch_dim_idx == 1: #s,b,h
input = input.transpose(1, 2).contiguous() #s,h,b
seq_len, h, batch_size = input.shape
num_local_heads_list = get_shard_size_list(get_num_kv_heads(), seq_world_size)
local_heads = num_local_heads_list[groups._get_sequence_parallel_rank()]
h_dim = h // local_heads
local_seq_len = seq_len // seq_world_size

input = input.view(seq_len * h, batch_size)
local_seq_len_with_heads = int(input.shape[0] / seq_world_size) # dim size of local_seq_len*local_heads*hdim
input_splits = [local_seq_len_with_heads] * seq_world_size
coeff = local_seq_len_with_heads // local_heads #per head: dim size of local_seq_len*hdim

#uneven seq_world_size coeff, total_heads/local_heads.
heads_scale_coeff = get_num_kv_heads() / local_heads

output_splits = [num_local_heads * coeff for num_local_heads in num_local_heads_list]
output_buff_d1_size = int(heads_scale_coeff * local_seq_len_with_heads)
total_h = int(inp_shape[gather_idx] * heads_scale_coeff)
output = torch.empty(output_buff_d1_size, input.shape[1], device=input.device, dtype=input.dtype)
dist.all_to_all_single(output,input,output_split_sizes=output_splits, \
input_split_sizes=input_splits,group=group)
##################
#suppose 7 heads divide into 4 ranks [2,2,2,1]
#chunk_num_heads_small=floor(7/4)=1
#chunk_num_heads_large=ceil(7/4)=2
#num_chunk_heads_large=len([2,2,2])=3, all2all_buffer_counts
#num_chunk_heads_small=len([1])=1, all2all_buffer_counts
#total_num_large_heads=sum([2,2,2])=7
#total_num_small_heads=sum([1])=1

chunk_num_heads_small = get_num_kv_heads() // seq_world_size # even heads compatible
chunk_num_heads_large = chunk_num_heads_small + 1
num_chunk_heads_large = get_num_kv_heads() % seq_world_size
num_chunk_heads_small = seq_world_size - num_chunk_heads_large
total_num_large_heads = num_chunk_heads_large * chunk_num_heads_large
total_num_small_heads = num_chunk_heads_small * chunk_num_heads_small

heads_large_combine_size = coeff * total_num_large_heads
heads_small_combine_size = coeff * total_num_small_heads
heads_large_chunk, heads_small_chunk = output.split([heads_large_combine_size, heads_small_combine_size],
dim=0)
heads_large_chunk = heads_large_chunk.view(num_chunk_heads_large, local_seq_len, chunk_num_heads_large, h_dim,
batch_size)
heads_small_chunk = heads_small_chunk.view(num_chunk_heads_small, local_seq_len, chunk_num_heads_small, h_dim,
batch_size)
if batch_dim_idx == 0:
#[all2all_buffer_counts, local_seq_len, n_heads,dim,batch]->[batch,local_seq_len,all2all_buffer_counts*n_heads,dim]
order = [4, 1, 0, 2, 3]
heads_large_chunk = heads_large_chunk.permute(order).contiguous().view(batch_size, local_seq_len,
total_num_large_heads, h_dim)
heads_small_chunk = heads_small_chunk.permute(order).contiguous().view(batch_size, local_seq_len,
total_num_small_heads, h_dim)
elif batch_dim_idx == 1:
#[all2all_buffer_counts, local_seq_len, n_heads,dim,batch]->[local_seq_len,batch,all2all_buffer_counts*n_heads,dim]
order = [1, 4, 0, 2, 3]
heads_large_chunk = heads_large_chunk.permute(order).contiguous().view(local_seq_len, batch_size,
total_num_large_heads, h_dim)
heads_small_chunk = heads_small_chunk.permute(order).contiguous().view(local_seq_len, batch_size,
total_num_small_heads, h_dim)

output = torch.cat([heads_large_chunk, heads_small_chunk], dim=2).contiguous()

inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size
output_shape= inp_shape[: gather_idx] + \
[total_h,] + \
inp_shape[gather_idx + 1:]

output = output.view(output_shape)

return output


def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, async_op=False, handle=None, type=None):
seq_world_size = dist.get_world_size(group)
# we only need num_heads once
num_heads = input.shape[2]

if get_num_kv_heads() is not None or num_heads % seq_world_size != 0:
# Assuming here that the number of heads for q is consistent with kv
# If not, additional logic is required for cases like GQA
if get_num_kv_heads() is None:
assert num_heads > seq_world_size, f"Number of heads ({num_heads}) must be larger than sequence parallel size ({seq_world_size})"
# set heads at first call by num_total_heads.
# then use ``get_num_kv_heads() is not None`` to re-entry uneven path.
set_num_kv_heads(num_heads)
assert async_op == False, "uneven head sp does not support async op"
return uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group)

if batch_dim_idx == 0:
# b, s, n, h
if scatter_idx < 2:
Expand Down
2 changes: 2 additions & 0 deletions deepspeed/utils/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,8 @@ def _get_sequence_parallel_rank():
global mpu
if mpu is not None and hasattr(mpu, 'get_sequence_parallel_rank'):
return mpu.get_sequence_parallel_rank()
if mesh_device is not None:
return dist.get_rank(mesh_device.get_group(mesh_dim="sequence_parallel"))
return 0


Expand Down
84 changes: 83 additions & 1 deletion tests/unit/sequence_parallelism/test_ulysses.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
from unit.common import DistributedTest
from deepspeed.sequence.layer import _SeqAllToAll
from unit.util import skip_on_arch
from unit.simple_model import *
from deepspeed.utils import groups
from deepspeed.module_inject.tp_shard import get_shard_size_list
#Use mesh device to create data and sequence parallel group


#Use mesh device to create data and sequence parallel group
class TestUlyssesUtils(DistributedTest):
world_size = 4

Expand Down Expand Up @@ -75,3 +78,82 @@ def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_
# Check outputs are the same as input
for i in range(1, len(outputs)):
assert torch.allclose(input_tensor, outputs[i]), f"Outputs differ for sequence dim {seq_dims[i]}"


@pytest.mark.parametrize("d0", [2, 4]) #batch or sequence dimension
@pytest.mark.parametrize("d1", [4, 8]) #batch or sequence dimension
@pytest.mark.parametrize("num_heads", [3, 7])
@pytest.mark.parametrize("head_dim", [16])
class TestUlyssesAll2All_odd(DistributedTest):
world_size = 4

def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_heads: int) -> None:

data_parallel_size = 2
seq_parallel_size = self.world_size // data_parallel_size
skip_on_arch(min_arch=8)

def seq_batch_heads_hash(d0, d1, h, offset_d0=0, offset_d1=0, offset_h=0):
d0 += offset_d0
d1 += offset_d1
h += offset_h
return d0 * 10 + h + d1 * 0.1

hidden_dim = 10
model = SimpleModel(hidden_dim)
ds_engine, _, _, _ = initialize(model=model,
config_params={"train_batch_size": 8},
mesh_param=(data_parallel_size, seq_parallel_size))

scatter_idx = 2
outputs = []
inputs = []
batch_dims = [0, 1]
seq_dims = [1, 0]

for idx, seq_dim in enumerate(seq_dims):
gather_idx = seq_dim
batch_dim_idx = batch_dims[idx]

#4D tensor : b,s,h,d or s,b,h,d
#create a hash tensor from pos_id, head_id, and batch_id
d0_indices = torch.arange(d0).reshape(-1, 1, 1, 1)
d1_indices = torch.arange(d1).reshape(1, -1, 1, 1)
h_indices = torch.arange(num_heads).reshape(1, 1, -1, 1)
input_tensor = torch.randn(d0, d1, num_heads, head_dim, device=ds_engine.device)
if batch_dim_idx == 1: #seq_len_dim : 0(d0)
input_tensor[:] = seq_batch_heads_hash(d0_indices, d1_indices, h_indices,
d0 * groups._get_sequence_parallel_rank(), 0)
elif batch_dim_idx == 0: #seq_len_dim : 1(d1)
input_tensor[:] = seq_batch_heads_hash(d0_indices, d1_indices, h_indices, 0,
d1 * groups._get_sequence_parallel_rank())
inputs.append(input_tensor)

### first all2all: sequence parallel to head parallel
s2h_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, input_tensor, scatter_idx, gather_idx,
batch_dim_idx)

# s2h_tensor check for the first all2all: compare with the expected ground truth
d0_indices = torch.arange(s2h_tensor.shape[0]).reshape(-1, 1, 1, 1)
d1_indices = torch.arange(s2h_tensor.shape[1]).reshape(1, -1, 1, 1)
h_indices = torch.arange(s2h_tensor.shape[2]).reshape(1, 1, -1, 1)
shard_list = get_shard_size_list(num_heads, groups._get_sequence_parallel_world_size())
head_offset = sum(shard_list[:groups._get_sequence_parallel_rank()])
s2h_truth = torch.zeros_like(s2h_tensor)
s2h_truth[:] = seq_batch_heads_hash(d0_indices, d1_indices, h_indices, 0, 0, head_offset)

assert torch.allclose(s2h_truth,
s2h_tensor), f"s2h_tensor differs from the expected for sequence dim: {seq_dim}"
#No op
### second all2all: head parallel to sequence parallel
h2s_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, s2h_tensor, gather_idx, scatter_idx,
batch_dim_idx)
print(
f'[{dist.get_rank()}] s={seq_dim} input: {input_tensor.shape} s2h: {s2h_tensor.shape} h2s_tensor: {h2s_tensor.shape}'
)
outputs.append(h2s_tensor)

# Check outputs for the second all2all
for i in range(0, len(outputs)):
assert torch.allclose(inputs[i],
outputs[i]), f"[{dist.get_rank()}]Outputs differ for sequence dim {seq_dims[i]}"

0 comments on commit 5fb71c0

Please sign in to comment.