Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Diff pruning and Training on LDM #6

Open
FATE4869 opened this issue Jan 4, 2024 · 7 comments
Open

Diff pruning and Training on LDM #6

FATE4869 opened this issue Jan 4, 2024 · 7 comments

Comments

@FATE4869
Copy link

FATE4869 commented Jan 4, 2024

Hi, thank you for publishing this amazing work about the structural pruning on diffusion models. I wondered if you are also publishing the code for diff-pruning and training on LDM. The current ldm-prune code only supports random, magnitude and reinit pruning. Thx!

@VainF
Copy link
Owner

VainF commented Jan 5, 2024

Hi @FATE4869, we perform LDM Pruning & Finetuning based on the official repo. The pruning code we used:

python prune_ldm.py --sparsity 0.3 --pruner diff-pruning

The contents of prune_ldm.py:

import sys
sys.path.append(".")
sys.path.append('./taming-transformers')
from taming.models import vqgan 
import argparse
from ldm.modules.attention import CrossAttention

parser = argparse.ArgumentParser()
parser.add_argument("--sparsity", type=float, default=0.0)
parser.add_argument("--pruner", type=str, choices=["magnitude", "random", "taylor", "diff-pruning", "reinit", "diff0"], default="magnitude")
args = parser.parse_args()

#@title loading utils
import torch
from omegaconf import OmegaConf

from ldm.util import instantiate_from_config

import torch_pruning as tp

def load_model_from_config(config, ckpt):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    model.cuda()
    model.eval()
    return model


def get_model():
    config = OmegaConf.load("configs/latent-diffusion/cin256-v2.yaml")  
    model = load_model_from_config(config, "models/ldm/cin256-v2/model.ckpt")
    return model

from ldm.models.diffusion.ddim import DDIMSampler

model = get_model()
sampler = DDIMSampler(model)

import numpy as np 
from PIL import Image
from einops import rearrange
from torchvision.utils import make_grid


classes = [25, 187, 448, 992]   # define classes to be sampled here
n_samples_per_class = 6

ddim_steps = 20
ddim_eta = 0.0
scale = 3.0   # for unconditional guidance

print(model)

print("Pruning ...")
model.eval()

if args.pruner == "magnitude":
    imp = tp.importance.MagnitudeImportance()
elif args.pruner == "random":
    imp = tp.importance.RandomImportance()
elif args.pruner == 'taylor':
    imp = tp.importance.TaylorImportance(multivariable=True) # standard first-order taylor expansion
elif args.pruner == 'diff-pruning' or args.pruner == 'diff0':
    imp = tp.importance.TaylorImportance(multivariable=False) # a modified version, estimating the accumulated error of weight removal
else:
    raise ValueError(f"Unknown pruner '{args.pruner}'")

ignored_layers = [model.model.diffusion_model.out]
channel_groups = {}
iterative_steps = 1
uc = model.get_learned_conditioning(
            {model.cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)}
            )


for m in model.model.diffusion_model.modules():
    if isinstance(m, CrossAttention):
        channel_groups[m.to_q] = m.heads
        channel_groups[m.to_k] = m.heads
        channel_groups[m.to_v] = m.heads


xc = torch.tensor(n_samples_per_class*[classes[0]])
c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})
example_inputs = {"x": torch.randn(n_samples_per_class, 3, 64, 64).to(model.device), "timesteps": torch.full((n_samples_per_class,), 1, device=model.device, dtype=torch.long), "context": c}
base_macs, base_params = tp.utils.count_ops_and_params(model.model.diffusion_model, example_inputs)
pruner = tp.pruner.MagnitudePruner(
    model.model.diffusion_model,
    example_inputs,
    importance=imp,
    iterative_steps=1,
    channel_groups =channel_groups,
    ch_sparsity=args.sparsity, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
    root_module_types=[torch.nn.Conv2d, torch.nn.Linear],
    round_to=2
)
model.zero_grad()

import random
max_loss = -1
for t in range(1000):
    if args.pruner not in ['diff-pruning', 'taylor', 'diff0']:
        break
    xc = torch.tensor(random.sample(range(1000), n_samples_per_class))
    #xc = torch.tensor(n_samples_per_class*[class_label])
    c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})
    samples_ddim, _ = sampler.sample(S=ddim_steps,
                                    conditioning=c,
                                    batch_size=n_samples_per_class,
                                    shape=[3, 64, 64],
                                    verbose=False,
                                    unconditional_guidance_scale=scale,
                                    unconditional_conditioning=uc, 
                                    eta=ddim_eta)

    encoded = model.encode_first_stage(samples_ddim)
    example_inputs = {"x": encoded.to(model.device), "timesteps": torch.full((n_samples_per_class,), t, device=model.device, dtype=torch.long), "context": c}
    loss = model.get_loss_at_t(example_inputs['x'], {model.cond_stage_key: xc.to(model.device)}, example_inputs['timesteps'])
    loss = loss[0]
    if loss > max_loss:
        max_loss = loss
    thres = 0.1 if args.pruner == 'diff-pruning' else 0.0
    if args.pruner == 'diff-pruning' or args.pruner == 'diff0':
        if loss / max_loss<thres:
            break
    print(t, (loss / max_loss).item(), loss.item(), max_loss.item())
    loss.backward()
pruner.step() 

print("After pruning")
print(model)

pruend_macs, pruned_params = tp.utils.count_ops_and_params(model.model.diffusion_model, example_inputs)
print(f"MACs: {pruend_macs / base_macs * 100:.2f}%, {base_macs / 1e9:.2f}G => {pruend_macs / 1e9:.2f}G")
print(f"Params: {pruned_params / base_params * 100:.2f}%, {base_params / 1e6:.2f}M => {pruned_params / 1e6:.2f}M")

all_samples = list()

with torch.no_grad():
    with model.ema_scope():
        uc = model.get_learned_conditioning(
            {model.cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)}
            )
        
        for class_label in classes:
            print(f"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.")
            xc = torch.tensor(n_samples_per_class*[class_label])
            c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})
            
            samples_ddim, _ = sampler.sample(S=ddim_steps,
                                             conditioning=c,
                                             batch_size=n_samples_per_class,
                                             shape=[3, 64, 64],
                                             verbose=False,
                                             unconditional_guidance_scale=scale,
                                             unconditional_conditioning=uc, 
                                             eta=ddim_eta)

            x_samples_ddim = model.decode_first_stage(samples_ddim)
            x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, 
                                         min=0.0, max=1.0)
            all_samples.append(x_samples_ddim)


# display as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_samples_per_class)

# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
img = Image.fromarray(grid.astype(np.uint8))
img.save("samples.png")



print("Saving pruned model ...")
torch.save(model, "logs/pruned_model_{}_{}.pt".format(args.sparsity, args.pruner))

The LDM project is a bit complicated so we have not included LDM Pruning and finetuning in this repo. We also attached the pruning and fine-tuning code in code.zip

@VainF
Copy link
Owner

VainF commented Jan 5, 2024

And the code for sampling:

python sample_for_FID.py --output run/samples --batch_size 10
import sys, os
sys.path.append(".")
sys.path.append('./taming-transformers')
from taming.models import vqgan 
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--pruned_model", type=str, default=None)
parser.add_argument("--finetuned_ckpt", type=str, default=None)
parser.add_argument("--ipc", type=int, default=50)
parser.add_argument("--output", type=str, default='run')
parser.add_argument("--batch_size", type=int, default=50)

args = parser.parse_args()

#@title loading utils
import torch
from omegaconf import OmegaConf

from ldm.util import instantiate_from_config

import torch_pruning as tp

from ldm.models.diffusion.ddim import DDIMSampler

def load_model_from_config(config, ckpt):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    model.cuda()
    model.eval()
    return model

def get_model():
    config = OmegaConf.load("configs/latent-diffusion/cin256-v2.yaml")  
    model = load_model_from_config(config, "models/ldm/cin256-v2/model.ckpt")
    return model

if args.pruned_model is None:
    model = get_model()
else:
    print("Loading model from ", args.pruned_model)
    model = torch.load(args.pruned_model, map_location="cpu")
    print("Loading finetuned parameters from ", args.finetuned_ckpt)
    pl_sd = torch.load(args.finetuned_ckpt, map_location="cpu")
    sd = pl_sd["state_dict"]
    m, u = model.load_state_dict(sd, strict=False)
model.cuda()
print(model)
sampler = DDIMSampler(model)

num_params = sum(p.numel() for p in model.parameters())
print("Number of parameters: {}", num_params/1000000, "M")

import numpy as np 
from PIL import Image
from einops import rearrange
from torchvision.utils import make_grid


classes = range(1000)   # define classes to be sampled here
n_samples_per_class = args.batch_size
n_batch_per_class = args.ipc // args.batch_size

ddim_steps = 250
ddim_eta = 0.0
scale = 3.0   # for unconditional guidance

all_samples = list()

from torchvision import utils as tvu
os.makedirs(args.output, exist_ok=True)

img_id = 0
with torch.no_grad():
    with model.ema_scope():
        uc = model.get_learned_conditioning(
            {model.cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)}
        )
        
        for _ in range(n_batch_per_class):
            for class_label in classes:
                print(f"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.")
                xc = torch.tensor(n_samples_per_class*[class_label])
                c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})
                
                samples_ddim, _ = sampler.sample(S=ddim_steps,
                                                conditioning=c,
                                                batch_size=n_samples_per_class,
                                                shape=[3, 64, 64],
                                                verbose=False,
                                                unconditional_guidance_scale=scale,
                                                unconditional_conditioning=uc, 
                                                eta=ddim_eta)

                x_samples_ddim = model.decode_first_stage(samples_ddim)
                x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, 
                                            min=0.0, max=1.0)
                #all_samples.append(x_samples_ddim)
                for i in range(len(x_samples_ddim)):
                    tvu.save_image(
                        x_samples_ddim[i], os.path.join(args.output, f"{class_label}_{img_id}.png")
                    )
                    img_id += 1
                

@jonathanyang0227
Copy link

jonathanyang0227 commented Jan 9, 2024

hi @VainF , thank you for providing these works, I got this error from running the above code you provided, can you help me check this? thanks a lot !!!
Screenshot 2024-01-09 at 12 25 50 PM

@VainF
Copy link
Owner

VainF commented Jan 9, 2024

I guess the Idm repo has changed a lot in these months. Let me check and upload the whole LDM project.

@jonathanyang0227
Copy link

I guess the Idm repo has changed a lot in these months. Let me check and upload the whole LDM project.

Thank you!!!

@VainF
Copy link
Owner

VainF commented Jan 13, 2024

Hi @jonathanyang0227, I uploaded the original code for LDM. It's a bit messy.

https://github.com/VainF/Diff-Pruning/tree/main/ldm_exp

Sample images for FID:

# to generate the fid_stats_imagenet.npz file
 python fid_score.py --save-stats ~/Datasets/imagenet/train run/fid_stats_imagenet --device cuda:0 --batch-size 64 --num_samples 50000 --res 256

# sample images from the pruned LDM
python sample_for_FID.py --pruned_model logs/pruned_model_0.3_diff-pruning.pt --finetuned_ckpt logs/2023-08-06T01-06-01_cin256-v2/checkpoints/epoch=000004.ckpt --ipc 50 --output PATH_TO_YOUR_IMAGES

# FID
python fid_score.py run/fid_stats_imagenet.npz PATH_TO_YOUR_IMAGES  --device cuda:0 --batch-size 100 

@jonathanyang0227
Copy link

thank you so much!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants