Skip to content

Commit

Permalink
[moe] feat: enabling expert parallelism in veScale (#59)
Browse files Browse the repository at this point in the history
## Overview

veScale provides an efficient framework for training Mixture of Experts
(MoE) models using expert parallelism. Expert parallelism can be
deployed with the `parallelize_experts()` function, which simplifies the
process of distributing and managing workload during MoE training.

### Function Signature

```python
model = parallelize_experts(
    module: nn.Module,
    experts_expr: Union[str, List[str]],
    experts_allocator: vescale.moe.ExpertsAllocator,
    token_dispatcher: vescale.moe.TokenDispatcher,
    config: Dict,
)
```

### Parameters
- **`module`**: The training model (an instance of `nn.Module`) to be
parallelized.
- **`experts_expr`**: Specifies the paths to the expert modules. Can be
a string or a list of strings.
- **`experts_allocator`**: An instance of `ExpertsAllocator`, used for
managing expert parameter allocation.
- **`token_dispatcher`**: An instance of `TokenDispatcher`, responsible
for token scheduling and distribution.
- **`config`**: A dictionary containing the MoE training configuration,
including layer count, number of experts, and other relevant settings.


## Custom Scheduling

veScale allows users to define custom scheduling strategies for expert
parallelism by implementing the following components:

- **`ExpertsAllocator`**: Manages expert parameter allocation. It can
use `collect_performance()` to profile and dynamically adjust the DP x
TP device mesh for each expert. By default, veScale shards all expert
parameters across devices using tensor parallelism.

- **`TokenDispatcher`**: Handles token distribution. Using
`assign_task()`, it determines workload allocation (e.g., expert IDs and
token weights) and adjusts scheduling with `collect_performance()`. The
default implementation randomly assigns tokens to a single DP rank for
the selected expert.

## Optimizer Support

Since veScale supports dynamic placement of expert parameters, a
dedicated optimizer, `MoEOptimizer`, is required. This optimizer handles
the redistribution of expert parameters and their states efficiently.
Future updates will integrate these functionalities into optimizers for
static parameters to streamline the process.


## Getting Started

### Data Preparation
Prepare the Shakespeare dataset by running:

```bash
cd data/shakespeare/
python3 prepare.py
cd ../..
```

### Training Command

```
torchrun --standalone --nproc_per_node={GPU_CNT} mixtral_train.py --dp={dp_size} --tp={tp_size} --max_iters={max_iters}
```
  • Loading branch information
chwan1016 authored Dec 27, 2024
1 parent b4b1686 commit ac76ffa
Show file tree
Hide file tree
Showing 41 changed files with 2,444 additions and 132 deletions.
24 changes: 19 additions & 5 deletions examples/llama2_4D_finetune/llama_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@
from data_loader import DataLoader


class Net(torch.nn.Module):
def __init__(self, path, torch_dtype):
super().__init__()
self.llama_model = LlamaForCausalLM.from_pretrained(path, torch_dtype=torch_dtype)
self.loss_fn = torch.nn.CrossEntropyLoss()

def forward(self, input_ids, labels):
logits = self.llama_model(input_ids).logits
logits = logits.flatten(end_dim=-2)
labels = labels.flatten()
loss = self.loss_fn(logits, labels)
return loss


def estimate_llama2(config, bsz, sqence_length):
embed = 4 * bsz * sqence_length * config.hidden_size
ff = 3 * 2 * config.hidden_size * config.intermediate_size * bsz * sqence_length
Expand All @@ -53,7 +67,7 @@ def run_llama2(args):
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
device = f"cuda:{rank}"
device = f"cuda:{local_rank}"
torch.cuda.set_device(device)
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
VESCALE_DEVICE_MESH.init_device_mesh(device, (args.dp, args.tp), mesh_dim_names=["DP", "TP"])
Expand All @@ -77,8 +91,8 @@ def run_llama2(args):
"bfloat16": torch.bfloat16,
}[args.dtype]

model = LlamaForCausalLM.from_pretrained("openlm-research/open_llama_3b", torch_dtype=ptdtype)
llama_config = model.config
model = Net("openlm-research/open_llama_3b", torch_dtype=ptdtype)
llama_config = model.llama_model.config
if rank == 0:
print(model)
print(llama_config)
Expand Down Expand Up @@ -165,7 +179,7 @@ def estimate_loss():
losses = torch.zeros(args.eval_iters // factor).to(device)
for k in range(args.eval_iters // factor):
X, Y = data_loader.get_batch(split, args.bsz * factor, factor * args.bsz // args.dp)
loss = model(X, labels=Y).loss
loss = model(X, Y)
if world_size > 1:
losses[k] = loss.to_local().item()
else:
Expand Down Expand Up @@ -198,7 +212,7 @@ def estimate_loss():
start_epoch.record()
if world_size > 1:
model.zero_grad_buffer()
loss = model(X, labels=Y).loss
loss = model(X, Y)
loss.backward()
grad_norm = -1
if world_size == 1 and args.grad_clip > 0:
Expand Down
14 changes: 7 additions & 7 deletions examples/llama2_4D_finetune/sharding_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,19 @@

# forward resharding plan for the whole open llama model
model_fwd_resharding_plan = {
"model.input": [[Replicate()]],
"model.embed_tokens.output": [[Shard(1)]],
"model.norm.input": [[Shard(1)]],
"model.output": {
"llama_model.model.input": [[Replicate()]],
"llama_model.model.embed_tokens.output": [[Shard(1)]],
"llama_model.model.norm.input": [[Shard(1)]],
"llama_model.model.output": {
"last_hidden_state": [Replicate()],
},
**{rf"model.layers.\d+.{k}": v for k, v in _decoder_fwd_resharding_plan.items()},
**{rf"llama_model.model.layers.\d+.{k}": v for k, v in _decoder_fwd_resharding_plan.items()},
}

# model parameter sharding plan for the whole open llama model
model_param_sharding_plan = {
"model.embed_tokens.weight": [Shard(1)],
**{rf"model.layers.\d+.{k}": v for k, v in _decoder_param_sharding_plan.items()},
"llama_model.model.embed_tokens.weight": [Shard(1)],
**{rf"llama_model.model.layers.\d+.{k}": v for k, v in _decoder_param_sharding_plan.items()},
}

llama2_plan = {"parameter": model_param_sharding_plan, "forward": model_fwd_resharding_plan}
2 changes: 1 addition & 1 deletion examples/mixtral_4D_benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ from HuggingFace without any model code modifications.

### Single Machine 8 cards
```
torchrun --nproc-per-node=8 --nnodes=1 --master-port=42516 -- examples/mixtral_4D_benchmark/mixtral_train.py --num_hidden_layers=16
torchrun --nproc-per-node=8 --standalone examples/mixtral_4D_benchmark/mixtral_train.py --num_hidden_layers=16
```
This will start a 8-cards MFU benchmark for Mixtral with veScale with dp=1 and tp=8.

Expand Down
4 changes: 2 additions & 2 deletions examples/mixtral_4D_benchmark/mixtral_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from vescale.optim.distributed_optimizer import DistributedOptimizer
from vescale.initialize.deferred_init import deferred_init, is_deferred

from transformers.models.mixtral.modeling_mixtral import MixtralModel
from transformers.models.mixtral.modeling_mixtral import MixtralModel, MixtralSparseMoeBlock
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from sharding_plan import mixtral_plan

Expand Down Expand Up @@ -84,7 +84,7 @@ def run_mixtral(args):
accumulate_allreduce_grads_in_fp32=True,
overlap_grad_reduce=False,
use_distributed_optimizer=True,
whitelist_module_types=[MixtralSparseMoeBlock],
module_to_enforce=[MixtralSparseMoeBlock],
)

doptim = DistributedOptimizer(
Expand Down
58 changes: 29 additions & 29 deletions examples/mixtral_4D_benchmark/sharding_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,49 +21,49 @@


param_sharding_plan = {
"embed_tokens.weight": [Replicate()],
r"layers.\d+.input_layernorm.weight": [Replicate()], # MixtralRMSNorm
r"layers.\d+.self_attn.q_proj.weight": [Shard(0)],
r"layers.\d+.self_attn.k_proj.weight": [Shard(0)],
r"layers.\d+.self_attn.v_proj.weight": [Shard(0)],
"model.embed_tokens.weight": [Replicate()],
r"model.layers.\d+.input_layernorm.weight": [Replicate()], # MixtralRMSNorm
r"model.layers.\d+.self_attn.q_proj.weight": [Shard(0)],
r"model.layers.\d+.self_attn.k_proj.weight": [Shard(0)],
r"model.layers.\d+.self_attn.v_proj.weight": [Shard(0)],
# TODO: buggy, cos_cached or sin_cached can be updated or recreated if seqlen exceeds the max seqlen.
r"layers.\d+.self_attn.rotary_emb.layers.\d+.cos_cached": [Replicate()],
r"layers.\d+.self_attn.rotary_emb.layers.\d+.sin_cached": [Replicate()],
r"layers.\d+.self_attn.o_proj.weight": [Shard(1)],
r"layers.\d+.post_attention_layernorm.weight": [Replicate()],
r"layers.\d+.block_sparse_moe.gate.weight": [Replicate()],
r"layers.\d+.block_sparse_moe.experts.\d+.w1.weight": [Shard(0)],
r"layers.\d+.block_sparse_moe.experts.\d+.w3.weight": [Shard(0)],
r"layers.\d+.block_sparse_moe.experts.\d+.w2.weight": [Shard(1)],
"norm.weight": [Replicate()],
r"model.layers.\d+.self_attn.rotary_emb.layers.\d+.cos_cached": [Replicate()],
r"model.layers.\d+.self_attn.rotary_emb.layers.\d+.sin_cached": [Replicate()],
r"model.layers.\d+.self_attn.o_proj.weight": [Shard(1)],
r"model.layers.\d+.post_attention_layernorm.weight": [Replicate()],
r"model.layers.\d+.block_sparse_moe.gate.weight": [Replicate()],
r"model.layers.\d+.block_sparse_moe.experts.\d+.w1.weight": [Shard(0)],
r"model.layers.\d+.block_sparse_moe.experts.\d+.w3.weight": [Shard(0)],
r"model.layers.\d+.block_sparse_moe.experts.\d+.w2.weight": [Shard(1)],
"model.norm.weight": [Replicate()],
}

fwd_resharding_plan = {
# TODO: buggy: attn mask is torch.Tensor, in training, it's a None
r".input": {"input_ids": [Replicate()], "attention_mask": [Replicate()]},
"embed_tokens.input": [[Replicate()]],
"model.embed_tokens.input": [[Replicate()]],
# No SP
# r"layers.\d+.input_layernorm.input": [[Replicate()]],
# r"layers.\d+.input_layernorm.output": [[Replicate()]],
# SP
r"layers.\d+.input_layernorm.input": [[Shard(1)]],
r"layers.\d+.input_layernorm.output": [[Shard(1)]],
r"layers.\d+.self_attn.input": [[Replicate()]],
r"layers.\d+.self_attn.output": {"attn_output": [Replicate()], "attn_weights": None, "past_key_value": None},
r"layers.\d+.self_attn.o_proj.output": [[Replicate()]],
r"model.layers.\d+.input_layernorm.input": [[Shard(1)]],
r"model.layers.\d+.input_layernorm.output": [[Shard(1)]],
r"model.layers.\d+.self_attn.input": [[Replicate()]],
r"model.layers.\d+.self_attn.output": {"attn_output": [Replicate()], "attn_weights": None, "past_key_value": None},
r"model.layers.\d+.self_attn.o_proj.output": [[Replicate()]],
# No SP
# r"layers.\d+.post_attention_layernorm.input": [[Replicate()]],
# r"layers.\d+.post_attention_layernorm.output": [[Replicate()]],
# SP
r"layers.\d+.post_attention_layernorm.input": [[Shard(1)]],
r"layers.\d+.post_attention_layernorm.output": [[Shard(1)]],
r"layers.\d+.block_sparse_moe.input": [[Replicate()]],
r"layers.\d+.block_sparse_moe.gate.output": [[Replicate()]],
r"layers.\d+.block_sparse_moe.output": {"final_hidden_states": [Replicate()], "router_logits": [Replicate()]},
r"layers.\d+.block_sparse_moe.experts.\d+.w1.input": [[Replicate()]],
r"layers.\d+.block_sparse_moe.experts.\d+.w3.input": [[Replicate()]],
r"layers.\d+.block_sparse_moe.experts.\d+.w2.output": [[Replicate()]],
"norm.input": [[Replicate()]],
r"model.layers.\d+.post_attention_layernorm.input": [[Shard(1)]],
r"model.layers.\d+.post_attention_layernorm.output": [[Shard(1)]],
r"model.layers.\d+.block_sparse_moe.input": [[Replicate()]],
r"model.layers.\d+.block_sparse_moe.gate.output": [[Replicate()]],
r"model.layers.\d+.block_sparse_moe.output": {"final_hidden_states": [Replicate()], "router_logits": [Replicate()]},
r"model.layers.\d+.block_sparse_moe.experts.\d+.w1.input": [[Replicate()]],
r"model.layers.\d+.block_sparse_moe.experts.\d+.w3.input": [[Replicate()]],
r"model.layers.\d+.block_sparse_moe.experts.\d+.w2.output": [[Replicate()]],
"model.norm.input": [[Replicate()]],
}

mixtral_plan = {"parameter": param_sharding_plan, "forward": fwd_resharding_plan}
33 changes: 21 additions & 12 deletions examples/mixtral_4D_training/mixtral_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@
from data_loader import DataLoader


class Net(torch.nn.Module):
def __init__(self, mixtral_config):
super().__init__()
self.mixtral_model = MixtralForCausalLM(mixtral_config)
self.loss_fn = torch.nn.CrossEntropyLoss()

def forward(self, input_ids, labels):
logits = self.mixtral_model(input_ids).logits
logits = logits.flatten(end_dim=-2)
labels = labels.flatten()
loss = self.loss_fn(logits, labels)
return loss


def estimate_mixtral(config, bsz, sqence_length):
embed = 4 * bsz * sqence_length * config.hidden_size
# MixtralMoE consists of 3 linear layers.
Expand All @@ -57,7 +71,7 @@ def run_mixtral(args):
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
device = f"cuda:{rank}"
device = f"cuda:{local_rank}"
torch.cuda.set_device(device)
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
VESCALE_DEVICE_MESH.init_device_mesh(device, (args.dp, args.tp), mesh_dim_names=["DP", "TP"])
Expand Down Expand Up @@ -90,7 +104,7 @@ def run_mixtral(args):
)

if world_size > 1:
model = MixtralForCausalLM(mixtral_config)
model = Net(mixtral_config)
model.to(ptdtype)

model = parallelize_module(
Expand All @@ -104,11 +118,11 @@ def run_mixtral(args):
model,
VESCALE_DEVICE_MESH["DP"],
accumulate_allreduce_grads_in_fp32=False,
use_distributed_optimizer=True,
whitelist_module_types=[MixtralSparseMoeBlock],
use_distributed_optimizer=args.use_DO,
module_to_enforce=[MixtralSparseMoeBlock],
)
else:
model = MixtralForCausalLM(mixtral_config).to(device)
model = Net(mixtral_config).to(device)
model.to(ptdtype)
print(f"rank {rank} cuda.rng_state {torch.cuda.get_rng_state().view(torch.int64)}")

Expand Down Expand Up @@ -170,7 +184,7 @@ def estimate_loss():
losses = torch.zeros(args.eval_iters // factor).to(device)
for k in range(args.eval_iters // factor):
X, Y = data_loader.get_batch(split, args.bsz * factor, factor * args.bsz // args.dp)
loss = model(X, labels=Y).loss
loss = model(X, Y)
if world_size > 1:
losses[k] = loss.to_local().item()
else:
Expand Down Expand Up @@ -203,7 +217,7 @@ def estimate_loss():
start_epoch.record()
if world_size > 1:
model.zero_grad_buffer()
loss = model(X, labels=Y).loss
loss = model(X, Y)
loss.backward()
grad_norm = -1
if world_size == 1 and args.grad_clip > 0:
Expand Down Expand Up @@ -274,11 +288,6 @@ def parse_args():
parser.add_argument("--num_hidden_layers", type=int, default=2)
parser.add_argument("--num_attention_heads", type=int, default=8)
parser.add_argument("--num_key_value_heads", type=int, default=8)
# parser.add_argument("--hidden_size", type=int, default=4096)
# parser.add_argument("--intermediate_size", type=int, default=14336)
# parser.add_argument("--num_hidden_layers", type=int, default=16)
# parser.add_argument("--num_attention_heads", type=int, default=32)
# parser.add_argument("--num_key_value_heads", type=int, default=8)

# Optimizer related
parser.add_argument("--use_DO", type=bool, default=True)
Expand Down
65 changes: 36 additions & 29 deletions examples/mixtral_4D_training/sharding_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,49 +21,56 @@


param_sharding_plan = {
"model.embed_tokens.weight": [Replicate()],
r"model.layers.\d+.input_layernorm.weight": [Replicate()], # MixtralRMSNorm
r"model.layers.\d+.self_attn.q_proj.weight": [Shard(0)],
r"model.layers.\d+.self_attn.k_proj.weight": [Shard(0)],
r"model.layers.\d+.self_attn.v_proj.weight": [Shard(0)],
"mixtral_model.model.embed_tokens.weight": [Replicate()],
r"mixtral_model.model.layers.\d+.input_layernorm.weight": [Replicate()], # MixtralRMSNorm
r"mixtral_model.model.layers.\d+.self_attn.q_proj.weight": [Shard(0)],
r"mixtral_model.model.layers.\d+.self_attn.k_proj.weight": [Shard(0)],
r"mixtral_model.model.layers.\d+.self_attn.v_proj.weight": [Shard(0)],
# TODO: buggy, cos_cached or sin_cached can be updated or recreated if seqlen exceeds the max seqlen.
r"model.layers.\d+.self_attn.rotary_emb.layers.\d+.cos_cached": [Replicate()],
r"model.layers.\d+.self_attn.rotary_emb.layers.\d+.sin_cached": [Replicate()],
r"model.layers.\d+.self_attn.o_proj.weight": [Shard(1)],
r"model.layers.\d+.post_attention_layernorm.weight": [Replicate()],
r"model.layers.\d+.block_sparse_moe.gate.weight": [Replicate()],
r"model.layers.\d+.block_sparse_moe.experts.\d+.w1.weight": [Shard(0)],
r"model.layers.\d+.block_sparse_moe.experts.\d+.w3.weight": [Shard(0)],
r"model.layers.\d+.block_sparse_moe.experts.\d+.w2.weight": [Shard(1)],
"model.norm.weight": [Replicate()],
r"mixtral_model.model.layers.\d+.self_attn.rotary_emb.layers.\d+.cos_cached": [Replicate()],
r"mixtral_model.model.layers.\d+.self_attn.rotary_emb.layers.\d+.sin_cached": [Replicate()],
r"mixtral_model.model.layers.\d+.self_attn.o_proj.weight": [Shard(1)],
r"mixtral_model.model.layers.\d+.post_attention_layernorm.weight": [Replicate()],
r"mixtral_model.model.layers.\d+.block_sparse_moe.gate.weight": [Replicate()],
r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w1.weight": [Shard(0)],
r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w3.weight": [Shard(0)],
r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w2.weight": [Shard(1)],
"mixtral_model.model.norm.weight": [Replicate()],
}

fwd_resharding_plan = {
# TODO: buggy: attn mask is torch.Tensor, in training, it's a None
r".input": {"input_ids": [Replicate()], "attention_mask": [Replicate()]},
"model.embed_tokens.input": [[Replicate()]],
"mixtral_model.model.embed_tokens.input": [[Replicate()]],
# No SP
# r"layers.\d+.input_layernorm.input": [[Replicate()]],
# r"layers.\d+.input_layernorm.output": [[Replicate()]],
# SP
r"model.layers.\d+.input_layernorm.input": [[Shard(1)]],
r"model.layers.\d+.input_layernorm.output": [[Shard(1)]],
r"model.layers.\d+.self_attn.input": [[Replicate()]],
r"model.layers.\d+.self_attn.output": {"attn_output": [Replicate()], "attn_weights": None, "past_key_value": None},
r"model.layers.\d+.self_attn.o_proj.output": [[Replicate()]],
r"mixtral_model.model.layers.\d+.input_layernorm.input": [[Shard(1)]],
r"mixtral_model.model.layers.\d+.input_layernorm.output": [[Shard(1)]],
r"mixtral_model.model.layers.\d+.self_attn.input": [[Replicate()]],
r"mixtral_model.model.layers.\d+.self_attn.output": {
"attn_output": [Replicate()],
"attn_weights": None,
"past_key_value": None,
},
r"mixtral_model.model.layers.\d+.self_attn.o_proj.output": [[Replicate()]],
# No SP
# r"model.layers.\d+.post_attention_layernorm.input": [[Replicate()]],
# r"model.layers.\d+.post_attention_layernorm.output": [[Replicate()]],
# SP
r"model.layers.\d+.post_attention_layernorm.input": [[Shard(1)]],
r"model.layers.\d+.post_attention_layernorm.output": [[Shard(1)]],
r"model.layers.\d+.block_sparse_moe.input": [[Replicate()]],
r"model.layers.\d+.block_sparse_moe.gate.output": [[Replicate()]],
r"model.layers.\d+.block_sparse_moe.output": {"final_hidden_states": [Replicate()], "router_logits": [Replicate()]},
r"model.layers.\d+.block_sparse_moe.experts.\d+.w1.input": [[Replicate()]],
r"model.layers.\d+.block_sparse_moe.experts.\d+.w3.input": [[Replicate()]],
r"model.layers.\d+.block_sparse_moe.experts.\d+.w2.output": [[Replicate()]],
"model.norm.input": [[Replicate()]],
r"mixtral_model.model.layers.\d+.post_attention_layernorm.input": [[Shard(1)]],
r"mixtral_model.model.layers.\d+.post_attention_layernorm.output": [[Shard(1)]],
r"mixtral_model.model.layers.\d+.block_sparse_moe.input": [[Replicate()]],
r"mixtral_model.model.layers.\d+.block_sparse_moe.gate.output": [[Replicate()]],
r"mixtral_model.model.layers.\d+.block_sparse_moe.output": {
"final_hidden_states": [Replicate()],
"router_logits": [Replicate()],
},
r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w1.input": [[Replicate()]],
r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w3.input": [[Replicate()]],
r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w2.output": [[Replicate()]],
"mixtral_model.model.norm.input": [[Replicate()]],
}

mixtral_plan = {"parameter": param_sharding_plan, "forward": fwd_resharding_plan}
Loading

0 comments on commit ac76ffa

Please sign in to comment.