Skip to content

Commit

Permalink
Support BREAK pseudo-token
Browse files Browse the repository at this point in the history
  • Loading branch information
ursg authored and daniandtheweb committed Sep 28, 2024
1 parent 14206fd commit a966b4a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 5 deletions.
41 changes: 37 additions & 4 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,30 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
const std::string& curr_text = item.first;
float curr_weight = item.second;
// printf(" %s: %f \n", curr_text.c_str(), curr_weight);
std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
int32_t clean_index = 0;
if(curr_text == "BREAK" && curr_weight == -1.0f) {
// Pad token array up to chunk size at this point.
// TODO: This is a hardcoded chunk_len, like in stable-diffusion.cpp, make it a parameter for the future?
// Also, this is 75 instead of 77 to leave room for BOS and EOS tokens.
int padding_size = 75 - (tokens_acc % 75);
for (int j = 0; j < padding_size; j++) {
clean_input_ids.push_back(tokenizer.EOS_TOKEN_ID);
clean_index++;
}

// After padding, continue to the next iteration to process the following text as a new segment
tokens.insert(tokens.end(), clean_input_ids.begin(), clean_input_ids.end());
weights.insert(weights.end(), padding_size, curr_weight);
continue;
}

// Regular token, process normally
std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
for (uint32_t i = 0; i < curr_tokens.size(); i++) {
int token_id = curr_tokens[i];
if (token_id == image_token)
if (token_id == image_token) {
class_token_index.push_back(clean_index - 1);
else {
} else {
clean_input_ids.push_back(token_id);
clean_index++;
}
Expand Down Expand Up @@ -354,6 +371,22 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
for (const auto& item : parsed_attention) {
const std::string& curr_text = item.first;
float curr_weight = item.second;

if(curr_text == "BREAK" && curr_weight == -1.0f) {
// Pad token array up to chunk size at this point.
// TODO: This is a hardcoded chunk_len, like in stable-diffusion.cpp, make it a parameter for the future?
// Also, this is 75 instead of 77 to leave room for BOS and EOS tokens.
size_t current_size = tokens.size();
size_t padding_size = (75 - (current_size % 75)) % 75; // Ensure no negative padding

if (padding_size > 0) {
LOG_DEBUG("BREAK token encountered, padding current chunk by %zu tokens.", padding_size);
tokens.insert(tokens.end(), padding_size, tokenizer.EOS_TOKEN_ID);
weights.insert(weights.end(), padding_size, 1.0f);
}
continue; // Skip to the next item after handling BREAK
}

std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
Expand Down Expand Up @@ -1203,4 +1236,4 @@ struct FluxCLIPEmbedder : public Conditioner {
}
};

#endif
#endif
5 changes: 4 additions & 1 deletion util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <codecvt>
#include <fstream>
#include <locale>
#include <regex>
#include <sstream>
#include <string>
#include <thread>
Expand Down Expand Up @@ -606,7 +607,7 @@ std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::str
float round_bracket_multiplier = 1.1f;
float square_bracket_multiplier = 1 / 1.1f;

std::regex re_attention(R"(\\\(|\\\)|\\\[|\\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|\)|\]|[^\\()\[\]:]+|:)");
std::regex re_attention(R"(\\\(|\\\)|\\\[|\\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|\)|\]|\bBREAK\b|[^\\()\[\]:]+|:)");
std::regex re_break(R"(\s*\bBREAK\b\s*)");

auto multiply_range = [&](int start_position, float multiplier) {
Expand Down Expand Up @@ -639,6 +640,8 @@ std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::str
square_brackets.pop_back();
} else if (text == "\\(") {
res.push_back({text.substr(1), 1.0f});
} else if (std::regex_search(text, re_break)) {
res.push_back({"BREAK", -1.0f});
} else {
res.push_back({text, 1.0f});
}
Expand Down

0 comments on commit a966b4a

Please sign in to comment.