From 38c7f2d32f36d4b43d53b9eb71dd5de896a74f9a Mon Sep 17 00:00:00 2001 From: Andrzej Kurek Date: Tue, 15 Dec 2020 05:46:54 -0500 Subject: [PATCH] Refactor the immediate transmission feature The original way or handling it did not cover message fragmentation or retransmission. Now, the messages are always appended to the flight and sent immediately, using the same function as normal flight transmission. Moreover, epoch handling is different for this feature, with a possibility to perform the usual retransmission using previous methods. Signed-off-by: Andrzej Kurek --- include/mbedtls/ssl_internal.h | 3 + library/ssl_cli.c | 14 +- library/ssl_srv.c | 28 ++- library/ssl_tls.c | 307 ++++++++++++++++++--------------- 4 files changed, 205 insertions(+), 147 deletions(-) diff --git a/include/mbedtls/ssl_internal.h b/include/mbedtls/ssl_internal.h index 97c00e256..80da3ac1a 100644 --- a/include/mbedtls/ssl_internal.h +++ b/include/mbedtls/ssl_internal.h @@ -1203,6 +1203,9 @@ void mbedtls_ssl_send_flight_completed( mbedtls_ssl_context *ssl ); void mbedtls_ssl_recv_flight_completed( mbedtls_ssl_context *ssl ); int mbedtls_ssl_resend( mbedtls_ssl_context *ssl ); int mbedtls_ssl_flight_transmit( mbedtls_ssl_context *ssl ); +#if defined(MBEDTLS_IMMEDIATE_TRANSMISSION) +void mbedtls_ssl_immediate_flight_done( mbedtls_ssl_context *ssl ); +#endif #endif /* Visible for testing purposes only */ diff --git a/library/ssl_cli.c b/library/ssl_cli.c index 08d5a7117..7f69b6242 100644 --- a/library/ssl_cli.c +++ b/library/ssl_cli.c @@ -1141,11 +1141,17 @@ static int ssl_write_client_hello( mbedtls_ssl_context *ssl ) } #if defined(MBEDTLS_SSL_PROTO_DTLS) - if( MBEDTLS_SSL_TRANSPORT_IS_DTLS( ssl->conf->transport ) && - ( ret = mbedtls_ssl_flight_transmit( ssl ) ) != 0 ) + if( MBEDTLS_SSL_TRANSPORT_IS_DTLS( ssl->conf->transport ) ) { - MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flight_transmit", ret ); - return( ret ); +#if defined(MBEDTLS_IMMEDIATE_TRANSMISSION) + mbedtls_ssl_immediate_flight_done( ssl ); +#else + if( ( ret = mbedtls_ssl_flight_transmit( ssl ) ) != 0 ) + { + MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flight_transmit", ret ); + return( ret ); + } +#endif } #endif /* MBEDTLS_SSL_PROTO_DTLS */ diff --git a/library/ssl_srv.c b/library/ssl_srv.c index 389a24e48..ce92f98dc 100644 --- a/library/ssl_srv.c +++ b/library/ssl_srv.c @@ -2743,11 +2743,17 @@ static int ssl_write_hello_verify_request( mbedtls_ssl_context *ssl ) } #if defined(MBEDTLS_SSL_PROTO_DTLS) - if( MBEDTLS_SSL_TRANSPORT_IS_DTLS( ssl->conf->transport ) && - ( ret = mbedtls_ssl_flight_transmit( ssl ) ) != 0 ) + if( MBEDTLS_SSL_TRANSPORT_IS_DTLS( ssl->conf->transport ) ) { - MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flight_transmit", ret ); - return( ret ); +#if defined(MBEDTLS_IMMEDIATE_TRANSMISSION) + mbedtls_ssl_immediate_flight_done( ssl ); +#else + if( ( ret = mbedtls_ssl_flight_transmit( ssl ) ) != 0 ) + { + MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flight_transmit", ret ); + return( ret ); + } +#endif } #endif /* MBEDTLS_SSL_PROTO_DTLS */ @@ -3802,11 +3808,17 @@ static int ssl_write_server_hello_done( mbedtls_ssl_context *ssl ) } #if defined(MBEDTLS_SSL_PROTO_DTLS) - if( MBEDTLS_SSL_TRANSPORT_IS_DTLS( ssl->conf->transport ) && - ( ret = mbedtls_ssl_flight_transmit( ssl ) ) != 0 ) + if( MBEDTLS_SSL_TRANSPORT_IS_DTLS( ssl->conf->transport ) ) { - MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flight_transmit", ret ); - return( ret ); +#if defined(MBEDTLS_IMMEDIATE_TRANSMISSION) + mbedtls_ssl_immediate_flight_done( ssl ); +#else + if( ( ret = mbedtls_ssl_flight_transmit( ssl ) ) != 0 ) + { + MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flight_transmit", ret ); + return( ret ); + } +#endif } #endif /* MBEDTLS_SSL_PROTO_DTLS */ diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 038e581d5..f20faf92a 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -4360,6 +4360,131 @@ int mbedtls_ssl_flush_output( mbedtls_ssl_context *ssl ) * Functions to handle the DTLS retransmission state machine */ #if defined(MBEDTLS_SSL_PROTO_DTLS) +static int ssl_swap_epochs( mbedtls_ssl_context *ssl ); + +static int mbedtls_ssl_flight_transmit_msg( mbedtls_ssl_context *ssl, mbedtls_ssl_flight_item *msg ) +{ + size_t max_frag_len; + int ret = MBEDTLS_ERR_PLATFORM_FAULT_DETECTED; + int const is_retransmitting = + ( ssl->handshake->retransmit_state == MBEDTLS_SSL_RETRANS_SENDING ); + int const is_finished = + ( msg->type == MBEDTLS_SSL_MSG_HANDSHAKE && + msg->p[0] == MBEDTLS_SSL_HS_FINISHED ); + + uint8_t const force_flush = ssl->disable_datagram_packing == 1 ? + SSL_FORCE_FLUSH : SSL_DONT_FORCE_FLUSH; + + /* Swap epochs before sending Finished: we can't do it after + * sending ChangeCipherSpec, in case write returns WANT_READ. + * Must be done before copying, may change out_msg pointer */ + if( is_retransmitting && is_finished && ssl->handshake->cur_msg_p == ( msg->p + 12 ) ) + { + MBEDTLS_SSL_DEBUG_MSG( 2, ( "swap epochs to send finished message" ) ); + if( ( ret = ssl_swap_epochs( ssl ) ) != 0 ) + return( ret ); + } + + ret = ssl_get_remaining_payload_in_datagram( ssl ); + if( ret < 0 ) + return( ret ); + max_frag_len = (size_t) ret; + + /* CCS is copied as is, while HS messages may need fragmentation */ + if( msg->type == MBEDTLS_SSL_MSG_CHANGE_CIPHER_SPEC ) + { + if( max_frag_len == 0 ) + { + if( ( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 ) + return( ret ); + + return( 0 ); + } + + mbedtls_platform_memcpy( ssl->out_msg, msg->p, msg->len ); + ssl->out_msglen = msg->len; + ssl->out_msgtype = msg->type; + + /* Update position inside current message */ + ssl->handshake->cur_msg_p += msg->len; + } + else + { + const unsigned char * const p = ssl->handshake->cur_msg_p; + const size_t hs_len = msg->len - 12; + const size_t frag_off = p - ( msg->p + 12 ); + const size_t rem_len = hs_len - frag_off; + size_t cur_hs_frag_len, max_hs_frag_len; + + if( ( max_frag_len < 12 ) || ( max_frag_len == 12 && hs_len != 0 ) ) + { + if( is_finished && is_retransmitting ) + { + if( ( ret = ssl_swap_epochs( ssl ) ) != 0 ) + return( ret ); + } + + if( ( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 ) + return( ret ); + + return( 0 ); + } + max_hs_frag_len = max_frag_len - 12; + + cur_hs_frag_len = rem_len > max_hs_frag_len ? + max_hs_frag_len : rem_len; + + if( frag_off == 0 && cur_hs_frag_len != hs_len ) + { + MBEDTLS_SSL_DEBUG_MSG( 2, ( "fragmenting handshake message (%u > %u)", + (unsigned) cur_hs_frag_len, + (unsigned) max_hs_frag_len ) ); + } + + /* Messages are stored with handshake headers as if not fragmented, + * copy beginning of headers then fill fragmentation fields. + * Handshake headers: type(1) len(3) seq(2) f_off(3) f_len(3) */ + mbedtls_platform_memcpy( ssl->out_msg, msg->p, 6 ); + + (void)mbedtls_platform_put_uint24_be( &ssl->out_msg[6], frag_off ); + (void)mbedtls_platform_put_uint24_be( &ssl->out_msg[9], + cur_hs_frag_len ); + + MBEDTLS_SSL_DEBUG_BUF( 3, "handshake header", ssl->out_msg, 12 ); + + /* Copy the handshake message content and set records fields */ + mbedtls_platform_memcpy( ssl->out_msg + 12, p, cur_hs_frag_len ); + ssl->out_msglen = cur_hs_frag_len + 12; + ssl->out_msgtype = msg->type; + + /* Update position inside current message */ + ssl->handshake->cur_msg_p += cur_hs_frag_len; + } + + /* If done with the current message move to the next one if any */ + if( ssl->handshake->cur_msg_p >= msg->p + msg->len ) + { + if( msg->next != NULL ) + { + ssl->handshake->cur_msg = msg->next; + ssl->handshake->cur_msg_p = msg->next->p + 12; + } + else + { + ssl->handshake->cur_msg = NULL; + ssl->handshake->cur_msg_p = NULL; + } + } + + /* Actually send the message out */ + if( ( ret = mbedtls_ssl_write_record( ssl, force_flush ) ) != 0 ) + { + MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_write_record", ret ); + return( ret ); + } + return( ret ); +} + /* * Append current handshake message to current outgoing flight */ @@ -4402,6 +4527,21 @@ static int ssl_flight_append( mbedtls_ssl_context *ssl ) cur->next = msg; } +#if defined(MBEDTLS_IMMEDIATE_TRANSMISSION) + ssl->handshake->cur_msg = msg; + ssl->handshake->cur_msg_p = msg->p + 12; + { + int ret = MBEDTLS_ERR_PLATFORM_FAULT_DETECTED; + while( ssl->handshake->cur_msg != NULL ) + { + if( ( ret = mbedtls_ssl_flight_transmit_msg( ssl, ssl->handshake->cur_msg ) ) != 0 ) + { + MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flight_transmit_msg", ret ); + return( ret ); + } + } + } +#endif MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= ssl_flight_append" ) ); return( 0 ); } @@ -4491,6 +4631,24 @@ int mbedtls_ssl_resend( mbedtls_ssl_context *ssl ) return( ret ); } +#if defined(MBEDTLS_IMMEDIATE_TRANSMISSION) +void mbedtls_ssl_immediate_flight_done( mbedtls_ssl_context *ssl ) +{ + MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> mbedtls_ssl_immediate_flight_done" ) ); + + /* Update state and set timer */ + if( ssl->state == MBEDTLS_SSL_HANDSHAKE_OVER ) + ssl->handshake->retransmit_state = MBEDTLS_SSL_RETRANS_FINISHED; + else + { + ssl->handshake->retransmit_state = MBEDTLS_SSL_RETRANS_WAITING; + ssl_set_timer( ssl, ssl->handshake->retransmit_timeout ); + } + + MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= mbedtls_ssl_immediate_flight_done" ) ); +} +#endif + /* * Transmit or retransmit the current flight of messages. * @@ -4507,138 +4665,19 @@ int mbedtls_ssl_flight_transmit( mbedtls_ssl_context *ssl ) { MBEDTLS_SSL_DEBUG_MSG( 2, ( "initialise flight transmission" ) ); -#if defined(MBEDTLS_IMMEDIATE_TRANSMISSION) - ssl->handshake->retransmit_state = MBEDTLS_SSL_RETRANS_SENDING; - - return( 0 ); -#else - ssl->handshake->cur_msg = ssl->handshake->flight; ssl->handshake->cur_msg_p = ssl->handshake->flight->p + 12; if( ( ret = ssl_swap_epochs( ssl ) ) != 0 ) return( ret ); ssl->handshake->retransmit_state = MBEDTLS_SSL_RETRANS_SENDING; -#endif /* MBEDTLS_IMMEDIATE_TRANSMISSION */ } while( ssl->handshake->cur_msg != NULL ) { - size_t max_frag_len; - const mbedtls_ssl_flight_item * const cur = ssl->handshake->cur_msg; - - int const is_finished = - ( cur->type == MBEDTLS_SSL_MSG_HANDSHAKE && - cur->p[0] == MBEDTLS_SSL_HS_FINISHED ); - - uint8_t const force_flush = ssl->disable_datagram_packing == 1 ? - SSL_FORCE_FLUSH : SSL_DONT_FORCE_FLUSH; - - /* Swap epochs before sending Finished: we can't do it after - * sending ChangeCipherSpec, in case write returns WANT_READ. - * Must be done before copying, may change out_msg pointer */ - if( is_finished && ssl->handshake->cur_msg_p == ( cur->p + 12 ) ) + if( ( ret = mbedtls_ssl_flight_transmit_msg( ssl, ssl->handshake->cur_msg ) ) != 0 ) { - MBEDTLS_SSL_DEBUG_MSG( 2, ( "swap epochs to send finished message" ) ); - if( ( ret = ssl_swap_epochs( ssl ) ) != 0 ) - return( ret ); - } - - ret = ssl_get_remaining_payload_in_datagram( ssl ); - if( ret < 0 ) - return( ret ); - max_frag_len = (size_t) ret; - - /* CCS is copied as is, while HS messages may need fragmentation */ - if( cur->type == MBEDTLS_SSL_MSG_CHANGE_CIPHER_SPEC ) - { - if( max_frag_len == 0 ) - { - if( ( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 ) - return( ret ); - - continue; - } - - mbedtls_platform_memcpy( ssl->out_msg, cur->p, cur->len ); - ssl->out_msglen = cur->len; - ssl->out_msgtype = cur->type; - - /* Update position inside current message */ - ssl->handshake->cur_msg_p += cur->len; - } - else - { - const unsigned char * const p = ssl->handshake->cur_msg_p; - const size_t hs_len = cur->len - 12; - const size_t frag_off = p - ( cur->p + 12 ); - const size_t rem_len = hs_len - frag_off; - size_t cur_hs_frag_len, max_hs_frag_len; - - if( ( max_frag_len < 12 ) || ( max_frag_len == 12 && hs_len != 0 ) ) - { - if( is_finished ) - { - if( ( ret = ssl_swap_epochs( ssl ) ) != 0 ) - return( ret ); - } - - if( ( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 ) - return( ret ); - - continue; - } - max_hs_frag_len = max_frag_len - 12; - - cur_hs_frag_len = rem_len > max_hs_frag_len ? - max_hs_frag_len : rem_len; - - if( frag_off == 0 && cur_hs_frag_len != hs_len ) - { - MBEDTLS_SSL_DEBUG_MSG( 2, ( "fragmenting handshake message (%u > %u)", - (unsigned) cur_hs_frag_len, - (unsigned) max_hs_frag_len ) ); - } - - /* Messages are stored with handshake headers as if not fragmented, - * copy beginning of headers then fill fragmentation fields. - * Handshake headers: type(1) len(3) seq(2) f_off(3) f_len(3) */ - mbedtls_platform_memcpy( ssl->out_msg, cur->p, 6 ); - - (void)mbedtls_platform_put_uint24_be( &ssl->out_msg[6], frag_off ); - (void)mbedtls_platform_put_uint24_be( &ssl->out_msg[9], - cur_hs_frag_len ); - - MBEDTLS_SSL_DEBUG_BUF( 3, "handshake header", ssl->out_msg, 12 ); - - /* Copy the handshake message content and set records fields */ - mbedtls_platform_memcpy( ssl->out_msg + 12, p, cur_hs_frag_len ); - ssl->out_msglen = cur_hs_frag_len + 12; - ssl->out_msgtype = cur->type; - - /* Update position inside current message */ - ssl->handshake->cur_msg_p += cur_hs_frag_len; - } - - /* If done with the current message move to the next one if any */ - if( ssl->handshake->cur_msg_p >= cur->p + cur->len ) - { - if( cur->next != NULL ) - { - ssl->handshake->cur_msg = cur->next; - ssl->handshake->cur_msg_p = cur->next->p + 12; - } - else - { - ssl->handshake->cur_msg = NULL; - ssl->handshake->cur_msg_p = NULL; - } - } - - /* Actually send the message out */ - if( ( ret = mbedtls_ssl_write_record( ssl, force_flush ) ) != 0 ) - { - MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_write_record", ret ); + MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flight_transmit_msg", ret ); return( ret ); } } @@ -4657,7 +4696,7 @@ int mbedtls_ssl_flight_transmit( mbedtls_ssl_context *ssl ) MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= mbedtls_ssl_flight_transmit" ) ); - return( 0 ); + return( ret ); } /* @@ -4868,14 +4907,6 @@ int mbedtls_ssl_write_handshake_msg( mbedtls_ssl_context *ssl ) ! ( ssl->out_msgtype == MBEDTLS_SSL_MSG_HANDSHAKE && hs_type == MBEDTLS_SSL_HS_HELLO_REQUEST ) ) { -#if defined(MBEDTLS_IMMEDIATE_TRANSMISSION) - if( ( ret = mbedtls_ssl_write_record( ssl, SSL_FORCE_FLUSH ) ) != 0 ) - { - MBEDTLS_SSL_DEBUG_RET( 1, "ssl_write_record", ret ); - return( ret ); - } -#endif /* MBEDTLS_IMMEDIATE_TRANSMISSION */ - if( ( ret = ssl_flight_append( ssl ) ) != 0 ) { MBEDTLS_SSL_DEBUG_RET( 1, "ssl_flight_append", ret ); @@ -8707,13 +8738,19 @@ int mbedtls_ssl_write_finished( mbedtls_ssl_context *ssl ) } #if defined(MBEDTLS_SSL_PROTO_DTLS) - if( MBEDTLS_SSL_TRANSPORT_IS_DTLS( ssl->conf->transport ) && - ( ret = mbedtls_ssl_flight_transmit( ssl ) ) != 0 ) + if( MBEDTLS_SSL_TRANSPORT_IS_DTLS( ssl->conf->transport ) ) { - MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flight_transmit", ret ); - return( ret ); - } +#if defined(MBEDTLS_IMMEDIATE_TRANSMISSION) + mbedtls_ssl_immediate_flight_done( ssl ); +#else + if( ( ret = mbedtls_ssl_flight_transmit( ssl ) ) != 0 ) + { + MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flight_transmit", ret ); + return( ret ); + } #endif + } +#endif /* MBEDTLS_SSL_PROTO_DTLS */ MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= write finished" ) );