Skip to content

Commit

Permalink
allow hybrid sharding
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Aug 26, 2024
1 parent f947be9 commit bf67b66
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 8 deletions.
81 changes: 75 additions & 6 deletions src/olmo_core/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,31 @@

import os
from datetime import timedelta
from typing import TYPE_CHECKING, List, Optional, TypeVar
from typing import List, Optional, TypeVar

import torch
import torch.distributed as dist
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh

from ..exceptions import OLMoEnvironmentError

if TYPE_CHECKING:
from torch.distributed.device_mesh import DeviceMesh
from ..config import StrEnum
from ..exceptions import OLMoConfigurationError, OLMoEnvironmentError
from ..utils import get_default_device

OLMO_SHARED_FS_ENV_VAR = "OLMO_SHARED_FS"
OLMO_FS_LOCAL_RANK_ENV_VAR = "FS_LOCAL_RANK"
OLMO_LOCAL_RANK_ENV_VAR = "LOCAL_RANK"
OLMO_NUM_NODES_ENV_VAR = "NUM_NODES"
OLMO_LOCAL_WORLD_SIZE_ENV_VAR = "LOCAL_WORLD_SIZE"


def validate_env_vars():
if not is_distributed():
return

if OLMO_LOCAL_RANK_ENV_VAR not in os.environ:
raise OLMoEnvironmentError(f"Missing env var '{OLMO_LOCAL_RANK_ENV_VAR}'")
elif (

if (
os.environ.get(OLMO_SHARED_FS_ENV_VAR) != "1"
and os.environ.get(OLMO_FS_LOCAL_RANK_ENV_VAR) is None
):
Expand All @@ -31,6 +37,11 @@ def validate_env_vars():
f"If this is a shared filesystem you can set '{OLMO_SHARED_FS_ENV_VAR}=1' instead."
)

if OLMO_NUM_NODES_ENV_VAR not in os.environ and OLMO_LOCAL_WORLD_SIZE_ENV_VAR not in os.environ:
raise OLMoEnvironmentError(
f"Missing either '{OLMO_NUM_NODES_ENV_VAR}' or '{OLMO_LOCAL_WORLD_SIZE_ENV_VAR}' env vars"
)


def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minutes=30)):
"""
Expand Down Expand Up @@ -124,6 +135,38 @@ def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int:
return 0


def get_local_world_size() -> int:
"""
Get the local world size.
.. warning::
This relies on the 'LOCAL_WORLD_SIZE' env var.
:returns: The local world size.
"""
if not is_distributed():
return 1
else:
return int(os.environ[OLMO_LOCAL_WORLD_SIZE_ENV_VAR])


def get_num_nodes() -> int:
"""
Get the number of nodes.
.. warning::
This relies on either the 'NUM_NODES'or 'LOCAL_WORLD_SIZE' env var.
:returns: The number of nodes.
"""
if not is_distributed():
return 1
elif OLMO_NUM_NODES_ENV_VAR in os.environ:
return int(os.environ[OLMO_NUM_NODES_ENV_VAR])
else:
return get_world_size() // get_local_world_size()


V = TypeVar("V", bool, int, float, torch.Tensor)


Expand Down Expand Up @@ -218,3 +261,29 @@ def backend_supports_cpu():
return True
else:
return False


class HybridShardMeshDimName(StrEnum):
replicas = "replicas"
shards = "shards"


def init_hybrid_shard_mesh(
num_replicas: Optional[int] = None, device_type: Optional[str] = None
) -> DeviceMesh:
"""
Initialize a device mesh for FSDP hybrid sharding.
"""
num_replicas = num_replicas or get_num_nodes()
device_type = device_type or get_default_device().type

if get_world_size() % num_replicas != 0:
raise OLMoConfigurationError(
"hybrid mesh requires world size to be divisible by 'num_replicas'"
)

return init_device_mesh(
device_type,
(num_replicas, get_world_size() // num_replicas),
mesh_dim_names=(HybridShardMeshDimName.replicas, HybridShardMeshDimName.shards),
)
7 changes: 5 additions & 2 deletions src/scripts/train/OLMo-7B.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from olmo_core.config import Config, DType, StrEnum
from olmo_core.data import MemMapDatasetConfig
from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType
from olmo_core.distributed.utils import get_rank
from olmo_core.distributed.utils import get_num_nodes, get_rank, init_hybrid_shard_mesh
from olmo_core.launch.beaker import BeakerEnvSecret, BeakerLaunchConfig
from olmo_core.nn.transformer import TransformerConfig
from olmo_core.optim import AdamWConfig, CosWithWarmup, OptimGroupOverride
Expand Down Expand Up @@ -178,7 +178,10 @@ def train(config: ExperimentConfig):

# Build components.
model = config.model.build(
init_device="meta", device=get_default_device(), max_seq_len=config.dataset.sequence_length
init_device="meta",
device=get_default_device(),
max_seq_len=config.dataset.sequence_length,
dp_mesh=None if get_num_nodes() == 1 else init_hybrid_shard_mesh(),
)
optim = config.optim.build(model)
dataset = config.dataset.build()
Expand Down

0 comments on commit bf67b66

Please sign in to comment.