Skip to content

Commit

Permalink
Add support for pipeline parallelism (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Dec 19, 2024
1 parent 72d8a83 commit 193fe15
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 47 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Changed

- You can pass `None` for `lm_logits` in `ICLMetric.update()` to support pipeline parallelism.

## [v0.4.0](https://github.com/allenai/OLMo-in-loop-evals/releases/tag/v0.4.0) - 2024-12-18

### Added
Expand Down
121 changes: 75 additions & 46 deletions src/olmo_eval/metrics.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
from typing import Any, Dict, List, Optional
import logging
from typing import Any, Dict, List, Optional, Tuple, TypeVar

import torch
import torch.nn.functional as F
from sklearn.metrics import f1_score
from torchmetrics import Metric

from .util import all_gather_object

LOG_2_OF_E = 1.44269504089


log = logging.getLogger(__name__)


T = TypeVar("T")


def dist_combine_lists(x: List[T]) -> List[T]:
all_lists = all_gather_object(x)
return [item for sublist in all_lists for item in sublist]


class ICLMetric(Metric):
# update method does not require access to global metric state
full_state_update: bool = False
Expand All @@ -18,20 +32,33 @@ def __init__(self, metric_type="acc") -> None:

self.metric_type = metric_type

self.add_state("loglikelihoods", default=[], dist_reduce_fx=None)
self.add_state("celosses", default=[], dist_reduce_fx=None)
self.add_state("bpbs", default=[], dist_reduce_fx=None)
self.add_state("labels", default=[], dist_reduce_fx=None)
self.add_state("loglikelihoods", default=[], dist_reduce_fx=dist_combine_lists)
self.add_state("celosses", default=[], dist_reduce_fx=dist_combine_lists)
self.add_state("bpbs", default=[], dist_reduce_fx=dist_combine_lists)
self.add_state("labels", default=[], dist_reduce_fx=dist_combine_lists)

def reset(self):
self.loglikelihoods: List[Tuple[Optional[int], Optional[int], Optional[float]]] = []
self.celosses: List[Tuple[Optional[int], Optional[int], Optional[float]]] = []
self.bpbs: List[Tuple[Optional[int], Optional[int], Optional[float]]] = []
self.labels: List[Tuple[Optional[int], Optional[int], Optional[int]]] = []

def reset(
def update(
self,
batch: Dict[str, Any],
lm_logits: Optional[torch.Tensor] = None,
dc_lm_logits: Optional[torch.Tensor] = None,
):
self.loglikelihoods = []
self.celosses = []
self.bpbs = []
self.labels = []
# NOTE: `lm_logits` could be none for some ranks if using pipeline parallelism. We still
# need to add something to these state lists though in order for them to get synchronized
# for reasons not clear to me other than the fact that torchmetrics is jenky a.f.
if lm_logits is None:
self.loglikelihoods.append((None, None, None))
self.celosses.append((None, None, None))
self.bpbs.append((None, None, None))
self.labels.append((None, None, None))
return

def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=None):
lm_logits = F.log_softmax(lm_logits, dim=-1)

if self.metric_type == "pmi_dc":
Expand All @@ -40,6 +67,9 @@ def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=No
), "PMI_DC acc type selected but no domain conditional logits provided"

for idx, (doc_id, cont_id) in enumerate(zip(batch["doc_id"], batch["cont_id"])):
doc_id = int(doc_id)
cont_id = int(cont_id)

# [cont_len]: continuation is padded for batching
cont_tokens = batch["continuation"][idx][: batch["cont_len"][idx]]
# get logits from LM for the continuation: [cont_len, vocab]
Expand Down Expand Up @@ -95,66 +125,64 @@ def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=No
else:
raise ValueError(self.metric_type)

# because metric states cannot be dict/list of tuples, store this tuple as tensor: (doc_id, cont_id, metric_state)
self.loglikelihoods.append(
torch.Tensor((doc_id, cont_id, log_likelihood)).to(
batch["continuation"][idx].device
)
)
self.celosses.append(
torch.Tensor((doc_id, cont_id, celoss)).to(batch["continuation"][idx].device)
)
self.bpbs.append(
torch.Tensor((doc_id, cont_id, bpb)).to(batch["continuation"][idx].device)
)
self.labels.append(
torch.LongTensor((doc_id, cont_id, batch["label_id"][idx])).to(
batch["label_id"][idx].device
)
)
self.loglikelihoods.append((doc_id, cont_id, float(log_likelihood)))
self.labels.append((doc_id, cont_id, int(batch["label_id"][idx])))
self.celosses.append((doc_id, cont_id, float(celoss)))
self.bpbs.append((doc_id, cont_id, float(bpb)))

def compute(self) -> Dict[str, torch.Tensor]:
# Task "suffix" -> tensor

# states should have been synced from all accelerators at this point
# account for duplicates here because of DistributedSampler compensating for drop_last=False
loglikelihood_dict: Dict[int, Dict[int, float]] = {}
label_dict: Dict[int, int] = {}
celoss_dict: Dict[int, Dict[int, float]] = {}
bpb_dict: Dict[int, Dict[int, float]] = {}
label_dict = {}

# collect labels
for doc_id, cont_id, label_id in self.labels:
if doc_id.item() not in label_dict:
label_dict[doc_id.item()] = label_id.item()
if doc_id is None or cont_id is None or label_id is None:
continue

if doc_id not in label_dict:
label_dict[doc_id] = label_id

# collect loglikelihoods
for doc_id, cont_id, loglikelihood in self.loglikelihoods:
if int(doc_id.item()) not in loglikelihood_dict:
loglikelihood_dict[int(doc_id.item())] = {}
if doc_id is None or cont_id is None or loglikelihood is None:
continue

if doc_id not in loglikelihood_dict:
loglikelihood_dict[doc_id] = {}

if int(cont_id.item()) not in loglikelihood_dict[int(doc_id.item())]:
loglikelihood_dict[int(doc_id.item())][int(cont_id.item())] = loglikelihood
if cont_id not in loglikelihood_dict[doc_id]:
loglikelihood_dict[doc_id][cont_id] = loglikelihood

# collect celosses
for doc_id, cont_id, celoss in self.celosses:
if int(doc_id.item()) not in celoss_dict:
celoss_dict[int(doc_id.item())] = {}
for doc_id, cont_id, celoss_val in self.celosses:
if doc_id is None or cont_id is None or celoss_val is None:
continue

if doc_id not in celoss_dict:
celoss_dict[doc_id] = {}

if int(cont_id.item()) not in celoss_dict[int(doc_id.item())]:
celoss_dict[int(doc_id.item())][int(cont_id.item())] = celoss
if cont_id not in celoss_dict[doc_id]:
celoss_dict[doc_id][cont_id] = celoss_val

# collect bpbs
for doc_id, cont_id, bpb in self.bpbs:
if int(doc_id.item()) not in bpb_dict:
bpb_dict[int(doc_id.item())] = {}
for doc_id, cont_id, bpb_val in self.bpbs:
if doc_id is None or cont_id is None or bpb_val is None:
continue

if int(cont_id.item()) not in bpb_dict[int(doc_id.item())]:
bpb_dict[int(doc_id.item())][int(cont_id.item())] = bpb
if doc_id not in bpb_dict:
bpb_dict[doc_id] = {}

if cont_id not in bpb_dict[doc_id]:
bpb_dict[doc_id][cont_id] = bpb_val

# compute acc
correct = []
loglikelihood = []
celoss = []
bpb = []
soft_score = []
Expand Down Expand Up @@ -185,6 +213,7 @@ def compute(self) -> Dict[str, torch.Tensor]:

if skip_document:
continue

if self.metric_type == "ce_loss":
celoss.append(celosses[0]) # Only one answer is scored
elif self.metric_type == "bpb":
Expand Down
38 changes: 37 additions & 1 deletion src/olmo_eval/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import os
from contextlib import contextmanager
from pathlib import Path
from typing import Generator, Optional
from typing import Generator, List, Optional, TypeVar

import datasets
import importlib_resources
import torch.distributed as dist
from importlib_resources.abc import Traversable


Expand Down Expand Up @@ -78,3 +79,38 @@ def load_oe_eval_requests(path: str, name: Optional[str] = None, split: Optional
with open(config_file, "r") as file:
config = json.load(file)
return config, requests


def is_distributed() -> bool:
"""
Check if in a distributed context.
"""
return dist.is_available() and dist.is_initialized()


def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int:
"""
Get the world size of the default distributed process group.
.. warning::
This will always return 1 if a distributed group has not been initialized.
"""
if is_distributed():
return dist.get_world_size(group)
else:
return 0


T = TypeVar("T")


def all_gather_object(obj: T, group: Optional[dist.ProcessGroup] = None) -> List[T]:
"""
All-gather an object using pickle to all ranks in a process group.
"""
if not is_distributed():
return [obj]

output_list = [obj] * get_world_size(group)
dist.all_gather_object(output_list, obj, group=group)
return output_list

0 comments on commit 193fe15

Please sign in to comment.