Skip to content

Commit

Permalink
RemoteRendezvous supports FlowControl.
Browse files Browse the repository at this point in the history
Signed-off-by: chenbangduo.cbd <[email protected]>
  • Loading branch information
JackMoriarty committed May 14, 2024
1 parent e10d441 commit 31f240f
Show file tree
Hide file tree
Showing 21 changed files with 903 additions and 30 deletions.
213 changes: 212 additions & 1 deletion tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/env_var.h"

namespace tensorflow {

namespace {
uint64 kGlobalStepId = 0x100000000000000uLL;
int64 kFlowControlMaxSize = 16;
} // namespace anonymous

static void StartAbortRendevous(Rendezvous* rendez, const Status& s) {
Expand Down Expand Up @@ -127,6 +129,23 @@ void BaseRendezvousMgr::FuseRecvLocalAsync(
rendez->FuseRecvLocalAsync(parsed_keys, std::move(done_cb));
}

void BaseRendezvousMgr::FlowControlRecvLocalAsync(int64 step_id,
const StringPiece& tag, const Rendezvous::ParsedKey& parsed,
Rendezvous::DoneCallback done) {
auto rendez = FindOrCreate(step_id);
using namespace std::placeholders;
Rendezvous::DoneCallback done_cb = std::bind(
[rendez](Rendezvous::DoneCallback done,
// Begin unbound arguments.
const Status& s, const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
rendez->Unref();
done(s, send_args, recv_args, v, dead);
},
std::move(done), _1, _2, _3, _4, _5);
rendez->FlowControlRecvLocalAsync(tag, parsed, std::move(done_cb));
}

void BaseRendezvousMgr::Cleanup(int64 step_id) {
Rendezvous* rendez = nullptr;
{
Expand Down Expand Up @@ -174,7 +193,17 @@ BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id)
: env_(env),
step_id_(step_id),
local_(NewLocalRendezvous()),
session_(nullptr) {}
session_(nullptr),
flow_control_num_(0) {
Status s = ReadInt64FromEnvVar("REMOTE_RENDEZVOUS_FLOW_CONTROL_MAX_SIZE",
kFlowControlMaxSize, &flow_control_max_size_);
if (!s.ok()) {
LOG(ERROR) << "Read REMOTE_RENDEZVOUS_FLOW_CONTROL_MAX_SIZE env error: "
<< s.error_message();
}
VLOG(2) << "BaseRemoteRendezvous set flow control max size: "
<< flow_control_max_size_;
}

BaseRemoteRendezvous::~BaseRemoteRendezvous() {
CHECK(active_.empty());
Expand Down Expand Up @@ -221,6 +250,16 @@ Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
std::move(fuse_call.done));
}

std::vector<DeferredFlowControlCall> deferred_flow_control_calls;
{
mutex_lock l(mu_);
std::swap(deferred_flow_control_calls, deferred_flow_control_calls_);
}
for (auto& fc_call : deferred_flow_control_calls) {
FlowControlRecvLocalAsyncInternal(fc_call.tag, fc_call.parsed,
std::move(fc_call.done));
}

return Status::OK();
}

Expand Down Expand Up @@ -271,6 +310,43 @@ Status BaseRemoteRendezvous::Send(const ParsedKey& parsed,
return local_->Send(parsed, args, val, mu, is_dead);
}

Status BaseRemoteRendezvous::FlowControlSend(const StringPiece& tag,
const ParsedKey& parsed,
const Args& args,
const Tensor& val,
const bool is_dead,
const int64 timeout_millis) {
VLOG(1) << "BaseRemoteRendezvous FlowControlSend " << this << " "
<< parsed.FullKey();
const std::string tag_string(tag.data(), tag.size());
{
mutex_lock l(mu_);
while(status_.ok() && flow_control_num_ >= flow_control_max_size_) {
if (flow_control_cv_.wait_for(
l, std::chrono::milliseconds(timeout_millis)) == \
std::cv_status::timeout) {
return errors::DeadlineExceeded("FlowControlSend has timed out.");
}
}

if (!status_.ok()) return status_;
DCHECK(is_initialized_locked());
if (!IsLocalDevice(session_->worker_name, parsed.src_device)) {
return errors::InvalidArgument(
"Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
session_->worker_name);
}

flow_control_num_++;
if (flow_control_counters_.count(tag_string) == 0) {
flow_control_counters_[tag_string] = 0;
}
flow_control_counters_[tag_string]++;
}
// Buffers "val" and "device_context" in local_.
return local_->Send(parsed, args, val, is_dead);
}

Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
bool is_src) {
// Cache session pointer to avoid repeatedly taking & releasing the lock
Expand Down Expand Up @@ -413,6 +489,63 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
}
}

void BaseRemoteRendezvous::FlowControlRecvAsync(const StringPiece& tag,
const ParsedKey& parsed,
const Args& recv_args,
DoneCallback done) {
VLOG(1) << "RemoteRendezvous FlowControlRecvAsync " << this
<< " " << tag << " " << parsed.FullKey();

Status s = ValidateDevices(parsed, false /*!is_src*/);
if (s.ok() && !is_initialized()) {
s.Update(errors::Internal(
"FlowControlRecvAsync called when uninitialized (key:",
parsed.FullKey(), ")."));
}
if (!s.ok()) {
done(s, Args(), recv_args, Tensor(), false);
return;
}

// Are src and dst in the same worker?
if (IsSameWorker(parsed.src, parsed.dst)) {
// Recv the tensor from local_.
local_->RecvAsync(
parsed, recv_args,
[this, tag, parsed, done](
const Status& status, const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) {
VLOG(2) << "RemoteRendezvous Finished Recv " << this << " "
<< parsed.FullKey();
Tensor* out = new Tensor;
StatusCallback final_callback = [done, send_args, recv_args, out,
is_dead](const Status& s) {
done(s, send_args, recv_args, *out, is_dead);
delete out;
};

if (status.ok()) {
SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
std::move(final_callback));
const std::string tag_string(tag.data(), tag.size());
{
mutex_lock l(mu_);
flow_control_num_--;
DCHECK(flow_control_counters_.count(tag_string) != 0);
flow_control_counters_[tag_string]--;
}
flow_control_cv_.notify_one();
} else {
final_callback(status);
}
});
return;
} else {
FlowControlRecvFromRemoteAsync(tag, parsed, recv_args, std::move(done));
}

}

void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
DoneCallback done) {
{
Expand Down Expand Up @@ -600,13 +733,71 @@ void BaseRemoteRendezvous::FuseRecvLocalAsyncInternal(
}
}

void BaseRemoteRendezvous::FlowControlRecvLocalAsync(const StringPiece& tag,
const ParsedKey& parsed,
DoneCallback done) {
{
mutex_lock l(mu_);
if (!is_initialized_locked()) {
// FlowControlRecvLocalAsync can be called (due to an incoming RecvTensor
// RPC from a remote worker) before the RunStep (or PartialRunStep) RPC
// from the master arrives. RecvLocalAsync thus buffers the arguments
// until after the RemoteRendezvous is Initialize()'d, when it completes
// the rendezvous logic. At some point after Initialize() is called, a
// Tensor is produced locally that will then be sent in response to the
// incoming RPC.
DeferredFlowControlCall call(tag, parsed, std::move(done));
deferred_flow_control_calls_.push_back(call);
return;
}
}
FlowControlRecvLocalAsyncInternal(tag, parsed, std::move(done));
}

void BaseRemoteRendezvous::FlowControlRecvLocalAsyncInternal(
const StringPiece& tag, const ParsedKey& parsed, DoneCallback done) {
Status s = ValidateDevices(parsed, true /* is_src */);
if (!s.ok()) {
done(s, Args(), Args(), Tensor(), false);
return;
}

using namespace std::placeholders;
Rendezvous::DoneCallback done_cb = std::bind(
[this, tag](Rendezvous::DoneCallback done,
// Begin unbound arguments.
const Status& s, const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
done(s, send_args, recv_args, v, dead);
if (s.ok()) {
const std::string tag_string(tag.data(), tag.size());
{
mutex_lock l(mu_);
flow_control_num_--;
DCHECK(flow_control_counters_.count(tag_string) != 0);
flow_control_counters_[tag_string]--;
}
flow_control_cv_.notify_one();
}
},
std::move(done), _1, _2, _3, _4, _5);

local_->RecvAsync(parsed, Args(), std::move(done_cb));
}

void BaseRemoteRendezvous::FuseRecvFromRemoteAsync(
const std::vector<Rendezvous::ParsedKey>& parsed_keys,
const Rendezvous::Args& args,
FuseDoneCallback done) {
CHECK(false) << "FuseRecvFromRemoteAsync Unimplemented";
}

void BaseRemoteRendezvous::FlowControlRecvFromRemoteAsync(
const StringPiece& tag, const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args, DoneCallback done) {
CHECK(false) << "FlowControlRecvFromRemoteAsync Unimplemented.";
}

void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
const Rendezvous::Args& recv_args,
RefDoneCallback done) {
Expand Down Expand Up @@ -636,6 +827,19 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
}
}

int64 BaseRemoteRendezvous::GetAllFlowControlItemNum() {
mutex_lock l(mu_);
return flow_control_num_;
}

int64 BaseRemoteRendezvous::GetFlowControlItemNum(StringPiece tag) {
const std::string tag_string(tag.data(), tag.size());
mutex_lock l(mu_);
if (flow_control_counters_.count(tag_string) == 0)
return 0;
return flow_control_counters_[tag_string];
}

void BaseRemoteRendezvous::StartAbort(const Status& s) {
CHECK(!s.ok());
// Use a "derived" status as the status for the rendezvous. Derived
Expand All @@ -656,7 +860,10 @@ void BaseRemoteRendezvous::StartAbort(const Status& s) {
}
active_.clear();
}
flow_control_num_ = 0;
flow_control_counters_.clear();
}
flow_control_cv_.notify_all();
}

void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call,
Expand Down Expand Up @@ -707,4 +914,8 @@ BaseRemoteRendezvous::DeferredFuseCall::DeferredFuseCall(
const std::vector<ParsedKey>& parsed_keys, FuseDoneCallback done)
: parsed_keys(parsed_keys), done(std::move(done)) {}

BaseRemoteRendezvous::DeferredFlowControlCall::DeferredFlowControlCall(
const StringPiece& tag, const ParsedKey& parsed, DoneCallback done)
: tag(tag), parsed(parsed), done(std::move(done)) {}

} // end namespace tensorflow
Loading

0 comments on commit 31f240f

Please sign in to comment.