From f1bc9e1c6968facfde877678761cfd599643abe3 Mon Sep 17 00:00:00 2001 From: Hanno Becker Date: Wed, 19 Jun 2019 16:23:21 +0100 Subject: [PATCH] Introduce helper functions to traverse signature hashes --- include/mbedtls/ssl_internal.h | 25 +++++++++++++++++++++++++ library/ssl_cli.c | 24 +++++++++++------------- library/ssl_srv.c | 20 +++++++------------- library/ssl_tls.c | 9 ++++----- 4 files changed, 47 insertions(+), 31 deletions(-) diff --git a/include/mbedtls/ssl_internal.h b/include/mbedtls/ssl_internal.h index 40391d581..29729d495 100644 --- a/include/mbedtls/ssl_internal.h +++ b/include/mbedtls/ssl_internal.h @@ -1676,4 +1676,29 @@ static inline unsigned int mbedtls_ssl_conf_get_ems_enforced( #endif /* MBEDTLS_SSL_CONF_SINGLE_EC */ +#define MBEDTLS_SSL_BEGIN_FOR_EACH_SIG_HASH( MD_VAR ) \ + { \ + int const *__md; \ + for( __md = ssl->conf->sig_hashes; \ + *__md != MBEDTLS_MD_NONE; __md++ ) \ + { \ + mbedtls_md_type_t MD_VAR = (mbedtls_md_type_t) *__md; \ + + #define MBEDTLS_SSL_END_FOR_EACH_SIG_HASH \ + } \ + } + +#define MBEDTLS_SSL_BEGIN_FOR_EACH_SIG_HASH_TLS( HASH_VAR ) \ + { \ + int const *__md; \ + for( __md = ssl->conf->sig_hashes; \ + *__md != MBEDTLS_MD_NONE; __md++ ) \ + { \ + unsigned char HASH_VAR; \ + HASH_VAR = mbedtls_ssl_hash_from_md_alg( *__md ); + +#define MBEDTLS_SSL_END_FOR_EACH_SIG_HASH_TLS \ + } \ + } + #endif /* ssl_internal.h */ diff --git a/library/ssl_cli.c b/library/ssl_cli.c index b0c0403fd..ee50b4d9a 100644 --- a/library/ssl_cli.c +++ b/library/ssl_cli.c @@ -173,7 +173,6 @@ static void ssl_write_signature_algorithms_ext( mbedtls_ssl_context *ssl, unsigned char *p = buf; const unsigned char *end = ssl->out_msg + MBEDTLS_SSL_OUT_CONTENT_LEN; size_t sig_alg_len = 0; - const int *md; #if defined(MBEDTLS_RSA_C) || defined(MBEDTLS_ECDSA_C) unsigned char *sig_alg_list = buf + 6; #endif @@ -188,15 +187,15 @@ static void ssl_write_signature_algorithms_ext( mbedtls_ssl_context *ssl, MBEDTLS_SSL_DEBUG_MSG( 3, ( "client hello, adding signature_algorithms extension" ) ); - for( md = ssl->conf->sig_hashes; *md != MBEDTLS_MD_NONE; md++ ) - { + MBEDTLS_SSL_BEGIN_FOR_EACH_SIG_HASH_TLS( hash ) + ((void) hash); #if defined(MBEDTLS_ECDSA_C) - sig_alg_len += 2; + sig_alg_len += 2; #endif #if defined(MBEDTLS_RSA_C) - sig_alg_len += 2; + sig_alg_len += 2; #endif - } + MBEDTLS_SSL_END_FOR_EACH_SIG_HASH_TLS if( end < p || (size_t)( end - p ) < sig_alg_len + 6 ) { @@ -209,17 +208,16 @@ static void ssl_write_signature_algorithms_ext( mbedtls_ssl_context *ssl, */ sig_alg_len = 0; - for( md = ssl->conf->sig_hashes; *md != MBEDTLS_MD_NONE; md++ ) - { + MBEDTLS_SSL_BEGIN_FOR_EACH_SIG_HASH_TLS( hash ) #if defined(MBEDTLS_ECDSA_C) - sig_alg_list[sig_alg_len++] = mbedtls_ssl_hash_from_md_alg( *md ); - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_SIG_ECDSA; + sig_alg_list[sig_alg_len++] = hash; + sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_SIG_ECDSA; #endif #if defined(MBEDTLS_RSA_C) - sig_alg_list[sig_alg_len++] = mbedtls_ssl_hash_from_md_alg( *md ); - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_SIG_RSA; + sig_alg_list[sig_alg_len++] = hash; + sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_SIG_RSA; #endif - } + MBEDTLS_SSL_END_FOR_EACH_SIG_HASH_TLS /* * enum { diff --git a/library/ssl_srv.c b/library/ssl_srv.c index f7ab70c47..e69c51790 100644 --- a/library/ssl_srv.c +++ b/library/ssl_srv.c @@ -3074,26 +3074,19 @@ static int ssl_write_certificate_request( mbedtls_ssl_context *ssl ) */ if( mbedtls_ssl_get_minor_ver( ssl ) == MBEDTLS_SSL_MINOR_VERSION_3 ) { - const int *cur; - /* * Supported signature algorithms */ - for( cur = ssl->conf->sig_hashes; *cur != MBEDTLS_MD_NONE; cur++ ) - { - unsigned char hash = mbedtls_ssl_hash_from_md_alg( *cur ); - if( !( 0 + MBEDTLS_SSL_BEGIN_FOR_EACH_SIG_HASH_TLS( hash ) + if( 0 #if defined(MBEDTLS_SHA512_C) - || hash == MBEDTLS_SSL_HASH_SHA384 + || hash == MBEDTLS_SSL_HASH_SHA384 #endif #if defined(MBEDTLS_SHA256_C) - || hash == MBEDTLS_SSL_HASH_SHA256 + || hash == MBEDTLS_SSL_HASH_SHA256 #endif - ) ) - { - continue; - } - + ) + { #if defined(MBEDTLS_RSA_C) p[2 + sa_len++] = hash; p[2 + sa_len++] = MBEDTLS_SSL_SIG_RSA; @@ -3103,6 +3096,7 @@ static int ssl_write_certificate_request( mbedtls_ssl_context *ssl ) p[2 + sa_len++] = MBEDTLS_SSL_SIG_ECDSA; #endif } + MBEDTLS_SSL_END_FOR_EACH_SIG_HASH_TLS p[0] = (unsigned char)( sa_len >> 8 ); p[1] = (unsigned char)( sa_len ); diff --git a/library/ssl_tls.c b/library/ssl_tls.c index ae6c282d2..9359be6d9 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -11308,14 +11308,13 @@ int mbedtls_ssl_check_curve( const mbedtls_ssl_context *ssl, mbedtls_ecp_group_i int mbedtls_ssl_check_sig_hash( const mbedtls_ssl_context *ssl, mbedtls_md_type_t md ) { - const int *cur; - if( ssl->conf->sig_hashes == NULL ) return( -1 ); - for( cur = ssl->conf->sig_hashes; *cur != MBEDTLS_MD_NONE; cur++ ) - if( *cur == (int) md ) - return( 0 ); + MBEDTLS_SSL_BEGIN_FOR_EACH_SIG_HASH( md_alg ) + if( md_alg == md ) + return( 0 ); + MBEDTLS_SSL_END_FOR_EACH_SIG_HASH return( -1 ); }