Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalizable multi gpu to run e.g. Llama 65b #238

Open
wants to merge 45 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
79a0be0
add llama map
thejaminator May 1, 2023
d692b7f
add typechecking if
thejaminator May 1, 2023
91d98d2
allocate the device properly
thejaminator May 1, 2023
aa5ee9e
print to debug
thejaminator May 1, 2023
7cbfb9a
change to device config
thejaminator May 1, 2023
cfb3200
more logs
thejaminator May 1, 2023
8cb4dec
print the value
thejaminator May 1, 2023
9e8e321
fix not returning configs
thejaminator May 1, 2023
36ea2e6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 1, 2023
a59d0a2
test the effect of not returning the past key values
thejaminator May 1, 2023
5fc1b5f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 1, 2023
c2ee397
add kwargs
thejaminator May 1, 2023
6a8de4d
Merge remote-tracking branch 'origin/main' into hardcoded-llama65-map
thejaminator May 2, 2023
6abddbf
add device map 0
thejaminator May 2, 2023
4953544
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 2, 2023
a2ceeb9
fix 8bit mem
thejaminator May 2, 2023
5409b01
make pyright happy
thejaminator May 2, 2023
0db53b9
implement multi gpu
thejaminator May 2, 2023
50edafe
Merge remote-tracking branch 'origin/main' into generalizable-multi-gpu
thejaminator May 2, 2023
5ee3c3a
add cli
thejaminator May 2, 2023
8d6279c
redirect only later
thejaminator May 2, 2023
47529b6
add logs and remove llama
thejaminator May 2, 2023
4a49aa0
fix keyword
thejaminator May 2, 2023
91a06b6
try out lm head
thejaminator May 3, 2023
d74c9b6
shift it to 0.8 instead
thejaminator May 3, 2023
e6eb9c1
try hardcoded map
thejaminator May 3, 2023
06b1a11
decrease further for gpu 1
thejaminator May 3, 2023
64919b3
fix import
thejaminator May 3, 2023
fe331bc
remove syntax
thejaminator May 3, 2023
0ed7f31
try comparing to hardcoding
thejaminator May 3, 2023
69bbf64
Revert "try comparing to hardcoding"
thejaminator May 3, 2023
051e2fd
Merge remote-tracking branch 'origin/main' into generalizable-multi-gpu
thejaminator May 3, 2023
8c6386c
add comment on future improvement
thejaminator May 3, 2023
0182a64
print
thejaminator May 3, 2023
df1c0ff
load in 8bit correctly
thejaminator May 3, 2023
6d9e9ea
add comment
thejaminator May 3, 2023
02602cb
try passing float16?
thejaminator May 3, 2023
d3a8f29
prevent mem issues?
thejaminator May 3, 2023
bf827ea
add logs
thejaminator May 3, 2023
301e6e2
try only adding load_in_8bit if we really need to
thejaminator May 3, 2023
6b6bb6f
catch max mem
thejaminator May 3, 2023
a5b3d5f
Revert "try only adding load_in_8bit if we really need to"
thejaminator May 3, 2023
99db2a0
try out means of memory
thejaminator May 3, 2023
55b18ab
remove debug print
thejaminator May 3, 2023
fa52400
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ The following runs `elicit` on the Cartesian product of the listed models and da
elk sweep --models gpt2-{medium,large,xl} --datasets imdb amazon_polarity --add_pooled
```

## Running big models
For big models that cannot fit on a single gpu, you'll need to use multiple
gpus per model.

This is an example to run a single 8bit llama-65b model on 2 A40s that have
~50 GB of memory each.

```
elk elicit huggyllama/llama-65b imdb --num_gpus 2 --gpus_per_model 2 --int8
```

## Caching

The hidden states resulting from `elk elicit` are cached as a HuggingFace dataset to avoid having to recompute them every time we want to train a probe. The cache is stored in the same place as all other HuggingFace datasets, which is usually `~/.cache/huggingface/datasets`.
Expand Down
47 changes: 27 additions & 20 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Functions for extracting the hidden states of a model."""
import logging
import os
from contextlib import nullcontext, redirect_stdout
from dataclasses import InitVar, dataclass, replace
from itertools import zip_longest
from typing import Any, Iterable, Literal
Expand Down Expand Up @@ -34,13 +33,16 @@
float_to_int16,
infer_label_column,
infer_num_classes,
instantiate_model,
instantiate_tokenizer,
is_autoregressive,
prevent_name_conflicts,
select_split,
select_train_val_splits,
select_usable_devices,
)
from ..utils.multi_gpu import (
ModelDevices,
instantiate_model_with_devices,
select_devices_multi_gpus,
)
from .dataset_name import (
DatasetDictWithName,
Expand Down Expand Up @@ -149,29 +151,33 @@ def explode(self) -> list["Extract"]:
def extract_hiddens(
cfg: "Extract",
*,
device: str | torch.device = "cpu",
devices: ModelDevices,
split_type: Literal["train", "val"] = "train",
rank: int = 0,
world_size: int = 1,
) -> Iterable[dict]:
first_device = (
devices if not isinstance(devices, ModelDevices) else devices.first_device
)
"""Run inference on a model with a set of prompts, yielding the hidden states."""
os.environ["TOKENIZERS_PARALLELISM"] = "false"

is_verbose = rank == 0

# Silence datasets logging messages from all but the first process
if rank != 0:
if not is_verbose:
filterwarnings("ignore")
logging.disable(logging.CRITICAL)

ds_names = cfg.datasets
assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time."

# We use contextlib.redirect_stdout to prevent `bitsandbytes` from printing its
# welcome message on every rank
with redirect_stdout(None) if rank != 0 else nullcontext():
model = instantiate_model(cfg.model, device=device, load_in_8bit=cfg.int8)
tokenizer = instantiate_tokenizer(
cfg.model, truncation_side="left", verbose=rank == 0
)
model = instantiate_model_with_devices(
cfg=cfg, device_config=devices, is_verbose=is_verbose
)
tokenizer = instantiate_tokenizer(
cfg.model, truncation_side="left", verbose=is_verbose
)

is_enc_dec = model.config.is_encoder_decoder
if is_enc_dec and cfg.use_encoder_states:
Expand Down Expand Up @@ -225,15 +231,15 @@ def extract_hiddens(
num_variants,
num_choices,
model.config.hidden_size,
device=device,
device=first_device,
dtype=torch.int16,
)
for layer_idx in layer_indices
}
lm_logits = torch.empty(
num_variants,
num_choices,
device=device,
device=first_device,
dtype=torch.float32,
)
text_questions = []
Expand All @@ -254,8 +260,7 @@ def extract_hiddens(
add_special_tokens=True,
return_tensors="pt",
text_target=target, # type: ignore[arg-type]
).to(device)

).to(first_device)
input_ids = assert_type(Tensor, encoding.input_ids)
if is_enc_dec:
answer = assert_type(Tensor, encoding.labels)
Expand All @@ -265,8 +270,7 @@ def extract_hiddens(
# Don't include [CLS] and [SEP] in the answer
add_special_tokens=False,
return_tensors="pt",
).to(device)

).to(first_device)
answer = assert_type(Tensor, encoding2.input_ids)
input_ids = torch.cat([input_ids, answer], dim=-1)

Expand Down Expand Up @@ -413,13 +417,16 @@ def extract(
disable_cache: bool = False,
highlight_color: Color = "cyan",
num_gpus: int = -1,
gpus_per_model: int = 1,
min_gpu_mem: int | None = None,
split_type: Literal["train", "val", None] = None,
) -> DatasetDictWithName:
"""Extract hidden states from a model and return a `DatasetDict` containing them."""
info, features = hidden_features(cfg)

devices = select_usable_devices(num_gpus, min_memory=min_gpu_mem)
devices: list[ModelDevices] = select_devices_multi_gpus(
gpus_per_model=gpus_per_model, num_gpus=num_gpus, min_memory=min_gpu_mem
)
limits = cfg.max_examples
splits = assert_type(SplitDict, info.splits)

Expand Down Expand Up @@ -455,7 +462,7 @@ def extract(
),
gen_kwargs=dict(
cfg=[cfg] * len(devices),
device=devices,
devices=devices,
rank=list(range(len(devices))),
split_type=[ty] * len(devices),
world_size=[len(devices)] * len(devices),
Expand Down
2 changes: 1 addition & 1 deletion elk/extraction/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def create_config_id(
config_kwargs["gen_kwargs"] = {
k: v[0]
for k, v in config_kwargs.get("gen_kwargs", {}).items()
if k not in ("device", "rank", "world_size")
if k not in ("devices", "rank", "world_size")
}
return super().create_config_id(config_kwargs, custom_features)

Expand Down
2 changes: 2 additions & 0 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Run(ABC, Serializable):
num_gpus: int = -1
out_dir: Path | None = None
disable_cache: bool = field(default=False, to_dict=False)
gpus_per_model: int = 1

def execute(
self,
Expand All @@ -61,6 +62,7 @@ def execute(
disable_cache=self.disable_cache,
highlight_color=highlight_color,
num_gpus=self.num_gpus,
gpus_per_model=self.gpus_per_model,
min_gpu_mem=self.min_gpu_mem,
split_type=split_type,
)
Expand Down
15 changes: 15 additions & 0 deletions elk/utils/gpu_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,18 @@ def select_usable_devices(
print(f"Using {len(selection)} of {num_visible} GPUs: {selection}")

return [f"cuda:{i}" for i in selection]


def get_available_memory_for_devices() -> dict[str, int]:
# PyNVML and PyTorch device indices should agree when CUDA_VISIBLE_DEVICES is
# not set. We need them to agree so that the PyNVML indices match the PyTorch
# indices, and we don't have to do any complex error-prone conversions.
num_visible = torch.cuda.device_count()
num_installed = pynvml.nvmlDeviceGetCount()
assert num_installed == num_visible, "PyNVML and PyTorch disagree on GPU count"
output = {}
# Get free memory for each GPU
for i in range(num_installed):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
output[f"cuda:{i}"] = int(pynvml.nvmlDeviceGetMemoryInfo(handle).free)
return output
78 changes: 47 additions & 31 deletions elk/utils/hf_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch
import transformers
from transformers import (
Expand All @@ -20,44 +22,59 @@
_AUTOREGRESSIVE_SUFFIXES = ["ConditionalGeneration"] + _DECODER_ONLY_SUFFIXES


def determine_dtypes(
model_str: str,
is_cpu: bool,
load_in_8bit: bool,
) -> torch.dtype | str:
model_cfg = AutoConfig.from_pretrained(model_str)

# When the torch_dtype is None, this generally means the model is fp32, because
# the config was probably created before the `torch_dtype` field was added.
fp32_weights = model_cfg.torch_dtype in (None, torch.float32)

# Required by `bitsandbytes` to load in 8-bit.
if load_in_8bit:
# Sanity check: we probably shouldn't be loading in 8-bit if the checkpoint
# is in fp32. `bitsandbytes` only supports mixed fp16/int8 inference, and
# we can't guarantee that there won't be overflow if we downcast to fp16.
if fp32_weights:
raise ValueError("Cannot load in 8-bit if weights are fp32")

torch_dtype = torch.float16

# CPUs generally don't support anything other than fp32.
elif is_cpu:
torch_dtype = torch.float32

# If the model is fp32 but bf16 is available, convert to bf16.
# Usually models with fp32 weights were actually trained in bf16, and
# converting them doesn't hurt performance.
elif fp32_weights and torch.cuda.is_bf16_supported():
torch_dtype = torch.bfloat16
print("Weights seem to be fp32, but bf16 is available. Loading in bf16.")
else:
torch_dtype = "auto"
return torch_dtype


def instantiate_model(
model_str: str,
device: str | torch.device = "cpu",
load_in_8bit: bool,
is_cpu: bool,
torch_dtype: Optional[torch.dtype] = None,
**kwargs,
) -> PreTrainedModel:
"""Instantiate a model string with the appropriate `Auto` class."""
device = torch.device(device)
kwargs["device_map"] = {"": device}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kwargs["device_map"] = {"": device} will be passed by the caller instead (because for e.g. when instantiating an empty model, we can't pass a device map. otherwise it'll really load the weights and won't be an empty model anymore


with prevent_name_conflicts():
model_cfg = AutoConfig.from_pretrained(model_str)

# When the torch_dtype is None, this generally means the model is fp32, because
# the config was probably created before the `torch_dtype` field was added.
fp32_weights = model_cfg.torch_dtype in (None, torch.float32)

# Required by `bitsandbytes` to load in 8-bit.
if kwargs.get("load_in_8bit"):
# Sanity check: we probably shouldn't be loading in 8-bit if the checkpoint
# is in fp32. `bitsandbytes` only supports mixed fp16/int8 inference, and
# we can't guarantee that there won't be overflow if we downcast to fp16.
if fp32_weights:
raise ValueError("Cannot load in 8-bit if weights are fp32")

kwargs["torch_dtype"] = torch.float16

# CPUs generally don't support anything other than fp32.
elif device.type == "cpu":
kwargs["torch_dtype"] = torch.float32

# If the model is fp32 but bf16 is available, convert to bf16.
# Usually models with fp32 weights were actually trained in bf16, and
# converting them doesn't hurt performance.
elif fp32_weights and torch.cuda.is_bf16_supported():
kwargs["torch_dtype"] = torch.bfloat16
print("Weights seem to be fp32, but bf16 is available. Loading in bf16.")
else:
kwargs["torch_dtype"] = "auto"
# If a torch_dtype was not specified, try to infer it.
kwargs["torch_dtype"] = torch_dtype or determine_dtypes(
model_str=model_str, is_cpu=is_cpu, load_in_8bit=load_in_8bit
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made this change because it previously was setting kwargs even if it was getting passed by the caller of instantiate_model, which confused me

# Add load_in_8bit to kwargs
kwargs["load_in_8bit"] = load_in_8bit

archs = model_cfg.architectures
if not isinstance(archs, list):
Expand All @@ -70,7 +87,6 @@ def instantiate_model(
if arch_str.endswith(suffix):
model_cls = getattr(transformers, arch_str)
return model_cls.from_pretrained(model_str, **kwargs)

return AutoModel.from_pretrained(model_str, **kwargs)


Expand Down
Loading