-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ParameterServerController for parameter server python api #1051
Changes from 7 commits
f3c61cb
f9a65b0
95f20b9
cfbb4c4
7783982
93e74f8
3f6c2b3
5aaaef4
b1eeb2e
d32c7a6
aa9f516
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "PServerController.h" | ||
|
||
namespace paddle { | ||
|
||
PServerController::PServerController(const ParameterServerConfig& config) { | ||
// round robin to load balance RDMA server ENGINE | ||
std::vector<std::string> devices; | ||
int rdmaCpu = 0; | ||
int onlineCpus = rdma::numCpus(); | ||
int numPorts = config.ports_num() + config.ports_num_for_sparse(); | ||
|
||
if (config.nics().empty()) { | ||
pservers_.resize(numPorts); | ||
for (int i = 0; i < numPorts; ++i) { | ||
if (config.rdma_tcp() == "rdma") { | ||
pservers_[i].reset( | ||
new ParameterServer2(std::string(), config.port() + i, rdmaCpu++)); | ||
rdmaCpu = rdmaCpu % onlineCpus; | ||
} else { | ||
pservers_[i].reset( | ||
new ParameterServer2(std::string(), config.port() + i)); | ||
} | ||
CHECK(pservers_[i]->init()) << "Fail to initialize parameter server" | ||
<< config.port() + i; | ||
} | ||
} else { | ||
str::split(config.nics(), ',', &devices); | ||
pservers_.resize(devices.size() * numPorts); | ||
for (int i = 0; i < numPorts; ++i) { | ||
for (size_t j = 0; j < devices.size(); ++j) { | ||
if (config.rdma_tcp() == "rdma") { | ||
pservers_[i * devices.size() + j].reset(new ParameterServer2( | ||
getIpAddr(devices[j]), config.port() + i, rdmaCpu++)); | ||
rdmaCpu = rdmaCpu % onlineCpus; | ||
} else { | ||
pservers_[i * devices.size() + j].reset( | ||
new ParameterServer2(getIpAddr(devices[j]), config.port() + i)); | ||
} | ||
CHECK(pservers_[i * devices.size() + j]->init()) | ||
<< "Fail to initialize parameter server" << devices[j] | ||
<< config.port() + i; | ||
} | ||
} | ||
} | ||
} | ||
|
||
PServerController::~PServerController() { this->join(); } | ||
|
||
PServerController* PServerController::createByGflags() { | ||
ParameterServerConfig config; | ||
|
||
config.set_nics(FLAGS_nics); | ||
config.set_rdma_tcp(FLAGS_rdma_tcp); | ||
config.set_port(FLAGS_port); | ||
config.set_ports_num(FLAGS_ports_num); | ||
config.set_ports_num_for_sparse(FLAGS_ports_num_for_sparse); | ||
|
||
return create(config); | ||
} | ||
|
||
PServerController* PServerController::create( | ||
const ParameterServerConfig& config) { | ||
return new PServerController(config); | ||
} | ||
|
||
void PServerController::start() { | ||
LOG(INFO) << "pserver sizes : " << pservers_.size(); | ||
int i = 0; | ||
for (const auto& pserver : pservers_) { | ||
LOG(INFO) << "pserver started : " << i; | ||
pserver->start(); | ||
i++; | ||
} | ||
} | ||
|
||
void PServerController::join() { | ||
LOG(INFO) << "pserver sizes : " << pservers_.size(); | ||
int i = 0; | ||
for (const auto& pserver : pservers_) { | ||
LOG(INFO) << "pserver join : " << i; | ||
pserver->join(); | ||
i++; | ||
} | ||
} | ||
|
||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "ParameterServer2.h" | ||
#include "ParameterServerConfig.pb.h" | ||
#include "RDMANetwork.h" | ||
#include "paddle/utils/StringUtil.h" | ||
|
||
namespace paddle { | ||
|
||
class PServerController final { | ||
public: | ||
DISABLE_COPY(PServerController); | ||
|
||
/** | ||
* @brief Ctor, Create a PServerUtil from ParameterServerConfig. | ||
*/ | ||
explicit PServerController(const ParameterServerConfig& config); | ||
|
||
/** | ||
* @brief Dtor. | ||
*/ | ||
~PServerController(); | ||
|
||
/** | ||
* @brief create PServerUtil from gflags, this is used for | ||
* compatibility with the old usage of configuration by gflags. | ||
*/ | ||
static PServerController* createByGflags(); | ||
|
||
/** | ||
* @brief create PServerUtil with ParameterServerConfig, remove gflags | ||
* from ParameterServer. Init all pservers thread according to the config. | ||
*/ | ||
static PServerController* create(const ParameterServerConfig& config); | ||
|
||
/** | ||
* @brief start all pserver thread in this PServerUtil. | ||
*/ | ||
void start(); | ||
|
||
/** | ||
* @brief join and wait for all pserver thread in this PServerUtil. | ||
*/ | ||
void join(); | ||
|
||
private: | ||
std::vector<std::unique_ptr<ParameterServer2>> pservers_; | ||
}; | ||
|
||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
syntax = "proto2"; | ||
|
||
package paddle; | ||
|
||
message ParameterClientConfig { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里应该有个注释说明这个proto message的用意。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done,写的比较简单 |
||
required int32 trainer_id = 1; | ||
} | ||
|
||
message ParameterServerConfig { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里应该有个注释说明这个proto message的用意。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
// The ports number for parameter send, | ||
// increment based on default port number | ||
required int32 ports_num = 1 [default = 1]; | ||
// The ports number for parameter send, | ||
// increment based on default (port + ports_num | ||
required int32 ports_num_for_sparse = 2 [default = 0]; | ||
// network device name for pservers | ||
required string nics = 3 [default = "xgbe0,xgbe1"]; | ||
required string rdma_tcp = 4 [default = "tcp"]; | ||
// Listening port for pserver | ||
required int32 port = 5 [default = 20134]; | ||
// number of gradient servers | ||
required int32 num_gradient_servers = 6 [default = 1]; | ||
// number of threads for sync op exec | ||
required int32 pserver_num_threads = 7 [default = 1]; | ||
// control config_.async_lagged_grad_discard_ratio() min value | ||
required double async_lagged_ratio_min = 8 [default = 1.0]; | ||
// if async_lagged_grad_discard_ratio is not set in trainer_config.conf | ||
// use it as defalut value | ||
required double async_lagged_ratio_default = 9 [default = 1.5]; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上面有文件叫 PServerUtils.*,这里叫ParameterServer,显然不一致呀。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个配置文件确实是用来配置parameter server的,目前的pserverutil封装了几个parameter server线程,根据config来创建这些线程。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我的意思是到底应该叫 pserver 还是 parameter server 呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经按照命名规范修改为ParameterServerController