Skip to content

Commit

Permalink
fix: load vocab_size first then use it to decide model type for model…
Browse files Browse the repository at this point in the history
… sharing between llama3, llama2 and Yi. (#230)
  • Loading branch information
guocuimi authored Jun 6, 2024
1 parent 917c416 commit 732f02f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 8 deletions.
4 changes: 1 addition & 3 deletions src/models/huggingface/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -416,11 +416,11 @@ REGISTER_MODEL_ARGS(llama, [&] {
LOAD_ARG_OR(hidden_act, "hidden_act", "silu");

// decide model type based on vocab size
LOAD_ARG_OR(vocab_size, "vocab_size", 128256);
if (args->vocab_size() == 128256) {
// choose the right chat template
SET_ARG(model_type, "llama3");

LOAD_ARG_OR(vocab_size, "vocab_size", 128256);
LOAD_ARG_OR(hidden_size, "hidden_size", 8192);
LOAD_ARG_OR(n_layers, "num_hidden_layers", 80);
LOAD_ARG_OR(n_heads, "num_attention_heads", 64);
Expand All @@ -437,7 +437,6 @@ REGISTER_MODEL_ARGS(llama, [&] {
} else if (args->vocab_size() == 64000) {
// choose the right chat template
SET_ARG(model_type, "Yi");
LOAD_ARG_OR(vocab_size, "vocab_size", 64000);
LOAD_ARG_OR(hidden_size, "hidden_size", 7168);
LOAD_ARG_OR(n_layers, "num_hidden_layers", 60);
LOAD_ARG_OR(n_heads, "num_attention_heads", 56);
Expand All @@ -454,7 +453,6 @@ REGISTER_MODEL_ARGS(llama, [&] {
SET_ARG(stop_token_ids, std::unordered_set<int32_t>({2, 6, 7, 8}));
} else {
// llama 2
LOAD_ARG_OR(vocab_size, "vocab_size", 32000);
LOAD_ARG_OR(hidden_size, "hidden_size", 4096);
LOAD_ARG_OR(n_layers, "num_hidden_layers", 32);
LOAD_ARG_OR(n_heads, "num_attention_heads", 32);
Expand Down
7 changes: 2 additions & 5 deletions src/server/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,9 @@ DEFINE_int32(num_speculative_tokens, 0, "number of speculative tokens");
// NOLINTNEXTLINE
static std::atomic<uint32_t> signal_received{0};
void shutdown_handler(int signal) {
// force exit after receiving second signal
if (signal_received.fetch_add(1, std::memory_order_relaxed) >= 1) {
LOG(ERROR) << "Received signal again, force aborting...";
exit(1);
}
// TODO: gracefully shutdown the server
LOG(WARNING) << "Received signal " << signal << ", stopping server...";
exit(1);
}

std::optional<std::vector<uint32_t>> parse_batch_sizes(
Expand Down

0 comments on commit 732f02f

Please sign in to comment.