Skip to content

Commit

Permalink
support tinyllama
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhaode committed Jan 23, 2024
1 parent cc06bb0 commit 8ef110c
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
10 changes: 10 additions & 0 deletions include/llm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,16 @@ class Llama2_7b : public Llm {
virtual bool is_stop(int token_id) override;
};

class TinyLlama : public Llama2_7b {
public:
TinyLlama() {
model_name_ = "TinyLlama";
layer_nums_ = 22;
key_value_shape_ = {2, 1, 4, 0, 64};
}
private:
virtual std::vector<int> tokenizer(const std::string& query) override;
};
// Llm end

// Embedding start
Expand Down
21 changes: 20 additions & 1 deletion src/llm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ Llm* Llm::createLLM(const std::string& path, std::string model_type) {
} else if (model_type.find("internlm") != std::string::npos) {
llm = new Llama2_7b;
llm->model_name_ = "Internlm_7b";
} else if (model_type.find("tinyllama") != std::string::npos) {
llm = new TinyLlama;
llm->model_name_ = "TinyLlama";
}
if (!llm) {
std::cerr << "model type can't judge!" << std::endl;
Expand Down Expand Up @@ -697,6 +700,22 @@ bool Llama2_7b::is_stop(int token_id) {
}
return token_id == 2;
}

std::vector<int> TinyLlama::tokenizer(const std::string& query) {
auto ids = tokenizer_encode(query);
/*
<|system|>
You are a friendly chatbot who always responds in the style of a pirate</s>
<|user|>
{query}</s>
<|assistant|>
*/
ids.insert(ids.begin(), {1, 529, 29989, 5205, 29989, 29958, 13, 3492, 526, 263, 19780, 13563,
7451, 1058, 2337, 10049, 29879, 297, 278, 3114, 310, 263, 21625,
403, 2, 29871, 13, 29966, 29989, 1792, 29989, 29958, 13});
ids.insert(ids.end(), {2, 29871, 13, 29966, 29989, 465, 22137, 29989, 29958, 13});
return ids;
}
// Llm end

// Embedding start
Expand Down Expand Up @@ -898,7 +917,7 @@ void TextVectorStore::bench() {
auto iptr = indices->readMap<int>();
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
printf("# [%d, %d] search took %lld ms.\n", n, d, duration.count());
std::cout << "bench search time (ms): " << duration.count();
vectors_ = nullptr;
}

Expand Down

0 comments on commit 8ef110c

Please sign in to comment.