Skip to content

Commit

Permalink
Mostly working async
Browse files Browse the repository at this point in the history
  • Loading branch information
AutonomicPerfectionist committed Nov 17, 2023
1 parent 6903128 commit 8b0aa86
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 71 deletions.
179 changes: 122 additions & 57 deletions examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ struct seq_draft {
struct seq_async_run {
struct ggml_cgraph * cgraph;
llama_batch batch;
std::vector<seq_draft> drafts;
int run_id;
int n_past_tgt;
};

int main(int argc, char ** argv) {
Expand Down Expand Up @@ -174,6 +177,7 @@ int main(int argc, char ** argv) {
int n_drafted = 0;
int n_accept = 0;

const int ASYNC_RUN_ID = n_seq_dft+1;
int n_past_tgt = inp.size();
int n_past_dft = inp.size();

Expand Down Expand Up @@ -205,6 +209,8 @@ int main(int argc, char ** argv) {
drafts[0].i_batch_tgt.resize(1);
drafts[0].i_batch_tgt[0] = 0;

int run_id = 0;

while (true) {
// print current draft sequences
for (int s = 0; s < n_seq_dft; ++s) {
Expand All @@ -226,31 +232,47 @@ int main(int argc, char ** argv) {
struct ggml_cgraph * cgraph = run.cgraph;
llama_finish_async_decode(*ctx_tgt, run.batch, cgraph);
tgt_cgraphs.pop_back();
run_id = run.run_id;
if (run_id == ASYNC_RUN_ID) {
llama_kv_cache_seq_cp (ctx_tgt, run_id, 0, -1, n_past_tgt);

}
}

llama_token id;
std::string token_str;
while (true) {
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);



// sample from the target model
llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);

id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
// Swap to pipeline roots
llama_swap_comm(ctx_tgt);
LOG("Swapped comm to pipeline roots, id %d\n", llama_node_id(ctx_tgt));

llama_sync_token(ctx_tgt, &id, 0);


LOG("Is async: %d\n", !should_run_async);
LOG("Sampling index: %d\n", drafts[s_keep].i_batch_tgt[i_dft]);
llama_sampling_accept(ctx_sampling, ctx_tgt, id, true);

//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());

const int n_vocab = llama_n_vocab(llama_get_model(ctx_tgt));
float * logits = llama_get_logits_ith(ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);

LOG("logits:\n");
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
LOG("\t%d: %.4f\n", token_id, logits[token_id]);
}

// Root of WORLD
std::string token_str;

if (llama_node_id(ctx_tgt) == 0) {
std::string token_str = llama_token_to_piece(ctx_tgt, id);
token_str = llama_token_to_piece(ctx_tgt, id);
LOG("Sampled token: %d ('%s'), n_past_tgt: %d\n", id, token_str.c_str(), n_past_tgt);
printf("%s", token_str.c_str());
fflush(stdout);
}
Expand All @@ -267,7 +289,7 @@ int main(int argc, char ** argv) {
++n_predict;

// check if the target token matches any of the drafts
{
if(should_run_async){ // Only running this when should_run_async starts out okay but still goes off the rails eventually
bool matches = false;

for (int s = 0; s < n_seq_dft; ++s) {
Expand Down Expand Up @@ -296,76 +318,116 @@ int main(int argc, char ** argv) {
}


if (llama_node_id(ctx_tgt) < 0) {
LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());

}

break;
}

// TODO: simplify
{
LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
if (llama_node_id(ctx_tgt) < 0) {
LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());

// Pipeline syncing cache ops
llama_kv_cache_seq_keep(ctx_dft, s_keep);
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_dft, 0);
}

llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
llama_kv_cache_seq_keep(ctx_tgt, s_keep);
llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_tgt, 0);
}

if (should_run_async) {
LOG("Beginning async decode\n");
llama_batch_clear(batch_tgt);
llama_batch_add(batch_tgt, id, n_past_tgt, {0}, true);
// batch_tgt.n_tokens = 1
// TODO: simplify
{
LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);

// Pipeline syncing cache ops
llama_kv_cache_seq_keep(ctx_dft, s_keep);
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_dft, 0);

for (int i = 0; i < n_seq_dft; i++) {
if (run_id == ASYNC_RUN_ID) {
llama_kv_cache_seq_rm(ctx_tgt, i + 0, n_past_tgt, -1);
} else {
// llama_kv_cache_seq_rm(ctx_tgt, i + ASYNC_RUN_ID, n_past_tgt, -1);

struct seq_async_run run;
run.cgraph = llama_start_async_decode(*ctx_tgt, batch_tgt);
run.batch = batch_tgt;
tgt_cgraphs.push_front(run);
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_past_tgt + 1);
}
// llama_kv_cache_seq_rm (ctx_tgt, i+run_id, n_past_tgt, -1);

}
// llama_kv_cache_seq_keep(ctx_tgt, s_keep);
llama_kv_cache_seq_cp (ctx_tgt, s_keep+run_id, run_id, -1, n_past_tgt);
// llama_kv_cache_seq_keep(ctx_tgt, 0);
for (int i = 1; i < n_seq_dft; i++) {
// llama_kv_cache_seq_rm (ctx_tgt, i+ASYNC_RUN_ID, -1, n_past_tgt);
llama_kv_cache_seq_rm (ctx_tgt, i+run_id, -1, n_past_tgt);

}
llama_kv_cache_seq_rm (ctx_tgt, run_id, n_past_tgt, n_past_tgt+2);
// llama_kv_cache_seq_rm (ctx_tgt, 0, n_past_tgt, n_past_tgt+2);


should_run_async = !should_run_async;
llama_kv_cache_seq_cp (ctx_tgt, run_id, 0, -1, n_past_tgt);

}

if (should_run_async) {
// LOG("Beginning async decode\n");
llama_batch_clear(batch_tgt);
llama_batch_add (batch_tgt, id, n_past_tgt, { 0 }, true);
llama_batch_add(batch_tgt, id, n_past_tgt, {ASYNC_RUN_ID}, true);
// batch_tgt.n_tokens = 1

for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].active = false;
drafts[s].tokens.clear();
drafts[s].i_batch_tgt.clear();

for (int i = 0; i < n_seq_dft; i++) {
llama_kv_cache_seq_rm (ctx_tgt, i+ASYNC_RUN_ID, n_past_tgt, -1);
}
// note: will be erased after the speculation phase
drafts[0].tokens.push_back(id);
drafts[0].i_batch_tgt.push_back(0);

llama_batch_clear(batch_dft);
llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
// batch_dft.n_tokens == 1 now
llama_kv_cache_seq_cp (ctx_tgt, run_id, ASYNC_RUN_ID, -1, n_past_tgt);

// Pipeline sync on draft pipeline
// llama_kv_cache_seq_keep(ctx_tgt, s_keep);
// llama_kv_cache_seq_cp (ctx_tgt, s_keep+run_id, ASYNC_RUN_ID, -1, n_past_tgt);
// llama_kv_cache_seq_keep(ctx_tgt, 0);
for (int i = 1; i < n_seq_dft; i++) {
// llama_kv_cache_seq_rm (ctx_tgt, i+ASYNC_RUN_ID, -1, n_past_tgt);
}
// llama_kv_cache_seq_rm (ctx_tgt, ASYNC_RUN_ID, n_past_tgt, n_past_tgt+2);

// Remove all tokens from all sequences after n_past_dft
llama_kv_cache_seq_rm(ctx_dft, -1, n_past_dft, -1);

// Kick off drafting pipeline but don't need it just yet
LOG("Beginning async draft\n");
dft_cgraphs.push_front(llama_start_async_decode(*ctx_dft, batch_dft));
//llama_decode(ctx_dft, batch_dft);
// DON'T FORGET THE MATCHING DECODE WHEN NEEDED

++n_past_dft;
struct seq_async_run run;
run.cgraph = llama_start_async_decode(*ctx_tgt, batch_tgt);
run.batch = batch_tgt;
run.run_id = ASYNC_RUN_ID;
run.n_past_tgt = n_past_tgt;
tgt_cgraphs.push_front(run);
// llama_kv_cache_seq_rm(ctx_tgt, ASYNC_RUN_ID, n_past_tgt, n_past_tgt + 2);

break;
}

should_run_async = !should_run_async;

llama_batch_clear(batch_tgt);
llama_batch_add (batch_tgt, id, n_past_tgt, { 0 }, true);

for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].active = false;
drafts[s].tokens.clear();
drafts[s].i_batch_tgt.clear();
}
// note: will be erased after the speculation phase
drafts[0].tokens.push_back(id);
drafts[0].i_batch_tgt.push_back(0);

llama_batch_clear(batch_dft);
llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
// batch_dft.n_tokens == 1 now

// Pipeline sync on draft pipeline

// Remove all tokens from all sequences after n_past_dft
llama_kv_cache_seq_rm(ctx_dft, -1, n_past_dft, -1);

// Kick off drafting pipeline but don't need it just yet
LOG("Beginning async draft\n");
dft_cgraphs.push_front(llama_start_async_decode(*ctx_dft, batch_dft));
//llama_decode(ctx_dft, batch_dft);
// DON'T FORGET THE MATCHING DECODE WHEN NEEDED

++n_past_dft;

if (n_predict > params.n_predict || has_eos) {
break;
}
Expand Down Expand Up @@ -487,6 +549,7 @@ int main(int argc, char ** argv) {
}

// add drafted token for each sequence
// TODO commenting this out fixes async
for (int is = 0; is < (int) sa.size(); ++is) {
const llama_token id = cur_p[is].id;

Expand All @@ -499,7 +562,7 @@ int main(int argc, char ** argv) {
// add unique drafted tokens to the target batch
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);

llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s,s+ASYNC_RUN_ID }, true);

// add the token to the batch for batched decoding with the draft model
drafts[s].i_batch_dft = batch_dft.n_tokens;
Expand Down Expand Up @@ -530,9 +593,9 @@ int main(int argc, char ** argv) {

// evaluate the target model on the drafted tokens
{
llama_kv_cache_seq_keep(ctx_tgt, 0);
// llama_kv_cache_seq_keep(ctx_tgt, 0); // Needed to get to "Here's the code:"
for (int s = 1; s < n_seq_dft; ++s) {
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
llama_kv_cache_seq_cp(ctx_tgt, run_id, s+run_id, -1, n_past_tgt);
}

++n_past_tgt;
Expand All @@ -541,6 +604,8 @@ int main(int argc, char ** argv) {
LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
struct seq_async_run run;
run.batch = batch_tgt;
run.run_id = 0;
run.n_past_tgt = n_past_tgt;
run.cgraph = llama_start_async_decode(*ctx_tgt, batch_tgt);
tgt_cgraphs.push_front(run);

Expand Down
52 changes: 43 additions & 9 deletions ggml-mpi.c
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,21 @@ int ggml_mpi_next_node(struct ggml_mpi_context * ctx_mpi) {
return (ctx_mpi->rank + 1) % ctx_mpi->size;
}

void ggml_mpi_sync_pipelined_recv(
struct ggml_mpi_context * ctx_mpi,
void * val,
int count,
MPI_Datatype datatype,
int tag
) {
if(ctx_mpi->comm == MPI_COMM_NULL) {
return;
}
MPI_Recv(val, count, datatype, ctx_mpi->rank - 1, tag, ctx_mpi->comm, MPI_STATUS_IGNORE);

}


void ggml_mpi_sync_pipelined(
struct ggml_mpi_context * ctx_mpi,
void * val,
Expand All @@ -141,7 +156,7 @@ void ggml_mpi_sync_pipelined(
if (ctx_mpi->rank != 0) {
MPI_Recv(val, count, datatype, ctx_mpi->rank - 1, tag, ctx_mpi->comm, MPI_STATUS_IGNORE);
}
if(ctx_mpi->rank < ctx_mpi->size - 1) {
if(ctx_mpi->rank < ctx_mpi->size) {
const int retval = MPI_Bsend(val, count, datatype, ggml_mpi_next_node(ctx_mpi), tag, ctx_mpi->comm);
GGML_ASSERT(retval == MPI_SUCCESS);

Expand All @@ -151,16 +166,23 @@ void ggml_mpi_sync_pipelined(
bool ggml_mpi_eval_init(
struct ggml_mpi_context * ctx_mpi,
int32_t * n_tokens,
int32_t ** tokens,
int32_t ** pos,
int32_t ** n_seq_ids,
int32_t *** seq_id,
int8_t ** logits) {
int8_t ** logits,
bool receive_only) {
if(ctx_mpi->comm == MPI_COMM_NULL) {
return false;
}
int32_t old_n_tokens = *n_tokens;

ggml_mpi_sync_pipelined(ctx_mpi, n_tokens, 1, MPI_INT, 0);
if (receive_only) {
ggml_mpi_sync_pipelined_recv(ctx_mpi, n_tokens, 1, MPI_INT, 0);

} else {
ggml_mpi_sync_pipelined(ctx_mpi, n_tokens, 1, MPI_INT, 0);
}

// If what was passed in differs from what was broadcast,
// we can't guarantee the allocated sizes are correct
Expand All @@ -169,14 +191,22 @@ bool ggml_mpi_eval_init(
if (old_n_tokens != *n_tokens) {
*pos = realloc(*pos, *n_tokens * sizeof(int32_t));
*n_seq_ids = realloc(*n_seq_ids, *n_tokens * sizeof(int32_t ));
*logits = realloc(*logits, *n_tokens * sizeof(int32_t));
*tokens = realloc(*tokens, *n_tokens * sizeof(int32_t ));
}

if (receive_only) {
ggml_mpi_sync_pipelined_recv(ctx_mpi, *tokens, *n_tokens, MPI_INT32_T, 0);

} else {
ggml_mpi_sync_pipelined(ctx_mpi, *tokens, *n_tokens, MPI_INT32_T, 0);
}

// MPI_Bcast(&total_n_seq_ids, 1, MPI_INT32_T, 0, ctx_mpi->comm);
ggml_mpi_sync_pipelined(ctx_mpi, *n_seq_ids, *n_tokens, MPI_INT32_T, 0);

if (receive_only) {
ggml_mpi_sync_pipelined_recv(ctx_mpi, *n_seq_ids, *n_tokens, MPI_INT32_T, 0);
} else {
ggml_mpi_sync_pipelined(ctx_mpi, *n_seq_ids, *n_tokens, MPI_INT32_T, 0);
}
// We need to know the total number of sequence
// ids, so we count them all up
int32_t total_n_seq_ids = 0;
Expand All @@ -201,9 +231,13 @@ bool ggml_mpi_eval_init(
}


ggml_mpi_sync_pipelined(ctx_mpi, *pos, *n_tokens, MPI_INT32_T, 0);
ggml_mpi_sync_pipelined(ctx_mpi, flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, 0);
//MPI_Bcast(*logits, *n_tokens, MPI_INT8_T, 0, ctx_mpi->comm);
if (receive_only) {
ggml_mpi_sync_pipelined_recv(ctx_mpi, *pos, *n_tokens, MPI_INT32_T, 0);
ggml_mpi_sync_pipelined_recv(ctx_mpi, flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, 0);
} else {
ggml_mpi_sync_pipelined(ctx_mpi, *pos, *n_tokens, MPI_INT32_T, 0);
ggml_mpi_sync_pipelined(ctx_mpi, flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, 0);
}
int32_t ** new_seq_id = calloc(*n_tokens, sizeof(int32_t*));
current_index = 0;
for (int32_t i = 0; i < *n_tokens; i++) {
Expand Down
Loading

0 comments on commit 8b0aa86

Please sign in to comment.