Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ParameterServerController for parameter server python api #1051

Merged
merged 11 commits into from
Jan 11, 2017
Merged
1 change: 1 addition & 0 deletions demo/quick_start/cluster/cluster_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ log_file="$bin_dir/train.log"
pushd "$home_dir"
cfg=trainer_config.lr.py
paddle train \
--start_pserver=false \
--config=$cfg \
--save_dir=${model_dir} \
--trainer_count=4 \
Expand Down
6 changes: 4 additions & 2 deletions paddle/pserver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ set(PSERVER_SOURCES
BaseClient.cpp
ParameterClient2.cpp
ParameterServer2.cpp
SparseParameterDistribution.cpp)
SparseParameterDistribution.cpp
PServerUtil.cpp)

set(PSERVER_HEADERS
BaseClient.h
ParameterClient2.h
ParameterServer2.h
SparseParameterDistribution.h)
SparseParameterDistribution.h
PServerUtil.h)

add_library(paddle_pserver STATIC
${PSERVER_SOURCES})
Expand Down
101 changes: 101 additions & 0 deletions paddle/pserver/PServerUtil.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "PServerUtil.h"

namespace paddle {

PServerUtil::PServerUtil(const ParameterServerConfig& config) {
// round robin to load balance RDMA server ENGINE
std::vector<std::string> devices;
int rdmaCpu = 0;
int onlineCpus = rdma::numCpus();
int numPorts = config.ports_num() + config.ports_num_for_sparse();

if (config.nics().empty()) {
pservers_.resize(numPorts);
for (int i = 0; i < numPorts; ++i) {
if (config.rdma_tcp() == "rdma") {
pservers_[i].reset(
new ParameterServer2(std::string(), config.port() + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
pservers_[i].reset(
new ParameterServer2(std::string(), config.port() + i));
}
CHECK(pservers_[i]->init()) << "Fail to initialize parameter server"
<< config.port() + i;
}
} else {
str::split(config.nics(), ',', &devices);
pservers_.resize(devices.size() * numPorts);
for (int i = 0; i < numPorts; ++i) {
for (size_t j = 0; j < devices.size(); ++j) {
if (config.rdma_tcp() == "rdma") {
pservers_[i * devices.size() + j].reset(new ParameterServer2(
getIpAddr(devices[j]), config.port() + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
pservers_[i * devices.size() + j].reset(
new ParameterServer2(getIpAddr(devices[j]), config.port() + i));
}
CHECK(pservers_[i * devices.size() + j]->init())
<< "Fail to initialize parameter server" << devices[j]
<< config.port() + i;
}
}
}
}

PServerUtil::~PServerUtil() { this->join(); }

ParameterServerConfig* PServerUtil::initConfig() {
ParameterServerConfig* config = new ParameterServerConfig();
config->set_nics(FLAGS_nics);
config->set_port(FLAGS_port);
config->set_ports_num(FLAGS_ports_num);
config->set_rdma_tcp(FLAGS_rdma_tcp);
return config;
}

PServerUtil* PServerUtil::createWithGflags() {
auto& pServerConfig = *paddle::PServerUtil::initConfig();
return create(pServerConfig);
}

PServerUtil* PServerUtil::create(const ParameterServerConfig& config) {
return new PServerUtil(config);
}

void PServerUtil::start() {
LOG(INFO) << "pserver sizes : " << pservers_.size();
int i = 0;
for (const auto& pserver : pservers_) {
LOG(INFO) << "pserver started : " << i;
pserver->start();
i++;
}
}

void PServerUtil::join() {
LOG(INFO) << "pserver sizes : " << pservers_.size();
int i = 0;
for (const auto& pserver : pservers_) {
LOG(INFO) << "pserver join : " << i;
pserver->join();
i++;
}
}

} // namespace paddle
70 changes: 70 additions & 0 deletions paddle/pserver/PServerUtil.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我现在对util这样的命名特别紧张——我发现一般都是大家懒得想明白应该叫什么的时候就叫util了。

在这里,看上去是想叫 PServerController 什么的。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool,同意。

PServerController是个好名字。
感觉程序员起名字是一个非常痛苦的事情。一痛苦就会用一些比较常用的名字,比如Utils

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "ParameterServer2.h"
#include "ParameterServerConfig.pb.h"
#include "RDMANetwork.h"
#include "paddle/utils/StringUtil.h"

namespace paddle {

class PServerUtil {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class PServerUtil => class PServerUtil final

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

public:
DISABLE_COPY(PServerUtil);

/**
* @brief Ctor, Create a PServerUtil from ParameterServerConfig.
*/
explicit PServerUtil(const ParameterServerConfig& config);

/**
* @brief Dtor.
*/
~PServerUtil();

/**
* @brief create PServerUtil from gflags, this is used for
* compatibility with the old usage of configuration by gflags.
*/
static PServerUtil* createWithGflags();

/**
* @brief create PServerUtil with ParameterServerConfig, remove gflags
* from ParameterServer. Init all pservers thread according to the config.
*/
static PServerUtil* create(const ParameterServerConfig& config);

/**
* @brief start all pserver thread in this PServerUtil.
*/
void start();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我们的naming convention指定的怎么样了? @reyoung

在这里constructor是camel形式,但是methods都是小写。显然不一致呀。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里因为构造函数和类名必须一致。类名必须是UpperCamelCase。比如 "SomeClass".而函数名是"lowerCamelCase",比如"someMethod"。

这样做的好处是,我们可以通过判断出一个东西是不是类型了。

比如

class SomeClass {
public:
    class Helper {
    };

   static Helper helper();
};

SomeClass::Helper 是类型,而SomeClass::helper是函数。


/**
* @brief join and wait for all pserver thread in this PServerUtil.
*/
void join();

private:
std::vector<std::shared_ptr<ParameterServer2>> pservers_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::vector<std::shared_ptr<ParameterServer2>> =>std::vector<std::unique_ptr<ParameterServer2>>

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


/**
* @brief create ParameterServerConfig from gflags, this is used for
* compatibility with the old usage of configuration by gflags.
*/
static ParameterServerConfig* initConfig();
};

} // namespace paddle
60 changes: 6 additions & 54 deletions paddle/pserver/ParameterServer2Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,66 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include <fstream>
#include "paddle/utils/StringUtil.h"
#include "paddle/utils/Util.h"

#include "ParameterServer2.h"
#include "RDMANetwork.h"
#include "paddle/utils/Flags.h"
#include "PServerUtil.h"
#include "paddle/trainer/ParamUtil.h"

using namespace paddle; // NOLINT

int main(int argc, char** argv) {
initMain(argc, argv);

std::vector<std::string> devices;
std::vector<std::shared_ptr<ParameterServer2>> pservers;

// round robin to loadbalance RDMA server ENGINE
int rdmaCpu = 0;
int onlineCpus = rdma::numCpus();
int numPorts = FLAGS_ports_num + FLAGS_ports_num_for_sparse;
if (FLAGS_nics.empty()) {
pservers.resize(numPorts);
for (int i = 0; i < numPorts; ++i) {
if (FLAGS_rdma_tcp == "rdma") {
pservers[i].reset(
new ParameterServer2(std::string(), FLAGS_port + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
pservers[i].reset(new ParameterServer2(std::string(), FLAGS_port + i));
}
CHECK(pservers[i]->init()) << "Fail to initialize parameter server"
<< FLAGS_port + i;
LOG(INFO) << "pserver started : " << FLAGS_port + i;
pservers[i]->start();
}
} else {
str::split(FLAGS_nics, ',', &devices);
pservers.resize(devices.size() * numPorts);
for (int i = 0; i < numPorts; ++i) {
for (size_t j = 0; j < devices.size(); ++j) {
if (FLAGS_rdma_tcp == "rdma") {
pservers[i * devices.size() + j].reset(new ParameterServer2(
getIpAddr(devices[j]), FLAGS_port + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
pservers[i * devices.size() + j].reset(
new ParameterServer2(getIpAddr(devices[j]), FLAGS_port + i));
}
CHECK(pservers[i * devices.size() + j]->init())
<< "Fail to initialize parameter server" << devices[j]
<< FLAGS_port + i;
LOG(INFO) << "pserver started : " << devices[j] << ":"
<< FLAGS_port + i;
pservers[i * devices.size() + j]->start();
}
}
}

for (auto& pserver : pservers) {
pserver->join();
}
std::unique_ptr<PServerUtil> pServerPtr(
paddle::PServerUtil::createWithGflags());
pServerPtr->start();
pServerPtr->join();

return 0;
}
54 changes: 4 additions & 50 deletions paddle/trainer/TrainerMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include <fenv.h>
#include "paddle/pserver/ParameterServer2.h"
#include "paddle/pserver/PServerUtil.h"
#include "paddle/utils/Excepts.h"
#include "paddle/utils/PythonUtil.h"
#include "paddle/utils/StringUtil.h"

#include "ParamUtil.h"
#include "Trainer.h"
#include "paddle/pserver/RDMANetwork.h"

DEFINE_bool(start_pserver, false, "Whether to start pserver");
DECLARE_int32(gpu_id);
Expand All @@ -39,54 +37,10 @@ int main(int argc, char** argv) {
initMain(argc, argv);
initPython(argc, argv);

std::vector<std::unique_ptr<ParameterServer2>> pservers;
std::vector<std::string> devices;

std::unique_ptr<PServerUtil> pServerPtr(nullptr);
if (FLAGS_start_pserver) {
// round robin to loadbalance RDMA server ENGINE
int rdmaCpu = 0;
int onlineCpus = rdma::numCpus();
int numPorts = FLAGS_ports_num + FLAGS_ports_num_for_sparse;
if (FLAGS_nics.empty()) {
pservers.resize(numPorts);
for (int i = 0; i < numPorts; ++i) {
if (FLAGS_rdma_tcp == "rdma") {
pservers[i].reset(
new ParameterServer2(std::string(), FLAGS_port + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
pservers[i].reset(
new ParameterServer2(std::string(), FLAGS_port + i));
}

CHECK(pservers[i]->init()) << "Fail to initialize parameter server"
<< FLAGS_port + i;
LOG(INFO) << "pserver started : " << FLAGS_port + i;
pservers[i]->start();
}
} else {
str::split(FLAGS_nics, ',', &devices);
pservers.resize(devices.size() * numPorts);
for (int i = 0; i < numPorts; ++i) {
for (size_t j = 0; j < devices.size(); ++j) {
if (FLAGS_rdma_tcp == "rdma") {
pservers[i * devices.size() + j].reset(new ParameterServer2(
getIpAddr(devices[j]), FLAGS_port + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
pservers[i * devices.size() + j].reset(
new ParameterServer2(getIpAddr(devices[j]), FLAGS_port + i));
}

CHECK(pservers[i * devices.size() + j]->init())
<< "Fail to initialize parameter server" << devices[j]
<< FLAGS_port + i;
LOG(INFO) << "pserver started : " << devices[j] << ":"
<< FLAGS_port + i;
pservers[i * devices.size() + j]->start();
}
}
}
pServerPtr.reset(paddle::PServerUtil::createWithGflags());
pServerPtr->start();
}
Trainer trainer;
auto config = TrainerConfigHelper::createFromFlags();
Expand Down
3 changes: 2 additions & 1 deletion proto/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ set(proto_filenames
ModelConfig.proto
ParameterConfig.proto
ParameterService.proto
TrainerConfig.proto)
TrainerConfig.proto
ParameterServerConfig.proto)

set(PROTO_GEN)
set(PROTO_GEN_PY)
Expand Down
Loading