diff --git a/ssl/statem/statem_dtls.c b/ssl/statem/statem_dtls.c index 5fef6648bc08a..61051c34b3cf8 100644 --- a/ssl/statem/statem_dtls.c +++ b/ssl/statem/statem_dtls.c @@ -96,6 +96,29 @@ void dtls1_hm_fragment_free(hm_fragment *frag) OPENSSL_free(frag); } +static int dtls1_write_hm_header(unsigned char *msgheaderstart, + unsigned char msg_type, size_t msg_len, + unsigned short msg_seq, size_t fragoff, + size_t fraglen) +{ + WPACKET msgheader; + size_t msgheaderlen; + + if (!WPACKET_init_static_len(&msgheader, msgheaderstart, + DTLS1_HM_HEADER_LENGTH, 0) + || !WPACKET_put_bytes_u8(&msgheader, msg_type) + || !WPACKET_put_bytes_u24(&msgheader, msg_len) + || !WPACKET_put_bytes_u16(&msgheader, msg_seq) + || !WPACKET_put_bytes_u24(&msgheader, fragoff) + || !WPACKET_put_bytes_u24(&msgheader, fraglen) + || !WPACKET_get_total_written(&msgheader, &msgheaderlen) + || msgheaderlen != DTLS1_HM_HEADER_LENGTH + || !WPACKET_finish(&msgheader)) + return 0; + + return 1; +} + /* * send s->init_buf in records of type 'type' (SSL3_RT_HANDSHAKE or * SSL3_RT_CHANGE_CIPHER_SPEC) @@ -117,11 +140,11 @@ int dtls1_do_write(SSL_CONNECTION *s, uint8_t type) size_t written; size_t curr_mtu; int retry = 1; - size_t len, overhead, used_len, msg_len; + size_t len, overhead, used_len, msg_len = 0; SSL *ssl = SSL_CONNECTION_GET_SSL(s); unsigned char *data = (unsigned char *)s->init_buf->data; unsigned short msg_seq = s->d1->w_msg_hdr.seq; - unsigned char msg_type; + unsigned char msg_type = 0; if (type == SSL3_RT_HANDSHAKE) { msg_type = *data++; @@ -152,6 +175,8 @@ int dtls1_do_write(SSL_CONNECTION *s, uint8_t type) /* s->init_num shouldn't ever be < 0...but just in case */ while (s->init_num > 0) { + unsigned char *msgstart; + if (type == SSL3_RT_HANDSHAKE && s->init_off > 0) { /* * We must be writing a fragment other than the first one @@ -207,29 +232,25 @@ int dtls1_do_write(SSL_CONNECTION *s, uint8_t type) if (len > ssl_get_max_send_fragment(s)) len = ssl_get_max_send_fragment(s); - /* - * XDTLS: this function is too long. split out the CCS part - */ + msgstart = (unsigned char *)&s->init_buf->data[s->init_off]; + if (type == SSL3_RT_HANDSHAKE) { - unsigned char *p = (unsigned char *)&s->init_buf->data[s->init_off]; + const size_t fragoff = s->init_off; + const size_t fraglen = len - DTLS1_HM_HEADER_LENGTH; - if (len < DTLS1_HM_HEADER_LENGTH) { + if (len < DTLS1_HM_HEADER_LENGTH + || !dtls1_write_hm_header(msgstart, msg_type, msg_len, + msg_seq, fragoff, fraglen)) /* * len is so small that we really can't do anything sensible * so fail */ return -1; - } - - *p++ = msg_type; - l2n3(msg_len, p); - s2n(msg_seq, p); - l2n3(s->init_off, p); - l2n3(len - DTLS1_HM_HEADER_LENGTH, p); } - unsigned char *msgstart = (unsigned char *)&s->init_buf->data[s->init_off]; - if (dtls1_write_bytes(s, type, msgstart, len, &written) <= 0) { + ret = dtls1_write_bytes(s, type, msgstart, len, &written); + + if (ret <= 0) { /* * might need to update MTU here, but we don't know which * previous packet caused the failure -- so can't really @@ -275,12 +296,10 @@ int dtls1_do_write(SSL_CONNECTION *s, uint8_t type) * reconstruct message header is if it is being sent in * single fragment */ - *msgstart++ = msg_type; - l2n3(msg_len, msgstart); - s2n(msg_seq, msgstart); - l2n3(0, msgstart); - l2n3(msg_len, msgstart); - msgstart -= DTLS1_HM_HEADER_LENGTH; + if (!dtls1_write_hm_header(msgstart, msg_type, msg_len, + msg_seq, s->init_off, msg_len)) + return -1; + xlen = written; } else { msgstart += DTLS1_HM_HEADER_LENGTH;