Skip to content

Commit

Permalink
fix merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Jan 22, 2025
2 parents 12cffc9 + 212108f commit 13a4c42
Show file tree
Hide file tree
Showing 20 changed files with 434 additions and 158 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ This major release introduces a few breaking changes. As such, we've provided an
- Added new LR schedulers: `LinearWithWarmup`, `InvSqrtWithWarmup`, `ConstantWithWarmup`, `SequentialScheduler`.
- Added option to pre-download checkpoint files from remote storage before trying to load a checkpoint.
- Added a callback for sending Slack notifications.
- Added `SkipStepAdamW` optimizer.
- The trainer can load model-only checkpoints now.
- Added the option to throttle checkpoint uploads to one rank from each node at a time.

### Changed

Expand All @@ -45,6 +47,7 @@ This major release introduces a few breaking changes. As such, we've provided an
### Fixed

- Added missing `weights_only=False` argument to fix loading train checkpoints with newer versions of PyTorch.
- Fixed bug where GCS upload does not retry on transient failures.

## [v1.7.0](https://github.com/allenai/OLMo-core/releases/tag/v1.7.0) - 2024-11-27

Expand Down
54 changes: 48 additions & 6 deletions src/olmo_core/distributed/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def save_state_dict(
*,
process_group: Optional[dist.ProcessGroup] = None,
save_overwrite: bool = False,
thread_count: Optional[int] = None,
throttle_uploads: bool = False,
):
"""
Save an arbitrary state dictionary to a distributed format that can loaded again with
Expand All @@ -79,12 +81,20 @@ def save_state_dict(
:param state_dict: The state dict to save.
:param process_group: The process group to use for distributed collectives.
:param save_overwrite: Overwrite existing files.
:param thread_count: Set this to override the number of threads used while writing data.
:param throttle_uploads: If this is set to ``True`` and ``dir`` is a URL then only one
rank from each node will upload data at a time.
"""
dir = _prepare_env_for_save(dir, process_group=process_group, save_overwrite=save_overwrite)
planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True)
dist_cp.state_dict_saver.save(
state_dict,
storage_writer=RemoteFileSystemWriter(dir),
storage_writer=RemoteFileSystemWriter(
dir,
thread_count=thread_count,
process_group=process_group,
throttle_uploads=throttle_uploads,
),
process_group=process_group,
planner=planner,
)
Expand All @@ -97,6 +107,8 @@ def async_save_state_dict(
*,
process_group: Optional[dist.ProcessGroup] = None,
save_overwrite: bool = False,
thread_count: Optional[int] = None,
throttle_uploads: bool = False,
) -> Future[None]:
"""
An async version of :func:`save_state_dict()`.
Expand All @@ -107,7 +119,12 @@ def async_save_state_dict(
planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True)
return dist_cp.state_dict_saver.async_save(
state_dict,
storage_writer=RemoteFileSystemWriter(dir),
storage_writer=RemoteFileSystemWriter(
dir,
thread_count=thread_count,
process_group=process_group,
throttle_uploads=throttle_uploads,
),
process_group=process_group,
planner=planner,
)
Expand All @@ -121,16 +138,20 @@ def load_state_dict(
process_group: Optional[dist.ProcessGroup] = None,
pre_download: bool = False,
work_dir: Optional[PathOrStr] = None,
thread_count: Optional[int] = None,
):
"""
Load an arbitrary state dict in-place from a checkpoint saved with :func:`save_state_dict()`.
:param dir: Path/URL to the checkpoint saved via :func:`save_state_dict()`.
:param state_dict: The state dict to load the state into.
:param process_group: The process group to use for distributed collectives.
:param thread_count: Set the number of threads used for certain operations.
"""
dir = normalize_path(dir)
reader = RemoteFileSystemReader(dir, pre_download=pre_download, work_dir=work_dir)
reader = RemoteFileSystemReader(
dir, thread_count=thread_count, pre_download=pre_download, work_dir=work_dir
)
dist_cp.load(
state_dict,
checkpoint_id=dir,
Expand All @@ -148,6 +169,8 @@ def save_model_and_optim_state(
process_group: Optional[dist.ProcessGroup] = None,
save_overwrite: bool = False,
flatten_optimizer_state: bool = False,
thread_count: Optional[int] = None,
throttle_uploads: bool = False,
) -> None:
"""
Save model and optimizer state dictionaries. The model state can be a sharded model, in which
Expand All @@ -173,6 +196,9 @@ def save_model_and_optim_state(
:param flatten_optimizer_state: Flatten the optimizer state before saving. This should match
the setting used when loading the state dict and is needed in a distributed setting when
the params in some param groups may differ between ranks, such as with pipeline parallelism.
:param thread_count: Set this to override the number of threads used while writing data.
:param throttle_uploads: If this is set to ``True`` and ``dir`` is a URL then only one
rank from each node will upload data at a time.
:raises FileExistsError: If the checkpoint dir exists and is non-empty unless ``save_overwrite=True``.
"""
Expand All @@ -186,7 +212,12 @@ def save_model_and_optim_state(
planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True)
dist_cp.state_dict_saver.save(
state_dict,
storage_writer=RemoteFileSystemWriter(dir),
storage_writer=RemoteFileSystemWriter(
dir,
thread_count=thread_count,
process_group=process_group,
throttle_uploads=throttle_uploads,
),
process_group=process_group,
planner=planner,
)
Expand All @@ -201,6 +232,8 @@ def async_save_model_and_optim_state(
process_group: Optional[dist.ProcessGroup] = None,
save_overwrite: bool = False,
flatten_optimizer_state: bool = False,
thread_count: Optional[int] = None,
throttle_uploads: bool = False,
) -> Future[None]:
"""
An async version of :func:`save_model_and_optim_state()`.
Expand All @@ -217,7 +250,12 @@ def async_save_model_and_optim_state(
planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True)
return dist_cp.state_dict_saver.async_save(
state_dict,
storage_writer=RemoteFileSystemWriter(dir),
storage_writer=RemoteFileSystemWriter(
dir,
thread_count=thread_count,
process_group=process_group,
throttle_uploads=throttle_uploads,
),
process_group=process_group,
planner=planner,
)
Expand All @@ -235,6 +273,7 @@ def load_model_and_optim_state(
work_dir: Optional[PathOrStr] = None,
strict: bool = True,
flatten_optimizer_state: bool = False,
thread_count: Optional[int] = None,
):
"""
Load model and optimizer state in-place from a checkpoint saved via :func:`save_model_and_optim_state()`.
Expand Down Expand Up @@ -276,12 +315,15 @@ def load_model_and_optim_state(
:param flatten_optimizer_state: Flatten the optimizer state when loading. This should match
the setting used when saving the state dict and is needed in a distributed setting when
the params in some param groups may differ between ranks, such as with pipeline parallelism.
:param thread_count: Set the number of threads used for certain operations.
"""
dir = normalize_path(dir)
state_dict = _prepare_state_dict(
model, optim, process_group=process_group, flatten_optimizer_state=flatten_optimizer_state
)
reader = RemoteFileSystemReader(dir, pre_download=pre_download, work_dir=work_dir)
reader = RemoteFileSystemReader(
dir, thread_count=thread_count, pre_download=pre_download, work_dir=work_dir
)

if key_mapping is not None:
metadata = reader.read_metadata()
Expand Down
36 changes: 28 additions & 8 deletions src/olmo_core/distributed/checkpoint/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple, cast

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
from torch.distributed.checkpoint.filesystem import WriteResult
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta
Expand All @@ -25,6 +27,7 @@
from torch.futures import Future

from olmo_core.aliases import PathOrStr
from olmo_core.distributed.utils import do_n_at_a_time
from olmo_core.exceptions import OLMoCheckpointError
from olmo_core.io import (
get_bytes_range,
Expand Down Expand Up @@ -154,12 +157,16 @@ def __init__(
self,
path: PathOrStr,
thread_count: Optional[int] = None,
process_group: Optional[dist.ProcessGroup] = None,
throttle_uploads: bool = False,
) -> None:
super().__init__()
if thread_count is not None and thread_count <= 0:
raise ValueError("thread count must be at least 1")
self.path = normalize_path(path)
self.thread_count = thread_count or get_default_thread_count()
self.process_group = process_group
self.throttle_uploads = throttle_uploads
self.save_id = generate_uuid()

def reset(self, checkpoint_id: Optional[PathOrStr] = None) -> None:
Expand Down Expand Up @@ -201,22 +208,35 @@ def gen_file_name() -> str:
file_count += 1
return file_name

with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
futures = []
for bucket in _split_by_size_and_type(self.thread_count, plan.items):
def write_items(buckets: List[List[WriteItem]]) -> List[WriteResult]:
results: List[WriteResult] = []
for bucket in buckets:
file_name = gen_file_name()
path = f"{self.path}/{file_name}"
futures.append(executor.submit(_write_items, path, file_name, bucket, planner))

results = []
for f in as_completed(futures):
try:
results += f.result()
results.extend(_write_items(path, file_name, bucket, planner))
except BaseException:
# NOTE: we might get an error here that can't be pickled, which causes a different failure
# later when PyTorch tries to reduce that error across ranks. So here we just make
# sure we're raising a simple error type that can be pickled.
raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
return results

results: List[WriteResult]
if self.throttle_uploads and is_url(self.path):
buckets = _split_by_size_and_type(1, plan.items)
results = do_n_at_a_time(
partial(write_items, buckets), process_group=self.process_group
)
else:
buckets = _split_by_size_and_type(self.thread_count, plan.items)
results = []
with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
futures = []
for bucket in buckets:
futures.append(executor.submit(write_items, [bucket]))
for f in as_completed(futures):
results.extend(f.result())

fut: Future[List[WriteResult]] = Future()
fut.set_result(results)
Expand Down
37 changes: 36 additions & 1 deletion src/olmo_core/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
"""

import logging
import math
import os
from datetime import timedelta
from typing import List, Optional, TypeVar
from typing import Callable, List, Optional, TypeVar, cast

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -92,6 +93,7 @@ def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minut
"enp6s0,enp7s0,enp13s0,enp14s0,enp134s0,enp135s0,enp141s0,enp142s0",
)
set_env_var("NCCL_SOCKET_IFNAME", "enp0s12")
set_env_var("NCCL_DEBUG_SUBSYS", "INIT,NET")

if backend_supports_cuda(backend):
# Set CUDA device.
Expand Down Expand Up @@ -420,3 +422,36 @@ def get_local_tensor(x: torch.Tensor) -> torch.Tensor:
return x.to_local()
else:
return x


def do_n_at_a_time(
f: Callable[[], T],
*,
n: Optional[int] = None,
process_group: Optional[dist.ProcessGroup] = None,
world_size: Optional[int] = None,
local_rank: Optional[int] = None,
) -> T:
"""
Call a function ``f`` in a distributed context from at most ``n`` ranks at a time.
All ranks will eventually call the given function exactly once, at which point this function
will return.
:param f: The function to call from each rank.
:param n: The level of concurrency, i.e. how many ranks are allowed to call ``f`` at once.
This defaults to the number of nodes, in which case one rank from each node will
call ``f`` at a time.
:param process_group: The process group to use.
"""
world_size = world_size if world_size is not None else get_world_size(process_group)
local_rank = local_rank if local_rank is not None else get_rank(process_group)
n = n if n is not None else get_num_nodes()
group_count = math.ceil(world_size / n)
group_rank = local_rank % group_count
result: Optional[T] = None
for active_group in range(group_count):
if group_rank == active_group:
result = f()
barrier(process_group)
return cast(T, result)
1 change: 1 addition & 0 deletions src/olmo_core/internal/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def build_launch_config(
"conda shell.bash activate base",
# "pip install 'ai2-olmo-eval @ git+https://[email protected]/allenai/OLMo-in-loop-evals.git@epwalsh/debug'",
"pip install -e '.[all]'",
"pip install --upgrade beaker-py",
# Quickly try a new version of PyTorch like this
# "pip install --upgrade --pre torch==2.6.0.dev20241112+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121",
"pip freeze",
Expand Down
1 change: 1 addition & 0 deletions src/olmo_core/internal/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def build_common_components(
root_dir=root_dir,
cmd=[script, cmd_to_launch, run_name, cluster, *overrides],
cluster=cluster,
nccl_debug=False,
)

beaker_user = get_beaker_username()
Expand Down
Loading

0 comments on commit 13a4c42

Please sign in to comment.