From 50142678f6f560a615b01be362c4dad7e954d74a Mon Sep 17 00:00:00 2001 From: neon Date: Wed, 22 Feb 2023 10:48:03 +0100 Subject: [PATCH 01/13] adapted workspace --- CODE_OF_CONDUCT.md | 6 +- CONTRIBUTING.md | 6 + README.md | 106 ++++++----- diffusion/__init__.py | 16 +- diffusion/diffusion_utils.py | 10 +- diffusion/gaussian_diffusion.py | 300 ++++++++++++++++---------------- diffusion/respace.py | 6 +- download.py | 1 - environment.yml | 6 +- models.py | 72 +++++--- sample.py | 14 +- train.py | 4 +- 12 files changed, 297 insertions(+), 250 deletions(-) diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 3232ed66..ccc3c990 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -23,13 +23,13 @@ include: Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or -advances + advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic -address, without explicit permission + address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a -professional setting + professional setting ## Our Responsibilities diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b45bfbaa..5cb72e81 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,11 +1,14 @@ # Contributing to DiT + We want to make contributing to this project as easy and transparent as possible. ## Our Development Process + Work on the `DiT` repo has mostly concluded. ## Pull Requests + We actively welcome your pull requests. 1. Fork the repo and create your branch from `main`. @@ -16,12 +19,14 @@ We actively welcome your pull requests. 6. If you haven't already, complete the Contributor License Agreement ("CLA"). ## Contributor License Agreement ("CLA") + In order to accept your pull request, we need you to submit a CLA. You only need to do this once to work on any of Meta's open source projects. Complete your CLA here: ## Issues + We use GitHub issues to track public bugs. Please ensure your description is clear and has sufficient instructions to be able to reproduce the issue. @@ -30,5 +35,6 @@ disclosure of security bugs. In those cases, please go through the process outlined on that page and do not file a public issue. ## License + By contributing to `DiT`, you agree that your contributions will be licensed under the LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/README.md b/README.md index edbc1a21..5e2fca63 100644 --- a/README.md +++ b/README.md @@ -4,29 +4,32 @@ ![DiT samples](visuals/sample_grid_0.png) -This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper exploring -diffusion models with transformers (DiTs). You can find more visualizations on our [project page](https://www.wpeebles.com/DiT). +This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper exploring +diffusion models with transformers (DiTs). You can find more visualizations on +our [project page](https://www.wpeebles.com/DiT). > [**Scalable Diffusion Models with Transformers**](https://www.wpeebles.com/DiT)
> [William Peebles](https://www.wpeebles.com), [Saining Xie](https://www.sainingxie.com) >
UC Berkeley, New York University
-We train latent diffusion models, replacing the commonly-used U-Net backbone with a transformer that operates on -latent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward pass +We train latent diffusion models, replacing the commonly-used U-Net backbone with a transformer that operates on +latent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward pass complexity as measured by Gflops. We find that DiTs with higher Gflops---through increased transformer depth/width or -increased number of input tokens---consistently have lower FID. In addition to good scalability properties, our -DiT-XL/2 models outperform all prior diffusion models on the class-conditional ImageNet 512×512 and 256×256 benchmarks, +increased number of input tokens---consistently have lower FID. In addition to good scalability properties, our +DiT-XL/2 models outperform all prior diffusion models on the class-conditional ImageNet 512×512 and 256×256 benchmarks, achieving a state-of-the-art FID of 2.27 on the latter. This repository contains: * 🪐 A simple PyTorch [implementation](models.py) of DiT * ⚡️ Pre-trained class-conditional DiT models trained on ImageNet (512x512 and 256x256) -* 💥 A self-contained [Hugging Face Space](https://huggingface.co/spaces/wpeebles/DiT) and [Colab notebook](http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb) for running pre-trained DiT-XL/2 models +* 💥 A self-contained [Hugging Face Space](https://huggingface.co/spaces/wpeebles/DiT) + and [Colab notebook](http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb) for running + pre-trained DiT-XL/2 models * 🛸 A DiT [training script](train.py) using PyTorch DDP -An implementation of DiT directly in Hugging Face `diffusers` can also be found [here](https://github.com/huggingface/diffusers/blob/main/docs/source/en/api/pipelines/dit.mdx). - +An implementation of DiT directly in Hugging Face `diffusers` can also be +found [here](https://github.com/huggingface/diffusers/blob/main/docs/source/en/api/pipelines/dit.mdx). ## Setup @@ -37,25 +40,30 @@ git clone https://github.com/facebookresearch/DiT.git cd DiT ``` -We provide an [`environment.yml`](environment.yml) file that can be used to create a Conda environment. If you only want -to run pre-trained models locally on CPU, you can remove the `cudatoolkit` and `pytorch-cuda` requirements from the file. +We provide an [`environment.yml`](environment.yml) file that can be used to create a Conda environment. If you only want +to run pre-trained models locally on CPU, you can remove the `cudatoolkit` and `pytorch-cuda` requirements from the +file. ```bash conda env create -f environment.yml conda activate DiT ``` - ## Sampling [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/wpeebles/DiT) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb) + ![More DiT samples](visuals/sample_grid_1.png) -**Pre-trained DiT checkpoints.** You can sample from our pre-trained DiT models with [`sample.py`](sample.py). Weights for our pre-trained DiT model will be +**Pre-trained DiT checkpoints.** You can sample from our pre-trained DiT models with [`sample.py`](sample.py). Weights +for our pre-trained DiT model will be automatically downloaded depending on the model you use. The script has various arguments to switch between the 256x256 and 512x512 models, adjust sampling steps, change the classifier-free guidance scale, etc. For example, to sample from our 512x512 DiT-XL/2 model, you can use: ```python -python sample.py --image-size 512 --seed 1 +python +sample.py - -image - size +512 - -seed +1 ``` For convenience, our pre-trained DiT models can be downloaded directly here as well: @@ -65,66 +73,79 @@ For convenience, our pre-trained DiT models can be downloaded directly here as w | [XL/2](https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt) | 256x256 | 2.27 | 278.24 | 119 | | [XL/2](https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-512x512.pt) | 512x512 | 3.04 | 240.82 | 525 | - -**Custom DiT checkpoints.** If you've trained a new DiT model with [`train.py`](train.py) (see [below](#training-dit)), you can add the `--ckpt` -argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom +**Custom DiT checkpoints.** If you've trained a new DiT model with [`train.py`](train.py) (see [below](#training-dit)), +you can add the `--ckpt` +argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom 256x256 DiT-L/4 model, run: ```python -python sample.py --model DiT-L/4 --image-size 256 --ckpt /path/to/model.pt +python +sample.py - -model +DiT - L / 4 - -image - size +256 - -ckpt / path / to / model.pt ``` - ## Training DiT -We provide a training script for DiT in [`train.py`](train.py). This script can be used to train class-conditional -DiT models, but it can be easily modified to support other types of conditioning. To launch DiT-XL/2 (256x256) training with `N` GPUs on +We provide a training script for DiT in [`train.py`](train.py). This script can be used to train class-conditional +DiT models, but it can be easily modified to support other types of conditioning. To launch DiT-XL/2 (256x256) training +with `N` GPUs on one node: ```python -torchrun --nnodes=1 --nproc_per_node=N train.py --model DiT-XL/2 --data-path /path/to/imagenet/train +torchrun - -nnodes = 1 - -nproc_per_node = N +train.py - -model +DiT - XL / 2 - -data - path / path / to / imagenet / train ``` ### PyTorch Training Results We've trained DiT-XL/2 and DiT-B/4 models from scratch with the PyTorch training script -to verify that it reproduces the original JAX results up to several hundred thousand training iterations. Across our experiments, the PyTorch-trained models give -similar (and sometimes slightly better) results compared to the JAX-trained models up to reasonable random variation. Some data points: +to verify that it reproduces the original JAX results up to several hundred thousand training iterations. Across our +experiments, the PyTorch-trained models give +similar (and sometimes slightly better) results compared to the JAX-trained models up to reasonable random variation. +Some data points: | DiT Model | Train Steps | FID-50K
(JAX Training) | FID-50K
(PyTorch Training) | PyTorch Global Training Seed | |------------|-------------|----------------------------|--------------------------------|------------------------------| -| XL/2 | 400K | 19.5 | **18.1** | 42 | -| B/4 | 400K | **68.4** | 68.9 | 42 | -| B/4 | 400K | 68.4 | **68.3** | 100 | - -These models were trained at 256x256 resolution; we used 8x A100s to train XL/2 and 4x A100s to train B/4. Note that FID -here is computed with 250 DDPM sampling steps, with the `mse` VAE decoder and without guidance (`cfg-scale=1`). - -**TF32 Note (important for A100 users).** When we ran the above tests, TF32 matmuls were disabled per PyTorch's defaults. -We've enabled them at the top of `train.py` and `sample.py` because it makes training and sampling way way way faster on -A100s (and should for other Ampere GPUs too), but note that the use of TF32 may lead to some differences compared to +| XL/2 | 400K | 19.5 | ** +18.1** | 42 | +| B/4 | 400K | ** +68.4** | 68.9 | 42 | +| B/4 | 400K | 68.4 | ** +68.3** | 100 | + +These models were trained at 256x256 resolution; we used 8x A100s to train XL/2 and 4x A100s to train B/4. Note that FID +here is computed with 250 DDPM sampling steps, with the `mse` VAE decoder and without guidance (`cfg-scale=1`). + +**TF32 Note (important for A100 users).** When we ran the above tests, TF32 matmuls were disabled per PyTorch's +defaults. +We've enabled them at the top of `train.py` and `sample.py` because it makes training and sampling way way way faster on +A100s (and should for other Ampere GPUs too), but note that the use of TF32 may lead to some differences compared to the above results. ### Enhancements + Training (and sampling) could likely be sped-up significantly by: + - [ ] using [Flash Attention](https://github.com/HazyResearch/flash-attention) in the DiT model - [ ] using `torch.compile` in PyTorch 2.0 Basic features that would be nice to add: + - [ ] Monitor FID and other metrics - [ ] Generate and save samples from the EMA model periodically - [ ] Resume training from a checkpoint - [ ] AMP/bfloat16 support - ## Differences from JAX -Our models were originally trained in JAX on TPUs. The weights in this repo are ported directly from the JAX models. -There may be minor differences in results stemming from sampling with different floating point precisions. We re-evaluated -our ported PyTorch weights at FP32, and they actually perform marginally better than sampling in JAX (2.21 FID +Our models were originally trained in JAX on TPUs. The weights in this repo are ported directly from the JAX models. +There may be minor differences in results stemming from sampling with different floating point precisions. We +re-evaluated +our ported PyTorch weights at FP32, and they actually perform marginally better than sampling in JAX (2.21 FID versus 2.27 in the paper). - ## BibTeX ```bibtex @@ -136,13 +157,14 @@ versus 2.27 in the paper). } ``` - ## Acknowledgments -We thank Kaiming He, Ronghang Hu, Alexander Berg, Shoubhik Debnath, Tim Brooks, Ilija Radosavovic and Tete Xiao for helpful discussions. + +We thank Kaiming He, Ronghang Hu, Alexander Berg, Shoubhik Debnath, Tim Brooks, Ilija Radosavovic and Tete Xiao for +helpful discussions. William Peebles is supported by the NSF Graduate Research Fellowship. This codebase borrows from OpenAI's diffusion repos, most notably [ADM](https://github.com/openai/guided-diffusion). - ## License + The code and model weights are licensed under CC-BY-NC. See [`LICENSE.txt`](LICENSE.txt) for details. diff --git a/diffusion/__init__.py b/diffusion/__init__.py index 8c536a98..5b809dcc 100644 --- a/diffusion/__init__.py +++ b/diffusion/__init__.py @@ -8,14 +8,14 @@ def create_diffusion( - timestep_respacing, - noise_schedule="linear", - use_kl=False, - sigma_small=False, - predict_xstart=False, - learn_sigma=True, - rescale_learned_sigmas=False, - diffusion_steps=1000 + timestep_respacing, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + learn_sigma=True, + rescale_learned_sigmas=False, + diffusion_steps=1000 ): betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) if use_kl: diff --git a/diffusion/diffusion_utils.py b/diffusion/diffusion_utils.py index e493a6a3..3f7eb242 100644 --- a/diffusion/diffusion_utils.py +++ b/diffusion/diffusion_utils.py @@ -28,11 +28,11 @@ def normal_kl(mean1, logvar1, mean2, logvar2): ] return 0.5 * ( - -1.0 - + logvar2 - - logvar1 - + th.exp(logvar1 - logvar2) - + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) ) diff --git a/diffusion/gaussian_diffusion.py b/diffusion/gaussian_diffusion.py index ccbcefec..e5663401 100644 --- a/diffusion/gaussian_diffusion.py +++ b/diffusion/gaussian_diffusion.py @@ -7,7 +7,7 @@ import math import numpy as np -import torch as th +import torch import enum from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl @@ -69,13 +69,13 @@ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_time """ if beta_schedule == "quad": betas = ( - np.linspace( - beta_start ** 0.5, - beta_end ** 0.5, - num_diffusion_timesteps, - dtype=np.float64, - ) - ** 2 + np.linspace( + beta_start ** 0.5, + beta_end ** 0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 ) elif beta_schedule == "linear": betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) @@ -151,12 +151,12 @@ class GaussianDiffusion: """ def __init__( - self, - *, - betas, - model_mean_type, - model_var_type, - loss_type + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type ): self.model_mean_type = model_mean_type @@ -186,7 +186,7 @@ def __init__( # calculations for posterior q(x_{t-1} | x_t, x_0) self.posterior_variance = ( - betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain self.posterior_log_variance_clipped = np.log( @@ -194,10 +194,10 @@ def __init__( ) if len(self.posterior_variance) > 1 else np.array([]) self.posterior_mean_coef1 = ( - betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) self.posterior_mean_coef2 = ( - (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) ) def q_mean_variance(self, x_start, t): @@ -222,11 +222,11 @@ def q_sample(self, x_start, t, noise=None): :return: A noisy version of x_start. """ if noise is None: - noise = th.randn_like(x_start) + noise = torch.randn_like(x_start) assert noise.shape == x_start.shape return ( - _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) def q_posterior_mean_variance(self, x_start, x_t, t): @@ -236,18 +236,18 @@ def q_posterior_mean_variance(self, x_start, x_t, t): """ assert x_start.shape == x_t.shape posterior_mean = ( - _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start - + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = _extract_into_tensor( self.posterior_log_variance_clipped, t, x_t.shape ) assert ( - posterior_mean.shape[0] - == posterior_variance.shape[0] - == posterior_log_variance_clipped.shape[0] - == x_start.shape[0] + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] ) return posterior_mean, posterior_variance, posterior_log_variance_clipped @@ -284,13 +284,13 @@ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, mod if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: assert model_output.shape == (B, C * 2, *x.shape[2:]) - model_output, model_var_values = th.split(model_output, C, dim=1) + model_output, model_var_values = torch.split(model_output, C, dim=1) min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) # The model_var_values is [-1, 1] for [min_var, max_var]. frac = (model_var_values + 1) / 2 model_log_variance = frac * max_log + (1 - frac) * min_log - model_variance = th.exp(model_log_variance) + model_variance = torch.exp(model_log_variance) else: model_variance, model_log_variance = { # for fixedlarge, we set the initial (log-)variance like so @@ -334,14 +334,14 @@ def process_xstart(x): def _predict_xstart_from_eps(self, x_t, t, eps): assert x_t.shape == eps.shape return ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) def _predict_eps_from_xstart(self, x_t, t, pred_xstart): return ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart - ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ @@ -374,14 +374,14 @@ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): return out def p_sample( - self, - model, - x, - t, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, ): """ Sample x_{t-1} from the model at the given timestep. @@ -407,26 +407,26 @@ def p_sample( denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) - noise = th.randn_like(x) + noise = torch.randn_like(x) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 if cond_fn is not None: out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) - sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + sample = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * noise return {"sample": sample, "pred_xstart": out["pred_xstart"]} def p_sample_loop( - self, - model, - shape, - noise=None, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - device=None, - progress=False, + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, ): """ Generate samples from the model. @@ -448,30 +448,30 @@ def p_sample_loop( """ final = None for sample in self.p_sample_loop_progressive( - model, - shape, - noise=noise, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - cond_fn=cond_fn, - model_kwargs=model_kwargs, - device=device, - progress=progress, + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, ): final = sample return final["sample"] def p_sample_loop_progressive( - self, - model, - shape, - noise=None, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - device=None, - progress=False, + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, ): """ Generate samples from the model and yield intermediate samples from @@ -486,7 +486,7 @@ def p_sample_loop_progressive( if noise is not None: img = noise else: - img = th.randn(*shape, device=device) + img = torch.randn(*shape, device=device) indices = list(range(self.num_timesteps))[::-1] if progress: @@ -496,8 +496,8 @@ def p_sample_loop_progressive( indices = tqdm(indices) for i in indices: - t = th.tensor([i] * shape[0], device=device) - with th.no_grad(): + t = torch.tensor([i] * shape[0], device=device) + with torch.no_grad(): out = self.p_sample( model, img, @@ -511,15 +511,15 @@ def p_sample_loop_progressive( img = out["sample"] def ddim_sample( - self, - model, - x, - t, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - eta=0.0, + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. @@ -543,15 +543,15 @@ def ddim_sample( alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( - eta - * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) - * th.sqrt(1 - alpha_bar / alpha_bar_prev) + eta + * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * torch.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. - noise = th.randn_like(x) + noise = torch.randn_like(x) mean_pred = ( - out["pred_xstart"] * th.sqrt(alpha_bar_prev) - + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + out["pred_xstart"] * torch.sqrt(alpha_bar_prev) + + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) @@ -560,15 +560,15 @@ def ddim_sample( return {"sample": sample, "pred_xstart": out["pred_xstart"]} def ddim_reverse_sample( - self, - model, - x, - t, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - eta=0.0, + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, ): """ Sample x_{t+1} from the model using DDIM reverse ODE. @@ -587,28 +587,28 @@ def ddim_reverse_sample( # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - - out["pred_xstart"] - ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) # Equation 12. reversed - mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + mean_pred = out["pred_xstart"] * torch.sqrt(alpha_bar_next) + torch.sqrt(1 - alpha_bar_next) * eps return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} def ddim_sample_loop( - self, - model, - shape, - noise=None, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - device=None, - progress=False, - eta=0.0, + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, ): """ Generate samples from the model using DDIM. @@ -616,32 +616,32 @@ def ddim_sample_loop( """ final = None for sample in self.ddim_sample_loop_progressive( - model, - shape, - noise=noise, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - cond_fn=cond_fn, - model_kwargs=model_kwargs, - device=device, - progress=progress, - eta=eta, + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, ): final = sample return final["sample"] def ddim_sample_loop_progressive( - self, - model, - shape, - noise=None, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - device=None, - progress=False, - eta=0.0, + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, ): """ Use DDIM to sample from the model and yield intermediate samples from @@ -654,7 +654,7 @@ def ddim_sample_loop_progressive( if noise is not None: img = noise else: - img = th.randn(*shape, device=device) + img = torch.randn(*shape, device=device) indices = list(range(self.num_timesteps))[::-1] if progress: @@ -664,8 +664,8 @@ def ddim_sample_loop_progressive( indices = tqdm(indices) for i in indices: - t = th.tensor([i] * shape[0], device=device) - with th.no_grad(): + t = torch.tensor([i] * shape[0], device=device) + with torch.no_grad(): out = self.ddim_sample( model, img, @@ -709,7 +709,7 @@ def _vb_terms_bpd( # At the first timestep return the decoder NLL, # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) - output = th.where((t == 0), decoder_nll, kl) + output = torch.where((t == 0), decoder_nll, kl) return {"output": output, "pred_xstart": out["pred_xstart"]} def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): @@ -727,7 +727,7 @@ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): if model_kwargs is None: model_kwargs = {} if noise is None: - noise = th.randn_like(x_start) + noise = torch.randn_like(x_start) x_t = self.q_sample(x_start, t, noise=noise) terms = {} @@ -752,10 +752,10 @@ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): ]: B, C = x_t.shape[:2] assert model_output.shape == (B, C * 2, *x_t.shape[2:]) - model_output, model_var_values = th.split(model_output, C, dim=1) + model_output, model_var_values = torch.split(model_output, C, dim=1) # Learn the variance using the variational bound, but don't let # it affect our mean prediction. - frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1) terms["vb"] = self._vb_terms_bpd( model=lambda *args, r=frozen_out: r, x_start=x_start, @@ -795,7 +795,7 @@ def _prior_bpd(self, x_start): :return: a batch of [N] KL values (in bits), one per batch element. """ batch_size = x_start.shape[0] - t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) kl_prior = normal_kl( mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 @@ -825,11 +825,11 @@ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): xstart_mse = [] mse = [] for t in list(range(self.num_timesteps))[::-1]: - t_batch = th.tensor([t] * batch_size, device=device) - noise = th.randn_like(x_start) + t_batch = torch.tensor([t] * batch_size, device=device) + noise = torch.randn_like(x_start) x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) # Calculate VLB term at the current timestep - with th.no_grad(): + with torch.no_grad(): out = self._vb_terms_bpd( model, x_start=x_start, @@ -843,9 +843,9 @@ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) mse.append(mean_flat((eps - noise) ** 2)) - vb = th.stack(vb, dim=1) - xstart_mse = th.stack(xstart_mse, dim=1) - mse = th.stack(mse, dim=1) + vb = torch.stack(vb, dim=1) + xstart_mse = torch.stack(xstart_mse, dim=1) + mse = torch.stack(mse, dim=1) prior_bpd = self._prior_bpd(x_start) total_bpd = vb.sum(dim=1) + prior_bpd @@ -867,7 +867,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): dimension equal to the length of timesteps. :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. """ - res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() while len(res.shape) < len(broadcast_shape): res = res[..., None] - return res + th.zeros(broadcast_shape, device=timesteps.device) + return res + torch.zeros(broadcast_shape, device=timesteps.device) diff --git a/diffusion/respace.py b/diffusion/respace.py index 0a2cc043..1ce11bcd 100644 --- a/diffusion/respace.py +++ b/diffusion/respace.py @@ -30,7 +30,7 @@ def space_timesteps(num_timesteps, section_counts): """ if isinstance(section_counts, str): if section_counts.startswith("ddim"): - desired_count = int(section_counts[len("ddim") :]) + desired_count = int(section_counts[len("ddim"):]) for i in range(1, num_timesteps): if len(range(0, num_timesteps, i)) == desired_count: return set(range(0, num_timesteps, i)) @@ -87,12 +87,12 @@ def __init__(self, use_timesteps, **kwargs): super().__init__(**kwargs) def p_mean_variance( - self, model, *args, **kwargs + self, model, *args, **kwargs ): # pylint: disable=signature-differs return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) def training_losses( - self, model, *args, **kwargs + self, model, *args, **kwargs ): # pylint: disable=signature-differs return super().training_losses(self._wrap_model(model), *args, **kwargs) diff --git a/download.py b/download.py index de22d459..fa20d726 100644 --- a/download.py +++ b/download.py @@ -11,7 +11,6 @@ import torch import os - pretrained_models = {'DiT-XL-2-512x512.pt', 'DiT-XL-2-256x256.pt'} diff --git a/environment.yml b/environment.yml index b5abcab9..df9e1471 100644 --- a/environment.yml +++ b/environment.yml @@ -8,6 +8,6 @@ dependencies: - torchvision - pytorch-cuda=11.7 - pip: - - timm - - diffusers - - accelerate + - timm + - diffusers + - accelerate diff --git a/models.py b/models.py index c90eeba7..3c671c60 100644 --- a/models.py +++ b/models.py @@ -28,6 +28,7 @@ class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. """ + def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( @@ -41,7 +42,7 @@ def __init__(self, hidden_size, frequency_embedding_size=256): def timestep_embedding(t, dim, max_period=10000): """ Create sinusoidal timestep embeddings. - :param t: a 1-D Tensor of N indices, one per batch element. + :param: t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. @@ -50,8 +51,7 @@ def timestep_embedding(t, dim, max_period=10000): # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py half = dim // 2 freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half - ).to(device=t.device) + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: @@ -68,6 +68,7 @@ class LabelEmbedder(nn.Module): """ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. """ + def __init__(self, num_classes, hidden_size, dropout_prob): super().__init__() use_cfg_embedding = dropout_prob > 0 @@ -102,6 +103,7 @@ class DiTBlock(nn.Module): """ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. """ + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): super().__init__() self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) @@ -126,6 +128,7 @@ class FinalLayer(nn.Module): """ The final layer of DiT. """ + def __init__(self, hidden_size, patch_size, out_channels): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) @@ -146,18 +149,19 @@ class DiT(nn.Module): """ Diffusion model with a Transformer backbone. """ + def __init__( - self, - input_size=32, - patch_size=2, - in_channels=4, - hidden_size=1152, - depth=28, - num_heads=16, - mlp_ratio=4.0, - class_dropout_prob=0.1, - num_classes=1000, - learn_sigma=True, + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, ): super().__init__() self.learn_sigma = learn_sigma @@ -186,6 +190,7 @@ def _basic_init(module): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) + self.apply(_basic_init) # Initialize (and freeze) pos_embed by sin-cos embedding: @@ -238,13 +243,13 @@ def forward(self, x, t, y): y: (N,) tensor of class labels """ x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 - t = self.t_embedder(t) # (N, D) - y = self.y_embedder(y, self.training) # (N, D) - c = t + y # (N, D) + t = self.t_embedder(t) # (N, D) + y = self.y_embedder(y, self.training) # (N, D) + c = t + y # (N, D) for block in self.blocks: - x = block(x, c) # (N, T, D) - x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) - x = self.unpatchify(x) # (N, out_channels, H, W) + x = block(x, c) # (N, T, D) + x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) return x def forward_with_cfg(self, x, t, y, cfg_scale): @@ -296,7 +301,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb @@ -309,13 +314,13 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2. - omega = 1. / 10000**omega # (D/2,) + omega = 1. / 10000 ** omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb @@ -328,43 +333,54 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): def DiT_XL_2(**kwargs): return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) + def DiT_XL_4(**kwargs): return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) + def DiT_XL_8(**kwargs): return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) + def DiT_L_2(**kwargs): return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) + def DiT_L_4(**kwargs): return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) + def DiT_L_8(**kwargs): return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) + def DiT_B_2(**kwargs): return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) + def DiT_B_4(**kwargs): return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) + def DiT_B_8(**kwargs): return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) + def DiT_S_2(**kwargs): return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) + def DiT_S_4(**kwargs): return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) + def DiT_S_8(**kwargs): return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) DiT_models = { - 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8, - 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8, - 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8, - 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8, + 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8, + 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8, + 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8, + 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8, } diff --git a/sample.py b/sample.py index 82238f16..fe23322c 100644 --- a/sample.py +++ b/sample.py @@ -7,17 +7,19 @@ Sample new images from a pre-trained DiT. """ import torch +import argparse + torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True from torchvision.utils import save_image -from diffusion import create_diffusion from diffusers.models import AutoencoderKL + +from diffusion import create_diffusion from download import find_model from models import DiT_models -import argparse -def main(args): +def sample(args): # Setup PyTorch: torch.manual_seed(args.seed) torch.set_grad_enabled(False) @@ -42,8 +44,8 @@ def main(args): diffusion = create_diffusion(str(args.num_sampling_steps)) vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) - # Labels to condition the model with (feel free to change): - class_labels = [207, 360, 387, 974, 88, 979, 417, 279] + # Labels to condition the model with (balloon, banjo, electric guitar, velvet) + class_labels = [417, 420, 546, 885] # Create sampling noise: n = len(class_labels) @@ -79,4 +81,4 @@ def main(args): parser.add_argument("--ckpt", type=str, default=None, help="Optional path to a DiT checkpoint (default: auto-download a pre-trained DiT-XL/2 model).") args = parser.parse_args() - main(args) + sample(args) diff --git a/train.py b/train.py index 7cfee808..edc89f7f 100644 --- a/train.py +++ b/train.py @@ -8,6 +8,7 @@ A minimal training script for DiT using PyTorch DDP. """ import torch + # the first flag below was False when we tested this script but True makes A100 training a lot faster: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True @@ -223,7 +224,8 @@ def main(args): avg_loss = torch.tensor(running_loss / log_steps, device=device) dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) avg_loss = avg_loss.item() / dist.get_world_size() - logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}") + logger.info( + f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}") # Reset monitoring variables: running_loss = 0 log_steps = 0 From fbb6ba423d3339d56b09a81cc4a90dc3dfe75053 Mon Sep 17 00:00:00 2001 From: neon Date: Wed, 22 Feb 2023 10:55:30 +0100 Subject: [PATCH 02/13] minor refactorings --- modules/__init__.py | 0 {diffusion => modules/diffusion}/__init__.py | 0 .../diffusion}/diffusion_utils.py | 26 +- .../diffusion}/gaussian_diffusion.py | 0 {diffusion => modules/diffusion}/respace.py | 0 .../diffusion}/timestep_sampler.py | 14 +- modules/encoders/__init__.py | 0 modules/encoders/modules.py | 240 +++++++ modules/encoders/x_transformer.py | 642 ++++++++++++++++++ sample.py | 7 +- train.py | 18 +- 11 files changed, 916 insertions(+), 31 deletions(-) create mode 100644 modules/__init__.py rename {diffusion => modules/diffusion}/__init__.py (100%) rename {diffusion => modules/diffusion}/diffusion_utils.py (76%) rename {diffusion => modules/diffusion}/gaussian_diffusion.py (100%) rename {diffusion => modules/diffusion}/respace.py (100%) rename {diffusion => modules/diffusion}/timestep_sampler.py (92%) create mode 100644 modules/encoders/__init__.py create mode 100644 modules/encoders/modules.py create mode 100644 modules/encoders/x_transformer.py diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/diffusion/__init__.py b/modules/diffusion/__init__.py similarity index 100% rename from diffusion/__init__.py rename to modules/diffusion/__init__.py diff --git a/diffusion/diffusion_utils.py b/modules/diffusion/diffusion_utils.py similarity index 76% rename from diffusion/diffusion_utils.py rename to modules/diffusion/diffusion_utils.py index 3f7eb242..08b3ae5c 100644 --- a/diffusion/diffusion_utils.py +++ b/modules/diffusion/diffusion_utils.py @@ -3,7 +3,7 @@ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py -import torch as th +import torch import numpy as np @@ -15,7 +15,7 @@ def normal_kl(mean1, logvar1, mean2, logvar2): """ tensor = None for obj in (mean1, logvar1, mean2, logvar2): - if isinstance(obj, th.Tensor): + if isinstance(obj, torch.Tensor): tensor = obj break assert tensor is not None, "at least one argument must be a Tensor" @@ -23,7 +23,7 @@ def normal_kl(mean1, logvar1, mean2, logvar2): # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for th.exp(). logvar1, logvar2 = [ - x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2) ] @@ -31,8 +31,8 @@ def normal_kl(mean1, logvar1, mean2, logvar2): -1.0 + logvar2 - logvar1 - + th.exp(logvar1 - logvar2) - + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) ) @@ -41,7 +41,7 @@ def approx_standard_normal_cdf(x): A fast approximation of the cumulative distribution function of the standard normal. """ - return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3)))) def continuous_gaussian_log_likelihood(x, *, means, log_scales): @@ -53,9 +53,9 @@ def continuous_gaussian_log_likelihood(x, *, means, log_scales): :return: a tensor like x of log probabilities (in nats). """ centered_x = x - means - inv_stdv = th.exp(-log_scales) + inv_stdv = torch.exp(-log_scales) normalized_x = centered_x * inv_stdv - log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) + log_probs = torch.distributions.Normal(torch.zeros_like(x), torch.ones_like(x)).log_prob(normalized_x) return log_probs @@ -71,18 +71,18 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales): """ assert x.shape == means.shape == log_scales.shape centered_x = x - means - inv_stdv = th.exp(-log_scales) + inv_stdv = torch.exp(-log_scales) plus_in = inv_stdv * (centered_x + 1.0 / 255.0) cdf_plus = approx_standard_normal_cdf(plus_in) min_in = inv_stdv * (centered_x - 1.0 / 255.0) cdf_min = approx_standard_normal_cdf(min_in) - log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) - log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) cdf_delta = cdf_plus - cdf_min - log_probs = th.where( + log_probs = torch.where( x < -0.999, log_cdf_plus, - th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + torch.where(x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), ) assert log_probs.shape == x.shape return log_probs diff --git a/diffusion/gaussian_diffusion.py b/modules/diffusion/gaussian_diffusion.py similarity index 100% rename from diffusion/gaussian_diffusion.py rename to modules/diffusion/gaussian_diffusion.py diff --git a/diffusion/respace.py b/modules/diffusion/respace.py similarity index 100% rename from diffusion/respace.py rename to modules/diffusion/respace.py diff --git a/diffusion/timestep_sampler.py b/modules/diffusion/timestep_sampler.py similarity index 92% rename from diffusion/timestep_sampler.py rename to modules/diffusion/timestep_sampler.py index a3f36984..9bcf9852 100644 --- a/diffusion/timestep_sampler.py +++ b/modules/diffusion/timestep_sampler.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod import numpy as np -import torch as th +import torch import torch.distributed as dist @@ -53,9 +53,9 @@ def sample(self, batch_size, device): w = self.weights() p = w / np.sum(w) indices_np = np.random.choice(len(p), size=(batch_size,), p=p) - indices = th.from_numpy(indices_np).long().to(device) + indices = torch.from_numpy(indices_np).long().to(device) weights_np = 1 / (len(p) * p[indices_np]) - weights = th.from_numpy(weights_np).float().to(device) + weights = torch.from_numpy(weights_np).float().to(device) return indices, weights @@ -80,20 +80,20 @@ def update_with_local_losses(self, local_ts, local_losses): :param local_losses: a 1D Tensor of losses. """ batch_sizes = [ - th.tensor([0], dtype=th.int32, device=local_ts.device) + torch.tensor([0], dtype=torch.int32, device=local_ts.device) for _ in range(dist.get_world_size()) ] dist.all_gather( batch_sizes, - th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + torch.tensor([len(local_ts)], dtype=torch.int32, device=local_ts.device), ) # Pad all_gather batches to be the maximum batch size. batch_sizes = [x.item() for x in batch_sizes] max_bs = max(batch_sizes) - timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] - loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + timestep_batches = [torch.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [torch.zeros(max_bs).to(local_losses) for bs in batch_sizes] dist.all_gather(timestep_batches, local_ts) dist.all_gather(loss_batches, local_losses) timesteps = [ diff --git a/modules/encoders/__init__.py b/modules/encoders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modules/encoders/modules.py b/modules/encoders/modules.py new file mode 100644 index 00000000..b47b20dd --- /dev/null +++ b/modules/encoders/modules.py @@ -0,0 +1,240 @@ +import torch +import torch.nn as nn +from functools import partial +import clip +from einops import rearrange, repeat +from transformers import CLIPTokenizer, CLIPTextModel +import kornia + +from x_transformer import Encoder, TransformerWrapper + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class'): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + + def forward(self, batch, key=None): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + c = self.embedding(c) + return c + + +class TransformerEmbedder(AbstractEncoder): + """Some transformer encoder layers""" + + def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + super().__init__() + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer)) + + def forward(self, tokens): + tokens = tokens.to(self.device) # meh + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, x): + return self(x) + + +class BERTTokenizer(AbstractEncoder): + """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" + + def __init__(self, device="cuda", vq_interface=True, max_length=77): + super().__init__() + from transformers import BertTokenizerFast + self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + self.device = device + self.vq_interface = vq_interface + self.max_length = max_length + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + return tokens + + @torch.no_grad() + def encode(self, text): + tokens = self(text) + if not self.vq_interface: + return tokens + return None, None, [None, None, tokens] + + def decode(self, text): + return text + + +class BERTEmbedder(AbstractEncoder): + """Uses the BERT tokenizr model and add some transformer encoder layers""" + + def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, + device="cuda", use_tokenizer=True, embedding_dropout=0.0): + super().__init__() + self.use_tknz_fn = use_tokenizer + if self.use_tknz_fn: + self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer), + emb_dropout=embedding_dropout) + + def forward(self, text): + if self.use_tknz_fn: + tokens = self.tknz_fn(text) # .to(self.device) + else: + tokens = text + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, text): + # output of length 77 + return self(text) + + +class SpatialRescaler(nn.Module): + def __init__(self, + n_stages=1, + method='bilinear', + multiplier=0.5, + in_channels=3, + out_channels=None, + bias=False): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in ['nearest', 'linear', 'bilinear', 'trilinear', 'bicubic', 'area'] + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.remap_output = out_channels is not None + if self.remap_output: + print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') + self.channel_mapper = nn.Conv2d(in_channels, out_channels, 1, bias=bias) + + def forward(self, x): + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPTextEmbedder(nn.Module): + """ + Uses the CLIP transformer encoder for text. + """ + + def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): + super().__init__() + self.model, _ = clip.load(version, jit=False, device="cpu") + self.device = device + self.max_length = max_length + self.n_repeat = n_repeat + self.normalize = normalize + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = clip.tokenize(text).to(self.device) + z = self.model.encode_text(tokens) + if self.normalize: + z = z / torch.linalg.norm(z, dim=1, keepdim=True) + return z + + def encode(self, text): + z = self(text) + if z.ndim == 2: + z = z[:, None, :] + z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) + return z + + +class FrozenClipImageEmbedder(nn.Module): + """ + Uses the CLIP image encoder. + """ + + def __init__( + self, + model, + jit=False, + device='cuda' if torch.cuda.is_available() else 'cpu', + antialias=False, + ): + super().__init__() + self.model, _ = clip.load(name=model, device=device, jit=jit) + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic', align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + # x is assumed to be in range [-1,1] + return self.model.encode_image(self.preprocess(x)) + + +if __name__ == "__main__": + from ldm.util import count_params + + model = FrozenCLIPEmbedder() + count_params(model, verbose=True) diff --git a/modules/encoders/x_transformer.py b/modules/encoders/x_transformer.py new file mode 100644 index 00000000..1487033f --- /dev/null +++ b/modules/encoders/x_transformer.py @@ -0,0 +1,642 @@ +"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial +from inspect import isfunction +from collections import namedtuple +from einops import rearrange, repeat + +# constants + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple('Intermediates', [ + 'pre_softmax_attn', + 'post_softmax_attn' +]) + +LayerIntermediates = namedtuple('Intermediates', [ + 'hiddens', + 'attn_intermediates' +]) + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self.init_() + + def init_(self): + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + + +# helpers + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def always(val): + def inner(*args, **kwargs): + return val + + return inner + + +def not_equals(val): + def inner(x): + return x != val + + return inner + + +def equals(val): + def inner(x): + return x == val + + return inner + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +# keyword argument helpers + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# classes +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.value, *rest) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class Residual(nn.Module): + def forward(self, x, residual): + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + + def forward(self, x, residual): + gated_output = self.gru( + rearrange(x, 'b n d -> (b n) d'), + rearrange(residual, 'b n d -> (b n) d') + ) + + return gated_output.reshape_as(x) + + +# feedforward + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +# attention. +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0., + on_attn=False + ): + super().__init__() + if use_entmax15: + raise NotImplementedError("Check out entmax activation instead of softmax activation!") + self.scale = dim_head ** -0.5 + self.heads = heads + self.causal = causal + self.mask = mask + + inner_dim = dim_head * heads + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + # self.attn_fn = entmax15 if use_entmax15 else F.softmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None + ): + b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, 'b i -> b () i ()') + k_mask = rearrange(k_mask, 'b j -> b () () j') + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots + + if talking_heads: + dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + + if exists(rel_pos): + dots = rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, + post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) + attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + + dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + + self.has_pos_emb = position_infused_attn + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.rotary_pos_emb = always(None) + + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + self.rel_pos = None + + self.pre_norm = pre_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ('a', 'c', 'f') + elif cross_attend and only_cross: + default_block = ('c', 'f') + else: + default_block = ('a', 'f') + + if macaron: + default_block = ('f',) + default_block + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, 'par ratio out of range' + default_block = tuple(filter(not_equals('f'), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f',) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ('f',) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' + layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + + for layer_type in self.layer_types: + if layer_type == 'a': + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == 'c': + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == 'f': + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f'invalid layer type {layer_type}') + + if isinstance(layer, Attention) and exists(branch_fn): + layer = branch_fn(layer) + + if gate_residual: + residual_fn = GRUGating(dim) + else: + residual_fn = Residual() + + self.layers.append(nn.ModuleList([ + norm_fn(), + layer, + residual_fn + ])) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + return_hiddens=False + ): + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + is_last = ind == (len(self.layers) - 1) + + if layer_type == 'a': + hiddens.append(x) + layer_mem = mems.pop(0) + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == 'a': + out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, + prev_attn=prev_attn, mem=layer_mem) + elif layer_type == 'c': + out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) + elif layer_type == 'f': + out = block(x) + + x = residual_fn(out, residual) + + if layer_type in ('a', 'c'): + intermediates.append(inter) + + if layer_type == 'a' and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == 'c' and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if not self.pre_norm and not is_last: + x = norm(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, + attn_intermediates=intermediates + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on encoder' + super().__init__(causal=False, **kwargs) + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0., + emb_dropout=0., + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.num_tokens = num_tokens + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + # let funnel encoder know number of memory tokens, if specified + if hasattr(attn_layers, 'num_memory_tokens'): + attn_layers.num_memory_tokens = num_memory_tokens + + def init_(self): + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + **kwargs + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_mems: + hiddens = intermediates.hiddens + new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens + new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + return out, new_mems + + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + return out, attn_maps + + return out diff --git a/sample.py b/sample.py index fe23322c..5caf101d 100644 --- a/sample.py +++ b/sample.py @@ -9,15 +9,16 @@ import torch import argparse -torch.backends.cuda.matmul.allow_tf32 = True -torch.backends.cudnn.allow_tf32 = True from torchvision.utils import save_image from diffusers.models import AutoencoderKL -from diffusion import create_diffusion +from modules.diffusion import create_diffusion from download import find_model from models import DiT_models +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + def sample(args): # Setup PyTorch: diff --git a/train.py b/train.py index edc89f7f..fd5b89ee 100644 --- a/train.py +++ b/train.py @@ -8,30 +8,32 @@ A minimal training script for DiT using PyTorch DDP. """ import torch +import argparse +import logging +import os -# the first flag below was False when we tested this script but True makes A100 training a lot faster: -torch.backends.cuda.matmul.allow_tf32 = True -torch.backends.cudnn.allow_tf32 = True +import numpy as np import torch.distributed as dist + from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from torchvision.datasets import ImageFolder from torchvision import transforms -import numpy as np from collections import OrderedDict from PIL import Image from copy import deepcopy from glob import glob from time import time -import argparse -import logging -import os from models import DiT_models -from diffusion import create_diffusion +from modules.diffusion import create_diffusion from diffusers.models import AutoencoderKL +# the first flag below was False when we tested this script but True makes A100 training a lot faster: +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + ################################################################################# # Training Helper Functions # From 9c5e16abf6f4345df24744a0ccfd82b10eecb9a7 Mon Sep 17 00:00:00 2001 From: neon Date: Wed, 22 Feb 2023 11:22:08 +0100 Subject: [PATCH 03/13] added a gradio interface --- environment.yml | 1 + sample_gradio.py | 100 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+) create mode 100644 sample_gradio.py diff --git a/environment.yml b/environment.yml index df9e1471..d46c3387 100644 --- a/environment.yml +++ b/environment.yml @@ -11,3 +11,4 @@ dependencies: - timm - diffusers - accelerate + - gradio diff --git a/sample_gradio.py b/sample_gradio.py new file mode 100644 index 00000000..d261a987 --- /dev/null +++ b/sample_gradio.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Sample new images from a pre-trained DiT. +""" +import torch +import argparse +import gradio as gr +import torchvision + +from torchvision.utils import save_image, make_grid +from diffusers.models import AutoencoderKL + +from modules.diffusion import create_diffusion +from download import find_model +from models import DiT_models + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + + +def sample(class_idx, cfg_scale, num_sampling_steps): + # Setup PyTorch: + torch.manual_seed(args.seed) + torch.set_grad_enabled(False) + device = "cuda" if torch.cuda.is_available() else "cpu" + + if args.ckpt is None: + assert args.model == "DiT-XL/2", "Only DiT-XL/2 models are available for auto-download." + assert args.image_size in [256, 512] + assert args.num_classes == 1000 + + # Load model: + latent_size = args.image_size // 8 + model = DiT_models[args.model]( + input_size=latent_size, + num_classes=args.num_classes + ).to(device) + # Auto-download a pre-trained model or load a custom DiT checkpoint from train.py: + ckpt_path = args.ckpt or f"DiT-XL-2-{args.image_size}x{args.image_size}.pt" + state_dict = find_model(ckpt_path) + model.load_state_dict(state_dict) + model.eval() # important! + diffusion = create_diffusion(str(num_sampling_steps)) + vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) + + # Labels to condition the model with (balloon, banjo, electric guitar, velvet) + class_labels = [class_idx] * 4 + + # Create sampling noise: + n = len(class_labels) + z = torch.randn(n, 4, latent_size, latent_size, device=device) + y = torch.tensor(class_labels, device=device) + + # Setup classifier-free guidance: + z = torch.cat([z, z], 0) + y_null = torch.tensor([1000] * n, device=device) + y = torch.cat([y, y_null], 0) + model_kwargs = dict(y=y, cfg_scale=cfg_scale) + + # Sample images: + samples = diffusion.p_sample_loop( + model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device + ) + samples, _ = samples.chunk(2, dim=0) # Remove null class samples + samples = vae.decode(samples / 0.18215).sample + + # Save and display images: + samples = make_grid(samples, nrow=4, normalize=True, value_range=(-1, 1)) + samples = torchvision.transforms.ToPILImage()(samples) + return samples + # save_image(samples, "sample.png", nrow=4, normalize=True, value_range=(-1, 1)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2") + parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="mse") + parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) + parser.add_argument("--num-classes", type=int, default=1000) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--ckpt", type=str, default=None, + help="Optional path to a DiT checkpoint (default: auto-download a pre-trained DiT-XL/2 model).") + + args = parser.parse_args() + demo = gr.Interface( + fn=sample, + inputs=[ + gr.Slider(minimum=1, maximum=1000, value=417, step=1, label="Imagenet class index"), + gr.Slider(minimum=1, maximum=20, value=4, step=0.1, label="Cfg scale"), + gr.Slider(minimum=5, maximum=500, value=250, step=1, label="Sampling steps") + ], + outputs=[ + gr.Image() + ] + ) + demo.launch() From cf9a4b9e4e7f3fe200354adfca745ba0f481a66f Mon Sep 17 00:00:00 2001 From: Slava Date: Wed, 22 Feb 2023 12:24:28 +0100 Subject: [PATCH 04/13] final minor tweaks --- README.md | 28 ++-- models.py => modules/dit_builder.py | 189 +-------------------------- modules/dit_clipped.py | 133 +++++++++++++++++++ modules/utils.py | 191 ++++++++++++++++++++++++++++ sample.py | 2 +- sample_gradio.py | 4 +- train.py | 2 +- 7 files changed, 346 insertions(+), 203 deletions(-) rename models.py => modules/dit_builder.py (50%) create mode 100644 modules/dit_clipped.py create mode 100644 modules/utils.py diff --git a/README.md b/README.md index 5e2fca63..a914c631 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ achieving a state-of-the-art FID of 2.27 on the latter. This repository contains: -* 🪐 A simple PyTorch [implementation](models.py) of DiT +* 🪐 A simple PyTorch [implementation](modules/dit_builder.py) of DiT * ⚡️ Pre-trained class-conditional DiT models trained on ImageNet (512x512 and 256x256) * 💥 A self-contained [Hugging Face Space](https://huggingface.co/spaces/wpeebles/DiT) and [Colab notebook](http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb) for running @@ -59,11 +59,14 @@ automatically downloaded depending on the model you use. The script has various and 512x512 models, adjust sampling steps, change the classifier-free guidance scale, etc. For example, to sample from our 512x512 DiT-XL/2 model, you can use: -```python -python -sample.py - -image - size -512 - -seed -1 +```bash +python sample.py --image-size 512 --seed 1 +``` + +**New gradio interface!** + +```bash +python sample_gradio.py ``` For convenience, our pre-trained DiT models can be downloaded directly here as well: @@ -78,11 +81,8 @@ you can add the `--ckpt` argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom 256x256 DiT-L/4 model, run: -```python -python -sample.py - -model -DiT - L / 4 - -image - size -256 - -ckpt / path / to / model.pt +```bash +python sample.py --model DiT-L/4 --image-size 256 --ckpt/path/to/model.pt ``` ## Training DiT @@ -92,10 +92,8 @@ DiT models, but it can be easily modified to support other types of conditioning with `N` GPUs on one node: -```python -torchrun - -nnodes = 1 - -nproc_per_node = N -train.py - -model -DiT - XL / 2 - -data - path / path / to / imagenet / train +```bash +torchrun --nnodes=1 --nproc_per_node = N train.py --model DiT-XL/2 --data -path/path/to/imagenet/train ``` ### PyTorch Training Results diff --git a/models.py b/modules/dit_builder.py similarity index 50% rename from models.py rename to modules/dit_builder.py index 3c671c60..e7bc1e30 100644 --- a/models.py +++ b/modules/dit_builder.py @@ -10,139 +10,14 @@ # -------------------------------------------------------- import torch -import torch.nn as nn -import numpy as np -import math -from timm.models.vision_transformer import PatchEmbed, Attention, Mlp - - -def modulate(x, shift, scale): - return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) - - -################################################################################# -# Embedding Layers for Timesteps and Class Labels # -################################################################################# - -class TimestepEmbedder(nn.Module): - """ - Embeds scalar timesteps into vector representations. - """ - - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=True), - ) - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10000): - """ - Create sinusoidal timestep embeddings. - :param: t: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an (N, D) Tensor of positional embeddings. - """ - # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) - args = t[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward(self, t): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq) - return t_emb - - -class LabelEmbedder(nn.Module): - """ - Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. - """ - - def __init__(self, num_classes, hidden_size, dropout_prob): - super().__init__() - use_cfg_embedding = dropout_prob > 0 - self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) - self.num_classes = num_classes - self.dropout_prob = dropout_prob - - def token_drop(self, labels, force_drop_ids=None): - """ - Drops labels to enable classifier-free guidance. - """ - if force_drop_ids is None: - drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob - else: - drop_ids = force_drop_ids == 1 - labels = torch.where(drop_ids, self.num_classes, labels) - return labels - def forward(self, labels, train, force_drop_ids=None): - use_dropout = self.dropout_prob > 0 - if (train and use_dropout) or (force_drop_ids is not None): - labels = self.token_drop(labels, force_drop_ids) - embeddings = self.embedding_table(labels) - return embeddings - - -################################################################################# -# Core DiT Model # -################################################################################# - -class DiTBlock(nn.Module): - """ - A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. - """ +import torch.nn as nn - def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): - super().__init__() - self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) - self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - mlp_hidden_dim = int(hidden_size * mlp_ratio) - approx_gelu = lambda: nn.GELU(approximate="tanh") - self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(hidden_size, 6 * hidden_size, bias=True) - ) - - def forward(self, x, c): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) - x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) - x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) - return x +from timm.models.vision_transformer import PatchEmbed +from modules.utils import TimestepEmbedder, LabelEmbedder, DiTBlock, FinalLayer, get_2d_sincos_pos_embed -class FinalLayer(nn.Module): - """ - The final layer of DiT. - """ - - def __init__(self, hidden_size, patch_size, out_channels): - super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(hidden_size, 2 * hidden_size, bias=True) - ) - - def forward(self, x, c): - shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) - x = modulate(self.norm_final(x), shift, scale) - x = self.linear(x) - return x +from dit_clipped import DiT_Clipped class DiT(nn.Module): @@ -271,61 +146,6 @@ def forward_with_cfg(self, x, t, y, cfg_scale): return torch.cat([eps, rest], dim=1) -################################################################################# -# Sine/Cosine Positional Embedding Functions # -################################################################################# -# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py - -def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): - """ - grid_size: int of the grid height and width - return: - pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - grid_h = np.arange(grid_size, dtype=np.float32) - grid_w = np.arange(grid_size, dtype=np.float32) - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, grid_size, grid_size]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if cls_token and extra_tokens > 0: - pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) - return pos_embed - - -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): - assert embed_dim % 2 == 0 - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - return emb - - -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position - pos: a list of positions to be encoded: size (M,) - out: (M, D) - """ - assert embed_dim % 2 == 0 - omega = np.arange(embed_dim // 2, dtype=np.float64) - omega /= embed_dim / 2. - omega = 1. / 10000 ** omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product - - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) - - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb - - ################################################################################# # DiT Configs # ################################################################################# @@ -383,4 +203,5 @@ def DiT_S_8(**kwargs): 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8, 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8, 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8, + 'DiT_Clipped': DiT_Clipped } diff --git a/modules/dit_clipped.py b/modules/dit_clipped.py new file mode 100644 index 00000000..ed46ad40 --- /dev/null +++ b/modules/dit_clipped.py @@ -0,0 +1,133 @@ +import torch + +import torch.nn as nn + +from timm.models.vision_transformer import PatchEmbed + +from modules.utils import TimestepEmbedder, LabelEmbedder, DiTBlock, FinalLayer, get_2d_sincos_pos_embed + + +class DiT_Clipped(nn.Module): + """ + Diffusion model with a Transformer backbone and clip encoder. + """ + + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + num_patches = self.x_embedder.num_patches + # Will use fixed sin-cos embedding: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + + self.blocks = nn.ModuleList([ + DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) + ]) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize (and freeze) pos_embed by sin-cos embedding: + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) + return imgs + + def forward(self, x, t, y): + """ + Forward pass of DiT. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N,) tensor of class labels + """ + x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(t) # (N, D) + y = self.y_embedder(y, self.training) # (N, D) + c = t + y # (N, D) + for block in self.blocks: + x = block(x, c) # (N, T, D) + x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + return x + + def forward_with_cfg(self, x, t, y, cfg_scale): + """ + Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, t, y) + # For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] + eps, rest = model_out[:, :3], model_out[:, 3:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) diff --git a/modules/utils.py b/modules/utils.py new file mode 100644 index 00000000..0312eec2 --- /dev/null +++ b/modules/utils.py @@ -0,0 +1,191 @@ +import torch +import math + +import torch.nn as nn +import numpy as np + +from timm.models.vision_transformer import Attention, Mlp + + +################################################################################# +# Sine/Cosine Positional Embedding Functions # +################################################################################# +# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2. + omega = 1. / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param: t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class LabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +################################################################################# +# Core DiT Model # +################################################################################# + +class DiTBlock(nn.Module): + """ + A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x diff --git a/sample.py b/sample.py index 5caf101d..6737e775 100644 --- a/sample.py +++ b/sample.py @@ -14,7 +14,7 @@ from modules.diffusion import create_diffusion from download import find_model -from models import DiT_models +from modules.dit_builder import DiT_models torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True diff --git a/sample_gradio.py b/sample_gradio.py index d261a987..a29edfe0 100644 --- a/sample_gradio.py +++ b/sample_gradio.py @@ -11,12 +11,12 @@ import gradio as gr import torchvision -from torchvision.utils import save_image, make_grid +from torchvision.utils import make_grid from diffusers.models import AutoencoderKL from modules.diffusion import create_diffusion from download import find_model -from models import DiT_models +from modules.dit_builder import DiT_models torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True diff --git a/train.py b/train.py index fd5b89ee..e880c609 100644 --- a/train.py +++ b/train.py @@ -26,7 +26,7 @@ from glob import glob from time import time -from models import DiT_models +from modules.dit_builder import DiT_models from modules.diffusion import create_diffusion from diffusers.models import AutoencoderKL From 7a9e5a37265af3875f4edbc357cb3933cebacc5f Mon Sep 17 00:00:00 2001 From: Slava Date: Wed, 22 Feb 2023 13:45:21 +0100 Subject: [PATCH 05/13] added openclip head --- modules/dit_builder.py | 12 ++++- modules/dit_clipped.py | 15 +++++- modules/encoders/modules.py | 6 +-- sample_gradio.py | 53 +++++++++++-------- sample_gradio_vanilla.py | 102 ++++++++++++++++++++++++++++++++++++ 5 files changed, 158 insertions(+), 30 deletions(-) create mode 100644 sample_gradio_vanilla.py diff --git a/modules/dit_builder.py b/modules/dit_builder.py index e7bc1e30..886c2896 100644 --- a/modules/dit_builder.py +++ b/modules/dit_builder.py @@ -17,7 +17,7 @@ from modules.utils import TimestepEmbedder, LabelEmbedder, DiTBlock, FinalLayer, get_2d_sincos_pos_embed -from dit_clipped import DiT_Clipped +from .dit_clipped import DiT_Clipped class DiT(nn.Module): @@ -116,6 +116,10 @@ def forward(self, x, t, y): x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) t: (N,) tensor of diffusion timesteps y: (N,) tensor of class labels + + typical values: + D: 1152 (576 * 2) + N: 8 (4 * 2) """ x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 t = self.t_embedder(t) # (N, D) @@ -198,10 +202,14 @@ def DiT_S_8(**kwargs): return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) +def DiT_clipper_builder(**kwargs): + return DiT_Clipped(depth=28, hidden_size=768, patch_size=2, num_heads=16, **kwargs) + + DiT_models = { 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8, 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8, 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8, 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8, - 'DiT_Clipped': DiT_Clipped + 'DiT_Clipped': DiT_clipper_builder } diff --git a/modules/dit_clipped.py b/modules/dit_clipped.py index ed46ad40..8d3a808c 100644 --- a/modules/dit_clipped.py +++ b/modules/dit_clipped.py @@ -4,6 +4,7 @@ from timm.models.vision_transformer import PatchEmbed +from modules.encoders.modules import FrozenCLIPTextEmbedder from modules.utils import TimestepEmbedder, LabelEmbedder, DiTBlock, FinalLayer, get_2d_sincos_pos_embed @@ -24,6 +25,7 @@ def __init__( class_dropout_prob=0.1, num_classes=1000, learn_sigma=True, + clip_version='ViT-L/14' ): super().__init__() self.learn_sigma = learn_sigma @@ -43,8 +45,15 @@ def __init__( DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) ]) self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + + self.encoder = FrozenCLIPTextEmbedder(clip_version) + self.initialize_weights() + def encode(self, text_prompt): + c = self.encoder.encode(text_prompt) + return c + def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): @@ -103,10 +112,14 @@ def forward(self, x, t, y): x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) t: (N,) tensor of diffusion timesteps y: (N,) tensor of class labels + + typical values: + D: 1152 (576 * 2) + N: 8 (4 * 2) """ x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 t = self.t_embedder(t) # (N, D) - y = self.y_embedder(y, self.training) # (N, D) + # y = self.y_embedder(y, self.training) # (N, D) c = t + y # (N, D) for block in self.blocks: x = block(x, c) # (N, T, D) diff --git a/modules/encoders/modules.py b/modules/encoders/modules.py index b47b20dd..b1d9ce27 100644 --- a/modules/encoders/modules.py +++ b/modules/encoders/modules.py @@ -6,7 +6,7 @@ from transformers import CLIPTokenizer, CLIPTextModel import kornia -from x_transformer import Encoder, TransformerWrapper +from .x_transformer import Encoder, TransformerWrapper class AbstractEncoder(nn.Module): @@ -234,7 +234,5 @@ def forward(self, x): if __name__ == "__main__": - from ldm.util import count_params - model = FrozenCLIPEmbedder() - count_params(model, verbose=True) + print(model) diff --git a/sample_gradio.py b/sample_gradio.py index a29edfe0..71417e04 100644 --- a/sample_gradio.py +++ b/sample_gradio.py @@ -22,30 +22,14 @@ torch.backends.cudnn.allow_tf32 = True -def sample(class_idx, cfg_scale, num_sampling_steps): +def sample(prompt, class_idx, cfg_scale, num_sampling_steps): # Setup PyTorch: torch.manual_seed(args.seed) torch.set_grad_enabled(False) - device = "cuda" if torch.cuda.is_available() else "cpu" - - if args.ckpt is None: - assert args.model == "DiT-XL/2", "Only DiT-XL/2 models are available for auto-download." - assert args.image_size in [256, 512] - assert args.num_classes == 1000 - # Load model: - latent_size = args.image_size // 8 - model = DiT_models[args.model]( - input_size=latent_size, - num_classes=args.num_classes - ).to(device) - # Auto-download a pre-trained model or load a custom DiT checkpoint from train.py: - ckpt_path = args.ckpt or f"DiT-XL-2-{args.image_size}x{args.image_size}.pt" - state_dict = find_model(ckpt_path) - model.load_state_dict(state_dict) - model.eval() # important! + model.to(device) + model.eval() diffusion = create_diffusion(str(num_sampling_steps)) - vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) # Labels to condition the model with (balloon, banjo, electric guitar, velvet) class_labels = [class_idx] * 4 @@ -53,11 +37,11 @@ def sample(class_idx, cfg_scale, num_sampling_steps): # Create sampling noise: n = len(class_labels) z = torch.randn(n, 4, latent_size, latent_size, device=device) - y = torch.tensor(class_labels, device=device) + y = model.encode(prompt)[0].repeat((4, 1)).to(device) # Setup classifier-free guidance: z = torch.cat([z, z], 0) - y_null = torch.tensor([1000] * n, device=device) + y_null = model.encode("")[0].repeat((4, 1)).to(device) # negative y = torch.cat([y, y_null], 0) model_kwargs = dict(y=y, cfg_scale=cfg_scale) @@ -66,7 +50,14 @@ def sample(class_idx, cfg_scale, num_sampling_steps): model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device ) samples, _ = samples.chunk(2, dim=0) # Remove null class samples + + # OOMs :cry: + model.cpu() + torch.cuda.empty_cache() + + vae.to(device) samples = vae.decode(samples / 0.18215).sample + vae.cpu() # Save and display images: samples = make_grid(samples, nrow=4, normalize=True, value_range=(-1, 1)) @@ -77,7 +68,7 @@ def sample(class_idx, cfg_scale, num_sampling_steps): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2") + parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT_Clipped") parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="mse") parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) parser.add_argument("--num-classes", type=int, default=1000) @@ -86,12 +77,28 @@ def sample(class_idx, cfg_scale, num_sampling_steps): help="Optional path to a DiT checkpoint (default: auto-download a pre-trained DiT-XL/2 model).") args = parser.parse_args() + + device = "cuda" if torch.cuda.is_available() else "cpu" + + latent_size = args.image_size // 8 + model = DiT_models[args.model]( + input_size=latent_size, + num_classes=args.num_classes + ).to(device) + if args.ckpt: + ckpt_path = args.ckpt or f"DiT-XL-2-{args.image_size}x{args.image_size}.pt" + state_dict = find_model(ckpt_path) + model.load_state_dict(state_dict) + + vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").cpu() + demo = gr.Interface( fn=sample, inputs=[ + gr.Text(label="Text Prompt"), gr.Slider(minimum=1, maximum=1000, value=417, step=1, label="Imagenet class index"), gr.Slider(minimum=1, maximum=20, value=4, step=0.1, label="Cfg scale"), - gr.Slider(minimum=5, maximum=500, value=250, step=1, label="Sampling steps") + gr.Slider(minimum=5, maximum=500, value=50, step=1, label="Sampling steps") ], outputs=[ gr.Image() diff --git a/sample_gradio_vanilla.py b/sample_gradio_vanilla.py new file mode 100644 index 00000000..d6d30e49 --- /dev/null +++ b/sample_gradio_vanilla.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Sample new images from a pre-trained DiT. +""" +import torch +import argparse +import gradio as gr +import torchvision + +from torchvision.utils import make_grid +from diffusers.models import AutoencoderKL + +from modules.diffusion import create_diffusion +from download import find_model +from modules.dit_builder import DiT_models + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + + +def sample(class_idx, cfg_scale, num_sampling_steps): + # Setup PyTorch: + torch.manual_seed(args.seed) + torch.set_grad_enabled(False) + device = "cuda" if torch.cuda.is_available() else "cpu" + + # if args.ckpt is None: + # assert args.model == "DiT-XL/2", "Only DiT-XL/2 models are available for auto-download." + # assert args.image_size in [256, 512] + # assert args.num_classes == 1000 + + # Load model: + latent_size = args.image_size // 8 + model = DiT_models[args.model]( + input_size=latent_size, + num_classes=args.num_classes + ).to(device) + # Auto-download a pre-trained model or load a custom DiT checkpoint from train.py: + if args.ckpt: + ckpt_path = args.ckpt or f"DiT-XL-2-{args.image_size}x{args.image_size}.pt" + state_dict = find_model(ckpt_path) + model.load_state_dict(state_dict) + + model.eval() + diffusion = create_diffusion(str(num_sampling_steps)) + vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) + + # Labels to condition the model with (balloon, banjo, electric guitar, velvet) + class_labels = [class_idx] * 4 + + # Create sampling noise: + n = len(class_labels) + z = torch.randn(n, 4, latent_size, latent_size, device=device) + y = torch.tensor(class_labels, device=device) + + # Setup classifier-free guidance: + z = torch.cat([z, z], 0) + y_null = torch.tensor([1000] * n, device=device) + y = torch.cat([y, y_null], 0) + model_kwargs = dict(y=y, cfg_scale=cfg_scale) + + # Sample images: + samples = diffusion.p_sample_loop( + model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device + ) + samples, _ = samples.chunk(2, dim=0) # Remove null class samples + samples = vae.decode(samples / 0.18215).sample + + # Save and display images: + samples = make_grid(samples, nrow=4, normalize=True, value_range=(-1, 1)) + samples = torchvision.transforms.ToPILImage()(samples) + return samples + # save_image(samples, "sample.png", nrow=4, normalize=True, value_range=(-1, 1)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2") + parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="mse") + parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) + parser.add_argument("--num-classes", type=int, default=1000) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--ckpt", type=str, default=None, + help="Optional path to a DiT checkpoint (default: auto-download a pre-trained DiT-XL/2 model).") + + args = parser.parse_args() + demo = gr.Interface( + fn=sample, + inputs=[ + gr.Slider(minimum=1, maximum=1000, value=417, step=1, label="Imagenet class index"), + gr.Slider(minimum=1, maximum=20, value=4, step=0.1, label="Cfg scale"), + gr.Slider(minimum=5, maximum=500, value=250, step=1, label="Sampling steps") + ], + outputs=[ + gr.Image() + ] + ) + demo.launch() From 75eff2936983d8fae14dcd949d3489f822fa078c Mon Sep 17 00:00:00 2001 From: neon Date: Wed, 22 Feb 2023 18:32:55 +0100 Subject: [PATCH 06/13] added training possibility --- modules/dit_clipped.py | 43 +++++++++++++++++----- modules/training_utils.py | 76 +++++++++++++++++++++++++++++++++++++++ modules/utils.py | 23 +++++++++++- train.py | 76 +-------------------------------------- train_pl_laion.py | 71 ++++++++++++++++++++++++++++++++++++ 5 files changed, 205 insertions(+), 84 deletions(-) create mode 100644 modules/training_utils.py create mode 100644 train_pl_laion.py diff --git a/modules/dit_clipped.py b/modules/dit_clipped.py index 8d3a808c..1a482df1 100644 --- a/modules/dit_clipped.py +++ b/modules/dit_clipped.py @@ -1,14 +1,15 @@ import torch import torch.nn as nn +import pytorch_lightning as pl from timm.models.vision_transformer import PatchEmbed from modules.encoders.modules import FrozenCLIPTextEmbedder -from modules.utils import TimestepEmbedder, LabelEmbedder, DiTBlock, FinalLayer, get_2d_sincos_pos_embed +from modules.utils import TimestepEmbedder, DiTBlock, FinalLayer, get_2d_sincos_pos_embed, process_input -class DiT_Clipped(nn.Module): +class DiT_Clipped(pl.LightningModule): """ Diffusion model with a Transformer backbone and clip encoder. """ @@ -23,7 +24,6 @@ def __init__( num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, - num_classes=1000, learn_sigma=True, clip_version='ViT-L/14' ): @@ -36,7 +36,6 @@ def __init__( self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) self.t_embedder = TimestepEmbedder(hidden_size) - self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) num_patches = self.x_embedder.num_patches # Will use fixed sin-cos embedding: self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) @@ -73,9 +72,6 @@ def _basic_init(module): nn.init.xavier_uniform_(w.view([w.shape[0], -1])) nn.init.constant_(self.x_embedder.proj.bias, 0) - # Initialize label embedding table: - nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) - # Initialize timestep embedding MLP: nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) @@ -119,7 +115,6 @@ def forward(self, x, t, y): """ x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 t = self.t_embedder(t) # (N, D) - # y = self.y_embedder(y, self.training) # (N, D) c = t + y # (N, D) for block in self.blocks: x = block(x, c) # (N, T, D) @@ -144,3 +139,35 @@ def forward_with_cfg(self, x, t, y, cfg_scale): half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) return torch.cat([eps, rest], dim=1) + + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4, weight_decay=0) + return optimizer + + def training_step(self, train_batch, batch_idx): + text, img = process_input(train_batch) + with torch.no_grad(): + x = self.vae.encode(img.to(self.device)).latent_dist.sample().mul_(0.18215) + t = torch.randint(0, self.diffusion.num_timesteps, (x.shape[0],), device=self.device) + + y = self.encode(text).squeeze(1).to(self.device) + + model_kwargs = dict(y=y) + loss_dict = self.diffusion.training_losses(self, x, t, model_kwargs) + loss = loss_dict["loss"].mean() + self.log("train_loss", loss) + return loss + + # def validation_step(self, val_batch, batch_idx): + # x, y = val_batch + # with torch.no_grad(): + # x = self.vae.encode(x).latent_dist.sample().mul_(0.18215) + # t = torch.randint(0, self.diffusion.num_timesteps, (x.shape[0],), device=self.device) + # model_kwargs = dict(y=y) + # loss_dict = self.diffusion.training_losses(self, x, t, model_kwargs) + # loss = loss_dict["loss"].mean() + # self.log("val_loss", loss) + + def backward(self, loss, optimizer, optimizer_idx, *args, **kwargs): + # update_ema(self.ema, self.module) + loss.backward() diff --git a/modules/training_utils.py b/modules/training_utils.py new file mode 100644 index 00000000..3534f702 --- /dev/null +++ b/modules/training_utils.py @@ -0,0 +1,76 @@ +from collections import OrderedDict +from PIL import Image + +import torch +import logging + +import numpy as np +import torch.distributed as dist + + +@torch.no_grad() +def update_ema(ema_model, model, decay=0.9999): + """ + Step the EMA model towards the current model. + """ + ema_model.to(model.device) + ema_params = OrderedDict(ema_model.named_parameters()) + model_params = OrderedDict(model.named_parameters()) + + for name, param in model_params.items(): + ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) + ema_model.cpu() + + +def requires_grad(model, flag=True): + """ + Set requires_grad flag for all parameters in a model. + """ + for p in model.parameters(): + p.requires_grad = flag + + +def cleanup(): + """ + End DDP training. + """ + dist.destroy_process_group() + + +def create_logger(logging_dir): + """ + Create a logger that writes to a log file and stdout. + """ + if dist.get_rank() == 0: # real logger + logging.basicConfig( + level=logging.INFO, + format='[\033[34m%(asctime)s\033[0m] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] + ) + logger = logging.getLogger(__name__) + else: # dummy logger (does nothing) + logger = logging.getLogger(__name__) + logger.addHandler(logging.NullHandler()) + return logger + + +def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) diff --git a/modules/utils.py b/modules/utils.py index 0312eec2..85bd333e 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -1,16 +1,37 @@ +from io import BytesIO + import torch import math +import requests import torch.nn as nn import numpy as np +from PIL import Image from timm.models.vision_transformer import Attention, Mlp - ################################################################################# # Sine/Cosine Positional Embedding Functions # ################################################################################# # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py +from torchvision.transforms import transforms + +from modules.training_utils import center_crop_arr + +transform = transforms.Compose([ + transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, 256)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) +]) + + +def process_input(data_dict): + text, imgs = data_dict["TEXT"], data_dict["URL"] + imgs = [Image.open(BytesIO(requests.get(img).content)) for img in imgs] + imgs = [transform(img) for img in imgs] + return text, torch.stack(imgs) + def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): """ diff --git a/train.py b/train.py index e880c609..ed13268c 100644 --- a/train.py +++ b/train.py @@ -9,10 +9,8 @@ """ import torch import argparse -import logging import os -import numpy as np import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP @@ -20,8 +18,6 @@ from torch.utils.data.distributed import DistributedSampler from torchvision.datasets import ImageFolder from torchvision import transforms -from collections import OrderedDict -from PIL import Image from copy import deepcopy from glob import glob from time import time @@ -29,83 +25,13 @@ from modules.dit_builder import DiT_models from modules.diffusion import create_diffusion from diffusers.models import AutoencoderKL +from modules.training_utils import * # the first flag below was False when we tested this script but True makes A100 training a lot faster: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True -################################################################################# -# Training Helper Functions # -################################################################################# - -@torch.no_grad() -def update_ema(ema_model, model, decay=0.9999): - """ - Step the EMA model towards the current model. - """ - ema_params = OrderedDict(ema_model.named_parameters()) - model_params = OrderedDict(model.named_parameters()) - - for name, param in model_params.items(): - # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed - ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) - - -def requires_grad(model, flag=True): - """ - Set requires_grad flag for all parameters in a model. - """ - for p in model.parameters(): - p.requires_grad = flag - - -def cleanup(): - """ - End DDP training. - """ - dist.destroy_process_group() - - -def create_logger(logging_dir): - """ - Create a logger that writes to a log file and stdout. - """ - if dist.get_rank() == 0: # real logger - logging.basicConfig( - level=logging.INFO, - format='[\033[34m%(asctime)s\033[0m] %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', - handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] - ) - logger = logging.getLogger(__name__) - else: # dummy logger (does nothing) - logger = logging.getLogger(__name__) - logger.addHandler(logging.NullHandler()) - return logger - - -def center_crop_arr(pil_image, image_size): - """ - Center cropping implementation from ADM. - https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 - """ - while min(*pil_image.size) >= 2 * image_size: - pil_image = pil_image.resize( - tuple(x // 2 for x in pil_image.size), resample=Image.BOX - ) - - scale = image_size / min(*pil_image.size) - pil_image = pil_image.resize( - tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC - ) - - arr = np.array(pil_image) - crop_y = (arr.shape[0] - image_size) // 2 - crop_x = (arr.shape[1] - image_size) // 2 - return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) - - ################################################################################# # Training Loop # ################################################################################# diff --git a/train_pl_laion.py b/train_pl_laion.py new file mode 100644 index 00000000..a5ab817c --- /dev/null +++ b/train_pl_laion.py @@ -0,0 +1,71 @@ +import argparse + +import pytorch_lightning as pl + +from torch.utils.data import DataLoader +from datasets import load_dataset + +from modules.dit_builder import DiT_models +from modules.diffusion import create_diffusion +from diffusers.models import AutoencoderKL +from modules.training_utils import * + + +def train_pl(args): + print("Starting training..") + laion_dataset = load_dataset("ChristophSchuhmann/improved_aesthetics_6.5plus")["train"] + + device = torch.device(0) + + latent_size = args.image_size // 8 + model = DiT_models[args.model]( + input_size=latent_size, + ) + # ema = deepcopy(model).cpu() + # requires_grad(ema, False) + + diffusion = create_diffusion(timestep_respacing="") + vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) + # training only + model.diffusion = diffusion + # model.ema = ema + model.vae = vae + + loader_train = DataLoader( + laion_dataset, + batch_size=args.global_batch_size, + shuffle=False, + num_workers=args.num_workers, + pin_memory=True, + drop_last=True + ) + # update_ema(ema, model.module, decay=0) + model.train().to(device) + # ema.eval() + + torch.set_float32_matmul_precision("medium") + trainer = pl.Trainer( + auto_lr_find=True, + enable_checkpointing=True, + detect_anomaly=True, + log_every_n_steps=50, + accelerator='gpu', + devices=1, + max_epochs=args.epochs, + precision=16 + ) + trainer.fit(model, loader_train) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--results-dir", type=str, default="results") + parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT_Clipped") + parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) + parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--global-batch-size", type=int, default=2) + parser.add_argument("--global-seed", type=int, default=0) + parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") # Choice doesn't affect training + parser.add_argument("--num-workers", type=int, default=4) + parsed_args = parser.parse_args() + train_pl(parsed_args) From 10a26dedff7e0857ff26723d9064bd9ae00b8989 Mon Sep 17 00:00:00 2001 From: neon Date: Wed, 22 Feb 2023 18:41:32 +0100 Subject: [PATCH 07/13] cpu optimizations --- train_pl_laion.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/train_pl_laion.py b/train_pl_laion.py index a5ab817c..bcdf0cde 100644 --- a/train_pl_laion.py +++ b/train_pl_laion.py @@ -43,7 +43,7 @@ def train_pl(args): model.train().to(device) # ema.eval() - torch.set_float32_matmul_precision("medium") + torch.set_float32_matmul_precision("high") trainer = pl.Trainer( auto_lr_find=True, enable_checkpointing=True, @@ -52,7 +52,8 @@ def train_pl(args): accelerator='gpu', devices=1, max_epochs=args.epochs, - precision=16 + precision=16, + move_metrics_to_cpu=True ) trainer.fit(model, loader_train) @@ -61,7 +62,7 @@ def train_pl(args): parser = argparse.ArgumentParser() parser.add_argument("--results-dir", type=str, default="results") parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT_Clipped") - parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) + parser.add_argument("--image-size", type=int, choices=[128, 256, 512], default=256) parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--global-batch-size", type=int, default=2) parser.add_argument("--global-seed", type=int, default=0) From 43908bcc99c2f473ffc19e876619f3e78b539be4 Mon Sep 17 00:00:00 2001 From: neon Date: Thu, 23 Feb 2023 12:44:45 +0100 Subject: [PATCH 08/13] finetuned parameters, added xformers --- modules/dit_clipped.py | 6 ++- modules/utils.py | 84 +++++++++++++++++++++++++++++++++++++----- sample_gradio.py | 17 +++------ train_pl_laion.py | 9 ++--- 4 files changed, 86 insertions(+), 30 deletions(-) diff --git a/modules/dit_clipped.py b/modules/dit_clipped.py index 1a482df1..dd801e7c 100644 --- a/modules/dit_clipped.py +++ b/modules/dit_clipped.py @@ -150,7 +150,7 @@ def training_step(self, train_batch, batch_idx): x = self.vae.encode(img.to(self.device)).latent_dist.sample().mul_(0.18215) t = torch.randint(0, self.diffusion.num_timesteps, (x.shape[0],), device=self.device) - y = self.encode(text).squeeze(1).to(self.device) + y = self.encode(text).squeeze(1) model_kwargs = dict(y=y) loss_dict = self.diffusion.training_losses(self, x, t, model_kwargs) @@ -169,5 +169,7 @@ def training_step(self, train_batch, batch_idx): # self.log("val_loss", loss) def backward(self, loss, optimizer, optimizer_idx, *args, **kwargs): - # update_ema(self.ema, self.module) loss.backward() + + def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): + optimizer.zero_grad(set_to_none=True) diff --git a/modules/utils.py b/modules/utils.py index 85bd333e..d57a2ddc 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -7,13 +7,21 @@ import torch.nn as nn import numpy as np from PIL import Image +from inspect import isfunction +from timm.models.vision_transformer import Attention -from timm.models.vision_transformer import Attention, Mlp +try: + import xformers + import xformers.ops + from xformers.components.feedforward import MLP + from xformers.components import Activation + + XFORMERS_AVAILABLE = True +except Exception as e: + print(f"No xformers installation found, {e}") + XFORMERS_AVAILABLE = False + from timm.models.vision_transformer import Mlp -################################################################################# -# Sine/Cosine Positional Embedding Functions # -################################################################################# -# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py from torchvision.transforms import transforms from modules.training_utils import center_crop_arr @@ -26,6 +34,16 @@ ]) +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + def process_input(data_dict): text, imgs = data_dict["TEXT"], data_dict["URL"] imgs = [Image.open(BytesIO(requests.get(img).content)) for img in imgs] @@ -162,9 +180,49 @@ def forward(self, labels, train, force_drop_ids=None): return embeddings -################################################################################# -# Core DiT Model # -################################################################################# +class MemoryEfficientCrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + class DiTBlock(nn.Module): """ @@ -174,11 +232,17 @@ class DiTBlock(nn.Module): def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): super().__init__() self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) + + self.attn = MemoryEfficientCrossAttention(hidden_size, heads=num_heads) \ + if XFORMERS_AVAILABLE else Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") - self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + self.mlp = MLP(dim_model=hidden_size, hidden_layer_multiplier=int(mlp_ratio), + activation=Activation("gelu"), dropout=0) if XFORMERS_AVAILABLE else \ + Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) diff --git a/sample_gradio.py b/sample_gradio.py index 71417e04..9332bd66 100644 --- a/sample_gradio.py +++ b/sample_gradio.py @@ -22,7 +22,7 @@ torch.backends.cudnn.allow_tf32 = True -def sample(prompt, class_idx, cfg_scale, num_sampling_steps): +def sample(prompt, cfg_scale, num_sampling_steps): # Setup PyTorch: torch.manual_seed(args.seed) torch.set_grad_enabled(False) @@ -31,17 +31,13 @@ def sample(prompt, class_idx, cfg_scale, num_sampling_steps): model.eval() diffusion = create_diffusion(str(num_sampling_steps)) - # Labels to condition the model with (balloon, banjo, electric guitar, velvet) - class_labels = [class_idx] * 4 - - # Create sampling noise: - n = len(class_labels) - z = torch.randn(n, 4, latent_size, latent_size, device=device) - y = model.encode(prompt)[0].repeat((4, 1)).to(device) + bsize = 1 + z = torch.randn(bsize, 4, latent_size, latent_size, device=device) + y = model.encode(prompt).squeeze(1).to(device) # Setup classifier-free guidance: z = torch.cat([z, z], 0) - y_null = model.encode("")[0].repeat((4, 1)).to(device) # negative + y_null = model.encode("").squeeze(1).to(device) # negative y = torch.cat([y, y_null], 0) model_kwargs = dict(y=y, cfg_scale=cfg_scale) @@ -71,7 +67,6 @@ def sample(prompt, class_idx, cfg_scale, num_sampling_steps): parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT_Clipped") parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="mse") parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) - parser.add_argument("--num-classes", type=int, default=1000) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--ckpt", type=str, default=None, help="Optional path to a DiT checkpoint (default: auto-download a pre-trained DiT-XL/2 model).") @@ -83,7 +78,6 @@ def sample(prompt, class_idx, cfg_scale, num_sampling_steps): latent_size = args.image_size // 8 model = DiT_models[args.model]( input_size=latent_size, - num_classes=args.num_classes ).to(device) if args.ckpt: ckpt_path = args.ckpt or f"DiT-XL-2-{args.image_size}x{args.image_size}.pt" @@ -96,7 +90,6 @@ def sample(prompt, class_idx, cfg_scale, num_sampling_steps): fn=sample, inputs=[ gr.Text(label="Text Prompt"), - gr.Slider(minimum=1, maximum=1000, value=417, step=1, label="Imagenet class index"), gr.Slider(minimum=1, maximum=20, value=4, step=0.1, label="Cfg scale"), gr.Slider(minimum=5, maximum=500, value=50, step=1, label="Sampling steps") ], diff --git a/train_pl_laion.py b/train_pl_laion.py index bcdf0cde..ee063a88 100644 --- a/train_pl_laion.py +++ b/train_pl_laion.py @@ -21,14 +21,11 @@ def train_pl(args): model = DiT_models[args.model]( input_size=latent_size, ) - # ema = deepcopy(model).cpu() - # requires_grad(ema, False) diffusion = create_diffusion(timestep_respacing="") vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) # training only model.diffusion = diffusion - # model.ema = ema model.vae = vae loader_train = DataLoader( @@ -52,8 +49,7 @@ def train_pl(args): accelerator='gpu', devices=1, max_epochs=args.epochs, - precision=16, - move_metrics_to_cpu=True + precision=16 if args.precision == "fp16" else 32, ) trainer.fit(model, loader_train) @@ -64,9 +60,10 @@ def train_pl(args): parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT_Clipped") parser.add_argument("--image-size", type=int, choices=[128, 256, 512], default=256) parser.add_argument("--epochs", type=int, default=100) - parser.add_argument("--global-batch-size", type=int, default=2) + parser.add_argument("--global-batch-size", type=int, default=4) parser.add_argument("--global-seed", type=int, default=0) parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") # Choice doesn't affect training parser.add_argument("--num-workers", type=int, default=4) + parser.add_argument("--precision", type=str, choices=["fp16", "fp32"], default="fp16") parsed_args = parser.parse_args() train_pl(parsed_args) From 4f78e549371253643270823835338ac300032aa4 Mon Sep 17 00:00:00 2001 From: neon Date: Thu, 23 Feb 2023 15:11:23 +0100 Subject: [PATCH 09/13] training stabilized for low-vram systems --- environment.yml | 1 + modules/diffusion/respace.py | 4 +-- modules/dit_clipped.py | 7 +++-- modules/encoders/modules.py | 52 +++++++----------------------------- modules/utils.py | 21 +++++++++++++-- train_pl_laion.py | 7 ++--- 6 files changed, 40 insertions(+), 52 deletions(-) diff --git a/environment.yml b/environment.yml index d46c3387..5c9b35fa 100644 --- a/environment.yml +++ b/environment.yml @@ -12,3 +12,4 @@ dependencies: - diffusers - accelerate - gradio + - open_clip_torch diff --git a/modules/diffusion/respace.py b/modules/diffusion/respace.py index 1ce11bcd..2169b2c7 100644 --- a/modules/diffusion/respace.py +++ b/modules/diffusion/respace.py @@ -4,7 +4,7 @@ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py import numpy as np -import torch as th +import torch from .gaussian_diffusion import GaussianDiffusion @@ -122,7 +122,7 @@ def __init__(self, model, timestep_map, original_num_steps): self.original_num_steps = original_num_steps def __call__(self, x, ts, **kwargs): - map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + map_tensor = torch.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) new_ts = map_tensor[ts] # if self.rescale_timesteps: # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) diff --git a/modules/dit_clipped.py b/modules/dit_clipped.py index dd801e7c..dcb1f4d6 100644 --- a/modules/dit_clipped.py +++ b/modules/dit_clipped.py @@ -49,9 +49,11 @@ def __init__( self.initialize_weights() + @torch.no_grad() def encode(self, text_prompt): + self.encoder.cpu() c = self.encoder.encode(text_prompt) - return c + return c.to(self.device) def initialize_weights(self): # Initialize transformer layers: @@ -147,7 +149,8 @@ def configure_optimizers(self): def training_step(self, train_batch, batch_idx): text, img = process_input(train_batch) with torch.no_grad(): - x = self.vae.encode(img.to(self.device)).latent_dist.sample().mul_(0.18215) + self.vae.cpu().to(torch.float32) + x = self.vae.encode(img).latent_dist.sample().mul_(0.18215).to(self.device) t = torch.randint(0, self.diffusion.num_timesteps, (x.shape[0],), device=self.device) y = self.encode(text).squeeze(1) diff --git a/modules/encoders/modules.py b/modules/encoders/modules.py index b1d9ce27..b0a49e43 100644 --- a/modules/encoders/modules.py +++ b/modules/encoders/modules.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from functools import partial -import clip +import open_clip from einops import rearrange, repeat from transformers import CLIPTokenizer, CLIPTextModel import kornia @@ -35,7 +35,7 @@ def forward(self, batch, key=None): class TransformerEmbedder(AbstractEncoder): """Some transformer encoder layers""" - def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cpu"): super().__init__() self.device = device self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, @@ -53,7 +53,7 @@ def encode(self, x): class BERTTokenizer(AbstractEncoder): """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" - def __init__(self, device="cuda", vq_interface=True, max_length=77): + def __init__(self, device="cpu", vq_interface=True, max_length=77): super().__init__() from transformers import BertTokenizerFast self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") @@ -82,7 +82,7 @@ class BERTEmbedder(AbstractEncoder): """Uses the BERT tokenizr model and add some transformer encoder layers""" def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, - device="cuda", use_tokenizer=True, embedding_dropout=0.0): + device="cpu", use_tokenizer=True, embedding_dropout=0.0): super().__init__() self.use_tknz_fn = use_tokenizer if self.use_tknz_fn: @@ -139,7 +139,7 @@ def encode(self, x): class FrozenCLIPEmbedder(AbstractEncoder): """Uses the CLIP transformer encoder for text (from Hugging Face)""" - def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): + def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77): super().__init__() self.tokenizer = CLIPTokenizer.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version) @@ -170,9 +170,10 @@ class FrozenCLIPTextEmbedder(nn.Module): Uses the CLIP transformer encoder for text. """ - def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): + def __init__(self, version='ViT-B-32-quickgelu', device="cpu", max_length=77, n_repeat=1, normalize=True): super().__init__() - self.model, _ = clip.load(version, jit=False, device="cpu") + self.model = open_clip.create_model(version, pretrained='laion400m_e32', jit=False, device="cpu") + self.tokenizer = open_clip.get_tokenizer('ViT-B-32-quickgelu') self.device = device self.max_length = max_length self.n_repeat = n_repeat @@ -184,7 +185,7 @@ def freeze(self): param.requires_grad = False def forward(self, text): - tokens = clip.tokenize(text).to(self.device) + tokens = self.tokenizer(text).to(self.device) z = self.model.encode_text(tokens) if self.normalize: z = z / torch.linalg.norm(z, dim=1, keepdim=True) @@ -198,41 +199,6 @@ def encode(self, text): return z -class FrozenClipImageEmbedder(nn.Module): - """ - Uses the CLIP image encoder. - """ - - def __init__( - self, - model, - jit=False, - device='cuda' if torch.cuda.is_available() else 'cpu', - antialias=False, - ): - super().__init__() - self.model, _ = clip.load(name=model, device=device, jit=jit) - - self.antialias = antialias - - self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) - self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) - - def preprocess(self, x): - # normalize to [0,1] - x = kornia.geometry.resize(x, (224, 224), - interpolation='bicubic', align_corners=True, - antialias=self.antialias) - x = (x + 1.) / 2. - # renormalize according to clip - x = kornia.enhance.normalize(x, self.mean, self.std) - return x - - def forward(self, x): - # x is assumed to be in range [-1,1] - return self.model.encode_image(self.preprocess(x)) - - if __name__ == "__main__": model = FrozenCLIPEmbedder() print(model) diff --git a/modules/utils.py b/modules/utils.py index d57a2ddc..a8c1019f 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -1,3 +1,4 @@ +import time from io import BytesIO import torch @@ -46,7 +47,23 @@ def default(val, d): def process_input(data_dict): text, imgs = data_dict["TEXT"], data_dict["URL"] - imgs = [Image.open(BytesIO(requests.get(img).content)) for img in imgs] + for i, img in enumerate(imgs): + r = None + for _ in range(3): # 3 tryouts + try: + r = requests.get(img).content + break + except Exception as e: + print(e, img) + time.sleep(0.5) + + r = BytesIO(r) + try: + imgs[i] = Image.open(r) + except Exception as e: + print(e, img) + imgs[i] = Image.open("forest.jpg").convert('RGB') + text[i] = "forest" imgs = [transform(img) for img in imgs] return text, torch.stack(imgs) @@ -240,7 +257,7 @@ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = MLP(dim_model=hidden_size, hidden_layer_multiplier=int(mlp_ratio), - activation=Activation("gelu"), dropout=0) if XFORMERS_AVAILABLE else \ + activation=Activation("gelu"), dropout=0) if XFORMERS_AVAILABLE else \ Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) self.adaLN_modulation = nn.Sequential( diff --git a/train_pl_laion.py b/train_pl_laion.py index ee063a88..947ef6ee 100644 --- a/train_pl_laion.py +++ b/train_pl_laion.py @@ -23,7 +23,7 @@ def train_pl(args): ) diffusion = create_diffusion(timestep_respacing="") - vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) + vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").cpu() # training only model.diffusion = diffusion model.vae = vae @@ -31,10 +31,10 @@ def train_pl(args): loader_train = DataLoader( laion_dataset, batch_size=args.global_batch_size, - shuffle=False, + shuffle=True, num_workers=args.num_workers, pin_memory=True, - drop_last=True + drop_last=True, ) # update_ema(ema, model.module, decay=0) model.train().to(device) @@ -51,6 +51,7 @@ def train_pl(args): max_epochs=args.epochs, precision=16 if args.precision == "fp16" else 32, ) + trainer.fit(model, loader_train) From 89d8687fa796870f9ec24c753cdde1734e55db9b Mon Sep 17 00:00:00 2001 From: neon Date: Thu, 23 Feb 2023 20:04:00 +0100 Subject: [PATCH 10/13] training stabilized for low-vram systems, changed dataset --- modules/dit_builder.py | 2 +- modules/dit_clipped.py | 16 +++++++++----- modules/encoders/modules.py | 6 +++++ modules/utils.py | 14 +++++++----- train_pl_laion.py => train_pl.py | 38 +++++++++++++++++++++++++++----- 5 files changed, 59 insertions(+), 17 deletions(-) rename train_pl_laion.py => train_pl.py (63%) diff --git a/modules/dit_builder.py b/modules/dit_builder.py index 886c2896..22c56040 100644 --- a/modules/dit_builder.py +++ b/modules/dit_builder.py @@ -203,7 +203,7 @@ def DiT_S_8(**kwargs): def DiT_clipper_builder(**kwargs): - return DiT_Clipped(depth=28, hidden_size=768, patch_size=2, num_heads=16, **kwargs) + return DiT_Clipped(depth=28, hidden_size=512, patch_size=2, num_heads=16, **kwargs) DiT_models = { diff --git a/modules/dit_clipped.py b/modules/dit_clipped.py index dcb1f4d6..c6762153 100644 --- a/modules/dit_clipped.py +++ b/modules/dit_clipped.py @@ -6,7 +6,7 @@ from timm.models.vision_transformer import PatchEmbed from modules.encoders.modules import FrozenCLIPTextEmbedder -from modules.utils import TimestepEmbedder, DiTBlock, FinalLayer, get_2d_sincos_pos_embed, process_input +from modules.utils import TimestepEmbedder, DiTBlock, FinalLayer, get_2d_sincos_pos_embed, process_input_laion class DiT_Clipped(pl.LightningModule): @@ -25,7 +25,7 @@ def __init__( mlp_ratio=4.0, class_dropout_prob=0.1, learn_sigma=True, - clip_version='ViT-L/14' + clip_version='ViT-B-32-quickgelu' ): super().__init__() self.learn_sigma = learn_sigma @@ -51,7 +51,6 @@ def __init__( @torch.no_grad() def encode(self, text_prompt): - self.encoder.cpu() c = self.encoder.encode(text_prompt) return c.to(self.device) @@ -147,13 +146,18 @@ def configure_optimizers(self): return optimizer def training_step(self, train_batch, batch_idx): - text, img = process_input(train_batch) + # text, img = process_input_laion(train_batch) + y, img = torch.stack(train_batch["y"][0]).permute(1, 0), \ + torch.stack([torch.stack([torch.stack(y) for y in x]) for x in train_batch["img"]]).permute(3, 0, 1, 2) + + y, img = y.to(self.device).to(self.dtype), img.cpu().to(torch.float32) + with torch.no_grad(): self.vae.cpu().to(torch.float32) - x = self.vae.encode(img).latent_dist.sample().mul_(0.18215).to(self.device) + x = self.vae.encode(img).latent_dist.sample().mul_(0.18215).to(self.device).to(self.dtype) t = torch.randint(0, self.diffusion.num_timesteps, (x.shape[0],), device=self.device) - y = self.encode(text).squeeze(1) + # y = self.encode(text).squeeze(1) model_kwargs = dict(y=y) loss_dict = self.diffusion.training_losses(self, x, t, model_kwargs) diff --git a/modules/encoders/modules.py b/modules/encoders/modules.py index b0a49e43..b94ec552 100644 --- a/modules/encoders/modules.py +++ b/modules/encoders/modules.py @@ -198,6 +198,12 @@ def encode(self, text): z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) return z + def to(self, device, dtype=None, non_blocking=None): + super(FrozenCLIPTextEmbedder, self).to(device) + self.model.to(device) + self.device = device + + if __name__ == "__main__": model = FrozenCLIPEmbedder() diff --git a/modules/utils.py b/modules/utils.py index a8c1019f..ada65faf 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -45,10 +45,14 @@ def default(val, d): return d() if isfunction(d) else d -def process_input(data_dict): - text, imgs = data_dict["TEXT"], data_dict["URL"] +def process_input_diff(data_dict): + texts, imgs = data_dict["prompt"], data_dict["image"] + return texts, torch.stack([transform(img.convert('RGB')) for img in imgs]) + + +def process_input_laion(data_dict): + texts, imgs = data_dict["TEXT"], data_dict["URL"] for i, img in enumerate(imgs): - r = None for _ in range(3): # 3 tryouts try: r = requests.get(img).content @@ -63,9 +67,9 @@ def process_input(data_dict): except Exception as e: print(e, img) imgs[i] = Image.open("forest.jpg").convert('RGB') - text[i] = "forest" + texts[i] = "forest" imgs = [transform(img) for img in imgs] - return text, torch.stack(imgs) + return texts, torch.stack(imgs) def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): diff --git a/train_pl_laion.py b/train_pl.py similarity index 63% rename from train_pl_laion.py rename to train_pl.py index 947ef6ee..c13af69d 100644 --- a/train_pl_laion.py +++ b/train_pl.py @@ -1,9 +1,12 @@ import argparse +import os.path import pytorch_lightning as pl +import torch from torch.utils.data import DataLoader -from datasets import load_dataset +from datasets import load_dataset, load_from_disk +from torchvision.transforms import transforms from modules.dit_builder import DiT_models from modules.diffusion import create_diffusion @@ -11,16 +14,33 @@ from modules.training_utils import * +def m(x): + img = transform(x["image"].convert('RGB')).cpu() + t = model.encode(x["prompt"]).squeeze(1).cpu() + return {"y": t, "img": img} + + def train_pl(args): + global model print("Starting training..") - laion_dataset = load_dataset("ChristophSchuhmann/improved_aesthetics_6.5plus")["train"] - device = torch.device(0) latent_size = args.image_size // 8 model = DiT_models[args.model]( input_size=latent_size, ) + if not os.path.exists("pl_dataset"): + dataset = load_dataset("poloclub/diffusiondb", name="2m_first_5k")["train"] + model.encoder.to(device) + dataset = dataset.map(m, remove_columns=['image', 'prompt', 'seed', 'step', 'cfg', 'sampler', 'width', 'height', + 'user_name', 'timestamp', 'image_nsfw', 'prompt_nsfw'], batch_size=100, + drop_last_batch=True) + dataset.save_to_disk("pl_dataset") + exit() + else: + dataset = load_from_disk("pl_dataset") # already preloaded + + del model.encoder diffusion = create_diffusion(timestep_respacing="") vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").cpu() @@ -29,13 +49,13 @@ def train_pl(args): model.vae = vae loader_train = DataLoader( - laion_dataset, + dataset, batch_size=args.global_batch_size, shuffle=True, - num_workers=args.num_workers, pin_memory=True, drop_last=True, ) + # update_ema(ema, model.module, decay=0) model.train().to(device) # ema.eval() @@ -67,4 +87,12 @@ def train_pl(args): parser.add_argument("--num-workers", type=int, default=4) parser.add_argument("--precision", type=str, choices=["fp16", "fp32"], default="fp16") parsed_args = parser.parse_args() + + transform = transforms.Compose([ + transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, parsed_args.image_size)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + ]) + train_pl(parsed_args) From 9c8271f046216262dece6fe4166b73570f3511e7 Mon Sep 17 00:00:00 2001 From: neon Date: Thu, 2 Mar 2023 16:10:39 +0100 Subject: [PATCH 11/13] training stabilized for low-vram systems, changed dataset and architecture slightly --- modules/dit_builder.py | 2 +- modules/dit_clipped.py | 45 +++---- modules/encoders/modules.py | 8 +- modules/image_cap_dataset.py | 47 +++++++ modules/utils.py | 242 +++++++++++++++++++++++++++++++---- sample_gradio.py | 23 ++-- train_pl.py | 35 ++--- 7 files changed, 320 insertions(+), 82 deletions(-) create mode 100644 modules/image_cap_dataset.py diff --git a/modules/dit_builder.py b/modules/dit_builder.py index 22c56040..f7e9001d 100644 --- a/modules/dit_builder.py +++ b/modules/dit_builder.py @@ -203,7 +203,7 @@ def DiT_S_8(**kwargs): def DiT_clipper_builder(**kwargs): - return DiT_Clipped(depth=28, hidden_size=512, patch_size=2, num_heads=16, **kwargs) + return DiT_Clipped(depth=16, hidden_size=768, patch_size=2, num_heads=12, **kwargs) DiT_models = { diff --git a/modules/dit_clipped.py b/modules/dit_clipped.py index c6762153..1e10e589 100644 --- a/modules/dit_clipped.py +++ b/modules/dit_clipped.py @@ -5,7 +5,7 @@ from timm.models.vision_transformer import PatchEmbed -from modules.encoders.modules import FrozenCLIPTextEmbedder +from modules.encoders.modules import FrozenCLIPEmbedder from modules.utils import TimestepEmbedder, DiTBlock, FinalLayer, get_2d_sincos_pos_embed, process_input_laion @@ -20,12 +20,13 @@ def __init__( patch_size=2, in_channels=4, hidden_size=1152, + context_dim=768, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, learn_sigma=True, - clip_version='ViT-B-32-quickgelu' + clip_version='openai/clip-vit-large-patch14' ): super().__init__() self.learn_sigma = learn_sigma @@ -41,16 +42,18 @@ def __init__( self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) self.blocks = nn.ModuleList([ - DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) + DiTBlock(hidden_size, num_heads, context_dim=context_dim, mlp_ratio=mlp_ratio) for _ in range(depth) ]) self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) - self.encoder = FrozenCLIPTextEmbedder(clip_version) + self.encoder = FrozenCLIPEmbedder(clip_version) self.initialize_weights() @torch.no_grad() - def encode(self, text_prompt): + def encode(self, text_prompt, device=None): + device = device if device is not None else self.device + self.encoder.to(device) c = self.encoder.encode(text_prompt) return c.to(self.device) @@ -103,23 +106,21 @@ def unpatchify(self, x): imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) return imgs - def forward(self, x, t, y): + def forward(self, x, t, context): """ Forward pass of DiT. x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) t: (N,) tensor of diffusion timesteps - y: (N,) tensor of class labels - - typical values: - D: 1152 (576 * 2) - N: 8 (4 * 2) + context: (N, context_length, context_dim) embedding context """ x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 t = self.t_embedder(t) # (N, D) - c = t + y # (N, D) for block in self.blocks: - x = block(x, c) # (N, T, D) - x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) + x = block(x, t, context) # (N, T, D) + + # left context in, but it's not used atm + x = self.final_layer(x, t, context) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) return x @@ -146,20 +147,20 @@ def configure_optimizers(self): return optimizer def training_step(self, train_batch, batch_idx): - # text, img = process_input_laion(train_batch) - y, img = torch.stack(train_batch["y"][0]).permute(1, 0), \ - torch.stack([torch.stack([torch.stack(y) for y in x]) for x in train_batch["img"]]).permute(3, 0, 1, 2) - - y, img = y.to(self.device).to(self.dtype), img.cpu().to(torch.float32) - + img, context = train_batch with torch.no_grad(): + context = self.encode(context, device=torch.device("cpu")) + context, img = context.to(self.device).to(self.dtype), img.cpu().to(torch.float32) self.vae.cpu().to(torch.float32) x = self.vae.encode(img).latent_dist.sample().mul_(0.18215).to(self.device).to(self.dtype) + t = torch.randint(0, self.diffusion.num_timesteps, (x.shape[0],), device=self.device) - # y = self.encode(text).squeeze(1) + # I'm paranoid + context.requires_grad = True + x.requires_grad = True - model_kwargs = dict(y=y) + model_kwargs = dict(context=context) loss_dict = self.diffusion.training_losses(self, x, t, model_kwargs) loss = loss_dict["loss"].mean() self.log("train_loss", loss) diff --git a/modules/encoders/modules.py b/modules/encoders/modules.py index b94ec552..ccb927fe 100644 --- a/modules/encoders/modules.py +++ b/modules/encoders/modules.py @@ -136,7 +136,7 @@ def encode(self, x): return self(x) -class FrozenCLIPEmbedder(AbstractEncoder): +class FrozenCLIPEmbedder(AbstractEncoder): # h """Uses the CLIP transformer encoder for text (from Hugging Face)""" def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77): @@ -164,6 +164,11 @@ def forward(self, text): def encode(self, text): return self(text) + def to(self, device, dtype=None, non_blocking=None): + super(FrozenCLIPEmbedder, self).to(device) + self.transformer.to(device) + self.device = device + class FrozenCLIPTextEmbedder(nn.Module): """ @@ -204,7 +209,6 @@ def to(self, device, dtype=None, non_blocking=None): self.device = device - if __name__ == "__main__": model = FrozenCLIPEmbedder() print(model) diff --git a/modules/image_cap_dataset.py b/modules/image_cap_dataset.py new file mode 100644 index 00000000..7cb295c9 --- /dev/null +++ b/modules/image_cap_dataset.py @@ -0,0 +1,47 @@ +import os +import random + +from datasets import load_dataset + +from torch.utils.data import Dataset +from PIL import Image +from modules.training_utils import center_crop_arr +from torchvision.transforms import transforms + + +class ImageCaptionDataset(Dataset): + def __init__(self, hf_dataset_name, token, transform=None, target_transform=None, res=256): + self.hf_dataset = load_dataset(hf_dataset_name, use_auth_token=token)["test"] + self.token = token + + self.transform = transform if transform is not None else transforms.Compose([ + transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, res)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + ]) + self.target_transform = target_transform + + def __len__(self): + return len(self.hf_dataset) + + def get_item(self, idx): + if random.random() < 0.5: + image, prompt = self.hf_dataset[idx]["image_0"], self.hf_dataset[idx]["caption_0"] + else: + image, prompt = self.hf_dataset[idx]["image_1"], self.hf_dataset[idx]["caption_1"] + + image = image.rotate(90).convert("RGB") + + prompt = prompt.lower() + + if self.transform: + image = self.transform(image) + return image, prompt + + def __getitem__(self, idx): + try: + image, prompt = self.get_item(idx) + except Exception as e: + print(e) + image, prompt = self.get_item(0) + return image, prompt diff --git a/modules/utils.py b/modules/utils.py index ada65faf..4b7db503 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -1,15 +1,16 @@ +import math import time +from inspect import isfunction from io import BytesIO -import torch -import math +import numpy as np import requests - +import torch import torch.nn as nn -import numpy as np +import torch.nn.functional as F from PIL import Image -from inspect import isfunction -from timm.models.vision_transformer import Attention +from einops import rearrange, repeat +from torch import einsum try: import xformers @@ -23,16 +24,17 @@ XFORMERS_AVAILABLE = False from timm.models.vision_transformer import Mlp -from torchvision.transforms import transforms - -from modules.training_utils import center_crop_arr -transform = transforms.Compose([ - transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, 256)), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) -]) +# transform = transforms.Compose([ +# transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, 256)), +# transforms.RandomHorizontalFlip(), +# transforms.ToTensor(), +# transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) +# ]) +# def m(x): +# img = transform(x["image"].convert('RGB')).cpu() +# t = model.encode(x["prompt"]).squeeze(1).cpu() +# return {"y": t, "img": img} def exists(val): @@ -202,7 +204,7 @@ def forward(self, labels, train, force_drop_ids=None): class MemoryEfficientCrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, qkv_bias=False): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) @@ -210,9 +212,9 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. self.heads = heads self.dim_head = dim_head - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias) + self.to_k = nn.Linear(context_dim, inner_dim, bias=qkv_bias) + self.to_v = nn.Linear(context_dim, inner_dim, bias=qkv_bias) self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) self.attention_op = None @@ -245,19 +247,203 @@ def forward(self, x, context=None, mask=None): return self.to_out(out) +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., qkv_bias=False): + super().__init__() + self.dim_head = 40 + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias) + self.to_k = nn.Linear(context_dim, inner_dim, bias=qkv_bias) + self.to_v = nn.Linear(context_dim, inner_dim, bias=qkv_bias) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + try: + return self.fast_forward(x, context, mask) + except: + return self.slow_forward(x, context, mask) + + def fast_forward(self, x, context=None, mask=None, dtype=None): + # return self.light_forward(x, context=context, mask=mask, dtype=dtype) + h = self.heads + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) # (8, 4096, 40) + sim *= self.scale + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + del mask + + sim[sim.shape[0] // 2:] = sim[sim.shape[0] // 2:].softmax(dim=-1) + sim[:sim.shape[0] // 2] = sim[:sim.shape[0] // 2].softmax(dim=-1) + + sim = einsum('b i j, b j d -> b i d', sim, v) + sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h) + del h, v + + return self.to_out(sim) + + def slow_forward(self, x, context=None, mask=None): + h = self.heads + device = x.device + dtype = x.dtype + q_proj = self.to_q(x) + context = default(context, x) + k_proj = self.to_k(context) + v_proj = self.to_v(context) + + del context, x + try: + stats = torch.cuda.memory_stats(device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + # mem counted before q k v are generated because they're gonna be stored on cpu + allocatable_mem = int(mem_free_total // 2) + 1 if dtype == torch.float16 else \ + int(mem_free_total // 4) + 1 + required_mem = int( + q_proj.shape[0] * q_proj.shape[1] * q_proj.shape[2] * 4 * 2 * 50) if dtype == torch.float16 \ + else int(q_proj.shape[0] * q_proj.shape[1] * q_proj.shape[2] * 8 * 2 * 50) # the last 50 is for speed + chunk_split = (required_mem // allocatable_mem) * 2 if required_mem > allocatable_mem else 1 + except Exception as e: + chunk_split = 1 + # print(e) + + # print(f"allocatable_mem: {allocatable_mem}, required_mem: {required_mem}, chunk_split: {chunk_split}") + # print(q.shape) torch.Size([1, 4096, 320]) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_proj, k_proj, v_proj)) + del q_proj, k_proj, v_proj + torch.cuda.empty_cache() + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=torch.device("cpu")) + mp = q.shape[1] // chunk_split + for i in range(0, q.shape[1], mp): + q, k = q.to(device), k.to(device) + s1 = einsum('b i d, b j d -> b i j', q[:, i:i + mp], k) + q, k = q.cpu(), k.cpu() + s1 *= self.scale + s2 = F.softmax(s1, dim=-1) + del s1 + r1[:, i:i + mp] = einsum('b i j, b j d -> b i d', s2, v).cpu() + del s2 + r2 = rearrange(r1.to(device), '(b h) n d -> b n (h d)', h=h).to(device) + del r1, q, k, v + + return self.to_out(r2) + + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + Parameters: + dim (:obj:`int`): The number of channels in the input and output. + n_heads (:obj:`int`): The number of heads to use for multi-head attention. + d_head (:obj:`int`): The number of channels in each head. + dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention. + gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network. + checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing. + """ + + def __init__( + self, + dim: int, + n_heads: int, + d_head: int = 64, + dropout=0.0, + context_dim=None, + gated_ff: bool = True, + checkpoint: bool = True, + qkv_bias=False + ): + super().__init__() + AttentionBuilder = MemoryEfficientCrossAttention if XFORMERS_AVAILABLE else CrossAttention + self.attn1 = AttentionBuilder( + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, qkv_bias=qkv_bias) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = AttentionBuilder( + query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout, qkv_bias=False) + self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.norm3 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.checkpoint = checkpoint + + def _set_attention_slice(self, slice_size): + self.attn1._slice_size = slice_size + self.attn2._slice_size = slice_size + + def forward(self, hidden_states, context=None): + hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states + hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states + hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + return hidden_states + + class DiTBlock(nn.Module): """ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. """ - def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): + def __init__(self, hidden_size, num_heads, context_dim=None, mlp_ratio=4.0, **block_kwargs): super().__init__() - self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.attn = MemoryEfficientCrossAttention(hidden_size, heads=num_heads) \ - if XFORMERS_AVAILABLE else Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) + self.attn = BasicTransformerBlock(dim=hidden_size, n_heads=num_heads, context_dim=context_dim, qkv_bias=True) - self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = MLP(dim_model=hidden_size, hidden_layer_multiplier=int(mlp_ratio), @@ -269,10 +455,10 @@ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) - def forward(self, x, c): + def forward(self, x, c, context=None): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) - x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) - x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(x, shift_msa, scale_msa), context=context) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(x, shift_mlp, scale_mlp)) return x @@ -290,7 +476,7 @@ def __init__(self, hidden_size, patch_size, out_channels): nn.Linear(hidden_size, 2 * hidden_size, bias=True) ) - def forward(self, x, c): + def forward(self, x, c, context=None): shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) diff --git a/sample_gradio.py b/sample_gradio.py index 9332bd66..68091561 100644 --- a/sample_gradio.py +++ b/sample_gradio.py @@ -6,6 +6,8 @@ """ Sample new images from a pre-trained DiT. """ +import random + import torch import argparse import gradio as gr @@ -22,9 +24,9 @@ torch.backends.cudnn.allow_tf32 = True -def sample(prompt, cfg_scale, num_sampling_steps): +def sample(prompt, cfg_scale, num_sampling_steps, seed): # Setup PyTorch: - torch.manual_seed(args.seed) + torch.manual_seed(seed) torch.set_grad_enabled(False) model.to(device) @@ -33,11 +35,11 @@ def sample(prompt, cfg_scale, num_sampling_steps): bsize = 1 z = torch.randn(bsize, 4, latent_size, latent_size, device=device) - y = model.encode(prompt).squeeze(1).to(device) + y = model.encode(prompt).squeeze(1).repeat(bsize, 1, 1).to(device) # Setup classifier-free guidance: z = torch.cat([z, z], 0) - y_null = model.encode("").squeeze(1).to(device) # negative + y_null = model.encode("").squeeze(1).repeat(bsize, 1, 1).to(device) # negative y = torch.cat([y, y_null], 0) model_kwargs = dict(y=y, cfg_scale=cfg_scale) @@ -80,18 +82,23 @@ def sample(prompt, cfg_scale, num_sampling_steps): input_size=latent_size, ).to(device) if args.ckpt: - ckpt_path = args.ckpt or f"DiT-XL-2-{args.image_size}x{args.image_size}.pt" + print(f"Loading {args.ckpt}") + ckpt_path = args.ckpt state_dict = find_model(ckpt_path) - model.load_state_dict(state_dict) + if 'pytorch-lightning_version' in state_dict.keys(): + state_dict = state_dict["state_dict"] + model.load_state_dict(state_dict, strict=False) vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").cpu() demo = gr.Interface( fn=sample, inputs=[ - gr.Text(label="Text Prompt"), + gr.Text(label="Text Prompt", value="an apple"), gr.Slider(minimum=1, maximum=20, value=4, step=0.1, label="Cfg scale"), - gr.Slider(minimum=5, maximum=500, value=50, step=1, label="Sampling steps") + gr.Slider(minimum=5, maximum=500, value=50, step=1, label="Sampling steps"), + gr.Slider(minimum=1, maximum=9223372036854775807, value=4, step=1, + label="Seed"), ], outputs=[ gr.Image() diff --git a/train_pl.py b/train_pl.py index c13af69d..63c742ee 100644 --- a/train_pl.py +++ b/train_pl.py @@ -1,23 +1,18 @@ import argparse -import os.path import pytorch_lightning as pl -import torch +import torchvision +from pytorch_lightning.callbacks import ModelCheckpoint from torch.utils.data import DataLoader -from datasets import load_dataset, load_from_disk from torchvision.transforms import transforms from modules.dit_builder import DiT_models from modules.diffusion import create_diffusion from diffusers.models import AutoencoderKL -from modules.training_utils import * - -def m(x): - img = transform(x["image"].convert('RGB')).cpu() - t = model.encode(x["prompt"]).squeeze(1).cpu() - return {"y": t, "img": img} +from modules.image_cap_dataset import ImageCaptionDataset +from modules.training_utils import * def train_pl(args): @@ -29,24 +24,17 @@ def train_pl(args): model = DiT_models[args.model]( input_size=latent_size, ) - if not os.path.exists("pl_dataset"): - dataset = load_dataset("poloclub/diffusiondb", name="2m_first_5k")["train"] - model.encoder.to(device) - dataset = dataset.map(m, remove_columns=['image', 'prompt', 'seed', 'step', 'cfg', 'sampler', 'width', 'height', - 'user_name', 'timestamp', 'image_nsfw', 'prompt_nsfw'], batch_size=100, - drop_last_batch=True) - dataset.save_to_disk("pl_dataset") - exit() - else: - dataset = load_from_disk("pl_dataset") # already preloaded - - del model.encoder + dataset = ImageCaptionDataset(args.hf_dataset_name, args.token, res=args.image_size) + + # del model.encoder diffusion = create_diffusion(timestep_respacing="") vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").cpu() # training only model.diffusion = diffusion model.vae = vae + model_ckpt = ModelCheckpoint(dirpath="ckpts/", monitor=None, save_top_k=-1, save_last=True, every_n_train_steps=1000 + ) loader_train = DataLoader( dataset, @@ -70,6 +58,7 @@ def train_pl(args): devices=1, max_epochs=args.epochs, precision=16 if args.precision == "fp16" else 32, + callbacks=[model_ckpt] ) trainer.fit(model, loader_train) @@ -77,6 +66,10 @@ def train_pl(args): if __name__ == "__main__": parser = argparse.ArgumentParser() + + parser.add_argument("--hf_dataset_name", type=str, default="facebook/winoground") + parser.add_argument("--token", type=str) + parser.add_argument("--results-dir", type=str, default="results") parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT_Clipped") parser.add_argument("--image-size", type=int, choices=[128, 256, 512], default=256) From ba0729b210905d8975adc6106a3d6ca1ca1f307f Mon Sep 17 00:00:00 2001 From: neon Date: Fri, 10 Mar 2023 09:40:39 +0100 Subject: [PATCH 12/13] working on training issues, architectural decisions --- modules/dit_builder.py | 2 +- modules/dit_clipped.py | 19 +- modules/image_cap_dataset.py | 1077 +++++++++++++++++++++++++++++++++- sample_gradio.py | 4 +- train_pl.py | 48 +- 5 files changed, 1126 insertions(+), 24 deletions(-) diff --git a/modules/dit_builder.py b/modules/dit_builder.py index f7e9001d..15bcd262 100644 --- a/modules/dit_builder.py +++ b/modules/dit_builder.py @@ -203,7 +203,7 @@ def DiT_S_8(**kwargs): def DiT_clipper_builder(**kwargs): - return DiT_Clipped(depth=16, hidden_size=768, patch_size=2, num_heads=12, **kwargs) + return DiT_Clipped(depth=16, hidden_size=768, patch_size=2, num_heads=14, **kwargs) DiT_models = { diff --git a/modules/dit_clipped.py b/modules/dit_clipped.py index 1e10e589..a4448514 100644 --- a/modules/dit_clipped.py +++ b/modules/dit_clipped.py @@ -1,3 +1,5 @@ +import random + import torch import torch.nn as nn @@ -48,6 +50,8 @@ def __init__( self.encoder = FrozenCLIPEmbedder(clip_version) + self.secondary_device = torch.device("cpu") + self.initialize_weights() @torch.no_grad() @@ -143,15 +147,16 @@ def forward_with_cfg(self, x, t, y, cfg_scale): return torch.cat([eps, rest], dim=1) def configure_optimizers(self): - optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4, weight_decay=0) + optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4, weight_decay=1e-5) return optimizer def training_step(self, train_batch, batch_idx): img, context = train_batch + with torch.no_grad(): - context = self.encode(context, device=torch.device("cpu")) - context, img = context.to(self.device).to(self.dtype), img.cpu().to(torch.float32) - self.vae.cpu().to(torch.float32) + context = self.encode(context, device=self.secondary_device) + context, img = context.to(self.device).to(self.dtype), img.to(self.secondary_device).to(torch.float32) + self.vae.to(self.secondary_device).to(torch.float32) x = self.vae.encode(img).latent_dist.sample().mul_(0.18215).to(self.device).to(self.dtype) t = torch.randint(0, self.diffusion.num_timesteps, (x.shape[0],), device=self.device) @@ -162,8 +167,12 @@ def training_step(self, train_batch, batch_idx): model_kwargs = dict(context=context) loss_dict = self.diffusion.training_losses(self, x, t, model_kwargs) - loss = loss_dict["loss"].mean() + loss = loss_dict["loss"].mean() # vb mse loss + self.log("train_loss", loss) + self.log("train_bv", loss_dict["vb"].mean()) + self.log("train_mse", loss_dict["mse"].mean()) + return loss # def validation_step(self, val_batch, batch_idx): diff --git a/modules/image_cap_dataset.py b/modules/image_cap_dataset.py index 7cb295c9..395c6254 100644 --- a/modules/image_cap_dataset.py +++ b/modules/image_cap_dataset.py @@ -9,7 +9,7 @@ from torchvision.transforms import transforms -class ImageCaptionDataset(Dataset): +class HuggingfaceImageDataset(Dataset): def __init__(self, hf_dataset_name, token, transform=None, target_transform=None, res=256): self.hf_dataset = load_dataset(hf_dataset_name, use_auth_token=token)["test"] self.token = token @@ -45,3 +45,1078 @@ def __getitem__(self, idx): print(e) image, prompt = self.get_item(0) return image, prompt + + +class HuggingfaceImageNetDataset(Dataset): + def __init__(self, token, transform=None, target_transform=None, res=256): + self.hf_dataset = load_dataset("imagenet-1k", use_auth_token=token)["train"] + self.token = token + self.id2class = {0: 'tench, Tinca tinca', + 1: 'goldfish, Carassius auratus', + 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', + 3: 'tiger shark, Galeocerdo cuvieri', + 4: 'hammerhead, hammerhead shark', + 5: 'electric ray, crampfish, numbfish, torpedo', + 6: 'stingray', + 7: 'cock', + 8: 'hen', + 9: 'ostrich, Struthio camelus', + 10: 'brambling, Fringilla montifringilla', + 11: 'goldfinch, Carduelis carduelis', + 12: 'house finch, linnet, Carpodacus mexicanus', + 13: 'junco, snowbird', + 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea', + 15: 'robin, American robin, Turdus migratorius', + 16: 'bulbul', + 17: 'jay', + 18: 'magpie', + 19: 'chickadee', + 20: 'water ouzel, dipper', + 21: 'kite', + 22: 'bald eagle, American eagle, Haliaeetus leucocephalus', + 23: 'vulture', + 24: 'great grey owl, great gray owl, Strix nebulosa', + 25: 'European fire salamander, Salamandra salamandra', + 26: 'common newt, Triturus vulgaris', + 27: 'eft', + 28: 'spotted salamander, Ambystoma maculatum', + 29: 'axolotl, mud puppy, Ambystoma mexicanum', + 30: 'bullfrog, Rana catesbeiana', + 31: 'tree frog, tree-frog', + 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui', + 33: 'loggerhead, loggerhead turtle, Caretta caretta', + 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea', + 35: 'mud turtle', + 36: 'terrapin', + 37: 'box turtle, box tortoise', + 38: 'banded gecko', + 39: 'common iguana, iguana, Iguana iguana', + 40: 'American chameleon, anole, Anolis carolinensis', + 41: 'whiptail, whiptail lizard', + 42: 'agama', + 43: 'frilled lizard, Chlamydosaurus kingi', + 44: 'alligator lizard', + 45: 'Gila monster, Heloderma suspectum', + 46: 'green lizard, Lacerta viridis', + 47: 'African chameleon, Chamaeleo chamaeleon', + 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis', + 49: 'African crocodile, Nile crocodile, Crocodylus niloticus', + 50: 'American alligator, Alligator mississipiensis', + 51: 'triceratops', + 52: 'thunder snake, worm snake, Carphophis amoenus', + 53: 'ringneck snake, ring-necked snake, ring snake', + 54: 'hognose snake, puff adder, sand viper', + 55: 'green snake, grass snake', + 56: 'king snake, kingsnake', + 57: 'garter snake, grass snake', + 58: 'water snake', + 59: 'vine snake', + 60: 'night snake, Hypsiglena torquata', + 61: 'boa constrictor, Constrictor constrictor', + 62: 'rock python, rock snake, Python sebae', + 63: 'Indian cobra, Naja naja', + 64: 'green mamba', + 65: 'sea snake', + 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus', + 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus', + 68: 'sidewinder, horned rattlesnake, Crotalus cerastes', + 69: 'trilobite', + 70: 'harvestman, daddy longlegs, Phalangium opilio', + 71: 'scorpion', + 72: 'black and gold garden spider, Argiope aurantia', + 73: 'barn spider, Araneus cavaticus', + 74: 'garden spider, Aranea diademata', + 75: 'black widow, Latrodectus mactans', + 76: 'tarantula', + 77: 'wolf spider, hunting spider', + 78: 'tick', + 79: 'centipede', + 80: 'black grouse', + 81: 'ptarmigan', + 82: 'ruffed grouse, partridge, Bonasa umbellus', + 83: 'prairie chicken, prairie grouse, prairie fowl', + 84: 'peacock', + 85: 'quail', + 86: 'partridge', + 87: 'African grey, African gray, Psittacus erithacus', + 88: 'macaw', + 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita', + 90: 'lorikeet', + 91: 'coucal', + 92: 'bee eater', + 93: 'hornbill', + 94: 'hummingbird', + 95: 'jacamar', + 96: 'toucan', + 97: 'drake', + 98: 'red-breasted merganser, Mergus serrator', + 99: 'goose', + 100: 'black swan, Cygnus atratus', + 101: 'tusker', + 102: 'echidna, spiny anteater, anteater', + 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus', + 104: 'wallaby, brush kangaroo', + 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus', + 106: 'wombat', + 107: 'jellyfish', + 108: 'sea anemone, anemone', + 109: 'brain coral', + 110: 'flatworm, platyhelminth', + 111: 'nematode, nematode worm, roundworm', + 112: 'conch', + 113: 'snail', + 114: 'slug', + 115: 'sea slug, nudibranch', + 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore', + 117: 'chambered nautilus, pearly nautilus, nautilus', + 118: 'Dungeness crab, Cancer magister', + 119: 'rock crab, Cancer irroratus', + 120: 'fiddler crab', + 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica', + 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus', + 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish', + 124: 'crayfish, crawfish, crawdad, crawdaddy', + 125: 'hermit crab', + 126: 'isopod', + 127: 'white stork, Ciconia ciconia', + 128: 'black stork, Ciconia nigra', + 129: 'spoonbill', + 130: 'flamingo', + 131: 'little blue heron, Egretta caerulea', + 132: 'American egret, great white heron, Egretta albus', + 133: 'bittern', + 134: 'crane', + 135: 'limpkin, Aramus pictus', + 136: 'European gallinule, Porphyrio porphyrio', + 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana', + 138: 'bustard', + 139: 'ruddy turnstone, Arenaria interpres', + 140: 'red-backed sandpiper, dunlin, Erolia alpina', + 141: 'redshank, Tringa totanus', + 142: 'dowitcher', + 143: 'oystercatcher, oyster catcher', + 144: 'pelican', + 145: 'king penguin, Aptenodytes patagonica', + 146: 'albatross, mollymawk', + 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus', + 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca', + 149: 'dugong, Dugong dugon', + 150: 'sea lion', + 151: 'Chihuahua', + 152: 'Japanese spaniel', + 153: 'Maltese dog, Maltese terrier, Maltese', + 154: 'Pekinese, Pekingese, Peke', + 155: 'Shih-Tzu', + 156: 'Blenheim spaniel', + 157: 'papillon', + 158: 'toy terrier', + 159: 'Rhodesian ridgeback', + 160: 'Afghan hound, Afghan', + 161: 'basset, basset hound', + 162: 'beagle', + 163: 'bloodhound, sleuthhound', + 164: 'bluetick', + 165: 'black-and-tan coonhound', + 166: 'Walker hound, Walker foxhound', + 167: 'English foxhound', + 168: 'redbone', + 169: 'borzoi, Russian wolfhound', + 170: 'Irish wolfhound', + 171: 'Italian greyhound', + 172: 'whippet', + 173: 'Ibizan hound, Ibizan Podenco', + 174: 'Norwegian elkhound, elkhound', + 175: 'otterhound, otter hound', + 176: 'Saluki, gazelle hound', + 177: 'Scottish deerhound, deerhound', + 178: 'Weimaraner', + 179: 'Staffordshire bullterrier, Staffordshire bull terrier', + 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier', + 181: 'Bedlington terrier', + 182: 'Border terrier', + 183: 'Kerry blue terrier', + 184: 'Irish terrier', + 185: 'Norfolk terrier', + 186: 'Norwich terrier', + 187: 'Yorkshire terrier', + 188: 'wire-haired fox terrier', + 189: 'Lakeland terrier', + 190: 'Sealyham terrier, Sealyham', + 191: 'Airedale, Airedale terrier', + 192: 'cairn, cairn terrier', + 193: 'Australian terrier', + 194: 'Dandie Dinmont, Dandie Dinmont terrier', + 195: 'Boston bull, Boston terrier', + 196: 'miniature schnauzer', + 197: 'giant schnauzer', + 198: 'standard schnauzer', + 199: 'Scotch terrier, Scottish terrier, Scottie', + 200: 'Tibetan terrier, chrysanthemum dog', + 201: 'silky terrier, Sydney silky', + 202: 'soft-coated wheaten terrier', + 203: 'West Highland white terrier', + 204: 'Lhasa, Lhasa apso', + 205: 'flat-coated retriever', + 206: 'curly-coated retriever', + 207: 'golden retriever', + 208: 'Labrador retriever', + 209: 'Chesapeake Bay retriever', + 210: 'German short-haired pointer', + 211: 'vizsla, Hungarian pointer', + 212: 'English setter', + 213: 'Irish setter, red setter', + 214: 'Gordon setter', + 215: 'Brittany spaniel', + 216: 'clumber, clumber spaniel', + 217: 'English springer, English springer spaniel', + 218: 'Welsh springer spaniel', + 219: 'cocker spaniel, English cocker spaniel, cocker', + 220: 'Sussex spaniel', + 221: 'Irish water spaniel', + 222: 'kuvasz', + 223: 'schipperke', + 224: 'groenendael', + 225: 'malinois', + 226: 'briard', + 227: 'kelpie', + 228: 'komondor', + 229: 'Old English sheepdog, bobtail', + 230: 'Shetland sheepdog, Shetland sheep dog, Shetland', + 231: 'collie', + 232: 'Border collie', + 233: 'Bouvier des Flandres, Bouviers des Flandres', + 234: 'Rottweiler', + 235: 'German shepherd, German shepherd dog, German police dog, alsatian', + 236: 'Doberman, Doberman pinscher', + 237: 'miniature pinscher', + 238: 'Greater Swiss Mountain dog', + 239: 'Bernese mountain dog', + 240: 'Appenzeller', + 241: 'EntleBucher', + 242: 'boxer', + 243: 'bull mastiff', + 244: 'Tibetan mastiff', + 245: 'French bulldog', + 246: 'Great Dane', + 247: 'Saint Bernard, St Bernard', + 248: 'Eskimo dog, husky', + 249: 'malamute, malemute, Alaskan malamute', + 250: 'Siberian husky', + 251: 'dalmatian, coach dog, carriage dog', + 252: 'affenpinscher, monkey pinscher, monkey dog', + 253: 'basenji', + 254: 'pug, pug-dog', + 255: 'Leonberg', + 256: 'Newfoundland, Newfoundland dog', + 257: 'Great Pyrenees', + 258: 'Samoyed, Samoyede', + 259: 'Pomeranian', + 260: 'chow, chow chow', + 261: 'keeshond', + 262: 'Brabancon griffon', + 263: 'Pembroke, Pembroke Welsh corgi', + 264: 'Cardigan, Cardigan Welsh corgi', + 265: 'toy poodle', + 266: 'miniature poodle', + 267: 'standard poodle', + 268: 'Mexican hairless', + 269: 'timber wolf, grey wolf, gray wolf, Canis lupus', + 270: 'white wolf, Arctic wolf, Canis lupus tundrarum', + 271: 'red wolf, maned wolf, Canis rufus, Canis niger', + 272: 'coyote, prairie wolf, brush wolf, Canis latrans', + 273: 'dingo, warrigal, warragal, Canis dingo', + 274: 'dhole, Cuon alpinus', + 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus', + 276: 'hyena, hyaena', + 277: 'red fox, Vulpes vulpes', + 278: 'kit fox, Vulpes macrotis', + 279: 'Arctic fox, white fox, Alopex lagopus', + 280: 'grey fox, gray fox, Urocyon cinereoargenteus', + 281: 'tabby, tabby cat', + 282: 'tiger cat', + 283: 'Persian cat', + 284: 'Siamese cat, Siamese', + 285: 'Egyptian cat', + 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor', + 287: 'lynx, catamount', + 288: 'leopard, Panthera pardus', + 289: 'snow leopard, ounce, Panthera uncia', + 290: 'jaguar, panther, Panthera onca, Felis onca', + 291: 'lion, king of beasts, Panthera leo', + 292: 'tiger, Panthera tigris', + 293: 'cheetah, chetah, Acinonyx jubatus', + 294: 'brown bear, bruin, Ursus arctos', + 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus', + 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus', + 297: 'sloth bear, Melursus ursinus, Ursus ursinus', + 298: 'mongoose', + 299: 'meerkat, mierkat', + 300: 'tiger beetle', + 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle', + 302: 'ground beetle, carabid beetle', + 303: 'long-horned beetle, longicorn, longicorn beetle', + 304: 'leaf beetle, chrysomelid', + 305: 'dung beetle', + 306: 'rhinoceros beetle', + 307: 'weevil', + 308: 'fly', + 309: 'bee', + 310: 'ant, emmet, pismire', + 311: 'grasshopper, hopper', + 312: 'cricket', + 313: 'walking stick, walkingstick, stick insect', + 314: 'cockroach, roach', + 315: 'mantis, mantid', + 316: 'cicada, cicala', + 317: 'leafhopper', + 318: 'lacewing, lacewing fly', + 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", + 320: 'damselfly', + 321: 'admiral', + 322: 'ringlet, ringlet butterfly', + 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus', + 324: 'cabbage butterfly', + 325: 'sulphur butterfly, sulfur butterfly', + 326: 'lycaenid, lycaenid butterfly', + 327: 'starfish, sea star', + 328: 'sea urchin', + 329: 'sea cucumber, holothurian', + 330: 'wood rabbit, cottontail, cottontail rabbit', + 331: 'hare', + 332: 'Angora, Angora rabbit', + 333: 'hamster', + 334: 'porcupine, hedgehog', + 335: 'fox squirrel, eastern fox squirrel, Sciurus niger', + 336: 'marmot', + 337: 'beaver', + 338: 'guinea pig, Cavia cobaya', + 339: 'sorrel', + 340: 'zebra', + 341: 'hog, pig, grunter, squealer, Sus scrofa', + 342: 'wild boar, boar, Sus scrofa', + 343: 'warthog', + 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius', + 345: 'ox', + 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis', + 347: 'bison', + 348: 'ram, tup', + 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis', + 350: 'ibex, Capra ibex', + 351: 'hartebeest', + 352: 'impala, Aepyceros melampus', + 353: 'gazelle', + 354: 'Arabian camel, dromedary, Camelus dromedarius', + 355: 'llama', + 356: 'weasel', + 357: 'mink', + 358: 'polecat, fitch, foulmart, foumart, Mustela putorius', + 359: 'black-footed ferret, ferret, Mustela nigripes', + 360: 'otter', + 361: 'skunk, polecat, wood pussy', + 362: 'badger', + 363: 'armadillo', + 364: 'three-toed sloth, ai, Bradypus tridactylus', + 365: 'orangutan, orang, orangutang, Pongo pygmaeus', + 366: 'gorilla, Gorilla gorilla', + 367: 'chimpanzee, chimp, Pan troglodytes', + 368: 'gibbon, Hylobates lar', + 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus', + 370: 'guenon, guenon monkey', + 371: 'patas, hussar monkey, Erythrocebus patas', + 372: 'baboon', + 373: 'macaque', + 374: 'langur', + 375: 'colobus, colobus monkey', + 376: 'proboscis monkey, Nasalis larvatus', + 377: 'marmoset', + 378: 'capuchin, ringtail, Cebus capucinus', + 379: 'howler monkey, howler', + 380: 'titi, titi monkey', + 381: 'spider monkey, Ateles geoffroyi', + 382: 'squirrel monkey, Saimiri sciureus', + 383: 'Madagascar cat, ring-tailed lemur, Lemur catta', + 384: 'indri, indris, Indri indri, Indri brevicaudatus', + 385: 'Indian elephant, Elephas maximus', + 386: 'African elephant, Loxodonta africana', + 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens', + 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca', + 389: 'barracouta, snoek', + 390: 'eel', + 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch', + 392: 'rock beauty, Holocanthus tricolor', + 393: 'anemone fish', + 394: 'sturgeon', + 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus', + 396: 'lionfish', + 397: 'puffer, pufferfish, blowfish, globefish', + 398: 'abacus', + 399: 'abaya', + 400: "academic gown, academic robe, judge's robe", + 401: 'accordion, piano accordion, squeeze box', + 402: 'acoustic guitar', + 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier', + 404: 'airliner', + 405: 'airship, dirigible', + 406: 'altar', + 407: 'ambulance', + 408: 'amphibian, amphibious vehicle', + 409: 'analog clock', + 410: 'apiary, bee house', + 411: 'apron', + 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin', + 413: 'assault rifle, assault gun', + 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack', + 415: 'bakery, bakeshop, bakehouse', + 416: 'balance beam, beam', + 417: 'balloon', + 418: 'ballpoint, ballpoint pen, ballpen, Biro', + 419: 'Band Aid', + 420: 'banjo', + 421: 'bannister, banister, balustrade, balusters, handrail', + 422: 'barbell', + 423: 'barber chair', + 424: 'barbershop', + 425: 'barn', + 426: 'barometer', + 427: 'barrel, cask', + 428: 'barrow, garden cart, lawn cart, wheelbarrow', + 429: 'baseball', + 430: 'basketball', + 431: 'bassinet', + 432: 'bassoon', + 433: 'bathing cap, swimming cap', + 434: 'bath towel', + 435: 'bathtub, bathing tub, bath, tub', + 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon', + 437: 'beacon, lighthouse, beacon light, pharos', + 438: 'beaker', + 439: 'bearskin, busby, shako', + 440: 'beer bottle', + 441: 'beer glass', + 442: 'bell cote, bell cot', + 443: 'bib', + 444: 'bicycle-built-for-two, tandem bicycle, tandem', + 445: 'bikini, two-piece', + 446: 'binder, ring-binder', + 447: 'binoculars, field glasses, opera glasses', + 448: 'birdhouse', + 449: 'boathouse', + 450: 'bobsled, bobsleigh, bob', + 451: 'bolo tie, bolo, bola tie, bola', + 452: 'bonnet, poke bonnet', + 453: 'bookcase', + 454: 'bookshop, bookstore, bookstall', + 455: 'bottlecap', + 456: 'bow', + 457: 'bow tie, bow-tie, bowtie', + 458: 'brass, memorial tablet, plaque', + 459: 'brassiere, bra, bandeau', + 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty', + 461: 'breastplate, aegis, egis', + 462: 'broom', + 463: 'bucket, pail', + 464: 'buckle', + 465: 'bulletproof vest', + 466: 'bullet train, bullet', + 467: 'butcher shop, meat market', + 468: 'cab, hack, taxi, taxicab', + 469: 'caldron, cauldron', + 470: 'candle, taper, wax light', + 471: 'cannon', + 472: 'canoe', + 473: 'can opener, tin opener', + 474: 'cardigan', + 475: 'car mirror', + 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig', + 477: "carpenter's kit, tool kit", + 478: 'carton', + 479: 'car wheel', + 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM', + 481: 'cassette', + 482: 'cassette player', + 483: 'castle', + 484: 'catamaran', + 485: 'CD player', + 486: 'cello, violoncello', + 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone', + 488: 'chain', + 489: 'chainlink fence', + 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour', + 491: 'chain saw, chainsaw', + 492: 'chest', + 493: 'chiffonier, commode', + 494: 'chime, bell, gong', + 495: 'china cabinet, china closet', + 496: 'Christmas stocking', + 497: 'church, church building', + 498: 'cinema, movie theater, movie theatre, movie house, picture palace', + 499: 'cleaver, meat cleaver, chopper', + 500: 'cliff dwelling', + 501: 'cloak', + 502: 'clog, geta, patten, sabot', + 503: 'cocktail shaker', + 504: 'coffee mug', + 505: 'coffeepot', + 506: 'coil, spiral, volute, whorl, helix', + 507: 'combination lock', + 508: 'computer keyboard, keypad', + 509: 'confectionery, confectionary, candy store', + 510: 'container ship, containership, container vessel', + 511: 'convertible', + 512: 'corkscrew, bottle screw', + 513: 'cornet, horn, trumpet, trump', + 514: 'cowboy boot', + 515: 'cowboy hat, ten-gallon hat', + 516: 'cradle', + 517: 'crane', + 518: 'crash helmet', + 519: 'crate', + 520: 'crib, cot', + 521: 'Crock Pot', + 522: 'croquet ball', + 523: 'crutch', + 524: 'cuirass', + 525: 'dam, dike, dyke', + 526: 'desk', + 527: 'desktop computer', + 528: 'dial telephone, dial phone', + 529: 'diaper, nappy, napkin', + 530: 'digital clock', + 531: 'digital watch', + 532: 'dining table, board', + 533: 'dishrag, dishcloth', + 534: 'dishwasher, dish washer, dishwashing machine', + 535: 'disk brake, disc brake', + 536: 'dock, dockage, docking facility', + 537: 'dogsled, dog sled, dog sleigh', + 538: 'dome', + 539: 'doormat, welcome mat', + 540: 'drilling platform, offshore rig', + 541: 'drum, membranophone, tympan', + 542: 'drumstick', + 543: 'dumbbell', + 544: 'Dutch oven', + 545: 'electric fan, blower', + 546: 'electric guitar', + 547: 'electric locomotive', + 548: 'entertainment center', + 549: 'envelope', + 550: 'espresso maker', + 551: 'face powder', + 552: 'feather boa, boa', + 553: 'file, file cabinet, filing cabinet', + 554: 'fireboat', + 555: 'fire engine, fire truck', + 556: 'fire screen, fireguard', + 557: 'flagpole, flagstaff', + 558: 'flute, transverse flute', + 559: 'folding chair', + 560: 'football helmet', + 561: 'forklift', + 562: 'fountain', + 563: 'fountain pen', + 564: 'four-poster', + 565: 'freight car', + 566: 'French horn, horn', + 567: 'frying pan, frypan, skillet', + 568: 'fur coat', + 569: 'garbage truck, dustcart', + 570: 'gasmask, respirator, gas helmet', + 571: 'gas pump, gasoline pump, petrol pump, island dispenser', + 572: 'goblet', + 573: 'go-kart', + 574: 'golf ball', + 575: 'golfcart, golf cart', + 576: 'gondola', + 577: 'gong, tam-tam', + 578: 'gown', + 579: 'grand piano, grand', + 580: 'greenhouse, nursery, glasshouse', + 581: 'grille, radiator grille', + 582: 'grocery store, grocery, food market, market', + 583: 'guillotine', + 584: 'hair slide', + 585: 'hair spray', + 586: 'half track', + 587: 'hammer', + 588: 'hamper', + 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier', + 590: 'hand-held computer, hand-held microcomputer', + 591: 'handkerchief, hankie, hanky, hankey', + 592: 'hard disc, hard disk, fixed disk', + 593: 'harmonica, mouth organ, harp, mouth harp', + 594: 'harp', + 595: 'harvester, reaper', + 596: 'hatchet', + 597: 'holster', + 598: 'home theater, home theatre', + 599: 'honeycomb', + 600: 'hook, claw', + 601: 'hoopskirt, crinoline', + 602: 'horizontal bar, high bar', + 603: 'horse cart, horse-cart', + 604: 'hourglass', + 605: 'iPod', + 606: 'iron, smoothing iron', + 607: "jack-o'-lantern", + 608: 'jean, blue jean, denim', + 609: 'jeep, landrover', + 610: 'jersey, T-shirt, tee shirt', + 611: 'jigsaw puzzle', + 612: 'jinrikisha, ricksha, rickshaw', + 613: 'joystick', + 614: 'kimono', + 615: 'knee pad', + 616: 'knot', + 617: 'lab coat, laboratory coat', + 618: 'ladle', + 619: 'lampshade, lamp shade', + 620: 'laptop, laptop computer', + 621: 'lawn mower, mower', + 622: 'lens cap, lens cover', + 623: 'letter opener, paper knife, paperknife', + 624: 'library', + 625: 'lifeboat', + 626: 'lighter, light, igniter, ignitor', + 627: 'limousine, limo', + 628: 'liner, ocean liner', + 629: 'lipstick, lip rouge', + 630: 'Loafer', + 631: 'lotion', + 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system', + 633: "loupe, jeweler's loupe", + 634: 'lumbermill, sawmill', + 635: 'magnetic compass', + 636: 'mailbag, postbag', + 637: 'mailbox, letter box', + 638: 'maillot', + 639: 'maillot, tank suit', + 640: 'manhole cover', + 641: 'maraca', + 642: 'marimba, xylophone', + 643: 'mask', + 644: 'matchstick', + 645: 'maypole', + 646: 'maze, labyrinth', + 647: 'measuring cup', + 648: 'medicine chest, medicine cabinet', + 649: 'megalith, megalithic structure', + 650: 'microphone, mike', + 651: 'microwave, microwave oven', + 652: 'military uniform', + 653: 'milk can', + 654: 'minibus', + 655: 'miniskirt, mini', + 656: 'minivan', + 657: 'missile', + 658: 'mitten', + 659: 'mixing bowl', + 660: 'mobile home, manufactured home', + 661: 'Model T', + 662: 'modem', + 663: 'monastery', + 664: 'monitor', + 665: 'moped', + 666: 'mortar', + 667: 'mortarboard', + 668: 'mosque', + 669: 'mosquito net', + 670: 'motor scooter, scooter', + 671: 'mountain bike, all-terrain bike, off-roader', + 672: 'mountain tent', + 673: 'mouse, computer mouse', + 674: 'mousetrap', + 675: 'moving van', + 676: 'muzzle', + 677: 'nail', + 678: 'neck brace', + 679: 'necklace', + 680: 'nipple', + 681: 'notebook, notebook computer', + 682: 'obelisk', + 683: 'oboe, hautboy, hautbois', + 684: 'ocarina, sweet potato', + 685: 'odometer, hodometer, mileometer, milometer', + 686: 'oil filter', + 687: 'organ, pipe organ', + 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO', + 689: 'overskirt', + 690: 'oxcart', + 691: 'oxygen mask', + 692: 'packet', + 693: 'paddle, boat paddle', + 694: 'paddlewheel, paddle wheel', + 695: 'padlock', + 696: 'paintbrush', + 697: "pajama, pyjama, pj's, jammies", + 698: 'palace', + 699: 'panpipe, pandean pipe, syrinx', + 700: 'paper towel', + 701: 'parachute, chute', + 702: 'parallel bars, bars', + 703: 'park bench', + 704: 'parking meter', + 705: 'passenger car, coach, carriage', + 706: 'patio, terrace', + 707: 'pay-phone, pay-station', + 708: 'pedestal, plinth, footstall', + 709: 'pencil box, pencil case', + 710: 'pencil sharpener', + 711: 'perfume, essence', + 712: 'Petri dish', + 713: 'photocopier', + 714: 'pick, plectrum, plectron', + 715: 'pickelhaube', + 716: 'picket fence, paling', + 717: 'pickup, pickup truck', + 718: 'pier', + 719: 'piggy bank, penny bank', + 720: 'pill bottle', + 721: 'pillow', + 722: 'ping-pong ball', + 723: 'pinwheel', + 724: 'pirate, pirate ship', + 725: 'pitcher, ewer', + 726: "plane, carpenter's plane, woodworking plane", + 727: 'planetarium', + 728: 'plastic bag', + 729: 'plate rack', + 730: 'plow, plough', + 731: "plunger, plumber's helper", + 732: 'Polaroid camera, Polaroid Land camera', + 733: 'pole', + 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria', + 735: 'poncho', + 736: 'pool table, billiard table, snooker table', + 737: 'pop bottle, soda bottle', + 738: 'pot, flowerpot', + 739: "potter's wheel", + 740: 'power drill', + 741: 'prayer rug, prayer mat', + 742: 'printer', + 743: 'prison, prison house', + 744: 'projectile, missile', + 745: 'projector', + 746: 'puck, hockey puck', + 747: 'punching bag, punch bag, punching ball, punchball', + 748: 'purse', + 749: 'quill, quill pen', + 750: 'quilt, comforter, comfort, puff', + 751: 'racer, race car, racing car', + 752: 'racket, racquet', + 753: 'radiator', + 754: 'radio, wireless', + 755: 'radio telescope, radio reflector', + 756: 'rain barrel', + 757: 'recreational vehicle, RV, R.V.', + 758: 'reel', + 759: 'reflex camera', + 760: 'refrigerator, icebox', + 761: 'remote control, remote', + 762: 'restaurant, eating house, eating place, eatery', + 763: 'revolver, six-gun, six-shooter', + 764: 'rifle', + 765: 'rocking chair, rocker', + 766: 'rotisserie', + 767: 'rubber eraser, rubber, pencil eraser', + 768: 'rugby ball', + 769: 'rule, ruler', + 770: 'running shoe', + 771: 'safe', + 772: 'safety pin', + 773: 'saltshaker, salt shaker', + 774: 'sandal', + 775: 'sarong', + 776: 'sax, saxophone', + 777: 'scabbard', + 778: 'scale, weighing machine', + 779: 'school bus', + 780: 'schooner', + 781: 'scoreboard', + 782: 'screen, CRT screen', + 783: 'screw', + 784: 'screwdriver', + 785: 'seat belt, seatbelt', + 786: 'sewing machine', + 787: 'shield, buckler', + 788: 'shoe shop, shoe-shop, shoe store', + 789: 'shoji', + 790: 'shopping basket', + 791: 'shopping cart', + 792: 'shovel', + 793: 'shower cap', + 794: 'shower curtain', + 795: 'ski', + 796: 'ski mask', + 797: 'sleeping bag', + 798: 'slide rule, slipstick', + 799: 'sliding door', + 800: 'slot, one-armed bandit', + 801: 'snorkel', + 802: 'snowmobile', + 803: 'snowplow, snowplough', + 804: 'soap dispenser', + 805: 'soccer ball', + 806: 'sock', + 807: 'solar dish, solar collector, solar furnace', + 808: 'sombrero', + 809: 'soup bowl', + 810: 'space bar', + 811: 'space heater', + 812: 'space shuttle', + 813: 'spatula', + 814: 'speedboat', + 815: "spider web, spider's web", + 816: 'spindle', + 817: 'sports car, sport car', + 818: 'spotlight, spot', + 819: 'stage', + 820: 'steam locomotive', + 821: 'steel arch bridge', + 822: 'steel drum', + 823: 'stethoscope', + 824: 'stole', + 825: 'stone wall', + 826: 'stopwatch, stop watch', + 827: 'stove', + 828: 'strainer', + 829: 'streetcar, tram, tramcar, trolley, trolley car', + 830: 'stretcher', + 831: 'studio couch, day bed', + 832: 'stupa, tope', + 833: 'submarine, pigboat, sub, U-boat', + 834: 'suit, suit of clothes', + 835: 'sundial', + 836: 'sunglass', + 837: 'sunglasses, dark glasses, shades', + 838: 'sunscreen, sunblock, sun blocker', + 839: 'suspension bridge', + 840: 'swab, swob, mop', + 841: 'sweatshirt', + 842: 'swimming trunks, bathing trunks', + 843: 'swing', + 844: 'switch, electric switch, electrical switch', + 845: 'syringe', + 846: 'table lamp', + 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle', + 848: 'tape player', + 849: 'teapot', + 850: 'teddy, teddy bear', + 851: 'television, television system', + 852: 'tennis ball', + 853: 'thatch, thatched roof', + 854: 'theater curtain, theatre curtain', + 855: 'thimble', + 856: 'thresher, thrasher, threshing machine', + 857: 'throne', + 858: 'tile roof', + 859: 'toaster', + 860: 'tobacco shop, tobacconist shop, tobacconist', + 861: 'toilet seat', + 862: 'torch', + 863: 'totem pole', + 864: 'tow truck, tow car, wrecker', + 865: 'toyshop', + 866: 'tractor', + 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi', + 868: 'tray', + 869: 'trench coat', + 870: 'tricycle, trike, velocipede', + 871: 'trimaran', + 872: 'tripod', + 873: 'triumphal arch', + 874: 'trolleybus, trolley coach, trackless trolley', + 875: 'trombone', + 876: 'tub, vat', + 877: 'turnstile', + 878: 'typewriter keyboard', + 879: 'umbrella', + 880: 'unicycle, monocycle', + 881: 'upright, upright piano', + 882: 'vacuum, vacuum cleaner', + 883: 'vase', + 884: 'vault', + 885: 'velvet', + 886: 'vending machine', + 887: 'vestment', + 888: 'viaduct', + 889: 'violin, fiddle', + 890: 'volleyball', + 891: 'waffle iron', + 892: 'wall clock', + 893: 'wallet, billfold, notecase, pocketbook', + 894: 'wardrobe, closet, press', + 895: 'warplane, military plane', + 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin', + 897: 'washer, automatic washer, washing machine', + 898: 'water bottle', + 899: 'water jug', + 900: 'water tower', + 901: 'whiskey jug', + 902: 'whistle', + 903: 'wig', + 904: 'window screen', + 905: 'window shade', + 906: 'Windsor tie', + 907: 'wine bottle', + 908: 'wing', + 909: 'wok', + 910: 'wooden spoon', + 911: 'wool, woolen, woollen', + 912: 'worm fence, snake fence, snake-rail fence, Virginia fence', + 913: 'wreck', + 914: 'yawl', + 915: 'yurt', + 916: 'web site, website, internet site, site', + 917: 'comic book', + 918: 'crossword puzzle, crossword', + 919: 'street sign', + 920: 'traffic light, traffic signal, stoplight', + 921: 'book jacket, dust cover, dust jacket, dust wrapper', + 922: 'menu', + 923: 'plate', + 924: 'guacamole', + 925: 'consomme', + 926: 'hot pot, hotpot', + 927: 'trifle', + 928: 'ice cream, icecream', + 929: 'ice lolly, lolly, lollipop, popsicle', + 930: 'French loaf', + 931: 'bagel, beigel', + 932: 'pretzel', + 933: 'cheeseburger', + 934: 'hotdog, hot dog, red hot', + 935: 'mashed potato', + 936: 'head cabbage', + 937: 'broccoli', + 938: 'cauliflower', + 939: 'zucchini, courgette', + 940: 'spaghetti squash', + 941: 'acorn squash', + 942: 'butternut squash', + 943: 'cucumber, cuke', + 944: 'artichoke, globe artichoke', + 945: 'bell pepper', + 946: 'cardoon', + 947: 'mushroom', + 948: 'Granny Smith', + 949: 'strawberry', + 950: 'orange', + 951: 'lemon', + 952: 'fig', + 953: 'pineapple, ananas', + 954: 'banana', + 955: 'jackfruit, jak, jack', + 956: 'custard apple', + 957: 'pomegranate', + 958: 'hay', + 959: 'carbonara', + 960: 'chocolate sauce, chocolate syrup', + 961: 'dough', + 962: 'meat loaf, meatloaf', + 963: 'pizza, pizza pie', + 964: 'potpie', + 965: 'burrito', + 966: 'red wine', + 967: 'espresso', + 968: 'cup', + 969: 'eggnog', + 970: 'alp', + 971: 'bubble', + 972: 'cliff, drop, drop-off', + 973: 'coral reef', + 974: 'geyser', + 975: 'lakeside, lakeshore', + 976: 'promontory, headland, head, foreland', + 977: 'sandbar, sand bar', + 978: 'seashore, coast, seacoast, sea-coast', + 979: 'valley, vale', + 980: 'volcano', + 981: 'ballplayer, baseball player', + 982: 'groom, bridegroom', + 983: 'scuba diver', + 984: 'rapeseed', + 985: 'daisy', + 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", + 987: 'corn', + 988: 'acorn', + 989: 'hip, rose hip, rosehip', + 990: 'buckeye, horse chestnut, conker', + 991: 'coral fungus', + 992: 'agaric', + 993: 'gyromitra', + 994: 'stinkhorn, carrion fungus', + 995: 'earthstar', + 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa', + 997: 'bolete', + 998: 'ear, spike, capitulum', + 999: 'toilet tissue, toilet paper, bathroom tissue'} + + self.transform = transform if transform is not None else transforms.Compose([ + transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, res)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + ]) + self.target_transform = target_transform + + def __len__(self): + return len(self.hf_dataset) + + def get_item(self, idx): + image, prompt = self.hf_dataset[idx]["image"], self.id2class[self.hf_dataset[idx]["label"]] + + image = image.convert("RGB") + + prompt = prompt.split(",")[0].lower() + + if self.transform: + image = self.transform(image) + return image, prompt + + def __getitem__(self, idx): + try: + image, prompt = self.get_item(idx) + except Exception as e: + print(e) + image, prompt = self.get_item(0) + return image, prompt + + +class DummyDataset(Dataset): + def __init__(self, transform=None, target_transform=None, res=256): + + self.transform = transform if transform is not None else transforms.Compose([ + transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, res)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + ]) + self.target_transform = target_transform + + def __len__(self): + return 100 + + def get_item(self, idx): + if random.random() < 0.3: + image = Image.open("apple.jpg") + + prompt = "an apple" + elif 0.3 < random.random() < 0.6: + image = Image.open("orange.jpg") + + prompt = "an orange" + else: + image = Image.open("forest.jpg") + + prompt = "a forest" + + if self.transform: + image = self.transform(image) + return image, prompt + + def __getitem__(self, idx): + try: + image, prompt = self.get_item(idx) + except Exception as e: + print(e) + image, prompt = self.get_item(0) + return image, prompt diff --git a/sample_gradio.py b/sample_gradio.py index 68091561..4f6a3026 100644 --- a/sample_gradio.py +++ b/sample_gradio.py @@ -96,8 +96,8 @@ def sample(prompt, cfg_scale, num_sampling_steps, seed): inputs=[ gr.Text(label="Text Prompt", value="an apple"), gr.Slider(minimum=1, maximum=20, value=4, step=0.1, label="Cfg scale"), - gr.Slider(minimum=5, maximum=500, value=50, step=1, label="Sampling steps"), - gr.Slider(minimum=1, maximum=9223372036854775807, value=4, step=1, + gr.Slider(minimum=5, maximum=1000, value=50, step=1, label="Sampling steps"), + gr.Slider(minimum=1, maximum=9223372036854775807, value=5782510030869745000, step=1, label="Seed"), ], outputs=[ diff --git a/train_pl.py b/train_pl.py index 63c742ee..0e14a7cf 100644 --- a/train_pl.py +++ b/train_pl.py @@ -1,40 +1,54 @@ import argparse +import random -import pytorch_lightning as pl import torchvision -from pytorch_lightning.callbacks import ModelCheckpoint +import torch + +import pytorch_lightning as pl +from pytorch_lightning.callbacks import ModelCheckpoint from torch.utils.data import DataLoader from torchvision.transforms import transforms +from download import find_model from modules.dit_builder import DiT_models from modules.diffusion import create_diffusion from diffusers.models import AutoencoderKL -from modules.image_cap_dataset import ImageCaptionDataset -from modules.training_utils import * +from modules.image_cap_dataset import HuggingfaceImageDataset, DummyDataset, HuggingfaceImageNetDataset +from modules.training_utils import center_crop_arr def train_pl(args): - global model print("Starting training..") device = torch.device(0) + secondary_device = torch.device("cpu") latent_size = args.image_size // 8 model = DiT_models[args.model]( input_size=latent_size, ) - dataset = ImageCaptionDataset(args.hf_dataset_name, args.token, res=args.image_size) - - # del model.encoder + # dataset = HuggingfaceImageDataset(args.hf_dataset_name, args.token, res=args.image_size) + dataset = HuggingfaceImageNetDataset(args.token, res=args.image_size) + # dataset = torchvision.datasets.CocoCaptions(args.coco_dataset_path + "/train2017/", + # args.coco_dataset_path + "/annotations_train/captions_train2017.json", + # transform=transform, + # target_transform=lambda x: random.choice(x).replace("\n", "").lower()) + # dataset = DummyDataset(res=args.image_size) + + state_dict = find_model("pretrained_models/last.ckpt") + if 'pytorch-lightning_version' in state_dict.keys(): + state_dict = state_dict["state_dict"] + model.load_state_dict(state_dict, strict=False) diffusion = create_diffusion(timestep_respacing="") - vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").cpu() + vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(secondary_device) # training only model.diffusion = diffusion model.vae = vae - model_ckpt = ModelCheckpoint(dirpath="ckpts/", monitor=None, save_top_k=-1, save_last=True, every_n_train_steps=1000 - ) + model.secondary_device = secondary_device + model_ckpt = ModelCheckpoint(dirpath="ckpts/", monitor="train_loss", save_top_k=2, save_last=True, + every_n_train_steps=3_000) loader_train = DataLoader( dataset, @@ -50,9 +64,9 @@ def train_pl(args): torch.set_float32_matmul_precision("high") trainer = pl.Trainer( - auto_lr_find=True, + auto_lr_find=False, enable_checkpointing=True, - detect_anomaly=True, + detect_anomaly=False, log_every_n_steps=50, accelerator='gpu', devices=1, @@ -67,8 +81,10 @@ def train_pl(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--hf_dataset_name", type=str, default="facebook/winoground") - parser.add_argument("--token", type=str) + parser.add_argument("--hf_dataset_name", type=str, default="facebook/winoground", required=False) + parser.add_argument("--token", type=str, required=False) + + parser.add_argument("--coco_dataset_path", type=str, required=False) parser.add_argument("--results-dir", type=str, default="results") parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT_Clipped") @@ -81,6 +97,8 @@ def train_pl(args): parser.add_argument("--precision", type=str, choices=["fp16", "fp32"], default="fp16") parsed_args = parser.parse_args() + assert (parsed_args.hf_dataset_name is not None and parsed_args.token is not None) or (parsed_args.coco_dataset_path) is not None, "Either hf token and dataset name or coco dataset path is required" + transform = transforms.Compose([ transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, parsed_args.image_size)), transforms.RandomHorizontalFlip(), From ed63602fcb0441a216108913f2390182d372a31d Mon Sep 17 00:00:00 2001 From: neon Date: Tue, 21 Mar 2023 17:35:25 +0100 Subject: [PATCH 13/13] upgraded to lightning 2.0, architectural changes, added link for a pretrained model --- README.md | 80 ++++------------------- modules/dit_clipped.py | 63 +++++++++++++----- modules/encoders/modules.py | 76 +++++++++++----------- sample_gradio_table.py | 125 ++++++++++++++++++++++++++++++++++++ train_pl.py | 46 +++++++------ 5 files changed, 250 insertions(+), 140 deletions(-) create mode 100644 sample_gradio_table.py diff --git a/README.md b/README.md index a914c631..aa2929c4 100644 --- a/README.md +++ b/README.md @@ -57,92 +57,36 @@ conda activate DiT for our pre-trained DiT model will be automatically downloaded depending on the model you use. The script has various arguments to switch between the 256x256 and 512x512 models, adjust sampling steps, change the classifier-free guidance scale, etc. For example, to sample from -our 512x512 DiT-XL/2 model, you can use: +our 512x512 DiT-clipped model, you can use the new gradio interface: ```bash -python sample.py --image-size 512 --seed 1 -``` - -**New gradio interface!** - -```bash -python sample_gradio.py +python sample_gradio.py --ckpt pretrained_models/last.ckpt ``` For convenience, our pre-trained DiT models can be downloaded directly here as well: -| DiT Model | Image Resolution | FID-50K | Inception Score | Gflops | -|---------------|------------------|---------|-----------------|--------| -| [XL/2](https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt) | 256x256 | 2.27 | 278.24 | 119 | -| [XL/2](https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-512x512.pt) | 512x512 | 3.04 | 240.82 | 525 | - -**Custom DiT checkpoints.** If you've trained a new DiT model with [`train.py`](train.py) (see [below](#training-dit)), -you can add the `--ckpt` -argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom -256x256 DiT-L/4 model, run: - -```bash -python sample.py --model DiT-L/4 --image-size 256 --ckpt/path/to/model.pt -``` +| DiT Model | Image Resolution | +|------------------------------------------------------------------------------|------------------| +| [DiT_clipped](https://www.mediafire.com/file/trqvosl8947s88z/last.ckpt/file) | 256x256 | ## Training DiT -We provide a training script for DiT in [`train.py`](train.py). This script can be used to train class-conditional -DiT models, but it can be easily modified to support other types of conditioning. To launch DiT-XL/2 (256x256) training +We provide a training script for DiT in [`train_pl.py`](train_pl.py). This script can be used to train class-conditional +DiT models, but it can be easily modified to support other types of conditioning. To launch DiT-clipped (256x256) +training with `N` GPUs on one node: ```bash -torchrun --nnodes=1 --nproc_per_node = N train.py --model DiT-XL/2 --data -path/path/to/imagenet/train +python train_pl.py --coco_dataset_path (...)/datasets/fast-ai-coco ``` -### PyTorch Training Results - -We've trained DiT-XL/2 and DiT-B/4 models from scratch with the PyTorch training script -to verify that it reproduces the original JAX results up to several hundred thousand training iterations. Across our -experiments, the PyTorch-trained models give -similar (and sometimes slightly better) results compared to the JAX-trained models up to reasonable random variation. -Some data points: - -| DiT Model | Train Steps | FID-50K
(JAX Training) | FID-50K
(PyTorch Training) | PyTorch Global Training Seed | -|------------|-------------|----------------------------|--------------------------------|------------------------------| -| XL/2 | 400K | 19.5 | ** -18.1** | 42 | -| B/4 | 400K | ** -68.4** | 68.9 | 42 | -| B/4 | 400K | 68.4 | ** -68.3** | 100 | - -These models were trained at 256x256 resolution; we used 8x A100s to train XL/2 and 4x A100s to train B/4. Note that FID -here is computed with 250 DDPM sampling steps, with the `mse` VAE decoder and without guidance (`cfg-scale=1`). - -**TF32 Note (important for A100 users).** When we ran the above tests, TF32 matmuls were disabled per PyTorch's -defaults. -We've enabled them at the top of `train.py` and `sample.py` because it makes training and sampling way way way faster on -A100s (and should for other Ampere GPUs too), but note that the use of TF32 may lead to some differences compared to -the above results. - ### Enhancements -Training (and sampling) could likely be sped-up significantly by: - -- [ ] using [Flash Attention](https://github.com/HazyResearch/flash-attention) in the DiT model -- [ ] using `torch.compile` in PyTorch 2.0 - -Basic features that would be nice to add: - -- [ ] Monitor FID and other metrics -- [ ] Generate and save samples from the EMA model periodically -- [ ] Resume training from a checkpoint -- [ ] AMP/bfloat16 support - -## Differences from JAX +Improvements to the project could be as follows: -Our models were originally trained in JAX on TPUs. The weights in this repo are ported directly from the JAX models. -There may be minor differences in results stemming from sampling with different floating point precisions. We -re-evaluated -our ported PyTorch weights at FP32, and they actually perform marginally better than sampling in JAX (2.21 FID -versus 2.27 in the paper). +- [ ] Improve generation quality by training the checkpoint further +- [ ] Adding more DiT_clipped architectures with more params and better training them ## BibTeX diff --git a/modules/dit_clipped.py b/modules/dit_clipped.py index a4448514..5c65185d 100644 --- a/modules/dit_clipped.py +++ b/modules/dit_clipped.py @@ -1,17 +1,18 @@ -import random +import gc import torch import torch.nn as nn -import pytorch_lightning as pl +import lightning as L from timm.models.vision_transformer import PatchEmbed +from tqdm import tqdm from modules.encoders.modules import FrozenCLIPEmbedder from modules.utils import TimestepEmbedder, DiTBlock, FinalLayer, get_2d_sincos_pos_embed, process_input_laion -class DiT_Clipped(pl.LightningModule): +class DiT_Clipped(L.LightningModule): """ Diffusion model with a Transformer backbone and clip encoder. """ @@ -28,7 +29,8 @@ def __init__( mlp_ratio=4.0, class_dropout_prob=0.1, learn_sigma=True, - clip_version='openai/clip-vit-large-patch14' + clip_version='openai/clip-vit-large-patch14', + compile_components=False ): super().__init__() self.learn_sigma = learn_sigma @@ -43,21 +45,43 @@ def __init__( # Will use fixed sin-cos embedding: self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) - self.blocks = nn.ModuleList([ - DiTBlock(hidden_size, num_heads, context_dim=context_dim, mlp_ratio=mlp_ratio) for _ in range(depth) - ]) self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) self.encoder = FrozenCLIPEmbedder(clip_version) self.secondary_device = torch.device("cpu") + self.blocks = [ + DiTBlock(hidden_size, num_heads, context_dim=context_dim, mlp_ratio=mlp_ratio) for _ in range(depth) + ] + self.initialize_weights() + if compile_components: + self.compile_components() + self.blocks = nn.ModuleList(self.blocks) + + def compile_components(self): + bar = tqdm(total=5, desc="Compiling components..") + bar.update(0) + self.x_embedder = torch.compile(self.x_embedder) + bar.update(1) + self.t_embedder = torch.compile(self.t_embedder) + bar.update(2) + self.final_layer = torch.compile(self.final_layer) + bar.update(3) + self.encoder = torch.compile(self.encoder) + bar.update(4) + self.blocks = [torch.compile(x) for x in self.blocks] + bar.update(5) + bar.close() + print("Compiling completed.") @torch.no_grad() def encode(self, text_prompt, device=None): device = device if device is not None else self.device - self.encoder.to(device) + if self.encoder.device != device: + self.encoder.to(device) + self.encoder.device = device c = self.encoder.encode(text_prompt) return c.to(self.device) @@ -124,6 +148,7 @@ def forward(self, x, t, context): # left context in, but it's not used atm x = self.final_layer(x, t, context) # (N, T, patch_size ** 2 * out_channels) + del t x = self.unpatchify(x) # (N, out_channels, H, W) return x @@ -148,15 +173,20 @@ def forward_with_cfg(self, x, t, y, cfg_scale): def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4, weight_decay=1e-5) - return optimizer + scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, max_lr=4e-4, total_steps=self.trainer.estimated_stepping_batches + ) + return {"optimizer": optimizer, + "lr_scheduler": scheduler} def training_step(self, train_batch, batch_idx): img, context = train_batch with torch.no_grad(): context = self.encode(context, device=self.secondary_device) - context, img = context.to(self.device).to(self.dtype), img.to(self.secondary_device).to(torch.float32) - self.vae.to(self.secondary_device).to(torch.float32) + context, img = context.to(self.dtype), img.to(self.secondary_device).to(torch.float32) + if self.vae.device != self.secondary_device: + self.vae.to(self.secondary_device).to(torch.float32) x = self.vae.encode(img).latent_dist.sample().mul_(0.18215).to(self.device).to(self.dtype) t = torch.randint(0, self.diffusion.num_timesteps, (x.shape[0],), device=self.device) @@ -167,11 +197,14 @@ def training_step(self, train_batch, batch_idx): model_kwargs = dict(context=context) loss_dict = self.diffusion.training_losses(self, x, t, model_kwargs) + + del x, t, context, model_kwargs + torch.cuda.empty_cache() + gc.collect() + loss = loss_dict["loss"].mean() # vb mse loss self.log("train_loss", loss) - self.log("train_bv", loss_dict["vb"].mean()) - self.log("train_mse", loss_dict["mse"].mean()) return loss @@ -185,8 +218,8 @@ def training_step(self, train_batch, batch_idx): # loss = loss_dict["loss"].mean() # self.log("val_loss", loss) - def backward(self, loss, optimizer, optimizer_idx, *args, **kwargs): + def backward(self, loss, *args, **kwargs) -> None: loss.backward() - def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): + def optimizer_zero_grad(self, epoch, batch_idx, optimizer): optimizer.zero_grad(set_to_none=True) diff --git a/modules/encoders/modules.py b/modules/encoders/modules.py index ccb927fe..d8866e39 100644 --- a/modules/encoders/modules.py +++ b/modules/encoders/modules.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from functools import partial -import open_clip +# import open_clip from einops import rearrange, repeat from transformers import CLIPTokenizer, CLIPTextModel import kornia @@ -170,43 +170,43 @@ def to(self, device, dtype=None, non_blocking=None): self.device = device -class FrozenCLIPTextEmbedder(nn.Module): - """ - Uses the CLIP transformer encoder for text. - """ - - def __init__(self, version='ViT-B-32-quickgelu', device="cpu", max_length=77, n_repeat=1, normalize=True): - super().__init__() - self.model = open_clip.create_model(version, pretrained='laion400m_e32', jit=False, device="cpu") - self.tokenizer = open_clip.get_tokenizer('ViT-B-32-quickgelu') - self.device = device - self.max_length = max_length - self.n_repeat = n_repeat - self.normalize = normalize - - def freeze(self): - self.model = self.model.eval() - for param in self.parameters(): - param.requires_grad = False - - def forward(self, text): - tokens = self.tokenizer(text).to(self.device) - z = self.model.encode_text(tokens) - if self.normalize: - z = z / torch.linalg.norm(z, dim=1, keepdim=True) - return z - - def encode(self, text): - z = self(text) - if z.ndim == 2: - z = z[:, None, :] - z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) - return z - - def to(self, device, dtype=None, non_blocking=None): - super(FrozenCLIPTextEmbedder, self).to(device) - self.model.to(device) - self.device = device +# class FrozenCLIPTextEmbedder(nn.Module): +# """ +# Uses the CLIP transformer encoder for text. +# """ +# +# def __init__(self, version='ViT-B-32-quickgelu', device="cpu", max_length=77, n_repeat=1, normalize=True): +# super().__init__() +# self.model = open_clip.create_model(version, pretrained='laion400m_e32', jit=False, device="cpu") +# self.tokenizer = open_clip.get_tokenizer('ViT-B-32-quickgelu') +# self.device = device +# self.max_length = max_length +# self.n_repeat = n_repeat +# self.normalize = normalize +# +# def freeze(self): +# self.model = self.model.eval() +# for param in self.parameters(): +# param.requires_grad = False +# +# def forward(self, text): +# tokens = self.tokenizer(text).to(self.device) +# z = self.model.encode_text(tokens) +# if self.normalize: +# z = z / torch.linalg.norm(z, dim=1, keepdim=True) +# return z +# +# def encode(self, text): +# z = self(text) +# if z.ndim == 2: +# z = z[:, None, :] +# z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) +# return z +# +# def to(self, device, dtype=None, non_blocking=None): +# super(FrozenCLIPTextEmbedder, self).to(device) +# self.model.to(device) +# self.device = device if __name__ == "__main__": diff --git a/sample_gradio_table.py b/sample_gradio_table.py new file mode 100644 index 00000000..fa5fff40 --- /dev/null +++ b/sample_gradio_table.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Sample new images from a pre-trained DiT. +""" +import random + +import numpy as np +import torch +import argparse +import gradio as gr +import torchvision +from matplotlib import pyplot as plt + +from torchvision.utils import make_grid +from diffusers.models import AutoencoderKL +from PIL import Image + +from modules.diffusion import create_diffusion +from download import find_model +from modules.dit_builder import DiT_models + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + + +def sample(prompt, init_scale, max_scale, num_scale_imgs, num_sampling_steps, max_sampling_steps, n_steps_images, seed): + # Setup PyTorch: + torch.manual_seed(seed) + torch.set_grad_enabled(False) + + fig, axs = plt.subplots(num_scale_imgs, n_steps_images, constrained_layout=True) + model.eval() + + bsize = 1 + z = torch.randn(bsize, 4, latent_size, latent_size, device=device) + y = model.encode(prompt).squeeze(1).repeat(bsize, 1, 1).to(device) + + # Setup classifier-free guidance: + z = torch.cat([z, z], 0) + y_null = model.encode("").squeeze(1).repeat(bsize, 1, 1).to(device) # negative + y = torch.cat([y, y_null], 0) + + for i in range(1, num_scale_imgs + 1): + for j in range(1, n_steps_images + 1): + model.to(device) + cfg_scale = round(float((max_scale - init_scale) / num_scale_imgs * i), 1) + steps = int((max_sampling_steps - num_sampling_steps) / n_steps_images * j) + + diffusion = create_diffusion(str(steps)) + + model_kwargs = dict(y=y, cfg_scale=cfg_scale) + samples = diffusion.p_sample_loop( + model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, + device=device + ) + samples, _ = samples.chunk(2, dim=0) # Remove null class samples + + # OOMs :cry: + model.cpu() + torch.cuda.empty_cache() + + vae.to(device) + samples = vae.decode(samples / 0.18215).sample + vae.cpu() + samples = samples.cpu()[0].permute(2, 1, 0).numpy() + 0.25 + # Save and display images: + axs[i - 1, j - 1].imshow(np.rot90(samples), k=3) + axs[i - 1, j - 1].set_title('scale: {} steps: {}'.format(cfg_scale, steps)) + fig.savefig('tmp.png') + return Image.open("tmp.png") + # save_image(samples, "sample.png", nrow=4, normalize=True, value_range=(-1, 1)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT_Clipped") + parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="mse") + parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--ckpt", type=str, default=None, + help="Optional path to a DiT checkpoint (default: auto-download a pre-trained DiT-XL/2 model).") + + args = parser.parse_args() + + device = "cuda" if torch.cuda.is_available() else "cpu" + + latent_size = args.image_size // 8 + model = DiT_models[args.model]( + input_size=latent_size, + ).to(device) + if args.ckpt: + print(f"Loading {args.ckpt}") + ckpt_path = args.ckpt + state_dict = find_model(ckpt_path) + if 'pytorch-lightning_version' in state_dict.keys(): + state_dict = state_dict["state_dict"] + model.load_state_dict(state_dict, strict=False) + + vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").cpu() + + demo = gr.Interface( + fn=sample, + inputs=[ + gr.Text(label="Text Prompt", value="an apple"), + + gr.Slider(minimum=1, maximum=20, value=3, step=0.1, label="Initial cfg scale"), + gr.Slider(minimum=1, maximum=20, value=13, step=0.1, label="Max cfg scale"), + gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Number of cfg steps"), + + gr.Slider(minimum=5, maximum=1000, value=30, step=1, label="Initial sampling steps"), + gr.Slider(minimum=5, maximum=1000, value=128, step=1, label="Max sampling steps"), + gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Number of steps steps"), + + gr.Slider(minimum=1, maximum=9223372036854775807, value=5782510030869745000, step=1, + label="Seed"), + ], + outputs=[ + gr.Image() + ] + ) + demo.launch() diff --git a/train_pl.py b/train_pl.py index 0e14a7cf..25080d5b 100644 --- a/train_pl.py +++ b/train_pl.py @@ -4,9 +4,9 @@ import torchvision import torch -import pytorch_lightning as pl +import lightning as L +from lightning.pytorch.callbacks import ModelCheckpoint -from pytorch_lightning.callbacks import ModelCheckpoint from torch.utils.data import DataLoader from torchvision.transforms import transforms @@ -21,21 +21,24 @@ def train_pl(args): print("Starting training..") - device = torch.device(0) secondary_device = torch.device("cpu") latent_size = args.image_size // 8 model = DiT_models[args.model]( input_size=latent_size, + compile_components=False ) - # dataset = HuggingfaceImageDataset(args.hf_dataset_name, args.token, res=args.image_size) - dataset = HuggingfaceImageNetDataset(args.token, res=args.image_size) - # dataset = torchvision.datasets.CocoCaptions(args.coco_dataset_path + "/train2017/", - # args.coco_dataset_path + "/annotations_train/captions_train2017.json", - # transform=transform, - # target_transform=lambda x: random.choice(x).replace("\n", "").lower()) + + # dataset1 = HuggingfaceImageNetDataset(args.token, res=args.image_size) + dataset = torchvision.datasets.CocoCaptions(args.coco_dataset_path + "/train2017/", + args.coco_dataset_path + "/annotations_train/captions_train2017.json", + transform=transform, + target_transform=lambda x: random.choice(x).replace("\n", "").lower()) + # dataset3 = HuggingfaceImageDataset(args.hf_dataset_name, args.token, res=args.image_size) # dataset = DummyDataset(res=args.image_size) + # dataset = torch.utils.data.ConcatDataset([dataset1, dataset2]) + state_dict = find_model("pretrained_models/last.ckpt") if 'pytorch-lightning_version' in state_dict.keys(): state_dict = state_dict["state_dict"] @@ -47,8 +50,10 @@ def train_pl(args): model.diffusion = diffusion model.vae = vae model.secondary_device = secondary_device + + grad_batch_accum = 6 model_ckpt = ModelCheckpoint(dirpath="ckpts/", monitor="train_loss", save_top_k=2, save_last=True, - every_n_train_steps=3_000) + every_n_train_steps=3_000 // grad_batch_accum) loader_train = DataLoader( dataset, @@ -59,20 +64,22 @@ def train_pl(args): ) # update_ema(ema, model.module, decay=0) - model.train().to(device) + model.train() # ema.eval() torch.set_float32_matmul_precision("high") - trainer = pl.Trainer( - auto_lr_find=False, + trainer = L.Trainer( enable_checkpointing=True, detect_anomaly=False, - log_every_n_steps=50, - accelerator='gpu', - devices=1, + log_every_n_steps=50 // grad_batch_accum, + accelerator="auto", + devices="auto", max_epochs=args.epochs, - precision=16 if args.precision == "fp16" else 32, - callbacks=[model_ckpt] + precision="16-mixed" if args.precision == "fp16" else "32-true", + callbacks=[model_ckpt], + # StochasticWeightAveraging(swa_lrs=1e-2)], + accumulate_grad_batches=grad_batch_accum, + # move_metrics_to_cpu=True ) trainer.fit(model, loader_train) @@ -97,7 +104,8 @@ def train_pl(args): parser.add_argument("--precision", type=str, choices=["fp16", "fp32"], default="fp16") parsed_args = parser.parse_args() - assert (parsed_args.hf_dataset_name is not None and parsed_args.token is not None) or (parsed_args.coco_dataset_path) is not None, "Either hf token and dataset name or coco dataset path is required" + assert (parsed_args.hf_dataset_name is not None and parsed_args.token is not None) or ( + parsed_args.coco_dataset_path) is not None, "Either hf token and dataset name or coco dataset path is required" transform = transforms.Compose([ transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, parsed_args.image_size)),