From 732f02f8f54359081c5c45fd5b42e4515dce4244 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Thu, 6 Jun 2024 12:11:31 -0700 Subject: [PATCH] fix: load vocab_size first then use it to decide model type for model sharing between llama3, llama2 and Yi. (#230) --- src/models/huggingface/llama.h | 4 +--- src/server/main.cpp | 7 ++----- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/src/models/huggingface/llama.h b/src/models/huggingface/llama.h index 1e00494d..4b8ceacc 100644 --- a/src/models/huggingface/llama.h +++ b/src/models/huggingface/llama.h @@ -416,11 +416,11 @@ REGISTER_MODEL_ARGS(llama, [&] { LOAD_ARG_OR(hidden_act, "hidden_act", "silu"); // decide model type based on vocab size + LOAD_ARG_OR(vocab_size, "vocab_size", 128256); if (args->vocab_size() == 128256) { // choose the right chat template SET_ARG(model_type, "llama3"); - LOAD_ARG_OR(vocab_size, "vocab_size", 128256); LOAD_ARG_OR(hidden_size, "hidden_size", 8192); LOAD_ARG_OR(n_layers, "num_hidden_layers", 80); LOAD_ARG_OR(n_heads, "num_attention_heads", 64); @@ -437,7 +437,6 @@ REGISTER_MODEL_ARGS(llama, [&] { } else if (args->vocab_size() == 64000) { // choose the right chat template SET_ARG(model_type, "Yi"); - LOAD_ARG_OR(vocab_size, "vocab_size", 64000); LOAD_ARG_OR(hidden_size, "hidden_size", 7168); LOAD_ARG_OR(n_layers, "num_hidden_layers", 60); LOAD_ARG_OR(n_heads, "num_attention_heads", 56); @@ -454,7 +453,6 @@ REGISTER_MODEL_ARGS(llama, [&] { SET_ARG(stop_token_ids, std::unordered_set({2, 6, 7, 8})); } else { // llama 2 - LOAD_ARG_OR(vocab_size, "vocab_size", 32000); LOAD_ARG_OR(hidden_size, "hidden_size", 4096); LOAD_ARG_OR(n_layers, "num_hidden_layers", 32); LOAD_ARG_OR(n_heads, "num_attention_heads", 32); diff --git a/src/server/main.cpp b/src/server/main.cpp index 1fa5c6ad..47285031 100644 --- a/src/server/main.cpp +++ b/src/server/main.cpp @@ -80,12 +80,9 @@ DEFINE_int32(num_speculative_tokens, 0, "number of speculative tokens"); // NOLINTNEXTLINE static std::atomic signal_received{0}; void shutdown_handler(int signal) { - // force exit after receiving second signal - if (signal_received.fetch_add(1, std::memory_order_relaxed) >= 1) { - LOG(ERROR) << "Received signal again, force aborting..."; - exit(1); - } + // TODO: gracefully shutdown the server LOG(WARNING) << "Received signal " << signal << ", stopping server..."; + exit(1); } std::optional> parse_batch_sizes(