From 0004a96bb4fbdb881d13ea3300b15a047c33b37a Mon Sep 17 00:00:00 2001 From: JunqiHu Date: Wed, 25 Oct 2023 14:13:53 +0800 Subject: [PATCH] [Distribute] Add elastic-grpc server. Signed-off-by: JunqiHu --- configure.py | 3 + tensorflow/BUILD | 6 + tensorflow/contrib/elastic_grpc_server/BUILD | 70 ++++ .../elastic_grpc_server_lib.cc | 317 ++++++++++++++++++ .../elastic_grpc_server_lib.h | 66 ++++ .../elastic_grpc_server_lib_test.cc | 76 +++++ .../elastic_grpc_server/elastic_service.cc | 157 +++++++++ .../elastic_grpc_server/elastic_service.h | 31 ++ tensorflow/core/BUILD | 23 ++ .../distributed_runtime/rpc/grpc_server_lib.h | 14 +- .../core/platform/default/build_config.bzl | 6 + .../platform/default/build_config_root.bzl | 8 + .../core/protobuf/elastic_training.proto | 76 +++++ tensorflow/python/BUILD | 3 +- 14 files changed, 848 insertions(+), 8 deletions(-) create mode 100644 tensorflow/contrib/elastic_grpc_server/BUILD create mode 100644 tensorflow/contrib/elastic_grpc_server/elastic_grpc_server_lib.cc create mode 100644 tensorflow/contrib/elastic_grpc_server/elastic_grpc_server_lib.h create mode 100644 tensorflow/contrib/elastic_grpc_server/elastic_grpc_server_lib_test.cc create mode 100644 tensorflow/contrib/elastic_grpc_server/elastic_service.cc create mode 100644 tensorflow/contrib/elastic_grpc_server/elastic_service.h create mode 100644 tensorflow/core/protobuf/elastic_training.proto diff --git a/configure.py b/configure.py index 362479981b2..6aeaf7d12af 100644 --- a/configure.py +++ b/configure.py @@ -1433,6 +1433,9 @@ def main(): set_build_var(environ_cp, 'TF_NEED_STAR', 'STAR', 'with_star_support', True, 'star') + set_build_var(environ_cp, 'TF_NEED_ELASTIC', 'ELASTIC TRAINING', 'with_elastic_support', + True, 'elastic') + set_build_var(environ_cp, 'TF_ENABLE_PMEM', 'PMEM', 'with_pmem_support', False, 'pmem') diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 493247a2162..8b4190ea680 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -434,6 +434,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "with_elastic_support", + values = {"define": "with_elastic_support=true"}, + visibility = ["//visibility:public"], +) + config_setting( name = "with_pmem_support", values = {"define": "with_pmem_support=true"}, diff --git a/tensorflow/contrib/elastic_grpc_server/BUILD b/tensorflow/contrib/elastic_grpc_server/BUILD new file mode 100644 index 00000000000..ea4b87e3b58 --- /dev/null +++ b/tensorflow/contrib/elastic_grpc_server/BUILD @@ -0,0 +1,70 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +package(default_visibility = [ + "//tensorflow:internal", +]) + +load( + "//tensorflow:tensorflow.bzl", "tf_cc_test", +) + +cc_library( + name = "elastic_grpc_server_lib", + srcs = select({"//tensorflow:with_elastic_support": ["elastic_service.cc", + "elastic_grpc_server_lib.cc"], + "//conditions:default": []}), + hdrs = ["elastic_service.h", + "elastic_grpc_server_lib.h"], + linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel + deps = [ + "//tensorflow/core:elastic_service_proto_cc", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc:async_service_interface", + "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/distributed_runtime/rpc:grpc_master_service", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", + "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", + "//tensorflow:grpc", + "//tensorflow:grpc++", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/distributed_runtime:collective_param_resolver_distributed", + "//tensorflow/core/distributed_runtime:device_resolver_distributed", + "//tensorflow/core/distributed_runtime:graph_mgr", + "//tensorflow/core/distributed_runtime:local_master", + "//tensorflow/core/distributed_runtime:master", + "//tensorflow/core/distributed_runtime:master_env", + "//tensorflow/core/distributed_runtime:master_session", + "//tensorflow/core/distributed_runtime:rpc_collective_executor_mgr", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime:session_mgr", + "//tensorflow/core/distributed_runtime:worker_cache_wrapper", + "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/distributed_runtime:worker_resource", + "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_service_impl", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "elastic_grpc_test", + size = "small", + srcs = ["elastic_grpc_server_lib_test.cc"], + deps = [ + ":elastic_grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_util", + "//tensorflow:grpc", + "//tensorflow:grpc++", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:worker_proto_cc", + ], + linkstatic = 1, +) diff --git a/tensorflow/contrib/elastic_grpc_server/elastic_grpc_server_lib.cc b/tensorflow/contrib/elastic_grpc_server/elastic_grpc_server_lib.cc new file mode 100644 index 00000000000..d45d70d6c8c --- /dev/null +++ b/tensorflow/contrib/elastic_grpc_server/elastic_grpc_server_lib.cc @@ -0,0 +1,317 @@ +/* Copyright 2023 The DeepRec Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +#include "tensorflow/contrib/elastic_grpc_server/elastic_grpc_server_lib.h" + +#include +#include +#include +#include + +#include "include/json/json.h" +#include "grpc/support/alloc.h" +#include "grpcpp/grpcpp.h" +#include "grpcpp/security/credentials.h" +#include "grpcpp/server_builder.h" +#include "tensorflow/core/util/env_var.h" + +#include "tensorflow/contrib/elastic_grpc_server/elastic_service.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/graph_mgr.h" +#include "tensorflow/core/distributed_runtime/local_master.h" +#include "tensorflow/core/distributed_runtime/master.h" +#include "tensorflow/core/distributed_runtime/master_env.h" +#include "tensorflow/core/distributed_runtime/master_session.h" +#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" +#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/distributed_runtime/worker_resource.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/protobuf/cluster.pb.h" + +namespace tensorflow { + +namespace { + +// static utility function +RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) { + return new RpcRendezvousMgr(env); +} + +} // namespace + +ElasticGrpcServer::ElasticGrpcServer(const ServerDef& server_def, Env* env) + : GrpcServer(server_def, env) {} + +ElasticGrpcServer::~ElasticGrpcServer() { + delete elastic_service_; +} + +Status ElasticGrpcServer::UpdateServerDef(const string& cluster_def_str, int& before_part_num, int& after_part_num) { + std::string tf_config; + ReadStringFromEnvVar("TF_CONFIG", "", &tf_config); + if (!tf_config.empty()) { + Json::Reader reader; + Json::Value tf_config_json; + if(!reader.parse(tf_config, tf_config_json)) { + return errors::Internal("PARSE TF_CONFIG ERROR"); + } + if ((tf_config_json["cluster"].isNull()) || + (tf_config_json["cluster"]["ps"].isNull())) { + return errors::Internal("PARSE PS FROM TF_CONFIG ERROR"); + } + + Json::Value cluster_json; + if (!reader.parse(cluster_def_str, cluster_json)) { + LOG(ERROR) << "cluster_def is not correct with " << cluster_def_str; + return errors::Internal("PARSE TF_CONFIG/cluster ERROR"); + } + + std::unordered_set ps_addrs_vec; + after_part_num = cluster_json["cluster"]["ps"].size(); + for (auto& value: cluster_json["cluster"]["ps"]) { + ps_addrs_vec.emplace(value.asString()); + } + + int job_size = server_def_.cluster().job_size(); + for (int j = 0; j < job_size; ++j) { + auto* job = server_def_.mutable_cluster()->mutable_job(j); + if (job->name() == "ps") { + before_part_num = job->tasks_size(); + if (before_part_num == after_part_num) { + return Status::OK(); + } else if (after_part_num > before_part_num) { + int idx = before_part_num; + LOG(INFO) << "SCALING UP, partition_num is: " << after_part_num; + std::unordered_set target_string_set; + for (auto& value: tf_config_json["cluster"]["ps"]) { + target_string_set.emplace(value.asString()); + } + for (auto ps_addr: ps_addrs_vec) { + if (target_string_set.find(ps_addr) == target_string_set.end()) { + job->mutable_tasks()->insert({idx, ps_addr}); + tf_config_json["cluster"]["ps"].append(ps_addr); + } + } + break; + } else { + LOG(INFO) << "SCALING DOWN, partition_num is: " << after_part_num; + for (int i = 0; i < before_part_num; ++i) { + string tmp_string = tf_config_json["cluster"]["ps"][i].asString(); + if (ps_addrs_vec.find(tmp_string) == ps_addrs_vec.end()) { + Json::Value ps_addr; + tf_config_json["cluster"]["ps"].removeIndex(i, &ps_addr); + job->mutable_tasks()->erase(i); + } + } + } + } + } + Json::FastWriter writer; + std::string new_tf_config = writer.write(tf_config_json); + LOG(INFO) << "new TF_CONFIG " << new_tf_config; + setenv("TF_CONFIG", new_tf_config.c_str(), 1); + } + return Status::OK(); +} + +Status ElasticGrpcServer::Update(const string& cluster_def_str) { + int before_part_num, after_part_num; + Status s = UpdateServerDef(cluster_def_str, before_part_num, after_part_num); + if (!s.ok()) { + LOG(ERROR) << s.error_message(); + return Status::OK(); + } + + if (after_part_num == before_part_num) { + return Status::OK(); + } + + WorkerCacheInterface* worker_cache; + WorkerCacheFactoryOptions worker_cache_factory_options(server_def_); + TF_RETURN_IF_ERROR( + WorkerCacheFactory(worker_cache_factory_options, &worker_cache)); + CHECK_NE(nullptr, worker_cache); + ConfigProto config = server_def_.default_session_config(); + string unused; + string default_worker_name; + if (!DeviceNameUtils::SplitDeviceName(master_env()->local_devices[0]->name(), + &default_worker_name, &unused)) { + return errors::Internal("Could not parse worker name."); + } + std::unique_ptr dev_resolver( + new DeviceResolverDistributed(worker_env()->device_mgr, worker_cache, + default_worker_name)); + std::unique_ptr param_resolver( + new CollectiveParamResolverDistributed(config, worker_env()->device_mgr, + dev_resolver.get(), worker_cache, + default_worker_name)); + worker_env()->collective_executor_mgr = new RpcCollectiveExecutorMgr( + config, worker_env()->device_mgr, std::move(dev_resolver), + std::move(param_resolver), worker_cache, default_worker_name); + + if (worker_env()->session_mgr != nullptr) { + delete worker_env()->session_mgr; // Deletes graph_mgr's. + } + + // Set up worker environment. + worker_env()->session_mgr = new SessionMgr( + worker_env(), SessionMgr::WorkerNameFromServerDef(server_def_), + std::unique_ptr(worker_cache), + [this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) { + WorkerCacheFactoryOptions options(server_def); + return WorkerCacheFactory(options, worker_cache); + }); + master_env()->worker_cache = worker_cache; + // Finish setting up master environment. + + StatsPublisherFactory stats_factory = opts_.stats_factory; + master_env()->master_session_factory = + [config, stats_factory]( + SessionOptions options, const MasterEnv* env, + std::unique_ptr>> remote_devs, + std::unique_ptr worker_cache, + std::unique_ptr device_set, + std::vector filtered_worker_list) { + options.config.MergeFrom(config); + return new MasterSession(options, env, std::move(remote_devs), + std::move(worker_cache), std::move(device_set), + std::move(filtered_worker_list), + stats_factory); + }; + master_env()->worker_cache_factory = + [this](const WorkerCacheFactoryOptions& options, + WorkerCacheInterface** worker_cache) { + return WorkerCacheFactory(options, worker_cache); + }; + return Status::OK(); +} + +void ElasticGrpcServer::MaybeMutateBuilder(::grpc::ServerBuilder* builder) { + elastic_service_ = NewElasticGrpcService(this, builder); +} + +Status ElasticGrpcServer::Start() { + { + mutex_lock l(mu_); + switch (state_) { + case NEW: { + update_server_thread_.reset( + env_->StartThread(ThreadOptions(), "TF_elastic_service", + [this] { elastic_service_->HandleRPCsLoop(); })); + LOG(INFO) << "Started server with target: " << target(); + break; + } + case STARTED: + LOG(INFO) << "Server already started (target: " << target() << ")"; + return Status::OK(); + case STOPPED: + return errors::FailedPrecondition("Server has stopped."); + default: + LOG(FATAL); + } + } + return GrpcServer::Start(); +} + +Status ElasticGrpcServer::Join() { + GrpcServer::Join(); + mutex_lock l(mu_); + switch (state_) { + case NEW: + LOG(FATAL) << "Server shoud already closed"; + case STARTED: + case STOPPED: + update_server_thread_.reset(); + return Status::OK(); + default: + LOG(FATAL); + } +} + +/* static */ +Status ElasticGrpcServer::Create(const ServerDef& server_def, Env* env, + std::unique_ptr* out_server) { + std::unique_ptr ret( + new ElasticGrpcServer(server_def, env == nullptr ? Env::Default() : env)); + ServiceInitFunction service_func = nullptr; + GrpcServerOptions options; + options.rendezvous_mgr_func = NewRpcRendezvousMgr; + Status s = ret->Init(options); + if (!s.ok()) { + LOG(ERROR) << s; + return s; + } + *out_server = std::move(ret); + return Status::OK(); +} + +/* static */ +Status ElasticGrpcServer::Create(const ServerDef& server_def, Env* env, + std::unique_ptr* out_server) { + std::unique_ptr ret( + new ElasticGrpcServer(server_def, env == nullptr ? Env::Default() : env)); + GrpcServerOptions options; + options.rendezvous_mgr_func = NewRpcRendezvousMgr; + Status s = ret->Init(options); + if (!s.ok()) { + LOG(ERROR) << s; + return s; + } + *out_server = std::move(ret); + return Status::OK(); +} + +namespace { + +class ElasticGrpcServerFactory : public ServerFactory { + public: + bool AcceptsOptions(const ServerDef& server_def) override { + return server_def.protocol() == "elastic-grpc"; + } + + Status NewServer(const ServerDef& server_def, + std::unique_ptr* out_server) override { + return ElasticGrpcServer::Create(server_def, Env::Default(), out_server); + } +}; + +// Registers a `ServerFactory` for `ElasticGrpcServer` instances. +class ElasticGrpcServerRegistrar { + public: + ElasticGrpcServerRegistrar() { + gpr_allocation_functions alloc_fns; + memset(&alloc_fns, 0, sizeof(alloc_fns)); + alloc_fns.malloc_fn = port::Malloc; + alloc_fns.realloc_fn = port::Realloc; + alloc_fns.free_fn = port::Free; + gpr_set_allocation_functions(alloc_fns); + ServerFactory::Register("ELASTIC_GRPC_SERVER", new ElasticGrpcServerFactory()); + } +}; +static ElasticGrpcServerRegistrar registrar; + +} // namespace +} // namespace tensorflow \ No newline at end of file diff --git a/tensorflow/contrib/elastic_grpc_server/elastic_grpc_server_lib.h b/tensorflow/contrib/elastic_grpc_server/elastic_grpc_server_lib.h new file mode 100644 index 00000000000..8853ceb2819 --- /dev/null +++ b/tensorflow/contrib/elastic_grpc_server/elastic_grpc_server_lib.h @@ -0,0 +1,66 @@ +/* Copyright 2023 The DeepRec Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_ELASTIC_GRPC_SERVER_ELASTIC_GRPC_SERVER_LIB_H_ +#define TENSORFLOW_CONTRIB_ELASTIC_GRPC_SERVER_ELASTIC_GRPC_SERVER_LIB_H_ + +#include + +#include "grpcpp/grpcpp.h" +#include "grpcpp/security/credentials.h" +#include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/common_runtime/stats_publisher_interface.h" +#include "tensorflow/core/distributed_runtime/master_env.h" +#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/distributed_runtime/session_mgr.h" +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { + +class ElasticGrpcServer : public GrpcServer { + public: + ElasticGrpcServer(const ServerDef& server_def, Env* env); + + virtual ~ElasticGrpcServer() override; + + static Status Create(const ServerDef& server_def, Env* env, + std::unique_ptr* out_server); + static Status Create(const ServerDef& server_def, Env* env, + std::unique_ptr* out_server); + + Status Update(const string& cluster_def_str); + + void MaybeMutateBuilder(::grpc::ServerBuilder* builder) override; + + Status Start() override; + + Status Join() override; + + private: + Status UpdateServerDef(const string& cluster_def_str, int& before_part_num, int& after_part_num); + + private: + // TensorFlow Eager implementation, and RPC polling thread. + AsyncServiceInterface* elastic_service_ = nullptr; + std::unique_ptr update_server_thread_ GUARDED_BY(mu_); + + std::unique_ptr<::grpc::Server> server_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_ELASTIC_GRPC_SERVER_ELASTIC_GRPC_SERVER_LIB_H_ \ No newline at end of file diff --git a/tensorflow/contrib/elastic_grpc_server/elastic_grpc_server_lib_test.cc b/tensorflow/contrib/elastic_grpc_server/elastic_grpc_server_lib_test.cc new file mode 100644 index 00000000000..edf6226080f --- /dev/null +++ b/tensorflow/contrib/elastic_grpc_server/elastic_grpc_server_lib_test.cc @@ -0,0 +1,76 @@ +/* Copyright 2023 The DeepRec Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +#include "tensorflow/contrib/elastic_grpc_server/elastic_grpc_server_lib.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +#include "gtest/gtest.h" + +namespace tensorflow { + +class ElasticGrpcServerTest : public ::testing::Test { + protected: + Status FillServerDef(const string& job_spec, ServerDef* options) { + options->set_protocol("elastic-grpc"); + options->set_job_name("chief"); + options->set_task_index(0); + + uint32 my_tasks_per_replica = 0; + for (const string& job_str : str_util::Split(job_spec, ',')) { + JobDef* job_def = options->mutable_cluster()->add_job(); + // Split each entry in the flag into 2 pieces, separated by "|". + const std::vector job_pieces = str_util::Split(job_str, '|'); + CHECK_EQ(2, job_pieces.size()) << job_str; + job_def->set_name(job_pieces[0]); + // Does a bit more validation of the tasks_per_replica. + const StringPiece spec = job_pieces[1]; + // job_str is of form |. + const std::vector host_ports = str_util::Split(spec, ';'); + uint32 tasks_per_replica = host_ports.size(); + for (size_t i = 0; i < host_ports.size(); ++i) { + (*job_def->mutable_tasks())[i] = host_ports[i]; + } + if (job_def->name() == options->job_name()) { + my_tasks_per_replica = tasks_per_replica; + } + LOG(INFO) << "Peer " << job_def->name() << " " << tasks_per_replica << " {" + << absl::StrJoin(host_ports, ", ") << "}"; + } + if (my_tasks_per_replica == 0) { + return errors::InvalidArgument("Invalid job specification"); + } + return Status::OK(); + } +}; + +//Test Update Logic +TEST_F(ElasticGrpcServerTest, UpdateServer) { + Status s; + std::unique_ptr grpc_server; + ServerDef server_def; + std::string job_spec = "worker|localhost:2222,ps|localhost:10086;localhost:10087;localhost:10088,chief|localhost:2220"; + TF_ASSERT_OK(FillServerDef(job_spec, &server_def)); + s = ElasticGrpcServer::Create(server_def, Env::Default(), &grpc_server); + if (!s.ok()) { + LOG(ERROR) << "Could not create server: " << s.error_message(); + } + TF_ASSERT_OK(grpc_server->Start()); + // TF_QCHECK_OK(grpc_server->Join()); + LOG(INFO) << "SCALING DOWN"; + std::string tf_config_str = "{\"cluster\": {\"worker\": [\"localhost:2222\"],\"ps\": [\"localhost:10086\", \"localhost:10087\"],\"chief\": [\"localhost:2220\"]]}}"; + grpc_server->Update(tf_config_str); + LOG(INFO) << "SCALING UP"; + tf_config_str = "{\"cluster\": {\"worker\": [\"localhost:2222\"],\"ps\": [\"localhost:10086\", \"localhost:10087\", \"localhost:10088\"],\"chief\": [\"localhost:2220\"]]}}"; + grpc_server->Update(tf_config_str); +} + +} \ No newline at end of file diff --git a/tensorflow/contrib/elastic_grpc_server/elastic_service.cc b/tensorflow/contrib/elastic_grpc_server/elastic_service.cc new file mode 100644 index 00000000000..61aa6e662ec --- /dev/null +++ b/tensorflow/contrib/elastic_grpc_server/elastic_service.cc @@ -0,0 +1,157 @@ +/* Copyright 2023 The DeepRec Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +#include "tensorflow/contrib/elastic_grpc_server/elastic_service.h" + +#include "tensorflow/contrib/elastic_grpc_server/elastic_grpc_server_lib.h" +#include "tensorflow/core/protobuf/elastic_training.grpc.pb.h" +#include "tensorflow/core/protobuf/elastic_training.pb.h" +#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" + + +#include +#include +#include +#include +#include "grpcpp/server_builder.h" + +using namespace des; + +using grpc::Server; +using grpc::ServerAsyncResponseWriter; +using grpc::ServerBuilder; +using grpc::ServerCompletionQueue; +using grpc::ServerContext; + +namespace tensorflow { + +class GrpcElasticService : public AsyncServiceInterface { + public: + GrpcElasticService(ElasticGrpcServer* elastic_grpc_server, + ::grpc::ServerBuilder* builder): + elastic_grpc_server_(elastic_grpc_server), builder_(builder) { + builder_->RegisterService(&elastic_service_); + cq_ = builder_->AddCompletionQueue(); + } + + ~GrpcElasticService() override { } + + void Shutdown() override { + cq_->Shutdown(); + } + + void HandleRPCsLoop() override { + new CallData(&elastic_service_, elastic_grpc_server_, cq_.get()); + void* tag; + bool ok; + while (true) { + // Block waiting to read the next event from the completion queue. The + // event is uniquely identified by its tag, which in this case is the + // memory address of a CallData instance. + // The return value of Next should always be checked. This return value + // tells us whether there is any kind of event or cq_ is shutting down. + GPR_ASSERT(cq_->Next(&tag, &ok)); + GPR_ASSERT(ok); + static_cast(tag)->Proceed(); + } + } + + private: + // Class encompasing the state and logic needed to serve a request. + class CallData { + public: + // Take in the "service" instance (in this case representing an asynchronous + // server) and the completion queue "cq" used for asynchronous communication + // with the gRPC runtime. + CallData(ElasticTrainingService::AsyncService* service, ElasticGrpcServer* elastic_grpc_server, + ServerCompletionQueue* cq) + : service_(service), elastic_grpc_server_(elastic_grpc_server), + cq_(cq), responder_(&ctx_), status_(CREATE) { + // Invoke the serving logic right away. + Proceed(); + } + + void Proceed() { + if (status_ == CREATE) { + // Make this instance progress to the PROCESS state. + status_ = PROCESS; + + // As part of the initial CREATE state, we *request* that the system + // start processing SayHello requests. In this request, "this" acts are + // the tag uniquely identifying the request (so that different CallData + // instances can serve different requests concurrently), in this case + // the memory address of this CallData instance. + service_->RequestUpdateServerDef(&ctx_, &request_, &responder_, + cq_, cq_, this); + } else if (status_ == PROCESS) { + // Spawn a new CallData instance to serve new clients while we process + // the one for this CallData. The instance will deallocate itself as + // part of its FINISH state. + new CallData(service_, elastic_grpc_server_, cq_); + + // The actual processing. + Status s = elastic_grpc_server_->Update(request_.cluster_def()); + if (s.ok()) { + reply_.set_code(Code::OK); + } else { + reply_.set_code(Code::INTERNAL); + reply_.set_msg(s.ToString()); + LOG(ERROR) << "error" << s.ToString(); + } + + // And we are done! Let the gRPC runtime know we've finished, using the + // memory address of this instance as the uniquely identifying tag for + // the event. + status_ = FINISH; + responder_.Finish(reply_, ::grpc::Status::OK, this); + } else { + GPR_ASSERT(status_ == FINISH); + // Once in the FINISH state, deallocate ourselves (CallData). + delete this; + } + } + private: + ElasticGrpcServer* elastic_grpc_server_; + // The means of communication with the gRPC runtime for an asynchronous + // server. + ElasticTrainingService::AsyncService* service_; + // The producer-consumer queue where for asynchronous server notifications. + ServerCompletionQueue* cq_; + // Context for the rpc, allowing to tweak aspects of it such as the use + // of compression, authentication, as well as to send metadata back to the + // client. + ServerContext ctx_; + + // What we get from the client. + UpdateServerDefRequest request_; + // What we send back to the client. + UpdateServerDefResponse reply_; + + // The means to get back to the client. + ServerAsyncResponseWriter responder_; + + // Let's implement a tiny state machine with the following states. + enum CallStatus { CREATE, PROCESS, FINISH }; + CallStatus status_; // The current serving state. + }; + + ElasticGrpcServer* elastic_grpc_server_; + ::grpc::ServerBuilder* builder_; + ElasticTrainingService::AsyncService elastic_service_; + std::unique_ptr<::grpc::ServerCompletionQueue> cq_; +}; + +AsyncServiceInterface* NewElasticGrpcService( + ElasticGrpcServer* elastic_grpc_server, ::grpc::ServerBuilder* builder) { + return reinterpret_cast(new GrpcElasticService(elastic_grpc_server, builder)); +} +} \ No newline at end of file diff --git a/tensorflow/contrib/elastic_grpc_server/elastic_service.h b/tensorflow/contrib/elastic_grpc_server/elastic_service.h new file mode 100644 index 00000000000..9465a10c918 --- /dev/null +++ b/tensorflow/contrib/elastic_grpc_server/elastic_service.h @@ -0,0 +1,31 @@ +/* Copyright 2023 The DeepRec Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_ELASTIC_GRPC_SERVER_ELASTIC_SERVICE_H_ +#define TENSORFLOW_CONTRIB_ELASTIC_GRPC_SERVER_ELASTIC_SERVICE_H_ + + +#include +#include "grpcpp/server_builder.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" +class ElasticGrpcServer; + +namespace tensorflow { + +class AsyncServiceInterface; +AsyncServiceInterface* NewElasticGrpcService( + ElasticGrpcServer* elastic_grpc_server, ::grpc::ServerBuilder* builder); + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_ELASTIC_GRPC_SERVER_ELASTIC_SERVICE_H_ \ No newline at end of file diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 95bbbab5624..0531200e7ab 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -139,6 +139,7 @@ load( "tf_lib_proto_parsing_deps", "tf_proto_library", "tf_proto_library_cc", + "tf_proto_library_py", "tf_protos_all", "tf_protos_all_impl", "tf_protos_grappler", @@ -2475,6 +2476,28 @@ tf_proto_library_cc( ], ) +tf_proto_library_cc( + name = "elastic_service_proto", + srcs = ["protobuf/elastic_training.proto"], + has_services = 1, + cc_api_version = 2, + cc_grpc_version = 1, + cc_stubby_versions = ["2"], + protodeps = tf_additional_all_protos(), + visibility = [ + "//tensorflow:internal", + ], +) + +tf_proto_library_py( + name = "elastic_service_pb", + srcs = ["protobuf/elastic_training.proto"], + use_grpc_plugin = True, + visibility = [ + "//tensorflow:internal", + ], +) + LIB_INTERNAL_PRIVATE_HEADERS = [ "framework/resource_handle.h", "//tensorflow/core/platform:legacy_lib_internal_headers", diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index 521c8f206f8..79d6b0cd65e 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -127,14 +127,11 @@ class GrpcServer : public ServerInterface { const ServerDef& server_def() const { return server_def_; } GrpcWorker* worker_impl() const { return worker_impl_.get(); } - - private: - // The overall server configuration. - const ServerDef server_def_; + protected: + // The overall server configuration. It may be changed during scaling. + ServerDef server_def_; Env* env_; - - // The port to which this server is bound. - int bound_port_ = 0; + GrpcServerOptions opts_; // Guards state transitions. mutex mu_; @@ -151,6 +148,9 @@ class GrpcServer : public ServerInterface { enum State { NEW, STARTED, STOPPED }; State state_ GUARDED_BY(mu_); + private: + // The port to which this server is bound. + int bound_port_ = 0; // Implementation of a TensorFlow master, and RPC polling thread. MasterEnv master_env_; std::unique_ptr master_impl_; diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 406285e7f0f..75d3c671562 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -769,6 +769,12 @@ def tf_additional_star_lib_defines(): "//conditions:default": [], }) +def tf_additional_elastic_server_lib_defines(): + return select({ + "//tensorflow:with_elastic_support": ["TENSORFLOW_USE_ELASTIC_SERVER"], + "//conditions:default": [], + }) + def tf_additional_api_compatible_defines(): return select({ "//tensorflow:with_api_compatible": ["TF_API_COMPATIBLE_1150"], diff --git a/tensorflow/core/platform/default/build_config_root.bzl b/tensorflow/core/platform/default/build_config_root.bzl index 71651faf0b1..38191dea3c4 100644 --- a/tensorflow/core/platform/default/build_config_root.bzl +++ b/tensorflow/core/platform/default/build_config_root.bzl @@ -77,6 +77,14 @@ def tf_additional_star_deps(): "//conditions:default": [], }) +def tf_additional_elastic_deps(): + return select({ + str(Label("//tensorflow:with_elastic_support")): [ + str(Label("//tensorflow/contrib/elastic_grpc_server:elastic_grpc_server_lib")), + ], + "//conditions:default": [], + }) + # Include specific extra dependencies when building statically, or # another set of dependencies otherwise. If "macos" is provided, that # dependency list is used when using the framework_shared_object config diff --git a/tensorflow/core/protobuf/elastic_training.proto b/tensorflow/core/protobuf/elastic_training.proto new file mode 100644 index 00000000000..ee0d0bd10e0 --- /dev/null +++ b/tensorflow/core/protobuf/elastic_training.proto @@ -0,0 +1,76 @@ +syntax = "proto3"; + +package des; + +enum Code { + OK = 0; + CANCELLED = 1; + UNKNOWN = 2; + INVALID_ARGUMENT = 3; + DEADLINE_EXCEEDED = 4; + NOT_FOUND = 5; + ALREADY_EXISTS = 6; + PERMISSION_DENIED = 7; + RESOURCE_EXHAUSTED = 8; + FAILED_PRECONDITION = 9; + ABORTED = 10; + OUT_OF_RANGE = 11; + UNIMPLEMENTED = 12; + INTERNAL = 13; + UNAVAILABLE = 14; + DATA_LOSS = 15; + UNAUTHENTICATED = 16; + REQUEST_STOP = 17; +} + +enum ElasticTrainingState { + READY = 0; + SCALING = 1; + All_SESSION_CLOSED = 2; +} + +enum ScalingAction { + NONE = 0; + SCALING_UP = 1; + SCALING_DOWN = 2; +} + +message IsReadyScalingRequest { + int32 task_index = 1; +} + +message IsReadyScalingResponse { + Code code = 1; + string msg = 2; + ScalingAction scaling_action = 3; + int32 ps_num = 4; // updated ps_num; +} + +message ReadyToUpdateRequest {}; +message ReadyToUpdateResponse {}; + +message UpdateServerDefRequest { + string cluster_def = 1;//serialized cluster_def +} + +message UpdateServerDefResponse { + Code code = 1; + string msg = 2; +} + +message FetchParamsRequest { + repeated string names = 1; // vec of partitioned variables or ev +} + +message FetchParamsResponse { + Code code = 1; + string msg = 2; + map param_partition_map = 3; // per partition num of variable +} + +service ElasticTrainingService { + rpc IsReadyScaling(IsReadyScalingRequest) returns (IsReadyScalingResponse); + rpc ReadyToUpdate(ReadyToUpdateRequest) returns (ReadyToUpdateResponse); + rpc UpdateServerDef(UpdateServerDefRequest) returns (UpdateServerDefResponse); + rpc FetchParamsMeta(FetchParamsRequest) returns (FetchParamsResponse); +} \ No newline at end of file diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 68649078f5c..a740e0916d9 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -24,7 +24,7 @@ load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_tests") load("//tensorflow/core/platform:default/build_config.bzl", "pyx_library", "tf_additional_all_protos", "tf_additional_cupti_test_flags", "tf_additional_lib_deps", "tf_proto_library", "tf_proto_library_py", "tf_protos_grappler") # @unused -load("//tensorflow/core/platform:default/build_config_root.bzl", "if_static", "tf_additional_gdr_deps", "tf_additional_mpi_deps", "tf_additional_plugin_deps", "tf_additional_verbs_deps", "tf_additional_star_deps") +load("//tensorflow/core/platform:default/build_config_root.bzl", "if_static", "tf_additional_gdr_deps", "tf_additional_mpi_deps", "tf_additional_plugin_deps", "tf_additional_verbs_deps", "tf_additional_star_deps", "tf_additional_elastic_deps") load("//tensorflow/python:build_defs.bzl", "tf_gen_op_wrapper_private_py") load( "//third_party/ngraph:build_defs.bzl", @@ -5307,6 +5307,7 @@ tf_py_wrap_cc( tf_additional_verbs_deps() + tf_additional_mpi_deps() + tf_additional_gdr_deps() + + tf_additional_elastic_deps() + tf_additional_star_deps()) + if_ngraph([ "@ngraph_tf//:ngraph_tf", ]),