Skip to content

Commit

Permalink
feat: add sd3.5 support (leejet#445)
Browse files Browse the repository at this point in the history
  • Loading branch information
leejet authored Oct 24, 2024
1 parent 14206fd commit ac54e00
Show file tree
Hide file tree
Showing 13 changed files with 250 additions and 127 deletions.
22 changes: 12 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Inference of Stable Diffusion and Flux in pure C/C++

- Plain C/C++ implementation based on [ggml](https://github.com/ggerganov/ggml), working in the same way as [llama.cpp](https://github.com/ggerganov/llama.cpp)
- Super lightweight and without external dependencies
- SD1.x, SD2.x, SDXL and SD3 support
- SD1.x, SD2.x, SDXL and [SD3/SD3.5](./docs/sd3.md) support
- !!!The VAE in SDXL encounters NaN issues under FP16, but unfortunately, the ggml_conv_2d only operates under FP16. Hence, a parameter is needed to specify the VAE that has fixed the FP16 NaN issue. You can find it here: [SDXL VAE FP16 Fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors).
- [Flux-dev/Flux-schnell Support](./docs/flux.md)

Expand Down Expand Up @@ -197,23 +197,24 @@ usage: ./bin/sd [arguments]
arguments:
-h, --help show this help message and exit
-M, --mode [MODEL] run mode (txt2img or img2img or convert, default: txt2img)
-t, --threads N number of threads to use during computation (default: -1).
-t, --threads N number of threads to use during computation (default: -1)
If threads <= 0, then threads will be set to the number of CPU physical cores
-m, --model [MODEL] path to full model
--diffusion-model path to the standalone diffusion model
--clip_l path to the clip-l text encoder
--t5xxl path to the the t5xxl text encoder.
--clip_g path to the clip-l text encoder
--t5xxl path to the the t5xxl text encoder
--vae [VAE] path to vae
--taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
--control-net [CONTROL_PATH] path to control net model
--embd-dir [EMBEDDING_PATH] path to embeddings.
--stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings.
--input-id-images-dir [DIR] path to PHOTOMAKER input id images dir.
--embd-dir [EMBEDDING_PATH] path to embeddings
--stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings
--input-id-images-dir [DIR] path to PHOTOMAKER input id images dir
--normalize-input normalize PHOTOMAKER input id images
--upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.
--upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now
--upscale-repeats Run the ESRGAN upscaler this many times (default 1)
--type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k)
If not specified, the default is the type of the weight file.
If not specified, the default is the type of the weight file
--lora-model-dir [DIR] lora model directory
-i, --init-img [IMAGE] path to the input image, required by img2img
--control-image [IMAGE] path to image condition, control net
Expand All @@ -232,13 +233,13 @@ arguments:
--steps STEPS number of sample steps (default: 20)
--rng {std_default, cuda} RNG (default: cuda)
-s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)
-b, --batch-count COUNT number of images to generate.
-b, --batch-count COUNT number of images to generate
--schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)
--clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
<= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
--vae-tiling process vae in tiles to reduce memory usage
--vae-on-cpu keep vae in cpu (for low vram)
--clip-on-cpu keep clip in cpu (for low vram).
--clip-on-cpu keep clip in cpu (for low vram)
--control-net-cpu keep controlnet in cpu (for low vram)
--canny apply canny preprocessor (edge detection)
--color Colors the logging tags according to level
Expand All @@ -253,6 +254,7 @@ arguments:
# ./bin/sd -m ../models/sd_xl_base_1.0.safetensors --vae ../models/sdxl_vae-fp16-fix.safetensors -H 1024 -W 1024 -p "a lovely cat" -v
# ./bin/sd -m ../models/sd3_medium_incl_clips_t5xxlfp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable Diffusion CPP\"' --cfg-scale 4.5 --sampling-method euler -v
# ./bin/sd --diffusion-model ../models/flux1-dev-q3_k.gguf --vae ../models/ae.sft --clip_l ../models/clip_l.safetensors --t5xxl ../models/t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'flux.cpp'" --cfg-scale 1.0 --sampling-method euler -v
# ./bin/sd -m ..\models\sd3.5_large.safetensors --clip_l ..\models\clip_l.safetensors --clip_g ..\models\clip_g.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable diffusion 3.5 Large\"' --cfg-scale 4.5 --sampling-method euler -v
```

Using formats of different precisions will yield results of varying quality.
Expand Down
Binary file added assets/sd3.5_large.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1001,8 +1001,8 @@ struct FluxCLIPEmbedder : public Conditioner {
}

void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.text_model");
t5->get_param_tensors(tensors, "text_encoders.t5xxl");
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.transformer.text_model");
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
}

void alloc_params_buffer() {
Expand Down
79 changes: 38 additions & 41 deletions denoiser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct ExponentialSchedule : SigmaSchedule {
// Calculate step size
float log_sigma_min = std::log(sigma_min);
float log_sigma_max = std::log(sigma_max);
float step = (log_sigma_max - log_sigma_min) / (n - 1);
float step = (log_sigma_max - log_sigma_min) / (n - 1);

// Fill sigmas with exponential values
for (uint32_t i = 0; i < n; ++i) {
Expand Down Expand Up @@ -205,7 +205,7 @@ struct AYSSchedule : SigmaSchedule {

/*
* GITS Scheduler: https://github.com/zju-pi/diff-sampler/tree/main/gits-main
*/
*/
struct GITSSchedule : SigmaSchedule {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) {
if (sigma_max <= 0.0f) {
Expand All @@ -221,7 +221,7 @@ struct GITSSchedule : SigmaSchedule {
// Calculate the index based on the coefficient
int index = static_cast<int>((coeff - 0.80f) / 0.05f);
// Ensure the index is within bounds
index = std::max(0, std::min(index, static_cast<int>(GITS_NOISE.size() - 1)));
index = std::max(0, std::min(index, static_cast<int>(GITS_NOISE.size() - 1)));
const std::vector<std::vector<float>>& selected_noise = *GITS_NOISE[index];

if (n <= 20) {
Expand Down Expand Up @@ -823,24 +823,24 @@ static void sample_k_diffusion(sample_method_t method,
} break;
case IPNDM: // iPNDM sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main
{
int max_order = 4;
int max_order = 4;
ggml_tensor* x_next = x;
std::vector<ggml_tensor*> buffer_model;

for (int i = 0; i < steps; i++) {
float sigma = sigmas[i];
float sigma = sigmas[i];
float sigma_next = sigmas[i + 1];

ggml_tensor* x_cur = x_next;
float* vec_x_cur = (float*)x_cur->data;
float* vec_x_next = (float*)x_next->data;
float* vec_x_cur = (float*)x_cur->data;
float* vec_x_next = (float*)x_next->data;

// Denoising step
ggml_tensor* denoised = model(x_cur, sigma, i + 1);
float* vec_denoised = (float*)denoised->data;
float* vec_denoised = (float*)denoised->data;
// d_cur = (x_cur - denoised) / sigma
struct ggml_tensor* d_cur = ggml_dup_tensor(work_ctx, x_cur);
float* vec_d_cur = (float*)d_cur->data;
float* vec_d_cur = (float*)d_cur->data;

for (int j = 0; j < ggml_nelements(d_cur); j++) {
vec_d_cur[j] = (vec_x_cur[j] - vec_denoised[j]) / sigma;
Expand All @@ -857,34 +857,31 @@ static void sample_k_diffusion(sample_method_t method,
break;

case 2: // Use one history point
{
float* vec_d_prev1 = (float*)buffer_model.back()->data;
for (int j = 0; j < ggml_nelements(x_next); j++) {
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (3 * vec_d_cur[j] - vec_d_prev1[j]) / 2;
}
{
float* vec_d_prev1 = (float*)buffer_model.back()->data;
for (int j = 0; j < ggml_nelements(x_next); j++) {
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (3 * vec_d_cur[j] - vec_d_prev1[j]) / 2;
}
break;
} break;

case 3: // Use two history points
{
float* vec_d_prev1 = (float*)buffer_model.back()->data;
float* vec_d_prev2 = (float*)buffer_model[buffer_model.size() - 2]->data;
for (int j = 0; j < ggml_nelements(x_next); j++) {
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (23 * vec_d_cur[j] - 16 * vec_d_prev1[j] + 5 * vec_d_prev2[j]) / 12;
}
{
float* vec_d_prev1 = (float*)buffer_model.back()->data;
float* vec_d_prev2 = (float*)buffer_model[buffer_model.size() - 2]->data;
for (int j = 0; j < ggml_nelements(x_next); j++) {
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (23 * vec_d_cur[j] - 16 * vec_d_prev1[j] + 5 * vec_d_prev2[j]) / 12;
}
break;
} break;

case 4: // Use three history points
{
float* vec_d_prev1 = (float*)buffer_model.back()->data;
float* vec_d_prev2 = (float*)buffer_model[buffer_model.size() - 2]->data;
float* vec_d_prev3 = (float*)buffer_model[buffer_model.size() - 3]->data;
for (int j = 0; j < ggml_nelements(x_next); j++) {
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (55 * vec_d_cur[j] - 59 * vec_d_prev1[j] + 37 * vec_d_prev2[j] - 9 * vec_d_prev3[j]) / 24;
}
{
float* vec_d_prev1 = (float*)buffer_model.back()->data;
float* vec_d_prev2 = (float*)buffer_model[buffer_model.size() - 2]->data;
float* vec_d_prev3 = (float*)buffer_model[buffer_model.size() - 3]->data;
for (int j = 0; j < ggml_nelements(x_next); j++) {
vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (55 * vec_d_cur[j] - 59 * vec_d_prev1[j] + 37 * vec_d_prev2[j] - 9 * vec_d_prev3[j]) / 24;
}
break;
} break;
}

// Manage buffer_model
Expand All @@ -906,27 +903,27 @@ static void sample_k_diffusion(sample_method_t method,
ggml_tensor* x_next = x;

for (int i = 0; i < steps; i++) {
float sigma = sigmas[i];
float sigma = sigmas[i];
float t_next = sigmas[i + 1];

// Denoising step
ggml_tensor* denoised = model(x, sigma, i + 1);
float* vec_denoised = (float*)denoised->data;
ggml_tensor* denoised = model(x, sigma, i + 1);
float* vec_denoised = (float*)denoised->data;
struct ggml_tensor* d_cur = ggml_dup_tensor(work_ctx, x);
float* vec_d_cur = (float*)d_cur->data;
float* vec_x = (float*)x->data;
float* vec_d_cur = (float*)d_cur->data;
float* vec_x = (float*)x->data;

// d_cur = (x - denoised) / sigma
for (int j = 0; j < ggml_nelements(d_cur); j++) {
vec_d_cur[j] = (vec_x[j] - vec_denoised[j]) / sigma;
}

int order = std::min(max_order, i + 1);
float h_n = t_next - sigma;
int order = std::min(max_order, i + 1);
float h_n = t_next - sigma;
float h_n_1 = (i > 0) ? (sigma - sigmas[i - 1]) : h_n;

switch (order) {
case 1: // First Euler step
case 1: // First Euler step
for (int j = 0; j < ggml_nelements(x_next); j++) {
vec_x[j] += vec_d_cur[j] * h_n;
}
Expand All @@ -941,7 +938,7 @@ static void sample_k_diffusion(sample_method_t method,
}

case 3: {
float h_n_2 = (i > 1) ? (sigmas[i - 1] - sigmas[i - 2]) : h_n_1;
float h_n_2 = (i > 1) ? (sigmas[i - 1] - sigmas[i - 2]) : h_n_1;
float* vec_d_prev1 = (float*)buffer_model.back()->data;
float* vec_d_prev2 = (buffer_model.size() > 1) ? (float*)buffer_model[buffer_model.size() - 2]->data : vec_d_prev1;
for (int j = 0; j < ggml_nelements(x_next); j++) {
Expand All @@ -951,8 +948,8 @@ static void sample_k_diffusion(sample_method_t method,
}

case 4: {
float h_n_2 = (i > 1) ? (sigmas[i - 1] - sigmas[i - 2]) : h_n_1;
float h_n_3 = (i > 2) ? (sigmas[i - 2] - sigmas[i - 3]) : h_n_2;
float h_n_2 = (i > 1) ? (sigmas[i - 1] - sigmas[i - 2]) : h_n_1;
float h_n_3 = (i > 2) ? (sigmas[i - 2] - sigmas[i - 3]) : h_n_2;
float* vec_d_prev1 = (float*)buffer_model.back()->data;
float* vec_d_prev2 = (buffer_model.size() > 1) ? (float*)buffer_model[buffer_model.size() - 2]->data : vec_d_prev1;
float* vec_d_prev3 = (buffer_model.size() > 2) ? (float*)buffer_model[buffer_model.size() - 3]->data : vec_d_prev2;
Expand Down
20 changes: 20 additions & 0 deletions docs/sd3.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# How to Use

## Download weights

- Download sd3.5_large from https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/sd3.5_large.safetensors
- Download clip_g from https://huggingface.co/Comfy-Org/stable-diffusion-3.5-fp8/blob/main/text_encoders/clip_g.safetensors
- Download clip_l from https://huggingface.co/Comfy-Org/stable-diffusion-3.5-fp8/blob/main/text_encoders/clip_l.safetensors
- Download t5xxl from https://huggingface.co/Comfy-Org/stable-diffusion-3.5-fp8/blob/main/text_encoders/t5xxl_fp16.safetensors


## Run

### SD3.5 Large
For example:

```
.\bin\Release\sd.exe -m ..\models\sd3.5_large.safetensors --clip_l ..\models\clip_l.safetensors --clip_g ..\models\clip_g.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable diffusion 3.5 Large\"' --cfg-scale 4.5 --sampling-method euler -v
```

![](../assets/sd3.5_large.png)
Loading

0 comments on commit ac54e00

Please sign in to comment.