From 3124fa067b1d18e933461ed8da5a4eebea47a9d2 Mon Sep 17 00:00:00 2001 From: Urs Ganse Date: Sun, 11 Aug 2024 03:38:53 +0200 Subject: [PATCH] Support BREAK pseudo-token --- conditioner.hpp | 41 +++++++++++++++++++++++++++++++++++++---- util.cpp | 9 ++++++--- 2 files changed, 43 insertions(+), 7 deletions(-) diff --git a/conditioner.hpp b/conditioner.hpp index 43d0a6d5..a7135a4d 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -248,13 +248,30 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { const std::string& curr_text = item.first; float curr_weight = item.second; // printf(" %s: %f \n", curr_text.c_str(), curr_weight); - std::vector curr_tokens = tokenizer.encode(curr_text, on_new_token_cb); int32_t clean_index = 0; + if(curr_text == "BREAK" && curr_weight == -1.0f) { + // Pad token array up to chunk size at this point. + // TODO: This is a hardcoded chunk_len, like in stable-diffusion.cpp, make it a parameter for the future? + // Also, this is 75 instead of 77 to leave room for BOS and EOS tokens. + int padding_size = 75 - (tokens_acc % 75); + for (int j = 0; j < padding_size; j++) { + clean_input_ids.push_back(tokenizer.EOS_TOKEN_ID); + clean_index++; + } + + // After padding, continue to the next iteration to process the following text as a new segment + tokens.insert(tokens.end(), clean_input_ids.begin(), clean_input_ids.end()); + weights.insert(weights.end(), padding_size, curr_weight); + continue; + } + + // Regular token, process normally + std::vector curr_tokens = tokenizer.encode(curr_text, on_new_token_cb); for (uint32_t i = 0; i < curr_tokens.size(); i++) { int token_id = curr_tokens[i]; - if (token_id == image_token) + if (token_id == image_token) { class_token_index.push_back(clean_index - 1); - else { + } else { clean_input_ids.push_back(token_id); clean_index++; } @@ -354,6 +371,22 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { for (const auto& item : parsed_attention) { const std::string& curr_text = item.first; float curr_weight = item.second; + + if(curr_text == "BREAK" && curr_weight == -1.0f) { + // Pad token array up to chunk size at this point. + // TODO: This is a hardcoded chunk_len, like in stable-diffusion.cpp, make it a parameter for the future? + // Also, this is 75 instead of 77 to leave room for BOS and EOS tokens. + size_t current_size = tokens.size(); + size_t padding_size = (75 - (current_size % 75)) % 75; // Ensure no negative padding + + if (padding_size > 0) { + LOG_DEBUG("BREAK token encountered, padding current chunk by %zu tokens.", padding_size); + tokens.insert(tokens.end(), padding_size, tokenizer.EOS_TOKEN_ID); + weights.insert(weights.end(), padding_size, 1.0f); + } + continue; // Skip to the next item after handling BREAK + } + std::vector curr_tokens = tokenizer.encode(curr_text, on_new_token_cb); tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); weights.insert(weights.end(), curr_tokens.size(), curr_weight); @@ -1203,4 +1236,4 @@ struct FluxCLIPEmbedder : public Conditioner { } }; -#endif \ No newline at end of file +#endif diff --git a/util.cpp b/util.cpp index 5de5ce26..35ba93c2 100644 --- a/util.cpp +++ b/util.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -606,7 +607,7 @@ std::vector> parse_prompt_attention(const std::str float round_bracket_multiplier = 1.1f; float square_bracket_multiplier = 1 / 1.1f; - std::regex re_attention(R"(\\\(|\\\)|\\\[|\\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|\)|\]|[^\\()\[\]:]+|:)"); + std::regex re_attention(R"(\\\(|\\\)|\\\[|\\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|\)|\]|\bBREAK\b|[^\\()\[\]:B]+|:|\bB)"); std::regex re_break(R"(\s*\bBREAK\b\s*)"); auto multiply_range = [&](int start_position, float multiplier) { @@ -615,7 +616,7 @@ std::vector> parse_prompt_attention(const std::str } }; - std::smatch m; + std::smatch m,m2; std::string remaining_text = text; while (std::regex_search(remaining_text, m, re_attention)) { @@ -639,6 +640,8 @@ std::vector> parse_prompt_attention(const std::str square_brackets.pop_back(); } else if (text == "\\(") { res.push_back({text.substr(1), 1.0f}); + } else if (std::regex_search(text, m2, re_break)) { + res.push_back({"BREAK", -1.0f}); } else { res.push_back({text, 1.0f}); } @@ -669,4 +672,4 @@ std::vector> parse_prompt_attention(const std::str } return res; -} \ No newline at end of file +}