diff --git a/include/internal/common.h b/include/internal/common.h index 2d26194c153e4a..e6e9627423f7dd 100644 --- a/include/internal/common.h +++ b/include/internal/common.h @@ -180,10 +180,6 @@ __owur static ossl_inline int ossl_assert_int(int expr, const char *exprstr, (((unsigned long)((c)[1]))<< 8)| \ (((unsigned long)((c)[2])) )),(c)+=3) -# define l2n3(l,c) (((c)[0]=(unsigned char)(((l)>>16)&0xff), \ - (c)[1]=(unsigned char)(((l)>> 8)&0xff), \ - (c)[2]=(unsigned char)(((l) )&0xff)),(c)+=3) - #define l3n2(c,l) (l =((uint64_t)(*((c)++)))<<16, \ l|=((uint64_t)(*((c)++)))<< 8, \ l|=((uint64_t)(*((c)++)))) diff --git a/ssl/record/rec_layer_d1.c b/ssl/record/rec_layer_d1.c index 480d0d9cdc2ba0..e245792269a328 100644 --- a/ssl/record/rec_layer_d1.c +++ b/ssl/record/rec_layer_d1.c @@ -490,7 +490,7 @@ int dtls1_read_bytes(SSL *s, uint8_t type, uint8_t *recvd_type, * Unexpected handshake message (Client Hello, or protocol violation) */ if (rr->type == SSL3_RT_HANDSHAKE && !ossl_statem_get_in_handshake(sc)) { - struct hm_header_st msg_hdr; + unsigned char msg_type; /* * This may just be a stale retransmit. Also sanity check that we have @@ -503,13 +503,13 @@ int dtls1_read_bytes(SSL *s, uint8_t type, uint8_t *recvd_type, goto start; } - dtls1_get_message_header(rr->data, &msg_hdr); + msg_type = *rr->data; /* * If we are server, we may have a repeated FINISHED of the client * here, then retransmit our CCS and FINISHED. */ - if (msg_hdr.type == SSL3_MT_FINISHED) { + if (msg_type == SSL3_MT_FINISHED) { if (dtls1_check_timeout_num(sc) < 0) { /* SSLfatal) already called */ return -1; diff --git a/ssl/ssl_local.h b/ssl/ssl_local.h index 31d92f473e09f3..55b849c3a4b9a4 100644 --- a/ssl/ssl_local.h +++ b/ssl/ssl_local.h @@ -1943,10 +1943,10 @@ struct dtls1_retransmit_state { struct hm_header_st { unsigned char type; - size_t msg_len; + unsigned long msg_len; unsigned short seq; - size_t frag_off; - size_t frag_len; + unsigned long frag_off; + unsigned long frag_len; unsigned int is_ccs; struct dtls1_retransmit_state saved_retransmit_state; }; @@ -1995,7 +1995,7 @@ typedef struct dtls1_state_st { size_t link_mtu; /* max on-the-wire DTLS packet size */ size_t mtu; /* max DTLS packet size */ struct hm_header_st w_msg_hdr; - struct hm_header_st r_msg_hdr; + unsigned short r_msg_seq; /* Number of alerts received so far */ unsigned int timeout_num_alerts; /* @@ -2706,11 +2706,7 @@ __owur int ssl_get_min_max_version(const SSL_CONNECTION *s, int *min_version, int *max_version, int *real_max); __owur OSSL_TIME tls1_default_timeout(void); -__owur int dtls1_do_write(SSL_CONNECTION *s, uint8_t type); -void dtls1_set_message_header(SSL_CONNECTION *s, - unsigned char mt, - size_t len, - size_t frag_off, size_t frag_len); +__owur int dtls1_do_write(SSL_CONNECTION *s, uint8_t recordtype); int dtls1_write_app_data_bytes(SSL *s, uint8_t type, const void *buf_, size_t len, size_t *written); @@ -2723,8 +2719,6 @@ __owur int dtls1_get_queue_priority(unsigned short seq, int is_ccs); int dtls1_retransmit_buffered_messages(SSL_CONNECTION *s); void dtls1_clear_received_buffer(SSL_CONNECTION *s); void dtls1_clear_sent_buffer(SSL_CONNECTION *s); -void dtls1_get_message_header(const unsigned char *data, - struct hm_header_st *msg_hdr); __owur OSSL_TIME dtls1_default_timeout(void); __owur int dtls1_get_timeout(const SSL_CONNECTION *s, OSSL_TIME *timeleft); __owur int dtls1_check_timeout_num(SSL_CONNECTION *s); diff --git a/ssl/statem/statem_dtls.c b/ssl/statem/statem_dtls.c index 61051c34b3cf83..f2e267d8cf405a 100644 --- a/ssl/statem/statem_dtls.c +++ b/ssl/statem/statem_dtls.c @@ -44,11 +44,6 @@ static const unsigned char bitmask_start_values[] = static const unsigned char bitmask_end_values[] = { 0xff, 0x01, 0x03, 0x07, 0x0f, 0x1f, 0x3f, 0x7f }; -static void dtls1_set_message_header_int(SSL_CONNECTION *s, unsigned char mt, - size_t len, - unsigned short seq_num, - size_t frag_off, - size_t frag_len); static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype, size_t *len); @@ -134,7 +129,7 @@ static int dtls1_write_hm_header(unsigned char *msgheaderstart, * |-- header3 --||-- fragment3 --| * ......... */ -int dtls1_do_write(SSL_CONNECTION *s, uint8_t type) +int dtls1_do_write(SSL_CONNECTION *s, uint8_t recordtype) { int ret; size_t written; @@ -146,10 +141,10 @@ int dtls1_do_write(SSL_CONNECTION *s, uint8_t type) unsigned short msg_seq = s->d1->w_msg_hdr.seq; unsigned char msg_type = 0; - if (type == SSL3_RT_HANDSHAKE) { + if (recordtype == SSL3_RT_HANDSHAKE) { msg_type = *data++; l3n2(data, msg_len); - } else if (ossl_assert(type == SSL3_RT_CHANGE_CIPHER_SPEC)) { + } else if (ossl_assert(recordtype == SSL3_RT_CHANGE_CIPHER_SPEC)) { msg_type = SSL3_MT_CCS; msg_len = 0; /* SSL3_RT_CHANGE_CIPHER_SPEC */ } else { @@ -164,7 +159,7 @@ int dtls1_do_write(SSL_CONNECTION *s, uint8_t type) /* should have something reasonable now */ return -1; - if (s->init_off == 0 && type == SSL3_RT_HANDSHAKE) { + if (s->init_off == 0 && recordtype == SSL3_RT_HANDSHAKE) { if (!ossl_assert(s->init_num == msg_len + DTLS1_HM_HEADER_LENGTH)) return -1; } @@ -177,7 +172,7 @@ int dtls1_do_write(SSL_CONNECTION *s, uint8_t type) while (s->init_num > 0) { unsigned char *msgstart; - if (type == SSL3_RT_HANDSHAKE && s->init_off > 0) { + if (recordtype == SSL3_RT_HANDSHAKE && s->init_off > 0) { /* * We must be writing a fragment other than the first one * and this is the first attempt at writing out this fragment @@ -234,7 +229,7 @@ int dtls1_do_write(SSL_CONNECTION *s, uint8_t type) msgstart = (unsigned char *)&s->init_buf->data[s->init_off]; - if (type == SSL3_RT_HANDSHAKE) { + if (recordtype == SSL3_RT_HANDSHAKE) { const size_t fragoff = s->init_off; const size_t fraglen = len - DTLS1_HM_HEADER_LENGTH; @@ -248,7 +243,7 @@ int dtls1_do_write(SSL_CONNECTION *s, uint8_t type) return -1; } - ret = dtls1_write_bytes(s, type, msgstart, len, &written); + ret = dtls1_write_bytes(s, recordtype, msgstart, len, &written); if (ret <= 0) { /* @@ -284,7 +279,7 @@ int dtls1_do_write(SSL_CONNECTION *s, uint8_t type) assert(s->s3.tmp.new_compression != NULL || BIO_wpending(s->wbio) <= (int)s->d1->mtu); - if (type == SSL3_RT_HANDSHAKE && !s->d1->retransmitting) { + if (recordtype == SSL3_RT_HANDSHAKE && !s->d1->retransmitting) { /* * should not be done for 'Hello Request's, but in that case * we'll ignore the result anyway @@ -322,7 +317,7 @@ int dtls1_do_write(SSL_CONNECTION *s, uint8_t type) if (written == s->init_num) { if (s->msg_callback) - s->msg_callback(1, s->version, type, s->init_buf->data, + s->msg_callback(1, s->version, recordtype, s->init_buf->data, s->init_off + s->init_num, ssl, s->msg_callback_arg); @@ -341,14 +336,11 @@ int dtls1_do_write(SSL_CONNECTION *s, uint8_t type) int dtls_get_message(SSL_CONNECTION *s, int *mt) { - struct hm_header_st *msg_hdr; - unsigned char *p; - size_t msg_len; + unsigned char *rec_data; size_t tmplen; int errtype; - msg_hdr = &s->d1->r_msg_hdr; - memset(msg_hdr, 0, sizeof(*msg_hdr)); + s->d1->r_msg_seq = 0; again: if (!dtls_get_reassembled_message(s, &errtype, &tmplen)) { @@ -362,12 +354,12 @@ int dtls_get_message(SSL_CONNECTION *s, int *mt) *mt = s->s3.tmp.message_type; - p = (unsigned char *)s->init_buf->data; + rec_data = (unsigned char *)s->init_buf->data; if (*mt == SSL3_MT_CHANGE_CIPHER_SPEC) { if (s->msg_callback) { s->msg_callback(0, s->version, SSL3_RT_CHANGE_CIPHER_SPEC, - p, 1, SSL_CONNECTION_GET_SSL(s), + rec_data, 1, SSL_CONNECTION_GET_SSL(s), s->msg_callback_arg); } /* @@ -376,16 +368,11 @@ int dtls_get_message(SSL_CONNECTION *s, int *mt) return 1; } - msg_len = msg_hdr->msg_len; - /* reconstruct message header */ - *(p++) = msg_hdr->type; - l2n3(msg_len, p); - s2n(msg_hdr->seq, p); - l2n3(0, p); - l2n3(msg_len, p); + dtls1_write_hm_header(rec_data, s->s3.tmp.message_type, s->s3.tmp.message_size, + s->d1->r_msg_seq, 0, s->s3.tmp.message_size); - memset(msg_hdr, 0, sizeof(*msg_hdr)); + s->d1->r_msg_seq = 0; s->d1->handshake_read_seq++; @@ -449,7 +436,7 @@ static size_t dtls1_max_handshake_message_len(const SSL_CONNECTION *s) } static int dtls1_preprocess_fragment(SSL_CONNECTION *s, - struct hm_header_st *msg_hdr) + const struct hm_header_st * const msg_hdr) { size_t frag_off, frag_len, msg_len; @@ -464,30 +451,19 @@ static int dtls1_preprocess_fragment(SSL_CONNECTION *s, return 0; } - if (s->d1->r_msg_hdr.frag_off == 0) { /* first fragment */ - /* - * msg_len is limited to 2^24, but is effectively checked against - * dtls_max_handshake_message_len(s) above - */ - if (!BUF_MEM_grow_clean(s->init_buf, msg_len + DTLS1_HM_HEADER_LENGTH)) { - SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_BUF_LIB); - return 0; - } - - s->s3.tmp.message_size = msg_len; - s->d1->r_msg_hdr.msg_len = msg_len; - s->s3.tmp.message_type = msg_hdr->type; - s->d1->r_msg_hdr.type = msg_hdr->type; - s->d1->r_msg_hdr.seq = msg_hdr->seq; - } else if (msg_len != s->d1->r_msg_hdr.msg_len) { - /* - * They must be playing with us! BTW, failure to enforce upper limit - * would open possibility for buffer overrun. - */ - SSLfatal(s, SSL_AD_ILLEGAL_PARAMETER, SSL_R_EXCESSIVE_MESSAGE_SIZE); + /* + * msg_len is limited to 2^24, but is effectively checked against + * dtls_max_handshake_message_len(s) above + */ + if (!BUF_MEM_grow_clean(s->init_buf, msg_len + DTLS1_HM_HEADER_LENGTH)) { + SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_BUF_LIB); return 0; } + s->s3.tmp.message_size = msg_len; + s->s3.tmp.message_type = msg_hdr->type; + s->d1->r_msg_seq = msg_hdr->seq; + return 1; } @@ -818,10 +794,26 @@ static int dtls1_process_out_of_seq_message(SSL_CONNECTION *s, return 0; } +static int dtls1_read_hm_header(unsigned char *msgheaderstart, struct hm_header_st *msg_hdr) +{ + PACKET msgheader; + + if (!PACKET_buf_init(&msgheader, msgheaderstart, DTLS1_HM_HEADER_LENGTH) + || !PACKET_get_1(&msgheader, (unsigned int *)&msg_hdr->type) + || !PACKET_get_net_3(&msgheader, &msg_hdr->msg_len) + || !PACKET_get_net_2(&msgheader, (unsigned int *)&msg_hdr->seq) + || !PACKET_get_net_3(&msgheader, &msg_hdr->frag_off) + || !PACKET_get_net_3(&msgheader, &msg_hdr->frag_len) + || PACKET_remaining(&msgheader) != 0) { + return 0; + } + + return 1; +} + static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype, size_t *len) { - size_t mlen, frag_off, frag_len; int i, ret; uint8_t recvd_type; struct hm_header_st msg_hdr; @@ -836,14 +828,14 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype, redo: /* see if we have the required fragment already */ - ret = dtls1_retrieve_buffered_fragment(s, &frag_len); + ret = dtls1_retrieve_buffered_fragment(s, &msg_hdr.frag_len); if (ret < 0) { /* SSLfatal() already called */ return 0; } if (ret > 0) { - s->init_num = frag_len; - *len = frag_len; + s->init_num = msg_hdr.frag_len; + *len = msg_hdr.frag_len; return 1; } @@ -877,17 +869,16 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype, } /* parse the message fragment header */ - dtls1_get_message_header(p, &msg_hdr); - - mlen = msg_hdr.msg_len; - frag_off = msg_hdr.frag_off; - frag_len = msg_hdr.frag_len; + if (!dtls1_read_hm_header(p, &msg_hdr)) { + SSLfatal(s, SSL_AD_ILLEGAL_PARAMETER, SSL_R_BAD_LENGTH); + goto f_err; + } /* * We must have at least frag_len bytes left in the record to be read. * Fragments must not span records. */ - if (frag_len > s->rlayer.tlsrecs[s->rlayer.curr_rec].length) { + if (msg_hdr.frag_len > s->rlayer.tlsrecs[s->rlayer.curr_rec].length) { SSLfatal(s, SSL_AD_ILLEGAL_PARAMETER, SSL_R_BAD_LENGTH); goto f_err; } @@ -902,7 +893,7 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype, if (!s->server || msg_hdr.seq != 0 || s->d1->handshake_read_seq != 1 - || p[0] != SSL3_MT_CLIENT_HELLO + || msg_hdr.type != SSL3_MT_CLIENT_HELLO || s->statem.hand_state != DTLS_ST_SW_HELLO_VERIFY_REQUEST) { *errtype = dtls1_process_out_of_seq_message(s, &msg_hdr); return 0; @@ -915,21 +906,20 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype, chretran = 1; } - if (frag_len && frag_len < mlen) { + if (msg_hdr.frag_len && msg_hdr.frag_len < msg_hdr.msg_len) { *errtype = dtls1_reassemble_fragment(s, &msg_hdr); return 0; } if (!s->server - && s->d1->r_msg_hdr.frag_off == 0 && s->statem.hand_state != TLS_ST_OK - && p[0] == SSL3_MT_HELLO_REQUEST) { + && msg_hdr.type == SSL3_MT_HELLO_REQUEST) { /* * The server may always send 'Hello Request' messages -- we are * doing a handshake anyway now, so ignore them if their format is * correct. Does not count for 'Finished' MAC. */ - if (p[1] == 0 && p[2] == 0 && p[3] == 0) { + if (msg_hdr.msg_len == 0) { if (s->msg_callback) s->msg_callback(0, s->version, SSL3_RT_HANDSHAKE, p, DTLS1_HM_HEADER_LENGTH, ssl, @@ -937,8 +927,8 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype, s->init_num = 0; goto redo; - } else { /* Incorrectly formatted Hello request */ - + } else { + /* Incorrectly formatted Hello request */ SSLfatal(s, SSL_AD_UNEXPECTED_MESSAGE, SSL_R_UNEXPECTED_MESSAGE); goto f_err; } @@ -949,11 +939,11 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype, goto f_err; } - if (frag_len > 0) { - p += DTLS1_HM_HEADER_LENGTH; + if (msg_hdr.frag_len > 0) { + p += DTLS1_HM_HEADER_LENGTH + msg_hdr.frag_off; i = ssl->method->ssl_read_bytes(ssl, SSL3_RT_HANDSHAKE, NULL, - &p[frag_off], frag_len, 0, &readbytes); + p, msg_hdr.frag_len, 0, &readbytes); /* * This shouldn't ever fail due to NBIO because we already checked @@ -972,7 +962,7 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype, * XDTLS: an incorrectly formatted fragment should cause the handshake * to fail */ - if (readbytes != frag_len) { + if (readbytes != msg_hdr.frag_len) { SSLfatal(s, SSL_AD_ILLEGAL_PARAMETER, SSL_R_BAD_LENGTH); goto f_err; } @@ -994,7 +984,7 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype, * soon as they sum up to handshake packet length, we assume we have got * all the fragments. */ - *len = s->init_num = frag_len; + *len = s->init_num = msg_hdr.frag_len; return 1; f_err: @@ -1133,6 +1123,7 @@ int dtls1_buffer_message(SSL_CONNECTION *s, int is_ccs) pitem *item; hm_fragment *frag; unsigned char seq64be[8]; + size_t headerlen; /* * this function is called immediately after a message has been @@ -1147,21 +1138,15 @@ int dtls1_buffer_message(SSL_CONNECTION *s, int is_ccs) memcpy(frag->fragment, s->init_buf->data, s->init_num); - if (is_ccs) { + if (is_ccs) /* For DTLS1_BAD_VER the header length is non-standard */ - if (!ossl_assert(s->d1->w_msg_hdr.msg_len + - ((s->version == - DTLS1_BAD_VER) ? 3 : DTLS1_CCS_HEADER_LENGTH) - == (unsigned int)s->init_num)) { - dtls1_hm_fragment_free(frag); - return 0; - } - } else { - if (!ossl_assert(s->d1->w_msg_hdr.msg_len + - DTLS1_HM_HEADER_LENGTH == (unsigned int)s->init_num)) { - dtls1_hm_fragment_free(frag); - return 0; - } + headerlen = (s->version == DTLS1_BAD_VER) ? 3 : DTLS1_CCS_HEADER_LENGTH; + else + headerlen = DTLS1_HM_HEADER_LENGTH; + + if (!ossl_assert(s->d1->w_msg_hdr.msg_len + headerlen == s->init_num)) { + dtls1_hm_fragment_free(frag); + return 0; } frag->msg_header.msg_len = s->d1->w_msg_hdr.msg_len; @@ -1230,10 +1215,9 @@ int dtls1_retransmit_message(SSL_CONNECTION *s, unsigned short seq, int *found) frag->msg_header.msg_len + header_length); s->init_num = frag->msg_header.msg_len + header_length; - dtls1_set_message_header_int(s, frag->msg_header.type, - frag->msg_header.msg_len, - frag->msg_header.seq, 0, - frag->msg_header.frag_len); + s->d1->w_msg_hdr.type = frag->msg_header.type; + s->d1->w_msg_hdr.msg_len = frag->msg_header.msg_len; + s->d1->w_msg_hdr.seq = frag->msg_header.seq; /* save current state */ saved_state.wrlmethod = s->rlayer.wrlmethod; @@ -1264,58 +1248,26 @@ int dtls1_retransmit_message(SSL_CONNECTION *s, unsigned short seq, int *found) return ret; } -void dtls1_set_message_header(SSL_CONNECTION *s, - unsigned char mt, size_t len, - size_t frag_off, size_t frag_len) -{ - if (frag_off == 0) { - s->d1->handshake_write_seq = s->d1->next_handshake_write_seq; - s->d1->next_handshake_write_seq++; - } - - dtls1_set_message_header_int(s, mt, len, s->d1->handshake_write_seq, - frag_off, frag_len); -} - -/* don't actually do the writing, wait till the MTU has been retrieved */ -static void -dtls1_set_message_header_int(SSL_CONNECTION *s, unsigned char mt, - size_t len, unsigned short seq_num, - size_t frag_off, size_t frag_len) -{ - struct hm_header_st *msg_hdr = &s->d1->w_msg_hdr; - - msg_hdr->type = mt; - msg_hdr->msg_len = len; - msg_hdr->seq = seq_num; - msg_hdr->frag_off = frag_off; - msg_hdr->frag_len = frag_len; -} - -void dtls1_get_message_header(const unsigned char *data, struct - hm_header_st *msg_hdr) -{ - memset(msg_hdr, 0, sizeof(*msg_hdr)); - msg_hdr->type = *(data++); - n2l3(data, msg_hdr->msg_len); - - n2s(data, msg_hdr->seq); - n2l3(data, msg_hdr->frag_off); - n2l3(data, msg_hdr->frag_len); -} - int dtls1_set_handshake_header(SSL_CONNECTION *s, WPACKET *pkt, int htype) { if (htype == SSL3_MT_CHANGE_CIPHER_SPEC) { s->d1->handshake_write_seq = s->d1->next_handshake_write_seq; - dtls1_set_message_header_int(s, SSL3_MT_CCS, 0, - s->d1->handshake_write_seq, 0, 0); + + s->d1->w_msg_hdr.type = SSL3_MT_CCS; + s->d1->w_msg_hdr.msg_len = 0; + s->d1->w_msg_hdr.seq = s->d1->handshake_write_seq; + if (!WPACKET_put_bytes_u8(pkt, SSL3_MT_CCS)) return 0; } else { size_t subpacket_offset = DTLS1_HM_HEADER_LENGTH - SSL3_HM_HEADER_LENGTH; - dtls1_set_message_header(s, htype, 0, 0, 0); + s->d1->handshake_write_seq = s->d1->next_handshake_write_seq; + s->d1->next_handshake_write_seq++; + + s->d1->w_msg_hdr.type = htype; + s->d1->w_msg_hdr.msg_len = 0; + s->d1->w_msg_hdr.seq = s->d1->handshake_write_seq; /* Set the content type and 3 bytes for the message len */ if (!WPACKET_put_bytes_u8(pkt, htype) @@ -1341,7 +1293,6 @@ int dtls1_close_construct_packet(SSL_CONNECTION *s, WPACKET *pkt, int htype) if (htype != SSL3_MT_CHANGE_CIPHER_SPEC) { s->d1->w_msg_hdr.msg_len = msglen - DTLS1_HM_HEADER_LENGTH; - s->d1->w_msg_hdr.frag_len = msglen - DTLS1_HM_HEADER_LENGTH; } s->init_num = msglen; s->init_off = 0;