diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index fb96dc9b605615..b87c88e627a78b 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -29,6 +29,7 @@ struct seq_draft { struct seq_async_run { struct ggml_cgraph * cgraph; + llama_batch batch; }; int main(int argc, char ** argv) { @@ -220,10 +221,10 @@ int main(int argc, char ** argv) { int s_keep = 0; if (!tgt_cgraphs.empty()) { - LOG("Finishing async decode\n"); + LOG("Finishing async decode, should_run_async = %d\n", should_run_async); struct seq_async_run run = tgt_cgraphs.back(); struct ggml_cgraph * cgraph = run.cgraph; - llama_finish_async_decode(*ctx_tgt, batch_tgt, cgraph); + llama_finish_async_decode(*ctx_tgt, run.batch, cgraph); tgt_cgraphs.pop_back(); } @@ -325,6 +326,7 @@ int main(int argc, char ** argv) { struct seq_async_run run; run.cgraph = llama_start_async_decode(*ctx_tgt, batch_tgt); + run.batch = batch_tgt; tgt_cgraphs.push_front(run); llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_past_tgt + 1); @@ -538,6 +540,7 @@ int main(int argc, char ** argv) { LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); struct seq_async_run run; + run.batch = batch_tgt; run.cgraph = llama_start_async_decode(*ctx_tgt, batch_tgt); tgt_cgraphs.push_front(run);