Skip to content

Commit

Permalink
Enabled 12GB VRAM training via optional activation checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-baumann committed Jun 13, 2024
1 parent 5204e0a commit 302c272
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ python learn_delta.py device=cuda:0 model=sdxl prompts=people/age
```
This will save the delta at `outputs/learn_delta/people/age/runs/<date>/<time>/checkpoints/delta.pt`, which you can then directly use as shown in the example notebooks.

This will typically require slightly more than 24GB of VRAM for training (26GB when training on an A100 as of June 13th 2024, although this will likely change with newer versions of diffusers and PyTorch). If you want to train on smaller hardware, you can enable gradient checkpointing (typically called activation checkpointing, but we'll stick to diffusers terminology here) by launching the training as
```shell
python learn_delta.py device=cuda:0 model=sdxl prompts=people/age model.compile=False +model.gradient_checkpointing=True
```
In our experiments, this enabled training deltas with a 11.5GB VRAM budget, at the cost of slower training.

#### Naive CLIP Difference Method
The simplest method to obtain deltas is the naive CLIP difference-based method. With it, you can obtain a delta in a few seconds on a decent GPU. It is substantially worse than the proper learned method though.

Expand Down
6 changes: 6 additions & 0 deletions attribute_control/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from abc import ABC, abstractmethod, abstractproperty
from typing import Union, Tuple, Dict, Optional, List, Any
from pydoc import locate
import warnings

import torch
from torch import nn
Expand Down Expand Up @@ -183,6 +184,7 @@ def __init__(
pipe_kwargs: dict = { },
device: Union[str, torch.device] = 'cuda:0',
compile: bool = False,
gradient_checkpointing: bool = False,
) -> None:
super().__init__(pipeline_type=pipeline_type, model_name=model_name, num_inference_steps=num_inference_steps, pipe_kwargs=pipe_kwargs, device=device, compile=compile)

Expand All @@ -191,6 +193,10 @@ def __init__(
d_v_major, d_v_minor, *_ = diffusers.__version__.split('.')
if int(d_v_major) > 0 or int(d_v_minor) >= 25:
self.pipe.fuse_qkv_projections()
if gradient_checkpointing:
if compile:
warnings.warn('Gradient checkpointing is typically not compatible with compiling the U-Net. This will likely lead to a crash.')
self.pipe.unet.enable_gradient_checkpointing()
if compile:
assert int(d_v_major) > 0 or int(d_v_minor) >= 25, 'Use at least diffusers==0.25 to enable proper functionality of torch.compile().'
self.pipe.unet.to(memory_format=torch.channels_last)
Expand Down

0 comments on commit 302c272

Please sign in to comment.