Skip to content

Commit

Permalink
feat: add option to switch the sigma schedule (#51)
Browse files Browse the repository at this point in the history
Concretely, this allows switching to the "Karras" schedule from the
Karras et al 2022 paper, equivalent to the samplers marked as "Karras"
in the AUTOMATIC1111 WebUI. This choice is in principle orthogonal to
the sampler choice and can be given independently.
  • Loading branch information
ursg authored Sep 8, 2023
1 parent b6899e8 commit 968fbf0
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 37 deletions.
30 changes: 28 additions & 2 deletions examples/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ const char* sample_method_str[] = {
"dpm++2m",
"dpm++2mv2"};

// Names of the sigma schedule overrides, same order as Schedule in stable-diffusion.h
const char* schedule_str[] = {
"default",
"discrete",
"karras"};

struct Option {
int n_threads = -1;
std::string mode = TXT2IMG;
Expand All @@ -92,6 +98,7 @@ struct Option {
int w = 512;
int h = 512;
SampleMethod sample_method = EULER_A;
Schedule schedule = DEFAULT;
int sample_steps = 20;
float strength = 0.75f;
RNGType rng_type = CUDA_RNG;
Expand All @@ -111,6 +118,7 @@ struct Option {
printf(" width: %d\n", w);
printf(" height: %d\n", h);
printf(" sample_method: %s\n", sample_method_str[sample_method]);
printf(" schedule: %s\n", schedule_str[schedule]);
printf(" sample_steps: %d\n", sample_steps);
printf(" strength: %.2f\n", strength);
printf(" rng: %s\n", rng_type_to_str[rng_type]);
Expand Down Expand Up @@ -141,6 +149,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" --steps STEPS number of sample steps (default: 20)\n");
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
printf(" --schedule {discrete, karras} Denoiser sigma schedule (default: discrete)\n");
printf(" -v, --verbose print extra info\n");
}

Expand Down Expand Up @@ -237,6 +246,23 @@ void parse_args(int argc, const char* argv[], Option* opt) {
invalid_arg = true;
break;
}
} else if (arg == "--schedule") {
if (++i >= argc) {
invalid_arg = true;
break;
}
const char* schedule_selected = argv[i];
int schedule_found = -1;
for (int d = 0; d < N_SCHEDULES; d++) {
if (!strcmp(schedule_selected, schedule_str[d])) {
schedule_found = d;
}
}
if (schedule_found == -1) {
invalid_arg = true;
break;
}
opt->schedule = (Schedule)schedule_found;
} else if (arg == "-s" || arg == "--seed") {
if (++i >= argc) {
invalid_arg = true;
Expand Down Expand Up @@ -377,7 +403,7 @@ int main(int argc, const char* argv[]) {
}

StableDiffusion sd(opt.n_threads, vae_decode_only, true, opt.rng_type);
if (!sd.load_from_file(opt.model_path)) {
if (!sd.load_from_file(opt.model_path, opt.schedule)) {
return 1;
}

Expand Down Expand Up @@ -413,4 +439,4 @@ int main(int argc, const char* argv[]) {
printf("save result image to '%s'\n", opt.output_path.c_str());

return 0;
}
}
115 changes: 81 additions & 34 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2654,32 +2654,12 @@ struct AutoEncoderKL {

// Ref: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/external.py

struct DiscreteSchedule {
struct SigmaSchedule {
float alphas_cumprod[TIMESTEPS];
float sigmas[TIMESTEPS];
float log_sigmas[TIMESTEPS];

std::vector<float> get_sigmas(uint32_t n) {
std::vector<float> result;

int t_max = TIMESTEPS - 1;

if (n == 0) {
return result;
} else if (n == 1) {
result.push_back(t_to_sigma(t_max));
result.push_back(0);
return result;
}

float step = static_cast<float>(t_max) / static_cast<float>(n - 1);
for (int i = 0; i < n; ++i) {
float t = t_max - step * i;
result.push_back(t_to_sigma(t));
}
result.push_back(0);
return result;
}
virtual std::vector<float> get_sigmas(uint32_t n) = 0;

float sigma_to_t(float sigma) {
float log_sigma = std::log(sigma);
Expand Down Expand Up @@ -2714,11 +2694,59 @@ struct DiscreteSchedule {
float log_sigma = (1.0f - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx];
return std::exp(log_sigma);
}
};

struct DiscreteSchedule : SigmaSchedule {
std::vector<float> get_sigmas(uint32_t n) {
std::vector<float> result;

int t_max = TIMESTEPS - 1;

if (n == 0) {
return result;
} else if (n == 1) {
result.push_back(t_to_sigma(t_max));
result.push_back(0);
return result;
}

float step = static_cast<float>(t_max) / static_cast<float>(n - 1);
for (int i = 0; i < n; ++i) {
float t = t_max - step * i;
result.push_back(t_to_sigma(t));
}
result.push_back(0);
return result;
}
};

struct KarrasSchedule : SigmaSchedule {
std::vector<float> get_sigmas(uint32_t n) {
// These *COULD* be function arguments here,
// but does anybody ever bother to touch them?
float sigma_min = 0.1;
float sigma_max = 10.;
float rho = 7.;

std::vector<float> result(n + 1);

float min_inv_rho = pow(sigma_min, (1. / rho));
float max_inv_rho = pow(sigma_max, (1. / rho));
for (int i = 0; i < n; i++) {
// Eq. (5) from Karras et al 2022
result[i] = pow(max_inv_rho + (float)i / ((float)n - 1.) * (min_inv_rho - max_inv_rho), rho);
}
result[n] = 0.;
return result;
}
};

struct Denoiser {
std::shared_ptr<SigmaSchedule> schedule = std::make_shared<DiscreteSchedule>();
virtual std::vector<float> get_scalings(float sigma) = 0;
};

struct CompVisDenoiser : public DiscreteSchedule {
struct CompVisDenoiser : public Denoiser {
float sigma_data = 1.0f;

std::vector<float> get_scalings(float sigma) {
Expand All @@ -2728,7 +2756,7 @@ struct CompVisDenoiser : public DiscreteSchedule {
}
};

struct CompVisVDenoiser : public DiscreteSchedule {
struct CompVisVDenoiser : public Denoiser {
float sigma_data = 1.0f;

std::vector<float> get_scalings(float sigma) {
Expand Down Expand Up @@ -2764,7 +2792,7 @@ class StableDiffusionGGML {
UNetModel diffusion_model;
AutoEncoderKL first_stage_model;

std::shared_ptr<DiscreteSchedule> denoiser = std::make_shared<CompVisDenoiser>();
std::shared_ptr<Denoiser> denoiser = std::make_shared<CompVisDenoiser>();

StableDiffusionGGML() = default;

Expand Down Expand Up @@ -2798,7 +2826,7 @@ class StableDiffusionGGML {
}
}

bool load_from_file(const std::string& file_path) {
bool load_from_file(const std::string& file_path, Schedule schedule) {
LOG_INFO("loading model from '%s'", file_path.c_str());

std::ifstream file(file_path, std::ios::binary);
Expand Down Expand Up @@ -3093,10 +3121,29 @@ class StableDiffusionGGML {
LOG_INFO("running in eps-prediction mode");
}

if (schedule != DEFAULT) {
switch (schedule) {
case DISCRETE:
LOG_INFO("running with discrete schedule");
denoiser->schedule = std::make_shared<DiscreteSchedule>();
break;
case KARRAS:
LOG_INFO("running with Karras schedule");
denoiser->schedule = std::make_shared<KarrasSchedule>();
break;
case DEFAULT:
// Don't touch anything.
break;
default:
LOG_ERROR("Unknown schedule %i", schedule);
abort();
}
}

for (int i = 0; i < TIMESTEPS; i++) {
denoiser->alphas_cumprod[i] = alphas_cumprod[i];
denoiser->sigmas[i] = std::sqrt((1 - denoiser->alphas_cumprod[i]) / denoiser->alphas_cumprod[i]);
denoiser->log_sigmas[i] = std::log(denoiser->sigmas[i]);
denoiser->schedule->alphas_cumprod[i] = alphas_cumprod[i];
denoiser->schedule->sigmas[i] = std::sqrt((1 - denoiser->schedule->alphas_cumprod[i]) / denoiser->schedule->alphas_cumprod[i]);
denoiser->schedule->log_sigmas[i] = std::log(denoiser->schedule->sigmas[i]);
}

return true;
Expand Down Expand Up @@ -3445,7 +3492,7 @@ class StableDiffusionGGML {
c_in = scaling[1];
}

float t = denoiser->sigma_to_t(sigma);
float t = denoiser->schedule->sigma_to_t(sigma);
ggml_set_f32(timesteps, t);
set_timestep_embedding(timesteps, t_emb, diffusion_model.model_channels);

Expand Down Expand Up @@ -4010,8 +4057,8 @@ StableDiffusion::StableDiffusion(int n_threads,
rng_type);
}

bool StableDiffusion::load_from_file(const std::string& file_path) {
return sd->load_from_file(file_path);
bool StableDiffusion::load_from_file(const std::string& file_path, Schedule s) {
return sd->load_from_file(file_path, s);
}

std::vector<uint8_t> StableDiffusion::txt2img(const std::string& prompt,
Expand Down Expand Up @@ -4061,7 +4108,7 @@ std::vector<uint8_t> StableDiffusion::txt2img(const std::string& prompt,
struct ggml_tensor* x_t = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, W, H, C, 1);
ggml_tensor_set_f32_randn(x_t, sd->rng);

std::vector<float> sigmas = sd->denoiser->get_sigmas(sample_steps);
std::vector<float> sigmas = sd->denoiser->schedule->get_sigmas(sample_steps);

LOG_INFO("start sampling");
struct ggml_tensor* x_0 = sd->sample(ctx, x_t, c, uc, cfg_scale, sample_method, sigmas);
Expand Down Expand Up @@ -4117,7 +4164,7 @@ std::vector<uint8_t> StableDiffusion::img2img(const std::vector<uint8_t>& init_i
}
LOG_INFO("img2img %dx%d", width, height);

std::vector<float> sigmas = sd->denoiser->get_sigmas(sample_steps);
std::vector<float> sigmas = sd->denoiser->schedule->get_sigmas(sample_steps);
size_t t_enc = static_cast<size_t>(sample_steps * strength);
LOG_INFO("target t_enc is %zu steps", t_enc);
std::vector<float> sigma_sched;
Expand Down
9 changes: 8 additions & 1 deletion stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ enum SampleMethod {
N_SAMPLE_METHODS
};

enum Schedule {
DEFAULT,
DISCRETE,
KARRAS,
N_SCHEDULES
};

class StableDiffusionGGML;

class StableDiffusion {
Expand All @@ -36,7 +43,7 @@ class StableDiffusion {
bool vae_decode_only = false,
bool free_params_immediately = false,
RNGType rng_type = STD_DEFAULT_RNG);
bool load_from_file(const std::string& file_path);
bool load_from_file(const std::string& file_path, Schedule d = DEFAULT);
std::vector<uint8_t> txt2img(
const std::string& prompt,
const std::string& negative_prompt,
Expand Down

0 comments on commit 968fbf0

Please sign in to comment.