diff --git a/common/common.cpp b/common/common.cpp index a7ffe01677c3c..99fa55fefde31 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -448,12 +448,18 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.p_split = std::stof(argv[i]); - }else if (arg == "--p-recovery" || arg == "-pr") { + } else if (arg == "--p-recovery" || arg == "-pr") { if (++i >= argc) { invalid_param = true; break; } params.p_recovery = std::stof(argv[i]); + } else if (arg == "--p-decay" || arg == "-pd") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.p_decay = std::stof(argv[i]); } else if (arg == "-m" || arg == "--model") { if (++i >= argc) { invalid_param = true; @@ -860,7 +866,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences); printf(" -pa N, --p-accept N speculative decoding accept probability (default: %.1f)\n", (double)params.p_accept); printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split); - printf(" -pr N, --p-recovery N PipeInfer recovery probability (default: %.1f)\n", (double)params.p_recovery); + printf(" -pr N, --p-recovery N PipeInfer probability recovery (default: %.1f)\n", (double)params.p_recovery); + printf(" -pd N, --p-decay N PipeInfer probability decay (default: %.1f)\n", (double)params.p_decay); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n"); printf(" --image IMAGE_FILE path to an image file. use with multimodal models\n"); diff --git a/common/common.h b/common/common.h index 635a5e2269aac..6053bb0d2ea3e 100644 --- a/common/common.h +++ b/common/common.h @@ -58,6 +58,7 @@ struct gpt_params { float p_accept = 0.5f; // speculative decoding accept probability float p_split = 0.1f; // speculative decoding split probability float p_recovery = 0.0f; // Cumulative probability that p_accept and p_split are increased by per-iteration. + float p_decay = 0.0f; // Cumulative probability that p_accept and p_split are decreased by per-iteration when drafting stops due to p_accept. int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 50b3d4ffc216b..591b3f21fdfe6 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -54,11 +54,11 @@ void begin_async_run(const llama_sampling_params& sparams, int n_seq_dft, bool start_async_spec_run(const gpt_params ¶ms, llama_context *ctx_tgt, llama_context *ctx_dft, std::deque &free_sequence_offsets, int max_seq, llama_batch &batch_tgt, int n_predict, - int prefix_n_past, int n_past_dft, bool has_eos, llama_sampling_context *ctx_sampling, + int prefix_n_past, int n_past_dft, llama_sampling_context *ctx_sampling, std::deque &tgt_cgraphs, const seq_async_run ¤t_run, int &spec_past_tgt, int &spec_past_dft, int first_run, int orig_offset, int32_t &batch_id, llama_batch &batch_dft, int &n_drafted, std::vector &drafts, llama_token &id, - llama_kv_cache_view &kvc, int iter); + llama_kv_cache_view &kvc, float p_adjust, int &n_reject); void begin_non_spec_run(const gpt_params ¶ms, int n_seq_dft, llama_context *ctx, int max_seq, const std::vector &drafts, llama_token id, int32_t &batch_id, int &n_past, int n_past_dft, @@ -66,11 +66,13 @@ void begin_non_spec_run(const gpt_params ¶ms, int n_seq_dft, llama_context * void run_speculation_loop(const gpt_params ¶ms, const float p_accept, llama_context *ctx_tgt, llama_context *ctx_dft, - int max_seq, llama_batch &batch_tgt, int n_predict, int n_past_tgt, int n_past_dft, - bool has_eos, llama_sampling_context *ctx_sampling, int & spec_past_tgt, int & spec_past_dft, - bool & first_run, std::deque &free_sequence_offsets, int32_t &batch_id, llama_batch &batch_dft, - int &n_drafted, std::vector &drafts, std::deque &tgt_cgraphs, - seq_async_run ¤t_run, llama_kv_cache_view &kvc_view_dft, llama_token &id); + const int max_seq, llama_batch &batch_tgt, int n_predict, int n_past_tgt, int n_past_dft, + llama_sampling_context *ctx_sampling, int &spec_past_tgt, int &spec_past_dft, bool &first_run, + std::deque &free_sequence_offsets, int32_t &batch_id, llama_batch &batch_dft, int &n_drafted, + std::vector &drafts, std::deque &tgt_cgraphs, + seq_async_run ¤t_run, llama_kv_cache_view &kvc_view_dft, llama_token &id, int &n_rejected); + +float calc_p_adjust(const gpt_params ¶ms, int iter, int n_reject); int main(int argc, char ** argv) { gpt_params params; @@ -307,6 +309,9 @@ int main(int argc, char ** argv) { bool first_run = true; llama_token id; + + int n_rejected = 0; + while (true) { @@ -389,10 +394,10 @@ int main(int argc, char ** argv) { first_run = true; } else if (!tgt_cgraphs.empty()) { - run_speculation_loop(params, p_accept, ctx_tgt, ctx_dft, max_seq, batch_tgt, n_predict, n_past_tgt, n_past_dft, - has_eos, ctx_sampling, + run_speculation_loop(params, p_accept, ctx_tgt, ctx_dft, max_seq, batch_tgt, n_predict, n_past_tgt, + n_past_dft, ctx_sampling, spec_past_tgt, spec_past_dft, first_run, free_sequence_offsets, batch_id, batch_dft, - n_drafted, drafts, tgt_cgraphs, current_run, kvc_view_dft, id); + n_drafted, drafts, tgt_cgraphs, current_run, kvc_view_dft, id, n_rejected); continue; } @@ -570,6 +575,8 @@ int main(int argc, char ** argv) { continue; } + n_rejected = 0; + check_for_cancel(ctx_tgt, n_past_tgt, tgt_cgraphs, generated, n_seq_dft); @@ -656,9 +663,9 @@ int main(int argc, char ** argv) { // bool is_waiting = false; run_speculation_loop(params, p_accept, ctx_tgt, ctx_dft, max_seq, batch_tgt, n_predict, n_past_tgt, n_past_dft, - has_eos, ctx_sampling, + ctx_sampling, spec_past_tgt, spec_past_dft, first_run, free_sequence_offsets, batch_id, batch_dft, - n_drafted, drafts, tgt_cgraphs, current_run, kvc_view_dft, id); + n_drafted, drafts, tgt_cgraphs, current_run, kvc_view_dft, id, n_rejected); if (n_predict > params.n_predict || has_eos) { @@ -733,10 +740,10 @@ int main(int argc, char ** argv) { void run_speculation_loop(const gpt_params ¶ms, const float p_accept, llama_context *ctx_tgt, llama_context *ctx_dft, const int max_seq, llama_batch &batch_tgt, int n_predict, int n_past_tgt, int n_past_dft, - bool has_eos, llama_sampling_context *ctx_sampling, int &spec_past_tgt, int &spec_past_dft, - bool & first_run, std::deque &free_sequence_offsets, int32_t &batch_id, llama_batch &batch_dft, - int &n_drafted, std::vector &drafts, std::deque &tgt_cgraphs, - seq_async_run ¤t_run, llama_kv_cache_view &kvc_view_dft, llama_token &id) { + llama_sampling_context *ctx_sampling, int &spec_past_tgt, int &spec_past_dft, bool &first_run, + std::deque &free_sequence_offsets, int32_t &batch_id, llama_batch &batch_dft, int &n_drafted, + std::vector &drafts, std::deque &tgt_cgraphs, + seq_async_run ¤t_run, llama_kv_cache_view &kvc_view_dft, llama_token &id, int &n_rejected) { bool is_waiting = llama_mpi_iprobe(ctx_tgt); llama_swap_comm(ctx_tgt); llama_sync_token(ctx_tgt, reinterpret_cast(&is_waiting), 0); @@ -749,7 +756,8 @@ run_speculation_loop(const gpt_params ¶ms, const float p_accept, llama_conte free_sequence_offsets.push_back(current_run.seq_offset); } int iter = 0; - while((!is_waiting && (p_accept + iter * params.p_recovery) < 1.0)) { + float p_adjust = calc_p_adjust(params, iter, n_rejected); + while((!is_waiting && p_accept + (p_adjust = calc_p_adjust(params, iter, n_rejected)) < 1.0)) { @@ -790,10 +798,10 @@ run_speculation_loop(const gpt_params ¶ms, const float p_accept, llama_conte if (start_async_spec_run(params, ctx_tgt, ctx_dft, free_sequence_offsets, max_seq, - batch_tgt, n_predict, n_past_tgt, n_past_dft, has_eos, ctx_sampling, + batch_tgt, n_predict, n_past_tgt, n_past_dft, ctx_sampling, tgt_cgraphs, current_run, spec_past_tgt, spec_past_dft, first_run, orig_offset, - batch_id, batch_dft, n_drafted, drafts, id, kvc_view_dft, iter)) { + batch_id, batch_dft, n_drafted, drafts, id, kvc_view_dft, p_adjust, n_rejected)) { LOG("Ending spec run because returned true\n"); break; } @@ -810,6 +818,10 @@ run_speculation_loop(const gpt_params ¶ms, const float p_accept, llama_conte } } +float calc_p_adjust(const gpt_params ¶ms, int iter, int n_reject) { + return iter * params.p_recovery - std::max(n_reject * params.p_decay, 0.0f); +} + void begin_non_spec_run(const gpt_params ¶ms, const int n_seq_dft, llama_context *ctx, const int max_seq, const std::vector &drafts, llama_token id, int32_t &batch_id, int &n_past, int n_past_dft, @@ -843,11 +855,11 @@ void begin_non_spec_run(const gpt_params ¶ms, const int n_seq_dft, llama_con bool start_async_spec_run(const gpt_params ¶ms, llama_context *ctx_tgt, llama_context *ctx_dft, std::deque &free_sequence_offsets, int max_seq, llama_batch &batch_tgt, int n_predict, - int prefix_n_past, int n_past_dft, bool has_eos, llama_sampling_context *ctx_sampling, + int prefix_n_past, int n_past_dft, llama_sampling_context *ctx_sampling, std::deque &tgt_cgraphs, const seq_async_run ¤t_run, int &spec_past_tgt, int &spec_past_dft, int first_run, int orig_offset, int32_t &batch_id, llama_batch &batch_dft, int &n_drafted, std::vector &drafts, llama_token &id, - llama_kv_cache_view &kvc, const int iter) { + llama_kv_cache_view &kvc, float p_adjust, int &n_reject) { LOG("Doing speculative run, seq_offset = %d, spec_past_tgt = %d, spec_past_dft = %d, prefix_n_past = %d, n_past_dft = %d\n", current_run.seq_offset, spec_past_tgt, spec_past_dft, prefix_n_past, n_past_dft); @@ -958,7 +970,7 @@ bool start_async_spec_run(const gpt_params ¶ms, llama_context *ctx_tgt, llam } - if (cur_p[0].p < params.p_accept + params.p_recovery * iter) { + if (cur_p[0].p < params.p_accept + p_adjust) { LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, params.p_accept); drafts[s].drafting = false; @@ -970,7 +982,7 @@ bool start_async_spec_run(const gpt_params ¶ms, llama_context *ctx_tgt, llam // attempt to split the branch if the probability is high enough for (int f = 1; f < 8; ++f) { - if (n_seq_cur < params.n_parallel - 1 && cur_p[f].p > params.p_split + params.p_recovery * iter) { + if (n_seq_cur < params.n_parallel - 1 && cur_p[f].p > params.p_split + p_adjust) { n_seq_cur++; LOG("splitting seq %3d into %3d\n", s, n_seq_cur); @@ -1093,6 +1105,7 @@ bool start_async_spec_run(const gpt_params ¶ms, llama_context *ctx_tgt, llam // fprintf(stderr, "\nNo tgt tokens, pushing seq offset %d to free seq offsets\n", current_run.seq_offset); // fflush(stderr); free_sequence_offsets.push_back(current_run.seq_offset); + n_reject++; return true; }