Skip to content

Commit

Permalink
add multithreading encoding (batching)
Browse files Browse the repository at this point in the history
  • Loading branch information
cahya-wirawan committed Aug 16, 2024
1 parent a12d07f commit f873c94
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 14 deletions.
59 changes: 55 additions & 4 deletions bindings/python/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pyrwkv_tokenizer"
version = "0.8.5"
version = "0.9.0"
edition = "2021"
authors = ["Cahya Wirawan <[email protected]>"]
description = "A fast RWKV Tokenizer"
Expand All @@ -22,4 +22,5 @@ crate-type = ["cdylib"]

[dependencies]
pyo3 = "0.21.2"
rwkv-tokenizer = "0.8.5"
rwkv-tokenizer = { path = "../../rwkv-tokenizer" }
rayon = "1.10.0"
2 changes: 1 addition & 1 deletion bindings/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "maturin"

[project]
name = "pyrwkv_tokenizer"
version = "0.8.5"
version = "0.9.0"
requires-python = ">=3.8"
description = "RWKV Tokenizer"
readme = "README.md"
Expand Down
12 changes: 11 additions & 1 deletion bindings/python/pyrwkv_tokenizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .pyrwkv_tokenizer import __doc__ # noqa: F401

__all__ = __all__ + ["RWKVTokenizer"]
__version__ = "0.8.5"
__version__ = "0.9.0"


class RWKVTokenizer:
Expand All @@ -22,6 +22,16 @@ def encode(self, text: str):
tokens_ids = self.tokenizer.encode(text)
return tokens_ids

def encode_batch(self, text_list: [str]):
tokens_ids = self.tokenizer.encode_batch(text_list)
return tokens_ids

def decode(self, tokens_ids):
text = self.tokenizer.decode(tokens_ids)
return text

def vocab_size(self):
return self.tokenizer.vocab_size()

def get_vocab(self):
return self.tokenizer.get_vocab()
6 changes: 5 additions & 1 deletion bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ impl WorldTokenizer {
self.tokenizer.encode(word)
}

pub(crate) fn encode_batch(&self, word_list: Vec<String>) -> Vec<Vec<u16>> {
self.tokenizer.encode_batch(word_list)
}

pub(crate) fn decode(&self, vec: Vec<u16>) -> String {
return self.tokenizer.decode(vec);
}
Expand All @@ -30,7 +34,7 @@ impl WorldTokenizer {
return self.tokenizer.vocab_size();
}

pub(crate) fn get_vocab(&self) -> HashMap<&Vec<u8>, usize> {
pub(crate) fn get_vocab(&self) -> HashMap<String, usize> {
return self.tokenizer.get_vocab();
}
}
Expand Down
54 changes: 53 additions & 1 deletion rwkv-tokenizer/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion rwkv-tokenizer/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "rwkv-tokenizer"
version = "0.8.5"
version = "0.9.0"
edition = "2021"
authors = ["Cahya Wirawan <[email protected]>"]
description = "A fast RWKV Tokenizer"
Expand All @@ -11,3 +11,4 @@ exclude = []
[dependencies]
regex = "1.10.4"
unescape = "0.1.0"
rayon = "1.10.0"
22 changes: 19 additions & 3 deletions rwkv-tokenizer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use std::path::{Path};
use regex::Regex;
use trie::Trie;
use unescape::unescape;
use rayon::prelude::*;


#[derive(Debug)]
pub struct WorldTokenizer {
Expand Down Expand Up @@ -58,6 +60,10 @@ impl WorldTokenizer {
self.trie.tokenize(word)
}

pub fn encode_batch(&self, word_list: Vec<String>) -> Vec<Vec<u16>> {
word_list.par_iter().map(|word| self.trie.tokenize(word)).collect()
}

pub fn decode(&self, vec: Vec<u16>) -> String {
let mut result: Vec<u8> = Vec::new();
for index in vec.iter() {
Expand All @@ -71,10 +77,11 @@ impl WorldTokenizer {
self.tokens.len()
}

pub fn get_vocab(&self) -> HashMap<&Vec<u8>, usize> {
let mut vocabularies: HashMap<&Vec<u8>, usize> = HashMap::new();
pub fn get_vocab(&self) -> HashMap<String, usize> {
let mut vocabularies: HashMap<String, usize> = HashMap::new();
for (index, value) in self.tokens.iter().enumerate() {
vocabularies.insert(value, index);
let text: String = String::from_utf8((*value).to_owned()).unwrap_or_else(|_e| "Binary string (TODO)".to_string());
vocabularies.insert(text, index);
}
vocabularies
}
Expand Down Expand Up @@ -838,4 +845,13 @@ Nórdicg: Ljœr ye caudran créneþ ý jor cẃran."#;
let text = tokenizer.decode(token_ids);
assert_eq!(text, LONG_UTF8_TEXT);
}

#[test]
fn test_get_vocab() {
let tokenizer = WorldTokenizer::new(None).unwrap();
let vocab = tokenizer.get_vocab();
// The vocab size should be 65529, but currently, the binary keys/strings are not included,
// therefore it is only 65044. It will be added later.
assert_eq!(vocab.len(), 65044);
}
}

0 comments on commit f873c94

Please sign in to comment.