Skip to content

Commit

Permalink
feat: Control Net support + Textual Inversion (embeddings) (#131)
Browse files Browse the repository at this point in the history
* add controlnet to pipeline

* add cli params

* control strength cli param

* cli param keep controlnet in cpu

* add Textual Inversion

* add canny preprocessor

* refactor: change ggml_type_sizef to ggml_row_size

* process hint once time

* ignore the embedding name case

---------

Co-authored-by: leejet <[email protected]>
  • Loading branch information
FSSRepo and leejet authored Jan 29, 2024
1 parent c6071fa commit 36ec16a
Show file tree
Hide file tree
Showing 20 changed files with 1,823 additions and 589 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,4 @@ test/
*.gguf
output*.png
models*
!taesd-model.gguf
*.log
4 changes: 0 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,13 @@ option(SD_CUBLAS "sd: cuda backend" OFF)
option(SD_HIPBLAS "sd: rocm backend" OFF)
option(SD_METAL "sd: metal backend" OFF)
option(SD_FLASH_ATTN "sd: use flash attention for x4 less memory usage" OFF)
option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF)
option(BUILD_SHARED_LIBS "sd: build shared libs" OFF)
#option(SD_BUILD_SERVER "sd: build server example" ON)

if(SD_CUBLAS)
message("Use CUBLAS as backend stable-diffusion")
set(GGML_CUBLAS ON)
add_definitions(-DSD_USE_CUBLAS)
if(SD_FAST_SOFTMAX)
set(GGML_CUDA_FAST_SOFTMAX ON)
endif()
endif()

if(SD_METAL)
Expand Down
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
- Faster and memory efficient latent decoding with [TAESD](https://github.com/madebyollin/taesd)
- Upscale images generated with [ESRGAN](https://github.com/xinntao/Real-ESRGAN)
- VAE tiling processing for reduce memory usage
- Control Net support with SD 1.5
- Sampling method
- `Euler A`
- `Euler`
Expand All @@ -53,9 +54,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
- [ ] More sampling methods
- [ ] Make inference faster
- The current implementation of ggml_conv_2d is slow and has high memory usage
- Implement Winograd Convolution 2D for 3x3 kernel filtering
- [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d)
- [ ] Implement Textual Inversion (embeddings)
- [ ] Implement Inpainting support
- [ ] k-quants support

Expand Down Expand Up @@ -159,16 +158,20 @@ arguments:
-m, --model [MODEL] path to model
--vae [VAE] path to vae
--taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
--control-net [CONTROL_PATH] path to control net model
--embd-dir [EMBEDDING_PATH] path to embeddings.
--upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.
--type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)
If not specified, the default is the type of the weight file.
--lora-model-dir [DIR] lora model directory
-i, --init-img [IMAGE] path to the input image, required by img2img
--control-image [IMAGE] path to image condition, control net
-o, --output OUTPUT path to write result image to (default: ./output.png)
-p, --prompt [PROMPT] the prompt to render
-n, --negative-prompt PROMPT the negative prompt (default: "")
--cfg-scale SCALE unconditional guidance scale: (default: 7.0)
--strength STRENGTH strength for noising/unnoising (default: 0.75)
--control-strength STRENGTH strength to apply Control Net (default: 0.9)
1.0 corresponds to full destruction of information in init image
-H, --height H image height, in pixel space (default: 512)
-W, --width W image width, in pixel space (default: 512)
Expand All @@ -182,6 +185,7 @@ arguments:
--clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
<= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
--vae-tiling process vae in tiles to reduce memory usage
--control-net-cpu keep controlnet in cpu (for low vram)
-v, --verbose print extra info
```
Expand Down
Binary file added assets/control.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/control_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/control_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
149 changes: 124 additions & 25 deletions clip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define __CLIP_HPP__

#include "ggml_extend.hpp"
#include "model.h"

/*================================================== CLIPTokenizer ===================================================*/

Expand Down Expand Up @@ -67,6 +68,9 @@ std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
}

// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py

typedef std::function<bool(std::string&, std::vector<int32_t>&)> on_new_token_cb_t;

class CLIPTokenizer {
private:
SDVersion version = VERSION_1_x;
Expand Down Expand Up @@ -234,8 +238,11 @@ class CLIPTokenizer {
return result;
}

std::vector<int> tokenize(std::string text, size_t max_length = 0, bool padding = false) {
std::vector<int32_t> tokens = encode(text);
std::vector<int> tokenize(std::string text,
on_new_token_cb_t on_new_token_cb,
size_t max_length = 0,
bool padding = false) {
std::vector<int32_t> tokens = encode(text, on_new_token_cb);
tokens.insert(tokens.begin(), BOS_TOKEN_ID);
if (max_length > 0) {
if (tokens.size() > max_length - 1) {
Expand All @@ -255,7 +262,7 @@ class CLIPTokenizer {
return tokens;
}

std::vector<int> encode(std::string text) {
std::vector<int> encode(std::string text, on_new_token_cb_t on_new_token_cb) {
std::string original_text = text;
std::vector<int32_t> bpe_tokens;
text = whitespace_clean(text);
Expand All @@ -268,6 +275,10 @@ class CLIPTokenizer {
std::string str = text;
std::vector<std::string> token_strs;
while (std::regex_search(str, matches, pat)) {
bool skip = on_new_token_cb(str, bpe_tokens);
if (skip) {
continue;
}
for (auto& token : matches) {
std::string token_str = token.str();
std::u32string utf32_token;
Expand Down Expand Up @@ -444,13 +455,13 @@ struct ResidualAttentionBlock {
struct ggml_tensor* ln2_b; // [hidden_size, ]

size_t calculate_mem_size(ggml_type wtype) {
double mem_size = 0;
mem_size += 4 * hidden_size * hidden_size * ggml_type_sizef(wtype); // q_w/k_w/v_w/out_w
mem_size += 8 * hidden_size * ggml_type_sizef(GGML_TYPE_F32); // q_b/k_b/v_b/out_b/ln1_w/ln1_b/ln2_w/ln2_b
mem_size += 2 * hidden_size * intermediate_size * ggml_type_sizef(wtype); // fc1_w/fc2_w
mem_size += intermediate_size * ggml_type_sizef(GGML_TYPE_F32); // fc1_b
mem_size += hidden_size * ggml_type_sizef(GGML_TYPE_F32); // fc2_b
return static_cast<size_t>(mem_size);
size_t mem_size = 0;
mem_size += 4 * ggml_row_size(wtype, hidden_size * hidden_size); // q_w/k_w/v_w/out_w
mem_size += 8 * ggml_row_size(GGML_TYPE_F32, hidden_size); // q_b/k_b/v_b/out_b/ln1_w/ln1_b/ln2_w/ln2_b
mem_size += 2 * ggml_row_size(wtype, hidden_size * intermediate_size); // fc1_w/fc2_w
mem_size += ggml_row_size(GGML_TYPE_F32, intermediate_size); // fc1_b
mem_size += ggml_row_size(GGML_TYPE_F32, hidden_size); // fc2_b
return mem_size;
}

void init_params(struct ggml_context* ctx, ggml_allocr* alloc, ggml_type wtype) {
Expand Down Expand Up @@ -597,13 +608,17 @@ struct CLIPTextModel {
struct ggml_tensor* position_ids;
struct ggml_tensor* token_embed_weight;
struct ggml_tensor* position_embed_weight;
struct ggml_tensor* token_embed_custom;

// transformer
std::vector<ResidualAttentionBlock> resblocks;
struct ggml_tensor* final_ln_w;
struct ggml_tensor* final_ln_b;

struct ggml_tensor* text_projection;
std::string embd_dir;
int32_t num_custom_embeddings = 0;
std::vector<std::string> readed_embeddings;

CLIPTextModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14,
int clip_skip = -1,
Expand Down Expand Up @@ -642,18 +657,21 @@ struct CLIPTextModel {
}

size_t calculate_mem_size(ggml_type wtype) {
double mem_size = 0;
mem_size += hidden_size * max_position_embeddings * ggml_type_sizef(GGML_TYPE_I32); // position_ids
mem_size += hidden_size * vocab_size * ggml_type_sizef(wtype); // token_embed_weight
mem_size += hidden_size * max_position_embeddings * ggml_type_sizef(wtype); // position_embed_weight
size_t mem_size = 0;
mem_size += ggml_row_size(GGML_TYPE_I32, hidden_size * max_position_embeddings); // position_ids
mem_size += ggml_row_size(wtype, hidden_size * vocab_size); // token_embed_weight
mem_size += ggml_row_size(wtype, hidden_size * max_position_embeddings); // position_embed_weight
if(version == OPENAI_CLIP_VIT_L_14) {
mem_size += ggml_row_size(wtype, hidden_size * max_position_embeddings); // token_embed_custom
}
for (int i = 0; i < num_hidden_layers; i++) {
mem_size += resblocks[i].calculate_mem_size(wtype);
}
mem_size += 2 * hidden_size * ggml_type_sizef(GGML_TYPE_F32); // final_ln_w/b
mem_size += 2 * ggml_row_size(GGML_TYPE_F32, hidden_size); // final_ln_w/b
if (version == OPEN_CLIP_VIT_BIGG_14) {
mem_size += hidden_size * projection_dim * ggml_type_sizef(GGML_TYPE_F32); // text_projection
mem_size += ggml_row_size(GGML_TYPE_F32, hidden_size * projection_dim); // text_projection
}
return static_cast<size_t>(mem_size);
return mem_size;
}

void map_by_name(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
Expand All @@ -670,14 +688,48 @@ struct CLIPTextModel {
}
}

struct ggml_tensor* forward(struct ggml_context* ctx0, struct ggml_tensor* input_ids, size_t max_token_idx = 0, bool return_pooled = false) {
bool load_embedding(std::string embd_name, std::string embd_path, std::vector<int32_t> &bpe_tokens) {
// the order matters
ModelLoader model_loader;
if(!model_loader.init_from_file(embd_path)) {
LOG_ERROR("embedding '%s' failed", embd_name.c_str());
return false;
}
struct ggml_init_params params;
params.mem_size = 32 * 1024; // max for custom embeddings 32 KB
params.mem_buffer = NULL;
params.no_alloc = false;
struct ggml_context* embd_ctx = ggml_init(params);
struct ggml_tensor* embd = NULL;
auto on_load = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) {
if(tensor_storage.ne[0] != hidden_size) {
LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], hidden_size);
return false;
}
embd = ggml_new_tensor_2d(embd_ctx, token_embed_weight->type, hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1);
*dst_tensor = embd;
return true;
};
model_loader.load_tensors(on_load, NULL);
ggml_backend_tensor_set(token_embed_custom, embd->data, num_custom_embeddings * hidden_size * ggml_type_size(token_embed_custom->type), ggml_nbytes(embd));
readed_embeddings.push_back(embd_name);
for(int i = 0; i < embd->ne[1]; i++) {
bpe_tokens.push_back(vocab_size + num_custom_embeddings);
// LOG_DEBUG("new custom token: %i", vocab_size + num_custom_embeddings);
num_custom_embeddings++;
}
LOG_DEBUG("embedding '%s' applied, custom embeddings: %i", embd_name.c_str(), num_custom_embeddings);
return true;
}

struct ggml_tensor* forward(struct ggml_context* ctx0, struct ggml_tensor* input_ids, struct ggml_tensor* tkn_embeddings, uint32_t max_token_idx = 0, bool return_pooled = false) {
// input_ids: [N, n_token]
GGML_ASSERT(input_ids->ne[0] <= position_ids->ne[0]);

// token_embedding + position_embedding
struct ggml_tensor* x;
x = ggml_add(ctx0,
ggml_get_rows(ctx0, token_embed_weight, input_ids),
ggml_get_rows(ctx0, tkn_embeddings == NULL ? token_embed_weight : tkn_embeddings, input_ids),
ggml_get_rows(ctx0,
position_embed_weight,
ggml_view_1d(ctx0, position_ids, input_ids->ne[0], 0))); // [N, n_token, hidden_size]
Expand Down Expand Up @@ -723,6 +775,10 @@ struct CLIPTextModel {

final_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size);

if(version == OPENAI_CLIP_VIT_L_14) {
token_embed_custom = ggml_new_tensor_2d(ctx, wtype, hidden_size, max_position_embeddings);
}

if (version == OPEN_CLIP_VIT_BIGG_14) {
text_projection = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, projection_dim, hidden_size);
}
Expand Down Expand Up @@ -805,11 +861,11 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
}
}

struct ggml_tensor* forward(struct ggml_context* ctx0, struct ggml_tensor* input_ids, struct ggml_tensor* input_ids2, size_t max_token_idx = 0, bool return_pooled = false) {
struct ggml_tensor* forward(struct ggml_context* ctx0, struct ggml_tensor* input_ids, struct ggml_tensor* input_ids2, struct ggml_tensor* embeddings, size_t max_token_idx = 0, bool return_pooled = false) {
if (return_pooled) {
return text_model2.forward(ctx0, input_ids2, max_token_idx, return_pooled);
return text_model2.forward(ctx0, input_ids2, NULL, max_token_idx, return_pooled);
}
auto hidden_states = text_model.forward(ctx0, input_ids); // [N, n_token, hidden_size]
auto hidden_states = text_model.forward(ctx0, input_ids, embeddings); // [N, n_token, hidden_size]
// LOG_DEBUG("hidden_states: %d %d %d %d %d", hidden_states->n_dims, hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]);
if (version == VERSION_XL) {
hidden_states = ggml_reshape_4d(ctx0,
Expand All @@ -820,7 +876,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
hidden_states->ne[3]);
hidden_states = ggml_cont(ctx0, ggml_permute(ctx0, hidden_states, 2, 0, 1, 3));

auto hidden_states2 = text_model2.forward(ctx0, input_ids2); // [N, n_token, hidden_size2]
auto hidden_states2 = text_model2.forward(ctx0, input_ids2, NULL); // [N, n_token, hidden_size2]
hidden_states2 = ggml_reshape_4d(ctx0,
hidden_states2,
hidden_states2->ne[0],
Expand Down Expand Up @@ -857,12 +913,36 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
}

auto on_new_token_cb = [&] (std::string& str, std::vector<int32_t> &bpe_tokens) -> bool {
size_t word_end = str.find(",");
std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
embd_name = trim(embd_name);
std::string embd_path = get_full_path(text_model.embd_dir, embd_name + ".pt");
if(embd_path.size() == 0) {
embd_path = get_full_path(text_model.embd_dir, embd_name + ".ckpt");
}
if(embd_path.size() == 0) {
embd_path = get_full_path(text_model.embd_dir, embd_name + ".safetensors");
}
if(embd_path.size() > 0) {
if(text_model.load_embedding(embd_name, embd_path, bpe_tokens)) {
if(word_end != std::string::npos) {
str = str.substr(word_end);
} else {
str = "";
}
return true;
}
}
return false;
};

std::vector<int> tokens;
std::vector<float> weights;
for (const auto& item : parsed_attention) {
const std::string& curr_text = item.first;
float curr_weight = item.second;
std::vector<int> curr_tokens = tokenizer.encode(curr_text);
std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
}
Expand Down Expand Up @@ -951,7 +1031,26 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
}
}

struct ggml_tensor* hidden_states = forward(ctx0, input_ids, input_ids2, max_token_idx, return_pooled);
struct ggml_tensor* embeddings = NULL;

if(text_model.num_custom_embeddings > 0 && version != VERSION_XL) {
embeddings = ggml_new_tensor_2d(ctx0, wtype, text_model.hidden_size, text_model.vocab_size + text_model.num_custom_embeddings /* custom placeholder */);
ggml_allocr_alloc(allocr, embeddings);
if (!ggml_allocr_is_measure(allocr)) {
// really bad, there is memory inflexibility (this is for host<->device memory conflicts)
void* freeze_data = malloc(ggml_nbytes(text_model.token_embed_weight));
ggml_backend_tensor_get_and_sync(backend, text_model.token_embed_weight, freeze_data, 0, ggml_nbytes(text_model.token_embed_weight));
ggml_backend_tensor_set(embeddings, freeze_data, 0, ggml_nbytes(text_model.token_embed_weight));
free(freeze_data);
// concatenate custom embeddings
void* custom_data = malloc(ggml_nbytes(text_model.token_embed_custom));
ggml_backend_tensor_get_and_sync(backend, text_model.token_embed_custom, custom_data, 0, ggml_nbytes(text_model.token_embed_custom));
ggml_backend_tensor_set(embeddings, custom_data, ggml_nbytes(text_model.token_embed_weight), text_model.num_custom_embeddings * text_model.hidden_size * ggml_type_size(wtype));
free(custom_data);
}
}

struct ggml_tensor* hidden_states = forward(ctx0, input_ids, input_ids2, embeddings, max_token_idx, return_pooled);

ggml_build_forward_expand(gf, hidden_states);
ggml_free(ctx0);
Expand Down
Loading

0 comments on commit 36ec16a

Please sign in to comment.