Skip to content

Commit

Permalink
Fix command args for MPI based deepspeed launchers
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Jan 20, 2025
1 parent 78b8126 commit b2b65c5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/accelerate/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
FSDP_MODEL_NAME = "pytorch_model_fsdp"
DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich"]
DEEPSPEED_MULTINODE_MPI_LAUNCHERS = ["openmpi", "mvapich", "mpich"]
TORCH_DYNAMO_MODES = ["default", "reduce-overhead", "max-autotune"]
ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION = "2.2.0"
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0"
Expand Down
16 changes: 14 additions & 2 deletions src/accelerate/utils/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
is_torch_xla_available,
is_xpu_available,
)
from ..utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS
from ..utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS, DEEPSPEED_MULTINODE_MPI_LAUNCHERS
from ..utils.other import is_port_in_use, merge_dicts
from .dataclasses import DistributedType, SageMakerDistributedType

Expand Down Expand Up @@ -320,6 +320,17 @@ def prepare_deepspeed_cmd_env(args: argparse.Namespace) -> Tuple[List[str], Dict
# set to default pdsh
args.deepspeed_multinode_launcher = DEEPSPEED_MULTINODE_LAUNCHERS[0]

if args.deepspeed_multinode_launcher in DEEPSPEED_MULTINODE_MPI_LAUNCHERS:
# MPI based launchers do not support --include and --exclude
if args.deepspeed_exclusion_filter is not None:
raise ValueError(
f"--deepspeed_exclusion_filter is not supported with --deepspeed_multinode_launcher {args.deepspeed_multinode_launcher}"
)
if args.deepspeed_inclusion_filter is not None:
raise ValueError(
f"--deepspeed_inclusion_filter is not supported with --deepspeed_multinode_launcher {args.deepspeed_multinode_launcher}"
)

if num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]:
cmd = ["deepspeed", "--no_local_rank"]
cmd.extend(["--hostfile", str(args.deepspeed_hostfile), "--launcher", str(args.deepspeed_multinode_launcher)])
Expand All @@ -337,7 +348,8 @@ def prepare_deepspeed_cmd_env(args: argparse.Namespace) -> Tuple[List[str], Dict
str(args.deepspeed_inclusion_filter),
]
)
else:
elif args.deepspeed_multinode_launcher not in DEEPSPEED_MULTINODE_MPI_LAUNCHERS:
# MPI based launchers do not support --num_gpus
cmd.extend(["--num_gpus", str(args.num_processes // args.num_machines)])
if main_process_ip:
cmd.extend(["--master_addr", str(main_process_ip)])
Expand Down

0 comments on commit b2b65c5

Please sign in to comment.