From b2b65c595d0d6a9f73bd4cac5a6ccfe8eed59cc6 Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Mon, 20 Jan 2025 22:29:01 +0530 Subject: [PATCH] Fix command args for MPI based deepspeed launchers --- src/accelerate/utils/constants.py | 1 + src/accelerate/utils/launch.py | 16 ++++++++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/accelerate/utils/constants.py b/src/accelerate/utils/constants.py index a6d7d262678..3f5a296032a 100644 --- a/src/accelerate/utils/constants.py +++ b/src/accelerate/utils/constants.py @@ -42,6 +42,7 @@ ) FSDP_MODEL_NAME = "pytorch_model_fsdp" DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich"] +DEEPSPEED_MULTINODE_MPI_LAUNCHERS = ["openmpi", "mvapich", "mpich"] TORCH_DYNAMO_MODES = ["default", "reduce-overhead", "max-autotune"] ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION = "2.2.0" XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0" diff --git a/src/accelerate/utils/launch.py b/src/accelerate/utils/launch.py index c6f3d60031d..011e99a46ab 100644 --- a/src/accelerate/utils/launch.py +++ b/src/accelerate/utils/launch.py @@ -34,7 +34,7 @@ is_torch_xla_available, is_xpu_available, ) -from ..utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS +from ..utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS, DEEPSPEED_MULTINODE_MPI_LAUNCHERS from ..utils.other import is_port_in_use, merge_dicts from .dataclasses import DistributedType, SageMakerDistributedType @@ -320,6 +320,17 @@ def prepare_deepspeed_cmd_env(args: argparse.Namespace) -> Tuple[List[str], Dict # set to default pdsh args.deepspeed_multinode_launcher = DEEPSPEED_MULTINODE_LAUNCHERS[0] + if args.deepspeed_multinode_launcher in DEEPSPEED_MULTINODE_MPI_LAUNCHERS: + # MPI based launchers do not support --include and --exclude + if args.deepspeed_exclusion_filter is not None: + raise ValueError( + f"--deepspeed_exclusion_filter is not supported with --deepspeed_multinode_launcher {args.deepspeed_multinode_launcher}" + ) + if args.deepspeed_inclusion_filter is not None: + raise ValueError( + f"--deepspeed_inclusion_filter is not supported with --deepspeed_multinode_launcher {args.deepspeed_multinode_launcher}" + ) + if num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]: cmd = ["deepspeed", "--no_local_rank"] cmd.extend(["--hostfile", str(args.deepspeed_hostfile), "--launcher", str(args.deepspeed_multinode_launcher)]) @@ -337,7 +348,8 @@ def prepare_deepspeed_cmd_env(args: argparse.Namespace) -> Tuple[List[str], Dict str(args.deepspeed_inclusion_filter), ] ) - else: + elif args.deepspeed_multinode_launcher not in DEEPSPEED_MULTINODE_MPI_LAUNCHERS: + # MPI based launchers do not support --num_gpus cmd.extend(["--num_gpus", str(args.num_processes // args.num_machines)]) if main_process_ip: cmd.extend(["--master_addr", str(main_process_ip)])