Skip to content

Commit

Permalink
[misc] feat: support mfu calculation (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
vermouth1992 authored Jan 19, 2025
1 parent 1ec5eb5 commit 41f645d
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 2 deletions.
3 changes: 3 additions & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ actor_rollout_ref:
param_offload: False
grad_offload: False
optimizer_offload: False
fsdp_size: -1
ref:
fsdp_config:
param_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
fsdp_size: -1
log_prob_micro_batch_size: 128
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
rollout:
Expand Down Expand Up @@ -94,6 +96,7 @@ critic:
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
fsdp_size: -1
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
ppo_micro_batch_size: 64
forward_micro_batch_size: ${critic.ppo_micro_batch_size}
Expand Down
3 changes: 3 additions & 0 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,9 @@ def fit(self):
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)

# compute global_valid tokens
batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist()

if self.use_reference_policy:
# compute reference log_prob
with _timer('ref', timing_raw):
Expand Down
123 changes: 123 additions & 0 deletions verl/utils/flops_counter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from transformers import PretrainedConfig, Qwen2Config, LlamaConfig

VALID_CONFIG_TYPE = (Qwen2Config, LlamaConfig)


def get_device_flops(unit="T"):

def unit_convert(number, level):
units = ["B", "K", "M", "G", "T", "P"]
if number <= 0:
return number
ptr = 0
while ptr < len(units) and units[ptr] != level:
number /= 1000
ptr += 1
return number

device_name = torch.cuda.get_device_name()
flops = float("inf") # INF flops for unkown gpu type
if "H100" in device_name or "H800" in device_name:
flops = 989e12
elif "A100" in device_name or "A800" in device_name:
flops = 312e12
elif "L40" in device_name:
flops = 181.05e12
elif "L20" in device_name:
flops = 119.5e12
elif "H20" in device_name:
flops = 148e12
elif "910B" in device_name:
flops = 354e12
flops_unit = unit_convert(flops, unit)
return flops_unit


class FlopsCounter:
"""
Used to count mfu during training loop
Example:
flops_counter = FlopsCounter(config)
flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time)
"""

def __init__(self, config: PretrainedConfig):
if not isinstance(config, VALID_CONFIG_TYPE):
print(f"Only support config type of {VALID_CONFIG_TYPE}, but got {type(config)}. "
f"MFU will always be zero.")

self.estimate_func = {"qwen2": self._estimate_qwen2_flops, 'llama': self._estimate_qwen2_flops}
self.config = config

def _estimate_unknown_flops(self, tokens_sum, batch_seqlens, delta_time):
return 0

def _estimate_qwen2_flops(self, tokens_sum, batch_seqlens, delta_time):
assert isinstance(self.config, (Qwen2Config, LlamaConfig))
hidden_size = self.config.hidden_size
vocab_size = self.config.vocab_size
num_hidden_layers = self.config.num_hidden_layers
num_key_value_heads = self.config.num_key_value_heads
num_attention_heads = self.config.num_attention_heads
intermediate_size = self.config.intermediate_size

head_dim = hidden_size // num_attention_heads
q_size = num_attention_heads * head_dim
k_size = num_key_value_heads * head_dim
v_size = num_key_value_heads * head_dim

# non-attn per layer parm
# Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp
mlp_N = hidden_size * intermediate_size * 3
attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)
emd_and_lm_head_N = vocab_size * hidden_size * 2
# non-attn all_layer parm
dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N
# non-attn all_layer & all_token fwd & bwd flops
dense_N_flops = 6 * dense_N * tokens_sum

# attn all_layer & all_token fwd & bwd flops
seqlen_square_sum = 0
for seqlen in batch_seqlens:
seqlen_square_sum += seqlen * seqlen
attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers

# all_layer & all_token fwd & bwd flops
flops_all_token = dense_N_flops + attn_qkv_flops
flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
return flops_achieved

def estimate_flops(self, batch_seqlens, delta_time):
"""
Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken.
Args:
batch_seqlens (List[int]): A list where each element represents the number of valid tokens in the current batch.
delta_time (float): The time taken to process the batch, in seconds.
Returns:
estimated_flops (float): The estimated FLOPS based on the input tokens and time.
promised_flops (float): The expected FLOPS of the current device.
"""
tokens_sum = sum(batch_seqlens)
func = self.estimate_func.get(self.config.model_type, self._estimate_unknown_flops)
estimated_flops = func(tokens_sum, batch_seqlens, delta_time)
promised_flops = get_device_flops()
return estimated_flops, promised_flops
28 changes: 26 additions & 2 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@
load_fsdp_param_and_grad
from verl.utils.import_utils import import_external_libs
from verl.utils.model import compute_position_id_with_mask
from verl.utils.flops_counter import FlopsCounter
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager

from codetiming import Timer

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN'))

Expand Down Expand Up @@ -341,6 +344,9 @@ def init_model(self):
self.config.ref.use_remove_padding = use_remove_padding
self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp)

if self._is_actor:
self.flops_counter = FlopsCounter(self.actor_model_config)

torch.cuda.empty_cache()

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
Expand All @@ -362,7 +368,13 @@ def update_actor(self, data: DataProto):
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
# perform training
metrics = self.actor.update_policy(data=data)
with Timer(name='update_policy', logger=None) as timer:
metrics = self.actor.update_policy(data=data)
delta_time = timer.last
global_num_tokens = data.meta_info['global_token_num']
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
metrics['mfu/actor'] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size

self.actor_lr_scheduler.step()
lr = self.actor_lr_scheduler.get_last_lr()[0]
metrics['actor/lr'] = lr
Expand Down Expand Up @@ -580,6 +592,8 @@ def _build_critic_model_optimizer(self, config):
if self.rank == 0:
print_model_size(critic_module)

self.critic_model_config = critic_model_config

fsdp_config = self.config.model.fsdp_config
mixed_precision_config = fsdp_config.get('mixed_precision', None)
if mixed_precision_config is not None:
Expand Down Expand Up @@ -643,6 +657,9 @@ def init_model(self):
self.critic = DataParallelPPOCritic(config=self.config,
critic_module=self.critic_module,
critic_optimizer=self.critic_optimizer)

self.flops_counter = FlopsCounter(self.critic_model_config)

torch.cuda.empty_cache()

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
Expand Down Expand Up @@ -681,7 +698,14 @@ def update_critic(self, data: DataProto):
# perform forward computation
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
metrics = self.critic.update_critic(data=data)

with Timer(name='update_critic', logger=None) as timer:
metrics = self.critic.update_critic(data=data)
delta_time = timer.last

global_num_tokens = data.meta_info['global_token_num']
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
metrics['mfu/critic'] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size

self.critic_lr_scheduler.step()
lr = self.critic_lr_scheduler.get_last_lr()[0]
Expand Down

0 comments on commit 41f645d

Please sign in to comment.