Skip to content

Commit

Permalink
Update do work
Browse files Browse the repository at this point in the history
  • Loading branch information
SeyedMir committed Jan 29, 2025
1 parent 08a5126 commit abe5127
Showing 1 changed file with 83 additions and 75 deletions.
158 changes: 83 additions & 75 deletions examples/ucp_client_server_multi_dev.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "ucp_util.h"

#include <ucp/api/ucp.h>
#include <ucs/sys/uid.h>

#include <string.h> /* memset */
#include <arpa/inet.h> /* inet_addr */
Expand Down Expand Up @@ -101,9 +102,10 @@ static struct {
*/
typedef struct dev_ucp_ctx {
ucp_context_h ucp_context;
ucp_worker_h ucp_context;
ucp_worker_h ucp_worker;
ucp_ep_h ucp_eps[MAX_DEV_COUNT];
size_t ep_count;
int dev_id;
} dev_ucp_ctx_t;


Expand Down Expand Up @@ -207,7 +209,7 @@ static void err_cb(void *arg, ucp_ep_h ep, ucs_status_t status)
{
printf("error handling callback was invoked with status %d (%s)\n",
status, ucs_status_string(status));
connection_closed = 1;
connection_closed++;
}

/**
Expand Down Expand Up @@ -248,6 +250,21 @@ void set_sock_addr(const char *address_str, struct sockaddr_storage *saddr)
}
}

/**
* Close all the endpoints.
*/
static void close_eps(dev_ucp_ctx_t *dev_ucp_contexts, int dev_count)
{
for (int ldev = 0; ldev < dev_count; ldev++) {
for (int rdev = 0; rdev < dev_ucp_contexts[ldev].ep_count; rdev++) {
ep_close(dev_ucp_contexts[ldev].ucp_worker,
dev_ucp_contexts[ldev].ucp_eps[rdev],
UCP_EP_CLOSE_FLAG_FORCE);
}
dev_ucp_contexts[ldev].ep_count = 0;
}
}

/**
* Create an endpoint from each client GPU-specific
* worker to each remote server GPU-specific worker (to the given IP).
Expand Down Expand Up @@ -784,7 +801,7 @@ static int client_server_communication(ucp_worker_h worker, ucp_ep_h ep,
* Create a ucp worker on the given ucp context.
*/
static int init_worker(ucp_context_h ucp_context, int client_id,
ucp_worker_h *ucp_worker);
ucp_worker_h *ucp_worker)
{
ucp_worker_params_t worker_params;
ucs_status_t status;
Expand Down Expand Up @@ -833,10 +850,9 @@ static void server_conn_handle_cb(ucp_conn_request_h conn_request, void *arg)
/* Accept the request only if we are not processing another client
* already, or it is coming from the same client as the one we're
* already processing. Otherwise, reject it. */
if ((context->client_id == 0) ||
(context->client_id == conn_attr.client_id)) {
if ((context->client_id == 0) || (context->client_id == attr.client_id)) {
context->conn_request = conn_request;
context->client_id = conn_request.client_id;
context->client_id = attr.client_id;
} else {
printf("Rejecting a connection request. "
"Only one client at a time is supported.\n");
Expand All @@ -848,24 +864,12 @@ static void server_conn_handle_cb(ucp_conn_request_h conn_request, void *arg)
}
}

void close_eps(dev_ucp_ctx_t *dev_ucp_contexts, int dev_count)
{
for (int ldev = 0; ldev < dev_count; ldev++) {
for (int rdev = 0; rdev < dev_ucp_contexts[ldev].ep_count; rdev++) {
ep_close(dev_ucp_contexts[ldev].ucp_eps[rdev],
UCP_EP_CLOSE_FLAG_FORCE);
}
}
}

static ucs_status_t server_create_ep(ucp_worker_h ucp_worker,
ucp_conn_request_h conn_request,
ucp_ep_h *server_ep)
{
ucp_ep_params_t ep_params;
ucs_status_t status;
ucp_conn_request_attr_t conn_attr;
int clinet_dev_id;

/* Server creates an ep to the client for each of its GPU-specific workers.
* The client side should have initiated the connection (one for each of
Expand All @@ -885,8 +889,8 @@ static ucs_status_t server_create_ep(ucp_worker_h ucp_worker,
return status;
}

static ucs_status_t server_create_eps(ucp_dev_ctx_t *dev_ucp_contexts,
int dev_count)
static ucs_status_t server_create_eps(dev_ucp_ctx_t *dev_ucp_contexts,
int dev_count, ucx_server_ctx_t *context)
{
/* Creating server-side eps. The eps are created upon receiving connection
* requests initiated by the client. The client must initiate one request
Expand All @@ -906,19 +910,20 @@ static ucs_status_t server_create_eps(ucp_dev_ctx_t *dev_ucp_contexts,
* Note that we assume the client and server use the same number of GPUs.
* Otherwise, they need to exchange an initial message to let each other
* know about the number of GPUs they use. */
ucs_status_t status;
for (int cdev = 0; cdev < dev_count; cdev++) { /* server GPUs */
for (int sdev = 0; sdev < dev_count; sdev++) { /* client GPUs */
/* Wait for the server to receive a connection request
* from the client. If there are multiple clients for
* which the server's connection request callback is invoked,
* i.e. several clients are trying to connect in parallel, the
* server will handle only the first one and reject the rest. */
while (context.conn_request == NULL) {
while (context->conn_request == NULL) {
ucp_worker_progress(dev_ucp_contexts[0].ucp_worker);
}

status = server_create_ep(dev_ucp_contexts[sdev].ucp_worker,
context.conn_request,
context->conn_request,
&dev_ucp_contexts[sdev].ucp_eps[cdev]);
if (status != UCS_OK) {
close_eps(dev_ucp_contexts, dev_count);
Expand All @@ -929,7 +934,7 @@ static ucs_status_t server_create_eps(ucp_dev_ctx_t *dev_ucp_contexts,

/* Now we are ready to accept the next request, but only
* for the rest of the GPUs from the same client. */
context.conn_request = NULL;
context->conn_request = NULL;
}
}

Expand Down Expand Up @@ -1000,47 +1005,52 @@ ucs_status_t register_am_recv_callback(ucp_worker_h worker)
return ucp_worker_set_am_recv_handler(worker, &param);
}

static int client_server_do_work(ucp_worker_h ucp_worker, ucp_ep_h ep,
static int client_server_do_work(dev_ucp_ctx_t *dev_ucp_contexts, int dev_count,
send_recv_type_t send_recv_type, int is_server)
{
int i, ret = 0;
ucs_status_t status;
ucp_worker_h ucp_worker;
ucp_ep_h ucp_ep;

connection_closed = 0;

for (i = 0; i < num_iterations; i++) {
ret = client_server_communication(ucp_worker, ep, send_recv_type,
is_server, i);
if (ret != 0) {
fprintf(stderr, "%s failed on iteration #%d\n",
(is_server ? "server": "client"), i + 1);
goto out;
}
}
for (int cdev = 0; cdev < dev_count; cdev++) {
for (int sdev = 0; sdev < dev_count; sdev++) {
int ldev = is_server ? sdev : cdev;
int rdev = is_server ? cdev : sdev;
ucp_worker = dev_ucp_contexts[ldev].ucp_worker;
ucp_ep = dev_ucp_contexts[ldev].ucp_eps[rdev];
for (i = 0; i < num_iterations; i++) {
ret = client_server_communication(ucp_worker, ucp_ep,
send_recv_type, is_server, i);
if (ret != 0) {
fprintf(stderr, "%s failed on iteration #%d ldev %d rdev %d\n",
(is_server ? "server": "client"), i + 1, ldev, rdev);
goto out;
}
}

/* Register recv callback on the client side to receive FIN message */
if (!is_server && (send_recv_type == CLIENT_SERVER_SEND_RECV_AM)) {
status = register_am_recv_callback(ucp_worker);
if (status != UCS_OK) {
ret = -1;
goto out;
}
}
/* FIN message in reverse direction to acknowledge delivery */
ret = client_server_communication(ucp_worker, ucp_ep,
send_recv_type, !is_server,
i + 1);
if (ret != 0) {
fprintf(stderr, "%s failed on FIN message ldev %d rdev %d\n",
(is_server ? "server": "client"), ldev, rdev);
goto out;
}

/* FIN message in reverse direction to acknowledge delivery */
ret = client_server_communication(ucp_worker, ep, send_recv_type,
!is_server, i + 1);
if (ret != 0) {
fprintf(stderr, "%s failed on FIN message\n",
(is_server ? "server": "client"));
goto out;
printf("%s FIN message local_dev %d remote_dev %d\n",
is_server ? "sent" : "received", ldev, rdev);
}
}

printf("%s FIN message\n", is_server ? "sent" : "received");

/* Server waits until the client closed the connection after receiving FIN */
while (is_server && !connection_closed) {
ucp_worker_progress(ucp_worker);
/* Server waits until the client closed all
* the connections after receiving all FIN */
while (is_server && connection_closed < dev_count * dev_count) {
for (int sdev = 0; sdev < dev_count; sdev++) {
ucp_worker_progress(dev_ucp_contexts[sdev].ucp_worker);
}
}

out:
Expand All @@ -1054,16 +1064,6 @@ static int run_server(dev_ucp_ctx_t *dev_ucp_contexts, int dev_count,
ucs_status_t status;
int ret;

if (send_recv_type == CLIENT_SERVER_SEND_RECV_AM) {
for (int dev_id = 0; dev_id < dev_count; dev_id++) {
status = register_am_recv_callback(dev_ucp_contexts[dev_id].ucp_worker);
if (status != UCS_OK) {
ret = -1;
goto err;
}
}
}

/* Initialize the server's context. */
context.conn_request = NULL;
context.client_id = 0;
Expand All @@ -1083,7 +1083,7 @@ static int run_server(dev_ucp_ctx_t *dev_ucp_contexts, int dev_count,

/* Server is always up listening */
while (1) {
ret = server_create_eps(dev_ucp_contexts, dev_count);
ret = server_create_eps(dev_ucp_contexts, dev_count, &context);

if (ret != 0) {
goto err_listener;
Expand Down Expand Up @@ -1115,24 +1115,20 @@ static int run_server(dev_ucp_ctx_t *dev_ucp_contexts, int dev_count,
return ret;
}

static int run_client(ucp_worker_h *ucp_workers, int dev_count,
static int run_client(dev_ucp_ctx_t *dev_ucp_contexts, int dev_count,
char *server_addr, send_recv_type_t send_recv_type)
{
ucs_status_t status;
int ret;

ret = client_create_eps(dev_ucp_contexts, dev_count, server_addr);
int ret = client_create_eps(dev_ucp_contexts, dev_count, server_addr);
if (ret != 0) {
fprintf(stderr, "failed to create client eps(%s)\n",
ucs_status_string(status));
fprintf(stderr, "failed to create client eps\n");
goto ep_close;
}

ret = client_server_do_work(ucp_worker, client_ep, send_recv_type, 0);
ret = client_server_do_work(dev_ucp_contexts, dev_count, send_recv_type, 0);

ep_close:
/* Close all the endpoints to the server */
close_eps(ucp_workers, client_eps, dev_count);
close_eps(dev_ucp_contexts, dev_count);
out:
return ret;
}
Expand Down Expand Up @@ -1179,7 +1175,7 @@ static int init_context(dev_ucp_ctx_t *dev_ucp_ctx,
* multiple GPUs, the server will receive multiple connection requests per
* client. Thus, we need a way to distinguish the requests that belong to
* different clients on the server side. */
client_id = ucs_generate_uuid((uintptr_t)&ucp_context);
client_id = ucs_get_system_id();
ret = init_worker(dev_ucp_ctx->ucp_context, client_id,
&dev_ucp_ctx->ucp_worker);
if (ret != 0) {
Expand Down Expand Up @@ -1211,6 +1207,7 @@ int main(int argc, char **argv)
parse_cmd_status_t parse_cmd_status;
int ret;
int dev_id, dev_count;
ucs_status_t status;

/* GPU-specific UCP context */
dev_ucp_ctx_t dev_ucp_contexts[MAX_DEV_COUNT];
Expand All @@ -1234,6 +1231,16 @@ int main(int argc, char **argv)
if (ret != 0) {
goto err;
}

/* Register recv callback (client needs it to receive FIN message) */
if (send_recv_type == CLIENT_SERVER_SEND_RECV_AM) {
status = register_am_recv_callback(dev_ucp_contexts[dev_id].ucp_worker);
if (status != UCS_OK) {
ret = -1;
goto fin_context;
}
}

}

printf("created %d ucp_contexts and ucp_workers\n", dev_id);
Expand All @@ -1249,6 +1256,7 @@ int main(int argc, char **argv)
server_addr, send_recv_type);
}

fin_context:
for (int i = 0; i < dev_id; i++) {
finalize_context(&dev_ucp_contexts[i]);
}
Expand Down

0 comments on commit abe5127

Please sign in to comment.