forked from secretflow/scql
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmux_link_factory.cc
308 lines (264 loc) · 10.7 KB
/
mux_link_factory.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
// Copyright 2023 Ant Group Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "engine/link/mux_link_factory.h"
#include "brpc/closure_guard.h"
#include "bthread/bthread.h"
#include "bthread/condition_variable.h"
#include "spdlog/spdlog.h"
namespace scql::engine {
// NOTE: Throw NetworkError for ErrorCode::LINKID_NOT_FOUND:
// since peer's Context corresponding to link_id_ may be available later,
// while ConnectToMesh() in lib-yacl only catch NetworkError and
// retry.
#ifndef THROW_IF_RPC_NOT_OK
#define THROW_IF_RPC_NOT_OK(cntl, response, request_info) \
do { \
if (cntl.Failed()) { \
YACL_THROW_NETWORK_ERROR("send failed: {}, rpc failed={}, message={}.", \
request_info, cntl.ErrorCode(), \
cntl.ErrorText()); \
} else if (response.error_code() != link::pb::ErrorCode::SUCCESS) { \
std::string error_info = fmt::format( \
"send failed: {}, peer failed code={}, message={}.", request_info, \
response.error_code(), response.error_msg()); \
if (response.error_code() == link::pb::ErrorCode::LINKID_NOT_FOUND) { \
YACL_THROW_NETWORK_ERROR(error_info); \
} \
YACL_THROW(error_info); \
} \
} while (false)
#endif
std::shared_ptr<yacl::link::Context> MuxLinkFactory::CreateContext(
const yacl::link::ContextDesc& desc, size_t self_rank) {
const size_t world_size = desc.parties.size();
YACL_ENFORCE(self_rank < world_size,
"invalid arg: self rank={} not small than world_size={}",
self_rank, world_size);
// 1. create channels.
std::vector<std::shared_ptr<yacl::link::IChannel>> channels(world_size);
for (size_t rank = 0; rank < world_size; rank++) {
if (rank == self_rank) {
continue;
}
const auto& peer_host = desc.parties[rank].host;
auto rpc_channel =
channel_manager_->Create(peer_host, RemoteRole::PeerEngine);
YACL_ENFORCE(rpc_channel, "create rpc channel failed for rank={}", rank);
channels[rank] = std::make_shared<MuxLinkChannel>(
self_rank, rank, desc.recv_timeout_ms, desc.http_max_payload_size,
desc.id, rpc_channel);
}
// 2. add channels to ListenManager.
auto listener = std::make_shared<Listener>();
for (size_t rank = 0; rank < world_size; rank++) {
if (rank == self_rank) {
continue;
}
listener->AddChannel(rank, channels[rank]);
}
listener_manager_->AddListener(desc.id, listener);
// 3. construct Context.
auto ctx = std::make_shared<yacl::link::Context>(
desc, self_rank, std::move(channels), nullptr, false);
return ctx;
}
void MuxLinkChannel::SendImpl(const std::string& key,
yacl::ByteContainerView value) {
if (value.size() > http_max_payload_size_) {
SendChunked(key, value);
return;
}
link::pb::MuxPushRequest request;
{
request.set_link_id(link_id_);
auto msg = request.mutable_msg();
msg->set_sender_rank(self_rank_);
msg->set_key(key);
msg->set_value(value.data(), value.size());
msg->set_trans_type(link::pb::TransType::MONO);
}
link::pb::MuxPushResponse response;
brpc::Controller cntl;
link::pb::MuxReceiverService::Stub stub(rpc_channel_.get());
stub.Push(&cntl, &request, &response, nullptr);
std::string request_info = fmt::format(
"link_id={} sender_rank={} send key={}", link_id_, self_rank_, key);
THROW_IF_RPC_NOT_OK(cntl, response, request_info);
return;
}
namespace {
class OnPushDone : public google::protobuf::Closure {
public:
OnPushDone(std::shared_ptr<MuxLinkChannel> channel, std::string request_info)
: channel_(std::move(channel)), request_info_(std::move(request_info)) {
channel_->AddAsyncCount();
}
~OnPushDone() {
try {
channel_->SubAsyncCount();
} catch (const std::exception& ex) {
SPDLOG_WARN(ex.what());
}
}
void Run() {
std::unique_ptr<OnPushDone> self_guard(this);
std::string error_msg;
if (cntl_.Failed()) {
SPDLOG_ERROR("async send failed: {}, rpc failed={}, message={}",
request_info_, cntl_.ErrorCode(), cntl_.ErrorText());
} else if (response_.error_code() != link::pb::ErrorCode::SUCCESS) {
SPDLOG_ERROR("async send failed: {}, peer failed, message={}",
request_info_, response_.error_code());
}
}
link::pb::MuxPushResponse response_;
brpc::Controller cntl_;
const std::shared_ptr<MuxLinkChannel> channel_;
const std::string request_info_;
};
struct SendChunckedBrpcTask {
std::shared_ptr<MuxLinkChannel> channel;
std::string key;
yacl::Buffer value;
SendChunckedBrpcTask(std::shared_ptr<MuxLinkChannel> _channel,
std::string _key, yacl::Buffer _value)
: channel(std::move(_channel)),
key(std::move(_key)),
value(std::move(_value)) {
channel->AddAsyncCount();
}
~SendChunckedBrpcTask() {
try {
channel->SubAsyncCount();
} catch (const std::exception& ex) {
SPDLOG_WARN(ex.what());
}
}
static void* Proc(void* args) {
// take ownership of task.
std::unique_ptr<SendChunckedBrpcTask> task(
static_cast<SendChunckedBrpcTask*>(args));
try {
task->channel->SendChunked(task->key, task->value);
} catch (const std::exception& e) {
SPDLOG_ERROR("chunked async send failed. key={}, failed={}", task->key,
e.what());
}
return nullptr;
}
};
} // namespace
template <class ValueType>
void MuxLinkChannel::SendAsyncInternal(const std::string& key,
ValueType&& value) {
if (static_cast<size_t>(value.size()) > http_max_payload_size_) {
auto btask = std::make_unique<SendChunckedBrpcTask>(
this->shared_from_this(), key,
yacl::Buffer(std::forward<ValueType>(value)));
// bthread run in 'detached' mode, we will never wait for it.
bthread_t tid;
if (bthread_start_background(&tid, nullptr, SendChunckedBrpcTask::Proc,
btask.get()) == 0) {
// bthread takes the ownership, release it.
static_cast<void>(btask.release());
} else {
YACL_THROW("failed to push async sending job to bthread");
}
return;
}
link::pb::MuxPushRequest request;
{
request.set_link_id(link_id_);
auto msg = request.mutable_msg();
msg->set_sender_rank(self_rank_);
msg->set_key(key);
msg->set_value(value.data(), value.size());
msg->set_trans_type(link::pb::TransType::MONO);
}
std::string request_info = fmt::format(
"link_id={} sender_rank={} send_key={}", link_id_, self_rank_, key);
auto* done = new OnPushDone(shared_from_this(), std::move(request_info));
link::pb::MuxReceiverService::Stub stub(rpc_channel_.get());
stub.Push(&done->cntl_, &request, &done->response_, done);
}
namespace {
class BatchDesc {
protected:
size_t batch_idx_;
size_t batch_size_;
size_t total_size_;
public:
BatchDesc(size_t batch_idx, size_t batch_size, size_t total_size)
: batch_idx_(batch_idx),
batch_size_(batch_size),
total_size_(total_size) {}
// return the index of this batch.
size_t Index() const { return batch_idx_; }
// return the offset of the first element in this batch.
size_t Begin() const { return batch_idx_ * batch_size_; }
// return the offset after last element in this batch.
size_t End() const { return std::min(Begin() + batch_size_, total_size_); }
// return the size of this batch.
size_t Size() const { return End() - Begin(); }
std::string ToString() const { return "B:" + std::to_string(batch_idx_); };
};
} // namespace
void MuxLinkChannel::SendChunked(const std::string& key,
yacl::ByteContainerView value) {
const size_t bytes_per_chunk = http_max_payload_size_;
const size_t num_bytes = value.size();
const size_t num_chunks = (num_bytes + bytes_per_chunk - 1) / bytes_per_chunk;
constexpr uint32_t kParallelSize = 10;
const size_t batch_size = kParallelSize;
const size_t num_batches = (num_chunks + batch_size - 1) / batch_size;
for (size_t batch_idx = 0; batch_idx < num_batches; batch_idx++) {
const BatchDesc batch(batch_idx, batch_size, num_chunks);
// See: "半同步“ from
// https://github.com/apache/incubator-brpc/blob/master/docs/cn/client.md
std::vector<brpc::Controller> cntls(batch.Size());
std::vector<link::pb::MuxPushResponse> responses(batch.Size());
// fire batched chunk requests.
for (size_t idx = 0; idx < batch.Size(); idx++) {
const size_t chunk_idx = batch.Begin() + idx;
const size_t chunk_offset = chunk_idx * bytes_per_chunk;
link::pb::MuxPushRequest request;
{
request.set_link_id(link_id_);
auto msg = request.mutable_msg();
msg->set_sender_rank(self_rank_);
msg->set_key(key);
msg->set_value(value.data() + chunk_offset,
std::min(bytes_per_chunk, value.size() - chunk_offset));
msg->set_trans_type(link::pb::TransType::CHUNKED);
msg->mutable_chunk_info()->set_chunk_offset(chunk_offset);
msg->mutable_chunk_info()->set_message_length(num_bytes);
}
auto& cntl = cntls[idx];
auto& response = responses[idx];
link::pb::MuxReceiverService::Stub stub(rpc_channel_.get());
stub.Push(&cntl, &request, &response, brpc::DoNothing());
}
for (size_t idx = 0; idx < batch.Size(); idx++) {
brpc::Join(cntls[idx].call_id());
}
for (size_t idx = 0; idx < batch.Size(); idx++) {
const size_t chunk_idx = batch.Begin() + idx;
std::string request_info = fmt::format(
"link_id={}, sender_rank={}, key={} (chunked {} out of {})", link_id_,
self_rank_, key, chunk_idx + 1, num_chunks);
THROW_IF_RPC_NOT_OK(cntls[idx], responses[idx], request_info);
}
}
}
} // namespace scql::engine