Skip to content

Commit

Permalink
add license and fix build rollout
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterSH6 committed Jan 17, 2025
1 parent b541273 commit 1e424a7
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 2 deletions.
14 changes: 14 additions & 0 deletions verl/models/transformers/llama.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import torch
from typing import Optional, List, Union, Tuple, Unpack, Callable
Expand Down
14 changes: 14 additions & 0 deletions verl/models/transformers/qwen2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import torch
from typing import Optional, List, Union, Tuple, Unpack, Callable
Expand Down
6 changes: 4 additions & 2 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,10 @@ def _build_model_optimizer(self,
def _build_rollout(self):
from torch.distributed.device_mesh import init_device_mesh
# TODO(sgm): support FSDP hybrid shard for larger model
dp = self.world_size // self.config.tensor_model_parallel_size
rollout_device_mesh = init_device_mesh('cuda', mesh_shape=(dp, self.config.tensor_model_parallel_size), mesh_dim_names=['dp', 'infer_tp'])
infer_tp = self.config.rollout.tensor_model_parallel_size
dp = self.world_size // infer_tp
assert self.world_size % infer_tp == 0, f'rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}'
rollout_device_mesh = init_device_mesh('cuda', mesh_shape=(dp, infer_tp), mesh_dim_names=['dp', 'infer_tp'])

if self.config.rollout.name == 'hf':
from verl.workers.rollout import HFRollout
Expand Down

0 comments on commit 1e424a7

Please sign in to comment.