diff --git a/control.c b/control.c index ba49105..ad13471 100644 --- a/control.c +++ b/control.c @@ -753,46 +753,107 @@ send_msg_frp_server(struct bufferevent *bev, free(req_msg); } -void -send_enc_msg_frp_server(struct bufferevent *bev, - const enum msg_type type, - const char *msg, - const size_t msg_len, - struct tmux_stream *stream) +static int prepare_encrypted_message(const enum msg_type type, + const char *msg, + const size_t msg_len, + uint8_t **enc_msg_out, + size_t *enc_len_out) { - struct bufferevent *bout = NULL; - if (bev) { - bout = bev; - } else { - bout = main_ctl->connect_bev; + // Validate inputs + if (!msg || !enc_msg_out || !enc_len_out) { + debug(LOG_ERR, "Invalid input parameters"); + return -1; + } + + // Prepare message header and content + size_t total_len = msg_len + sizeof(struct msg_hdr); + struct msg_hdr *req_msg = calloc(total_len, 1); + if (!req_msg) { + debug(LOG_ERR, "Failed to allocate memory for message"); + return -1; } - assert(bout); - struct msg_hdr *req_msg = calloc(msg_len+sizeof(struct msg_hdr), 1); - assert(req_msg); req_msg->type = type; req_msg->length = msg_hton((uint64_t)msg_len); memcpy(req_msg->data, msg, msg_len); + // Encrypt message + uint8_t *enc_msg = NULL; + size_t enc_len = encrypt_data((uint8_t *)req_msg, total_len, + get_main_encoder(), &enc_msg); + free(req_msg); + + if (enc_len <= 0 || !enc_msg) { + debug(LOG_ERR, "Encryption failed"); + return -1; + } + + *enc_msg_out = enc_msg; + *enc_len_out = enc_len; + return 0; +} + +static int initialize_encoder(struct bufferevent *bout, struct tmux_stream *stream) +{ + struct frp_coder *coder = init_main_encoder(); + if (!coder) { + debug(LOG_ERR, "Failed to initialize encoder"); + return -1; + } + struct common_conf *c_conf = get_common_config(); - if (get_main_encoder() == NULL) { - struct frp_coder *coder = init_main_encoder(); - if (c_conf->tcp_mux) - tmux_stream_write(bout, coder->iv, 16, stream); - else - bufferevent_write(bout, coder->iv, 16); + if (c_conf->tcp_mux) { + if (tmux_stream_write(bout, coder->iv, 16, stream) < 0) { + debug(LOG_ERR, "Failed to write IV through TCP mux"); + return -1; + } + } else { + if (bufferevent_write(bout, coder->iv, 16) < 0) { + debug(LOG_ERR, "Failed to write IV directly"); + return -1; + } + } + return 0; +} + +void send_enc_msg_frp_server(struct bufferevent *bev, + const enum msg_type type, + const char *msg, + const size_t msg_len, + struct tmux_stream *stream) +{ + // Get output bufferevent + struct bufferevent *bout = bev ? bev : main_ctl->connect_bev; + if (!bout) { + debug(LOG_ERR, "No valid bufferevent"); + return; } + // Initialize encoder if needed + if (!get_main_encoder() && initialize_encoder(bout, stream) != 0) { + return; + } + + // Prepare and encrypt message uint8_t *enc_msg = NULL; - size_t olen = encrypt_data((uint8_t *)req_msg, msg_len+sizeof(struct msg_hdr), get_main_encoder(), &enc_msg); - assert(olen > 0); - if (c_conf->tcp_mux) - tmux_stream_write(bout, enc_msg, olen, stream); - else - bufferevent_write(bout, enc_msg, olen); + size_t enc_len = 0; + if (prepare_encrypted_message(type, msg, msg_len, &enc_msg, &enc_len) != 0) { + return; + } - free(enc_msg); - free(req_msg); + // Send encrypted message + struct common_conf *c_conf = get_common_config(); + if (c_conf->tcp_mux) { + if (tmux_stream_write(bout, enc_msg, enc_len, stream) < 0) { + debug(LOG_ERR, "Failed to write encrypted message through TCP mux"); + } + } else { + if (bufferevent_write(bout, enc_msg, enc_len) < 0) { + debug(LOG_ERR, "Failed to write encrypted message directly"); + } + } + + free(enc_msg); } struct control *