Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding clip encoder for generation #28

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
6 changes: 3 additions & 3 deletions CODE_OF_CONDUCT.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ include:
Examples of unacceptable behavior by participants include:

* The use of sexualized language or imagery and unwelcome sexual attention or
advances
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
professional setting

## Our Responsibilities

Expand Down
6 changes: 6 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# Contributing to DiT

We want to make contributing to this project as easy and transparent as
possible.

## Our Development Process

Work on the `DiT` repo has mostly concluded.

## Pull Requests

We actively welcome your pull requests.

1. Fork the repo and create your branch from `main`.
Expand All @@ -16,12 +19,14 @@ We actively welcome your pull requests.
6. If you haven't already, complete the Contributor License Agreement ("CLA").

## Contributor License Agreement ("CLA")

In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Meta's open source projects.

Complete your CLA here: <https://code.facebook.com/cla>

## Issues

We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.

Expand All @@ -30,5 +35,6 @@ disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.

## License

By contributing to `DiT`, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.
112 changes: 38 additions & 74 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,32 @@

![DiT samples](visuals/sample_grid_0.png)

This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper exploring
diffusion models with transformers (DiTs). You can find more visualizations on our [project page](https://www.wpeebles.com/DiT).
This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper exploring
diffusion models with transformers (DiTs). You can find more visualizations on
our [project page](https://www.wpeebles.com/DiT).

> [**Scalable Diffusion Models with Transformers**](https://www.wpeebles.com/DiT)<br>
> [William Peebles](https://www.wpeebles.com), [Saining Xie](https://www.sainingxie.com)
> <br>UC Berkeley, New York University<br>

We train latent diffusion models, replacing the commonly-used U-Net backbone with a transformer that operates on
latent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward pass
We train latent diffusion models, replacing the commonly-used U-Net backbone with a transformer that operates on
latent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward pass
complexity as measured by Gflops. We find that DiTs with higher Gflops---through increased transformer depth/width or
increased number of input tokens---consistently have lower FID. In addition to good scalability properties, our
DiT-XL/2 models outperform all prior diffusion models on the class-conditional ImageNet 512×512 and 256×256 benchmarks,
increased number of input tokens---consistently have lower FID. In addition to good scalability properties, our
DiT-XL/2 models outperform all prior diffusion models on the class-conditional ImageNet 512×512 and 256×256 benchmarks,
achieving a state-of-the-art FID of 2.27 on the latter.

This repository contains:

* 🪐 A simple PyTorch [implementation](models.py) of DiT
* 🪐 A simple PyTorch [implementation](modules/dit_builder.py) of DiT
* ⚡️ Pre-trained class-conditional DiT models trained on ImageNet (512x512 and 256x256)
* 💥 A self-contained [Hugging Face Space](https://huggingface.co/spaces/wpeebles/DiT) and [Colab notebook](http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb) for running pre-trained DiT-XL/2 models
* 💥 A self-contained [Hugging Face Space](https://huggingface.co/spaces/wpeebles/DiT)
and [Colab notebook](http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb) for running
pre-trained DiT-XL/2 models
* 🛸 A DiT [training script](train.py) using PyTorch DDP

An implementation of DiT directly in Hugging Face `diffusers` can also be found [here](https://github.com/huggingface/diffusers/blob/main/docs/source/en/api/pipelines/dit.mdx).

An implementation of DiT directly in Hugging Face `diffusers` can also be
found [here](https://github.com/huggingface/diffusers/blob/main/docs/source/en/api/pipelines/dit.mdx).

## Setup

Expand All @@ -37,93 +40,53 @@ git clone https://github.com/facebookresearch/DiT.git
cd DiT
```

We provide an [`environment.yml`](environment.yml) file that can be used to create a Conda environment. If you only want
to run pre-trained models locally on CPU, you can remove the `cudatoolkit` and `pytorch-cuda` requirements from the file.
We provide an [`environment.yml`](environment.yml) file that can be used to create a Conda environment. If you only want
to run pre-trained models locally on CPU, you can remove the `cudatoolkit` and `pytorch-cuda` requirements from the
file.

```bash
conda env create -f environment.yml
conda activate DiT
```


## Sampling [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/wpeebles/DiT) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb)

![More DiT samples](visuals/sample_grid_1.png)

**Pre-trained DiT checkpoints.** You can sample from our pre-trained DiT models with [`sample.py`](sample.py). Weights for our pre-trained DiT model will be
**Pre-trained DiT checkpoints.** You can sample from our pre-trained DiT models with [`sample.py`](sample.py). Weights
for our pre-trained DiT model will be
automatically downloaded depending on the model you use. The script has various arguments to switch between the 256x256
and 512x512 models, adjust sampling steps, change the classifier-free guidance scale, etc. For example, to sample from
our 512x512 DiT-XL/2 model, you can use:
our 512x512 DiT-clipped model, you can use the new gradio interface:

```python
python sample.py --image-size 512 --seed 1
```bash
python sample_gradio.py --ckpt pretrained_models/last.ckpt
```

For convenience, our pre-trained DiT models can be downloaded directly here as well:

| DiT Model | Image Resolution | FID-50K | Inception Score | Gflops |
|---------------|------------------|---------|-----------------|--------|
| [XL/2](https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt) | 256x256 | 2.27 | 278.24 | 119 |
| [XL/2](https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-512x512.pt) | 512x512 | 3.04 | 240.82 | 525 |


**Custom DiT checkpoints.** If you've trained a new DiT model with [`train.py`](train.py) (see [below](#training-dit)), you can add the `--ckpt`
argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom
256x256 DiT-L/4 model, run:

```python
python sample.py --model DiT-L/4 --image-size 256 --ckpt /path/to/model.pt
```

| DiT Model | Image Resolution |
|------------------------------------------------------------------------------|------------------|
| [DiT_clipped](https://www.mediafire.com/file/trqvosl8947s88z/last.ckpt/file) | 256x256 |

## Training DiT

We provide a training script for DiT in [`train.py`](train.py). This script can be used to train class-conditional
DiT models, but it can be easily modified to support other types of conditioning. To launch DiT-XL/2 (256x256) training with `N` GPUs on
We provide a training script for DiT in [`train_pl.py`](train_pl.py). This script can be used to train class-conditional
DiT models, but it can be easily modified to support other types of conditioning. To launch DiT-clipped (256x256)
training
with `N` GPUs on
one node:

```python
torchrun --nnodes=1 --nproc_per_node=N train.py --model DiT-XL/2 --data-path /path/to/imagenet/train
```bash
python train_pl.py --coco_dataset_path (...)/datasets/fast-ai-coco
```

### PyTorch Training Results

We've trained DiT-XL/2 and DiT-B/4 models from scratch with the PyTorch training script
to verify that it reproduces the original JAX results up to several hundred thousand training iterations. Across our experiments, the PyTorch-trained models give
similar (and sometimes slightly better) results compared to the JAX-trained models up to reasonable random variation. Some data points:

| DiT Model | Train Steps | FID-50K<br> (JAX Training) | FID-50K<br> (PyTorch Training) | PyTorch Global Training Seed |
|------------|-------------|----------------------------|--------------------------------|------------------------------|
| XL/2 | 400K | 19.5 | **18.1** | 42 |
| B/4 | 400K | **68.4** | 68.9 | 42 |
| B/4 | 400K | 68.4 | **68.3** | 100 |

These models were trained at 256x256 resolution; we used 8x A100s to train XL/2 and 4x A100s to train B/4. Note that FID
here is computed with 250 DDPM sampling steps, with the `mse` VAE decoder and without guidance (`cfg-scale=1`).

**TF32 Note (important for A100 users).** When we ran the above tests, TF32 matmuls were disabled per PyTorch's defaults.
We've enabled them at the top of `train.py` and `sample.py` because it makes training and sampling way way way faster on
A100s (and should for other Ampere GPUs too), but note that the use of TF32 may lead to some differences compared to
the above results.

### Enhancements
Training (and sampling) could likely be sped-up significantly by:
- [ ] using [Flash Attention](https://github.com/HazyResearch/flash-attention) in the DiT model
- [ ] using `torch.compile` in PyTorch 2.0

Basic features that would be nice to add:
- [ ] Monitor FID and other metrics
- [ ] Generate and save samples from the EMA model periodically
- [ ] Resume training from a checkpoint
- [ ] AMP/bfloat16 support


## Differences from JAX

Our models were originally trained in JAX on TPUs. The weights in this repo are ported directly from the JAX models.
There may be minor differences in results stemming from sampling with different floating point precisions. We re-evaluated
our ported PyTorch weights at FP32, and they actually perform marginally better than sampling in JAX (2.21 FID
versus 2.27 in the paper).
Improvements to the project could be as follows:

- [ ] Improve generation quality by training the checkpoint further
- [ ] Adding more DiT_clipped architectures with more params and better training them

## BibTeX

Expand All @@ -136,13 +99,14 @@ versus 2.27 in the paper).
}
```


## Acknowledgments
We thank Kaiming He, Ronghang Hu, Alexander Berg, Shoubhik Debnath, Tim Brooks, Ilija Radosavovic and Tete Xiao for helpful discussions.

We thank Kaiming He, Ronghang Hu, Alexander Berg, Shoubhik Debnath, Tim Brooks, Ilija Radosavovic and Tete Xiao for
helpful discussions.
William Peebles is supported by the NSF Graduate Research Fellowship.

This codebase borrows from OpenAI's diffusion repos, most notably [ADM](https://github.com/openai/guided-diffusion).


## License

The code and model weights are licensed under CC-BY-NC. See [`LICENSE.txt`](LICENSE.txt) for details.
1 change: 0 additions & 1 deletion download.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch
import os


pretrained_models = {'DiT-XL-2-512x512.pt', 'DiT-XL-2-256x256.pt'}


Expand Down
8 changes: 5 additions & 3 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ dependencies:
- torchvision
- pytorch-cuda=11.7
- pip:
- timm
- diffusers
- accelerate
- timm
- diffusers
- accelerate
- gradio
- open_clip_torch
Loading