Skip to content

Commit

Permalink
Add p_decay
Browse files Browse the repository at this point in the history
  • Loading branch information
AutonomicPerfectionist committed Feb 3, 2024
1 parent d23b996 commit d6a70a9
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 25 deletions.
11 changes: 9 additions & 2 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,12 +448,18 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.p_split = std::stof(argv[i]);
}else if (arg == "--p-recovery" || arg == "-pr") {
} else if (arg == "--p-recovery" || arg == "-pr") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.p_recovery = std::stof(argv[i]);
} else if (arg == "--p-decay" || arg == "-pd") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.p_decay = std::stof(argv[i]);
} else if (arg == "-m" || arg == "--model") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -860,7 +866,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
printf(" -pa N, --p-accept N speculative decoding accept probability (default: %.1f)\n", (double)params.p_accept);
printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split);
printf(" -pr N, --p-recovery N PipeInfer recovery probability (default: %.1f)\n", (double)params.p_recovery);
printf(" -pr N, --p-recovery N PipeInfer probability recovery (default: %.1f)\n", (double)params.p_recovery);
printf(" -pd N, --p-decay N PipeInfer probability decay (default: %.1f)\n", (double)params.p_decay);
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
printf(" --image IMAGE_FILE path to an image file. use with multimodal models\n");
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ struct gpt_params {
float p_accept = 0.5f; // speculative decoding accept probability
float p_split = 0.1f; // speculative decoding split probability
float p_recovery = 0.0f; // Cumulative probability that p_accept and p_split are increased by per-iteration.
float p_decay = 0.0f; // Cumulative probability that p_accept and p_split are decreased by per-iteration when drafting stops due to p_accept.
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
Expand Down
59 changes: 36 additions & 23 deletions examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,25 @@ void begin_async_run(const llama_sampling_params& sparams, int n_seq_dft,

bool start_async_spec_run(const gpt_params &params, llama_context *ctx_tgt, llama_context *ctx_dft,
std::deque<int> &free_sequence_offsets, int max_seq, llama_batch &batch_tgt, int n_predict,
int prefix_n_past, int n_past_dft, bool has_eos, llama_sampling_context *ctx_sampling,
int prefix_n_past, int n_past_dft, llama_sampling_context *ctx_sampling,
std::deque<struct seq_async_run> &tgt_cgraphs, const seq_async_run &current_run,
int &spec_past_tgt, int &spec_past_dft, int first_run, int orig_offset, int32_t &batch_id,
llama_batch &batch_dft, int &n_drafted, std::vector<seq_draft> &drafts, llama_token &id,
llama_kv_cache_view &kvc, int iter);
llama_kv_cache_view &kvc, float p_adjust, int &n_reject);

void begin_non_spec_run(const gpt_params &params, int n_seq_dft, llama_context *ctx, int max_seq,
const std::vector<seq_draft> &drafts, llama_token id, int32_t &batch_id, int &n_past, int n_past_dft,
std::deque<struct seq_async_run> &dft_cgraphs, llama_kv_cache_view &kvc_view);

void
run_speculation_loop(const gpt_params &params, const float p_accept, llama_context *ctx_tgt, llama_context *ctx_dft,
int max_seq, llama_batch &batch_tgt, int n_predict, int n_past_tgt, int n_past_dft,
bool has_eos, llama_sampling_context *ctx_sampling, int & spec_past_tgt, int & spec_past_dft,
bool & first_run, std::deque<int> &free_sequence_offsets, int32_t &batch_id, llama_batch &batch_dft,
int &n_drafted, std::vector<seq_draft> &drafts, std::deque<struct seq_async_run> &tgt_cgraphs,
seq_async_run &current_run, llama_kv_cache_view &kvc_view_dft, llama_token &id);
const int max_seq, llama_batch &batch_tgt, int n_predict, int n_past_tgt, int n_past_dft,
llama_sampling_context *ctx_sampling, int &spec_past_tgt, int &spec_past_dft, bool &first_run,
std::deque<int> &free_sequence_offsets, int32_t &batch_id, llama_batch &batch_dft, int &n_drafted,
std::vector<seq_draft> &drafts, std::deque<struct seq_async_run> &tgt_cgraphs,
seq_async_run &current_run, llama_kv_cache_view &kvc_view_dft, llama_token &id, int &n_rejected);

float calc_p_adjust(const gpt_params &params, int iter, int n_reject);

int main(int argc, char ** argv) {
gpt_params params;
Expand Down Expand Up @@ -307,6 +309,9 @@ int main(int argc, char ** argv) {

bool first_run = true;
llama_token id;

int n_rejected = 0;

while (true) {


Expand Down Expand Up @@ -389,10 +394,10 @@ int main(int argc, char ** argv) {
first_run = true;

} else if (!tgt_cgraphs.empty()) {
run_speculation_loop(params, p_accept, ctx_tgt, ctx_dft, max_seq, batch_tgt, n_predict, n_past_tgt, n_past_dft,
has_eos, ctx_sampling,
run_speculation_loop(params, p_accept, ctx_tgt, ctx_dft, max_seq, batch_tgt, n_predict, n_past_tgt,
n_past_dft, ctx_sampling,
spec_past_tgt, spec_past_dft, first_run, free_sequence_offsets, batch_id, batch_dft,
n_drafted, drafts, tgt_cgraphs, current_run, kvc_view_dft, id);
n_drafted, drafts, tgt_cgraphs, current_run, kvc_view_dft, id, n_rejected);
continue;
}

Expand Down Expand Up @@ -570,6 +575,8 @@ int main(int argc, char ** argv) {
continue;
}

n_rejected = 0;

check_for_cancel(ctx_tgt, n_past_tgt, tgt_cgraphs, generated, n_seq_dft);


Expand Down Expand Up @@ -656,9 +663,9 @@ int main(int argc, char ** argv) {
// bool is_waiting = false;

run_speculation_loop(params, p_accept, ctx_tgt, ctx_dft, max_seq, batch_tgt, n_predict, n_past_tgt, n_past_dft,
has_eos, ctx_sampling,
ctx_sampling,
spec_past_tgt, spec_past_dft, first_run, free_sequence_offsets, batch_id, batch_dft,
n_drafted, drafts, tgt_cgraphs, current_run, kvc_view_dft, id);
n_drafted, drafts, tgt_cgraphs, current_run, kvc_view_dft, id, n_rejected);


if (n_predict > params.n_predict || has_eos) {
Expand Down Expand Up @@ -733,10 +740,10 @@ int main(int argc, char ** argv) {
void
run_speculation_loop(const gpt_params &params, const float p_accept, llama_context *ctx_tgt, llama_context *ctx_dft,
const int max_seq, llama_batch &batch_tgt, int n_predict, int n_past_tgt, int n_past_dft,
bool has_eos, llama_sampling_context *ctx_sampling, int &spec_past_tgt, int &spec_past_dft,
bool & first_run, std::deque<int> &free_sequence_offsets, int32_t &batch_id, llama_batch &batch_dft,
int &n_drafted, std::vector<seq_draft> &drafts, std::deque<struct seq_async_run> &tgt_cgraphs,
seq_async_run &current_run, llama_kv_cache_view &kvc_view_dft, llama_token &id) {
llama_sampling_context *ctx_sampling, int &spec_past_tgt, int &spec_past_dft, bool &first_run,
std::deque<int> &free_sequence_offsets, int32_t &batch_id, llama_batch &batch_dft, int &n_drafted,
std::vector<seq_draft> &drafts, std::deque<struct seq_async_run> &tgt_cgraphs,
seq_async_run &current_run, llama_kv_cache_view &kvc_view_dft, llama_token &id, int &n_rejected) {
bool is_waiting = llama_mpi_iprobe(ctx_tgt);
llama_swap_comm(ctx_tgt);
llama_sync_token(ctx_tgt, reinterpret_cast<llama_token *>(&is_waiting), 0);
Expand All @@ -749,7 +756,8 @@ run_speculation_loop(const gpt_params &params, const float p_accept, llama_conte
free_sequence_offsets.push_back(current_run.seq_offset);
}
int iter = 0;
while((!is_waiting && (p_accept + iter * params.p_recovery) < 1.0)) {
float p_adjust = calc_p_adjust(params, iter, n_rejected);
while((!is_waiting && p_accept + (p_adjust = calc_p_adjust(params, iter, n_rejected)) < 1.0)) {



Expand Down Expand Up @@ -790,10 +798,10 @@ run_speculation_loop(const gpt_params &params, const float p_accept, llama_conte


if (start_async_spec_run(params, ctx_tgt, ctx_dft, free_sequence_offsets, max_seq,
batch_tgt, n_predict, n_past_tgt, n_past_dft, has_eos, ctx_sampling,
batch_tgt, n_predict, n_past_tgt, n_past_dft, ctx_sampling,
tgt_cgraphs,
current_run, spec_past_tgt, spec_past_dft, first_run, orig_offset,
batch_id, batch_dft, n_drafted, drafts, id, kvc_view_dft, iter)) {
batch_id, batch_dft, n_drafted, drafts, id, kvc_view_dft, p_adjust, n_rejected)) {
LOG("Ending spec run because returned true\n");
break;
}
Expand All @@ -810,6 +818,10 @@ run_speculation_loop(const gpt_params &params, const float p_accept, llama_conte
}
}

float calc_p_adjust(const gpt_params &params, int iter, int n_reject) {
return iter * params.p_recovery - std::max(n_reject * params.p_decay, 0.0f);
}

void begin_non_spec_run(const gpt_params &params, const int n_seq_dft, llama_context *ctx, const int max_seq,
const std::vector<seq_draft> &drafts, llama_token id, int32_t &batch_id, int &n_past,
int n_past_dft,
Expand Down Expand Up @@ -843,11 +855,11 @@ void begin_non_spec_run(const gpt_params &params, const int n_seq_dft, llama_con

bool start_async_spec_run(const gpt_params &params, llama_context *ctx_tgt, llama_context *ctx_dft,
std::deque<int> &free_sequence_offsets, int max_seq, llama_batch &batch_tgt, int n_predict,
int prefix_n_past, int n_past_dft, bool has_eos, llama_sampling_context *ctx_sampling,
int prefix_n_past, int n_past_dft, llama_sampling_context *ctx_sampling,
std::deque<struct seq_async_run> &tgt_cgraphs, const seq_async_run &current_run,
int &spec_past_tgt, int &spec_past_dft, int first_run, int orig_offset, int32_t &batch_id,
llama_batch &batch_dft, int &n_drafted, std::vector<seq_draft> &drafts, llama_token &id,
llama_kv_cache_view &kvc, const int iter) {
llama_kv_cache_view &kvc, float p_adjust, int &n_reject) {
LOG("Doing speculative run, seq_offset = %d, spec_past_tgt = %d, spec_past_dft = %d, prefix_n_past = %d, n_past_dft = %d\n",
current_run.seq_offset, spec_past_tgt, spec_past_dft, prefix_n_past, n_past_dft);

Expand Down Expand Up @@ -958,7 +970,7 @@ bool start_async_spec_run(const gpt_params &params, llama_context *ctx_tgt, llam
}


if (cur_p[0].p < params.p_accept + params.p_recovery * iter) {
if (cur_p[0].p < params.p_accept + p_adjust) {
LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p,
params.p_accept);
drafts[s].drafting = false;
Expand All @@ -970,7 +982,7 @@ bool start_async_spec_run(const gpt_params &params, llama_context *ctx_tgt, llam

// attempt to split the branch if the probability is high enough
for (int f = 1; f < 8; ++f) {
if (n_seq_cur < params.n_parallel - 1 && cur_p[f].p > params.p_split + params.p_recovery * iter) {
if (n_seq_cur < params.n_parallel - 1 && cur_p[f].p > params.p_split + p_adjust) {
n_seq_cur++;
LOG("splitting seq %3d into %3d\n", s, n_seq_cur);

Expand Down Expand Up @@ -1093,6 +1105,7 @@ bool start_async_spec_run(const gpt_params &params, llama_context *ctx_tgt, llam
// fprintf(stderr, "\nNo tgt tokens, pushing seq offset %d to free seq offsets\n", current_run.seq_offset);
// fflush(stderr);
free_sequence_offsets.push_back(current_run.seq_offset);
n_reject++;
return true;
}

Expand Down

0 comments on commit d6a70a9

Please sign in to comment.