Skip to content

Commit

Permalink
feat[paramsync]: allow kv replicate in vllm generation
Browse files Browse the repository at this point in the history
  • Loading branch information
haolin-nju committed Jan 23, 2025
1 parent 3b6d92b commit 6526327
Showing 1 changed file with 26 additions and 9 deletions.
35 changes: 26 additions & 9 deletions chatlearn/synchronizer/megatron_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,9 @@ def regroup_qkv_tp_slices(self, name, param_data, tp_divition):
if "attention.query_key_value" in name or \
"self_attention.query_key_value" in name or \
"self_attention.linear_qkv" in name:
tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"]
heads = self.src_module_args.args_dict["num_attention_heads"] // tp_size
src_tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"]
dst_tp_size = self.dst_module_args.args_dict["tensor_model_parallel_size"]
heads = self.src_module_args.args_dict["num_attention_heads"] // src_tp_size
hidden_size_per_head = self.src_module_args.args_dict["hidden_size"] // self.src_module_args.args_dict["num_attention_heads"]

param_shape = (3, heads, hidden_size_per_head) + param_data_shape[1:]
Expand All @@ -348,22 +349,38 @@ def regroup_qkv_tp_slices(self, name, param_data, tp_divition):
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list
else:
_num_query_groups = self.src_module_args.args_dict["num_query_groups"]//tp_size \
if self.src_module_args.args_dict["group_query_attention"] else heads
if to_fix_qkv_ordering_dict is not None or _num_query_groups == 1:
num_query_groups = self.src_module_args.args_dict["num_query_groups"]
assert num_query_groups == self.dst_moduel_args.args_dict["num_query_groups"], (
f"num_query_groups of src model ({num_query_groups}) must be equal to num_query_groups of "
f"dst model ({self.dst_moduel_args.args_dict['num_query_groups']}). Please double-check your config."
)
if self.src_module_args.args_dict["group_query_attention"]:
src_num_query_groups_per_replica = num_query_groups // src_tp_size
if dst_tp_size >= num_query_groups:
num_dst_kv_head_replicas = dst_tp_size // num_query_groups
else:
num_dst_kv_head_replicas = 1
else:
src_num_query_groups_per_replica = heads
num_dst_kv_head_replicas = 1

if to_fix_qkv_ordering_dict is not None or src_num_query_groups_per_replica == 1:
if len(param_data_shape) == 1:
param_data = param_data.view((heads + 2 * _num_query_groups, hidden_size_per_head))
param_data = param_data.view((heads + 2 * src_num_query_groups_per_replica, hidden_size_per_head))
else:
param_data = param_data.view(
(heads + 2 * _num_query_groups, hidden_size_per_head, self.src_module_args.args_dict["hidden_size"]))
(heads + 2 * src_num_query_groups_per_replica, hidden_size_per_head, self.src_module_args.args_dict["hidden_size"]))
param_data_list = []
head_offset = heads // tp_divition
for idx in range(tp_divition):
q_start = idx * head_offset
q_end = q_start + head_offset
k_start = (heads + idx) if _num_query_groups // tp_divition else heads
if src_num_query_groups_per_replica // tp_divition and num_dst_kv_head_replicas == 1:
k_start = heads + idx
else:
k_start = heads
k_end = k_start + 1
v_start = k_start + _num_query_groups
v_start = k_start + src_num_query_groups_per_replica
v_end = v_start + 1

q_proj = param_data[q_start:q_end].contiguous()
Expand Down

0 comments on commit 6526327

Please sign in to comment.