diff --git a/verl/models/transformers/llama.py b/verl/models/transformers/llama.py index cd06237..4e8a5b1 100644 --- a/verl/models/transformers/llama.py +++ b/verl/models/transformers/llama.py @@ -1,3 +1,17 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import torch from typing import Optional, List, Union, Tuple, Unpack, Callable diff --git a/verl/models/transformers/qwen2.py b/verl/models/transformers/qwen2.py index d4fdfea..05927b5 100644 --- a/verl/models/transformers/qwen2.py +++ b/verl/models/transformers/qwen2.py @@ -1,3 +1,17 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import torch from typing import Optional, List, Union, Tuple, Unpack, Callable diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 52dfe55..510d5db 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -243,8 +243,10 @@ def _build_model_optimizer(self, def _build_rollout(self): from torch.distributed.device_mesh import init_device_mesh # TODO(sgm): support FSDP hybrid shard for larger model - dp = self.world_size // self.config.tensor_model_parallel_size - rollout_device_mesh = init_device_mesh('cuda', mesh_shape=(dp, self.config.tensor_model_parallel_size), mesh_dim_names=['dp', 'infer_tp']) + infer_tp = self.config.rollout.tensor_model_parallel_size + dp = self.world_size // infer_tp + assert self.world_size % infer_tp == 0, f'rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}' + rollout_device_mesh = init_device_mesh('cuda', mesh_shape=(dp, infer_tp), mesh_dim_names=['dp', 'infer_tp']) if self.config.rollout.name == 'hf': from verl.workers.rollout import HFRollout