Skip to content

Commit

Permalink
Track device ids in request
Browse files Browse the repository at this point in the history
  • Loading branch information
SeyedMir committed Jan 30, 2025
1 parent abe5127 commit c18ea67
Showing 1 changed file with 31 additions and 14 deletions.
45 changes: 31 additions & 14 deletions examples/ucp_client_server_multi_dev.c
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ typedef struct ucx_server_ctx {
*/
typedef struct test_req {
int complete;
int cdev;
int sdev;
} test_req_t;


Expand Down Expand Up @@ -319,6 +321,8 @@ static int client_create_eps(dev_ucp_ctx_t *dev_ucp_contexts, int dev_count,
return -1;
}

printf("Client created ep for sdev %d cdev %d\n", sdev, cdev);

dev_ucp_contexts[cdev].ep_count++;
}
}
Expand All @@ -343,14 +347,17 @@ static void print_iov(const ucp_dt_iov_t *iov)
* side.
*/
static
void print_result(int is_server, const ucp_dt_iov_t *iov, int current_iter)
void print_result(int is_server, const ucp_dt_iov_t *iov, int current_iter,
int sdev, int cdev)
{
if (is_server) {
printf("Server: iteration #%d\n", (current_iter + 1));
printf("Server: iteration #%d, sdev %d, cdev %d\n",
(current_iter + 1), sdev, cdev);
printf("UCX data message was received\n");
printf("\n\n----- UCP TEST SUCCESS -------\n\n");
} else {
printf("Client: iteration #%d\n", (current_iter + 1));
printf("Client: iteration #%d, sdev %d, cdev %d\n",
(current_iter + 1), sdev, cdev);
printf("\n\n------------------------------\n\n");
}

Expand Down Expand Up @@ -404,7 +411,7 @@ static int request_finalize(ucp_worker_h ucp_worker, test_req_t *request,
/* Print the output of the first, last and every PRINT_INTERVAL iteration */
if ((current_iter == 0) || (current_iter == (num_iterations - 1)) ||
!((current_iter + 1) % (PRINT_INTERVAL))) {
print_result(is_server, iov, current_iter);
print_result(is_server, iov, current_iter, ctx->sdev, ctx->cdev);
}

release_iov:
Expand All @@ -415,6 +422,7 @@ static int request_finalize(ucp_worker_h ucp_worker, test_req_t *request,
static int
fill_request_param(ucp_dt_iov_t *iov, int is_client,
void **msg, size_t *msg_length,
int sdev, int cdev,
test_req_t *ctx, ucp_request_param_t *param)
{
CHKERR_ACTION(buffer_malloc(iov) != 0, "allocate memory", return -1;);
Expand All @@ -428,6 +436,8 @@ fill_request_param(ucp_dt_iov_t *iov, int is_client,
*msg_length = (iov_cnt == 1) ? iov[0].length : iov_cnt;

ctx->complete = 0;
ctx->sdev = sdev;
ctx->cdev = cdev;
param->op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
UCP_OP_ATTR_FIELD_DATATYPE |
UCP_OP_ATTR_FIELD_USER_DATA;
Expand All @@ -444,7 +454,7 @@ fill_request_param(ucp_dt_iov_t *iov, int is_client,
* The server receives a message from the client and waits for its completion.
*/
static int send_recv_stream(ucp_worker_h ucp_worker, ucp_ep_h ep, int is_server,
int current_iter)
int current_iter, int sdev, int cdev)
{
ucp_dt_iov_t *iov = alloca(iov_cnt * sizeof(ucp_dt_iov_t));
ucp_request_param_t param;
Expand All @@ -456,6 +466,7 @@ static int send_recv_stream(ucp_worker_h ucp_worker, ucp_ep_h ep, int is_server,
memset(iov, 0, iov_cnt * sizeof(*iov));

if (fill_request_param(iov, !is_server, &msg, &msg_length,
sdev, cdev,
&ctx, &param) != 0) {
return -1;
}
Expand Down Expand Up @@ -483,7 +494,7 @@ static int send_recv_stream(ucp_worker_h ucp_worker, ucp_ep_h ep, int is_server,
* The server receives a message from the client and waits for its completion.
*/
static int send_recv_tag(ucp_worker_h ucp_worker, ucp_ep_h ep, int is_server,
int current_iter)
int current_iter, int sdev, int cdev)
{
ucp_dt_iov_t *iov = alloca(iov_cnt * sizeof(ucp_dt_iov_t));
ucp_request_param_t param;
Expand All @@ -495,6 +506,7 @@ static int send_recv_tag(ucp_worker_h ucp_worker, ucp_ep_h ep, int is_server,
memset(iov, 0, iov_cnt * sizeof(*iov));

if (fill_request_param(iov, !is_server, &msg, &msg_length,
sdev, cdev,
&ctx, &param) != 0) {
return -1;
}
Expand Down Expand Up @@ -567,7 +579,7 @@ ucs_status_t ucp_am_data_cb(void *arg, const void *header, size_t header_length,
* initiates receive operation.
*/
static int send_recv_am(ucp_worker_h ucp_worker, ucp_ep_h ep, int is_server,
int current_iter)
int current_iter, int sdev, int cdev)
{
static int last = 0;
ucp_dt_iov_t *iov = alloca(iov_cnt * sizeof(ucp_dt_iov_t));
Expand All @@ -580,6 +592,7 @@ static int send_recv_am(ucp_worker_h ucp_worker, ucp_ep_h ep, int is_server,
memset(iov, 0, iov_cnt * sizeof(*iov));

if (fill_request_param(iov, !is_server, &msg, &msg_length,
sdev, cdev,
&ctx, &params) != 0) {
return -1;
}
Expand Down Expand Up @@ -772,22 +785,23 @@ static char* sockaddr_get_port_str(const struct sockaddr_storage *sock_addr,

static int client_server_communication(ucp_worker_h worker, ucp_ep_h ep,
send_recv_type_t send_recv_type,
int is_server, int current_iter)
int is_server, int current_iter,
int sdev, int cdev)
{
int ret;

switch (send_recv_type) {
case CLIENT_SERVER_SEND_RECV_STREAM:
/* Client-Server communication via Stream API */
ret = send_recv_stream(worker, ep, is_server, current_iter);
ret = send_recv_stream(worker, ep, is_server, current_iter, sdev, cdev);
break;
case CLIENT_SERVER_SEND_RECV_TAG:
/* Client-Server communication via Tag-Matching API */
ret = send_recv_tag(worker, ep, is_server, current_iter);
ret = send_recv_tag(worker, ep, is_server, current_iter, sdev, cdev);
break;
case CLIENT_SERVER_SEND_RECV_AM:
/* Client-Server communication via AM API. */
ret = send_recv_am(worker, ep, is_server, current_iter);
ret = send_recv_am(worker, ep, is_server, current_iter, sdev, cdev);
break;
default:
fprintf(stderr, "unknown send-recv type %d\n", send_recv_type);
Expand Down Expand Up @@ -930,6 +944,8 @@ static ucs_status_t server_create_eps(dev_ucp_ctx_t *dev_ucp_contexts,
return -1;
}

printf("Server created ep for sdev %d cdev %d\n", sdev, cdev);

dev_ucp_contexts[sdev].ep_count++;

/* Now we are ready to accept the next request, but only
Expand Down Expand Up @@ -1022,7 +1038,8 @@ static int client_server_do_work(dev_ucp_ctx_t *dev_ucp_contexts, int dev_count,
ucp_ep = dev_ucp_contexts[ldev].ucp_eps[rdev];
for (i = 0; i < num_iterations; i++) {
ret = client_server_communication(ucp_worker, ucp_ep,
send_recv_type, is_server, i);
send_recv_type, is_server, i,
sdev, cdev);
if (ret != 0) {
fprintf(stderr, "%s failed on iteration #%d ldev %d rdev %d\n",
(is_server ? "server": "client"), i + 1, ldev, rdev);
Expand All @@ -1032,8 +1049,8 @@ static int client_server_do_work(dev_ucp_ctx_t *dev_ucp_contexts, int dev_count,

/* FIN message in reverse direction to acknowledge delivery */
ret = client_server_communication(ucp_worker, ucp_ep,
send_recv_type, !is_server,
i + 1);
send_recv_type, !is_server, i + 1,
sdev, cdev);
if (ret != 0) {
fprintf(stderr, "%s failed on FIN message ldev %d rdev %d\n",
(is_server ? "server": "client"), ldev, rdev);
Expand Down

0 comments on commit c18ea67

Please sign in to comment.