Skip to content

Commit

Permalink
add support for sd3.5 model
Browse files Browse the repository at this point in the history
  • Loading branch information
stduhpf committed Oct 29, 2024
1 parent 0e440c3 commit 4d70e16
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 12 deletions.
3 changes: 3 additions & 0 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,9 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
is_flux = true;
}
if (tensor_storage.name.find("joint_blocks.0.x_block.attn2.ln_q.weight") != std::string::npos) {
return VERSION_SD3_5_2B;
}
if (tensor_storage.name.find("joint_blocks.37.x_block.attn.ln_q.weight") != std::string::npos) {
return VERSION_SD3_5_8B;
}
Expand Down
1 change: 1 addition & 0 deletions model.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ enum SDVersion {
VERSION_FLUX_DEV,
VERSION_FLUX_SCHNELL,
VERSION_SD3_5_8B,
VERSION_SD3_5_2B,
VERSION_COUNT,
};

Expand Down
23 changes: 12 additions & 11 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ const char* model_version_to_str[] = {
"SD3 2B",
"Flux Dev",
"Flux Schnell",
"SD3.5 8B"};
"SD3.5 8B",
"SD3.5 2B"};

const char* sampling_methods_str[] = {
"Euler A",
Expand Down Expand Up @@ -288,7 +289,7 @@ class StableDiffusionGGML {
"try specifying SDXL VAE FP16 Fix with the --vae parameter. "
"You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors");
}
} else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) {
} else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
scale_factor = 1.5305f;
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
scale_factor = 0.3611;
Expand All @@ -311,7 +312,7 @@ class StableDiffusionGGML {
} else {
clip_backend = backend;
bool use_t5xxl = false;
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
use_t5xxl = true;
}
if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) {
Expand All @@ -322,7 +323,7 @@ class StableDiffusionGGML {
LOG_INFO("CLIP: Using CPU backend");
clip_backend = ggml_backend_cpu_init();
}
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype);
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version);
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
Expand Down Expand Up @@ -520,7 +521,7 @@ class StableDiffusionGGML {
is_using_v_parameterization = true;
}

if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
LOG_INFO("running in FLOW mode");
denoiser = std::make_shared<DiscreteFlowDenoiser>();
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
Expand Down Expand Up @@ -948,7 +949,7 @@ class StableDiffusionGGML {
if (use_tiny_autoencoder) {
C = 4;
} else {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
C = 32;
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
C = 32;
Expand Down Expand Up @@ -1281,7 +1282,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
// Sample
std::vector<struct ggml_tensor*> final_latents; // collect latents to decode
int C = 4;
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
C = 16;
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
C = 16;
Expand Down Expand Up @@ -1394,7 +1395,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,

struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
params.mem_size *= 3;
}
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
Expand All @@ -1420,15 +1421,15 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);

int C = 4;
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
C = 16;
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
C = 16;
}
int W = width / 8;
int H = height / 8;
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
ggml_set_f32(init_latent, 0.0609f);
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
ggml_set_f32(init_latent, 0.1159f);
Expand Down Expand Up @@ -1489,7 +1490,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,

struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
params.mem_size *= 2;
}
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
Expand Down
2 changes: 1 addition & 1 deletion vae.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ class AutoencodingEngine : public GGMLBlock {
bool use_video_decoder = false,
SDVersion version = VERSION_SD1)
: decode_only(decode_only), use_video_decoder(use_video_decoder) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
dd_config.z_channels = 16;
use_quant = false;
}
Expand Down

0 comments on commit 4d70e16

Please sign in to comment.