-
Notifications
You must be signed in to change notification settings - Fork 80
/
hpu_config_web_dataset.yaml
148 lines (140 loc) · 4.02 KB
/
hpu_config_web_dataset.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 32 #64?
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_fused_adamw: True
use_ema: False
use_autocast: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
#warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
# f_max: [ 1.e-4 ]
# f_min: [ 1.e-10 ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32
# from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: False
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
# from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
monitor: val/rec_loss
# first_stage_models/kl-f8/model.ckpt
ckpt_path: "/software/lfs/data/pytorch/stable-diffusion/checkpoint/model.ckpt"
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
params:
device: "hpu"
# use_fp16: True
data:
target: ldm.data.laion.WebDataModuleFromConfig
params:
tar_base: "/software/lfs/data/pytorch/stable-diffusion/laion2B-data"
batch_size: 8
#wrap: True
#shuffle: 10000
#min_size: 256
num_workers: 4
multinode: True
train:
shards: '{000256..111463}.tar' #231349}.tar' #'{000000..231317}.tar -'
shuffle: 10000
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 256
interpolation: 3
- target: torchvision.transforms.RandomCrop
params:
size: 256
validation:
shards: '{000000..000255}.tar' #'{231318..231349}.tar -'
shuffle: 0
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 256
interpolation: 3
- target: torchvision.transforms.CenterCrop
params:
size: 256
lightning:
trainer:
val_check_interval: 10240
num_sanity_val_steps: 10
benchmark: True
accumulate_grad_batches: 16
max_epochs: 1
max_steps: 237000
limit_val_batches: 2048
callbacks:
image_logger:
target: main.ImageLogger
params:
disabled: False
batch_frequency: 50000
max_images: 10
increase_log_steps: False
log_first_step: False
log_images_kwargs:
use_ema_scope: False
inpaint: True
plot_progressive_rows: False
plot_diffusion_rows: True
N: 4
unconditional_guidance_scale: 3.0
unconditional_guidance_label: [""]
print_freq:
refresh_rate: 100