Skip to content

Commit

Permalink
control flow updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Roger Waleffe authored and Roger Waleffe committed Nov 20, 2023
1 parent 05be4d6 commit 3e9b0b4
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 61 deletions.
4 changes: 3 additions & 1 deletion src/cpp/include/nn/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ class Model : public torch::nn::Module {
bool first_epoch_;
std::atomic<bool> epoch_complete_;
std::mutex *pg_lock_;
std::mutex *update_feeders_lock_;
int last_compute_worker_;
std::atomic<bool> already_notified_;

Model(shared_ptr<GeneralEncoder> encoder, shared_ptr<Decoder> decoder, shared_ptr<LossFunction> loss,
shared_ptr<Reporter> reporter = nullptr, LearningTask learning_task = LearningTask::LINK_PREDICTION,
Expand Down Expand Up @@ -137,7 +139,7 @@ class Model : public torch::nn::Module {

void distModelAverage();

void distNotifyCompleteAndWait(bool eval = false);
void distNotifyCompleteAndWait(bool eval = false, bool wait = true);
};

shared_ptr<Model> initModelFromConfig(shared_ptr<ModelConfig> model_config, std::vector<torch::Device> devices, int num_relations, int num_partitions, bool train,
Expand Down
4 changes: 2 additions & 2 deletions src/cpp/src/common/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ torch::Tensor receive_tensor(shared_ptr<c10d::ProcessGroupGloo> pg, int worker_i
}

torch::Tensor sizes = metadata.narrow(0, 0, dim);
std::cout<<sizes<<"\n";
std::cout<<dtype_label<<"\n\n";
// std::cout<<sizes<<"\n";
// std::cout<<dtype_label<<"\n\n";
int64_t *data_ptr = (int64_t *)sizes.data_ptr();
int64_t *end = (int64_t *)data_ptr + sizes.size(0);
std::vector<int64_t> sizes_vec = std::vector<int64_t>(data_ptr, end);
Expand Down
4 changes: 2 additions & 2 deletions src/cpp/src/data/batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,10 @@ void Batch::remoteReceive(shared_ptr<c10d::ProcessGroupGloo> pg, int worker_id,

y_pred_ = receive_tensor(pg, worker_id, tag);
t.stop();
std::cout<<"batch recv: "<<t.getDuration()<<"\n";
// std::cout<<"batch recv: "<<t.getDuration()<<"\n";

t_full.stop();
std::cout<<"batch recv full: "<<t_full.getDuration()<<"\n";
// std::cout<<"batch recv full: "<<t_full.getDuration()<<"\n";
}

void Batch::accumulateGradients(float learning_rate) {
Expand Down
119 changes: 78 additions & 41 deletions src/cpp/src/nn/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,8 @@ void Model::distPrepareForTraining(bool eval) {
return;
}

already_notified_ = false;

std::cout<<"distPrepareForTraining\n";

// set batch_worker_, compute_worker_, children, parents, num_batch_, num_compute_ here based on config,
Expand Down Expand Up @@ -948,9 +950,9 @@ void Model::distPrepareForTraining(bool eval) {

feeders_ = feeders;

if (compute_worker_ and !eval) {
createComputePG(feeders, global_to_compute_worker, compute_worker_to_global);
}
// if (compute_worker_ and !eval) {
// createComputePG(feeders, global_to_compute_worker, compute_worker_to_global);
// }

std::thread(&Model::distListenForComplete, this, eval).detach();

Expand All @@ -970,11 +972,23 @@ void Model::distPrepareForTraining(bool eval) {


// TODO: batch construction (i.e., everybody) waits for sync (with some sort of barrier)?


auto work = pg_gloo_->pg->barrier(); // TBD on if this actually works when we have batch construction workers or not
//// while (!work->isCompleted()) { std::cout<<"barrier waiting\n"; }
if (!work->wait()) {
throw work->exception();
}

std::cout<<"distPrepareForTraining barrier complete\n";


// exit(0);
}

void Model::updateFeeders(int x, bool eval) {
std::cout<<"update feeders: "<<x<<"\n";
update_feeders_lock_->lock();

if (compute_workers_[0] == pg_gloo_->pg->getRank()) {
for (int i = 0; i < pg_gloo_->pg->getSize(); i++) {
Expand All @@ -984,15 +998,24 @@ void Model::updateFeeders(int x, bool eval) {

std::cout<<pg_gloo_->pg->getRank()<<" sending "<<x<<" to "<<i<<"\n";

// bool success = false;

// while (!success) {
// try {
std::vector<torch::Tensor> vec;
torch::Tensor x_tens = torch::zeros({1}, torch::kInt32) + x;
vec.push_back(x_tens);
auto work = pg_gloo_->pg->send(vec, i, 2 + eval);
if (!work->wait()) {
if (!work->wait()) { //std::chrono::milliseconds(1000)
throw work->exception();
}
// success = true;
// } catch (...) {
// std::cout<<"Caught ERROR with update feeders\n\n";
// }
// }

std::cout<<"done sending\n";
std::cout<<"done sending "<<x<<" to "<<i<<"\n";
}
}

Expand Down Expand Up @@ -1020,9 +1043,11 @@ void Model::updateFeeders(int x, bool eval) {
epoch_complete_ = true;
}

if (compute_worker_ and !eval and !epoch_complete_)
createComputePG(feeders_, all_workers_, compute_workers_);
std::cout<<"done update feeders\n";
// if (compute_worker_ and !eval and !epoch_complete_)
// createComputePG(feeders_, all_workers_, compute_workers_);

update_feeders_lock_->unlock();
std::cout<<"done update feeders:"<<x<<"\n";
}

void Model::distListenForComplete(bool eval) {
Expand All @@ -1039,7 +1064,9 @@ void Model::distListenForComplete(bool eval) {
// std::cout<<"x: "<<x<<"\n";

auto work = pg_gloo_->pg->recvAnysource(vec, 2 + eval);
// auto work = pg_gloo_->pg->recv(vec, compute_workers_[0], 2);
// std::cout<<"compute_worker[0]: "<< compute_workers_[0]<<"\n";
// auto work = pg_gloo_->pg->recv(vec, compute_workers_[0], 2 + eval);
// std::cout<<"distListenForComplete waiting\n";
if (!work->wait()) {
throw work->exception();
}
Expand Down Expand Up @@ -1128,7 +1155,7 @@ void Model::distListenForComplete(bool eval) {

//}

void Model::distNotifyCompleteAndWait(bool eval) {
void Model::distNotifyCompleteAndWait(bool eval, bool wait) {
if (pg_gloo_ == nullptr) {
return;
}
Expand All @@ -1137,41 +1164,49 @@ void Model::distNotifyCompleteAndWait(bool eval) {

// called on everything

// if batch construction worker, notify all of completion
if (batch_worker_) {
torch::Tensor x = torch::zeros({1}, torch::kInt32) + pg_gloo_->pg->getRank();
std::vector<torch::Tensor> transfer_vec;
transfer_vec.push_back(x);

auto compute_worker_id = compute_workers_[0];
// for (auto compute_worker_id : compute_workers_) {
// std::cout<<compute_worker_id<<"\n";
if (compute_worker_id == pg_gloo_->pg->getRank()) {
// std::cout<<"direct update feeders\n";
updateFeeders(compute_worker_id, eval); //TODO: this can actually cause update feeders to be called in parallel with the thread, we should have a lock
// continue;
} else {
std::cout<<pg_gloo_->pg->getRank()<<" direct sending "<<x.item<int>()<<" to "<<compute_worker_id<<"\n";

auto work = pg_gloo_->pg->send(transfer_vec, compute_worker_id, 2 + eval);
if (!work->wait()) {
throw work->exception();
if (!already_notified_) {
already_notified_ = true;

// if batch construction worker, notify all of completion
if (batch_worker_) {
torch::Tensor x = torch::zeros({1}, torch::kInt32) + pg_gloo_->pg->getRank();
std::vector<torch::Tensor> transfer_vec;
transfer_vec.push_back(x);

auto compute_worker_id = compute_workers_[0];
// for (auto compute_worker_id : compute_workers_) {
// std::cout<<compute_worker_id<<"\n";
if (compute_worker_id == pg_gloo_->pg->getRank()) {
// std::cout<<"direct update feeders\n";
updateFeeders(compute_worker_id, eval); //TODO: this can actually cause update feeders to be called in parallel with the thread, we should have a lock
// continue;
} else {
std::cout<<pg_gloo_->pg->getRank()<<" direct sending "<<x.item<int>()<<" to "<<compute_worker_id<<"\n";

auto work = pg_gloo_->pg->send(transfer_vec, compute_worker_id, 2 + eval);
if (!work->wait()) {
throw work->exception();
}
std::cout<<"done direct sending\n";
}
std::cout<<"done direct sending\n";
}

// }
// }

// if (!compute_worker_) {
// std::cout<<"distNotifyCompleteAndWait barrier\n";
// auto work = pg_gloo_->pg->barrier();
// if (!work->wait()) {
// throw work->exception();
// }
// epoch_complete_ = true;
// }
} else {
// if (!compute_worker_) {
// std::cout<<"distNotifyCompleteAndWait barrier\n";
// auto work = pg_gloo_->pg->barrier();
// if (!work->wait()) {
// throw work->exception();
// }
// epoch_complete_ = true;
// }
} else {

}
}

if (!wait) {
return;
}


Expand Down Expand Up @@ -1334,7 +1369,9 @@ shared_ptr<Model> initModelFromConfig(shared_ptr<ModelConfig> model_config, std:
model->compute_worker_ = compute_worker;
model->first_epoch_ = true;
model->pg_lock_ = new std::mutex();
model->update_feeders_lock_ = new std::mutex();
model->last_compute_worker_ = -1;
model->already_notified_ = false;

std::cout<<"init model"<<"\n"; // init some model listening queues here if desired for distributed training

Expand Down
22 changes: 19 additions & 3 deletions src/cpp/src/pipeline/pipeline_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,26 @@ PipelineCPU::~PipelineCPU() {
}

bool Pipeline::isDone() {
if (compute_worker_ and compute_worker_needs_remote_) {
return model_->epoch_complete_;
// if (compute_worker_ and compute_worker_needs_remote_) {
// return model_->epoch_complete_;
// } else {
// return (batches_in_flight_ <= 0) && dataloader_->epochComplete();
// }

if (batch_worker_) {
bool done_locally = (batches_in_flight_ <= 0) && dataloader_->epochComplete();
if (done_locally and compute_worker_ and compute_worker_needs_remote_) {
// model_->distNotifyCompleteAndWait(train_, false);
if (!model_->already_notified_) {
model_->updateFeeders(model_->pg_gloo_->pg->getRank(), !train_); // TODO; not general, need something like distNotifyComplete
model_->already_notified_ = true;
}
return model_->epoch_complete_;
} else {
return done_locally;
}
} else {
return (batches_in_flight_ <= 0) && dataloader_->epochComplete();
return model_->epoch_complete_;
}
}

Expand Down
24 changes: 12 additions & 12 deletions src/cpp/src/pipeline/pipeline_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void RemoteLoadWorker::run() {
((PipelineCPU *)pipeline_)->loaded_batches_->blocking_push(batch);
}
t.stop();
std::cout<<"remote load: "<<t.getDuration()<<"\n";
// std::cout<<"remote load: "<<t.getDuration()<<"\n";
}
nanosleep(&sleep_time_, NULL);
}
Expand All @@ -93,7 +93,7 @@ void RemoteToDeviceWorker::run() {
t.start();
auto tup = ((PipelineGPU *)pipeline_)->loaded_batches_->blocking_pop();
t.stop();
std::cout<<"remote to block: "<<t.getDuration()<<"\n";
// std::cout<<"remote to block: "<<t.getDuration()<<"\n";
t.start();
bool popped = std::get<0>(tup);
shared_ptr<Batch> batch = std::get<1>(tup);
Expand All @@ -118,7 +118,7 @@ void RemoteToDeviceWorker::run() {
batch->creator_id_ = pipeline_->model_->pg_gloo_->pg->getRank();
batch->remoteTo(pipeline_->model_->pg_gloo_->pg, child, tag);
t.stop();
std::cout<<"remote to: "<<t.getDuration()<<"\n";
// std::cout<<"remote to: "<<t.getDuration()<<"\n";
}
nanosleep(&sleep_time_, NULL);
}
Expand All @@ -135,7 +135,7 @@ void BatchToDeviceWorker::run() {
t.start();
auto tup = ((PipelineGPU *)pipeline_)->loaded_batches_->blocking_pop();
t.stop();
std::cout<<"batch to block: "<<t.getDuration()<<"\n";
// std::cout<<"batch to block: "<<t.getDuration()<<"\n";
t.start();
bool popped = std::get<0>(tup);
shared_ptr<Batch> batch = std::get<1>(tup);
Expand All @@ -145,7 +145,7 @@ void BatchToDeviceWorker::run() {

batchToDevice(pipeline_, batch);
t.stop();
std::cout<<"batch to: "<<t.getDuration()<<"\n";
// std::cout<<"batch to: "<<t.getDuration()<<"\n";
}
nanosleep(&sleep_time_, NULL);
}
Expand All @@ -161,7 +161,7 @@ void ComputeWorkerGPU::run() {
t.start();
auto tup = ((PipelineGPU *)pipeline_)->device_loaded_batches_[gpu_id_]->blocking_pop();
t.stop();
std::cout<<"compute block: "<<t.getDuration()<<"\n";
// std::cout<<"compute block: "<<t.getDuration()<<"\n";
t.start();
bool popped = std::get<0>(tup);
shared_ptr<Batch> batch = std::get<1>(tup);
Expand Down Expand Up @@ -253,7 +253,7 @@ void ComputeWorkerGPU::run() {
}
}
t.stop();
std::cout<<"compute: "<<t.getDuration()<<"\n";
// std::cout<<"compute: "<<t.getDuration()<<"\n";
}
nanosleep(&sleep_time_, NULL);
}
Expand Down Expand Up @@ -289,7 +289,7 @@ void BatchToHostWorker::run() {
t.start();
auto tup = ((PipelineGPU *)pipeline_)->device_update_batches_[gpu_id_]->blocking_pop();
t.stop();
std::cout<<"batch to host block: "<<t.getDuration()<<"\n";
// std::cout<<"batch to host block: "<<t.getDuration()<<"\n";
t.start();
bool popped = std::get<0>(tup);
shared_ptr<Batch> batch = std::get<1>(tup);
Expand All @@ -316,7 +316,7 @@ void BatchToHostWorker::run() {

((PipelineGPU *)pipeline_)->update_batches_->blocking_push(batch);
t.stop();
std::cout<<"batch to host: "<<t.getDuration()<<"\n";
// std::cout<<"batch to host: "<<t.getDuration()<<"\n";
}
nanosleep(&sleep_time_, NULL);
}
Expand All @@ -329,7 +329,7 @@ void RemoteToHostWorker::run() {
t.start();
auto tup = ((PipelineGPU *)pipeline_)->update_batches_->blocking_pop();
t.stop();
std::cout<<"remote to host block: "<<t.getDuration()<<"\n";
// std::cout<<"remote to host block: "<<t.getDuration()<<"\n";
t.start();
bool popped = std::get<0>(tup);
shared_ptr<Batch> batch = std::get<1>(tup);
Expand Down Expand Up @@ -368,7 +368,7 @@ void RemoteToHostWorker::run() {

batch->remoteTo(pipeline_->model_->pg_gloo_->pg, parent, tag);
t.stop();
std::cout<<"remote to host: "<<t.getDuration()<<"\n";
// std::cout<<"remote to host: "<<t.getDuration()<<"\n";
}
nanosleep(&sleep_time_, NULL);
}
Expand Down Expand Up @@ -411,7 +411,7 @@ void RemoteListenForUpdatesWorker::run() {

((PipelineGPU *)pipeline_)->update_batches_->blocking_push(batch);
t.stop();
std::cout<<"remote listen: "<<t.getDuration()<<"\n";
// std::cout<<"remote listen: "<<t.getDuration()<<"\n";
}
nanosleep(&sleep_time_, NULL);
}
Expand Down

0 comments on commit 3e9b0b4

Please sign in to comment.