Skip to content

Commit

Permalink
Fix draft thread args and remove grads from mpi eval_init
Browse files Browse the repository at this point in the history
  • Loading branch information
AutonomicPerfectionist committed Feb 3, 2024
1 parent ecda8c9 commit 005f9cb
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 16 deletions.
46 changes: 32 additions & 14 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
std::vector<std::string> split_arg{it, {}};
params.n_threads.resize(split_arg.size());
for (size_t i = 0; i < split_arg.size(); ++i) {
params.n_threads[i] = std::stoi(split_arg[i]);
if (params.n_threads[i] <= 0) {
params.n_threads[i] = std::thread::hardware_concurrency();
for (size_t node = 0; node < split_arg.size(); ++node) {
params.n_threads[node] = std::stoi(split_arg[node]);
if (params.n_threads[node] <= 0) {
params.n_threads[node] = std::thread::hardware_concurrency();
}
}

Expand All @@ -184,29 +184,47 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
std::vector<std::string> split_arg{it, {}};
params.n_threads_batch.resize(split_arg.size());
for (size_t i = 0; i < split_arg.size(); ++i) {
params.n_threads_batch[i] = std::stoi(split_arg[i]);
if (params.n_threads_batch[i] <= 0) {
params.n_threads_batch[i] = std::thread::hardware_concurrency();
for (size_t node = 0; node < split_arg.size(); ++node) {
params.n_threads_batch[node] = std::stoi(split_arg[node]);
if (params.n_threads_batch[node] <= 0) {
params.n_threads_batch[node] = std::thread::hardware_concurrency();
}
}
} else if (arg == "-td" || arg == "--threads-draft") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.n_threads_draft = std::stoi(argv[i]);
if (params.n_threads_draft <= 0) {
params.n_threads_draft = std::thread::hardware_concurrency();
std::string arg_next = argv[i];

// split string by , and /
const std::regex regex{R"([,/]+)"};
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
std::vector<std::string> split_arg{it, {}};
params.n_threads_draft.resize(split_arg.size());
for (size_t node = 0; node < split_arg.size(); ++node) {
params.n_threads_draft[node] = std::stoi(split_arg[node]);
if (params.n_threads_draft[node] <= 0) {
params.n_threads_draft[node] = std::thread::hardware_concurrency();
}
}
} else if (arg == "-tbd" || arg == "--threads-batch-draft") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.n_threads_batch_draft = std::stoi(argv[i]);
if (params.n_threads_batch_draft <= 0) {
params.n_threads_batch_draft = std::thread::hardware_concurrency();
std::string arg_next = argv[i];

// split string by , and /
const std::regex regex{R"([,/]+)"};
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
std::vector<std::string> split_arg{it, {}};
params.n_threads_batch_draft.resize(split_arg.size());
for (size_t node = 0; node < split_arg.size(); ++node) {
params.n_threads_batch_draft[node] = std::stoi(split_arg[node]);
if (params.n_threads_batch_draft[node] <= 0) {
params.n_threads_batch_draft[node] = std::thread::hardware_concurrency();
}
}
} else if (arg == "-p" || arg == "--prompt") {
if (++i >= argc) {
Expand Down
5 changes: 3 additions & 2 deletions ggml-mpi.c
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,9 @@ void ggml_mpi_graph_compute_pre(
return;
}

GGML_ASSERT(inp0 == gf->nodes[0]);
// fprintf(stderr, "gf->nodes[0] == %s\n", ggml_get_name(gf->nodes[0]));
//
// GGML_ASSERT(inp0 == gf->nodes[0]);

// distribute the compute graph into slices across the MPI nodes
//
Expand Down Expand Up @@ -333,7 +335,6 @@ void ggml_mpi_graph_compute_pre(
// TODO: instead of rearranging the nodes, we should be able to execute a subset of the compute graph
for (int i = 1; i < idx_l1 - idx_l0; i++) {
gf->nodes[i] = gf->nodes[idx_l0 + i];
gf->grads[i] = gf->grads[idx_l0 + i];
}

// the first node performs the "get_rows" operation, the rest of the nodes get the data from the previous node
Expand Down

0 comments on commit 005f9cb

Please sign in to comment.