From 22cbc389922f69c242120b1319b9010ebd0cc479 Mon Sep 17 00:00:00 2001 From: Frederik Wedel-Heinen Date: Thu, 7 Mar 2024 09:56:50 +0100 Subject: [PATCH] Cleanup messages for retransmission when receiving ack --- include/internal/statem.h | 2 ++ ssl/d1_lib.c | 9 ++++++++ ssl/ssl_local.h | 5 ++-- ssl/statem/statem_clnt.c | 5 ---- ssl/statem/statem_dtls.c | 48 +++++++++++++++++++++++++++++++++++++++ ssl/statem/statem_srvr.c | 9 ++++---- 6 files changed, 66 insertions(+), 12 deletions(-) diff --git a/include/internal/statem.h b/include/internal/statem.h index 136e6523660a53..2672f0c0798a57 100644 --- a/include/internal/statem.h +++ b/include/internal/statem.h @@ -104,6 +104,8 @@ struct ossl_statem_st { OSSL_HANDSHAKE_STATE hand_state; /* The handshake state requested by an API call (e.g. HelloRequest) */ OSSL_HANDSHAKE_STATE request_state; + /* The handshake state waiting for acknowledge */ + OSSL_HANDSHAKE_STATE ack_state; int in_init; int read_state_first_init; /* true when we are actually in SSL_accept() or SSL_connect() */ diff --git a/ssl/d1_lib.c b/ssl/d1_lib.c index e781d1bc17e050..af6de6afd821bb 100644 --- a/ssl/d1_lib.c +++ b/ssl/d1_lib.c @@ -139,6 +139,15 @@ void dtls1_clear_received_buffer(SSL_CONNECTION *s) } } +void dtls1_remove_sent_buffer_item(struct pqueue_st *pq, unsigned char *prio64be) { + pitem *item = NULL; + + while ((item = pqueue_find(pq, prio64be)) != NULL) { + dtls1_hm_fragment_free((hm_fragment *)item->data); + pitem_free(item); + } +} + void dtls1_clear_sent_buffer(SSL_CONNECTION *s) { pitem *item = NULL; diff --git a/ssl/ssl_local.h b/ssl/ssl_local.h index c19351761871ca..d8c2dc5fca9ac7 100644 --- a/ssl/ssl_local.h +++ b/ssl/ssl_local.h @@ -1960,10 +1960,9 @@ typedef struct dtls1_state_st { int shutdown_received; # endif - /* Sequence numbers that should be acknowledged */ + /* Sequence numbers that are to be acknowledged */ uint16_t ack_seq_num[DTLS_ACK_SEQ_NUM_LEN]; size_t ack_seq_num_count; - int msg_being_acked; DTLS_timer_cb timer_cb; @@ -2679,6 +2678,8 @@ __owur void dtls1_get_queue_priority(unsigned char *prio64be, int dtls1_retransmit_buffered_messages(SSL_CONNECTION *s); void dtls1_clear_received_buffer(SSL_CONNECTION *s); void dtls1_clear_sent_buffer(SSL_CONNECTION *s); +void dtls1_remove_sent_buffer_item(struct pqueue_st *pq, + unsigned char *prio64be); void dtls1_get_message_header(const unsigned char *data, struct hm_header_st *msg_hdr); __owur OSSL_TIME dtls1_default_timeout(void); diff --git a/ssl/statem/statem_clnt.c b/ssl/statem/statem_clnt.c index 3a639d86fcd557..9aef2ccaacb895 100644 --- a/ssl/statem/statem_clnt.c +++ b/ssl/statem/statem_clnt.c @@ -1410,11 +1410,6 @@ MSG_PROCESS_RETURN dtls_process_hello_verify(SSL_CONNECTION *s, PACKET *pkt) return MSG_PROCESS_FINISHED_READING; } -MSG_PROCESS_RETURN dtls_process_ack(SSL_CONNECTION *s, PACKET *pkt) -{ - - return MSG_PROCESS_FINISHED_READING; -} static int set_client_ciphersuite(SSL_CONNECTION *s, const unsigned char *cipherchars) { diff --git a/ssl/statem/statem_dtls.c b/ssl/statem/statem_dtls.c index 138bd21ab3658a..c81cbba2fb9ddf 100644 --- a/ssl/statem/statem_dtls.c +++ b/ssl/statem/statem_dtls.c @@ -1030,6 +1030,15 @@ CON_FUNC_RETURN dtls_construct_ack(SSL_CONNECTION *s, WPACKET *pkt) { } for (size_t i = 0; i < s->d1->ack_seq_num_count; ++i) { + /* + * rfc9147: section 4. + * + * Record numbers are encoded as + * struct { + * uint64 epoch; + * uint64 sequence_number; + * } RecordNumber; + */ if (!WPACKET_put_bytes_u16(pkt, s->d1->ack_seq_num[i] & 0xffff)) { SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR); @@ -1044,6 +1053,45 @@ CON_FUNC_RETURN dtls_construct_ack(SSL_CONNECTION *s, WPACKET *pkt) { return CON_FUNC_SUCCESS; } +MSG_PROCESS_RETURN dtls_process_ack(SSL_CONNECTION *s, PACKET *pkt) +{ + PACKET record_numbers; + + if (!PACKET_get_length_prefixed_2(pkt, &record_numbers)) { + SSLfatal(s, SSL_AD_ILLEGAL_PARAMETER, SSL_R_LENGTH_TOO_LONG); + return MSG_PROCESS_ERROR; + } + while (PACKET_remaining(&record_numbers) > 0) { + unsigned char prio64be[8]; + uint64_t epoch; + uint64_t sequence_number; + + /* + * rfc9147: section 4. + * + * Record numbers are encoded as + * struct { + * uint64 epoch; + * uint64 sequence_number; + * } RecordNumber; + */ + + if (!PACKET_get_net_8(&record_numbers, &epoch) + || !PACKET_get_net_8(&record_numbers, &sequence_number)) { + SSLfatal(s, SSL_AD_ILLEGAL_PARAMETER, SSL_R_LENGTH_TOO_LONG); + return MSG_PROCESS_ERROR; + } + + if (dtls1_get_epoch(s, SSL3_CC_WRITE) == epoch) { + dtls1_get_queue_priority(prio64be, sequence_number, 0); + dtls1_remove_sent_buffer_item(s->d1->sent_messages, prio64be); + } + } + + return MSG_PROCESS_FINISHED_READING; +} + + #ifndef OPENSSL_NO_SCTP /* * Wait for a dry event. Should only be called at a point in the handshake diff --git a/ssl/statem/statem_srvr.c b/ssl/statem/statem_srvr.c index fd037c54111e3d..c7d79e64bd9e9a 100644 --- a/ssl/statem/statem_srvr.c +++ b/ssl/statem/statem_srvr.c @@ -569,7 +569,7 @@ static WRITE_TRAN ossl_statem_server13_write_transition(SSL_CONNECTION *s) s->post_handshake_auth = SSL_PHA_EXT_RECEIVED; if (SSL_CONNECTION_IS_DTLS13(s)) { - s->d1->msg_being_acked = SSL3_MT_FINISHED; + st->ack_state = TLS_ST_SR_FINISHED; st->hand_state = TLS_ST_SW_ACK; } else { /* Check if we are expected to deliver a new session ticket */ @@ -582,7 +582,7 @@ static WRITE_TRAN ossl_statem_server13_write_transition(SSL_CONNECTION *s) case TLS_ST_SR_KEY_UPDATE: if (SSL_CONNECTION_IS_DTLS13(s)) { - s->d1->msg_being_acked = SSL3_MT_KEY_UPDATE; + st->ack_state = TLS_ST_SR_KEY_UPDATE; st->hand_state = TLS_ST_SW_ACK; return WRITE_TRAN_CONTINUE; } @@ -616,15 +616,14 @@ static WRITE_TRAN ossl_statem_server13_write_transition(SSL_CONNECTION *s) st->hand_state = TLS_ST_OK; return WRITE_TRAN_CONTINUE; case TLS_ST_SW_ACK: - if (s->d1->msg_being_acked == SSL3_MT_FINISHED) { + if (st->ack_state == TLS_ST_SR_FINISHED) { if (s->ext.ticket_expected && s->num_tickets > s->sent_tickets) st->hand_state = TLS_ST_SW_SESSION_TICKET; else st->hand_state = TLS_ST_OK; - } else if (s->d1->msg_being_acked == SSL3_MT_KEY_UPDATE) + } else if (st->ack_state == TLS_ST_SR_KEY_UPDATE) st->hand_state = TLS_ST_OK; - s->d1->msg_being_acked = -1; // Clear state return WRITE_TRAN_CONTINUE; } }