-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Diff pruning and Training on LDM #6
Comments
Hi @FATE4869, we perform LDM Pruning & Finetuning based on the official repo. The pruning code we used: python prune_ldm.py --sparsity 0.3 --pruner diff-pruning The contents of prune_ldm.py: import sys
sys.path.append(".")
sys.path.append('./taming-transformers')
from taming.models import vqgan
import argparse
from ldm.modules.attention import CrossAttention
parser = argparse.ArgumentParser()
parser.add_argument("--sparsity", type=float, default=0.0)
parser.add_argument("--pruner", type=str, choices=["magnitude", "random", "taylor", "diff-pruning", "reinit", "diff0"], default="magnitude")
args = parser.parse_args()
#@title loading utils
import torch
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
import torch_pruning as tp
def load_model_from_config(config, ckpt):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
model.cuda()
model.eval()
return model
def get_model():
config = OmegaConf.load("configs/latent-diffusion/cin256-v2.yaml")
model = load_model_from_config(config, "models/ldm/cin256-v2/model.ckpt")
return model
from ldm.models.diffusion.ddim import DDIMSampler
model = get_model()
sampler = DDIMSampler(model)
import numpy as np
from PIL import Image
from einops import rearrange
from torchvision.utils import make_grid
classes = [25, 187, 448, 992] # define classes to be sampled here
n_samples_per_class = 6
ddim_steps = 20
ddim_eta = 0.0
scale = 3.0 # for unconditional guidance
print(model)
print("Pruning ...")
model.eval()
if args.pruner == "magnitude":
imp = tp.importance.MagnitudeImportance()
elif args.pruner == "random":
imp = tp.importance.RandomImportance()
elif args.pruner == 'taylor':
imp = tp.importance.TaylorImportance(multivariable=True) # standard first-order taylor expansion
elif args.pruner == 'diff-pruning' or args.pruner == 'diff0':
imp = tp.importance.TaylorImportance(multivariable=False) # a modified version, estimating the accumulated error of weight removal
else:
raise ValueError(f"Unknown pruner '{args.pruner}'")
ignored_layers = [model.model.diffusion_model.out]
channel_groups = {}
iterative_steps = 1
uc = model.get_learned_conditioning(
{model.cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)}
)
for m in model.model.diffusion_model.modules():
if isinstance(m, CrossAttention):
channel_groups[m.to_q] = m.heads
channel_groups[m.to_k] = m.heads
channel_groups[m.to_v] = m.heads
xc = torch.tensor(n_samples_per_class*[classes[0]])
c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})
example_inputs = {"x": torch.randn(n_samples_per_class, 3, 64, 64).to(model.device), "timesteps": torch.full((n_samples_per_class,), 1, device=model.device, dtype=torch.long), "context": c}
base_macs, base_params = tp.utils.count_ops_and_params(model.model.diffusion_model, example_inputs)
pruner = tp.pruner.MagnitudePruner(
model.model.diffusion_model,
example_inputs,
importance=imp,
iterative_steps=1,
channel_groups =channel_groups,
ch_sparsity=args.sparsity, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
ignored_layers=ignored_layers,
root_module_types=[torch.nn.Conv2d, torch.nn.Linear],
round_to=2
)
model.zero_grad()
import random
max_loss = -1
for t in range(1000):
if args.pruner not in ['diff-pruning', 'taylor', 'diff0']:
break
xc = torch.tensor(random.sample(range(1000), n_samples_per_class))
#xc = torch.tensor(n_samples_per_class*[class_label])
c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})
samples_ddim, _ = sampler.sample(S=ddim_steps,
conditioning=c,
batch_size=n_samples_per_class,
shape=[3, 64, 64],
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
eta=ddim_eta)
encoded = model.encode_first_stage(samples_ddim)
example_inputs = {"x": encoded.to(model.device), "timesteps": torch.full((n_samples_per_class,), t, device=model.device, dtype=torch.long), "context": c}
loss = model.get_loss_at_t(example_inputs['x'], {model.cond_stage_key: xc.to(model.device)}, example_inputs['timesteps'])
loss = loss[0]
if loss > max_loss:
max_loss = loss
thres = 0.1 if args.pruner == 'diff-pruning' else 0.0
if args.pruner == 'diff-pruning' or args.pruner == 'diff0':
if loss / max_loss<thres:
break
print(t, (loss / max_loss).item(), loss.item(), max_loss.item())
loss.backward()
pruner.step()
print("After pruning")
print(model)
pruend_macs, pruned_params = tp.utils.count_ops_and_params(model.model.diffusion_model, example_inputs)
print(f"MACs: {pruend_macs / base_macs * 100:.2f}%, {base_macs / 1e9:.2f}G => {pruend_macs / 1e9:.2f}G")
print(f"Params: {pruned_params / base_params * 100:.2f}%, {base_params / 1e6:.2f}M => {pruned_params / 1e6:.2f}M")
all_samples = list()
with torch.no_grad():
with model.ema_scope():
uc = model.get_learned_conditioning(
{model.cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)}
)
for class_label in classes:
print(f"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.")
xc = torch.tensor(n_samples_per_class*[class_label])
c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})
samples_ddim, _ = sampler.sample(S=ddim_steps,
conditioning=c,
batch_size=n_samples_per_class,
shape=[3, 64, 64],
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
eta=ddim_eta)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0,
min=0.0, max=1.0)
all_samples.append(x_samples_ddim)
# display as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_samples_per_class)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
img = Image.fromarray(grid.astype(np.uint8))
img.save("samples.png")
print("Saving pruned model ...")
torch.save(model, "logs/pruned_model_{}_{}.pt".format(args.sparsity, args.pruner)) The LDM project is a bit complicated so we have not included LDM Pruning and finetuning in this repo. We also attached the pruning and fine-tuning code in code.zip |
And the code for sampling: python sample_for_FID.py --output run/samples --batch_size 10 import sys, os
sys.path.append(".")
sys.path.append('./taming-transformers')
from taming.models import vqgan
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--pruned_model", type=str, default=None)
parser.add_argument("--finetuned_ckpt", type=str, default=None)
parser.add_argument("--ipc", type=int, default=50)
parser.add_argument("--output", type=str, default='run')
parser.add_argument("--batch_size", type=int, default=50)
args = parser.parse_args()
#@title loading utils
import torch
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
import torch_pruning as tp
from ldm.models.diffusion.ddim import DDIMSampler
def load_model_from_config(config, ckpt):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
model.cuda()
model.eval()
return model
def get_model():
config = OmegaConf.load("configs/latent-diffusion/cin256-v2.yaml")
model = load_model_from_config(config, "models/ldm/cin256-v2/model.ckpt")
return model
if args.pruned_model is None:
model = get_model()
else:
print("Loading model from ", args.pruned_model)
model = torch.load(args.pruned_model, map_location="cpu")
print("Loading finetuned parameters from ", args.finetuned_ckpt)
pl_sd = torch.load(args.finetuned_ckpt, map_location="cpu")
sd = pl_sd["state_dict"]
m, u = model.load_state_dict(sd, strict=False)
model.cuda()
print(model)
sampler = DDIMSampler(model)
num_params = sum(p.numel() for p in model.parameters())
print("Number of parameters: {}", num_params/1000000, "M")
import numpy as np
from PIL import Image
from einops import rearrange
from torchvision.utils import make_grid
classes = range(1000) # define classes to be sampled here
n_samples_per_class = args.batch_size
n_batch_per_class = args.ipc // args.batch_size
ddim_steps = 250
ddim_eta = 0.0
scale = 3.0 # for unconditional guidance
all_samples = list()
from torchvision import utils as tvu
os.makedirs(args.output, exist_ok=True)
img_id = 0
with torch.no_grad():
with model.ema_scope():
uc = model.get_learned_conditioning(
{model.cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)}
)
for _ in range(n_batch_per_class):
for class_label in classes:
print(f"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.")
xc = torch.tensor(n_samples_per_class*[class_label])
c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})
samples_ddim, _ = sampler.sample(S=ddim_steps,
conditioning=c,
batch_size=n_samples_per_class,
shape=[3, 64, 64],
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
eta=ddim_eta)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0,
min=0.0, max=1.0)
#all_samples.append(x_samples_ddim)
for i in range(len(x_samples_ddim)):
tvu.save_image(
x_samples_ddim[i], os.path.join(args.output, f"{class_label}_{img_id}.png")
)
img_id += 1
|
hi @VainF , thank you for providing these works, I got this error from running the above code you provided, can you help me check this? thanks a lot !!! |
I guess the Idm repo has changed a lot in these months. Let me check and upload the whole LDM project. |
Thank you!!! |
Hi @jonathanyang0227, I uploaded the original code for LDM. It's a bit messy. https://github.com/VainF/Diff-Pruning/tree/main/ldm_exp Sample images for FID: # to generate the fid_stats_imagenet.npz file
python fid_score.py --save-stats ~/Datasets/imagenet/train run/fid_stats_imagenet --device cuda:0 --batch-size 64 --num_samples 50000 --res 256
# sample images from the pruned LDM
python sample_for_FID.py --pruned_model logs/pruned_model_0.3_diff-pruning.pt --finetuned_ckpt logs/2023-08-06T01-06-01_cin256-v2/checkpoints/epoch=000004.ckpt --ipc 50 --output PATH_TO_YOUR_IMAGES
# FID
python fid_score.py run/fid_stats_imagenet.npz PATH_TO_YOUR_IMAGES --device cuda:0 --batch-size 100 |
thank you so much!! |
Hi, thank you for publishing this amazing work about the structural pruning on diffusion models. I wondered if you are also publishing the code for diff-pruning and training on LDM. The current ldm-prune code only supports random, magnitude and reinit pruning. Thx!
The text was updated successfully, but these errors were encountered: