forked from deforum/stable-diffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDeforum_Stable_Diffusion.py
1903 lines (1614 loc) · 78.3 KB
/
Deforum_Stable_Diffusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# %%
# !! {"metadata":{
# !! "id": "c442uQJ_gUgy"
# !! }}
"""
# **Deforum Stable Diffusion v0.5**
[Stable Diffusion](https://github.com/CompVis/stable-diffusion) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer and the [Stability.ai](https://stability.ai/) Team. [K Diffusion](https://github.com/crowsonkb/k-diffusion) by [Katherine Crowson](https://twitter.com/RiversHaveWings). You need to get the ckpt file and put it on your Google Drive first to use this. It can be downloaded from [HuggingFace](https://huggingface.co/CompVis/stable-diffusion).
Notebook by [deforum](https://discord.gg/upmXXsrwZc)
"""
# %%
# !! {"metadata":{
# !! "id": "LBamKxcmNI7-"
# !! }}
"""
By using this Notebook, you agree to the following Terms of Use, and license:
**Stablity.AI Model Terms of Use**
This model is open access and available to all, with a CreativeML OpenRAIL-M license further specifying rights and usage.
The CreativeML OpenRAIL License specifies:
You can't use the model to deliberately produce nor share illegal or harmful outputs or content
CompVis claims no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in the license
You may re-distribute the weights and use the model commercially and/or as a service. If you do, please be aware you have to include the same use restrictions as the ones in the license and share a copy of the CreativeML OpenRAIL-M to all your users (please read the license entirely and carefully)
Please read the full license here: https://huggingface.co/spaces/CompVis/stable-diffusion-license
"""
# %%
# !! {"metadata":{
# !! "id": "T4knibRpAQ06"
# !! }}
"""
# Setup
"""
# %%
# !! {"metadata":{
# !! "id": "2g-f7cQmf2Nt",
# !! "cellView": "form"
# !! }}
#@markdown **NVIDIA GPU**
import subprocess
sub_p_res = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.free', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8')
print(sub_p_res)
# %%
# !! {"metadata":{
# !! "cellView": "form",
# !! "id": "TxIOPT0G5Lx1"
# !! }}
#@markdown **Model and Output Paths**
# ask for the link
print("Local Path Variables:\n")
models_path = "/content/models" #@param {type:"string"}
output_path = "/content/output" #@param {type:"string"}
#@markdown **Google Drive Path Variables (Optional)**
mount_google_drive = True #@param {type:"boolean"}
force_remount = False
if mount_google_drive:
from google.colab import drive # type: ignore
try:
drive_path = "/content/drive"
drive.mount(drive_path,force_remount=force_remount)
models_path_gdrive = "/content/drive/MyDrive/AI/models" #@param {type:"string"}
output_path_gdrive = "/content/drive/MyDrive/AI/StableDiffusion" #@param {type:"string"}
models_path = models_path_gdrive
output_path = output_path_gdrive
except:
print("...error mounting drive or with drive path variables")
print("...reverting to default path variables")
import os
os.makedirs(models_path, exist_ok=True)
os.makedirs(output_path, exist_ok=True)
print(f"models_path: {models_path}")
print(f"output_path: {output_path}")
# %%
# !! {"metadata":{
# !! "id": "VRNl2mfepEIe",
# !! "cellView": "form"
# !! }}
#@markdown **Setup Environment**
setup_environment = True #@param {type:"boolean"}
print_subprocess = False #@param {type:"boolean"}
if setup_environment:
import subprocess, time
print("Setting up environment...")
start_time = time.time()
all_process = [
['pip', 'install', 'torch==1.12.1+cu113', 'torchvision==0.13.1+cu113', '--extra-index-url', 'https://download.pytorch.org/whl/cu113'],
['pip', 'install', 'omegaconf==2.2.3', 'einops==0.4.1', 'pytorch-lightning==1.7.4', 'torchmetrics==0.9.3', 'torchtext==0.13.1', 'transformers==4.21.2', 'kornia==0.6.7'],
['git', 'clone', 'https://github.com/deforum/stable-diffusion'],
['pip', 'install', '-e', 'git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers'],
['pip', 'install', '-e', 'git+https://github.com/openai/CLIP.git@main#egg=clip'],
['pip', 'install', 'accelerate', 'ftfy', 'jsonmerge', 'matplotlib', 'resize-right', 'timm', 'torchdiffeq'],
['git', 'clone', 'https://github.com/shariqfarooq123/AdaBins.git'],
['git', 'clone', 'https://github.com/isl-org/MiDaS.git'],
['git', 'clone', 'https://github.com/MSFTserver/pytorch3d-lite.git'],
]
for process in all_process:
running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')
if print_subprocess:
print(running)
print(subprocess.run(['git', 'clone', 'https://github.com/deforum/k-diffusion/'], stdout=subprocess.PIPE).stdout.decode('utf-8'))
with open('k-diffusion/k_diffusion/__init__.py', 'w') as f:
f.write('')
end_time = time.time()
print(f"Environment set up in {end_time-start_time:.0f} seconds")
# %%
# !! {"metadata":{
# !! "id": "81qmVZbrm4uu",
# !! "cellView": "form"
# !! }}
#@markdown **Python Definitions**
import json
from IPython import display
import gc, math, os, pathlib, subprocess, sys, time
import cv2
import numpy as np
import pandas as pd
import random
import requests
import torch
import torch.nn as nn
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from contextlib import contextmanager, nullcontext
from einops import rearrange, repeat
from omegaconf import OmegaConf
from PIL import Image
from pytorch_lightning import seed_everything
from skimage.exposure import match_histograms
from torchvision.utils import make_grid
from tqdm import tqdm, trange
from types import SimpleNamespace
from torch import autocast
import re
from scipy.ndimage import gaussian_filter
sys.path.extend([
'src/taming-transformers',
'src/clip',
'stable-diffusion/',
'k-diffusion',
'pytorch3d-lite',
'AdaBins',
'MiDaS',
])
import py3d_tools as p3d
from helpers import DepthModel, sampler_fn
from k_diffusion.external import CompVisDenoiser
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
def sanitize(prompt):
whitelist = set('abcdefghijklmnopqrstuvwxyz ABCDEFGHIJKLMNOPQRSTUVWXYZ')
tmp = ''.join(filter(whitelist.__contains__, prompt))
return tmp.replace(' ', '_')
from functools import reduce
def construct_RotationMatrixHomogenous(rotation_angles):
assert(type(rotation_angles)==list and len(rotation_angles)==3)
RH = np.eye(4,4)
cv2.Rodrigues(np.array(rotation_angles), RH[0:3, 0:3])
return RH
# https://en.wikipedia.org/wiki/Rotation_matrix
def getRotationMatrixManual(rotation_angles):
rotation_angles = [np.deg2rad(x) for x in rotation_angles]
phi = rotation_angles[0] # around x
gamma = rotation_angles[1] # around y
theta = rotation_angles[2] # around z
# X rotation
Rphi = np.eye(4,4)
sp = np.sin(phi)
cp = np.cos(phi)
Rphi[1,1] = cp
Rphi[2,2] = Rphi[1,1]
Rphi[1,2] = -sp
Rphi[2,1] = sp
# Y rotation
Rgamma = np.eye(4,4)
sg = np.sin(gamma)
cg = np.cos(gamma)
Rgamma[0,0] = cg
Rgamma[2,2] = Rgamma[0,0]
Rgamma[0,2] = sg
Rgamma[2,0] = -sg
# Z rotation (in-image-plane)
Rtheta = np.eye(4,4)
st = np.sin(theta)
ct = np.cos(theta)
Rtheta[0,0] = ct
Rtheta[1,1] = Rtheta[0,0]
Rtheta[0,1] = -st
Rtheta[1,0] = st
R = reduce(lambda x,y : np.matmul(x,y), [Rphi, Rgamma, Rtheta])
return R
def getPoints_for_PerspectiveTranformEstimation(ptsIn, ptsOut, W, H, sidelength):
ptsIn2D = ptsIn[0,:]
ptsOut2D = ptsOut[0,:]
ptsOut2Dlist = []
ptsIn2Dlist = []
for i in range(0,4):
ptsOut2Dlist.append([ptsOut2D[i,0], ptsOut2D[i,1]])
ptsIn2Dlist.append([ptsIn2D[i,0], ptsIn2D[i,1]])
pin = np.array(ptsIn2Dlist) + [W/2.,H/2.]
pout = (np.array(ptsOut2Dlist) + [1.,1.]) * (0.5*sidelength)
pin = pin.astype(np.float32)
pout = pout.astype(np.float32)
return pin, pout
def warpMatrix(W, H, theta, phi, gamma, scale, fV):
# M is to be estimated
M = np.eye(4, 4)
fVhalf = np.deg2rad(fV/2.)
d = np.sqrt(W*W+H*H)
sideLength = scale*d/np.cos(fVhalf)
h = d/(2.0*np.sin(fVhalf))
n = h-(d/2.0);
f = h+(d/2.0);
# Translation along Z-axis by -h
T = np.eye(4,4)
T[2,3] = -h
# Rotation matrices around x,y,z
R = getRotationMatrixManual([phi, gamma, theta])
# Projection Matrix
P = np.eye(4,4)
P[0,0] = 1.0/np.tan(fVhalf)
P[1,1] = P[0,0]
P[2,2] = -(f+n)/(f-n)
P[2,3] = -(2.0*f*n)/(f-n)
P[3,2] = -1.0
# pythonic matrix multiplication
F = reduce(lambda x,y : np.matmul(x,y), [P, T, R])
# shape should be 1,4,3 for ptsIn and ptsOut since perspectiveTransform() expects data in this way.
# In C++, this can be achieved by Mat ptsIn(1,4,CV_64FC3);
ptsIn = np.array([[
[-W/2., H/2., 0.],[ W/2., H/2., 0.],[ W/2.,-H/2., 0.],[-W/2.,-H/2., 0.]
]])
ptsOut = np.array(np.zeros((ptsIn.shape), dtype=ptsIn.dtype))
ptsOut = cv2.perspectiveTransform(ptsIn, F)
ptsInPt2f, ptsOutPt2f = getPoints_for_PerspectiveTranformEstimation(ptsIn, ptsOut, W, H, sideLength)
# check float32 otherwise OpenCV throws an error
assert(ptsInPt2f.dtype == np.float32)
assert(ptsOutPt2f.dtype == np.float32)
M33 = cv2.getPerspectiveTransform(ptsInPt2f,ptsOutPt2f)
return M33, sideLength
def anim_frame_warp_2d(prev_img_cv2, args, anim_args, keys, frame_idx):
angle = keys.angle_series[frame_idx]
zoom = keys.zoom_series[frame_idx]
translation_x = keys.translation_x_series[frame_idx]
translation_y = keys.translation_y_series[frame_idx]
center = (args.W // 2, args.H // 2)
trans_mat = np.float32([[1, 0, translation_x], [0, 1, translation_y]])
rot_mat = cv2.getRotationMatrix2D(center, angle, zoom)
trans_mat = np.vstack([trans_mat, [0,0,1]])
rot_mat = np.vstack([rot_mat, [0,0,1]])
if anim_args.flip_2d_perspective:
perspective_flip_theta = keys.perspective_flip_theta_series[frame_idx]
perspective_flip_phi = keys.perspective_flip_phi_series[frame_idx]
perspective_flip_gamma = keys.perspective_flip_gamma_series[frame_idx]
perspective_flip_fv = keys.perspective_flip_fv_series[frame_idx]
M,sl = warpMatrix(args.W, args.H, perspective_flip_theta, perspective_flip_phi, perspective_flip_gamma, 1., perspective_flip_fv);
post_trans_mat = np.float32([[1, 0, (args.W-sl)/2], [0, 1, (args.H-sl)/2]])
post_trans_mat = np.vstack([post_trans_mat, [0,0,1]])
bM = np.matmul(M, post_trans_mat)
xform = np.matmul(bM, rot_mat, trans_mat)
else:
xform = np.matmul(rot_mat, trans_mat)
return cv2.warpPerspective(
prev_img_cv2,
xform,
(prev_img_cv2.shape[1], prev_img_cv2.shape[0]),
borderMode=cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE
)
def anim_frame_warp_3d(prev_img_cv2, depth, anim_args, keys, frame_idx):
TRANSLATION_SCALE = 1.0/200.0 # matches Disco
translate_xyz = [
-keys.translation_x_series[frame_idx] * TRANSLATION_SCALE,
keys.translation_y_series[frame_idx] * TRANSLATION_SCALE,
-keys.translation_z_series[frame_idx] * TRANSLATION_SCALE
]
rotate_xyz = [
math.radians(keys.rotation_3d_x_series[frame_idx]),
math.radians(keys.rotation_3d_y_series[frame_idx]),
math.radians(keys.rotation_3d_z_series[frame_idx])
]
rot_mat = p3d.euler_angles_to_matrix(torch.tensor(rotate_xyz, device=device), "XYZ").unsqueeze(0)
result = transform_image_3d(prev_img_cv2, depth, rot_mat, translate_xyz, anim_args)
torch.cuda.empty_cache()
return result
def add_noise(sample: torch.Tensor, noise_amt: float) -> torch.Tensor:
return sample + torch.randn(sample.shape, device=sample.device) * noise_amt
def get_output_folder(output_path, batch_folder):
out_path = os.path.join(output_path,time.strftime('%Y-%m'))
if batch_folder != "":
out_path = os.path.join(out_path, batch_folder)
os.makedirs(out_path, exist_ok=True)
return out_path
def load_img(path, shape, use_alpha_as_mask=False):
# use_alpha_as_mask: Read the alpha channel of the image as the mask image
if path.startswith('http://') or path.startswith('https://'):
image = Image.open(requests.get(path, stream=True).raw)
else:
image = Image.open(path)
if use_alpha_as_mask:
image = image.convert('RGBA')
else:
image = image.convert('RGB')
image = image.resize(shape, resample=Image.LANCZOS)
mask_image = None
if use_alpha_as_mask:
# Split alpha channel into a mask_image
red, green, blue, alpha = Image.Image.split(image)
mask_image = alpha.convert('L')
image = image.convert('RGB')
image = np.array(image).astype(np.float16) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
image = 2.*image - 1.
return image, mask_image
def load_mask_latent(mask_input, shape):
# mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object
# shape (list-like len(4)): shape of the image to match, usually latent_image.shape
if isinstance(mask_input, str): # mask input is probably a file name
if mask_input.startswith('http://') or mask_input.startswith('https://'):
mask_image = Image.open(requests.get(mask_input, stream=True).raw).convert('RGBA')
else:
mask_image = Image.open(mask_input).convert('RGBA')
elif isinstance(mask_input, Image.Image):
mask_image = mask_input
else:
raise Exception("mask_input must be a PIL image or a file name")
mask_w_h = (shape[-1], shape[-2])
mask = mask_image.resize(mask_w_h, resample=Image.LANCZOS)
mask = mask.convert("L")
return mask
def prepare_mask(mask_input, mask_shape, mask_brightness_adjust=1.0, mask_contrast_adjust=1.0):
# mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object
# shape (list-like len(4)): shape of the image to match, usually latent_image.shape
# mask_brightness_adjust (non-negative float): amount to adjust brightness of the iamge,
# 0 is black, 1 is no adjustment, >1 is brighter
# mask_contrast_adjust (non-negative float): amount to adjust contrast of the image,
# 0 is a flat grey image, 1 is no adjustment, >1 is more contrast
mask = load_mask_latent(mask_input, mask_shape)
# Mask brightness/contrast adjustments
if mask_brightness_adjust != 1:
mask = TF.adjust_brightness(mask, mask_brightness_adjust)
if mask_contrast_adjust != 1:
mask = TF.adjust_contrast(mask, mask_contrast_adjust)
# Mask image to array
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask,(4,1,1))
mask = np.expand_dims(mask,axis=0)
mask = torch.from_numpy(mask)
if args.invert_mask:
mask = ( (mask - 0.5) * -1) + 0.5
mask = np.clip(mask,0,1)
return mask
def maintain_colors(prev_img, color_match_sample, mode):
if mode == 'Match Frame 0 RGB':
return match_histograms(prev_img, color_match_sample, multichannel=True)
elif mode == 'Match Frame 0 HSV':
prev_img_hsv = cv2.cvtColor(prev_img, cv2.COLOR_RGB2HSV)
color_match_hsv = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2HSV)
matched_hsv = match_histograms(prev_img_hsv, color_match_hsv, multichannel=True)
return cv2.cvtColor(matched_hsv, cv2.COLOR_HSV2RGB)
else: # Match Frame 0 LAB
prev_img_lab = cv2.cvtColor(prev_img, cv2.COLOR_RGB2LAB)
color_match_lab = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2LAB)
matched_lab = match_histograms(prev_img_lab, color_match_lab, multichannel=True)
return cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB)
#
# Callback functions
#
class SamplerCallback(object):
# Creates the callback function to be passed into the samplers for each step
def __init__(self, args, mask=None, init_latent=None, sigmas=None, sampler=None,
verbose=False):
self.sampler_name = args.sampler
self.dynamic_threshold = args.dynamic_threshold
self.static_threshold = args.static_threshold
self.mask = mask
self.init_latent = init_latent
self.sigmas = sigmas
self.sampler = sampler
self.verbose = verbose
self.batch_size = args.n_samples
self.save_sample_per_step = args.save_sample_per_step
self.show_sample_per_step = args.show_sample_per_step
self.paths_to_image_steps = [os.path.join( args.outdir, f"{args.timestring}_{index:02}_{args.seed}") for index in range(args.n_samples) ]
if self.save_sample_per_step:
for path in self.paths_to_image_steps:
os.makedirs(path, exist_ok=True)
self.step_index = 0
self.noise = None
if init_latent is not None:
self.noise = torch.randn_like(init_latent, device=device)
self.mask_schedule = None
if sigmas is not None and len(sigmas) > 0:
self.mask_schedule, _ = torch.sort(sigmas/torch.max(sigmas))
elif len(sigmas) == 0:
self.mask = None # no mask needed if no steps (usually happens because strength==1.0)
if self.sampler_name in ["plms","ddim"]:
if mask is not None:
assert sampler is not None, "Callback function for stable-diffusion samplers requires sampler variable"
if self.sampler_name in ["plms","ddim"]:
# Callback function formated for compvis latent diffusion samplers
self.callback = self.img_callback_
else:
# Default callback function uses k-diffusion sampler variables
self.callback = self.k_callback_
self.verbose_print = print if verbose else lambda *args, **kwargs: None
def view_sample_step(self, latents, path_name_modifier=''):
if self.save_sample_per_step or self.show_sample_per_step:
samples = model.decode_first_stage(latents)
if self.save_sample_per_step:
fname = f'{path_name_modifier}_{self.step_index:05}.png'
for i, sample in enumerate(samples):
sample = sample.double().cpu().add(1).div(2).clamp(0, 1)
sample = torch.tensor(np.array(sample))
grid = make_grid(sample, 4).cpu()
TF.to_pil_image(grid).save(os.path.join(self.paths_to_image_steps[i], fname))
if self.show_sample_per_step:
print(path_name_modifier)
self.display_images(samples)
return
def display_images(self, images):
images = images.double().cpu().add(1).div(2).clamp(0, 1)
images = torch.tensor(np.array(images))
grid = make_grid(images, 4).cpu()
display.display(TF.to_pil_image(grid))
return
# The callback function is applied to the image at each step
def dynamic_thresholding_(self, img, threshold):
# Dynamic thresholding from Imagen paper (May 2022)
s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1,img.ndim)))
s = np.max(np.append(s,1.0))
torch.clamp_(img, -1*s, s)
torch.FloatTensor.div_(img, s)
# Callback for samplers in the k-diffusion repo, called thus:
# callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
def k_callback_(self, args_dict):
self.step_index = args_dict['i']
if self.dynamic_threshold is not None:
self.dynamic_thresholding_(args_dict['x'], self.dynamic_threshold)
if self.static_threshold is not None:
torch.clamp_(args_dict['x'], -1*self.static_threshold, self.static_threshold)
if self.mask is not None:
init_noise = self.init_latent + self.noise * args_dict['sigma']
is_masked = torch.logical_and(self.mask >= self.mask_schedule[args_dict['i']], self.mask != 0 )
new_img = init_noise * torch.where(is_masked,1,0) + args_dict['x'] * torch.where(is_masked,0,1)
args_dict['x'].copy_(new_img)
self.view_sample_step(args_dict['denoised'], "x0_pred")
# Callback for Compvis samplers
# Function that is called on the image (img) and step (i) at each step
def img_callback_(self, img, i):
self.step_index = i
# Thresholding functions
if self.dynamic_threshold is not None:
self.dynamic_thresholding_(img, self.dynamic_threshold)
if self.static_threshold is not None:
torch.clamp_(img, -1*self.static_threshold, self.static_threshold)
if self.mask is not None:
i_inv = len(self.sigmas) - i - 1
init_noise = self.sampler.stochastic_encode(self.init_latent, torch.tensor([i_inv]*self.batch_size).to(device), noise=self.noise)
is_masked = torch.logical_and(self.mask >= self.mask_schedule[i], self.mask != 0 )
new_img = init_noise * torch.where(is_masked,1,0) + img * torch.where(is_masked,0,1)
img.copy_(new_img)
self.view_sample_step(img, "x")
def sample_from_cv2(sample: np.ndarray) -> torch.Tensor:
sample = ((sample.astype(float) / 255.0) * 2) - 1
sample = sample[None].transpose(0, 3, 1, 2).astype(np.float16)
sample = torch.from_numpy(sample)
return sample
def sample_to_cv2(sample: torch.Tensor, type=np.uint8) -> np.ndarray:
sample_f32 = rearrange(sample.squeeze().cpu().numpy(), "c h w -> h w c").astype(np.float32)
sample_f32 = ((sample_f32 * 0.5) + 0.5).clip(0, 1)
sample_int8 = (sample_f32 * 255)
return sample_int8.astype(type)
def transform_image_3d(prev_img_cv2, depth_tensor, rot_mat, translate, anim_args):
# adapted and optimized version of transform_image_3d from Disco Diffusion https://github.com/alembics/disco-diffusion
w, h = prev_img_cv2.shape[1], prev_img_cv2.shape[0]
aspect_ratio = float(w)/float(h)
near, far, fov_deg = anim_args.near_plane, anim_args.far_plane, anim_args.fov
persp_cam_old = p3d.FoVPerspectiveCameras(near, far, aspect_ratio, fov=fov_deg, degrees=True, device=device)
persp_cam_new = p3d.FoVPerspectiveCameras(near, far, aspect_ratio, fov=fov_deg, degrees=True, R=rot_mat, T=torch.tensor([translate]), device=device)
# range of [-1,1] is important to torch grid_sample's padding handling
y,x = torch.meshgrid(torch.linspace(-1.,1.,h,dtype=torch.float32,device=device),torch.linspace(-1.,1.,w,dtype=torch.float32,device=device))
z = torch.as_tensor(depth_tensor, dtype=torch.float32, device=device)
xyz_old_world = torch.stack((x.flatten(), y.flatten(), z.flatten()), dim=1)
xyz_old_cam_xy = persp_cam_old.get_full_projection_transform().transform_points(xyz_old_world)[:,0:2]
xyz_new_cam_xy = persp_cam_new.get_full_projection_transform().transform_points(xyz_old_world)[:,0:2]
offset_xy = xyz_new_cam_xy - xyz_old_cam_xy
# affine_grid theta param expects a batch of 2D mats. Each is 2x3 to do rotation+translation.
identity_2d_batch = torch.tensor([[1.,0.,0.],[0.,1.,0.]], device=device).unsqueeze(0)
# coords_2d will have shape (N,H,W,2).. which is also what grid_sample needs.
coords_2d = torch.nn.functional.affine_grid(identity_2d_batch, [1,1,h,w], align_corners=False)
offset_coords_2d = coords_2d - torch.reshape(offset_xy, (h,w,2)).unsqueeze(0)
image_tensor = rearrange(torch.from_numpy(prev_img_cv2.astype(np.float32)), 'h w c -> c h w').to(device)
new_image = torch.nn.functional.grid_sample(
image_tensor.add(1/512 - 0.0001).unsqueeze(0),
offset_coords_2d,
mode=anim_args.sampling_mode,
padding_mode=anim_args.padding_mode,
align_corners=False
)
# convert back to cv2 style numpy array
result = rearrange(
new_image.squeeze().clamp(0,255),
'c h w -> h w c'
).cpu().numpy().astype(prev_img_cv2.dtype)
return result
def check_is_number(value):
float_pattern = r'^(?=.)([+-]?([0-9]*)(\.([0-9]+))?)$'
return re.match(float_pattern, value)
# prompt weighting with colons and number coefficients (like 'bacon:0.75 eggs:0.25')
# borrowed from https://github.com/kylewlacy/stable-diffusion/blob/0a4397094eb6e875f98f9d71193e350d859c4220/ldm/dream/conditioning.py
# and https://github.com/raefu/stable-diffusion-automatic/blob/unstablediffusion/modules/processing.py
def get_uc_and_c(prompts, model, args, frame = 0):
prompt = prompts[0] # they are the same in a batch anyway
# get weighted sub-prompts
negative_subprompts, positive_subprompts = split_weighted_subprompts(
prompt, frame, not args.normalize_prompt_weights
)
uc = get_learned_conditioning(model, negative_subprompts, "", args, -1)
c = get_learned_conditioning(model, positive_subprompts, prompt, args, 1)
return (uc, c)
def get_learned_conditioning(model, weighted_subprompts, text, args, sign = 1):
if len(weighted_subprompts) < 1:
log_tokenization(text, model, args.log_weighted_subprompts, sign)
c = model.get_learned_conditioning(args.n_samples * [text])
else:
c = None
for subtext, subweight in weighted_subprompts:
log_tokenization(subtext, model, args.log_weighted_subprompts, sign * subweight)
if c is None:
c = model.get_learned_conditioning(args.n_samples * [subtext])
c *= subweight
else:
c.add_(model.get_learned_conditioning(args.n_samples * [subtext]), alpha=subweight)
return c
def parse_weight(match, frame = 0)->float:
import numexpr
w_raw = match.group("weight")
if w_raw == None:
return 1
if check_is_number(w_raw):
return float(w_raw)
else:
t = frame
if len(w_raw) < 3:
print('the value inside `-characters cannot represent a math function')
return 1
return float(numexpr.evaluate(w_raw[1:-1]))
def normalize_prompt_weights(parsed_prompts):
if len(parsed_prompts) == 0:
return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
if weight_sum == 0:
print(
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
equal_weight = 1 / max(len(parsed_prompts), 1)
return [(x[0], equal_weight) for x in parsed_prompts]
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
def split_weighted_subprompts(text, frame = 0, skip_normalize=False):
"""
grabs all text up to the first occurrence of ':'
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
if ':' has no value defined, defaults to 1.0
repeats until no text remaining
"""
prompt_parser = re.compile("""
(?P<prompt> # capture group for 'prompt'
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
) # end 'prompt'
(?: # non-capture group
:+ # match one or more ':' characters
(?P<weight>(( # capture group for 'weight'
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
)|( # or
`[\S\s]*?`# a math function
)))? # end weight capture group, make optional
\s* # strip spaces after weight
| # OR
$ # else, if no ':' then match end of line
) # end non-capture group
""", re.VERBOSE)
negative_prompts = []
positive_prompts = []
for match in re.finditer(prompt_parser, text):
w = parse_weight(match, frame)
if w < 0:
# negating the sign as we'll feed this to uc
negative_prompts.append((match.group("prompt").replace("\\:", ":"), -w))
elif w > 0:
positive_prompts.append((match.group("prompt").replace("\\:", ":"), w))
if skip_normalize:
return (negative_prompts, positive_prompts)
return (normalize_prompt_weights(negative_prompts), normalize_prompt_weights(positive_prompts))
# shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
def log_tokenization(text, model, log=False, weight=1):
if not log:
return
tokens = model.cond_stage_model.tokenizer._tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', ' ')
# alternate color
s = (usedTokens % 6) + 1
if i < model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
)
def generate(args, frame = 0, return_latent=False, return_sample=False, return_c=False):
seed_everything(args.seed)
os.makedirs(args.outdir, exist_ok=True)
sampler = PLMSSampler(model) if args.sampler == 'plms' else DDIMSampler(model)
model_wrap = CompVisDenoiser(model)
batch_size = args.n_samples
prompt = args.prompt
assert prompt is not None
data = [batch_size * [prompt]]
precision_scope = autocast if args.precision == "autocast" else nullcontext
init_latent = None
mask_image = None
init_image = None
if args.init_latent is not None:
init_latent = args.init_latent
elif args.init_sample is not None:
with precision_scope("cuda"):
init_latent = model.get_first_stage_encoding(model.encode_first_stage(args.init_sample))
elif args.use_init and args.init_image != None and args.init_image != '':
init_image, mask_image = load_img(args.init_image,
shape=(args.W, args.H),
use_alpha_as_mask=args.use_alpha_as_mask)
init_image = init_image.to(device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
with precision_scope("cuda"):
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
if not args.use_init and args.strength > 0 and args.strength_0_no_init:
print("\nNo init image, but strength > 0. Strength has been auto set to 0, since use_init is False.")
print("If you want to force strength > 0 with no init, please set strength_0_no_init to False.\n")
args.strength = 0
# Mask functions
if args.use_mask:
assert args.mask_file is not None or mask_image is not None, "use_mask==True: An mask image is required for a mask. Please enter a mask_file or use an init image with an alpha channel"
assert args.use_init, "use_mask==True: use_init is required for a mask"
assert init_latent is not None, "use_mask==True: An latent init image is required for a mask"
mask = prepare_mask(args.mask_file if mask_image is None else mask_image,
init_latent.shape,
args.mask_contrast_adjust,
args.mask_brightness_adjust)
if (torch.all(mask == 0) or torch.all(mask == 1)) and args.use_alpha_as_mask:
raise Warning("use_alpha_as_mask==True: Using the alpha channel from the init image as a mask, but the alpha channel is blank.")
mask = mask.to(device)
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
else:
mask = None
assert not ( (args.use_mask and args.overlay_mask) and (args.init_sample is None and init_image is None)), "Need an init image when use_mask == True and overlay_mask == True"
t_enc = int((1.0-args.strength) * args.steps)
# Noise schedule for the k-diffusion samplers (used for masking)
k_sigmas = model_wrap.get_sigmas(args.steps)
k_sigmas = k_sigmas[len(k_sigmas)-t_enc-1:]
if args.sampler in ['plms','ddim']:
sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.ddim_eta, ddim_discretize='fill', verbose=False)
callback = SamplerCallback(args=args,
mask=mask,
init_latent=init_latent,
sigmas=k_sigmas,
sampler=sampler,
verbose=False).callback
results = []
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
for prompts in data:
if isinstance(prompts, tuple):
prompts = list(prompts)
if args.prompt_weighting:
uc, c = get_uc_and_c(prompts, model, args, frame)
else:
uc = model.get_learned_conditioning(batch_size * [""])
c = model.get_learned_conditioning(prompts)
if args.scale == 1.0:
uc = None
if args.init_c != None:
c = args.init_c
if args.sampler in ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral"]:
samples = sampler_fn(
c=c,
uc=uc,
args=args,
model_wrap=model_wrap,
init_latent=init_latent,
t_enc=t_enc,
device=device,
cb=callback)
else:
# args.sampler == 'plms' or args.sampler == 'ddim':
if init_latent is not None and args.strength > 0:
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
else:
z_enc = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device)
if args.sampler == 'ddim':
samples = sampler.decode(z_enc,
c,
t_enc,
unconditional_guidance_scale=args.scale,
unconditional_conditioning=uc,
img_callback=callback)
elif args.sampler == 'plms': # no "decode" function in plms, so use "sample"
shape = [args.C, args.H // args.f, args.W // args.f]
samples, _ = sampler.sample(S=args.steps,
conditioning=c,
batch_size=args.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=args.scale,
unconditional_conditioning=uc,
eta=args.ddim_eta,
x_T=z_enc,
img_callback=callback)
else:
raise Exception(f"Sampler {args.sampler} not recognised.")
if return_latent:
results.append(samples.clone())
x_samples = model.decode_first_stage(samples)
if args.use_mask and args.overlay_mask:
# Overlay the masked image after the image is generated
if args.init_sample is not None:
img_original = args.init_sample
elif init_image is not None:
img_original = init_image
else:
raise Exception("Cannot overlay the masked image without an init image to overlay")
mask_fullres = prepare_mask(args.mask_file if mask_image is None else mask_image,
img_original.shape,
args.mask_contrast_adjust,
args.mask_brightness_adjust)
mask_fullres = mask_fullres[:,:3,:,:]
mask_fullres = repeat(mask_fullres, '1 ... -> b ...', b=batch_size)
mask_fullres[mask_fullres < mask_fullres.max()] = 0
mask_fullres = gaussian_filter(mask_fullres, args.mask_overlay_blur)
mask_fullres = torch.Tensor(mask_fullres).to(device)
x_samples = img_original * mask_fullres + x_samples * ((mask_fullres * -1.0) + 1)
if return_sample:
results.append(x_samples.clone())
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
if return_c:
results.append(c.clone())
for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
image = Image.fromarray(x_sample.astype(np.uint8))
results.append(image)
return results
# %%
# !! {"metadata":{
# !! "cellView": "form",
# !! "id": "CIUJ7lWI4v53"
# !! }}
#@markdown **Select and Load Model**
model_config = "v1-inference.yaml" #@param ["custom","v1-inference.yaml"]
model_checkpoint = "sd-v1-4.ckpt" #@param ["custom","sd-v1-4-full-ema.ckpt","sd-v1-4.ckpt","sd-v1-3-full-ema.ckpt","sd-v1-3.ckpt","sd-v1-2-full-ema.ckpt","sd-v1-2.ckpt","sd-v1-1-full-ema.ckpt","sd-v1-1.ckpt", "robo-diffusion-v1.ckpt","waifu-diffusion-v1-3.ckpt"]
if model_checkpoint == "waifu-diffusion-v1-3.ckpt":
model_checkpoint = "model-epoch05-float16.ckpt"
custom_config_path = "" #@param {type:"string"}
custom_checkpoint_path = "" #@param {type:"string"}
load_on_run_all = True #@param {type: 'boolean'}
half_precision = True # check
check_sha256 = True #@param {type:"boolean"}
model_map = {
"sd-v1-4-full-ema.ckpt": {
'sha256': '14749efc0ae8ef0329391ad4436feb781b402f4fece4883c7ad8d10556d8a36a',
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-4-full-ema.ckpt',
'requires_login': True,
},
"sd-v1-4.ckpt": {
'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556',
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt',
'requires_login': True,
},
"sd-v1-3-full-ema.ckpt": {
'sha256': '54632c6e8a36eecae65e36cb0595fab314e1a1545a65209f24fde221a8d4b2ca',
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/blob/main/sd-v1-3-full-ema.ckpt',
'requires_login': True,
},
"sd-v1-3.ckpt": {
'sha256': '2cff93af4dcc07c3e03110205988ff98481e86539c51a8098d4f2236e41f7f2f',
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/resolve/main/sd-v1-3.ckpt',
'requires_login': True,
},
"sd-v1-2-full-ema.ckpt": {
'sha256': 'bc5086a904d7b9d13d2a7bccf38f089824755be7261c7399d92e555e1e9ac69a',
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-2-full-ema.ckpt',
'requires_login': True,
},
"sd-v1-2.ckpt": {
'sha256': '3b87d30facd5bafca1cbed71cfb86648aad75d1c264663c0cc78c7aea8daec0d',
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/resolve/main/sd-v1-2.ckpt',
'requires_login': True,
},
"sd-v1-1-full-ema.ckpt": {
'sha256': 'efdeb5dc418a025d9a8cc0a8617e106c69044bc2925abecc8a254b2910d69829',
'url':'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1-full-ema.ckpt',
'requires_login': True,
},
"sd-v1-1.ckpt": {
'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea',
'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1.ckpt',
'requires_login': True,
},
"robo-diffusion-v1.ckpt": {
'sha256': '244dbe0dcb55c761bde9c2ac0e9b46cc9705ebfe5f1f3a7cc46251573ea14e16',
'url': 'https://huggingface.co/nousr/robo-diffusion/resolve/main/models/robo-diffusion-v1.ckpt',
'requires_login': False,
},
"model-epoch05-float16.ckpt": {
'sha256': '26cf2a2e30095926bb9fd9de0c83f47adc0b442dbfdc3d667d43778e8b70bece',
'url': 'https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/model-epoch05-float16.ckpt',
'requires_login': False,
},
}
# config path
ckpt_config_path = custom_config_path if model_config == "custom" else os.path.join(models_path, model_config)
if os.path.exists(ckpt_config_path):
print(f"{ckpt_config_path} exists")
else:
ckpt_config_path = "./stable-diffusion/configs/stable-diffusion/v1-inference.yaml"
print(f"Using config: {ckpt_config_path}")
# checkpoint path or download
ckpt_path = custom_checkpoint_path if model_checkpoint == "custom" else os.path.join(models_path, model_checkpoint)
ckpt_valid = True
if os.path.exists(ckpt_path):
print(f"{ckpt_path} exists")
elif 'url' in model_map[model_checkpoint]:
url = model_map[model_checkpoint]['url']
# CLI dialogue to authenticate download
if model_map[model_checkpoint]['requires_login']:
print("This model requires an authentication token")
print("Please ensure you have accepted its terms of service before continuing.")
username = input("What is your huggingface username?:")
token = input("What is your huggingface token?:")
_, path = url.split("https://")
url = f"https://{username}:{token}@{path}"