Skip to content

Commit

Permalink
Fix incorrect initial s_keep for multi-sequence runs
Browse files Browse the repository at this point in the history
  • Loading branch information
AutonomicPerfectionist committed Jan 10, 2024
1 parent 0082332 commit 6b55f04
Showing 1 changed file with 35 additions and 25 deletions.
60 changes: 35 additions & 25 deletions examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,21 +284,24 @@ int main(int argc, char ** argv) {
LOG("Finishing async decode, is async = %d, old seq_offset = %d, new seq offset = %d, batch id = %d\n", run.run_id == ASYNC_RUN_ID, seq_offset, run.seq_offset, run.batch.batch_id);
struct ggml_cgraph * cgraph = run.cgraph;

run_id = run.run_id;
drafts = run.drafts;
run_speculative = run.speculative;
run_max_n_past = run.n_past_max;
// ctx_sampling = run.ctx_sampling;
run_n_past_tgt = run.n_past_tgt;
run_n_past_dft = run.n_past_dft;
// n_past_dft = run.n_past_dft;
seq_offset = run.seq_offset;


LOG("Checking run, last generated: %d, first draft: %d\n", generated.back(), drafts[s_keep].tokens[0]);
LOG("Checking run, last generated: %d, first draft: %d\n", generated.back(), run.drafts[run.s_keep].tokens[0]);
// if(run.n_past_max >= n_past_tgt && (!run_speculative || (n_past_tgt-run_n_past_tgt >= 0 && generated.at(generated.size() - (n_past_tgt-run_n_past_tgt+1)) == drafts[s_keep].tokens[0]))) {

if(!run.canceled) {

run_id = run.run_id;
drafts = run.drafts;
run_speculative = run.speculative;
run_max_n_past = run.n_past_max;
// ctx_sampling = run.ctx_sampling;
run_n_past_tgt = run.n_past_tgt;
run_n_past_dft = run.n_past_dft;
// n_past_dft = run.n_past_dft;
seq_offset = run.seq_offset;
s_keep = run.s_keep;

//drafts[0].tokens.erase(drafts[0].tokens.begin());
for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[s].active) {
Expand All @@ -316,8 +319,8 @@ int main(int argc, char ** argv) {
// }
llama_finish_async_decode(*ctx_tgt, run.batch, cgraph);
tgt_cgraphs.pop_back();
if (run_speculative) {
free_sequence_offsets.push_back(seq_offset);
if (run.speculative) {
free_sequence_offsets.push_back(run.seq_offset);
}
// fprintf(stderr, "Incorrect starting token\n");
continue;
Expand Down Expand Up @@ -359,7 +362,9 @@ int main(int argc, char ** argv) {
int old_n_past_dft = n_past_dft;


std::vector<int> keeps = seq_ids;
std::deque<int> keeps(seq_ids.begin(), seq_ids.end());
keeps.erase(std::find(keeps.begin(), keeps.end(),s_keep));
keeps.push_front(s_keep);
while (!keeps.empty()) {

LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d, run_n_past_tgt = %3d, n_past_tgt = %3d, seq_offset = %d, keeps[0] = %d\n", s_keep, i_dft, drafts[keeps[0]].i_batch_tgt[i_dft], run_n_past_tgt, n_past_tgt, seq_offset, keeps[0]);
Expand Down Expand Up @@ -500,7 +505,7 @@ int main(int argc, char ** argv) {
}
// LOG("Copying tgt sequence %d to %d from positions %d to %d\n", s_keep+seq_offset, 0, run_n_past_tgt, n_past_tgt);
// llama_kv_cache_seq_cp_back (ctx_tgt, s_keep+seq_offset, 0, run_n_past_tgt, n_past_tgt);
llama_kv_cache_seq_cp (ctx_tgt, s_keep+seq_offset, 0, old_n_past_tgt, n_past_tgt);
llama_kv_cache_seq_cp (ctx_tgt, s_keep+seq_offset, 0, run_n_past_tgt, n_past_tgt);

// if (llama_node_id(ctx_tgt) == 0) {
// llama_kv_cache_view_update(ctx_tgt, &kvc_view);
Expand All @@ -521,11 +526,11 @@ int main(int argc, char ** argv) {

// LOG("Copying dft sequence %d to %d from positions %d to %d\n", s_keep+seq_offset, 0, run_n_past_dft, n_past_dft);

llama_kv_cache_seq_cp (ctx_dft, s_keep+seq_offset, 0, old_n_past_dft, n_past_dft);
llama_kv_cache_seq_cp (ctx_dft, s_keep+seq_offset, 0, run_n_past_dft, n_past_dft);
for (int i = 0; i < n_seq_dft; i++) {
// LOG("Removing tgt sequence %d from positions %d to %d\n", i+seq_offset, -1, -1);

llama_kv_cache_seq_rm (ctx_tgt, i+seq_offset, old_n_past_tgt, -1);
llama_kv_cache_seq_rm (ctx_tgt, i+seq_offset, run_n_past_tgt, -1);

// if (llama_node_id(ctx_tgt) == 0) {
// llama_kv_cache_view_update(ctx_tgt, &kvc_view);
Expand All @@ -536,16 +541,16 @@ int main(int argc, char ** argv) {

// LOG("Removing dft sequence %d from positions %d to %d\n", i+seq_offset, -1, -1);

llama_kv_cache_seq_rm (ctx_dft, i+seq_offset, old_n_past_dft, -1);
llama_kv_cache_seq_rm (ctx_dft, i+seq_offset, run_n_past_dft, -1);
}


for (int i = 1; i < max_seq; i++) {
// LOG("Copying tgt sequence %d to %d from positions %d to %d\n", 0, i, -1, n_past_tgt);
// LOG("Copying dft sequence %d to %d from positions %d to %d\n", 0, i, -1, n_past_dft);

llama_kv_cache_seq_rm(ctx_tgt, i, old_n_past_tgt, n_past_tgt);
llama_kv_cache_seq_rm(ctx_dft, i, old_n_past_dft, n_past_dft);
llama_kv_cache_seq_rm(ctx_tgt, i, run_n_past_tgt, n_past_tgt);
llama_kv_cache_seq_rm(ctx_dft, i, run_n_past_dft, n_past_dft);
//
// if (llama_node_id(ctx_tgt) == 0) {
//// llama_kv_cache_view_update(ctx_tgt, &kvc_view);
Expand All @@ -554,7 +559,7 @@ int main(int argc, char ** argv) {
// printf("Removed %d, n_past_tgt: %d, run_n_past_tgt: %d, run_max_n_past: %d, old_n_past: %d\n", i+seq_offset, n_past_tgt, run_n_past_tgt, run_max_n_past, old_n_past_tgt);
// }

llama_kv_cache_seq_cp(ctx_tgt, 0, i, old_n_past_tgt, n_past_tgt);
llama_kv_cache_seq_cp(ctx_tgt, 0, i, run_n_past_tgt, n_past_tgt);

// if (llama_node_id(ctx_tgt) == 0) {
//// llama_kv_cache_view_update(ctx_tgt, &kvc_view);
Expand All @@ -563,7 +568,7 @@ int main(int argc, char ** argv) {
// printf("Copied 0 to %d, n_past_tgt: %d, run_n_past_tgt: %d, run_max_n_past: %d, old_n_past: %d\n", i, n_past_tgt, run_n_past_tgt, run_max_n_past, old_n_past_tgt);
// }

llama_kv_cache_seq_cp(ctx_dft, 0, i, old_n_past_dft, n_past_dft);
llama_kv_cache_seq_cp(ctx_dft, 0, i, run_n_past_dft, n_past_dft);
}

// if (llama_node_id(ctx_tgt) == 0) {
Expand Down Expand Up @@ -615,6 +620,7 @@ int main(int argc, char ** argv) {
++n_past_tgt;
struct seq_async_run run;
run.canceled = false;
run.s_keep = 0;
// if (!free_sequence_offsets.empty()) {
// run.seq_offset = free_sequence_offsets.front();
// printf("Popping %d from seq offsets\n", run.seq_offset);
Expand Down Expand Up @@ -652,7 +658,6 @@ int main(int argc, char ** argv) {
run.drafts[s].prefix_tokens = std::vector<llama_token>(0);
}
run.i_dft = offset - 1;
run.s_keep = s_keep;
run.run_id = ASYNC_RUN_ID;
run.n_past_tgt = n_past_tgt-1;
run.prefix_n_past_tgt = n_past_tgt-1;
Expand Down Expand Up @@ -768,7 +773,7 @@ int main(int argc, char ** argv) {
LOG("Copying tgt sequence %d to %d from positions %d to %d\n", (first_run) ? 0 : orig_offset,
i + seq_offset, -1, (first_run) ? spec_past_tgt : spec_past_tgt);

llama_kv_cache_seq_cp(ctx_tgt, (first_run) ? 0 : orig_offset, i + seq_offset, -1, (first_run) ? spec_past_tgt : spec_past_tgt);
llama_kv_cache_seq_cp(ctx_tgt, (first_run) ? 0 : orig_offset, i + seq_offset, -1, (first_run) ? spec_past_tgt : spec_past_tgt+1);
// if (llama_node_id(ctx_tgt) == 0) {
// llama_kv_cache_view_update(ctx_tgt, &kvc_view);
// dump_kv_cache_view_seqs(kvc_view, 20);
Expand Down Expand Up @@ -1086,7 +1091,7 @@ int main(int argc, char ** argv) {
run.drafts[s].prefix_tokens = drafts[s].prefix_tokens;
}
run.i_dft = offset;
run.s_keep = s_keep;
run.s_keep = 0;
run.batch = llama_batch_init(params.n_ctx, 0, max_seq);
run.batch.batch_id = batch_id;
run.batch.n_tokens = batch_tgt.n_tokens;
Expand Down Expand Up @@ -1182,9 +1187,11 @@ void check_for_cancel(llama_context *ctx_tgt, int n_past_tgt, std::deque<struct
bool correct_prefix = true;

if (run.speculative && n_past_tgt >= run.prefix_n_past_tgt) {
for (int draft_id = 0; draft_id < n_seq_dft; draft_id++) {
for (int draft_id = n_seq_dft - 1; draft_id >= 0; draft_id--) {
if (!run.drafts[draft_id].tokens.empty()) {
correct_prefix = true;
} else {
continue;
}
size_t draft_index = 0;
int prev_token = -1;
Expand Down Expand Up @@ -1217,6 +1224,9 @@ void check_for_cancel(llama_context *ctx_tgt, int n_past_tgt, std::deque<struct
draft_index++;
index = run.prefix_n_past_tgt + draft_index;
}
if (correct_prefix) {
run.s_keep = draft_id;
}
}
}

Expand Down

0 comments on commit 6b55f04

Please sign in to comment.