diff --git a/examples/main.cpp b/examples/main.cpp index b2a1ddf3..6d68f5fe 100644 --- a/examples/main.cpp +++ b/examples/main.cpp @@ -80,6 +80,12 @@ const char* sample_method_str[] = { "dpm++2m", "dpm++2mv2"}; +// Names of the sigma schedule overrides, same order as Schedule in stable-diffusion.h +const char* schedule_str[] = { + "default", + "discrete", + "karras"}; + struct Option { int n_threads = -1; std::string mode = TXT2IMG; @@ -92,6 +98,7 @@ struct Option { int w = 512; int h = 512; SampleMethod sample_method = EULER_A; + Schedule schedule = DEFAULT; int sample_steps = 20; float strength = 0.75f; RNGType rng_type = CUDA_RNG; @@ -111,6 +118,7 @@ struct Option { printf(" width: %d\n", w); printf(" height: %d\n", h); printf(" sample_method: %s\n", sample_method_str[sample_method]); + printf(" schedule: %s\n", schedule_str[schedule]); printf(" sample_steps: %d\n", sample_steps); printf(" strength: %.2f\n", strength); printf(" rng: %s\n", rng_type_to_str[rng_type]); @@ -141,6 +149,7 @@ void print_usage(int argc, const char* argv[]) { printf(" --steps STEPS number of sample steps (default: 20)\n"); printf(" --rng {std_default, cuda} RNG (default: cuda)\n"); printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n"); + printf(" --schedule {discrete, karras} Denoiser sigma schedule (default: discrete)\n"); printf(" -v, --verbose print extra info\n"); } @@ -237,6 +246,23 @@ void parse_args(int argc, const char* argv[], Option* opt) { invalid_arg = true; break; } + } else if (arg == "--schedule") { + if (++i >= argc) { + invalid_arg = true; + break; + } + const char* schedule_selected = argv[i]; + int schedule_found = -1; + for (int d = 0; d < N_SCHEDULES; d++) { + if (!strcmp(schedule_selected, schedule_str[d])) { + schedule_found = d; + } + } + if (schedule_found == -1) { + invalid_arg = true; + break; + } + opt->schedule = (Schedule)schedule_found; } else if (arg == "-s" || arg == "--seed") { if (++i >= argc) { invalid_arg = true; @@ -377,7 +403,7 @@ int main(int argc, const char* argv[]) { } StableDiffusion sd(opt.n_threads, vae_decode_only, true, opt.rng_type); - if (!sd.load_from_file(opt.model_path)) { + if (!sd.load_from_file(opt.model_path, opt.schedule)) { return 1; } @@ -413,4 +439,4 @@ int main(int argc, const char* argv[]) { printf("save result image to '%s'\n", opt.output_path.c_str()); return 0; -} \ No newline at end of file +} diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 9263ce71..c39bdff3 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -2654,32 +2654,12 @@ struct AutoEncoderKL { // Ref: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/external.py -struct DiscreteSchedule { +struct SigmaSchedule { float alphas_cumprod[TIMESTEPS]; float sigmas[TIMESTEPS]; float log_sigmas[TIMESTEPS]; - std::vector get_sigmas(uint32_t n) { - std::vector result; - - int t_max = TIMESTEPS - 1; - - if (n == 0) { - return result; - } else if (n == 1) { - result.push_back(t_to_sigma(t_max)); - result.push_back(0); - return result; - } - - float step = static_cast(t_max) / static_cast(n - 1); - for (int i = 0; i < n; ++i) { - float t = t_max - step * i; - result.push_back(t_to_sigma(t)); - } - result.push_back(0); - return result; - } + virtual std::vector get_sigmas(uint32_t n) = 0; float sigma_to_t(float sigma) { float log_sigma = std::log(sigma); @@ -2714,11 +2694,59 @@ struct DiscreteSchedule { float log_sigma = (1.0f - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx]; return std::exp(log_sigma); } +}; +struct DiscreteSchedule : SigmaSchedule { + std::vector get_sigmas(uint32_t n) { + std::vector result; + + int t_max = TIMESTEPS - 1; + + if (n == 0) { + return result; + } else if (n == 1) { + result.push_back(t_to_sigma(t_max)); + result.push_back(0); + return result; + } + + float step = static_cast(t_max) / static_cast(n - 1); + for (int i = 0; i < n; ++i) { + float t = t_max - step * i; + result.push_back(t_to_sigma(t)); + } + result.push_back(0); + return result; + } +}; + +struct KarrasSchedule : SigmaSchedule { + std::vector get_sigmas(uint32_t n) { + // These *COULD* be function arguments here, + // but does anybody ever bother to touch them? + float sigma_min = 0.1; + float sigma_max = 10.; + float rho = 7.; + + std::vector result(n + 1); + + float min_inv_rho = pow(sigma_min, (1. / rho)); + float max_inv_rho = pow(sigma_max, (1. / rho)); + for (int i = 0; i < n; i++) { + // Eq. (5) from Karras et al 2022 + result[i] = pow(max_inv_rho + (float)i / ((float)n - 1.) * (min_inv_rho - max_inv_rho), rho); + } + result[n] = 0.; + return result; + } +}; + +struct Denoiser { + std::shared_ptr schedule = std::make_shared(); virtual std::vector get_scalings(float sigma) = 0; }; -struct CompVisDenoiser : public DiscreteSchedule { +struct CompVisDenoiser : public Denoiser { float sigma_data = 1.0f; std::vector get_scalings(float sigma) { @@ -2728,7 +2756,7 @@ struct CompVisDenoiser : public DiscreteSchedule { } }; -struct CompVisVDenoiser : public DiscreteSchedule { +struct CompVisVDenoiser : public Denoiser { float sigma_data = 1.0f; std::vector get_scalings(float sigma) { @@ -2764,7 +2792,7 @@ class StableDiffusionGGML { UNetModel diffusion_model; AutoEncoderKL first_stage_model; - std::shared_ptr denoiser = std::make_shared(); + std::shared_ptr denoiser = std::make_shared(); StableDiffusionGGML() = default; @@ -2798,7 +2826,7 @@ class StableDiffusionGGML { } } - bool load_from_file(const std::string& file_path) { + bool load_from_file(const std::string& file_path, Schedule schedule) { LOG_INFO("loading model from '%s'", file_path.c_str()); std::ifstream file(file_path, std::ios::binary); @@ -3093,10 +3121,29 @@ class StableDiffusionGGML { LOG_INFO("running in eps-prediction mode"); } + if (schedule != DEFAULT) { + switch (schedule) { + case DISCRETE: + LOG_INFO("running with discrete schedule"); + denoiser->schedule = std::make_shared(); + break; + case KARRAS: + LOG_INFO("running with Karras schedule"); + denoiser->schedule = std::make_shared(); + break; + case DEFAULT: + // Don't touch anything. + break; + default: + LOG_ERROR("Unknown schedule %i", schedule); + abort(); + } + } + for (int i = 0; i < TIMESTEPS; i++) { - denoiser->alphas_cumprod[i] = alphas_cumprod[i]; - denoiser->sigmas[i] = std::sqrt((1 - denoiser->alphas_cumprod[i]) / denoiser->alphas_cumprod[i]); - denoiser->log_sigmas[i] = std::log(denoiser->sigmas[i]); + denoiser->schedule->alphas_cumprod[i] = alphas_cumprod[i]; + denoiser->schedule->sigmas[i] = std::sqrt((1 - denoiser->schedule->alphas_cumprod[i]) / denoiser->schedule->alphas_cumprod[i]); + denoiser->schedule->log_sigmas[i] = std::log(denoiser->schedule->sigmas[i]); } return true; @@ -3445,7 +3492,7 @@ class StableDiffusionGGML { c_in = scaling[1]; } - float t = denoiser->sigma_to_t(sigma); + float t = denoiser->schedule->sigma_to_t(sigma); ggml_set_f32(timesteps, t); set_timestep_embedding(timesteps, t_emb, diffusion_model.model_channels); @@ -4010,8 +4057,8 @@ StableDiffusion::StableDiffusion(int n_threads, rng_type); } -bool StableDiffusion::load_from_file(const std::string& file_path) { - return sd->load_from_file(file_path); +bool StableDiffusion::load_from_file(const std::string& file_path, Schedule s) { + return sd->load_from_file(file_path, s); } std::vector StableDiffusion::txt2img(const std::string& prompt, @@ -4061,7 +4108,7 @@ std::vector StableDiffusion::txt2img(const std::string& prompt, struct ggml_tensor* x_t = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, W, H, C, 1); ggml_tensor_set_f32_randn(x_t, sd->rng); - std::vector sigmas = sd->denoiser->get_sigmas(sample_steps); + std::vector sigmas = sd->denoiser->schedule->get_sigmas(sample_steps); LOG_INFO("start sampling"); struct ggml_tensor* x_0 = sd->sample(ctx, x_t, c, uc, cfg_scale, sample_method, sigmas); @@ -4117,7 +4164,7 @@ std::vector StableDiffusion::img2img(const std::vector& init_i } LOG_INFO("img2img %dx%d", width, height); - std::vector sigmas = sd->denoiser->get_sigmas(sample_steps); + std::vector sigmas = sd->denoiser->schedule->get_sigmas(sample_steps); size_t t_enc = static_cast(sample_steps * strength); LOG_INFO("target t_enc is %zu steps", t_enc); std::vector sigma_sched; diff --git a/stable-diffusion.h b/stable-diffusion.h index 81186338..b0706180 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -25,6 +25,13 @@ enum SampleMethod { N_SAMPLE_METHODS }; +enum Schedule { + DEFAULT, + DISCRETE, + KARRAS, + N_SCHEDULES +}; + class StableDiffusionGGML; class StableDiffusion { @@ -36,7 +43,7 @@ class StableDiffusion { bool vae_decode_only = false, bool free_params_immediately = false, RNGType rng_type = STD_DEFAULT_RNG); - bool load_from_file(const std::string& file_path); + bool load_from_file(const std::string& file_path, Schedule d = DEFAULT); std::vector txt2img( const std::string& prompt, const std::string& negative_prompt,