Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for pipeline parallelism #3

Merged
merged 10 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Changed

- You can pass `None` for `lm_logits` in `ICLMetric.update()` to support pipeline parallelism.

## [v0.4.0](https://github.com/allenai/OLMo-in-loop-evals/releases/tag/v0.4.0) - 2024-12-18

### Added
Expand Down
121 changes: 75 additions & 46 deletions src/olmo_eval/metrics.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
from typing import Any, Dict, List, Optional
import logging
from typing import Any, Dict, List, Optional, Tuple, TypeVar

import torch
import torch.nn.functional as F
from sklearn.metrics import f1_score
from torchmetrics import Metric

from .util import all_gather_object

LOG_2_OF_E = 1.44269504089


log = logging.getLogger(__name__)


T = TypeVar("T")


def dist_combine_lists(x: List[T]) -> List[T]:
all_lists = all_gather_object(x)
return [item for sublist in all_lists for item in sublist]


class ICLMetric(Metric):
# update method does not require access to global metric state
full_state_update: bool = False
Expand All @@ -18,20 +32,33 @@ def __init__(self, metric_type="acc") -> None:

self.metric_type = metric_type

self.add_state("loglikelihoods", default=[], dist_reduce_fx=None)
self.add_state("celosses", default=[], dist_reduce_fx=None)
self.add_state("bpbs", default=[], dist_reduce_fx=None)
self.add_state("labels", default=[], dist_reduce_fx=None)
self.add_state("loglikelihoods", default=[], dist_reduce_fx=dist_combine_lists)
self.add_state("celosses", default=[], dist_reduce_fx=dist_combine_lists)
self.add_state("bpbs", default=[], dist_reduce_fx=dist_combine_lists)
self.add_state("labels", default=[], dist_reduce_fx=dist_combine_lists)

def reset(self):
self.loglikelihoods: List[Tuple[Optional[int], Optional[int], Optional[float]]] = []
self.celosses: List[Tuple[Optional[int], Optional[int], Optional[float]]] = []
self.bpbs: List[Tuple[Optional[int], Optional[int], Optional[float]]] = []
self.labels: List[Tuple[Optional[int], Optional[int], Optional[int]]] = []

def reset(
def update(
self,
batch: Dict[str, Any],
lm_logits: Optional[torch.Tensor] = None,
dc_lm_logits: Optional[torch.Tensor] = None,
):
self.loglikelihoods = []
self.celosses = []
self.bpbs = []
self.labels = []
# NOTE: `lm_logits` could be none for some ranks if using pipeline parallelism. We still
# need to add something to these state lists though in order for them to get synchronized
# for reasons not clear to me other than the fact that torchmetrics is jenky a.f.
if lm_logits is None:
self.loglikelihoods.append((None, None, None))
self.celosses.append((None, None, None))
self.bpbs.append((None, None, None))
self.labels.append((None, None, None))
return

def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=None):
lm_logits = F.log_softmax(lm_logits, dim=-1)

if self.metric_type == "pmi_dc":
Expand All @@ -40,6 +67,9 @@ def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=No
), "PMI_DC acc type selected but no domain conditional logits provided"

for idx, (doc_id, cont_id) in enumerate(zip(batch["doc_id"], batch["cont_id"])):
doc_id = int(doc_id)
cont_id = int(cont_id)

# [cont_len]: continuation is padded for batching
cont_tokens = batch["continuation"][idx][: batch["cont_len"][idx]]
# get logits from LM for the continuation: [cont_len, vocab]
Expand Down Expand Up @@ -95,66 +125,64 @@ def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=No
else:
raise ValueError(self.metric_type)

# because metric states cannot be dict/list of tuples, store this tuple as tensor: (doc_id, cont_id, metric_state)
self.loglikelihoods.append(
torch.Tensor((doc_id, cont_id, log_likelihood)).to(
batch["continuation"][idx].device
)
)
self.celosses.append(
torch.Tensor((doc_id, cont_id, celoss)).to(batch["continuation"][idx].device)
)
self.bpbs.append(
torch.Tensor((doc_id, cont_id, bpb)).to(batch["continuation"][idx].device)
)
self.labels.append(
torch.LongTensor((doc_id, cont_id, batch["label_id"][idx])).to(
batch["label_id"][idx].device
)
)
self.loglikelihoods.append((doc_id, cont_id, float(log_likelihood)))
self.labels.append((doc_id, cont_id, int(batch["label_id"][idx])))
self.celosses.append((doc_id, cont_id, float(celoss)))
self.bpbs.append((doc_id, cont_id, float(bpb)))

def compute(self) -> Dict[str, torch.Tensor]:
# Task "suffix" -> tensor

# states should have been synced from all accelerators at this point
# account for duplicates here because of DistributedSampler compensating for drop_last=False
loglikelihood_dict: Dict[int, Dict[int, float]] = {}
label_dict: Dict[int, int] = {}
celoss_dict: Dict[int, Dict[int, float]] = {}
bpb_dict: Dict[int, Dict[int, float]] = {}
label_dict = {}

# collect labels
for doc_id, cont_id, label_id in self.labels:
if doc_id.item() not in label_dict:
label_dict[doc_id.item()] = label_id.item()
if doc_id is None or cont_id is None or label_id is None:
continue

if doc_id not in label_dict:
label_dict[doc_id] = label_id

# collect loglikelihoods
for doc_id, cont_id, loglikelihood in self.loglikelihoods:
if int(doc_id.item()) not in loglikelihood_dict:
loglikelihood_dict[int(doc_id.item())] = {}
if doc_id is None or cont_id is None or loglikelihood is None:
continue

if doc_id not in loglikelihood_dict:
loglikelihood_dict[doc_id] = {}

if int(cont_id.item()) not in loglikelihood_dict[int(doc_id.item())]:
loglikelihood_dict[int(doc_id.item())][int(cont_id.item())] = loglikelihood
if cont_id not in loglikelihood_dict[doc_id]:
loglikelihood_dict[doc_id][cont_id] = loglikelihood

# collect celosses
for doc_id, cont_id, celoss in self.celosses:
if int(doc_id.item()) not in celoss_dict:
celoss_dict[int(doc_id.item())] = {}
for doc_id, cont_id, celoss_val in self.celosses:
if doc_id is None or cont_id is None or celoss_val is None:
continue

if doc_id not in celoss_dict:
celoss_dict[doc_id] = {}

if int(cont_id.item()) not in celoss_dict[int(doc_id.item())]:
celoss_dict[int(doc_id.item())][int(cont_id.item())] = celoss
if cont_id not in celoss_dict[doc_id]:
celoss_dict[doc_id][cont_id] = celoss_val

# collect bpbs
for doc_id, cont_id, bpb in self.bpbs:
if int(doc_id.item()) not in bpb_dict:
bpb_dict[int(doc_id.item())] = {}
for doc_id, cont_id, bpb_val in self.bpbs:
if doc_id is None or cont_id is None or bpb_val is None:
continue

if int(cont_id.item()) not in bpb_dict[int(doc_id.item())]:
bpb_dict[int(doc_id.item())][int(cont_id.item())] = bpb
if doc_id not in bpb_dict:
bpb_dict[doc_id] = {}

if cont_id not in bpb_dict[doc_id]:
bpb_dict[doc_id][cont_id] = bpb_val

# compute acc
correct = []
loglikelihood = []
celoss = []
bpb = []
soft_score = []
Expand Down Expand Up @@ -185,6 +213,7 @@ def compute(self) -> Dict[str, torch.Tensor]:

if skip_document:
continue

if self.metric_type == "ce_loss":
celoss.append(celosses[0]) # Only one answer is scored
elif self.metric_type == "bpb":
Expand Down
38 changes: 37 additions & 1 deletion src/olmo_eval/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import os
from contextlib import contextmanager
from pathlib import Path
from typing import Generator, Optional
from typing import Generator, List, Optional, TypeVar

import datasets
import importlib_resources
import torch.distributed as dist
from importlib_resources.abc import Traversable


Expand Down Expand Up @@ -78,3 +79,38 @@ def load_oe_eval_requests(path: str, name: Optional[str] = None, split: Optional
with open(config_file, "r") as file:
config = json.load(file)
return config, requests


def is_distributed() -> bool:
"""
Check if in a distributed context.
"""
return dist.is_available() and dist.is_initialized()


def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int:
"""
Get the world size of the default distributed process group.

.. warning::
This will always return 1 if a distributed group has not been initialized.
"""
if is_distributed():
return dist.get_world_size(group)
else:
return 0


T = TypeVar("T")


def all_gather_object(obj: T, group: Optional[dist.ProcessGroup] = None) -> List[T]:
"""
All-gather an object using pickle to all ranks in a process group.
"""
if not is_distributed():
return [obj]

output_list = [obj] * get_world_size(group)
dist.all_gather_object(output_list, obj, group=group)
return output_list
Loading