-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tanakamura] 読み進めメモ #9
Comments
MPIの初期化時 ncclGetUniqueId に設定用サーバーのアドレスとポートが入っている。 (ので、NCCL自体はMPIには依存してない。ncclGetUniqueIdをブロードキャストする方法はユーザに委ねられている) |
リングトポロジの決めかた 4node x 2GPU
2node x 2GPU
1node x 2GPU
4node x 1GPU
2node x 1GPU
|
ncclBarrierEnqueueWaitでバックトレース取ったときの挙動 B C === trace (b enqueue) sta === include/debug.h:56 /home/tanakamura/shared/nccl/src/enqueue.cc:154 misc/group.cc:141 /home/tanakamura/shared/test/nccl/a.cu:115 ??:0 ??:0 === trace (b enqueue) end === === trace (b enqueue) sta === include/debug.h:56 /home/tanakamura/shared/nccl/src/enqueue.cc:154 misc/group.cc:141 /home/tanakamura/shared/test/nccl/a.cu:115 ??:0 ??:0 === trace (b enqueue) end === a.cu:115 は ncclGroupEnd。 コード見たらncclGroupStartはstartした回数をカウントしてるだけ。 |
今知ったが、共有ライブラリの実行時のオフセットも考慮したaddr2lineをしてくれるeu-addr2lineという神ツールがある。 |
ncclGroupStart, ncclAllReduce, ncclGroupEnd を 実行したときのcuda呼び出しのトレース iter 0 === trace (cudaGetDevice(&savedDev),tid=13019) sta === dump_trace include/debug.h:57 ncclGroupEnd misc/group.cc:119 main /home/tanakamura/shared/test/nccl/a.cu:109 __libc_start_main ??:0 _start ??:0 === trace (cudaGetDevice(&savedDev),tid=13019) end === === trace (b enqueue,tid=13019) sta === dump_trace include/debug.h:57 ncclBarrierEnqueue(ncclComm*) /home/tanakamura/shared/nccl/src/enqueue.cc:156 ncclGroupEnd misc/group.cc:141 main /home/tanakamura/shared/test/nccl/a.cu:109 __libc_start_main ??:0 _start ??:0 === trace (b enqueue,tid=13019) end === === trace (cudaStreamWaitEvent(comm->userStream, comm->doneEvent, 0),tid=13019) sta === dump_trace include/debug.h:57 ncclBarrierEnqueue(ncclComm*) /home/tanakamura/shared/nccl/src/enqueue.cc:168 ncclGroupEnd misc/group.cc:141 main /home/tanakamura/shared/test/nccl/a.cu:109 __libc_start_main ??:0 _start ??:0 === trace (cudaStreamWaitEvent(comm->userStream, comm->doneEvent, 0),tid=13019) end === === trace (b enqueue,tid=13019) sta === dump_trace include/debug.h:57 ncclBarrierEnqueue(ncclComm*) /home/tanakamura/shared/nccl/src/enqueue.cc:156 ncclGroupEnd misc/group.cc:141 main /home/tanakamura/shared/test/nccl/a.cu:109 __libc_start_main ??:0 _start ??:0 === trace (b enqueue,tid=13019) end === === trace (cudaStreamWaitEvent(comm->userStream, comm->doneEvent, 0),tid=13019) sta === dump_trace include/debug.h:57 ncclBarrierEnqueue(ncclComm*) /home/tanakamura/shared/nccl/src/enqueue.cc:168 ncclGroupEnd misc/group.cc:141 main /home/tanakamura/shared/test/nccl/a.cu:109 __libc_start_main ??:0 _start ??:0 === trace (cudaStreamWaitEvent(comm->userStream, comm->doneEvent, 0),tid=13019) end === === trace (cudaLaunchCooperativeKernelMultiDevice(paramsList, numDevices, cudaCooperativeLaunchMultiDeviceNoPreSync|cudaCooperativeLaunchMultiDeviceNoPostSync),tid=13019) sta === dump_trace include/debug.h:57 ncclLaunchCooperativeKernelMultiDevice(cudaLaunchParams*, int*, int, int) /home/tanakamura/shared/nccl/src/enqueue.cc:72 ncclBarrierEnqueue(ncclComm*) /home/tanakamura/shared/nccl/src/enqueue.cc:179 ncclGroupEnd misc/group.cc:141 main /home/tanakamura/shared/test/nccl/a.cu:109 __libc_start_main ??:0 _start ??:0 === trace (cudaLaunchCooperativeKernelMultiDevice(paramsList, numDevices, cudaCooperativeLaunchMultiDeviceNoPreSync|cudaCooperativeLaunchMultiDeviceNoPostSync),tid=13019) end === nccl-study01:13019:13019 [0] NCCL INFO Launch mode Group/CGMD === trace (cudaEventRecord(comm->doneEvent, params->stream),tid=13019) sta === dump_trace include/debug.h:57 ncclEnqueueEvents(ncclComm*) /home/tanakamura/shared/nccl/src/enqueue.cc:223 ncclGroupEnd misc/group.cc:156 main /home/tanakamura/shared/test/nccl/a.cu:109 __libc_start_main ??:0 _start ??:0 === trace (cudaEventRecord(comm->doneEvent, params->stream),tid=13019) end === === trace (cudaEventRecord(comm->doneEvent, params->stream),tid=13019) sta === dump_trace include/debug.h:57 ncclEnqueueEvents(ncclComm*) /home/tanakamura/shared/nccl/src/enqueue.cc:223 ncclGroupEnd misc/group.cc:156 main /home/tanakamura/shared/test/nccl/a.cu:109 __libc_start_main ??:0 _start ??:0 === trace (cudaEventRecord(comm->doneEvent, params->stream),tid=13019) end === === trace (cudaSetDevice(savedDev),tid=13019) sta === dump_trace include/debug.h:57 ncclGroupEnd misc/group.cc:196 main /home/tanakamura/shared/test/nccl/a.cu:109 __libc_start_main ??:0 _start ??:0 === trace (cudaSetDevice(savedDev),tid=13019) end === |
10:=== trace (cudaGetDevice(&savedDev),tid=13019) sta === 22:=== trace (b enqueue,tid=13019) sta === 36:=== trace (cudaStreamWaitEvent(comm->userStream, comm->doneEvent, 0),tid=13019) sta === 50:=== trace (b enqueue,tid=13019) sta === 64:=== trace (cudaStreamWaitEvent(comm->userStream, comm->doneEvent, 0),tid=13019) sta === 78:=== trace (cudaLaunchCooperativeKernelMultiDevice(paramsList, numDevices, cudaCooperativeLaunchMultiDeviceNoPreSync|cudaCooperativeLaunchMultiDeviceNoPostSync),tid=13019) sta === 95:=== trace (cudaEventRecord(comm->doneEvent, params->stream),tid=13019) sta === 109:=== trace (cudaEventRecord(comm->doneEvent, params->stream),tid=13019) sta === 123:=== trace (cudaSetDevice(savedDev),tid=13019) sta ===
|
ib受信 wc recv === 0x7f85a2028620 trace (ibv wc recv,host=nccl-study02,tid=1151,pid=1083) sta === dump_trace(char const*) misc/utils.cc:239 ncclIbTest(void*, int*, int*) transport/net_ib.cc:797 ncclNetTest include/net.h:29 netRecvProxy(ncclProxyArgs*) transport/net.cc:533 persistentThread(void*) /home/tanakamura/shared/nccl/src/transport.cc:162 start_thread ??:0 __clone ??:0 === trace (ibv wc recv,tid=2329) end === ib送信 === 0x7f85a2028620 trace (wrap ibv post send,host=nccl-study02,tid=1151,pid=1083) sta === dump_trace(char const*) misc/utils.cc:239 wrap_ibv_post_send include/ibvwrap.h:1094 ncclIbIsend(void*, void*, int, void*, void**) transport/net_ib.cc:668 ncclNetIsend include/net.h:26 netSendProxy(ncclProxyArgs*) transport/net.cc:470 persistentThread(void*) /home/tanakamura/shared/nccl/src/transport.cc:162 start_thread ??:0 __clone ??:0 === trace (wrap ibv post send,tid=2329) end === 別スレッドからncclProxyArgsを受けて、そこからib転送してる。 |
ノード数が1個の場合は persistentThread を使ってないので、GPU内で似たようなことをしてるはず。 |
ただ persistentThread 使ってない場合も NeedProxy は true を返している。 Cooperative サポートの違いかと思ったけど、これは関係ないっぽい。 |
ノード一個の場合は、 if (connector->transportComm->proxy == NULL) return ncclSuccess; ここで抜けている。 見てる限りだと常に抜けている。 つまりSaveProxyは何もしてない。 |
state->ops にはやるべきタスクが積んである。 ib の場合は、send,recvの二個積んである。 idle = 0 になれば次の op が実行される。全opのprogressが終了すれば、もう一度先頭からやりなおす。 全パケット転送完了したら op->state が ncclPorxyOpNone になって state->ops のチェインから外される。 TCP socket の場合も同じようにprogress回してポーリングしてる。 |
netSendProxy
netRecvProxy
|
ncclConnectorのconn が transport と GPU をつないでる transport は ncclConnInfo 一個しか持ってないのに ncclPrimitives が NRECV個持ってるように見えたが、これはtree用で、ringが単純な一本の場合は NRECV = 1 |
loadSendConn でキューのアドレスが決まる。 ここにきてるconnは、netSendConnect などで初期化された send->conn。 devHostRecvMem は cudaHostAlloc で割り当てられてるのでGPUからアクセスできる host mem。 netSendProxy などで hostSendMem->head が更新されるとそれがGPUから見える。 これで ib → GPU への通知ができている。 |
p2p.cc の場合は(よく見てないけど)、p2pSendConnect などでConnが繋がってそう。 同一ノードで、P2Pで繋がらない場合はよくわからない。 |
shmへ行くと、posix shmが割り当てられて、それがcudaHostRegister されて GPU に見えるようになる。 これでホストメモリを複数GPUで共有してる。 |
llmodeは受信フラグとデータを同じ8byteに詰めることで、ホストのポーリング無しで、IBの受信確認をするみたいなことをやってる気がする。 |
次回
|
あと PCIe 経由でホストメモリをポーリングしてるのが気になる。IBの帯域が減りそう。 これがキャッシュか何か仕組みがあるか調べたほうがいいかもしれない。 |
netSendProxy (GPU -> net) はホスト側のメモリに書き込んでるのでポーリングはローカルでしか発生しない Line 439 in 9db4b1d
hostRecvMem はホスト側。 この場合はsend bufferに空きが出るまで待ってるときにホストメモリをPCIe経由でポーリングしてる気がする。 nccl/src/collectives/device/primitives.h Line 105 in 9db4b1d
このwaitPtrがホストメモリでは? Line 538 in 9db4b1d
netRecvProxy (net->GPU) 側は逆で、
データレディをGPUがPCIe経由でホストメモリをポーリングしてそう。 キューのhead,tailがホストメモリに置いてある。
|
__device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i, T* directBuff) {
recvConn[i] = conn;
recvBuff[i] = (const T*)recvConn[i]->buff;
recvStep[i] = recvConn[i]->step;
recvStep[i] = ROUNDUP(recvStep[i], SLICESPERCHUNK*SLICESTEPS);
// Return credits in case we rounded up.
if (tid == nthreads) *recvConn[i]->head = recvStep[i];
if (tid == i) {
waitPtr = recvConn[i]->tail;
*(recvConn[i]->opCountLoc) = opCount;
}
recvDirectBuff[i] = NULL;
if (directBuff && recvConn[i]->direct) {
recvDirectBuff[i] = directBuff;
if (tid == 0) *recvConn[i]->ptrExchange = directBuff;
}
nrecv++;
}
__device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i, T* directBuff) {
sendConn[i] = conn;
sendBuff[i] = (T*)sendConn[i]->buff;
sendStep[i] = sendConn[i]->step;
sendStep[i] = ROUNDUP(sendStep[i], SLICESPERCHUNK*SLICESTEPS);
if (tid == WARP_SIZE+i) {
waitPtr = sendConn[i]->head;
sendConnHead[i] = *waitPtr;
*(sendConn[i]->opCountLoc) = opCount;
}
sendDirectBuff[i] = NULL;
if (directBuff && sendConn[i]->direct) {
void* volatile* ptr = sendConn[i]->ptrExchange;
while ((sendDirectBuff[i] = (T*)(*ptr)) == NULL);
__syncthreads();
if (tid == 0) *ptr = NULL;
}
nsend++;
}
// ll
__device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) {
sendConn[i] = conn;
sendBuff[i] = sendConn[i]->llBuff;
sendStep[i] = sendConn[i]->step;
if (tid == WARP_SIZE+i) {
waitPtr = sendConn[i]->head;
fifoPtr = sendConn[i]->fifo;
sendConnHead = *waitPtr;
*(sendConn[i]->opCountLoc) = opCount;
}
nsend++;
}
__device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) {
recvConn[i] = conn;
recvBuff[i] = recvConn[i]->llBuff;
recvStep[i] = recvConn[i]->step;
if (tid == i) {
postPtr = recvConn[i]->head;
*(recvConn[i]->opCountLoc) = opCount;
}
nrecv++;
}
ncclPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, T* directBuff, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount)
: comm(comm), tid(tid), nthreads(nthreads), stepSize(stepSize), opCount(opCount) {
// Make sure step is updated before we read it
__syncthreads();
for (int i=0; i<NRECV && recvPeers[i] >= 0; i++) loadRecvConn(&channel->devPeers[recvPeers[i]].recv.conn, i, directBuff);
for (int i=0; i<NSEND && sendPeers[i] >= 0; i++) loadSendConn(&channel->devPeers[sendPeers[i]].send.conn, i, directBuff);
}
ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount)
: comm(comm), tid(tid), nthreads(nthreads), opCount(opCount) {
// Make sure step is updated before we read it.
barrier();
for (int i=0; i<NRECV && recvPeers[i] >= 0; i++) loadRecvConn(&channel->devPeers[recvPeers[i]].recv.conn, i);
for (int i=0; i<NSEND && sendPeers[i] >= 0; i++) loadSendConn(&channel->devPeers[sendPeers[i]].send.conn, i);
}
// ncclAllReduceRingKernel
__device__ void ncclAllReduceRingKernel(struct CollectiveArgs* args) {
struct ncclChannel* channel = comm->channels+blockIdx.x;
ncclPrimitives<UNROLL, ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS, T, 1, 1, FUNC>
prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, args->opCount);
// ...
}
// all_reduce.cu
IMPL_COLL_R(ncclAllReduce, ncclCollAllReduce);
// common.h
// Reduction define all functions
#if NCCL_OP == 0
#define IMPL_COLL_R(collf, colln) \
IMPL_COLL2(collf, sum, FuncSum, colln, ncclSum);
#elif NCCL_OP == 1
#define IMPL_COLL_R(collf, colln) \
IMPL_COLL2(collf, prod, FuncProd, colln, ncclProd);
#elif NCCL_OP == 2
#define IMPL_COLL_R(collf, colln) \
IMPL_COLL2(collf, min, FuncMin, colln, ncclMin);
#elif NCCL_OP == 3
#define IMPL_COLL_R(collf, colln) \
IMPL_COLL2(collf, max, FuncMax, colln, ncclMax);
#endif
// Only generate inline kernels for LL
#define IMPL_COLL4(coll, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, al) \
IMPL_COLL_FUNC(coll, op, ncclFunc, dtype, ctype) \
IMPL_COLL_FUNC(coll##LL, op, ncclFunc, dtype, ctype) \
IMPL_COLL_KERN(coll##LL, op, ncclFunc, dtype, ctype, FUNC_INDEX(ncclColl, ncclOp, ncclType, 1, al)) \
#define IMPL_COLL3(coll, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType) \
IMPL_COLL4(coll##Ring, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, 0) \
IMPL_COLL4(coll##Tree, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, 1)
#if NCCL_OP == 0
/* Kernels with the first operation inlined */
#define IMPL_COLL_KERN(coll, op, ncclFunc, dtype, ctype, fIndex) \
__launch_bounds__(MAXTHREADS+WARP_SIZE, 1) \
__global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \
int tid = threadIdx.x; \
int bid = blockIdx.x; \
__shared__ struct ncclColl localColl; \
\
struct ncclDevComm* comm = firstColl.args.comm; \
struct ncclChannel* channel = comm->channels+bid; \
struct ncclColl* c; \
if (bid == 0) { \
/* To optimize for latency, (only) the first operation is passed as argument.*/ \
c = &firstColl; \
} else { \
c = &localColl; \
load_coll(c, channel->devCollectives+channel->collFifoHead, tid); \
} \
while (1) { \
if (tid < c->args.nThreads) { \
if (c->funcIndex == fIndex) { \
coll##Kernel<COLL_UNROLL, ncclFunc<ctype>, ctype>(&c->args); \
} else { \
ncclFuncs[c->funcIndex](&c->args); \
} \
} \
int nextIndex = c->nextIndex; \
if (tid == 0) channel->collFifoHead = nextIndex; \
\
if (c->active == 2) { \
return; \
} \
\
/* Load next collective operation*/ \
c = &localColl; /* for bid 0 */ \
load_coll(c, channel->devCollectives+nextIndex, tid); \
} \
}
#else
#define IMPL_COLL_KERN(coll, op, ncclFunc, dtype, ctype, fIndex)
#endif
waitPtr はカーネル引数の firstColl, localCollから来てる。(localColl is 何?) |
100回くらい調べてる気がするがカーネル起動してるのは enqueue.cc の ncclLaunchCooperativeKernelMultiDevice 。(cooperativeサポートの有無に依存) |
ncclResult_t ncclLaunchCooperativeKernelMultiDevice(struct cudaLaunchParams *paramsList, int* cudaDevs, int numDevices, int cgMode) {
#if CUDART_VERSION >= 9000
if (cgMode & 0x01) {
CUDACHECK(cudaLaunchCooperativeKernelMultiDevice(paramsList, numDevices,
// These flags are to reduce the latency of using this API
cudaCooperativeLaunchMultiDeviceNoPreSync|cudaCooperativeLaunchMultiDeviceNoPostSync));
return ncclSuccess;
}
#endif
//...
}
ncclResult_t ncclBarrierEnqueue(struct ncclComm* comm) {
if (comm->nRanks == 1) return ncclSuccess;
struct cudaLaunchParams* params = comm->myParams;
dump_trace("b enqueue");
NCCLCHECK(setupLaunch(comm, params));
// Use internal NCCL stream for CGMD/GROUP launch if required or if the user stream is NULL
if (comm->launchMode == ncclComm::GROUP && (comm->groupCudaStream || comm->userStream == NULL)) {
// Enqueue event in user stream
CUDACHECK(cudaEventRecord(comm->doneEvent, comm->userStream));
// Create dependency between user stream and internal NCCL stream
CUDACHECK(cudaStreamWaitEvent(comm->groupStream, comm->doneEvent, 0));
params->stream = comm->groupStream;
} else {
if (comm->userStream != params->stream) {
// Stream changed from last call, create dependency against last NCCL kernel launch
CUDACHECK(cudaStreamWaitEvent(comm->userStream, comm->doneEvent, 0));
}
params->stream = comm->userStream;
}
int isLast = 0;
NCCLCHECK(ncclCpuBarrierIn(comm, &isLast));
if (isLast) {
if (comm->launchMode == ncclComm::GROUP) {
// I'm the last. Launch all operations.
NCCLCHECK(ncclLaunchCooperativeKernelMultiDevice(comm->intraParams, comm->intraCudaDevs, comm->intraRanks, *comm->intraCGMode));
}
NCCLCHECK(ncclCpuBarrierLast(comm));
}
return ncclSuccess;
}
ncclResult_t ncclGroupEnd() {
ncclGroupMode--;
if (ncclGroupMode > 0) return ncclSuccess;
int savedDev;
CUDACHECK(cudaGetDevice(&savedDev));
int done = ncclGroupIndex;
int doneArray[MAX_ASYNC_OPS];
for (int i=0; i<ncclGroupIndex; i++) doneArray[i] = 0;
ncclResult_t ret = ncclGroupError;
if (ret != ncclSuccess) goto group_cleanup;
/* Collectives are done in three steps :
* 1. Barrier Check In. Only the last call may call cudaLaunchKernel[cooperative]
* 2. Barrier Wait. No CUDA call is permitted
* 3. Enqueue Events. CUDA event wait/enqueue.
* This is needed because step 2 cannot call any CUDA primitive, otherwise if
* cudaFree happens between 1 and 3, it could block that CUDA call and
* prevent some ranks from launching their network threads, which would
* prevent the NCCL call from completing, blocking the cudaFree call.
*/
for (int i=0; i<ncclGroupIndex; i++) {
struct ncclAsyncArgs* args = ncclGroupArgs+i;
if (args->funcType == ASYNC_FUNC_COLL) {
if (args->coll.comm->userStream == NULL)
CUDACHECKGOTO(cudaSetDevice(args->coll.comm->cudaDev), ret, end);
NCCLCHECKGOTO(ncclBarrierEnqueue(args->coll.comm), ret, end);
}
}
// ..
}
どこかにセットアップしたncclGroupArgsがグローバル変数に残っていてそれがncclGroupEnd で使われてカーネル起動 |
ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) {
struct ncclInfo info = { ncclCollAllReduce, "AllReduce",
sendbuff, recvbuff, count, datatype, op, 0, comm, stream, /* Args */
ALLREDUCE_CHUNKSTEPS, ALLREDUCE_SLICESTEPS };
return ncclEnqueueCheck(&info);
}
ncclResult_t ncclEnqueueCheck(struct ncclInfo* info) {
dump_trace("enqueueCheck");
if (info->comm == NULL) return ncclInvalidArgument;
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p",
info->opName, info->comm->opCount, info->sendbuff, info->recvbuff, info->count,
info->datatype, info->op, info->root, info->comm, info->comm->nRanks, info->stream);
// Launch asynchronously if needed
if (ncclAsyncMode()) {
ncclResult_t ret = ncclSuccess;
int savedDev = -1;
if (info->comm->checkPointers) {
CUDACHECKGOTO(cudaGetDevice(&savedDev), ret, end);
CUDACHECKGOTO(cudaSetDevice(info->comm->cudaDev), ret, end);
}
// Check arguments
NCCLCHECKGOTO(ArgsCheck(info), ret, end);
// Always register comm even in case of error to make sure ncclGroupEnd
// cleans it up.
NCCLCHECKGOTO(ncclAsyncColl(info->comm), ret, end);
NCCLCHECKGOTO(saveKernel(info), ret, end);
end:
if (savedDev != -1) CUDACHECK(cudaSetDevice(savedDev));
ncclAsyncErrCheck(ret);
return ret;
} else {
NCCLCHECK(ArgsCheck(info));
NCCLCHECK(saveKernel(info));
NCCLCHECK(ncclBarrierEnqueue(info->comm));
NCCLCHECK(ncclBarrierEnqueueWait(info->comm));
NCCLCHECK(ncclEnqueueEvents(info->comm));
return ncclSuccess;
}
}
ncclAsyncModeでなければgroupは使ってなくて、その場でEnqueueされてそう。 |
cudaPointerGetAttributesとかいうネ申APIがある。 |
とりあえずnetSendProxy でready notifyのところを出してみたが、 if (args->head < args->tail) {
int done;
int buffSlot = args->head%NCCL_STEPS;
NCCLCHECK(ncclNetTest(args->requests[buffSlot], &done, NULL));
if (done) {
args->head += args->sliceSteps;
resources->hostSendMem->head = args->head;
args->idle = 0;
cudaPointerAttributes attr;
cudaPointerGetAttributes(&attr, resources->hostSendMem);
printf("send q ready notify : ptr=%p, attr=%d\n", resources->hostSendMem, (int)attr.memoryType);
}
} attr=1 が出てる、つまりやっぱりホストメモリをGPUからポーリングしてる気がする。 enum __device_builtin__ cudaMemoryType
{
cudaMemoryTypeHost = 1, /**< Host memory */
cudaMemoryTypeDevice = 2 /**< Device memory */
}; |
if (done) {
args->head += args->sliceSteps;
if (args->llMode == 0) {
if (resources->useGdr) ncclNetFlush(resources->netRecvComm, localBuff+buffSlot*stepSize, size, mhandle);
resources->hostRecvMem->tail = args->head;
cudaPointerAttributes attr;
cudaPointerGetAttributes(&attr, resources->hostRecvMem);
printf("recv data ready notify : ptr=%p, attr=%d\n", resources->hostRecvMem, (int)attr.memoryType);
}
args->idle = 0;
} netRecvProxy の data ready も HostMem |
どれか |
https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#zero-copy "But since any repeated access to such memory areas causes repeated PCIe transfers" やっぱそうだよな… |
#include <stdio.h>
#include <stdint.h>
struct x{
double a;
double b;
x(x const &rhs)
:a(rhs.a), b(rhs.b)
{
}
__device__ __host__ x(x volatile const &rhs)
:a(rhs.a), b(rhs.b)
{
}
x() {
}
};
//typedef struct x v_t;
//typedef int v_t;
typedef uint64_t v_t;
__global__ void f(volatile v_t *p) {
int off = blockIdx.x * blockDim.x + threadIdx.x;
long count = 0;
while (count<1024*1024*16) {
v_t x = p[off];
count++;
}
}
int main(){
v_t *p, *hp;
v_t *x;
size_t sz = 1024*1024*4*sizeof(v_t);
cudaMalloc(&p, sz);
cudaHostAlloc(&hp, sz, cudaHostAllocMapped);
cudaMallocHost(&x, sz);
while (1) {
f<<<8,64>>>(hp);
cudaMemcpy(x, hp, 64, cudaMemcpyDeviceToHost);
//cudaMemcpy(hp, x, sz, cudaMemcpyHostToDevice);
puts("xx");
}
} これとかだと、pcm (https://github.com/opcm/pcm) で見てると、3.6GB/s しか使ってないので、帯域占有しないからOKとかの判断かもしれない(?) |
もちろんcudaMemcpyだと11GB/sぐらい使う。 |
__global__ void f(volatile v_t *p) {
//int off = blockIdx.x * blockDim.x + threadIdx.x;
int off = blockIdx.x;
if (threadIdx.x == 0) {
long count = 0;
while (count<1024*1024) {
v_t x = p[off];
count++;
}
}
__syncthreads();
}
int main(){
v_t *p, *hp;
v_t *x;
size_t sz = 1024*1024*4*sizeof(v_t);
cudaMalloc(&p, sz);
cudaHostAlloc(&hp, sz, cudaHostAllocMapped);
cudaMallocHost(&x, sz);
while (1) {
f<<<32,64>>>(hp);
cudaMemcpy(x, hp, 64, cudaMemcpyDeviceToHost);
//cudaMemcpy(hp, x, sz, cudaMemcpyHostToDevice);
puts("xx");
}
} いやこっちのほうがncclと挙動近いか(ブロック中1threadしかpollしない) これだと350MB/sぐらい。 |
PCIe の read はレイテンシもきついので、頑張ればこれをGPU側メモリに置けば若干レイテンシ/スループット両方改善するとかいうパッチが作れそうな気がする。 |
そもそもCUDAにはデバイスメモリをホストにマップする方法が無かった。(ので無理) |
#include <unistd.h>
#include <stdio.h>
#include <stdint.h>
#include <immintrin.h>
#include <x86intrin.h>
#include <getopt.h>
#define NTEST 4
struct result {
long long clk_sum[NTEST];
};
template <typename CLK>
__device__ __host__
long long
other_self_other(volatile int *other2self,
volatile int *self2other,
int nloop,
CLK clk)
{
long long sum = 0;
/* other->self->other */
for (int i=0; i<nloop; i++) {
long long t0 = clk();
while (1) {
if (*other2self == i) {
break;
}
asm volatile ("" ::: "memory");
}
long long t1 = clk();
sum += t1-t0;
*self2other = i;
}
return sum;
}
template <typename CLK>
__device__ __host__
long long
self_other_self(volatile int *other2self,
volatile int *self2other,
int nloop,
CLK clk)
{
long long sum = 0;
/* self->other->self */
for (int i=0; i<nloop; i++) {
long long t0 = clk();
*self2other = i;
while (1) {
if (*other2self == i) {
break;
}
asm volatile ("" ::: "memory");
}
long long t1 = clk();
sum += t1-t0;
}
return sum;
}
__global__ void ping_pong_device(volatile int *h2d,
volatile int *d2h,
struct result *r,
int nloop,
int dhd)
{
auto dev_clk = [] __device__ () { return clock64(); };
for (int ti=0; ti<NTEST; ti++) {
long long t;
if (dhd) {
t = self_other_self(h2d, d2h, nloop, dev_clk);
} else {
t = other_self_other(h2d, d2h, nloop, dev_clk);
}
long long t1 = clock64();
r->clk_sum[ti] = t;
}
}
void ping_pong_host(volatile int *h2d,
volatile int *d2h,
struct result *r,
int nloop,
int dhd)
{
auto host_clk = [] __host__ () { return __rdtsc(); };
for (int ti=0; ti<NTEST; ti++) {
long long t;
if (dhd) {
t = other_self_other(d2h, h2d, nloop, host_clk);
} else {
t = self_other_self(d2h, h2d, nloop, host_clk);
}
r->clk_sum[ti] = t;
}
}
enum op {
/* h2d d2h */
MEM_DD, /* dev dev */
MEM_DH, /* dev host */
MEM_HD, /* host dev */
MEM_HH, /* host host = NCCL compatible */
};
int main(int argc, char **argv)
{
int *d_h2d, *d_d2h;
int *h_h2d, *h_d2h;
struct result *r, h_r, r_copy;
enum op o = MEM_DD;
int opt;
int nloop = 1024;
int dhd = 0;
while ((opt = getopt(argc, argv, "dn:o:")) != -1) {
switch (opt) {
case 'd':
dhd = true;
break;
case 'n':
nloop = atoi(optarg);
break;
case 'o':
o = (enum op)atoi(optarg);
break;
default:
puts("usage : xx");
return 1;
}
}
if (argc > 1) {
o = (enum op)atoi(argv[1]);
}
printf("op = %d\n", (int)o);
cudaMalloc(&r, sizeof(*r));
switch (o) {
default:
case MEM_DD:
/* h2d/d2h
* D D */
cudaHostAlloc(&h_h2d, sizeof(int), cudaHostAllocMapped);
cudaHostAlloc(&h_d2h, sizeof(int), cudaHostAllocMapped);
d_h2d = h_h2d;
d_d2h = h_d2h;
*h_h2d = 9999;
*h_d2h = 9999;
break;
}
printf("h_h2d:%p, h_d2h:%p, d_h2d:%p, d_d2h:%p\n",
h_h2d, h_d2h,
d_h2d, d_d2h);
ping_pong_device<<<1,1>>>(d_h2d, d_d2h, r, nloop, dhd);
ping_pong_host(h_h2d, h_d2h, &h_r, nloop, dhd);
cudaMemcpy(&r_copy, r, sizeof(r_copy), cudaMemcpyDeviceToHost);
printf("TAT : hdh:%f[cycle/iter]\n",
h_r.clk_sum[2]/(double)nloop);
cudaMemcpy(&r_copy, r, sizeof(r_copy), cudaMemcpyDeviceToHost);
int *tt, *ttt;
cudaMalloc(&tt, sizeof(int));
cudaMallocHost(&ttt, sizeof(int));
for (int i=0; i<4; i++) {
long long t0 = __rdtsc();
cudaMemcpy(tt, ttt, sizeof(int), cudaMemcpyHostToDevice);
long long t1 = __rdtsc();
printf("%lld\n", t1-t0);
}
for (int i=0; i<4; i++) {
long long t0 = __rdtsc();
cudaMemcpy(ttt, tt, sizeof(int), cudaMemcpyDeviceToHost);
long long t1 = __rdtsc();
printf("%lld\n", t1-t0);
}
} 途中まで計測したのでその結果だけ書いておくと、
|
cudaMallocManagedにしてもパフォーマンス変わらんかった(これは意外) |
No description provided.
The text was updated successfully, but these errors were encountered: