Skip to content

Commit

Permalink
refact: hybrid_engine dir to sharding_manager for more general repres…
Browse files Browse the repository at this point in the history
…entation (#103)
  • Loading branch information
PeterSH6 authored Jan 14, 2025
1 parent e230de8 commit 6a9f6e1
Show file tree
Hide file tree
Showing 6 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,13 @@ def _build_model_optimizer(self,
def _build_rollout(self):
if self.config.rollout.name == 'hf':
from verl.workers.rollout import HFRollout
from verl.workers.hybrid_engine import BaseShardingManager
from verl.workers.sharding_manager import BaseShardingManager
rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout)
sharding_manager = BaseShardingManager()
# TODO: a sharding manager that do nothing?
elif self.config.rollout.name == 'vllm':
from verl.workers.rollout.vllm_rollout import vLLMRollout
from verl.workers.hybrid_engine import FSDPVLLMShardingManager
from verl.workers.sharding_manager import FSDPVLLMShardingManager
log_gpu_memory_usage('Before building vllm rollout', logger=None)
rollout = vLLMRollout(actor_module=self.actor_module_fsdp,
config=self.config.rollout,
Expand Down
4 changes: 2 additions & 2 deletions verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from verl.single_controller.base.megatron.worker import MegatronWorker
from verl.workers.actor.megatron_actor import MegatronPPOActor
from verl.workers.critic.megatron_critic import MegatronPPOCritic
from verl.workers.hybrid_engine import AllGatherPPModel
from verl.workers.sharding_manager import AllGatherPPModel
from verl.workers.reward_model.megatron.reward_model import MegatronRewardModel

from verl.single_controller.base.decorator import register, Dispatch
Expand Down Expand Up @@ -216,7 +216,7 @@ def megatron_actor_model_provider(pre_process, post_process):
def _build_rollout(self):
if self.config.rollout.name == 'vllm':
from verl.workers.rollout.vllm_rollout import vLLMRollout
from verl.workers.hybrid_engine import MegatronVLLMShardingManager
from verl.workers.sharding_manager import MegatronVLLMShardingManager
from verl.utils.model import normalize_pp_vpp_params

# NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor,
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit 6a9f6e1

Please sign in to comment.