Skip to content

Commit

Permalink
Another partially working version
Browse files Browse the repository at this point in the history
  • Loading branch information
AutonomicPerfectionist committed Nov 17, 2023
1 parent 8b0aa86 commit 046ef60
Showing 1 changed file with 22 additions and 13 deletions.
35 changes: 22 additions & 13 deletions examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,9 @@ int main(int argc, char ** argv) {
if (run_id == ASYNC_RUN_ID) {
llama_kv_cache_seq_cp (ctx_tgt, run_id, 0, -1, n_past_tgt);

} else {
llama_kv_cache_seq_cp (ctx_tgt, run_id, ASYNC_RUN_ID, -1, n_past_tgt);

}
}

Expand Down Expand Up @@ -338,29 +341,35 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_dft, 0);

for (int i = 0; i < n_seq_dft; i++) {
if (run_id == ASYNC_RUN_ID) {
llama_kv_cache_seq_rm(ctx_tgt, i + 0, n_past_tgt, -1);
} else {
// llama_kv_cache_seq_rm(ctx_tgt, i + ASYNC_RUN_ID, n_past_tgt, -1);
// llama_kv_cache_seq_rm (ctx_tgt, s_keep+run_id, n_past_tgt, -1);
llama_kv_cache_seq_rm (ctx_tgt, 0, n_past_tgt, -1); // Forces "create" and nothing else effects it


for (int i = 0; i < n_seq_dft; i++) {
// if (run_id == ASYNC_RUN_ID) {
// llama_kv_cache_seq_rm(ctx_tgt, i + 0, n_past_tgt, -1);
// } else {
//// llama_kv_cache_seq_rm(ctx_tgt, i + ASYNC_RUN_ID, n_past_tgt, -1);
//
// }
if (i != s_keep) {
// llama_kv_cache_seq_rm(ctx_tgt, i + 0, -1, n_past_tgt);
}
// llama_kv_cache_seq_rm (ctx_tgt, i+run_id, n_past_tgt, -1);

}
// llama_kv_cache_seq_keep(ctx_tgt, s_keep);
llama_kv_cache_seq_cp (ctx_tgt, s_keep+run_id, run_id, -1, n_past_tgt);
// llama_kv_cache_seq_cp (ctx_tgt, s_keep+run_id, run_id, -1, n_past_tgt);
// llama_kv_cache_seq_keep(ctx_tgt, 0);
for (int i = 1; i < n_seq_dft; i++) {
// llama_kv_cache_seq_rm (ctx_tgt, i+ASYNC_RUN_ID, -1, n_past_tgt);
llama_kv_cache_seq_rm (ctx_tgt, i+run_id, -1, n_past_tgt);
// llama_kv_cache_seq_rm (ctx_tgt, i+run_id, -1, n_past_tgt);

}
llama_kv_cache_seq_rm (ctx_tgt, run_id, n_past_tgt, n_past_tgt+2);
// llama_kv_cache_seq_rm (ctx_tgt, run_id, n_past_tgt, n_past_tgt+2);
// llama_kv_cache_seq_rm (ctx_tgt, 0, n_past_tgt, n_past_tgt+2);


llama_kv_cache_seq_cp (ctx_tgt, run_id, 0, -1, n_past_tgt);
// llama_kv_cache_seq_cp (ctx_tgt, run_id, 0, -1, n_past_tgt);

}

Expand All @@ -372,10 +381,10 @@ int main(int argc, char ** argv) {


for (int i = 0; i < n_seq_dft; i++) {
llama_kv_cache_seq_rm (ctx_tgt, i+ASYNC_RUN_ID, n_past_tgt, -1);
// llama_kv_cache_seq_rm (ctx_tgt, i+ASYNC_RUN_ID, n_past_tgt, -1);
}

llama_kv_cache_seq_cp (ctx_tgt, run_id, ASYNC_RUN_ID, -1, n_past_tgt);
// llama_kv_cache_seq_cp (ctx_tgt, run_id, ASYNC_RUN_ID, -1, n_past_tgt);

// llama_kv_cache_seq_keep(ctx_tgt, s_keep);
// llama_kv_cache_seq_cp (ctx_tgt, s_keep+run_id, ASYNC_RUN_ID, -1, n_past_tgt);
Expand Down Expand Up @@ -562,7 +571,7 @@ int main(int argc, char ** argv) {
// add unique drafted tokens to the target batch
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);

llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s,s+ASYNC_RUN_ID }, true);
llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);

// add the token to the batch for batched decoding with the draft model
drafts[s].i_batch_dft = batch_dft.n_tokens;
Expand Down

0 comments on commit 046ef60

Please sign in to comment.