Skip to content

Commit

Permalink
add get vocab
Browse files Browse the repository at this point in the history
  • Loading branch information
cahya-wirawan committed Aug 15, 2024
1 parent 9e398b1 commit a12d07f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
9 changes: 9 additions & 0 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::HashMap;
use pyo3::prelude::*;
use std::str;
use rwkv_tokenizer;
Expand All @@ -24,6 +25,14 @@ impl WorldTokenizer {
pub(crate) fn decode(&self, vec: Vec<u16>) -> String {
return self.tokenizer.decode(vec);
}

pub(crate) fn vocab_size(&self) -> usize {
return self.tokenizer.vocab_size();
}

pub(crate) fn get_vocab(&self) -> HashMap<&Vec<u8>, usize> {
return self.tokenizer.get_vocab();
}
}


Expand Down
13 changes: 13 additions & 0 deletions rwkv-tokenizer/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod trie;
use std::{str, env};
use std::collections::HashMap;
use std::fs::File;
use std::io::{self, BufRead};
use std::path::{Path};
Expand Down Expand Up @@ -66,6 +67,18 @@ impl WorldTokenizer {
return str::from_utf8(&*result).unwrap().to_string();
}

pub fn vocab_size(&self) -> usize {
self.tokens.len()
}

pub fn get_vocab(&self) -> HashMap<&Vec<u8>, usize> {
let mut vocabularies: HashMap<&Vec<u8>, usize> = HashMap::new();
for (index, value) in self.tokens.iter().enumerate() {
vocabularies.insert(value, index);
}
vocabularies
}

fn hex_to_bytes(hex: &str) -> Option<Vec<u8>> {
let hex = hex.replace("\\x", "");
if hex.len() % 2 == 0 {
Expand Down

0 comments on commit a12d07f

Please sign in to comment.