diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index a0e5d06169edac..4e62e64cfc48e8 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -236,6 +236,9 @@ int main(int argc, char ** argv) { if (run_id == ASYNC_RUN_ID) { llama_kv_cache_seq_cp (ctx_tgt, run_id, 0, -1, n_past_tgt); + } else { + llama_kv_cache_seq_cp (ctx_tgt, run_id, ASYNC_RUN_ID, -1, n_past_tgt); + } } @@ -338,29 +341,35 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1); llama_kv_cache_seq_keep(ctx_dft, 0); - for (int i = 0; i < n_seq_dft; i++) { - if (run_id == ASYNC_RUN_ID) { - llama_kv_cache_seq_rm(ctx_tgt, i + 0, n_past_tgt, -1); - } else { -// llama_kv_cache_seq_rm(ctx_tgt, i + ASYNC_RUN_ID, n_past_tgt, -1); +// llama_kv_cache_seq_rm (ctx_tgt, s_keep+run_id, n_past_tgt, -1); + llama_kv_cache_seq_rm (ctx_tgt, 0, n_past_tgt, -1); // Forces "create" and nothing else effects it + + for (int i = 0; i < n_seq_dft; i++) { +// if (run_id == ASYNC_RUN_ID) { +// llama_kv_cache_seq_rm(ctx_tgt, i + 0, n_past_tgt, -1); +// } else { +//// llama_kv_cache_seq_rm(ctx_tgt, i + ASYNC_RUN_ID, n_past_tgt, -1); +// +// } + if (i != s_keep) { +// llama_kv_cache_seq_rm(ctx_tgt, i + 0, -1, n_past_tgt); } -// llama_kv_cache_seq_rm (ctx_tgt, i+run_id, n_past_tgt, -1); } // llama_kv_cache_seq_keep(ctx_tgt, s_keep); - llama_kv_cache_seq_cp (ctx_tgt, s_keep+run_id, run_id, -1, n_past_tgt); +// llama_kv_cache_seq_cp (ctx_tgt, s_keep+run_id, run_id, -1, n_past_tgt); // llama_kv_cache_seq_keep(ctx_tgt, 0); for (int i = 1; i < n_seq_dft; i++) { // llama_kv_cache_seq_rm (ctx_tgt, i+ASYNC_RUN_ID, -1, n_past_tgt); - llama_kv_cache_seq_rm (ctx_tgt, i+run_id, -1, n_past_tgt); +// llama_kv_cache_seq_rm (ctx_tgt, i+run_id, -1, n_past_tgt); } - llama_kv_cache_seq_rm (ctx_tgt, run_id, n_past_tgt, n_past_tgt+2); +// llama_kv_cache_seq_rm (ctx_tgt, run_id, n_past_tgt, n_past_tgt+2); // llama_kv_cache_seq_rm (ctx_tgt, 0, n_past_tgt, n_past_tgt+2); - llama_kv_cache_seq_cp (ctx_tgt, run_id, 0, -1, n_past_tgt); +// llama_kv_cache_seq_cp (ctx_tgt, run_id, 0, -1, n_past_tgt); } @@ -372,10 +381,10 @@ int main(int argc, char ** argv) { for (int i = 0; i < n_seq_dft; i++) { - llama_kv_cache_seq_rm (ctx_tgt, i+ASYNC_RUN_ID, n_past_tgt, -1); +// llama_kv_cache_seq_rm (ctx_tgt, i+ASYNC_RUN_ID, n_past_tgt, -1); } - llama_kv_cache_seq_cp (ctx_tgt, run_id, ASYNC_RUN_ID, -1, n_past_tgt); +// llama_kv_cache_seq_cp (ctx_tgt, run_id, ASYNC_RUN_ID, -1, n_past_tgt); // llama_kv_cache_seq_keep(ctx_tgt, s_keep); // llama_kv_cache_seq_cp (ctx_tgt, s_keep+run_id, ASYNC_RUN_ID, -1, n_past_tgt); @@ -562,7 +571,7 @@ int main(int argc, char ** argv) { // add unique drafted tokens to the target batch drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens); - llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s,s+ASYNC_RUN_ID }, true); + llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true); // add the token to the batch for batched decoding with the draft model drafts[s].i_batch_dft = batch_dft.n_tokens;