From 8bda9750154d21f5d159c1506e71d498025656c7 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Wed, 5 Jun 2024 06:58:38 +0000 Subject: [PATCH 01/10] fix rope precision for long context --- megatron/model/rotary_pos_embedding.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/megatron/model/rotary_pos_embedding.py b/megatron/model/rotary_pos_embedding.py index 4d4497e0cd..d36cc71490 100644 --- a/megatron/model/rotary_pos_embedding.py +++ b/megatron/model/rotary_pos_embedding.py @@ -20,8 +20,9 @@ def __init__(self, dim, theta=10000): raise RuntimeError("einops is required for Rotary Embedding") def forward(self, max_seq_len, offset=0): - seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset - freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq) + seq = torch.arange(max_seq_len, device=self.inv_freq.device, dtype=torch.float) + offset + # Force float32 since bfloat16 loses precision on long contexts + freqs = einsum('i , j -> i j', seq, self.inv_freq.float()) # first part even vector components, second part odd vector components, # 2 * dim in dimension size emb = torch.cat((freqs, freqs), dim=-1) From 2e69e220b22009a918a9218081dccf8fcbd4c854 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 18 Jun 2024 13:37:49 +0000 Subject: [PATCH 02/10] enable o_compute aynsc --- megatron/core/tensor_parallel/layers.py | 42 ++-- megatron/model/transformer.py | 16 +- megatron/training.py | 248 +++++++++++++----------- 3 files changed, 175 insertions(+), 131 deletions(-) diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index 020d25915a..0ccb633ffe 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -237,16 +237,20 @@ def forward(self, position_ids): class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): """See linear_with_grad_accumulation_and_async_allreduce""" + @staticmethod @custom_fwd def forward(ctx, input, weight, bias, gradient_accumulation_fusion, - async_grad_allreduce, sequence_parallel): + async_grad_allreduce, sequence_parallel, bwd_stream): ctx.save_for_backward(input, weight) ctx.use_bias = bias is not None ctx.gradient_accumulation_fusion = gradient_accumulation_fusion ctx.async_grad_allreduce = async_grad_allreduce ctx.sequence_parallel = sequence_parallel - + if bwd_stream==None: + ctx.bwd_stream=torch.cuda.default_stream() + else: + ctx.bwd_stream=bwd_stream if sequence_parallel: world_size = get_tensor_model_parallel_world_size() dim_size = list(input.size()) @@ -359,8 +363,12 @@ def backward(ctx, grad_output): # grad_weight = None # else: # grad_weight = grad_output.t().matmul(total_input) - grad_weight = grad_output.t().matmul(total_input) - grad_bias = grad_output.sum(dim=0) if use_bias else None + + # to get grad_output + ctx.bwd_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(ctx.bwd_stream): + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None if ctx.sequence_parallel: handle.wait() @@ -368,9 +376,9 @@ def backward(ctx, grad_output): if ctx.async_grad_allreduce: handle.wait() - - return grad_input, grad_weight, grad_bias, None, None, None - + # torch.cuda.default_stream().wait_stream(ctx.bwd_stream) + return grad_input, grad_weight, grad_bias, None, None, None,None +from typing import Any def linear_with_grad_accumulation_and_async_allreduce( input: torch.Tensor, weight: torch.Tensor, @@ -378,6 +386,7 @@ def linear_with_grad_accumulation_and_async_allreduce( gradient_accumulation_fusion: bool, async_grad_allreduce: bool, sequence_parallel: bool, + async_sp_all2all_stream:Any=None ) -> torch.Tensor: """Linear layer execution with asynchronous communication and gradient accumulation fusion in backprop. @@ -438,6 +447,7 @@ def linear_with_grad_accumulation_and_async_allreduce( gradient_accumulation_fusion, async_grad_allreduce, sequence_parallel, + async_sp_all2all_stream ] if not linear_with_grad_accumulation_and_async_allreduce.warned: @@ -494,6 +504,7 @@ class ColumnParallelLinear(torch.nn.Module): config: ModelParallelConfig object """ + stream=None def __init__(self, input_size, output_size, *, config: ModelParallelConfig, @@ -591,7 +602,8 @@ def __init__(self, input_size, output_size, *, "cannot be enabled at the same time." ) - + + def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None): @@ -628,6 +640,8 @@ def forward(self, input_parallel = input_ else: input_parallel = copy_to_tensor_model_parallel_region(input_) + + # Matrix multiply. output_parallel = linear_with_grad_accumulation_and_async_allreduce( input=input_parallel, @@ -635,7 +649,8 @@ def forward(self, bias=bias, gradient_accumulation_fusion=self.gradient_accumulation_fusion, async_grad_allreduce=self.async_tensor_model_parallel_allreduce, - sequence_parallel=self.sequence_parallel + sequence_parallel=self.sequence_parallel, + async_sp_all2all_stream=None ) if self.gather_output and not self.is_expert_without_slicing: # All-gather across the partitions. @@ -690,9 +705,10 @@ def __init__(self, input_size: int, output_size: int, *, stride: int = 1, keep_master_weight_for_test: bool = False, skip_bias_add: bool = False, - moe=False, enable_expert_tensor_parallelism=False): + moe=False, enable_expert_tensor_parallelism=False,ds_sp_sync_stream=None): torch.nn.Module.__init__(self) - + self.ds_sp_sync_stream=ds_sp_sync_stream + # Keep input parameters self.input_size = input_size self.output_size = output_size @@ -768,6 +784,9 @@ def forward(self, input_): assert not self.sequence_parallel input_parallel = scatter_to_tensor_model_parallel_region(input_) # Matrix multiply. + async_sp_all2all_stream=None + if self.ds_sp_sync_stream is not None: + async_sp_all2all_stream=self.ds_sp_sync_stream output_parallel = linear_with_grad_accumulation_and_async_allreduce( input=input_parallel, weight=self.weight, @@ -775,6 +794,7 @@ def forward(self, input_): gradient_accumulation_fusion=self.gradient_accumulation_fusion, async_grad_allreduce=False, sequence_parallel=False, + async_sp_all2all_stream=async_sp_all2all_stream ) # All-reduce across all the partitions. diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index e75f13a24f..198d47fba9 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -496,7 +496,14 @@ class ParallelAttention(MegatronModule): Self-attention layer takes input with size [s, b, h] and returns output of the same size. """ - + stream=None + @classmethod + def get_stream(cls): + if cls.stream==None: + cls.stream=torch.cuda.Stream() + # cls.stream=torch.cuda.current_stream() + return cls.stream + def __init__(self, config, layer_number, attention_type=AttnType.self_attn, attn_mask_type=AttnMaskType.padding): @@ -598,7 +605,7 @@ def __init__(self, config, layer_number, if self.enable_ds_sequence_parallel: assert dist_attn_supported, 'Distributed attention is not supported in this DeepSpeed version' assert args.num_attention_heads % parallel_state.get_sequence_parallel_world_size() == 0 - self.dist_attn = DistributedAttention(local_attn, parallel_state.get_sequence_parallel_group()) + self.dist_attn = DistributedAttention(local_attn, parallel_state.get_sequence_parallel_group(),sp_stream=self.get_stream()) else: if self.use_flash_attn: self.core_attention_flash = local_attn @@ -606,6 +613,7 @@ def __init__(self, config, layer_number, self.core_attention = local_attn self.checkpoint_core_attention = config.recompute_granularity == 'selective' + # Output. self.dense = tensor_parallel.RowParallelLinear( projection_size, @@ -614,7 +622,9 @@ def __init__(self, config, layer_number, init_method=config.output_layer_init_method, bias=args.add_bias_linear, input_is_parallel=True, - skip_bias_add=True) + skip_bias_add=True, + ds_sp_sync_stream=self.get_stream() + ) def _checkpointed_attention_forward(self, query_layer, key_layer, diff --git a/megatron/training.py b/megatron/training.py index 19b8a6c71f..9ad0afe186 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -956,7 +956,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, iteration, ) - if iteration % args.tensorboard_log_interval == 0: + if iteration % args.tensorboard_log_interval == 9999999: # This logging write various optimizer states to tensorboard. This # feature may consume extra GPU memory thus is set at false by default. if args.log_optimizer_states_to_tensorboard and optimizer is not None: @@ -1188,129 +1188,143 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, if args.random_ltd: assert model[0].random_ltd_enabled() args.random_ltd_layer_num = model[0].random_ltd_scheduler.get_random_ltd_layer_num() - - while iteration < args.train_iters and (args.train_tokens is None or \ - args.consumed_train_tokens < args.train_tokens): - update_num_microbatches(args.consumed_train_samples) - if args.deepspeed: - # inform deepspeed of any batch size changes - global_batch_size = mpu.get_data_parallel_world_size() * \ - args.micro_batch_size * \ - get_num_microbatches() - model[0].set_train_batch_size(global_batch_size) - if args.curriculum_learning_legacy and not args.no_pipeline_parallel: - curriculum_seqlen = args.curriculum_scheduler.update_difficulty( \ - args.iteration + 1) - if iteration == 0 or curriculum_seqlen != args.curriculum_seqlen: - if args.use_rotary_position_embeddings: - update_rotary_pos_emb(curriculum_seqlen) - args.curriculum_seqlen = curriculum_seqlen - args.curr_iteration = iteration - loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \ - train_step(forward_step_func, - train_data_iterator, - model, - optimizer, - opt_param_scheduler, - config) - iteration += 1 - args.iteration = iteration - new_samples = mpu.get_data_parallel_world_size() * \ - args.micro_batch_size * \ - get_num_microbatches() - args.consumed_train_samples += new_samples - # This actual_seq_length is used for actual consumed tokens calculation, flops calculation, and logging. - args.actual_seq_length = args.seq_length - if args.curriculum_learning_legacy or args.data_efficiency_curriculum_learning: - args.actual_seq_length = args.curriculum_seqlen - if args.random_ltd: - args.random_ltd_reserved_length = model[0].random_ltd_scheduler.get_current_seq() - if args.random_ltd_reserved_length < args.actual_seq_length: - args.actual_seq_length = (args.actual_seq_length * (args.num_layers - args.random_ltd_layer_num) + args.random_ltd_reserved_length * args.random_ltd_layer_num) // args.num_layers - if args.curriculum_learning_legacy or args.data_efficiency_curriculum_learning: - if hasattr(args, 'data_efficiency_curriculum_learning_numel'): - act_mbsz = args.data_efficiency_curriculum_learning_numel / args.curriculum_seqlen - act_token = act_mbsz * args.actual_seq_length - args.consumed_train_tokens += mpu.get_data_parallel_world_size() * \ - get_num_microbatches() * act_token + with torch.profiler.profile( + schedule=torch.profiler.schedule(skip_first=2, wait=1, warmup=1, active=2, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler('/app/mingzhil/zhejiang/Megatron-DeepSpeed/0617_cuda_log_official_noact_waitstream_on_ds'), + record_shapes=True, + profile_memory=True, + with_stack=True + ) as prof: + + while iteration < args.train_iters and (args.train_tokens is None or \ + args.consumed_train_tokens < args.train_tokens): + prof.step() + + update_num_microbatches(args.consumed_train_samples) + if args.deepspeed: + # inform deepspeed of any batch size changes + global_batch_size = mpu.get_data_parallel_world_size() * \ + args.micro_batch_size * \ + get_num_microbatches() + model[0].set_train_batch_size(global_batch_size) + + if args.curriculum_learning_legacy and not args.no_pipeline_parallel: + curriculum_seqlen = args.curriculum_scheduler.update_difficulty( \ + args.iteration + 1) + if iteration == 0 or curriculum_seqlen != args.curriculum_seqlen: + if args.use_rotary_position_embeddings: + update_rotary_pos_emb(curriculum_seqlen) + args.curriculum_seqlen = curriculum_seqlen + args.curr_iteration = iteration + + + + loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \ + train_step(forward_step_func, + train_data_iterator, + model, + optimizer, + opt_param_scheduler, + config) + + iteration += 1 + args.iteration = iteration + new_samples = mpu.get_data_parallel_world_size() * \ + args.micro_batch_size * \ + get_num_microbatches() + args.consumed_train_samples += new_samples + # This actual_seq_length is used for actual consumed tokens calculation, flops calculation, and logging. + args.actual_seq_length = args.seq_length + if args.curriculum_learning_legacy or args.data_efficiency_curriculum_learning: + args.actual_seq_length = args.curriculum_seqlen + if args.random_ltd: + args.random_ltd_reserved_length = model[0].random_ltd_scheduler.get_current_seq() + if args.random_ltd_reserved_length < args.actual_seq_length: + args.actual_seq_length = (args.actual_seq_length * (args.num_layers - args.random_ltd_layer_num) + args.random_ltd_reserved_length * args.random_ltd_layer_num) // args.num_layers + if args.curriculum_learning_legacy or args.data_efficiency_curriculum_learning: + if hasattr(args, 'data_efficiency_curriculum_learning_numel'): + act_mbsz = args.data_efficiency_curriculum_learning_numel / args.curriculum_seqlen + act_token = act_mbsz * args.actual_seq_length + args.consumed_train_tokens += mpu.get_data_parallel_world_size() * \ + get_num_microbatches() * act_token + else: + args.consumed_train_tokens += new_samples * args.actual_seq_length else: args.consumed_train_tokens += new_samples * args.actual_seq_length - else: - args.consumed_train_tokens += new_samples * args.actual_seq_length - - # Logging. - if args.deepspeed: - if hasattr(model[0].optimizer, 'cur_scale'): - loss_scale = model[0].optimizer.cur_scale + + # Logging. + if args.deepspeed: + if hasattr(model[0].optimizer, 'cur_scale'): + loss_scale = model[0].optimizer.cur_scale + else: + loss_scale = None else: - loss_scale = None - else: - loss_scale = optimizer.get_loss_scale().item() - params_norm = None - if args.log_params_norm: - params_norm = calc_params_l2_norm(model) - report_memory_flag = training_log(loss_dict, total_loss_dict, - optimizer.param_groups[0]['lr'], - iteration, loss_scale, - report_memory_flag, skipped_iter, - grad_norm, params_norm, num_zeros_in_grad, - model, optimizer) - - # Autoresume - if args.adlr_autoresume and \ - (iteration % args.adlr_autoresume_interval == 0): - check_adlr_autoresume_termination(iteration, model, optimizer, - opt_param_scheduler) - - # Evaluation - if args.eval_interval and iteration % args.eval_interval == 0 and \ - args.do_valid: - prefix = 'iteration {}'.format(iteration) - evaluate_and_print_results(prefix, forward_step_func, - valid_data_iterator, model, - iteration, process_non_loss_data_func, - config, False) - - # Checkpointing - saved_checkpoint = False - if args.exit_signal_handler: - signal_handler = get_signal_handler() - if any(signal_handler.signals_received()): - save_checkpoint_and_time(iteration, model, optimizer, - opt_param_scheduler) - print_datetime('exiting program after receiving SIGTERM.') - sys.exit() - - if args.save and args.save_interval and \ - iteration % args.save_interval == 0: - save_checkpoint_and_time(iteration, model, optimizer, - opt_param_scheduler) - saved_checkpoint = True - - # Exiting based on duration - if args.exit_duration_in_mins: - train_time = (time.time() - _TRAIN_START_TIME) / 60.0 - done_cuda = get_accelerator().IntTensor( - [train_time > args.exit_duration_in_mins]) - torch.distributed.all_reduce( - done_cuda, op=torch.distributed.ReduceOp.MAX) - done = done_cuda.item() - if done: - if not saved_checkpoint: + loss_scale = optimizer.get_loss_scale().item() + params_norm = None + if args.log_params_norm: + params_norm = calc_params_l2_norm(model) + report_memory_flag = training_log(loss_dict, total_loss_dict, + optimizer.param_groups[0]['lr'], + iteration, loss_scale, + report_memory_flag, skipped_iter, + grad_norm, params_norm, num_zeros_in_grad, + model, optimizer) + + # Autoresume + if args.adlr_autoresume and \ + (iteration % args.adlr_autoresume_interval == 0): + check_adlr_autoresume_termination(iteration, model, optimizer, + opt_param_scheduler) + + # Evaluation + if args.eval_interval and iteration % args.eval_interval == 0 and \ + args.do_valid: + prefix = 'iteration {}'.format(iteration) + evaluate_and_print_results(prefix, forward_step_func, + valid_data_iterator, model, + iteration, process_non_loss_data_func, + config, False) + + # Checkpointing + saved_checkpoint = False + if args.exit_signal_handler: + signal_handler = get_signal_handler() + if any(signal_handler.signals_received()): save_checkpoint_and_time(iteration, model, optimizer, - opt_param_scheduler) - print_datetime('exiting program after {} minutes'.format(train_time)) - sys.exit() + opt_param_scheduler) + print_datetime('exiting program after receiving SIGTERM.') + sys.exit() - # Exiting based on iterations - if args.exit_interval and iteration % args.exit_interval == 0: - if args.save and not saved_checkpoint: + if args.save and args.save_interval and \ + iteration % args.save_interval == 0: save_checkpoint_and_time(iteration, model, optimizer, - opt_param_scheduler) - torch.distributed.barrier() - print_datetime('exiting program at iteration {}'.format(iteration)) - sys.exit() + opt_param_scheduler) + saved_checkpoint = True + + # Exiting based on duration + if args.exit_duration_in_mins: + train_time = (time.time() - _TRAIN_START_TIME) / 60.0 + done_cuda = get_accelerator().IntTensor( + [train_time > args.exit_duration_in_mins]) + torch.distributed.all_reduce( + done_cuda, op=torch.distributed.ReduceOp.MAX) + done = done_cuda.item() + if done: + if not saved_checkpoint: + save_checkpoint_and_time(iteration, model, optimizer, + opt_param_scheduler) + print_datetime('exiting program after {} minutes'.format(train_time)) + sys.exit() + + # Exiting based on iterations + if args.exit_interval and iteration % args.exit_interval == 0: + if args.save and not saved_checkpoint: + save_checkpoint_and_time(iteration, model, optimizer, + opt_param_scheduler) + torch.distributed.barrier() + print_datetime('exiting program at iteration {}'.format(iteration)) + sys.exit() return iteration From f78b8f51ffe3f526115fd0e5f613703c6135c1ab Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 20 Jun 2024 07:12:53 +0000 Subject: [PATCH 03/10] enable qk_bwd_ayncall2all --- megatron/model/__init__.py | 2 + megatron/model/transformer.py | 72 ++++++++++++++++++++++++++--------- megatron/training.py | 4 +- 3 files changed, 57 insertions(+), 21 deletions(-) diff --git a/megatron/model/__init__.py b/megatron/model/__init__.py index 2306749fcb..74b29afdde 100644 --- a/megatron/model/__init__.py +++ b/megatron/model/__init__.py @@ -3,6 +3,8 @@ from deepspeed.accelerator.real_accelerator import get_accelerator if get_accelerator().device_name() == 'cuda': from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm + from torch.nn import LayerNorm + from apex.normalization import MixedFusedRMSNorm as RMSNorm else: from .rmsnorm import RMSNorm diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 198d47fba9..71c3722ea6 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -566,13 +566,37 @@ def __init__(self, config, layer_number, # Strided linear layer. if attention_type == AttnType.self_attn: - self.query_key_value = tensor_parallel.ColumnParallelLinear( - config.hidden_size, - projection_size + 2 * kv_projection_size, - config=config, - init_method=config.init_method, - bias=args.add_bias_linear, - gather_output=False) + if False: + self.query_key_value = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + projection_size + 2 * kv_projection_size, + config=config, + init_method=config.init_method, + bias=args.add_bias_linear, + gather_output=False) + + else: + self.query_linear = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + projection_size , + config=config, + init_method=config.init_method, + bias=args.add_bias_linear, + gather_output=False) + self.key_linear = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + kv_projection_size , + config=config, + init_method=config.init_method, + bias=args.add_bias_linear, + gather_output=False) + self.value_linear = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + kv_projection_size , + config=config, + init_method=config.init_method, + bias=args.add_bias_linear, + gather_output=False) else: assert attention_type == AttnType.cross_attn self.query = tensor_parallel.ColumnParallelLinear( @@ -605,7 +629,7 @@ def __init__(self, config, layer_number, if self.enable_ds_sequence_parallel: assert dist_attn_supported, 'Distributed attention is not supported in this DeepSpeed version' assert args.num_attention_heads % parallel_state.get_sequence_parallel_world_size() == 0 - self.dist_attn = DistributedAttention(local_attn, parallel_state.get_sequence_parallel_group(),sp_stream=self.get_stream()) + self.dist_attn = DistributedAttention(local_attn, parallel_state.get_sequence_parallel_group(),sp_stream=self.get_stream(),q_linear=self.query_linear,k_linear=self.key_linear) else: if self.use_flash_attn: self.core_attention_flash = local_attn @@ -705,19 +729,29 @@ def forward(self, hidden_states, attention_mask, # ===================== if self.attention_type == AttnType.self_attn: - # Attention heads [sq, b, h] --> [sq, b, ((nq + 2 * nkv) * hn)] - mixed_x_layer, _ = self.query_key_value(hidden_states) + if False: + # Attention heads [sq, b, h] --> [sq, b, ((nq + 2 * nkv) * hn)] hidden_states 4096, 1, 2048 + mixed_x_layer, _ = self.query_key_value(hidden_states) #heads16 hidden 2048 num_per_head 128 + #[4096, 1,6144] -> 16,3,128 + # [sq, b, ((nq + 2 * nkv) * hn)] --> [sq, b, nkv, (nq // nkv + 2), hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + \ + (-1, (self.num_key_value_groups + 2), + self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, nkv, (nq // nkv + 2), hn] --> 3 [sq, b, np, hn] + (query_layer, #[4096,1,16,128] + key_layer, #[4096,1,16,128] + value_layer) = self.split_tensor(mixed_x_layer) #[4096,1,16,128] + else: + query_layer,_ = self.query_linear(hidden_states) + query_layer=query_layer.reshape(query_layer.shape[0],query_layer.shape[1],self.num_attention_heads,-1) - # [sq, b, ((nq + 2 * nkv) * hn)] --> [sq, b, nkv, (nq // nkv + 2), hn] - new_tensor_shape = mixed_x_layer.size()[:-1] + \ - (-1, (self.num_key_value_groups + 2), - self.hidden_size_per_attention_head) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + key_layer,_ = self.key_linear(hidden_states) + key_layer=key_layer.reshape(key_layer.shape[0],key_layer.shape[1],self.num_attention_heads,-1) - # [sq, b, nkv, (nq // nkv + 2), hn] --> 3 [sq, b, np, hn] - (query_layer, - key_layer, - value_layer) = self.split_tensor(mixed_x_layer) + value_layer,_ = self.value_linear(hidden_states) + value_layer=value_layer.reshape(value_layer.shape[0],value_layer.shape[1],self.num_attention_heads,-1) # Repeat kv if self.use_gqa: diff --git a/megatron/training.py b/megatron/training.py index 9ad0afe186..1c40479067 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -1190,8 +1190,8 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, args.random_ltd_layer_num = model[0].random_ltd_scheduler.get_random_ltd_layer_num() with torch.profiler.profile( - schedule=torch.profiler.schedule(skip_first=2, wait=1, warmup=1, active=2, repeat=1), - on_trace_ready=torch.profiler.tensorboard_trace_handler('/app/mingzhil/zhejiang/Megatron-DeepSpeed/0617_cuda_log_official_noact_waitstream_on_ds'), + schedule=torch.profiler.schedule(skip_first=2000000, wait=1, warmup=1, active=2, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler('/app/mingzhil/zhejiang/Megatron-DeepSpeed/0617_cuda_log_asyncqkv_changetuple_falsebench'), record_shapes=True, profile_memory=True, with_stack=True From 0be1080d505afb5e5a8cf99bf9cb902611bb3be7 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Jun 2024 06:09:40 +0000 Subject: [PATCH 04/10] fwd optim --- megatron/model/__init__.py | 2 -- megatron/model/transformer.py | 36 ++++++++++++++++++++++++++--------- megatron/training.py | 10 +++++++--- 3 files changed, 34 insertions(+), 14 deletions(-) diff --git a/megatron/model/__init__.py b/megatron/model/__init__.py index 74b29afdde..2306749fcb 100644 --- a/megatron/model/__init__.py +++ b/megatron/model/__init__.py @@ -3,8 +3,6 @@ from deepspeed.accelerator.real_accelerator import get_accelerator if get_accelerator().device_name() == 'cuda': from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm - from torch.nn import LayerNorm - from apex.normalization import MixedFusedRMSNorm as RMSNorm else: from .rmsnorm import RMSNorm diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 71c3722ea6..4a2ca13cd9 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -497,12 +497,18 @@ class ParallelAttention(MegatronModule): and returns output of the same size. """ stream=None + stream2=None @classmethod def get_stream(cls): if cls.stream==None: cls.stream=torch.cuda.Stream() # cls.stream=torch.cuda.current_stream() return cls.stream + def get_stream2(cls): + if cls.stream2==None: + cls.stream2=torch.cuda.Stream() + # cls.stream=torch.cuda.current_stream() + return cls.stream2 def __init__(self, config, layer_number, attention_type=AttnType.self_attn, @@ -744,14 +750,24 @@ def forward(self, hidden_states, attention_mask, key_layer, #[4096,1,16,128] value_layer) = self.split_tensor(mixed_x_layer) #[4096,1,16,128] else: - query_layer,_ = self.query_linear(hidden_states) - query_layer=query_layer.reshape(query_layer.shape[0],query_layer.shape[1],self.num_attention_heads,-1) - - key_layer,_ = self.key_linear(hidden_states) - key_layer=key_layer.reshape(key_layer.shape[0],key_layer.shape[1],self.num_attention_heads,-1) - - value_layer,_ = self.value_linear(hidden_states) - value_layer=value_layer.reshape(value_layer.shape[0],value_layer.shape[1],self.num_attention_heads,-1) + self.get_stream().wait_stream(torch.cuda.current_stream()) + # with torch.cuda.stream(torch.cuda.current_stream()): + with torch.cuda.stream(self.get_stream()): + query_layer,_ = self.query_linear(hidden_states) + query_layer=query_layer.reshape(query_layer.shape[0],query_layer.shape[1],self.num_attention_heads,-1) + fwd_query_layer_done_event = torch.cuda.Event() + fwd_query_layer_done_event.record(self.get_stream()) + key_layer,_ = self.key_linear(hidden_states) + key_layer=key_layer.reshape(key_layer.shape[0],key_layer.shape[1],self.num_attention_heads,-1) + + fwd_key_layer_done_event = torch.cuda.Event() + fwd_key_layer_done_event.record(self.get_stream()) + # key_layer.done_event=fwd_key_layer_done_event + value_layer,_ = self.value_linear(hidden_states) + value_layer=value_layer.reshape(value_layer.shape[0],value_layer.shape[1],self.num_attention_heads,-1) + # fwd_value_layer_done_event = torch.cuda.Event() + # fwd_value_layer_done_event.record(torch.cuda.current_stream()) + # torch.cuda.current_stream().wait_stream(self.get_stream()) # Repeat kv if self.use_gqa: @@ -851,7 +867,9 @@ def forward(self, hidden_states, attention_mask, if not self.use_flash_attn_triton: query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous() for x in (query_layer, key_layer, value_layer)] - + key_layer.done_event=fwd_key_layer_done_event + query_layer.done_event=fwd_query_layer_done_event + # value_layer.done_event=fwd_value_layer_done_event context_layer = self.dist_attn(query_layer, key_layer, value_layer) if not self.use_flash_attn_triton: diff --git a/megatron/training.py b/megatron/training.py index 1c40479067..7cf42ebe47 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -753,7 +753,11 @@ def train_step(forward_step_func, data_iterator, # Update learning rate. if args.deepspeed: skipped_iter = 0 - grad_norm = None + # grad_norm = None + if hasattr(model[0], 'get_global_grad_norm'): + grad_norm = model[0].get_global_grad_norm() + else: + grad_norm = None num_zeros_in_grad = None loss_reduced = {} @@ -1190,8 +1194,8 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, args.random_ltd_layer_num = model[0].random_ltd_scheduler.get_random_ltd_layer_num() with torch.profiler.profile( - schedule=torch.profiler.schedule(skip_first=2000000, wait=1, warmup=1, active=2, repeat=1), - on_trace_ready=torch.profiler.tensorboard_trace_handler('/app/mingzhil/zhejiang/Megatron-DeepSpeed/0617_cuda_log_asyncqkv_changetuple_falsebench'), + schedule=torch.profiler.schedule(skip_first=10000, wait=1, warmup=1, active=2, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler('/app/mingzhil/zhejiang/Megatron-DeepSpeed/0621_cuda_log_fwd_enventq+asyncqkv_changetuple_finalbench_'), record_shapes=True, profile_memory=True, with_stack=True From bdc9f917a2abd54471e18d2f584331740e6d35dd Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 25 Jun 2024 10:54:47 +0800 Subject: [PATCH 05/10] fix arg --- megatron/model/transformer.py | 4 ++-- megatron/training.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 4a2ca13cd9..c7d4b36e23 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -635,7 +635,7 @@ def __init__(self, config, layer_number, if self.enable_ds_sequence_parallel: assert dist_attn_supported, 'Distributed attention is not supported in this DeepSpeed version' assert args.num_attention_heads % parallel_state.get_sequence_parallel_world_size() == 0 - self.dist_attn = DistributedAttention(local_attn, parallel_state.get_sequence_parallel_group(),sp_stream=self.get_stream(),q_linear=self.query_linear,k_linear=self.key_linear) + self.dist_attn = DistributedAttention(local_attn, parallel_state.get_sequence_parallel_group(),sp_stream=self.get_stream()) else: if self.use_flash_attn: self.core_attention_flash = local_attn @@ -751,8 +751,8 @@ def forward(self, hidden_states, attention_mask, value_layer) = self.split_tensor(mixed_x_layer) #[4096,1,16,128] else: self.get_stream().wait_stream(torch.cuda.current_stream()) - # with torch.cuda.stream(torch.cuda.current_stream()): with torch.cuda.stream(self.get_stream()): + # with torch.cuda.stream(self.get_stream()): query_layer,_ = self.query_linear(hidden_states) query_layer=query_layer.reshape(query_layer.shape[0],query_layer.shape[1],self.num_attention_heads,-1) fwd_query_layer_done_event = torch.cuda.Event() diff --git a/megatron/training.py b/megatron/training.py index 7cf42ebe47..a6fc093504 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -1194,8 +1194,8 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, args.random_ltd_layer_num = model[0].random_ltd_scheduler.get_random_ltd_layer_num() with torch.profiler.profile( - schedule=torch.profiler.schedule(skip_first=10000, wait=1, warmup=1, active=2, repeat=1), - on_trace_ready=torch.profiler.tensorboard_trace_handler('/app/mingzhil/zhejiang/Megatron-DeepSpeed/0621_cuda_log_fwd_enventq+asyncqkv_changetuple_finalbench_'), + schedule=torch.profiler.schedule(skip_first=100000, wait=1, warmup=1, active=2, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler('/app/mingzhil/zhejiang/Megatron-DeepSpeed/0621_cuda_log_aysnqkv+optimicpu'), record_shapes=True, profile_memory=True, with_stack=True From e9f8e99e27bd628fb569504b9cfd44a9c3a4d65a Mon Sep 17 00:00:00 2001 From: inkcherry Date: Sun, 30 Jun 2024 06:54:35 +0000 Subject: [PATCH 06/10] use current_stream --- megatron/core/tensor_parallel/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index 0ccb633ffe..552b889c0f 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -248,7 +248,7 @@ def forward(ctx, input, weight, bias, gradient_accumulation_fusion, ctx.async_grad_allreduce = async_grad_allreduce ctx.sequence_parallel = sequence_parallel if bwd_stream==None: - ctx.bwd_stream=torch.cuda.default_stream() + ctx.bwd_stream=torch.cuda.current_stream() else: ctx.bwd_stream=bwd_stream if sequence_parallel: From 231d2e9e7393ef7a28a93644184f55c84dbae402 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Sun, 30 Jun 2024 12:30:18 +0000 Subject: [PATCH 07/10] split qkv + sp overlap comm --- megatron/arguments.py | 10 +- megatron/core/tensor_parallel/layers.py | 50 ++--- megatron/model/transformer.py | 104 +++++----- megatron/training.py | 254 +++++++++++------------- 4 files changed, 196 insertions(+), 222 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index e7182c317e..f6b571fb44 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -95,7 +95,9 @@ def validate_args(args, defaults={}): if args.ds_sequence_parallel_size > 1: assert version.parse(deepspeed.__version__) >= version.parse("0.10.2"), "sequence parallelism requires DeepSpeed version 0.10.2+" - + if args.ds_sequence_parallel_overlap_comm: + assert args.split_qkv_linear, \ + "ds_sequence_parallel_overlap_comm requires split_qkv_linear is True" model_parallel_size = args.pipeline_model_parallel_size * \ args.tensor_model_parallel_size * \ args.ds_sequence_parallel_size @@ -912,6 +914,9 @@ def _add_training_args(parser): group.add_argument('--disable-moe-top2-2nd-expert-sampling', action='store_false', help='Disable MoE top2 sampling of the 2nd expert. Instead of sampling, use argmax.', dest='moe_top2_2nd_expert_sampling') + group.add_argument('--split-qkv-linear', action='store_true', + help='Separate linear computations for query, key, and value.', + dest='split_qkv_linear') group.add_argument('--use-flash-attn', '--use-flash-attn-v1', dest='use_flash_attn_v1', action='store_true', help='use first version FlashAttention implementation of attention. ' 'https://arxiv.org/abs/2205.14135') @@ -963,6 +968,9 @@ def _add_training_args(parser): help='Enable DeepSpeed\'s sequence parallel. Cannot be combined with "--sequence-parallel", which enables Megatron-LM\'s sequence parallel.') group.add_argument('--force-ds-sequence-parallel', action='store_true', help='use DeepSpeed sequence parallelism regardless of sequence parallel size.') + group.add_argument('--ds-sequence-parallel-overlap-comm', action='store_true', + help='overlap comm for ds-sequence-parallel', + dest='ds_sequence_parallel_overlap_comm') group.add_argument('--no-gradient-accumulation-fusion', action='store_false', help='Disable fusing gradient accumulation to weight ' diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index 5d46ea36c2..c08622a498 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -19,6 +19,7 @@ from megatron import get_args from megatron.core.model_parallel_config import ModelParallelConfig +from megatron.core.tensor_parallel.weight_grad_store import WeightGradStore from megatron.core.parallel_state import ( get_tensor_model_parallel_rank, @@ -244,20 +245,17 @@ def gradientUpdateFunction(total_input, grad_output, weight): class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): """See linear_with_grad_accumulation_and_async_allreduce""" - @staticmethod @custom_fwd def forward(ctx, input, weight, bias, gradient_accumulation_fusion, - async_grad_allreduce, sequence_parallel, bwd_stream): + async_grad_allreduce, sequence_parallel, bwd_stream=None): ctx.save_for_backward(input, weight) ctx.use_bias = bias is not None ctx.gradient_accumulation_fusion = gradient_accumulation_fusion ctx.async_grad_allreduce = async_grad_allreduce ctx.sequence_parallel = sequence_parallel - if bwd_stream==None: - ctx.bwd_stream=torch.cuda.current_stream() - else: - ctx.bwd_stream=bwd_stream + ctx.bwd_stream = bwd_stream + if sequence_parallel: world_size = get_tensor_model_parallel_world_size() dim_size = list(input.size()) @@ -316,6 +314,7 @@ def backward(ctx, grad_output): total_input = all_gather_buffer else: total_input = input + grad_input = grad_output.matmul(weight) if ctx.sequence_parallel: @@ -370,12 +369,16 @@ def backward(ctx, grad_output): # grad_weight = None # else: # grad_weight = grad_output.t().matmul(total_input) - from megatron.core.tensor_parallel.weight_grad_store import WeightGradStore - ctx.bwd_stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(ctx.bwd_stream): + if ctx.bwd_stream is not None: + ctx.bwd_stream.wait_stream(get_accelerator().current_stream()) + with get_accelerator().stream(ctx.bwd_stream): + WeightGradStore.put(total_input, grad_output, weight, gradientUpdateFunction) + ctx.bwd_stream.activation_buffer_list = [total_input, grad_output] + else: WeightGradStore.put(total_input, grad_output, weight, gradientUpdateFunction) - grad_weight = None - grad_bias = grad_output.sum(dim=0) if use_bias else None + + grad_weight = None + grad_bias = grad_output.sum(dim=0) if use_bias else None if ctx.sequence_parallel: handle.wait() @@ -383,9 +386,8 @@ def backward(ctx, grad_output): if ctx.async_grad_allreduce: handle.wait() - # torch.cuda.default_stream().wait_stream(ctx.bwd_stream) - return grad_input, grad_weight, grad_bias, None, None, None,None -from typing import Any + return grad_input, grad_weight, grad_bias, None, None, None, None + def linear_with_grad_accumulation_and_async_allreduce( input: torch.Tensor, weight: torch.Tensor, @@ -393,7 +395,7 @@ def linear_with_grad_accumulation_and_async_allreduce( gradient_accumulation_fusion: bool, async_grad_allreduce: bool, sequence_parallel: bool, - async_sp_all2all_stream:Any=None + async_sp_all2all_stream=None ) -> torch.Tensor: """Linear layer execution with asynchronous communication and gradient accumulation fusion in backprop. @@ -511,7 +513,6 @@ class ColumnParallelLinear(torch.nn.Module): config: ModelParallelConfig object """ - stream=None def __init__(self, input_size, output_size, *, config: ModelParallelConfig, @@ -609,8 +610,6 @@ def __init__(self, input_size, output_size, *, "cannot be enabled at the same time." ) - - def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None): @@ -647,8 +646,6 @@ def forward(self, input_parallel = input_ else: input_parallel = copy_to_tensor_model_parallel_region(input_) - - # Matrix multiply. output_parallel = linear_with_grad_accumulation_and_async_allreduce( input=input_parallel, @@ -656,8 +653,7 @@ def forward(self, bias=bias, gradient_accumulation_fusion=self.gradient_accumulation_fusion, async_grad_allreduce=self.async_tensor_model_parallel_allreduce, - sequence_parallel=self.sequence_parallel, - async_sp_all2all_stream=None + sequence_parallel=self.sequence_parallel ) if self.gather_output and not self.is_expert_without_slicing: # All-gather across the partitions. @@ -712,9 +708,9 @@ def __init__(self, input_size: int, output_size: int, *, stride: int = 1, keep_master_weight_for_test: bool = False, skip_bias_add: bool = False, - moe=False, enable_expert_tensor_parallelism=False,ds_sp_sync_stream=None): + moe=False, enable_expert_tensor_parallelism=False, ds_sp_async_stream=None): torch.nn.Module.__init__(self) - self.ds_sp_sync_stream=ds_sp_sync_stream + self.ds_sp_async_stream = ds_sp_async_stream # Keep input parameters self.input_size = input_size @@ -791,9 +787,7 @@ def forward(self, input_): assert not self.sequence_parallel input_parallel = scatter_to_tensor_model_parallel_region(input_) # Matrix multiply. - async_sp_all2all_stream=None - if self.ds_sp_sync_stream is not None: - async_sp_all2all_stream=self.ds_sp_sync_stream + output_parallel = linear_with_grad_accumulation_and_async_allreduce( input=input_parallel, weight=self.weight, @@ -801,7 +795,7 @@ def forward(self, input_): gradient_accumulation_fusion=self.gradient_accumulation_fusion, async_grad_allreduce=False, sequence_parallel=False, - async_sp_all2all_stream=async_sp_all2all_stream + async_sp_all2all_stream=self.ds_sp_async_stream ) # All-reduce across all the partitions. diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index c7d4b36e23..ddf900f561 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -496,20 +496,14 @@ class ParallelAttention(MegatronModule): Self-attention layer takes input with size [s, b, h] and returns output of the same size. """ - stream=None - stream2=None - @classmethod - def get_stream(cls): - if cls.stream==None: - cls.stream=torch.cuda.Stream() - # cls.stream=torch.cuda.current_stream() - return cls.stream - def get_stream2(cls): - if cls.stream2==None: - cls.stream2=torch.cuda.Stream() - # cls.stream=torch.cuda.current_stream() - return cls.stream2 + sp_stream=None + def get_sp_stream(self): + if not self.ds_sp_overlap: + return None + if ParallelAttention.sp_stream is None: + ParallelAttention.sp_stream=get_accelerator().Stream() + return ParallelAttention.sp_stream def __init__(self, config, layer_number, attention_type=AttnType.self_attn, attn_mask_type=AttnMaskType.padding): @@ -523,7 +517,8 @@ def __init__(self, config, layer_number, self.num_attention_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.use_gqa = (self.num_attention_heads != self.num_key_value_heads) - + self.split_qkv = args.split_qkv_linear + self.ds_sp_overlap = args.ds_sequence_parallel_overlap_comm self.use_flash_attn = (args.use_flash_attn_v1 or args.use_flash_attn_triton or args.use_flash_attn_v2 or \ args.use_flash_attn_builder) \ and attention_type == AttnType.self_attn \ @@ -572,7 +567,7 @@ def __init__(self, config, layer_number, # Strided linear layer. if attention_type == AttnType.self_attn: - if False: + if not self.split_qkv: self.query_key_value = tensor_parallel.ColumnParallelLinear( config.hidden_size, projection_size + 2 * kv_projection_size, @@ -582,27 +577,21 @@ def __init__(self, config, layer_number, gather_output=False) else: - self.query_linear = tensor_parallel.ColumnParallelLinear( - config.hidden_size, - projection_size , - config=config, - init_method=config.init_method, - bias=args.add_bias_linear, - gather_output=False) - self.key_linear = tensor_parallel.ColumnParallelLinear( - config.hidden_size, - kv_projection_size , - config=config, - init_method=config.init_method, - bias=args.add_bias_linear, - gather_output=False) - self.value_linear = tensor_parallel.ColumnParallelLinear( - config.hidden_size, - kv_projection_size , - config=config, - init_method=config.init_method, - bias=args.add_bias_linear, - gather_output=False) + linear_configs = [ + ("query_linear", projection_size), + ("key_linear", kv_projection_size), + ("value_linear", kv_projection_size), + ] + + for attr_name, output_size in linear_configs: + setattr(self, attr_name, tensor_parallel.ColumnParallelLinear( + config.hidden_size, + output_size, + config=config, + init_method=config.init_method, + bias=args.add_bias_linear, + gather_output=False + )) else: assert attention_type == AttnType.cross_attn self.query = tensor_parallel.ColumnParallelLinear( @@ -633,9 +622,10 @@ def __init__(self, config, layer_number, self.enable_ds_sequence_parallel = parallel_state.get_sequence_parallel_world_size() > 1 \ or args.force_ds_sequence_parallel if self.enable_ds_sequence_parallel: + assert dist_attn_supported, 'Distributed attention is not supported in this DeepSpeed version' assert args.num_attention_heads % parallel_state.get_sequence_parallel_world_size() == 0 - self.dist_attn = DistributedAttention(local_attn, parallel_state.get_sequence_parallel_group(),sp_stream=self.get_stream()) + self.dist_attn = DistributedAttention(local_attn, parallel_state.get_sequence_parallel_group(),sp_stream=self.get_sp_stream()) else: if self.use_flash_attn: self.core_attention_flash = local_attn @@ -643,7 +633,6 @@ def __init__(self, config, layer_number, self.core_attention = local_attn self.checkpoint_core_attention = config.recompute_granularity == 'selective' - # Output. self.dense = tensor_parallel.RowParallelLinear( projection_size, @@ -653,7 +642,7 @@ def __init__(self, config, layer_number, bias=args.add_bias_linear, input_is_parallel=True, skip_bias_add=True, - ds_sp_sync_stream=self.get_stream() + ds_sp_async_stream=self.get_sp_stream() ) @@ -735,7 +724,7 @@ def forward(self, hidden_states, attention_mask, # ===================== if self.attention_type == AttnType.self_attn: - if False: + if not self.split_qkv: # Attention heads [sq, b, h] --> [sq, b, ((nq + 2 * nkv) * hn)] hidden_states 4096, 1, 2048 mixed_x_layer, _ = self.query_key_value(hidden_states) #heads16 hidden 2048 num_per_head 128 #[4096, 1,6144] -> 16,3,128 @@ -746,28 +735,28 @@ def forward(self, hidden_states, attention_mask, mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, nkv, (nq // nkv + 2), hn] --> 3 [sq, b, np, hn] - (query_layer, #[4096,1,16,128] - key_layer, #[4096,1,16,128] - value_layer) = self.split_tensor(mixed_x_layer) #[4096,1,16,128] + (query_layer, + key_layer, + value_layer) = self.split_tensor(mixed_x_layer) else: - self.get_stream().wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(self.get_stream()): - # with torch.cuda.stream(self.get_stream()): + assert self.ds_sp_overlap, """ + Currently, the split_qkv operation is only applicable + when ds_sp_overlap is enabled. + """ + self.get_sp_stream().wait_stream(get_accelerator().current_stream()) + with get_accelerator().stream(self.get_sp_stream()): query_layer,_ = self.query_linear(hidden_states) query_layer=query_layer.reshape(query_layer.shape[0],query_layer.shape[1],self.num_attention_heads,-1) - fwd_query_layer_done_event = torch.cuda.Event() - fwd_query_layer_done_event.record(self.get_stream()) + fwd_query_layer_done_event = get_accelerator().Event() + fwd_query_layer_done_event.record(self.get_sp_stream()) key_layer,_ = self.key_linear(hidden_states) key_layer=key_layer.reshape(key_layer.shape[0],key_layer.shape[1],self.num_attention_heads,-1) - fwd_key_layer_done_event = torch.cuda.Event() - fwd_key_layer_done_event.record(self.get_stream()) - # key_layer.done_event=fwd_key_layer_done_event + fwd_key_layer_done_event = get_accelerator().Event() + fwd_key_layer_done_event.record(self.get_sp_stream()) value_layer,_ = self.value_linear(hidden_states) value_layer=value_layer.reshape(value_layer.shape[0],value_layer.shape[1],self.num_attention_heads,-1) - # fwd_value_layer_done_event = torch.cuda.Event() - # fwd_value_layer_done_event.record(torch.cuda.current_stream()) - # torch.cuda.current_stream().wait_stream(self.get_stream()) + # Repeat kv if self.use_gqa: @@ -867,9 +856,10 @@ def forward(self, hidden_states, attention_mask, if not self.use_flash_attn_triton: query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous() for x in (query_layer, key_layer, value_layer)] - key_layer.done_event=fwd_key_layer_done_event - query_layer.done_event=fwd_query_layer_done_event - # value_layer.done_event=fwd_value_layer_done_event + if self.ds_sp_overlap: + key_layer.done_event=fwd_key_layer_done_event + query_layer.done_event=fwd_query_layer_done_event + context_layer = self.dist_attn(query_layer, key_layer, value_layer) if not self.use_flash_attn_triton: diff --git a/megatron/training.py b/megatron/training.py index a6fc093504..19b8a6c71f 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -753,11 +753,7 @@ def train_step(forward_step_func, data_iterator, # Update learning rate. if args.deepspeed: skipped_iter = 0 - # grad_norm = None - if hasattr(model[0], 'get_global_grad_norm'): - grad_norm = model[0].get_global_grad_norm() - else: - grad_norm = None + grad_norm = None num_zeros_in_grad = None loss_reduced = {} @@ -960,7 +956,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, iteration, ) - if iteration % args.tensorboard_log_interval == 9999999: + if iteration % args.tensorboard_log_interval == 0: # This logging write various optimizer states to tensorboard. This # feature may consume extra GPU memory thus is set at false by default. if args.log_optimizer_states_to_tensorboard and optimizer is not None: @@ -1192,144 +1188,130 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, if args.random_ltd: assert model[0].random_ltd_enabled() args.random_ltd_layer_num = model[0].random_ltd_scheduler.get_random_ltd_layer_num() - - with torch.profiler.profile( - schedule=torch.profiler.schedule(skip_first=100000, wait=1, warmup=1, active=2, repeat=1), - on_trace_ready=torch.profiler.tensorboard_trace_handler('/app/mingzhil/zhejiang/Megatron-DeepSpeed/0621_cuda_log_aysnqkv+optimicpu'), - record_shapes=True, - profile_memory=True, - with_stack=True - ) as prof: - - while iteration < args.train_iters and (args.train_tokens is None or \ - args.consumed_train_tokens < args.train_tokens): - prof.step() - - update_num_microbatches(args.consumed_train_samples) - if args.deepspeed: - # inform deepspeed of any batch size changes - global_batch_size = mpu.get_data_parallel_world_size() * \ - args.micro_batch_size * \ - get_num_microbatches() - model[0].set_train_batch_size(global_batch_size) - - if args.curriculum_learning_legacy and not args.no_pipeline_parallel: - curriculum_seqlen = args.curriculum_scheduler.update_difficulty( \ - args.iteration + 1) - if iteration == 0 or curriculum_seqlen != args.curriculum_seqlen: - if args.use_rotary_position_embeddings: - update_rotary_pos_emb(curriculum_seqlen) - args.curriculum_seqlen = curriculum_seqlen - args.curr_iteration = iteration - - - loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \ - train_step(forward_step_func, - train_data_iterator, - model, - optimizer, - opt_param_scheduler, - config) - - iteration += 1 - args.iteration = iteration - new_samples = mpu.get_data_parallel_world_size() * \ - args.micro_batch_size * \ - get_num_microbatches() - args.consumed_train_samples += new_samples - # This actual_seq_length is used for actual consumed tokens calculation, flops calculation, and logging. - args.actual_seq_length = args.seq_length - if args.curriculum_learning_legacy or args.data_efficiency_curriculum_learning: - args.actual_seq_length = args.curriculum_seqlen - if args.random_ltd: - args.random_ltd_reserved_length = model[0].random_ltd_scheduler.get_current_seq() - if args.random_ltd_reserved_length < args.actual_seq_length: - args.actual_seq_length = (args.actual_seq_length * (args.num_layers - args.random_ltd_layer_num) + args.random_ltd_reserved_length * args.random_ltd_layer_num) // args.num_layers - if args.curriculum_learning_legacy or args.data_efficiency_curriculum_learning: - if hasattr(args, 'data_efficiency_curriculum_learning_numel'): - act_mbsz = args.data_efficiency_curriculum_learning_numel / args.curriculum_seqlen - act_token = act_mbsz * args.actual_seq_length - args.consumed_train_tokens += mpu.get_data_parallel_world_size() * \ - get_num_microbatches() * act_token - else: - args.consumed_train_tokens += new_samples * args.actual_seq_length + while iteration < args.train_iters and (args.train_tokens is None or \ + args.consumed_train_tokens < args.train_tokens): + update_num_microbatches(args.consumed_train_samples) + if args.deepspeed: + # inform deepspeed of any batch size changes + global_batch_size = mpu.get_data_parallel_world_size() * \ + args.micro_batch_size * \ + get_num_microbatches() + model[0].set_train_batch_size(global_batch_size) + + if args.curriculum_learning_legacy and not args.no_pipeline_parallel: + curriculum_seqlen = args.curriculum_scheduler.update_difficulty( \ + args.iteration + 1) + if iteration == 0 or curriculum_seqlen != args.curriculum_seqlen: + if args.use_rotary_position_embeddings: + update_rotary_pos_emb(curriculum_seqlen) + args.curriculum_seqlen = curriculum_seqlen + args.curr_iteration = iteration + loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \ + train_step(forward_step_func, + train_data_iterator, + model, + optimizer, + opt_param_scheduler, + config) + iteration += 1 + args.iteration = iteration + new_samples = mpu.get_data_parallel_world_size() * \ + args.micro_batch_size * \ + get_num_microbatches() + args.consumed_train_samples += new_samples + # This actual_seq_length is used for actual consumed tokens calculation, flops calculation, and logging. + args.actual_seq_length = args.seq_length + if args.curriculum_learning_legacy or args.data_efficiency_curriculum_learning: + args.actual_seq_length = args.curriculum_seqlen + if args.random_ltd: + args.random_ltd_reserved_length = model[0].random_ltd_scheduler.get_current_seq() + if args.random_ltd_reserved_length < args.actual_seq_length: + args.actual_seq_length = (args.actual_seq_length * (args.num_layers - args.random_ltd_layer_num) + args.random_ltd_reserved_length * args.random_ltd_layer_num) // args.num_layers + if args.curriculum_learning_legacy or args.data_efficiency_curriculum_learning: + if hasattr(args, 'data_efficiency_curriculum_learning_numel'): + act_mbsz = args.data_efficiency_curriculum_learning_numel / args.curriculum_seqlen + act_token = act_mbsz * args.actual_seq_length + args.consumed_train_tokens += mpu.get_data_parallel_world_size() * \ + get_num_microbatches() * act_token else: args.consumed_train_tokens += new_samples * args.actual_seq_length - - # Logging. - if args.deepspeed: - if hasattr(model[0].optimizer, 'cur_scale'): - loss_scale = model[0].optimizer.cur_scale - else: - loss_scale = None + else: + args.consumed_train_tokens += new_samples * args.actual_seq_length + + # Logging. + if args.deepspeed: + if hasattr(model[0].optimizer, 'cur_scale'): + loss_scale = model[0].optimizer.cur_scale else: - loss_scale = optimizer.get_loss_scale().item() - params_norm = None - if args.log_params_norm: - params_norm = calc_params_l2_norm(model) - report_memory_flag = training_log(loss_dict, total_loss_dict, - optimizer.param_groups[0]['lr'], - iteration, loss_scale, - report_memory_flag, skipped_iter, - grad_norm, params_norm, num_zeros_in_grad, - model, optimizer) - - # Autoresume - if args.adlr_autoresume and \ - (iteration % args.adlr_autoresume_interval == 0): - check_adlr_autoresume_termination(iteration, model, optimizer, - opt_param_scheduler) - - # Evaluation - if args.eval_interval and iteration % args.eval_interval == 0 and \ - args.do_valid: - prefix = 'iteration {}'.format(iteration) - evaluate_and_print_results(prefix, forward_step_func, - valid_data_iterator, model, - iteration, process_non_loss_data_func, - config, False) - - # Checkpointing - saved_checkpoint = False - if args.exit_signal_handler: - signal_handler = get_signal_handler() - if any(signal_handler.signals_received()): - save_checkpoint_and_time(iteration, model, optimizer, - opt_param_scheduler) - print_datetime('exiting program after receiving SIGTERM.') - sys.exit() - - if args.save and args.save_interval and \ - iteration % args.save_interval == 0: + loss_scale = None + else: + loss_scale = optimizer.get_loss_scale().item() + params_norm = None + if args.log_params_norm: + params_norm = calc_params_l2_norm(model) + report_memory_flag = training_log(loss_dict, total_loss_dict, + optimizer.param_groups[0]['lr'], + iteration, loss_scale, + report_memory_flag, skipped_iter, + grad_norm, params_norm, num_zeros_in_grad, + model, optimizer) + + # Autoresume + if args.adlr_autoresume and \ + (iteration % args.adlr_autoresume_interval == 0): + check_adlr_autoresume_termination(iteration, model, optimizer, + opt_param_scheduler) + + # Evaluation + if args.eval_interval and iteration % args.eval_interval == 0 and \ + args.do_valid: + prefix = 'iteration {}'.format(iteration) + evaluate_and_print_results(prefix, forward_step_func, + valid_data_iterator, model, + iteration, process_non_loss_data_func, + config, False) + + # Checkpointing + saved_checkpoint = False + if args.exit_signal_handler: + signal_handler = get_signal_handler() + if any(signal_handler.signals_received()): save_checkpoint_and_time(iteration, model, optimizer, - opt_param_scheduler) - saved_checkpoint = True - - # Exiting based on duration - if args.exit_duration_in_mins: - train_time = (time.time() - _TRAIN_START_TIME) / 60.0 - done_cuda = get_accelerator().IntTensor( - [train_time > args.exit_duration_in_mins]) - torch.distributed.all_reduce( - done_cuda, op=torch.distributed.ReduceOp.MAX) - done = done_cuda.item() - if done: - if not saved_checkpoint: - save_checkpoint_and_time(iteration, model, optimizer, - opt_param_scheduler) - print_datetime('exiting program after {} minutes'.format(train_time)) - sys.exit() - - # Exiting based on iterations - if args.exit_interval and iteration % args.exit_interval == 0: - if args.save and not saved_checkpoint: + opt_param_scheduler) + print_datetime('exiting program after receiving SIGTERM.') + sys.exit() + + if args.save and args.save_interval and \ + iteration % args.save_interval == 0: + save_checkpoint_and_time(iteration, model, optimizer, + opt_param_scheduler) + saved_checkpoint = True + + # Exiting based on duration + if args.exit_duration_in_mins: + train_time = (time.time() - _TRAIN_START_TIME) / 60.0 + done_cuda = get_accelerator().IntTensor( + [train_time > args.exit_duration_in_mins]) + torch.distributed.all_reduce( + done_cuda, op=torch.distributed.ReduceOp.MAX) + done = done_cuda.item() + if done: + if not saved_checkpoint: save_checkpoint_and_time(iteration, model, optimizer, - opt_param_scheduler) - torch.distributed.barrier() - print_datetime('exiting program at iteration {}'.format(iteration)) + opt_param_scheduler) + print_datetime('exiting program after {} minutes'.format(train_time)) sys.exit() + # Exiting based on iterations + if args.exit_interval and iteration % args.exit_interval == 0: + if args.save and not saved_checkpoint: + save_checkpoint_and_time(iteration, model, optimizer, + opt_param_scheduler) + torch.distributed.barrier() + print_datetime('exiting program at iteration {}'.format(iteration)) + sys.exit() + return iteration From 35d6d54d2d86b212ef4d589e97ead53bd0c88fe5 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 5 Jul 2024 16:46:57 +0800 Subject: [PATCH 08/10] revert rope change --- megatron/model/rotary_pos_embedding.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/megatron/model/rotary_pos_embedding.py b/megatron/model/rotary_pos_embedding.py index d36cc71490..4d4497e0cd 100644 --- a/megatron/model/rotary_pos_embedding.py +++ b/megatron/model/rotary_pos_embedding.py @@ -20,9 +20,8 @@ def __init__(self, dim, theta=10000): raise RuntimeError("einops is required for Rotary Embedding") def forward(self, max_seq_len, offset=0): - seq = torch.arange(max_seq_len, device=self.inv_freq.device, dtype=torch.float) + offset - # Force float32 since bfloat16 loses precision on long contexts - freqs = einsum('i , j -> i j', seq, self.inv_freq.float()) + seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset + freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq) # first part even vector components, second part odd vector components, # 2 * dim in dimension size emb = torch.cat((freqs, freqs), dim=-1) From 08fedf4040ec7c460505222620ff35f932ae176e Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 30 Aug 2024 05:00:34 +0000 Subject: [PATCH 09/10] fix merge --- megatron/model/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index cc8581792a..617881e3a4 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -875,7 +875,7 @@ def forward(self, hidden_states, attention_mask, if not self.use_flash_attn_triton: query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous() for x in (query_layer, key_layer, value_layer)] - batch_dim_idx = 0 + batch_dim_idx = 0 context_layer = self.dist_attn(query_layer, key_layer, value_layer, batch_dim_idx) From 8516213c3dca0b26e97aac92cbd4f5010b0e6d30 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 14 Nov 2024 09:35:03 +0000 Subject: [PATCH 10/10] merge branch --- megatron/core/tensor_parallel/layers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index a9e7beec53..58e1d1f976 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -377,17 +377,17 @@ def backward(ctx, grad_output): ctx.bwd_stream.wait_stream(get_accelerator().current_stream()) with get_accelerator().stream(ctx.bwd_stream): WeightGradStore.put(total_input, grad_output, weight, gradientUpdateFunction) - ctx.bwd_stream.activation_buffer_list = [total_input, grad_output] grad_weight = None - if args.enable_zbh1_pipeline: - from megatron.core.tensor_parallel.weight_grad_store import WeightGradStore + elif args.enable_zbh1_pipeline: WeightGradStore.put(total_input, grad_output, weight, gradientUpdateFunction) grad_weight = None else: grad_weight = grad_output.t().matmul(total_input) grad_bias = grad_output.sum(dim=0) if use_bias else None - + if ctx.bwd_stream is not None: + total_input.record_stream(ctx.bwd_stream) + grad_output.record_stream(ctx.bwd_stream) if ctx.sequence_parallel: handle.wait() return sub_grad_input, grad_weight, grad_bias, None, None, None