diff --git a/examples/llama2_4D_finetune/llama_train.py b/examples/llama2_4D_finetune/llama_train.py index 84092ee..839fa43 100644 --- a/examples/llama2_4D_finetune/llama_train.py +++ b/examples/llama2_4D_finetune/llama_train.py @@ -37,6 +37,20 @@ from data_loader import DataLoader +class Net(torch.nn.Module): + def __init__(self, path, torch_dtype): + super().__init__() + self.llama_model = LlamaForCausalLM.from_pretrained(path, torch_dtype=torch_dtype) + self.loss_fn = torch.nn.CrossEntropyLoss() + + def forward(self, input_ids, labels): + logits = self.llama_model(input_ids).logits + logits = logits.flatten(end_dim=-2) + labels = labels.flatten() + loss = self.loss_fn(logits, labels) + return loss + + def estimate_llama2(config, bsz, sqence_length): embed = 4 * bsz * sqence_length * config.hidden_size ff = 3 * 2 * config.hidden_size * config.intermediate_size * bsz * sqence_length @@ -53,7 +67,7 @@ def run_llama2(args): local_rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) rank = int(os.environ["RANK"]) - device = f"cuda:{rank}" + device = f"cuda:{local_rank}" torch.cuda.set_device(device) dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) VESCALE_DEVICE_MESH.init_device_mesh(device, (args.dp, args.tp), mesh_dim_names=["DP", "TP"]) @@ -77,8 +91,8 @@ def run_llama2(args): "bfloat16": torch.bfloat16, }[args.dtype] - model = LlamaForCausalLM.from_pretrained("openlm-research/open_llama_3b", torch_dtype=ptdtype) - llama_config = model.config + model = Net("openlm-research/open_llama_3b", torch_dtype=ptdtype) + llama_config = model.llama_model.config if rank == 0: print(model) print(llama_config) @@ -165,7 +179,7 @@ def estimate_loss(): losses = torch.zeros(args.eval_iters // factor).to(device) for k in range(args.eval_iters // factor): X, Y = data_loader.get_batch(split, args.bsz * factor, factor * args.bsz // args.dp) - loss = model(X, labels=Y).loss + loss = model(X, Y) if world_size > 1: losses[k] = loss.to_local().item() else: @@ -198,7 +212,7 @@ def estimate_loss(): start_epoch.record() if world_size > 1: model.zero_grad_buffer() - loss = model(X, labels=Y).loss + loss = model(X, Y) loss.backward() grad_norm = -1 if world_size == 1 and args.grad_clip > 0: diff --git a/examples/llama2_4D_finetune/sharding_plan.py b/examples/llama2_4D_finetune/sharding_plan.py index 89b0653..d66b1af 100644 --- a/examples/llama2_4D_finetune/sharding_plan.py +++ b/examples/llama2_4D_finetune/sharding_plan.py @@ -45,19 +45,19 @@ # forward resharding plan for the whole open llama model model_fwd_resharding_plan = { - "model.input": [[Replicate()]], - "model.embed_tokens.output": [[Shard(1)]], - "model.norm.input": [[Shard(1)]], - "model.output": { + "llama_model.model.input": [[Replicate()]], + "llama_model.model.embed_tokens.output": [[Shard(1)]], + "llama_model.model.norm.input": [[Shard(1)]], + "llama_model.model.output": { "last_hidden_state": [Replicate()], }, - **{rf"model.layers.\d+.{k}": v for k, v in _decoder_fwd_resharding_plan.items()}, + **{rf"llama_model.model.layers.\d+.{k}": v for k, v in _decoder_fwd_resharding_plan.items()}, } # model parameter sharding plan for the whole open llama model model_param_sharding_plan = { - "model.embed_tokens.weight": [Shard(1)], - **{rf"model.layers.\d+.{k}": v for k, v in _decoder_param_sharding_plan.items()}, + "llama_model.model.embed_tokens.weight": [Shard(1)], + **{rf"llama_model.model.layers.\d+.{k}": v for k, v in _decoder_param_sharding_plan.items()}, } llama2_plan = {"parameter": model_param_sharding_plan, "forward": model_fwd_resharding_plan} diff --git a/examples/mixtral_4D_benchmark/README.md b/examples/mixtral_4D_benchmark/README.md index 7425f1a..f14abab 100644 --- a/examples/mixtral_4D_benchmark/README.md +++ b/examples/mixtral_4D_benchmark/README.md @@ -11,7 +11,7 @@ from HuggingFace without any model code modifications. ### Single Machine 8 cards ``` -torchrun --nproc-per-node=8 --nnodes=1 --master-port=42516 -- examples/mixtral_4D_benchmark/mixtral_train.py --num_hidden_layers=16 +torchrun --nproc-per-node=8 --standalone examples/mixtral_4D_benchmark/mixtral_train.py --num_hidden_layers=16 ``` This will start a 8-cards MFU benchmark for Mixtral with veScale with dp=1 and tp=8. diff --git a/examples/mixtral_4D_benchmark/mixtral_train.py b/examples/mixtral_4D_benchmark/mixtral_train.py index 32f77cf..164dfad 100644 --- a/examples/mixtral_4D_benchmark/mixtral_train.py +++ b/examples/mixtral_4D_benchmark/mixtral_train.py @@ -27,7 +27,7 @@ from vescale.optim.distributed_optimizer import DistributedOptimizer from vescale.initialize.deferred_init import deferred_init, is_deferred -from transformers.models.mixtral.modeling_mixtral import MixtralModel +from transformers.models.mixtral.modeling_mixtral import MixtralModel, MixtralSparseMoeBlock from transformers.models.mixtral.configuration_mixtral import MixtralConfig from sharding_plan import mixtral_plan @@ -84,7 +84,7 @@ def run_mixtral(args): accumulate_allreduce_grads_in_fp32=True, overlap_grad_reduce=False, use_distributed_optimizer=True, - whitelist_module_types=[MixtralSparseMoeBlock], + module_to_enforce=[MixtralSparseMoeBlock], ) doptim = DistributedOptimizer( diff --git a/examples/mixtral_4D_benchmark/sharding_plan.py b/examples/mixtral_4D_benchmark/sharding_plan.py index b8ae79e..8f0828c 100644 --- a/examples/mixtral_4D_benchmark/sharding_plan.py +++ b/examples/mixtral_4D_benchmark/sharding_plan.py @@ -21,49 +21,49 @@ param_sharding_plan = { - "embed_tokens.weight": [Replicate()], - r"layers.\d+.input_layernorm.weight": [Replicate()], # MixtralRMSNorm - r"layers.\d+.self_attn.q_proj.weight": [Shard(0)], - r"layers.\d+.self_attn.k_proj.weight": [Shard(0)], - r"layers.\d+.self_attn.v_proj.weight": [Shard(0)], + "model.embed_tokens.weight": [Replicate()], + r"model.layers.\d+.input_layernorm.weight": [Replicate()], # MixtralRMSNorm + r"model.layers.\d+.self_attn.q_proj.weight": [Shard(0)], + r"model.layers.\d+.self_attn.k_proj.weight": [Shard(0)], + r"model.layers.\d+.self_attn.v_proj.weight": [Shard(0)], # TODO: buggy, cos_cached or sin_cached can be updated or recreated if seqlen exceeds the max seqlen. - r"layers.\d+.self_attn.rotary_emb.layers.\d+.cos_cached": [Replicate()], - r"layers.\d+.self_attn.rotary_emb.layers.\d+.sin_cached": [Replicate()], - r"layers.\d+.self_attn.o_proj.weight": [Shard(1)], - r"layers.\d+.post_attention_layernorm.weight": [Replicate()], - r"layers.\d+.block_sparse_moe.gate.weight": [Replicate()], - r"layers.\d+.block_sparse_moe.experts.\d+.w1.weight": [Shard(0)], - r"layers.\d+.block_sparse_moe.experts.\d+.w3.weight": [Shard(0)], - r"layers.\d+.block_sparse_moe.experts.\d+.w2.weight": [Shard(1)], - "norm.weight": [Replicate()], + r"model.layers.\d+.self_attn.rotary_emb.layers.\d+.cos_cached": [Replicate()], + r"model.layers.\d+.self_attn.rotary_emb.layers.\d+.sin_cached": [Replicate()], + r"model.layers.\d+.self_attn.o_proj.weight": [Shard(1)], + r"model.layers.\d+.post_attention_layernorm.weight": [Replicate()], + r"model.layers.\d+.block_sparse_moe.gate.weight": [Replicate()], + r"model.layers.\d+.block_sparse_moe.experts.\d+.w1.weight": [Shard(0)], + r"model.layers.\d+.block_sparse_moe.experts.\d+.w3.weight": [Shard(0)], + r"model.layers.\d+.block_sparse_moe.experts.\d+.w2.weight": [Shard(1)], + "model.norm.weight": [Replicate()], } fwd_resharding_plan = { # TODO: buggy: attn mask is torch.Tensor, in training, it's a None r".input": {"input_ids": [Replicate()], "attention_mask": [Replicate()]}, - "embed_tokens.input": [[Replicate()]], + "model.embed_tokens.input": [[Replicate()]], # No SP # r"layers.\d+.input_layernorm.input": [[Replicate()]], # r"layers.\d+.input_layernorm.output": [[Replicate()]], # SP - r"layers.\d+.input_layernorm.input": [[Shard(1)]], - r"layers.\d+.input_layernorm.output": [[Shard(1)]], - r"layers.\d+.self_attn.input": [[Replicate()]], - r"layers.\d+.self_attn.output": {"attn_output": [Replicate()], "attn_weights": None, "past_key_value": None}, - r"layers.\d+.self_attn.o_proj.output": [[Replicate()]], + r"model.layers.\d+.input_layernorm.input": [[Shard(1)]], + r"model.layers.\d+.input_layernorm.output": [[Shard(1)]], + r"model.layers.\d+.self_attn.input": [[Replicate()]], + r"model.layers.\d+.self_attn.output": {"attn_output": [Replicate()], "attn_weights": None, "past_key_value": None}, + r"model.layers.\d+.self_attn.o_proj.output": [[Replicate()]], # No SP # r"layers.\d+.post_attention_layernorm.input": [[Replicate()]], # r"layers.\d+.post_attention_layernorm.output": [[Replicate()]], # SP - r"layers.\d+.post_attention_layernorm.input": [[Shard(1)]], - r"layers.\d+.post_attention_layernorm.output": [[Shard(1)]], - r"layers.\d+.block_sparse_moe.input": [[Replicate()]], - r"layers.\d+.block_sparse_moe.gate.output": [[Replicate()]], - r"layers.\d+.block_sparse_moe.output": {"final_hidden_states": [Replicate()], "router_logits": [Replicate()]}, - r"layers.\d+.block_sparse_moe.experts.\d+.w1.input": [[Replicate()]], - r"layers.\d+.block_sparse_moe.experts.\d+.w3.input": [[Replicate()]], - r"layers.\d+.block_sparse_moe.experts.\d+.w2.output": [[Replicate()]], - "norm.input": [[Replicate()]], + r"model.layers.\d+.post_attention_layernorm.input": [[Shard(1)]], + r"model.layers.\d+.post_attention_layernorm.output": [[Shard(1)]], + r"model.layers.\d+.block_sparse_moe.input": [[Replicate()]], + r"model.layers.\d+.block_sparse_moe.gate.output": [[Replicate()]], + r"model.layers.\d+.block_sparse_moe.output": {"final_hidden_states": [Replicate()], "router_logits": [Replicate()]}, + r"model.layers.\d+.block_sparse_moe.experts.\d+.w1.input": [[Replicate()]], + r"model.layers.\d+.block_sparse_moe.experts.\d+.w3.input": [[Replicate()]], + r"model.layers.\d+.block_sparse_moe.experts.\d+.w2.output": [[Replicate()]], + "model.norm.input": [[Replicate()]], } mixtral_plan = {"parameter": param_sharding_plan, "forward": fwd_resharding_plan} diff --git a/examples/mixtral_4D_training/mixtral_train.py b/examples/mixtral_4D_training/mixtral_train.py index cb33761..b0bf840 100644 --- a/examples/mixtral_4D_training/mixtral_train.py +++ b/examples/mixtral_4D_training/mixtral_train.py @@ -37,6 +37,20 @@ from data_loader import DataLoader +class Net(torch.nn.Module): + def __init__(self, mixtral_config): + super().__init__() + self.mixtral_model = MixtralForCausalLM(mixtral_config) + self.loss_fn = torch.nn.CrossEntropyLoss() + + def forward(self, input_ids, labels): + logits = self.mixtral_model(input_ids).logits + logits = logits.flatten(end_dim=-2) + labels = labels.flatten() + loss = self.loss_fn(logits, labels) + return loss + + def estimate_mixtral(config, bsz, sqence_length): embed = 4 * bsz * sqence_length * config.hidden_size # MixtralMoE consists of 3 linear layers. @@ -57,7 +71,7 @@ def run_mixtral(args): local_rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) rank = int(os.environ["RANK"]) - device = f"cuda:{rank}" + device = f"cuda:{local_rank}" torch.cuda.set_device(device) dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) VESCALE_DEVICE_MESH.init_device_mesh(device, (args.dp, args.tp), mesh_dim_names=["DP", "TP"]) @@ -90,7 +104,7 @@ def run_mixtral(args): ) if world_size > 1: - model = MixtralForCausalLM(mixtral_config) + model = Net(mixtral_config) model.to(ptdtype) model = parallelize_module( @@ -104,11 +118,11 @@ def run_mixtral(args): model, VESCALE_DEVICE_MESH["DP"], accumulate_allreduce_grads_in_fp32=False, - use_distributed_optimizer=True, - whitelist_module_types=[MixtralSparseMoeBlock], + use_distributed_optimizer=args.use_DO, + module_to_enforce=[MixtralSparseMoeBlock], ) else: - model = MixtralForCausalLM(mixtral_config).to(device) + model = Net(mixtral_config).to(device) model.to(ptdtype) print(f"rank {rank} cuda.rng_state {torch.cuda.get_rng_state().view(torch.int64)}") @@ -170,7 +184,7 @@ def estimate_loss(): losses = torch.zeros(args.eval_iters // factor).to(device) for k in range(args.eval_iters // factor): X, Y = data_loader.get_batch(split, args.bsz * factor, factor * args.bsz // args.dp) - loss = model(X, labels=Y).loss + loss = model(X, Y) if world_size > 1: losses[k] = loss.to_local().item() else: @@ -203,7 +217,7 @@ def estimate_loss(): start_epoch.record() if world_size > 1: model.zero_grad_buffer() - loss = model(X, labels=Y).loss + loss = model(X, Y) loss.backward() grad_norm = -1 if world_size == 1 and args.grad_clip > 0: @@ -274,11 +288,6 @@ def parse_args(): parser.add_argument("--num_hidden_layers", type=int, default=2) parser.add_argument("--num_attention_heads", type=int, default=8) parser.add_argument("--num_key_value_heads", type=int, default=8) - # parser.add_argument("--hidden_size", type=int, default=4096) - # parser.add_argument("--intermediate_size", type=int, default=14336) - # parser.add_argument("--num_hidden_layers", type=int, default=16) - # parser.add_argument("--num_attention_heads", type=int, default=32) - # parser.add_argument("--num_key_value_heads", type=int, default=8) # Optimizer related parser.add_argument("--use_DO", type=bool, default=True) diff --git a/examples/mixtral_4D_training/sharding_plan.py b/examples/mixtral_4D_training/sharding_plan.py index 523827d..5f1d9d2 100644 --- a/examples/mixtral_4D_training/sharding_plan.py +++ b/examples/mixtral_4D_training/sharding_plan.py @@ -21,49 +21,56 @@ param_sharding_plan = { - "model.embed_tokens.weight": [Replicate()], - r"model.layers.\d+.input_layernorm.weight": [Replicate()], # MixtralRMSNorm - r"model.layers.\d+.self_attn.q_proj.weight": [Shard(0)], - r"model.layers.\d+.self_attn.k_proj.weight": [Shard(0)], - r"model.layers.\d+.self_attn.v_proj.weight": [Shard(0)], + "mixtral_model.model.embed_tokens.weight": [Replicate()], + r"mixtral_model.model.layers.\d+.input_layernorm.weight": [Replicate()], # MixtralRMSNorm + r"mixtral_model.model.layers.\d+.self_attn.q_proj.weight": [Shard(0)], + r"mixtral_model.model.layers.\d+.self_attn.k_proj.weight": [Shard(0)], + r"mixtral_model.model.layers.\d+.self_attn.v_proj.weight": [Shard(0)], # TODO: buggy, cos_cached or sin_cached can be updated or recreated if seqlen exceeds the max seqlen. - r"model.layers.\d+.self_attn.rotary_emb.layers.\d+.cos_cached": [Replicate()], - r"model.layers.\d+.self_attn.rotary_emb.layers.\d+.sin_cached": [Replicate()], - r"model.layers.\d+.self_attn.o_proj.weight": [Shard(1)], - r"model.layers.\d+.post_attention_layernorm.weight": [Replicate()], - r"model.layers.\d+.block_sparse_moe.gate.weight": [Replicate()], - r"model.layers.\d+.block_sparse_moe.experts.\d+.w1.weight": [Shard(0)], - r"model.layers.\d+.block_sparse_moe.experts.\d+.w3.weight": [Shard(0)], - r"model.layers.\d+.block_sparse_moe.experts.\d+.w2.weight": [Shard(1)], - "model.norm.weight": [Replicate()], + r"mixtral_model.model.layers.\d+.self_attn.rotary_emb.layers.\d+.cos_cached": [Replicate()], + r"mixtral_model.model.layers.\d+.self_attn.rotary_emb.layers.\d+.sin_cached": [Replicate()], + r"mixtral_model.model.layers.\d+.self_attn.o_proj.weight": [Shard(1)], + r"mixtral_model.model.layers.\d+.post_attention_layernorm.weight": [Replicate()], + r"mixtral_model.model.layers.\d+.block_sparse_moe.gate.weight": [Replicate()], + r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w1.weight": [Shard(0)], + r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w3.weight": [Shard(0)], + r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w2.weight": [Shard(1)], + "mixtral_model.model.norm.weight": [Replicate()], } fwd_resharding_plan = { # TODO: buggy: attn mask is torch.Tensor, in training, it's a None r".input": {"input_ids": [Replicate()], "attention_mask": [Replicate()]}, - "model.embed_tokens.input": [[Replicate()]], + "mixtral_model.model.embed_tokens.input": [[Replicate()]], # No SP # r"layers.\d+.input_layernorm.input": [[Replicate()]], # r"layers.\d+.input_layernorm.output": [[Replicate()]], # SP - r"model.layers.\d+.input_layernorm.input": [[Shard(1)]], - r"model.layers.\d+.input_layernorm.output": [[Shard(1)]], - r"model.layers.\d+.self_attn.input": [[Replicate()]], - r"model.layers.\d+.self_attn.output": {"attn_output": [Replicate()], "attn_weights": None, "past_key_value": None}, - r"model.layers.\d+.self_attn.o_proj.output": [[Replicate()]], + r"mixtral_model.model.layers.\d+.input_layernorm.input": [[Shard(1)]], + r"mixtral_model.model.layers.\d+.input_layernorm.output": [[Shard(1)]], + r"mixtral_model.model.layers.\d+.self_attn.input": [[Replicate()]], + r"mixtral_model.model.layers.\d+.self_attn.output": { + "attn_output": [Replicate()], + "attn_weights": None, + "past_key_value": None, + }, + r"mixtral_model.model.layers.\d+.self_attn.o_proj.output": [[Replicate()]], # No SP # r"model.layers.\d+.post_attention_layernorm.input": [[Replicate()]], # r"model.layers.\d+.post_attention_layernorm.output": [[Replicate()]], # SP - r"model.layers.\d+.post_attention_layernorm.input": [[Shard(1)]], - r"model.layers.\d+.post_attention_layernorm.output": [[Shard(1)]], - r"model.layers.\d+.block_sparse_moe.input": [[Replicate()]], - r"model.layers.\d+.block_sparse_moe.gate.output": [[Replicate()]], - r"model.layers.\d+.block_sparse_moe.output": {"final_hidden_states": [Replicate()], "router_logits": [Replicate()]}, - r"model.layers.\d+.block_sparse_moe.experts.\d+.w1.input": [[Replicate()]], - r"model.layers.\d+.block_sparse_moe.experts.\d+.w3.input": [[Replicate()]], - r"model.layers.\d+.block_sparse_moe.experts.\d+.w2.output": [[Replicate()]], - "model.norm.input": [[Replicate()]], + r"mixtral_model.model.layers.\d+.post_attention_layernorm.input": [[Shard(1)]], + r"mixtral_model.model.layers.\d+.post_attention_layernorm.output": [[Shard(1)]], + r"mixtral_model.model.layers.\d+.block_sparse_moe.input": [[Replicate()]], + r"mixtral_model.model.layers.\d+.block_sparse_moe.gate.output": [[Replicate()]], + r"mixtral_model.model.layers.\d+.block_sparse_moe.output": { + "final_hidden_states": [Replicate()], + "router_logits": [Replicate()], + }, + r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w1.input": [[Replicate()]], + r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w3.input": [[Replicate()]], + r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w2.output": [[Replicate()]], + "mixtral_model.model.norm.input": [[Replicate()]], } mixtral_plan = {"parameter": param_sharding_plan, "forward": fwd_resharding_plan} diff --git a/examples/mixtral_EP_training/README.md b/examples/mixtral_EP_training/README.md new file mode 100644 index 0000000..f8ad193 --- /dev/null +++ b/examples/mixtral_EP_training/README.md @@ -0,0 +1,56 @@ +# Expert Parallelism in veScale + +## Overview + +veScale provides an efficient framework for training Mixture of Experts (MoE) models using expert parallelism. Expert parallelism can be deployed with the `parallelize_experts()` function, which simplifies the process of distributing and managing workload during MoE training. + +### Function Signature + +```python +model = parallelize_experts( + module: nn.Module, + experts_expr: Union[str, List[str]], + experts_allocator: vescale.moe.ExpertsAllocator, + token_dispatcher: vescale.moe.TokenDispatcher, + config: Dict, +) +``` + +### Parameters +- **`module`**: The training model (an instance of `nn.Module`) to be parallelized. +- **`experts_expr`**: Specifies the paths to the expert modules. Can be a string or a list of strings. +- **`experts_allocator`**: An instance of `ExpertsAllocator`, used for managing expert parameter allocation. +- **`token_dispatcher`**: An instance of `TokenDispatcher`, responsible for token scheduling and distribution. +- **`config`**: A dictionary containing the MoE training configuration, including layer count, number of experts, and other relevant settings. + + +## Custom Scheduling + +veScale allows users to define custom scheduling strategies for expert parallelism by implementing the following components: + +- **`ExpertsAllocator`**: Manages expert parameter allocation. It can use `collect_performance()` to profile and dynamically adjust the DP x TP device mesh for each expert. By default, veScale shards all expert parameters across devices using tensor parallelism. + +- **`TokenDispatcher`**: Handles token distribution. Using `assign_task()`, it determines workload allocation (e.g., expert IDs and token weights) and adjusts scheduling with `collect_performance()`. The default implementation randomly assigns tokens to a single DP rank for the selected expert. + +## Optimizer Support + +Since veScale supports dynamic placement of expert parameters, a dedicated optimizer, `MoEOptimizer`, is required. This optimizer handles the redistribution of expert parameters and their states efficiently. +Future updates will integrate these functionalities into optimizers for static parameters to streamline the process. + + +## Getting Started + +### Data Preparation +Prepare the Shakespeare dataset by running: + +```bash +cd data/shakespeare/ +python3 prepare.py +cd ../.. +``` + +### Training Command + +``` +torchrun --standalone --nproc_per_node={GPU_CNT} mixtral_train.py --dp={dp_size} --tp={tp_size} --max_iters={max_iters} +``` diff --git a/examples/mixtral_EP_training/data/shakespeare/prepare.py b/examples/mixtral_EP_training/data/shakespeare/prepare.py new file mode 100644 index 0000000..60e56e5 --- /dev/null +++ b/examples/mixtral_EP_training/data/shakespeare/prepare.py @@ -0,0 +1,54 @@ +################################################################################ +# Copyright (c) 2022 Andrej Karpathy + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +################################################################################ +import os +import requests +import tiktoken +import numpy as np + +# download the tiny shakespeare dataset +input_file_path = os.path.join(os.path.dirname(__file__), "input.txt") +if not os.path.exists(input_file_path): + data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" + with open(input_file_path, "w", encoding="utf-8") as f: + f.write(requests.get(data_url).text) + +with open(input_file_path, encoding="utf-8") as f: + data = f.read() +n = len(data) +train_data = data[: int(n * 0.9)] +val_data = data[int(n * 0.9) :] + +# encode with tiktoken gpt2 bpe +enc = tiktoken.get_encoding("gpt2") +train_ids = enc.encode_ordinary(train_data) +val_ids = enc.encode_ordinary(val_data) +print(f"train has {len(train_ids):,} tokens") +print(f"val has {len(val_ids):,} tokens") + +# export to bin files +train_ids = np.array(train_ids, dtype=np.uint16) +val_ids = np.array(val_ids, dtype=np.uint16) +train_ids.tofile(os.path.join(os.path.dirname(__file__), "train.bin")) +val_ids.tofile(os.path.join(os.path.dirname(__file__), "val.bin")) + +# train.bin has 301,966 tokens +# val.bin has 36,059 tokens diff --git a/examples/mixtral_EP_training/data/shakespeare/readme.md b/examples/mixtral_EP_training/data/shakespeare/readme.md new file mode 100644 index 0000000..1e6c457 --- /dev/null +++ b/examples/mixtral_EP_training/data/shakespeare/readme.md @@ -0,0 +1,9 @@ + +# tiny shakespeare + +Tiny shakespeare, of the good old char-rnn fame :) + +After running `prepare.py`: + +- train.bin has 301,966 tokens +- val.bin has 36,059 tokens diff --git a/examples/mixtral_EP_training/data_loader.py b/examples/mixtral_EP_training/data_loader.py new file mode 100644 index 0000000..f83c582 --- /dev/null +++ b/examples/mixtral_EP_training/data_loader.py @@ -0,0 +1,64 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import os +from typing import Optional + +import numpy as np +import torch + +from vescale.dtensor.device_mesh import DeviceMesh +from vescale import distribute_tensor +from vescale.dtensor.placement_types import Replicate +from vescale.dtensor import empty as d_empty + + +class DataLoader: + def __init__(self, dataset: str, seqlen: int, mesh: Optional[DeviceMesh] = None, dp_rank: int = 0): + self.data_dir = os.path.join("data", dataset) + self.seqlen = seqlen + self.mesh = mesh + self.dp_rank = dp_rank + if mesh is not None: + self.device_type = mesh.device_type + else: + self.device_type = "cuda" + + def get_batch(self, split, bsz, lbsz): + # We recreate np.memmap every batch to avoid a memory leak, as per + # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122 + if split == "train": + data = np.memmap(os.path.join(self.data_dir, "train.bin"), dtype=np.uint16, mode="r") + else: + data = np.memmap(os.path.join(self.data_dir, "val.bin"), dtype=np.uint16, mode="r") + if self.mesh is not None: + ix = d_empty((bsz,), device_mesh=self.mesh, placements=[Replicate()]) + else: + ix = torch.empty((bsz,), device="cuda") + ix = torch.randint_like(ix, len(data) - self.seqlen, dtype=torch.int64) + if self.mesh is not None: + ix = ix.to_local() + if self.mesh is None or self.mesh.get_rank() == 0: + print(f"sum(ix) {sum(ix)}") + ix = torch.split(ix, lbsz)[self.dp_rank] + x = torch.stack([torch.from_numpy((data[i : i + self.seqlen]).astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy((data[i + 1 : i + 1 + self.seqlen]).astype(np.int64)) for i in ix]) + x, y = x.to(self.device_type), y.to(self.device_type) + if self.mesh is not None: + x = distribute_tensor(x, self.mesh["TP"], [Replicate()]) + y = distribute_tensor(y, self.mesh["TP"], [Replicate()]) + return x, y diff --git a/examples/mixtral_EP_training/exp.py b/examples/mixtral_EP_training/exp.py new file mode 100644 index 0000000..e2545ff --- /dev/null +++ b/examples/mixtral_EP_training/exp.py @@ -0,0 +1,95 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import os + + +def parse_train_loss(log_fn, name=None): + lines = open(log_fn).readlines() + train_losses = [] + for line in lines: + if "loss" in line and "iter" in line: + token = line.split()[line.split().index("loss") + 1] + train_loss = float(token) + train_losses.append(train_loss) + if name is None: + name = log_fn + print(f'"{name}": {train_losses},') + + +def parse_grad_norm(log_fn, name=None): + lines = open(log_fn).readlines() + grad_norms = [] + for line in lines: + if "|g|" in line: + token = line.split()[line.split().index("|g|") + 1] + grad_norm = float(token) + grad_norms.append(grad_norm) + if name is None: + name = log_fn + print(f'"{name}": {grad_norms},') + + +GPU_CNT = 4 +DP_SIZES = [1, 2] +SINGLE_GPU_RUN = "python3" +MULTI_GPU_RUN = f"torchrun --standalone --nproc_per_node={GPU_CNT}" +CODE = "mixtral_train.py" +LOG_PREFIX = "mixtral_new_MOE" +TRAIN_BIN_PATH = "data/shakespeare/train.bin" + + +def run_exps(max_iters, dtypes, run=True): + if not os.path.isfile(TRAIN_BIN_PATH): + os.system("cd data/shakespeare/ && python3 prepare.py && cd ../..") + os.makedirs("logs", exist_ok=True) + if run: + for dtype in dtypes: + dt = "bfloat16" if dtype == "bf16" else "float32" + cmd = f"{SINGLE_GPU_RUN} {CODE} --dp=1 --tp=1 --max_iters={max_iters} --dtype='{dt}'" + log_fn = f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log" + print(f"run {cmd} > {log_fn} 2> {log_fn}.err") + os.system(f"{cmd} > {log_fn} 2> {log_fn}.err") + for dp_size in DP_SIZES: + tp_size = GPU_CNT // dp_size + dt = "bfloat16" if dtype == "bf16" else "float32" + cmd = f"{MULTI_GPU_RUN} {CODE} --dp={dp_size} --tp={tp_size} --max_iters={max_iters} --dtype='{dt}'" + log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log" + print(f"run {cmd} > {log_fn} 2> {log_fn}.err") + os.system(f"{cmd} > {log_fn} 2> {log_fn}.err") + + print("train_loss = {") + for dtype in dtypes: + parse_train_loss(f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}") + for dp_size in DP_SIZES: + tp_size = GPU_CNT // dp_size + log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log" + parse_train_loss(log_fn, f"{GPU_CNT}GPU_DP{dp_size}_TP{tp_size}_{dtype}") + print("}") + + print("grad_norm = {") + for dtype in dtypes: + parse_grad_norm(f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}") + for dp_size in DP_SIZES: + tp_size = GPU_CNT // dp_size + log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log" + parse_grad_norm(log_fn, f"{GPU_CNT}GPU_DP{dp_size}_TP{tp_size}_{dtype}") + print("}") + + +if __name__ == "__main__": + run_exps(1000, ["bf16"], run=True) diff --git a/examples/mixtral_EP_training/mixtral_train.py b/examples/mixtral_EP_training/mixtral_train.py new file mode 100644 index 0000000..09ef5d0 --- /dev/null +++ b/examples/mixtral_EP_training/mixtral_train.py @@ -0,0 +1,404 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import argparse +import os +import math +import inspect + +import torch +import torch.distributed as dist + +from vescale.dmodule import parallelize_module +from vescale.dtensor.placement_types import InterleavedShard +from vescale.moe import parallelize_experts, MoEOptimizer +from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP +from vescale.optim.distributed_optimizer import DistributedOptimizer +from vescale.optim.base_optimizer import BasicOptimizer +from vescale.devicemesh_api import VESCALE_DEVICE_MESH +from vescale.dtensor.random import manual_seed +from vescale import DTensor + +from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM, MixtralSparseMoeBlock, MixtralModel +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from sharding_plan import mixtral_plan + +from data_loader import DataLoader + + +class Net(torch.nn.Module): + def __init__(self, mixtral_config): + super().__init__() + self.mixtral_model = MixtralForCausalLM(mixtral_config) + self.loss_fn = torch.nn.CrossEntropyLoss() + + def forward(self, input_ids, labels): + logits = self.mixtral_model(input_ids).logits + logits = logits.flatten(end_dim=-2) + labels = labels.flatten() + loss = self.loss_fn(logits, labels) + return loss + + +def wrap_moe_block(forward_func): + old_func_dict = {} + + def _pre_forward_overload(): + nonlocal old_func_dict + old_func_dict = {} + old_func_dict["where"] = torch.where + old_func_dict["index_select"] = DTensor.__getitem__ + + def local_where(*args, **kwargs): + output = old_func_dict["where"](*args, **kwargs) + if isinstance(output, DTensor): + return output.to_local() + elif isinstance(output, torch.Tensor): + return output + elif isinstance(output, tuple): + output_list = [] + for t in output: + if isinstance(t, DTensor): + output_list.append(t.to_local()) + else: + output_list.append(t) + return tuple(output_list) + else: + raise NotImplementedError + + def local_index_select(*args, **kwargs): + return old_func_dict["index_select"](args[0].to_local(), *args[1:], **kwargs) + + torch.where = local_where + DTensor.__getitem__ = local_index_select + + def _post_forward_overload(): + nonlocal old_func_dict + torch.where = old_func_dict["where"] + DTensor.__getitem__ = old_func_dict["index_select"] + + def forward(*args, **kwargs): + _pre_forward_overload() + output = forward_func(*args, **kwargs) + _post_forward_overload() + return output + + return forward + + +def estimate_mixtral(config, bsz, sqence_length): + embed = 4 * bsz * sqence_length * config.hidden_size + # MixtralMoE consists of 3 linear layers. + ff = 3 * 2 * config.num_experts_per_tok * config.hidden_size * config.intermediate_size * bsz * sqence_length + # GQA + head_size = config.hidden_size // config.num_attention_heads + attn_q = 2 * bsz * sqence_length * config.hidden_size * config.hidden_size + attn_kv = 2 * 2 * bsz * sqence_length * config.hidden_size * config.num_key_value_heads * head_size + attn_mask = 2 * sqence_length * config.hidden_size + attn_proj = 2 * config.hidden_size * config.hidden_size * bsz * sqence_length + attn = attn_q + attn_kv + attn_mask + attn_proj + return embed + (ff + attn) * config.num_hidden_layers + + +def run_mixtral(args): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + if world_size > 1: + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["RANK"]) + device = f"cuda:{local_rank}" + torch.cuda.set_device(device) + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) + VESCALE_DEVICE_MESH.init_device_mesh(device, (args.dp, args.tp), mesh_dim_names=["DP", "TP"]) + device_mesh = VESCALE_DEVICE_MESH.get() + dp_rank = dist.get_rank() // args.tp + torch.random.manual_seed(0) + torch.cuda.random.manual_seed_all(0) + manual_seed(0, device_mesh) + else: + local_rank = 0 + rank = 0 + device = f"cuda:{0}" + device_mesh = None + torch.cuda.set_device(device) + dp_rank = 0 + torch.random.manual_seed(0) + torch.cuda.random.manual_seed_all(0) + ptdtype = { + "float32": torch.float, + "bfloat16": torch.bfloat16, + }[args.dtype] + + mixtral_config = MixtralConfig( + vocab_size=args.vocab_size, + hidden_size=args.hidden_size, + intermediate_size=args.intermediate_size, + num_hidden_layers=args.num_hidden_layers, + num_attention_heads=args.num_attention_heads, + num_key_value_heads=args.num_key_value_heads, + ) + + if world_size > 1: + model = Net(mixtral_config) + model.to(ptdtype) + + experts_module_name = r"mixtral_model.model.layers.\d+.block_sparse_moe.experts" + + factory = { + MixtralSparseMoeBlock: {torch.zeros: [InterleavedShard(0, args.bsz // args.dp)]}, + MixtralModel: True, + } + MixtralSparseMoeBlock.forward = wrap_moe_block(MixtralSparseMoeBlock.forward) + + model = parallelize_module( + model, + VESCALE_DEVICE_MESH["TP"], + mixtral_plan, + factory=factory, + ) + + param_to_ignore = [param_name for param_name, _ in model.named_parameters() if "experts" in param_name] + + model = DDP( + model, + VESCALE_DEVICE_MESH["DP"], + accumulate_allreduce_grads_in_fp32=False, + use_distributed_optimizer=args.use_DO, + param_to_ignore=param_to_ignore, + ) + + moe_config = { + "num_layers": mixtral_config.num_hidden_layers, + "num_experts": mixtral_config.num_local_experts, + "num_devices": torch.distributed.get_world_size(), + } + + model = parallelize_experts( + model, + experts_module_name, + config=moe_config, + ) + else: + model = Net(mixtral_config).to(device) + model.to(ptdtype) + print(f"rank {rank} cuda.rng_state {torch.cuda.get_rng_state().view(torch.int64)}") + + def configure_optimizers(model, weight_decay, learning_rate, betas): + # filter out those that do not require grad + param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} + # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. + # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. + decay_params = [p for n, p in param_dict.items() if p.dim() >= 2 and "experts" not in n] + nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2 and "experts" not in n] + moe_params = [p for n, p in param_dict.items() if "experts" in n] + optim_groups = [ + {"params": decay_params, "weight_decay": weight_decay}, + {"params": nodecay_params, "weight_decay": 0.0}, + ] + num_decay_params = sum(p.numel() for p in decay_params) + num_nodecay_params = sum(p.numel() for p in nodecay_params) + num_moe_params = sum(p.numel() for p in moe_params) + # Create AdamW optimizer and use the fused version if it is available + fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters + use_fused = fused_available and (world_size == 1 or device_mesh.device_type == "cuda") + extra_args = dict(fused=True) if use_fused else dict() + base_optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) + if world_size == 1 or dist.get_rank() == 0: + print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") + print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") + print(f"experts parameter tensors: {len(moe_params)}, with {num_moe_params:,} parameters") + print(f"using fused AdamW: {use_fused}") + # + + + Initialize a ZeRO-2 optimizer using veScale API + if args.use_DO and world_size > 1: + optimizer = DistributedOptimizer( + base_optimizer, + models=[model], + clip_grad=args.grad_clip, + grad_to_fp32=False, + ) + moe_optimizer = MoEOptimizer( + torch.optim.AdamW, + clip_grad=args.grad_clip, + param_buffer=model.moe_param_buffer, + lr=learning_rate, + betas=betas, + weight_decay=weight_decay, + **extra_args, + ) + elif world_size > 1: + optimizer = BasicOptimizer(base_optimizer, models=model) + moe_optimizer = MoEOptimizer( + torch.optim.AdamW, + clip_grad=args.grad_clip, + param_buffer=model.moe_param_buffer, + lr=learning_rate, + betas=betas, + weight_decay=weight_decay, + **extra_args, + ) + else: + optimizer = base_optimizer + moe_optim_groups = [ + {"params": moe_params, "weight_decay": weight_decay}, + ] + moe_optimizer = torch.optim.AdamW(moe_optim_groups, lr=learning_rate, betas=betas, **extra_args) + return optimizer, moe_optimizer + + # TODO: wrap up them into a single optimizer + doptimizer, moe_optimizer = configure_optimizers(model, args.weight_decay, args.lr, (0.9, 0.95)) + + # learning rate decay scheduler (cosine with warmup) + def get_lr(it): + # 1) linear warmup for warmup_iters steps + if it < args.warmup_iters: + return args.lr * it / args.warmup_iters + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - args.warmup_iters) / (args.max_iters - args.warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return args.min_lr + coeff * (args.lr - args.min_lr) + + @torch.no_grad() + def estimate_loss(): + out = {} + model.eval() + for split in ["train", "val"]: + factor = 1 + losses = torch.zeros(args.eval_iters // factor).to(device) + for k in range(args.eval_iters // factor): + X, Y = data_loader.get_batch(split, args.bsz * factor, factor * args.bsz // args.dp) + loss = model(X, Y) + if world_size > 1: + losses[k] = loss.to_local().item() + else: + losses[k] = loss.item() + if world_size > 1: + dist.all_reduce(losses) + out[split] = losses.mean() / world_size + model.train() + return out + + data_loader = DataLoader(args.dataset, args.seqlen, device_mesh, dp_rank) + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + model.train() + for iter in range(args.max_iters): + # determine and set the learning rate for this iteration + lr = get_lr(iter) if args.decay_lr else args.lr + for param_group in doptimizer.param_groups if world_size == 1 else doptimizer.optimizer.param_groups: + param_group["lr"] = lr + for param_group in moe_optimizer.param_groups: + param_group["lr"] = lr + # load a batch of training data + X, Y = data_loader.get_batch("train", args.bsz, args.bsz // args.dp) + + start_epoch = torch.cuda.Event(enable_timing=True) + end_epoch = torch.cuda.Event(enable_timing=True) + start_epoch.record() + if world_size > 1: + model.zero_grad_buffer() + loss = model(X, Y) + loss.backward() + grad_norm = -1 + if world_size == 1 and args.grad_clip > 0: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) + if world_size > 1: + model.finish_grad_sync() + if world_size > 1 and args.grad_clip > 0: + grad_norm = doptimizer.step() + moe_optimizer.step() + else: + doptimizer.step() + moe_optimizer.step() + doptimizer.zero_grad(set_to_none=True) + moe_optimizer.zero_grad(set_to_none=True) + end_epoch.record() + torch.cuda.synchronize() + epoch_t = start_epoch.elapsed_time(end_epoch) + if world_size > 1: + loss_val = loss.to_local() + dist.all_reduce(loss_val) + loss_val = loss_val.item() / world_size + else: + loss_val = loss.item() + if world_size == 1 or dist.get_rank() == 0: + print(f"iter {iter} loss {loss_val:.6f} |g| {grad_norm:.6f} lr {lr:.6f} fwd/bwd_t {epoch_t:.2f}ms") + end.record() + torch.cuda.synchronize() + exec_t = start.elapsed_time(end) / 1000 / args.max_iters + # masure mfu + if rank == 0: + total_flops = { + "A100": { + "bfloat16": 312 * (10**12), + "float32": 19.5 * (10**12), + }, + "H100": { + "bfloat16": 1000 * (10**12), + "float32": 312 * (10**12), + }, + }["A100"][args.dtype] + if world_size > 1: + total_flops *= world_size + print(f"1 iter time: {exec_t}") + mixtral_flops = estimate_mixtral(mixtral_config, args.bsz, args.seqlen) + print(f"fwd llama2 flops: {mixtral_flops}") + # bwd ~= fwd * 2 + print("mfu:", mixtral_flops * 3 * 100 / exec_t / total_flops) + + if world_size > 1: + dist.barrier() + dist.destroy_process_group() + + +def parse_args(): + parser = argparse.ArgumentParser() + # Training Meta + parser.add_argument("--dtype", type=str, default="float32") + parser.add_argument("--max_iters", type=int, default=2) + parser.add_argument("--bsz", type=int, default=128) + parser.add_argument("--seqlen", type=int, default=256) + parser.add_argument("--dp", type=int, default=1) + parser.add_argument("--tp", type=int, default=8) + parser.add_argument("--dataset", type=str, default="shakespeare") + parser.add_argument("--eval_iters", type=int, default=1) + parser.add_argument("--eval_interval", type=int, default=400) + + # Model config + parser.add_argument("--vocab_size", type=int, default=50304) + parser.add_argument("--hidden_size", type=int, default=384) + parser.add_argument("--intermediate_size", type=int, default=1536) + parser.add_argument("--num_hidden_layers", type=int, default=2) + parser.add_argument("--num_attention_heads", type=int, default=8) + parser.add_argument("--num_key_value_heads", type=int, default=8) + + # Optimizer related + parser.add_argument("--use_DO", type=bool, default=True) + parser.add_argument("--decay_lr", type=bool, default=True) + parser.add_argument("--lr", type=float, default=3e-4) + parser.add_argument("--warmup_iters", type=int, default=100) + parser.add_argument("--min_lr", type=float, default=3e-5) + parser.add_argument("--grad_clip", type=float, default=0) + parser.add_argument("--weight_decay", type=float, default=0.1) + return parser + + +if __name__ == "__main__": + parser = parse_args() + args = parser.parse_args() + run_mixtral(args) diff --git a/examples/mixtral_EP_training/sharding_plan.py b/examples/mixtral_EP_training/sharding_plan.py new file mode 100644 index 0000000..00484df --- /dev/null +++ b/examples/mixtral_EP_training/sharding_plan.py @@ -0,0 +1,47 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +"""This file contain TP/SP sharding plans for Mixtral example code.""" + +from vescale.dtensor.placement_types import Replicate, Shard + +param_sharding_plan = { + "mixtral_model.model.embed_tokens.weight": [Replicate()], + r"mixtral_model.model.layers.\d+.input_layernorm.weight": [Replicate()], # MixtralRMSNorm + r"mixtral_model.model.layers.\d+.self_attn.q_proj.weight": [Shard(0)], + r"mixtral_model.model.layers.\d+.self_attn.k_proj.weight": [Shard(0)], + r"mixtral_model.model.layers.\d+.self_attn.v_proj.weight": [Shard(0)], + # TODO: buggy, cos_cached or sin_cached can be updated or recreated if seqlen exceeds the max seqlen. + r"mixtral_model.model.layers.\d+.self_attn.rotary_emb.cos_cached": [Replicate()], + r"mixtral_model.model.layers.\d+.self_attn.rotary_emb.sin_cached": [Replicate()], + r"mixtral_model.model.layers.\d+.self_attn.o_proj.weight": [Shard(1)], + r"mixtral_model.model.layers.\d+.post_attention_layernorm.weight": [Replicate()], + r"mixtral_model.model.layers.\d+.block_sparse_moe.gate.weight": [Replicate()], + r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w1.weight": [Shard(0)], + r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w3.weight": [Shard(0)], + r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w2.weight": [Shard(1)], + "mixtral_model.model.norm.weight": [Replicate()], +} + +fwd_resharding_plan = { + "mixtral_model.model.embed_tokens.output": [[Shard(1)]], + r"mixtral_model.model.layers.\d+.self_attn.input": [[Replicate()]], + r"mixtral_model.model.layers.\d+.self_attn.o_proj.output": [[Shard(1)]], + "mixtral_model.lm_head.input": [[Replicate()]], +} + +mixtral_plan = {"parameter": param_sharding_plan, "forward": fwd_resharding_plan} diff --git a/examples/nanogpt_4D_finetune/finetune_4D.py b/examples/nanogpt_4D_finetune/finetune_4D.py index d660b71..331b8b3 100644 --- a/examples/nanogpt_4D_finetune/finetune_4D.py +++ b/examples/nanogpt_4D_finetune/finetune_4D.py @@ -112,9 +112,10 @@ def main(): # ddp = world_size > 1 ddp = int(os.environ.get("RANK", -1)) != -1 if ddp: + local_rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) rank = int(os.environ["RANK"]) - device = f"cuda:{rank}" + device = f"cuda:{local_rank}" torch.cuda.set_device(device) init_process_group(backend=backend, world_size=world_size, rank=rank) # + + + VeScale API below diff --git a/requirements.txt b/requirements.txt index 450e8fe..9eca37a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,4 @@ accelerate transformers==4.40.2 flash_attn matplotlib -mmh3 \ No newline at end of file +mmh3 diff --git a/test/emulator/test_mesh_collectives.py b/test/emulator/test_mesh_collectives.py index 4864d43..c4ed821 100644 --- a/test/emulator/test_mesh_collectives.py +++ b/test/emulator/test_mesh_collectives.py @@ -209,7 +209,9 @@ def test_mesh_all_to_all(self, mesh_dim, nelement): local_tensor_list = [torch.cat(local_tensor_list, dim=0)] local_tensor_list = list(torch.chunk(local_tensor_list[0], group_world_size, dim=0)) - vescale.dtensor._collective_utils.mesh_all_to_all(ground_truth_list, local_tensor_list, self.vescale_mesh, mesh_dim) + vescale.dtensor._collective_utils.mesh_all_to_all( + ground_truth_list, local_tensor_list, self.vescale_mesh, mesh_dim + ) mesh_all_to_all(outputs_list, data_list, self.mesh, mesh_dim) local_output = outputs_list[torch_rank] diff --git a/vescale/ddp/distributed_data_parallel.py b/vescale/ddp/distributed_data_parallel.py index f14c423..9f825d7 100644 --- a/vescale/ddp/distributed_data_parallel.py +++ b/vescale/ddp/distributed_data_parallel.py @@ -4,7 +4,7 @@ # Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. ################################################################################ -from typing import Dict, Union, List, Any +from typing import Dict, Union, List, Type import torch import torch.distributed.distributed_c10d as c10d @@ -14,6 +14,9 @@ from vescale.ddp.grad_buffer import GradBuffer +_DDP_IGNORE_TAG = "DDP_IGNORE" + + class DistributedDataParallel(torch.nn.Module): """ DDP wrapper which stores grads in contiguous buffers. Also has option of overlapping @@ -38,7 +41,9 @@ class DistributedDataParallel(torch.nn.Module): per bucket _if_ overlap_grad_reduce is True and pp_rank is 0. bucket_size (int): the size of single bucket, only useful when bucketing is enabled. By default, 40000000. - whitelist_module_types (List[Type]): Types of sparse submodules. By default, None. + module_to_enforce (List[Type]): Types of sparse submodules. By default, None. + param_to_ignore (List[str]): A list of fully qualified names of parameters to be ignored duing gradient + syncronization. By default, None. Returns: A :class:`DistributedDataParallel` object. @@ -57,7 +62,7 @@ class DistributedDataParallel(torch.nn.Module): ddp_module = DDP( module=mlp, data_pg_or_device_mesh=mesh, - whitelist_module_types=[MoEBlock] + module_to_enforce=[MoEBlock] ) # run the forward. ddp_module(torch.rand(xxx)) @@ -73,7 +78,8 @@ def __init__( use_distributed_optimizer: bool = False, disable_bucketing: bool = False, bucket_size: int = 40000000, # Unit: number of the elements - whitelist_module_types: List[Any] = None, + module_to_enforce: List[Type] = None, + param_to_ignore: List[str] = None, **kwargs, ): super().__init__() @@ -108,23 +114,33 @@ def __init__( bucket_size = None self.bucket_size = bucket_size + param_to_ignore = set() if param_to_ignore is None else set(param_to_ignore) + self.module = module self.grad_buffers = {} - self.expert_grads = [] self.grad_buffer_param_index_map = {} self.param_to_grad_buffer = {} + self.ignored_param = [] # Group parameters by their gradient type. grad_dtype_to_params = {} param_to_name = {} + for name, param in self.module.named_parameters(): if param.requires_grad: - param_to_name[param] = name dtype = torch.float if accumulate_allreduce_grads_in_fp32 else param.dtype - params = grad_dtype_to_params.get(dtype, []) - params.append(param) - grad_dtype_to_params[dtype] = params + if name in param_to_ignore: + if not hasattr(param, _DDP_IGNORE_TAG): + setattr(param, _DDP_IGNORE_TAG, True) + param.main_grad = None + self.ignored_param.append(param) + else: + assert not hasattr(param, _DDP_IGNORE_TAG), "registering a parameter that has been ignored by DDP" + param_to_name[param] = name + params = grad_dtype_to_params.get(dtype, []) + params.append(param) + grad_dtype_to_params[dtype] = params # Allocate the grad buffers and map the grads. # The grad buffer under the hood creates buckets as appropriate based on bucket_size. @@ -143,37 +159,24 @@ def __init__( for param in params: self.param_to_grad_buffer[param] = self.grad_buffers[dtype] - # Allocate separate buffer for MoE params' grads - # NOTE: maybe we shoule handle these code later when we need MOE parallel. - for param in self.module.parameters(): - if param.requires_grad and not getattr(param, "allreduce", True): - dtype = torch.float if accumulate_allreduce_grads_in_fp32 else param.dtype - param.main_grad = torch.zeros( - param.data.shape, - dtype=dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - self.expert_grads.append(param.main_grad) - # Register backward hook. # Accumulation function for the gradients need to be stored so they # don't go out of scope. self.grad_accs = [] - for param in self.module.parameters(): + for name, param in self.module.named_parameters(): if param.requires_grad: # Expand so we get access to grad_fn. param_tmp = param.expand_as(param) # Get the gradient accumulator function. grad_acc = param_tmp.grad_fn.next_functions[0][0] - grad_acc.register_hook(self._make_param_hook(param, self.param_to_grad_buffer)) + grad_acc.register_hook(self._make_param_hook(name, param, self.param_to_grad_buffer)) self.grad_accs.append(grad_acc) # Register backward hook for submodules of sparse structure. - if whitelist_module_types is not None and self.overlap_grad_reduce: + if module_to_enforce is not None and self.overlap_grad_reduce: for submod in self.module.modules(): is_sparse = False - for t in whitelist_module_types: + for t in module_to_enforce: if isinstance(submod, t): is_sparse = True break @@ -181,6 +184,9 @@ def __init__( continue submod.register_forward_pre_hook(self._make_sparse_module_pre_hook(), prepend=True) + self.fx_grad_sharding = {} + self.model_parallel_device_mesh = None + def forward(self, *inputs, **kwargs): """ Calls the wrapped module's forward() method. @@ -189,6 +195,7 @@ def forward(self, *inputs, **kwargs): def _make_param_hook( self, + fqn: str, param: torch.nn.Parameter, param_to_grad_buffer: Dict[torch.nn.Parameter, GradBuffer], ): @@ -198,6 +205,18 @@ def _make_param_hook( def param_hook(*unused): if param.requires_grad: + if hasattr(param, _DDP_IGNORE_TAG): + if isinstance(param.data, DTensor): + grad = param.grad._local_tensor.data + else: + grad = param.grad.data + param.grad = None + if param.main_grad is None: + param.main_grad = grad + else: + param.main_grad.add_(grad) + return + if self.overlap_grad_reduce: assert param.grad is not None, "param.grad being None is not safe when overlap_grad_reduce is True" model_parallel_device_mesh, placements = None, None @@ -206,8 +225,11 @@ def param_hook(*unused): param.main_grad.add_(param.grad._local_tensor.data) # add DTensor's data model_parallel_device_mesh = param.grad._spec.mesh placements = param.grad._spec.placements + self.model_parallel_device_mesh = model_parallel_device_mesh else: param.main_grad.add_(param.grad.data) + model_parallel_device_mesh = self.model_parallel_device_mesh + placements = self.fx_grad_sharding[fqn].grad_sharding if fqn in self.fx_grad_sharding else None param.grad = None if ( @@ -248,6 +270,10 @@ def sparse_module_pre_hook(module, args): return sparse_module_pre_hook + def load_fx_grad_sharding(self, grad_sharding, mesh): + self.fx_grad_sharding = grad_sharding + self.model_parallel_device_mesh = mesh + def start_grad_sync(self, *unused): """ Initiates grad sync (all-reduce or reduce-scatter) communication operations @@ -272,9 +298,6 @@ def finish_grad_sync(self): for grad_buffer in self.grad_buffers.values(): grad_buffer.finish_grad_sync() - for expert_grad in self.expert_grads: - expert_grad /= self.data_parallel_world_size - def zero_grad_buffer(self, zero_buffer: bool = True): """ Zeros out all grad buffers. Needs to be called at the beginning of each @@ -284,8 +307,8 @@ def zero_grad_buffer(self, zero_buffer: bool = True): """ for grad_buffer in self.grad_buffers.values(): grad_buffer.reset(zero_buffer) - for expert_grad in self.expert_grads: - expert_grad.zero_() + for param in self.ignored_param: + param.main_grad = None def state_dict(self, prefix="", keep_vars=False): """ diff --git a/vescale/dmodule/_factory.py b/vescale/dmodule/_factory.py index 6816a07..294e0b3 100644 --- a/vescale/dmodule/_factory.py +++ b/vescale/dmodule/_factory.py @@ -38,7 +38,7 @@ __all__ = ["wrap_factory_mode"] -_IS_DEBUG = True +_IS_DEBUG = False aten = torch.ops.aten diff --git a/vescale/dmp/dmp.py b/vescale/dmp/dmp.py index e54754d..5319b8b 100644 --- a/vescale/dmp/dmp.py +++ b/vescale/dmp/dmp.py @@ -28,7 +28,7 @@ __all__ = ["auto_parallelize_module", "set_plan_overriding_policy", "get_plan_overriding_policy"] -_IS_DEBUG = True +_IS_DEBUG = False _PARAM_PLAN_OVERRIDING_POLICY = "PARAM_PLAN_OVERRIDING_POLICY" _FWD_PLAN_OVERRIDING_POLICY = "FWD_PLAN_OVERRIDING_POLICY" diff --git a/vescale/dmp/policies/megatron.py b/vescale/dmp/policies/megatron.py index 91e8596..10e0cae 100644 --- a/vescale/dmp/policies/megatron.py +++ b/vescale/dmp/policies/megatron.py @@ -24,7 +24,7 @@ from .utils import validate_single_input -_IS_DEBUG = True +_IS_DEBUG = False register = REGISTRY.provide_register_for_policy("MEGATRON") diff --git a/vescale/dtensor/_collective_utils.py b/vescale/dtensor/_collective_utils.py index a813935..66ba6de 100644 --- a/vescale/dtensor/_collective_utils.py +++ b/vescale/dtensor/_collective_utils.py @@ -37,6 +37,12 @@ TORCH_VERSION_BIGGER_THAN_2_2 = torch.__version__ >= "2.2" +def mesh_wait(tensor: torch.Tensor) -> torch.Tensor: + if isinstance(tensor, funcol.AsyncCollectiveTensor): + return funcol.wait_tensor(tensor) + return tensor + + # NOTE: upstream are working to migrate the following three collective # apis to be functional, pay attention to it. @@ -348,6 +354,49 @@ def mesh_all_reduce( return funcol.all_reduce(tensor, reduceOp=reduce_op.name, group=mesh._dim_group_infos[mesh_dim][1]) +def broadcast_across_mesh( + tensor: torch.Tensor, + sender: int, + shape: torch.Size, + dtype: torch.dtype, + mesh: DeviceMesh, + async_op=False, +) -> Optional[torch.Tensor]: + if DebugLogger.IS_DEBUG_MODE: + DebugLogger.log_communication(broadcast_across_mesh, tensor, sender, mesh) + + recv_list = mesh.mesh.flatten().tolist() + rank = torch.distributed.get_rank() + group = torch.distributed.group.WORLD + + comm_tensor = torch.empty(shape, dtype=dtype, device=mesh.device_type) + if rank == sender: + comm_tensor = tensor + + if TORCH_VERSION_BIGGER_THAN_2_2: + comm_tensor = funcol.broadcast(comm_tensor, sender, group) + else: + work = broadcast(comm_tensor, sender, group, async_op=async_op) + + if rank in recv_list: + if not async_op: + if TORCH_VERSION_BIGGER_THAN_2_2: + return funcol.wait_tensor(comm_tensor) + else: + return comm_tensor + else: + if TORCH_VERSION_BIGGER_THAN_2_2: + return comm_tensor + else: + from torch.distributed._functional_collectives_impl import _register_tensor_work + from torch.distributed._functional_collectives import _maybe_wrap_tensor + + _register_tensor_work(comm_tensor, work) + return _maybe_wrap_tensor(comm_tensor) + else: + return torch.tensor([], device=mesh.device_type, dtype=tensor.dtype) + + def wait(tensor: torch.Tensor) -> torch.Tensor: if isinstance(tensor, funcol.AsyncCollectiveTensor): return funcol.wait_tensor(tensor) diff --git a/vescale/dtensor/device_mesh.py b/vescale/dtensor/device_mesh.py index 9951271..a32537f 100644 --- a/vescale/dtensor/device_mesh.py +++ b/vescale/dtensor/device_mesh.py @@ -57,6 +57,8 @@ def create_child_mesh(self, device_mesh: "DeviceMesh", mesh_dim: int, mesh_dim_n cur_rank = device_mesh.get_rank() pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape(-1, device_mesh.mesh.size(mesh_dim)) + res_sub_mesh = None + for mesh_1d in pg_ranks_by_dim: sub_mesh = DeviceMesh( device_mesh.device_type, @@ -67,10 +69,13 @@ def create_child_mesh(self, device_mesh: "DeviceMesh", mesh_dim: int, mesh_dim_n if cur_rank in mesh_1d: res_sub_mesh = sub_mesh - res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]] - # Assign the current DeviceMesh as the parent of the child DeviceMesh. - self.child_to_parent_mapping[res_sub_mesh] = device_mesh - return res_sub_mesh + if res_sub_mesh is None: + return None + else: + res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]] + # Assign the current DeviceMesh as the parent of the child DeviceMesh. + self.child_to_parent_mapping[res_sub_mesh] = device_mesh + return res_sub_mesh def create_submesh_along_multi_dims( self, device_mesh: "DeviceMesh", mesh_dims: List[int], cur_rank: int = None diff --git a/vescale/dtensor/dtensor.py b/vescale/dtensor/dtensor.py index 719d0b9..90a4eaf 100644 --- a/vescale/dtensor/dtensor.py +++ b/vescale/dtensor/dtensor.py @@ -30,6 +30,7 @@ from vescale.dtensor.sharding_prop import ShardingPropagator from vescale.dtensor.redistribute import ( Redistribute, + CrossMeshRedistribute, redistribute_local_tensor, ) from vescale.dtensor._utils import compute_global_tensor_info, gather_local_tensor_shape @@ -520,6 +521,21 @@ def redistribute( return Redistribute.apply(self, device_mesh, placements, async_op) + def cross_mesh_redistribute( + self, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, + async_op: bool = True, + ) -> "DTensor": + device_mesh = device_mesh or self._spec.mesh + + # check new placements for not specified + if placements is None: + raise RuntimeError("placements is needed for redistribute!") + placements: Tuple[Placement] = normalize_placements(placements, device_mesh.ndim, tensor_ndim=self.ndim) + + return CrossMeshRedistribute.apply(self, device_mesh, placements, async_op) + def requires_grad_(self, mode=True): self._local_tensor.requires_grad_(mode) return super().requires_grad_(mode) diff --git a/vescale/dtensor/ops/pointwise_ops.py b/vescale/dtensor/ops/pointwise_ops.py index 781b86b..66ab7fd 100644 --- a/vescale/dtensor/ops/pointwise_ops.py +++ b/vescale/dtensor/ops/pointwise_ops.py @@ -30,7 +30,7 @@ is_tensor_partial, register_op_strategy, ) -from vescale.dtensor.placement_types import DTensorSpec, Partial, Placement, Shard +from vescale.dtensor.placement_types import DTensorSpec, Partial, Placement, Shard, InterleavedShard aten = torch.ops.aten # leave the remaining pointwise_ops list here for convenience, @@ -473,7 +473,12 @@ def pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = out_placements: List[Placement] = [] for placement in spec_to_follow.placements: - if isinstance(placement, Shard): + if isinstance(placement, InterleavedShard): + shard_dim = normalize_dim(placement.dim, len(spec_to_follow.shape)) + common_ndim = len(common_shape) + new_shard_dim = common_ndim - len(spec_to_follow.shape) + shard_dim + out_placements.append(InterleavedShard(new_shard_dim, placement.interleaved_size)) + elif isinstance(placement, Shard): shard_dim = normalize_dim(placement.dim, len(spec_to_follow.shape)) common_ndim = len(common_shape) new_shard_dim = common_ndim - len(spec_to_follow.shape) + shard_dim diff --git a/vescale/dtensor/ops/tensor_ops.py b/vescale/dtensor/ops/tensor_ops.py index 3e119d6..4696429 100644 --- a/vescale/dtensor/ops/tensor_ops.py +++ b/vescale/dtensor/ops/tensor_ops.py @@ -443,7 +443,9 @@ def _prop_select(op_schema: OpSchema) -> OutputSharding: for p in placements: # Using isinstance instead of is_shard so that mypy won't complain # about accessing dim attribute. - if isinstance(p, Shard) and p.dim > dim: + if isinstance(p, InterleavedShard) and p.dim > dim: + new_placements.append(InterleavedShard(p.dim - 1, p.interleaved_size)) + elif isinstance(p, Shard) and p.dim > dim: new_placements.append(Shard(p.dim - 1)) else: new_placements.append(p) @@ -1089,7 +1091,6 @@ def unbind_rule(op_schema: OpSchema) -> OutputSharding: placements=unshard_tensor_dim(input_spec.placements, dim=dim), tensor_meta=input_spec.tensor_meta, ) - if need_reshard: return OutputSharding( None, diff --git a/vescale/dtensor/redistribute.py b/vescale/dtensor/redistribute.py index ede8c0e..3fe2c81 100644 --- a/vescale/dtensor/redistribute.py +++ b/vescale/dtensor/redistribute.py @@ -21,11 +21,12 @@ mesh_reduce_scatter, mesh_scatter, mesh_all_to_all_single, + broadcast_across_mesh, wait, ) from vescale.dtensor.device_mesh import DeviceMesh from vescale.dtensor.op_schema import DTensorSpec -from vescale.dtensor.placement_types import InterleavedShard, Partial, Placement, Replicate, Shard +from vescale.dtensor.placement_types import InterleavedShard, Partial, Placement, Replicate, Shard, TensorMeta from vescale.dtensor._utils import compute_global_stride _PlacementItem = Tuple[int, Tuple[Placement, Placement]] @@ -556,3 +557,102 @@ def backward(ctx, grad_output: "dtensor.DTensor"): ) return (output_dtensor, None, None, None) + + +class CrossMeshRedistribute(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + # pyre-fixme[2]: Parameter must be annotated. + ctx, + input: "dtensor.DTensor", + device_mesh: DeviceMesh = None, + placements: Tuple[Placement] = None, + async_op: bool = True, + ): + previous_spec = input._spec + ctx.previous_spec = previous_spec + ctx.async_op = async_op + + # step 1: redistribute to [Replicate()] * mesh_dimension + placements1 = [Replicate()] * previous_spec.mesh.ndim + full_tensor = input.redistribute(placements=placements1) + + # step 2: broastcast across mesh + # select the sender through min-reduce + sender = previous_spec.mesh.mesh.min().to(input.device) + world_size = torch.distributed.get_world_size() + world_mesh = DeviceMesh(device_mesh.device_type, list(range(world_size))) + sender = mesh_all_reduce(sender, mesh=world_mesh, reduce_op=c10d.ReduceOp.MIN, mesh_dim=0) + # broadcast + local_full_tensor = broadcast_across_mesh( + full_tensor.to_local(), + sender.item(), + full_tensor.size(), + full_tensor.dtype, + device_mesh, + ) + + # step3: redistribute to the desired placements over the output mesh + placements2 = [Replicate()] * device_mesh.ndim + local_spec = DTensorSpec(device_mesh, placements2, tensor_meta=previous_spec.tensor_meta) + final_spec = DTensorSpec(device_mesh, placements, tensor_meta=previous_spec.tensor_meta) + output = redistribute_local_tensor(local_full_tensor, local_spec, final_spec, async_op) + output.requires_grad_(input.requires_grad) + + return dtensor.DTensor( + output, + device_mesh, + placements, + shape=input.shape, + dtype=input.dtype, + requires_grad=input.requires_grad, + stride=input.stride(), + ) + + @staticmethod + # type: ignore[override] + def backward(ctx, grad_output: "dtensor.DTensor"): + previous_spec = ctx.previous_spec + async_op = ctx.async_op + + # step 1: local replicate + target_spec = grad_output._spec + placements2 = [Replicate()] * target_spec.mesh.ndim + full_grad = grad_output.redistribute(placements=placements2).to_local() + + # step 2: broastcast across mesh + local_sender = target_spec.mesh.mesh.min().to(grad_output.device) + world_size, rank = torch.distributed.get_world_size(), torch.distributed.get_rank() + world_mesh = DeviceMesh(target_spec.mesh.device_type, list(range(world_size))) + if rank != local_sender: + full_grad = torch.zeros(target_spec.shape, dtype=grad_output.dtype, device=grad_output.device) + global_sender = mesh_all_reduce(local_sender, mesh=world_mesh, reduce_op=c10d.ReduceOp.MIN, mesh_dim=0) + full_grad = mesh_all_reduce(full_grad, mesh=world_mesh, reduce_op=c10d.ReduceOp.SUM, mesh_dim=0) + + local_full_tensor = broadcast_across_mesh( + full_grad, + global_sender, + full_grad.size(), + full_grad.dtype, + previous_spec.mesh, + ) + + # step3: redistribute to the desired placements over the output mesh + placements1 = [Replicate()] * previous_spec.mesh.ndim + local_spec = DTensorSpec(previous_spec.mesh, placements1, tensor_meta=previous_spec.tensor_meta) + final_spec = DTensorSpec(previous_spec.mesh, previous_spec.placements, tensor_meta=previous_spec.tensor_meta) + previous_local_tensor = redistribute_local_tensor(local_full_tensor, local_spec, final_spec, async_op) + previous_spec.tensor_meta = TensorMeta( + shape=grad_output.shape, stride=grad_output.stride(), dtype=grad_output.dtype + ) + + previous_dtensor = dtensor.DTensor( + previous_local_tensor, + previous_spec.mesh, + previous_spec.placements, + shape=grad_output.shape, + dtype=grad_output.stride(), + requires_grad=grad_output.requires_grad, + stride=grad_output.dtype, + ) + return previous_dtensor, None, None, None diff --git a/vescale/emulator/comm_api.py b/vescale/emulator/comm_api.py index 85e21f7..ffaa766 100644 --- a/vescale/emulator/comm_api.py +++ b/vescale/emulator/comm_api.py @@ -270,7 +270,7 @@ def distribute_tensor( return results target_spec = DTensorSpec(mesh=device_mesh, placements=placements, tensor_meta=None) - + placements: Tuple[Placement] = tuple([Replicate()] * device_mesh.ndim) tensor_meta = TensorMeta(shape=tensors[0].shape, stride=tensors[0].stride(), dtype=tensors[0].dtype) current_spec = DTensorSpec(mesh=device_mesh, placements=placements, tensor_meta=tensor_meta) diff --git a/vescale/emulator/comm_primitive.py b/vescale/emulator/comm_primitive.py index 22a0503..cf19126 100644 --- a/vescale/emulator/comm_primitive.py +++ b/vescale/emulator/comm_primitive.py @@ -326,6 +326,7 @@ def normalize_dim_for_shard(placement, tensor): return placement new_dim = placement.dim + tensor_ndim return Shard(new_dim) + current_placement = normalize_dim_for_shard(current_placement, tensor=local_tensors[0]) return _reshard_to_replicate_with_pad_one_dim( local_tensors, diff --git a/vescale/moe/__init__.py b/vescale/moe/__init__.py new file mode 100644 index 0000000..138b9ec --- /dev/null +++ b/vescale/moe/__init__.py @@ -0,0 +1,23 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +from .api import parallelize_experts, is_experts_parallized +from .moe_optimizer import MoEOptimizer +from .experts_allocator import ExpertsAllocator +from .token_dispatcher import TokenDispatcher + +__all__ = ["parallelize_experts", "is_experts_parallized", "MoEOptimizer", "ExpertsAllocator", "TokenDispatcher"] diff --git a/vescale/moe/_experts.py b/vescale/moe/_experts.py new file mode 100644 index 0000000..db189b2 --- /dev/null +++ b/vescale/moe/_experts.py @@ -0,0 +1,103 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +from collections import OrderedDict +import re +import warnings +from typing import List, Dict, Optional, Union + +import torch +from torch import nn + +from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP +from vescale.moe.experts_allocator import ExpertsAllocator, BasicExpertsAllocator +from vescale.moe.token_dispatcher import TokenDispatcher, BasicTokenDispatcher +from vescale.moe._moe_tensor import moe_tensor_forward, register_tensor_operations_with_placeholder +from vescale.moe._scheduler import MoEScheduler, MoETask +from vescale.moe._utils import _TAG_EXPERTS_PARALLIZED + + +__all__ = ["Experts"] + + +class Experts: + @staticmethod + def is_experts_parallized(module: nn.Module) -> bool: + return hasattr(module, _TAG_EXPERTS_PARALLIZED) + + @staticmethod + def set_experts_parallized(module: nn.Module) -> None: + if hasattr(module, _TAG_EXPERTS_PARALLIZED): + warnings.warn(f"resetting `{module.__class__}` as parallized experts!", UserWarning) + setattr(module, _TAG_EXPERTS_PARALLIZED, True) + + @staticmethod + @torch.no_grad() + def init_scheduler( + input_module: nn.Module, + experts: Optional[Union[str, List[str]]], + experts_allocator: Optional[ExpertsAllocator] = None, + token_dispatcher: Optional[TokenDispatcher] = None, + config: Optional[Dict] = None, + ) -> None: + if experts_allocator is None: + experts_allocator = BasicExpertsAllocator() + if token_dispatcher is None: + token_dispatcher = BasicTokenDispatcher() + + experts = experts if type(experts) is List else [experts] + + if isinstance(input_module, DDP): + core_module = input_module.module + else: + core_module = input_module + + moe_layer_list = [] + for experts_pattern in experts: + for submod_fqn, submod in core_module.named_modules(): + if re.fullmatch(experts_pattern, submod_fqn): + if isinstance(submod, nn.ModuleList): + for i in range(len(submod)): + for nm, m in submod[i].named_modules(): + weight = submod_fqn + f".{i}." + nm + if weight in core_module._param_sharding_plan: + m._weight_placement = core_module._param_sharding_plan[weight]["weight"].placements[ + 0 + ] + m._backward_hooks = OrderedDict() + m._forward_hooks = OrderedDict() + m._forward_pre_hooks = OrderedDict() + moe_layer_list.append(submod) + elif isinstance(submod, nn.Parameter): + # TODO: override bmm + raise NotImplementedError + else: + raise ValueError + + scheduler = MoEScheduler(experts_allocator, token_dispatcher, config) + for layer_id, moe_layer in enumerate(moe_layer_list): + for i, expert_model in enumerate(moe_layer): + task = MoETask(model=expert_model, layer_id=layer_id, expert_id=i) + expert_model._original_forward = expert_model.forward + expert_model.forward = moe_tensor_forward(expert_model, task, scheduler) + scheduler.init_param_buffer(moe_layer_list) + input_module.moe_param_buffer = scheduler.get_moe_param_buffer() + + @staticmethod + @torch.no_grad() + def init_forward(config) -> None: + register_tensor_operations_with_placeholder(config) diff --git a/vescale/moe/_moe_param_buffer.py b/vescale/moe/_moe_param_buffer.py new file mode 100644 index 0000000..b859bff --- /dev/null +++ b/vescale/moe/_moe_param_buffer.py @@ -0,0 +1,449 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +from dataclasses import dataclass +from typing import Dict, List, Tuple, Optional + +import torch +from torch import nn +import torch.distributed.distributed_c10d as c10d + +from vescale import DeviceMesh, DTensor, Replicate, Shard, from_local +from vescale.moe._utils import _MOE_DP +from vescale.dtensor._collective_utils import mesh_all_gather, mesh_reduce_scatter, mesh_wait + + +_MOE_BUFFER_TRANSPOSE_TAG = "_MOE_BUFFER_TRANSPOSE_TAG" +aten = torch.ops.aten + + +class MoEBufferTensor(torch.Tensor): + def __init__(self, tensor): + self.tensor = tensor + + def __repr__(self): + return f"MoEBufferTensor({self.tensor})" + + __torch_function__ = torch._C._disabled_torch_function_impl + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + print(func, args) + if len(args) == 1: + return MoEBufferTensor(func(args[0].tensor)) + + +class MoELayerParamBuffer: + def __init__(self, layer_id: int, num_experts: int, param_buffer: "MoEParamBuffer"): + self.layer_id = layer_id + self.num_experts = num_experts + self._param_buffer = param_buffer + + self._mesh_id: Dict[DeviceMesh, int] = {} + self._local_param_buffer: List[torch.Tensor] = [] + self._local_grad_buffer: List[torch.Tensor] = [] + self._global_param_buffer: List[torch.Tensor] = [] + self._global_grad_buffer: List[torch.Tensor] = [] + self._buffer_mesh: List[DeviceMesh] = [] + self._buffer_shape: List[Tuple[int]] = [] + self._buffer_stride: List[Tuple[int]] = [] + + self._mesh_id_to_expert_id: List[List[int]] = [] + self._expert_dp_mesh: List[DeviceMesh] = [] + self._expert_local_param_buffer: List[Optional[torch.Tensor]] = [] + self._expert_global_grad_buffer: List[Optional[torch.Tensor]] = [] + self._expert_stride: List[Tuple[int]] = [] + + self._initialized_flag: bool = False + self._optimizer: Optional[torch.optim.Optimizer] = None + self._dirty_flag: bool = False + + def mark_dirty(self): + self._dirty_flag = True + + def is_dirty(self): + return self._dirty_flag + + def reset_dirty(self): + self._dirty_flag = False + + def init_buffer(self, expert_list: List[nn.Module], experts_alloc: List[DeviceMesh]): + self._initialized_flag = True + + for expert, alloc in zip(expert_list, experts_alloc): + for m in expert.modules(): + if hasattr(m, "weight"): + placement = m.weight.placements[0] + m._moe_device_mesh = alloc + m._moe_placements = (placement, Shard(1 - placement.dim)) + m._moe_buffer_global_shape = m.weight.shape + assert placement.is_shard() + if placement.is_shard() and placement.dim == 1: + setattr(m, _MOE_BUFFER_TRANSPOSE_TAG, True) + param = m.weight.cross_mesh_redistribute( + alloc, + placements=(placement, Shard(1 - placement.dim)), + async_op=False, + ) + m.weight = nn.Parameter(param) + + tmp_buffer = [] + self._mesh_id = {} + self._buffer_mesh = [] + self._mesh_id_to_expert_id = [] + self._expert_dp_mesh = [] + + for expert_id, (expert, alloc) in enumerate(zip(expert_list, experts_alloc)): + dp_mesh = alloc[_MOE_DP] + self._expert_dp_mesh.append(dp_mesh) + if dp_mesh is None: + continue + + if dp_mesh not in self._mesh_id: + mesh_id = len(self._mesh_id) + self._mesh_id[dp_mesh] = mesh_id + self._mesh_id_to_expert_id.append([]) + tmp_buffer.append(None) + self._buffer_mesh.append(dp_mesh) + else: + mesh_id = self._mesh_id[dp_mesh] + self._mesh_id_to_expert_id[mesh_id].append(expert_id) + + for m in expert.modules(): + if hasattr(m, "weight"): + if hasattr(m, _MOE_BUFFER_TRANSPOSE_TAG): + tensor = m.weight._local_tensor.T + local_shape = (tensor.shape[1], tensor.shape[0]) + dp_shape = (tensor.shape[1] * dp_mesh.ndevice, tensor.shape[0]) + else: + tensor = m.weight._local_tensor + local_shape = (tensor.shape[0], tensor.shape[1]) + dp_shape = (tensor.shape[0], tensor.shape[1] * dp_mesh.ndevice) + prev_tensor = tmp_buffer[mesh_id] + if prev_tensor is None: + storage_offset = 0 + tmp_buffer[mesh_id] = tensor + else: + storage_offset = prev_tensor.shape[0] + tmp_buffer[mesh_id] = torch.cat([prev_tensor, tensor]) + m._moe_buffer_local_shape = local_shape + m._moe_buffer_dp_shape = dp_shape + m._moe_buffer_storage_offset = storage_offset + + self._local_param_buffer = [] + self._global_grad_buffer = [] + self._buffer_shape = [] + self._buffer_stride = [] + + for mesh_id, tensor in enumerate(tmp_buffer): + shape = (tensor.numel() * self._buffer_mesh[mesh_id].ndevice,) + self._local_param_buffer.append(tensor.T.flatten()) + self._global_grad_buffer.append(torch.empty(shape, device=tensor.device, dtype=tensor.dtype)) + self._buffer_shape.append(shape) + self._buffer_stride.append((1, tensor.shape[0])) + + self._expert_local_param_buffer = [] + self._expert_global_grad_buffer = [] + self._expert_stride = [] + + for expert_id in range(self.num_experts): + dp_mesh = self._expert_dp_mesh[expert_id] + if dp_mesh is None: + self._expert_local_param_buffer.append(None) + self._expert_global_grad_buffer.append(None) + self._expert_stride.append(None) + else: + mesh_id = self._mesh_id[dp_mesh] + self._expert_local_param_buffer.append(self._local_param_buffer[mesh_id]) + self._expert_global_grad_buffer.append(self._global_grad_buffer[mesh_id]) + self._expert_stride.append(self._buffer_stride[mesh_id]) + + self.run_all_gather() + + def set_optimizer(self, optimizer: torch.optim.Optimizer): + self._optimizer = optimizer + + def is_initialized(self): + return self._initialized_flag + + def refresh_buffer(self, expert_list: List[nn.Module], experts_alloc: List[Optional[DeviceMesh]]) -> None: + def collect_optimizer_states(): + tensor_state_keys = [] + scalar_state = {} + param = self._local_param_buffer[0] + for state_name, state in self._optimizer.state[param].items(): + if state.shape == param.shape: + tensor_state_keys.append(state_name) + else: + scalar_state[state_name] = state + return tensor_state_keys, scalar_state + + device = self._local_param_buffer[0].device + tensor_state_keys, scalar_state = collect_optimizer_states() + param_key = "params" + refresh_keys = [param_key] + tensor_state_keys + refresh_buffer = {} + for key in refresh_keys: + refresh_buffer[key] = [] + + self._mesh_id = {} + self._buffer_mesh = [] + self._mesh_id_to_expert_id = [] + self._expert_dp_mesh = [] + + for expert_id, (expert, alloc) in enumerate(zip(expert_list, experts_alloc)): + dp_mesh = alloc[_MOE_DP] + self._expert_dp_mesh.append(dp_mesh) + if dp_mesh is not None: + if dp_mesh not in self._mesh_id: + mesh_id = len(self._mesh_id) + self._mesh_id[dp_mesh] = mesh_id + self._mesh_id_to_expert_id.append([]) + for key in refresh_keys: + refresh_buffer[key].append(None) + self._buffer_mesh.append(dp_mesh) + else: + mesh_id = self._mesh_id[dp_mesh] + self._mesh_id_to_expert_id[mesh_id].append(expert_id) + expert_stride = self._expert_stride[expert_id] + + for nm, m in expert.named_modules(): + if hasattr(m, "weight"): + device_mesh = m._moe_device_mesh + placements = m._moe_placements + global_shape = m._moe_buffer_global_shape + current_refresh_buffer = {} + + param = self._expert_local_param_buffer[expert_id] + if param is None: + current_refresh_buffer[param_key] = torch.tensor([], device=device).view(0, 0) + for key in tensor_state_keys: + current_refresh_buffer[key] = torch.tensor([], device=device).view(0, 0) + stride = (0, 0) + else: + local_shape = m._moe_buffer_local_shape + storage_offset = m._moe_buffer_storage_offset + optimizer_state = self._optimizer.state[param] + if hasattr(m, _MOE_BUFFER_TRANSPOSE_TAG): + stride = (expert_stride[1], expert_stride[0]) + else: + stride = expert_stride + current_refresh_buffer[param_key] = param.as_strided(local_shape, stride, storage_offset) + for key in tensor_state_keys: + current_refresh_buffer[key] = optimizer_state[key].as_strided( + local_shape, stride, storage_offset + ) + + for key in refresh_keys: + tensor = current_refresh_buffer[key] + dtensor = from_local( + tensor, + device_mesh, + placements, + run_check=False, + shape=global_shape, + stride=stride, + ) + current_refresh_buffer[key] = dtensor.redistribute( + alloc, placements, async_op=False + )._local_tensor + + m._moe_device_mesh = alloc + + if dp_mesh is None: + delattr(m, "_moe_buffer_local_shape") + delattr(m, "_moe_buffer_dp_shape") + delattr(m, "_moe_buffer_storage_offset") + continue + + for key in refresh_keys: + tmp_buffer = refresh_buffer[key] + if hasattr(m, _MOE_BUFFER_TRANSPOSE_TAG): + tensor = current_refresh_buffer[key].T + else: + tensor = current_refresh_buffer[key] + + prev_tensor = tmp_buffer[mesh_id] + if prev_tensor is None: + storage_offset = 0 + tmp_buffer[mesh_id] = tensor + else: + storage_offset = prev_tensor.shape[0] + tmp_buffer[mesh_id] = torch.cat([prev_tensor, tensor]) + + tensor = current_refresh_buffer[param_key] + local_shape = (tensor.shape[0], tensor.shape[1]) + if hasattr(m, _MOE_BUFFER_TRANSPOSE_TAG): + dp_shape = (tensor.shape[0] * dp_mesh.ndevice, tensor.shape[1]) + else: + dp_shape = (tensor.shape[0], tensor.shape[1] * dp_mesh.ndevice) + + m._moe_buffer_local_shape = local_shape + m._moe_buffer_dp_shape = dp_shape + m._moe_buffer_storage_offset = storage_offset + + for param in self._local_param_buffer: + del self._optimizer.state[param] + + self._local_param_buffer.clear() + self._global_grad_buffer.clear() + self._buffer_shape.clear() + self._buffer_stride.clear() + + for mesh_id, tensor in enumerate(refresh_buffer[param_key]): + shape = (tensor.numel() * self._buffer_mesh[mesh_id].ndevice,) + self._local_param_buffer.append(tensor.T.flatten()) + self._global_grad_buffer.append(torch.empty(shape, device=tensor.device, dtype=tensor.dtype)) + self._buffer_shape.append(shape) + self._buffer_stride.append((1, tensor.shape[0])) + + self._expert_local_param_buffer.clear() + self._expert_global_grad_buffer.clear() + self._expert_stride.clear() + + for expert_id in range(self.num_experts): + dp_mesh = self._expert_dp_mesh[expert_id] + if dp_mesh is None: + self._expert_local_param_buffer.append(None) + self._expert_global_grad_buffer.append(None) + self._expert_stride.append(None) + else: + mesh_id = self._mesh_id[dp_mesh] + self._expert_local_param_buffer.append(self._local_param_buffer[mesh_id]) + self._expert_global_grad_buffer.append(self._global_grad_buffer[mesh_id]) + self._expert_stride.append(self._buffer_stride[mesh_id]) + + for i, param in enumerate(self._local_param_buffer): + self._optimizer.state[param] = {} + for state_key in tensor_state_keys: + self._optimizer.state[param][state_key] = refresh_buffer[state_key][i].T.flatten() + self._optimizer.state[param] |= scalar_state + + self.mark_dirty() + self.run_all_gather() + + def assign_param(self, expert_list: List[nn.Module]) -> None: + expert_global_param_buffer = [None] * self.num_experts + + for mesh_id in range(len(self._local_param_buffer)): + mesh = self._buffer_mesh[mesh_id] + if mesh.ndevice == 1: + global_tensor = self._global_param_buffer[mesh_id] + else: + global_tensor = mesh_wait(self._global_param_buffer[mesh_id]) + for expert_id in self._mesh_id_to_expert_id[mesh_id]: + expert_global_param_buffer[expert_id] = global_tensor + self._global_grad_buffer[mesh_id].zero_() + + self._param_buffer.finish_all_gather() + + for expert_id, expert in enumerate(expert_list): + param_buffer = expert_global_param_buffer[expert_id] + if param_buffer is None: + continue + grad_buffer = self._expert_global_grad_buffer[expert_id] + expert_stride = self._expert_stride[expert_id] + for m in expert.modules(): + if hasattr(m, "weight"): + if hasattr(m, _MOE_BUFFER_TRANSPOSE_TAG): + stride = (expert_stride[1], expert_stride[0]) + else: + stride = expert_stride + shape = m._moe_buffer_dp_shape + storage_offset = m._moe_buffer_storage_offset + tensor = param_buffer.as_strided(shape, stride, storage_offset) + param = nn.Parameter(tensor) + param.grad = grad_buffer.as_strided(shape, stride, storage_offset) + m.weight = param + + def get_local_param_buffer(self): + return self._local_param_buffer + + def setup_grad(self) -> None: + for mesh_id, local_grad in enumerate(self._local_grad_buffer): + mesh = self._buffer_mesh[mesh_id] + if mesh.ndevice == 1: + self._local_param_buffer[mesh_id].grad = local_grad + else: + self._local_param_buffer[mesh_id].grad = mesh_wait(local_grad) + + def run_all_gather(self) -> None: + self._global_param_buffer = [] + for mesh_id, local_tensor in enumerate(self._local_param_buffer): + mesh = self._buffer_mesh[mesh_id] + if mesh.ndevice == 1: + self._global_param_buffer.append(local_tensor) + else: + self._global_param_buffer.append( + mesh_all_gather(local_tensor, self._buffer_shape[mesh_id], self._buffer_mesh[mesh_id], 0, 0) + ) + + def run_reduce_scatter(self) -> None: + self._local_grad_buffer = [] + for mesh_id, global_tensor in enumerate(self._global_grad_buffer): + mesh = self._buffer_mesh[mesh_id] + if mesh.ndevice == 1: + self._local_grad_buffer.append(global_tensor) + else: + self._local_grad_buffer.append(mesh_reduce_scatter(global_tensor, mesh, c10d.ReduceOp.SUM, 0, 0)) + + +class MoEParamBuffer: + def __init__(self, num_layers: int, num_experts_list: List[int]): + self.num_layers = num_layers + self._current_layer_id = 0 + self._buffer_list = [MoELayerParamBuffer(i, num_experts_list[i], self) for i in range(num_layers)] + self._optimizer: Optional[torch.optim.Optimizer] = None + + def get_layer_param_buffer(self, layer_id: int) -> MoELayerParamBuffer: + return self._buffer_list[layer_id] + + def get_param_group(self) -> Dict: + params = [] + for buffer in self._buffer_list: + params.extend(buffer.get_local_param_buffer()) + return params + + def setup_grad(self) -> None: + is_dirty = False + for buffer in self._buffer_list: + buffer.setup_grad() + is_dirty |= buffer.is_dirty() + if is_dirty: + self._optimizer.param_groups[0]["params"] = self.get_param_group() + for buffer in self._buffer_list: + buffer.reset_dirty() + + def set_optimizer(self, optimizer) -> None: + self._optimizer = optimizer + for buffer in self._buffer_list: + buffer.set_optimizer(optimizer) + + def process_all_gather(self) -> None: + self._current_layer_id = 0 + self._process_all_gather(self._current_layer_id) + + def _process_all_gather(self, layer_id: int) -> None: + self._buffer_list[layer_id].run_all_gather() + + def finish_all_gather(self) -> None: + if self._current_layer_id < self.num_layers - 1: + self._current_layer_id += 1 + self._process_all_gather(self._current_layer_id) + + def process_reduce_scatter(self, layer_id: int) -> None: + self._buffer_list[layer_id].run_reduce_scatter() diff --git a/vescale/moe/_moe_tensor.py b/vescale/moe/_moe_tensor.py new file mode 100644 index 0000000..6d53d06 --- /dev/null +++ b/vescale/moe/_moe_tensor.py @@ -0,0 +1,95 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +from typing import Any, List, Deque, Dict, Mapping, Optional, Tuple, Union +from dataclasses import dataclass +from vescale import Placement, DeviceMesh +from vescale.moe._scheduler import MoEScheduler, MoETask + +import torch +from torch import nn + + +class MoETensorPlaceholderState(int): + def __new__(cls, state): + obj = int.__new__(cls, state) + return obj + + +_MOE_STATE_EXPERT_OUTPUT = MoETensorPlaceholderState(0) +_MOE_STATE_WEIGHTED_EXPERT_OUTPUT = MoETensorPlaceholderState(1) +_MOE_STATE_ACCUMULATING_RESULTS = MoETensorPlaceholderState(2) + + +@dataclass +class MoETensorPlaceholder: + task: Optional[MoETask] = None + scheduler: Optional[MoEScheduler] = None + state: MoETensorPlaceholderState = _MOE_STATE_EXPERT_OUTPUT + + def __mul__(self, weight: torch.Tensor): + assert self.state == _MOE_STATE_EXPERT_OUTPUT + self.task.token_weight = weight + self.state = _MOE_STATE_WEIGHTED_EXPERT_OUTPUT + return self + + def __rmul__(self, weight: torch.Tensor): + assert self.state == _MOE_STATE_EXPERT_OUTPUT + self.task.token_weight = weight + self.state = _MOE_STATE_WEIGHTED_EXPERT_OUTPUT + return self + + def to(self, *args, **kwargs): + return self + + +def moe_tensor_forward(expert: nn.Module, task: MoETask, scheduler: MoEScheduler): + def forward(x): + task.hidden_state = x + task.device = x.device + return MoETensorPlaceholder(task=task, scheduler=scheduler) if expert.training else expert._original_forward(x) + + return forward + + +def register_tensor_operations_with_placeholder(config): + original_index_add_ = torch.Tensor.index_add_ + + def wrappped_index_add_(self, dim: int, index: torch.Tensor, value: Union[torch.Tensor, MoETensorPlaceholder]): + if isinstance(value, MoETensorPlaceholder): + assert value.state == _MOE_STATE_WEIGHTED_EXPERT_OUTPUT + if not hasattr(self, "_moe_tensor_placeholder"): + scheduler = value.scheduler + self._moe_tensor_placeholder = MoETensorPlaceholder( + task=None, state=_MOE_STATE_ACCUMULATING_RESULTS, scheduler=scheduler + ) + else: + assert self._moe_tensor_placeholder.state == _MOE_STATE_ACCUMULATING_RESULTS + scheduler = self._moe_tensor_placeholder.scheduler + assert scheduler is value.scheduler + value.task.token_id = index + value.task.output_tensor = self + scheduler.push_task(value.task) + if scheduler.num_tasks() == config["num_experts"]: + tensor = self._moe_tensor_placeholder.scheduler.launch() + self.copy_(tensor) + return self + else: + with torch._C.DisableTorchFunctionSubclass(): + return original_index_add_(self, dim, index, value) + + torch.Tensor.index_add_ = wrappped_index_add_ diff --git a/vescale/moe/_scheduler.py b/vescale/moe/_scheduler.py new file mode 100644 index 0000000..a6231e2 --- /dev/null +++ b/vescale/moe/_scheduler.py @@ -0,0 +1,277 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +from typing import Dict, List, Optional, Union +import torch +from torch import nn +from vescale import DeviceMesh, DTensor, from_local +from vescale.moe.experts_allocator import ExpertsAllocator +from vescale.moe.token_dispatcher import TokenDispatcher +from vescale.moe._moe_param_buffer import MoEParamBuffer, MoELayerParamBuffer +from vescale.moe._utils import _MOE_DP, global_all_to_all_single +from vescale.dmodule._factory import FactoryDispatchModeOff +from dataclasses import dataclass + + +@dataclass +class MoETask: + model: Optional[nn.Module] = None + layer_id: int = -1 + device: Optional[torch.device] = None + expert_id: Optional[Union[torch.Tensor, int]] = -1 + hidden_state: Optional[Union[torch.Tensor, DTensor]] = None + token_id: Optional[Union[torch.Tensor, DTensor]] = None + token_weight: Optional[Union[torch.Tensor, DTensor]] = None + output_tensor: Optional[Union[torch.Tensor, DTensor]] = None + + +@dataclass +class MoELayerInfo: + layer_id: int = 0 + num_experts: int = 0 + num_devices: int = 0 + experts_alloc: Optional[List[Optional[DeviceMesh]]] = None + experts_dp_mesh: Optional[List[Optional[DeviceMesh]]] = None + device_mask: Optional[List[torch.Tensor]] = None + param_buffer: Optional[MoELayerParamBuffer] = None + + def set_new_experts_alloc(self, expert_list: List[nn.Module], experts_alloc_info: Dict) -> None: + experts_alloc: List[Optional[DeviceMesh]] = experts_alloc_info["experts_alloc"] + dp_size: torch.Tensor = experts_alloc_info["dp_size"] + if self.param_buffer.is_initialized(): + self.param_buffer.refresh_buffer(expert_list, experts_alloc) + else: + self.param_buffer.init_buffer(expert_list, experts_alloc) + self.experts_alloc = experts_alloc + self.experts_dp_mesh = [alloc[_MOE_DP] if alloc is not None else None for alloc in experts_alloc] + self._update_device_mask(dp_size) + + def _update_device_mask(self, dp_size: torch.Tensor) -> None: + experts_alloc = self.experts_alloc + num_experts, max_replica, num_devices = len(experts_alloc), dp_size.max().item(), self.num_devices + device_type = experts_alloc[0].device_type + device_mask = torch.zeros((num_experts, max_replica, num_devices), device=device_type, dtype=torch.bool) + for i, alloc in enumerate(experts_alloc): + mesh = alloc.mesh + for r in range(mesh.shape[1]): + device = mesh[:, r] + device_mask[i, r, device] = True + self.device_mask = device_mask + + +class MoEScheduler: + def __init__(self, experts_allocator: ExpertsAllocator, token_dispatcher: TokenDispatcher, config: Dict): + def wrap_as_list(value, length): + if isinstance(value, int): + return [value] * length + else: + assert len(value) == length + return value + + self.num_layers: int = config["num_layers"] + self.experts_allocator: ExpertsAllocator = experts_allocator + self.token_dispatcher: TokenDispatcher = token_dispatcher + + num_experts_list = wrap_as_list(config["num_experts"], self.num_layers) + num_devices_list = wrap_as_list(config["num_devices"], self.num_layers) + + self._param_buffer = MoEParamBuffer(self.num_layers, num_experts_list) + + self._layer_info: List[MoELayerInfo] = [ + MoELayerInfo( + layer_id=i, + num_experts=num_experts_list[i], + num_devices=num_devices_list[i], + param_buffer=self._param_buffer.get_layer_param_buffer(i), + ) + for i in range(self.num_layers) + ] + self._task_per_expert: List[MoETask] = [] + self._current_info: Optional[MoELayerInfo] = None + + def init_param_buffer(self, moe_layer_list) -> None: + assert len(moe_layer_list) == self.num_layers + for expert_list, layer_info in zip(moe_layer_list, self._layer_info): + experts_alloc_info = self.experts_allocator.allocate_experts_internal(layer_info.layer_id) + self.token_dispatcher.set_experts_alloc(experts_alloc_info) + layer_info.set_new_experts_alloc(expert_list, experts_alloc_info) + + def get_moe_param_buffer(self) -> MoEParamBuffer: + return self._param_buffer + + def push_task(self, task: MoETask): + self._task_per_expert.append(task) + + def num_tasks(self): + return len(self._task_per_expert) + + def _set_context(self, layer_id: int): + self._current_info = self._layer_info[layer_id] + + def _allocate_experts(self, layer_id: int, task_list: List[MoETask]): + experts_alloc_info = self.experts_allocator.allocate_experts_internal(layer_id) + expert_list = [task.model for task in task_list] + layer_info = self._current_info + + if experts_alloc_info is not None: + self.token_dispatcher.set_experts_alloc(experts_alloc_info) + layer_info.set_new_experts_alloc(expert_list, experts_alloc_info) + + layer_info.param_buffer.assign_param(expert_list) + + def _concat_task_per_expert(self, _task_per_expert: List[MoETask]): + device = _task_per_expert[0].device + + token_id_list, expert_id_list, hidden_state_list, token_weight_list = [], [], [], [] + for task in _task_per_expert: + token_num = task.token_id.shape[0] + token_id_list.append(task.token_id) + expert_id_list.append(torch.full((token_num,), task.expert_id, device=device)) + hidden_state_list.append(task.hidden_state) + token_weight_list.append(task.token_weight) + + task_full = MoETask( + layer_id=_task_per_expert[0].layer_id, + token_id=torch.cat(token_id_list), + expert_id=torch.cat(expert_id_list), + hidden_state=torch.cat(hidden_state_list), + token_weight=torch.cat(token_weight_list), + device=device, + ) + + return task_full + + def _distribute_workload(self, task_full: MoETask): + layer_id = task_full.layer_id + + eid, rid = self.token_dispatcher.dispatch_token(layer_id) + device_mask = self._current_info.device_mask + token_id, device_id = torch.where( + device_mask[eid, rid] + ) # TODO: implement a dedicated kernel for batched slice to avoid sync + + device_id, sort_idx = torch.sort(device_id) + token_id = token_id[sort_idx] + device_id_start = torch.searchsorted( + device_id, torch.arange(self._current_info.num_devices + 1, device=device_id.device) + ) + pre_split_sizes = device_id_start.diff() + + pre_expert_id = task_full.expert_id[token_id] + pre_hidden_state = task_full.hidden_state[token_id] + pre_token_id = task_full.token_id[token_id] + pre_token_weight = task_full.token_weight[token_id] + + process_split_sizes = torch.empty_like(pre_split_sizes) + torch.distributed.all_to_all_single(process_split_sizes, pre_split_sizes) + pre_split_sizes = pre_split_sizes.tolist() + process_split_sizes = process_split_sizes.tolist() + process_expert_id = global_all_to_all_single(pre_expert_id, pre_split_sizes, process_split_sizes) + process_hidden_state = global_all_to_all_single(pre_hidden_state, pre_split_sizes, process_split_sizes) + + return ( + pre_split_sizes, + pre_token_id, + pre_token_weight, + process_split_sizes, + process_expert_id, + process_hidden_state, + ) + + def _compute_local_experts(self, process_models, process_expert_id, process_hidden_state): + if process_hidden_state.numel() == 0: + return process_hidden_state + process_expert_id, index_sort = torch.sort(process_expert_id) + process_start = torch.searchsorted( + process_expert_id, torch.arange(self._current_info.num_experts + 1, device=process_hidden_state.device) + ) + process_hidden_state = process_hidden_state[index_sort] + result_hidden_state = torch.empty_like(process_hidden_state) + + process_start = process_start.tolist() + for expert_id, expert_model in enumerate(process_models): + if process_start[expert_id] < process_start[expert_id + 1]: + hidden_state = process_hidden_state[process_start[expert_id] : process_start[expert_id + 1]] + hidden_state = expert_model._original_forward(hidden_state) + result_hidden_state[index_sort[process_start[expert_id] : process_start[expert_id + 1]]] = hidden_state + + return result_hidden_state + + def _distribute_result(self, process_hidden_state, process_split_sizes, post_split_sizes): + return global_all_to_all_single(process_hidden_state, process_split_sizes, post_split_sizes) + + def _triger_param_comm_hook(self, layer_id: int): + def hook(*useless): + self._param_buffer.process_reduce_scatter(layer_id) + + return hook + + @FactoryDispatchModeOff() + def launch(self): + layer_id = self._task_per_expert[0].layer_id + self._set_context(layer_id) + + # step 1: call experts_allocator.allocate_experts() and reallocate experts + # we place it at the beginning for overlapping this process with `gather` in ZeRO-2 + self._allocate_experts(layer_id, self._task_per_expert) + + # step 2: call token_dispatcher + task_full = self._concat_task_per_expert(self._task_per_expert) + self.token_dispatcher.assign_task( + layer_id, + token_id=task_full.token_id, + expert_id=task_full.expert_id, + hidden_state=task_full.hidden_state, + token_weight=task_full.token_weight, + ) + task_full.hidden_state.register_hook(self._triger_param_comm_hook(layer_id)) + + # step 3: distribute tokens + ( + post_split_sizes, + post_token_id, + post_token_weight, + process_split_sizes, + process_expert_id, + process_hidden_state, + ) = self._distribute_workload(task_full) + + # step 4: processing experts + process_models = [task.model for task in self._task_per_expert] + process_hidden_state = self._compute_local_experts(process_models, process_expert_id, process_hidden_state) + + # step 5: accumulate the results + post_hidden_state = self._distribute_result(process_hidden_state, process_split_sizes, post_split_sizes) + post_hidden_state *= post_token_weight + + output_tensor = self._task_per_expert[0].output_tensor + device_mesh: DeviceMesh = self._task_per_expert[0].output_tensor.device_mesh + placements = self._task_per_expert[0].output_tensor.placements + output_local_tensor = output_tensor._local_tensor + output_local_tensor.index_add_(0, post_token_id, post_hidden_state) + + # step 6: collect entire workload distribution and call experts_allocator.assign_workload() + pass + + # step 7: clear the task list and return the result + self._task_per_expert = [] + return from_local( + output_local_tensor, + device_mesh, + placements, + run_check=False, + ) diff --git a/vescale/moe/_utils.py b/vescale/moe/_utils.py new file mode 100644 index 0000000..9c0180c --- /dev/null +++ b/vescale/moe/_utils.py @@ -0,0 +1,68 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import torch +from typing import List + +_TAG_EXPERTS_PARALLIZED = "_EXPERTS_PARALLIZED" +_MOE_TP = "_MOE_TP" +_MOE_DP = "_MOE_DP" + + +def global_all_to_all_single( + tensor: torch.Tensor, + input_split_sizes: List[int], + output_split_sizes: List[int], + async_op: bool = False, +): + return _AllToAllSingle.apply(tensor, input_split_sizes, output_split_sizes, async_op) + + +class _AllToAllSingle(torch.autograd.Function): + @staticmethod + def forward( + ctx, + input_tensor: torch.Tensor, + input_split_sizes: List[int], + output_split_sizes: List[int], + async_op: bool = False, + ): + ctx.input_split_sizes = input_split_sizes + ctx.output_split_sizes = output_split_sizes + ctx.async_op = async_op + output_tensor = input_tensor.new_empty( + size=[sum(output_split_sizes)] + list(input_tensor.size()[1:]), + dtype=input_tensor.dtype, + device=input_tensor.device, + ) + torch.distributed.all_to_all_single( + output_tensor, + input_tensor, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + async_op=async_op, + ) + return output_tensor + + @staticmethod + def backward(ctx, grad): + return ( + _AllToAllSingle.apply(grad, ctx.output_split_sizes, ctx.input_split_sizes, ctx.async_op), + None, + None, + None, + ) diff --git a/vescale/moe/api.py b/vescale/moe/api.py new file mode 100644 index 0000000..d6d99b7 --- /dev/null +++ b/vescale/moe/api.py @@ -0,0 +1,52 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +from typing import Dict, List, Optional, Union +import warnings + +from torch import nn +from vescale.moe._experts import Experts +from vescale.moe.experts_allocator import ExpertsAllocator, BasicExpertsAllocator +from vescale.debug import DebugLogger +from vescale.moe.token_dispatcher import BasicTokenDispatcher, TokenDispatcher + +__all__ = ["parallelize_experts", "is_experts_parallized", "ExpertsAllocator"] + + +def parallelize_experts( + module: nn.Module, + experts_expr: Optional[Union[str, List[str]]] = None, + experts_allocator: Optional[ExpertsAllocator] = None, + token_dispatcher: Optional[TokenDispatcher] = None, + config: Optional[Dict] = None, +) -> nn.Module: + DebugLogger.update_vescale_debug_mode_from_env() + + if Experts.is_experts_parallized(module): + warnings.warn(f"{module} has already parallelized experts. Skip `parallelize_experts`", UserWarning) + return module + + Experts.init_scheduler(module, experts_expr, experts_allocator, token_dispatcher, config) + + Experts.init_forward(config) + + Experts.set_experts_parallized(module) + + return module + + +is_experts_parallized = Experts.is_experts_parallized diff --git a/vescale/moe/experts_allocator.py b/vescale/moe/experts_allocator.py new file mode 100644 index 0000000..5098390 --- /dev/null +++ b/vescale/moe/experts_allocator.py @@ -0,0 +1,82 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +from typing import Dict, Optional, Union, List +import torch +import torch.distributed as dist +from abc import ABC, abstractmethod +from vescale import DeviceMesh +from vescale.moe._utils import _MOE_DP, _MOE_TP + + +class ExpertsAllocator(ABC): + @abstractmethod + def __init__(self, model_config=None, env_config=None): + pass + + @abstractmethod + def collect_performance(self, perf, iter=-1): + pass + + @abstractmethod + def allocate_experts(self, layer_id, iter=-1): + pass + + def allocate_experts_internal(self, layer_id, iter=-1) -> Optional[Dict]: + experts_alloc = self.allocate_experts(layer_id, iter) + if experts_alloc is None: + return None + new_experts_alloc = [] + dp_size, tp_size = [], [] + for alloc in experts_alloc: + assert alloc.ndim == 2 + dp_size.append(alloc.mesh.shape[0]) + tp_size.append(alloc.mesh.shape[1]) + device_type = alloc.device_type + mesh_dim_names = (_MOE_TP, _MOE_DP) + new_alloc = DeviceMesh(device_type, alloc.mesh.t(), mesh_dim_names=mesh_dim_names) + new_experts_alloc.append(new_alloc) + dp_size = torch.tensor(dp_size, device=device_type) + tp_size = torch.tensor(tp_size, device=device_type) + experts_alloc_info = { + "experts_alloc": new_experts_alloc, + "dp_size": dp_size, + "tp_size": tp_size, + } + return experts_alloc_info + + +class BasicExpertsAllocator(ExpertsAllocator): + def __init__(self, exp_config=None, env_config=None): + self.experts_num = 8 + self.visit_flag = set() + + self.experts_allocation = [] # A list of DP * TP + world_size = dist.get_world_size() + devices = torch.arange(world_size) + for _ in range(self.experts_num): + self.experts_allocation.append(DeviceMesh("cuda", devices.reshape(1, -1), mesh_dim_names=("DP", "TP"))) + + def collect_performance(self, perf, iter=-1): + pass + + def allocate_experts(self, layer_id, iter=-1) -> Union[None, List[DeviceMesh]]: + if layer_id not in self.visit_flag: + self.visit_flag.add(layer_id) + return self.experts_allocation + else: + return None diff --git a/vescale/moe/moe_optimizer.py b/vescale/moe/moe_optimizer.py new file mode 100644 index 0000000..751f8a1 --- /dev/null +++ b/vescale/moe/moe_optimizer.py @@ -0,0 +1,107 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import logging +from typing import Optional, Iterable, Dict, Any, Union + +import torch +import torch.distributed.distributed_c10d as c10d + +from vescale.dtensor.dtensor import DTensor +from vescale.dtensor._collective_utils import mesh_reduce_scatter, mesh_wait +from vescale.optim.base_optimizer import OptimizerBase +from vescale.moe._moe_param_buffer import MoEParamBuffer + +try: + import amp_C + from apex.multi_tensor_apply import multi_tensor_applier +except ImportError: + USE_APEX_MULTI_APPLIER = False +else: + USE_APEX_MULTI_APPLIER = True + +logger = logging.getLogger(__name__) + + +class MoEOptimizer(OptimizerBase): + def __init__( + self, + optimizer: torch.optim.Optimizer, + param_buffer: MoEParamBuffer, + clip_grad: float = 0.0, + *args, + **kwargs, + ) -> None: + super().__init__(optimizer=optimizer) + self.clip_grad = clip_grad + self.param_buffer = param_buffer + self.optimizer = optimizer(param_buffer.get_param_group(), *args, **kwargs) + param_buffer.set_optimizer(self.optimizer) + + @torch.no_grad() + def step(self) -> Optional[float]: + self.param_buffer.setup_grad() + self.optimizer.step() + self.param_buffer.process_all_gather() + return 0 + + def zero_grad(self, set_to_none: bool = True) -> None: + self.optimizer.zero_grad(set_to_none=set_to_none) + + def state_dict(self): + return self.optimizer.state_dict() + + def load_state_dict(self, state_dict): + self.optimizer.load_state_dict(state_dict) + + def get_loss_scale(self): + return 1.0 + + def clip_grad_norm(self, clip_grad): + grad_list = [] + grad_norm = 0 + for pg in self.param_groups: + for p in pg["params"]: + if p.grad is not None: + grad_list.append(mesh_wait(p.grad._local_tensor)) + + if not USE_APEX_MULTI_APPLIER: + total_norm = 0 + for grad in grad_list: + total_norm += (grad**2).sum() + else: + dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") + grad_norm, _ = multi_tensor_applier( + amp_C.multi_tensor_l2norm, + dummy_overflow_buf, + [grad_list], + False, + ) + total_norm = grad_norm**2 + + torch.distributed.all_reduce(total_norm, op=torch.distributed.ReduceOp.SUM) + total_norm = total_norm.sqrt().item() + + clip_coeff = clip_grad / (total_norm + 1.0e-6) + if clip_coeff < 1.0: + dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") + if not USE_APEX_MULTI_APPLIER: + for g in grad_list: + g.data.mul_(clip_coeff) + else: + multi_tensor_applier(amp_C.multi_tensor_scale, dummy_overflow_buf, [grad_list, grad_list], clip_coeff) + return total_norm diff --git a/vescale/moe/token_dispatcher.py b/vescale/moe/token_dispatcher.py new file mode 100644 index 0000000..28710ff --- /dev/null +++ b/vescale/moe/token_dispatcher.py @@ -0,0 +1,49 @@ +from typing import Dict, List, Optional, Tuple +import torch +import torch.distributed as dist +from abc import ABC, abstractmethod +from vescale import DeviceMesh + + +class TokenDispatcher(ABC): + @abstractmethod + def __init__(self, exp_config=None, env_config=None): + pass + + @abstractmethod + def assign_task(self, layer_id, token_id, expert_id, hidden_state, token_weight): + pass + + @abstractmethod + def set_experts_alloc(self, experts_alloc): + pass + + @abstractmethod + def collect_performance(self, perf, iter=-1): + pass + + @abstractmethod + def dispatch_token(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: + pass + + +class BasicTokenDispatcher(TokenDispatcher): + def __init__(self, exp_config=None, env_config=None): + self.experts_alloc: Optional[List[Optional[DeviceMesh]]] = None + self.expert_id: Optional[torch.Tensor] = None + self.num_replicate: Optional[torch.Tensor] = None + + def assign_task(self, layer_id, token_id, expert_id, hidden_state, token_weight): + self.expert_id = expert_id + + def set_experts_alloc(self, experts_alloc_info: Dict) -> None: + self.experts_alloc = experts_alloc_info["experts_alloc"] + self.num_replicate = experts_alloc_info["dp_size"] + + def collect_performance(self, perf, iter=-1): + pass + + def dispatch_token(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: + num_replicate = self.num_replicate[self.expert_id] + replicate_id = torch.randint_like(num_replicate, 65535) % num_replicate + return self.expert_id, replicate_id diff --git a/vescale/optim/checkpoint_helper.py b/vescale/optim/checkpoint_helper.py new file mode 100644 index 0000000..5040289 --- /dev/null +++ b/vescale/optim/checkpoint_helper.py @@ -0,0 +1,30 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + + +import inspect +from .distributed_optimizer import DistributedOptimizer + + +def initialize_optimizer_state(optimizer: DistributedOptimizer): + optimizer._copy_model_grads_to_main_grads() + orig_optimizer = optimizer.optimizer + for group in orig_optimizer.param_groups: + param_list = inspect.signature(orig_optimizer._init_group).parameters + num_params = len(param_list) + args = [group] + [[] for i in range(num_params - 1)] + orig_optimizer._init_group(*args) diff --git a/vescale/optim/utils.py b/vescale/optim/utils.py index f4d75ac..8a10e28 100644 --- a/vescale/optim/utils.py +++ b/vescale/optim/utils.py @@ -59,5 +59,20 @@ def param_is_sharded_or_replicate_on_first_rank(param): return False +def zero_grad_group_helper(group, set_to_none: bool = True): + """Zero out the gradient for a group of parameters. + Note: copied from torch.optim.optimizer.""" + for param in group: + if param.grad is not None: + if set_to_none: + param.grad = None + else: + if param.grad.grad_fn is not None: + param.grad.detach_() + else: + param.grad.requires_grad_(False) + param.grad.zero_() + + def param_is_shared(param): return getattr(param, "shared", False)