Skip to content

Commit

Permalink
feat: implement the complete bpe function (#119)
Browse files Browse the repository at this point in the history
* implement the complete bpe function
---------

Co-authored-by: leejet <[email protected]>
  • Loading branch information
Cyberhan123 and leejet authored Dec 23, 2023
1 parent 8f6b4a3 commit 0e64238
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ IndentCaseLabels: false
ColumnLimit: 0
AccessModifierOffset: -4
NamespaceIndentation: All
FixNamespaceComments: false
FixNamespaceComments: false
AlignAfterOpenBracket: true
AlignConsecutiveAssignments: true
IndentCaseLabels: true
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
- The current implementation of ggml_conv_2d is slow and has high memory usage
- Implement Winograd Convolution 2D for 3x3 kernel filtering
- [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d)
- [ ] Implement BPE Tokenizer
- [ ] Implement [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN/tree/master) upscaler
- [ ] k-quants support

Expand Down
99 changes: 90 additions & 9 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,6 @@ std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
}

// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
// TODO: implement bpe
class CLIPTokenizer {
private:
SDVersion version = VERSION_1_x;
Expand All @@ -547,6 +546,21 @@ class CLIPTokenizer {
return text;
}

static std::set<std::pair<std::u32string, std::u32string>> get_pairs(const std::vector<std::u32string>& subwords) {
std::set<std::pair<std::u32string, std::u32string>> pairs;
if (subwords.size() == 0) {
return pairs;
}
std::u32string prev_subword = subwords[0];
for (int i = 1; i < subwords.size(); i++) {
std::u32string subword = subwords[i];
std::pair<std::u32string, std::u32string> pair(prev_subword, subword);
pairs.insert(pair);
prev_subword = subword;
}
return pairs;
}

public:
CLIPTokenizer(SDVersion version = VERSION_1_x)
: version(version) {}
Expand All @@ -565,7 +579,9 @@ class CLIPTokenizer {
merges.push_back(merges_utf32_str.substr(start, pos - start));
start = pos + 1;
}
merges = std::vector<std::u32string>(merges.begin() + 1, merges.begin() + 49152 - 256 - 2 + 1);
// LOG_DEBUG("merges size %llu", merges.size());
GGML_ASSERT(merges.size() == 48895);
merges = std::vector<std::u32string>(merges.begin() + 1, merges.end());
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
for (const auto& merge : merges) {
size_t space_pos = merge.find(' ');
Expand Down Expand Up @@ -596,14 +612,79 @@ class CLIPTokenizer {
}
};

std::u32string bpe(std::u32string token) {
std::u32string word = token + utf8_to_utf32("</w>");
if (encoder.find(word) != encoder.end()) {
return word;
} else if (encoder.find(token) != encoder.end()) {
return token;
std::u32string bpe(const std::u32string& token) {
std::vector<std::u32string> word;

for (int i = 0; i < token.size() - 1; i++) {
word.emplace_back(1, token[i]);
}
word.push_back(token.substr(token.size() - 1) + utf8_to_utf32("</w>"));

std::set<std::pair<std::u32string, std::u32string>> pairs = get_pairs(word);

if (pairs.empty()) {
return token + utf8_to_utf32("</w>");
}
return utf8_to_utf32(UNK_TOKEN);

while (true) {
auto min_pair_iter = std::min_element(pairs.begin(),
pairs.end(),
[&](const std::pair<std::u32string, std::u32string>& a,
const std::pair<std::u32string, std::u32string>& b) {
if (bpe_ranks.find(a) == bpe_ranks.end()) {
return false;
} else if (bpe_ranks.find(b) == bpe_ranks.end()) {
return true;
}
return bpe_ranks.at(a) < bpe_ranks.at(b);
});

const std::pair<std::u32string, std::u32string>& bigram = *min_pair_iter;

if (bpe_ranks.find(bigram) == bpe_ranks.end()) {
break;
}

std::u32string first = bigram.first;
std::u32string second = bigram.second;
std::vector<std::u32string> new_word;
int32_t i = 0;

while (i < word.size()) {
auto it = std::find(word.begin() + i, word.end(), first);
if (it == word.end()) {
new_word.insert(new_word.end(), word.begin() + i, word.end());
break;
}
new_word.insert(new_word.end(), word.begin() + i, it);
i = static_cast<int32_t>(std::distance(word.begin(), it));

if (word[i] == first && i < static_cast<int32_t>(word.size()) - 1 && word[i + 1] == second) {
new_word.push_back(first + second);
i += 2;
} else {
new_word.push_back(word[i]);
i += 1;
}
}

word = new_word;

if (word.size() == 1) {
break;
}
pairs = get_pairs(word);
}

std::u32string result;
for (int i = 0; i < word.size(); i++) {
result += word[i];
if (i != word.size() - 1) {
result += utf8_to_utf32(" ");
}
}

return result;
}

std::vector<int> tokenize(std::string text, size_t max_length = 0, bool padding = false) {
Expand Down
4 changes: 4 additions & 0 deletions stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <string>
#include <vector>

#include "ggml/ggml.h"

enum RNGType {
STD_DEFAULT_RNG,
CUDA_RNG
Expand Down Expand Up @@ -42,10 +44,12 @@ class StableDiffusion {
bool free_params_immediately = false,
std::string lora_model_dir = "",
RNGType rng_type = STD_DEFAULT_RNG);

bool load_from_file(const std::string& model_path,
const std::string& vae_path,
ggml_type wtype,
Schedule d = DEFAULT);

std::vector<uint8_t*> txt2img(
std::string prompt,
std::string negative_prompt,
Expand Down

0 comments on commit 0e64238

Please sign in to comment.