diff --git a/include/mbedtls/ssl_internal.h b/include/mbedtls/ssl_internal.h index 32f593972..54d3bc732 100644 --- a/include/mbedtls/ssl_internal.h +++ b/include/mbedtls/ssl_internal.h @@ -502,9 +502,6 @@ struct mbedtls_ssl_handshake_params void (*update_checksum)(mbedtls_ssl_context *, const unsigned char *, size_t); void (*calc_verify)(const mbedtls_ssl_context *, unsigned char *, size_t *); void (*calc_finished)(mbedtls_ssl_context *, unsigned char *, int); - int (*tls_prf)(const unsigned char *, size_t, const char *, - const unsigned char *, size_t, - unsigned char *, size_t); #if !defined(MBEDTLS_SSL_CONF_SINGLE_CIPHERSUITE) mbedtls_ssl_ciphersuite_handle_t ciphersuite_info; diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 96276c2b6..8bcad1b55 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -865,10 +865,52 @@ static void ssl_calc_finished_tls_sha384( mbedtls_ssl_context *, unsigned char * #endif #endif /* MBEDTLS_SSL_PROTO_TLS1_2 */ -/* Type for the TLS PRF */ -typedef int ssl_tls_prf_t(const unsigned char *, size_t, const char *, - const unsigned char *, size_t, - unsigned char *, size_t); +/* + * Call the appropriate PRF function + */ +MBEDTLS_ALWAYS_INLINE +static inline int ssl_prf( int minor_ver, + mbedtls_md_type_t hash, + const unsigned char *secret, size_t slen, + const char *label, + const unsigned char *random, size_t rlen, + unsigned char *dstbuf, size_t dlen ) +{ +#if !defined(MBEDTLS_SSL_PROTO_TLS1_2) || !defined(MBEDTLS_SHA512_C) + (void) hash; +#endif + +#if defined(MBEDTLS_SSL_PROTO_SSL3) + if( minor_ver == MBEDTLS_SSL_MINOR_VERSION_0 ) + return( ssl3_prf( secret, slen, label, random, rlen, dstbuf, dlen ) ); + else +#endif +#if defined(MBEDTLS_SSL_PROTO_TLS1) || defined(MBEDTLS_SSL_PROTO_TLS1_1) + if( minor_ver < MBEDTLS_SSL_MINOR_VERSION_3 ) + return( tls1_prf( secret, slen, label, random, rlen, dstbuf, dlen ) ); + else +#endif +#if defined(MBEDTLS_SSL_PROTO_TLS1_2) +#if defined(MBEDTLS_SHA512_C) + if( minor_ver == MBEDTLS_SSL_MINOR_VERSION_3 && + hash == MBEDTLS_MD_SHA384 ) + { + return( tls_prf_sha384( secret, slen, label, random, rlen, + dstbuf, dlen ) ); + } + else +#endif +#if defined(MBEDTLS_SHA256_C) + if( minor_ver == MBEDTLS_SSL_MINOR_VERSION_3 ) + { + return( tls_prf_sha256( secret, slen, label, random, rlen, + dstbuf, dlen ) ); + } +#endif +#endif /* MBEDTLS_SSL_PROTO_TLS1_2 */ + + return( MBEDTLS_ERR_SSL_INTERNAL_ERROR ); +} /* * Populate a transform structure with session keys and all the other @@ -906,7 +948,6 @@ static int ssl_populate_transform( mbedtls_ssl_transform *transform, #if defined(MBEDTLS_ZLIB_SUPPORT) int compression, #endif - ssl_tls_prf_t tls_prf, const unsigned char randbytes[64], int minor_ver, unsigned endpoint, @@ -1002,7 +1043,10 @@ static int ssl_populate_transform( mbedtls_ssl_transform *transform, /* * Compute key block using the PRF */ - ret = tls_prf( master, 48, "key expansion", randbytes, 64, keyblk, 256 ); + ret = ssl_prf( minor_ver, + mbedtls_ssl_suite_get_mac( ciphersuite_info ), + master, 48, "key expansion", randbytes, 64, + keyblk, 256 ); if( ret != 0 ) { MBEDTLS_SSL_DEBUG_RET( 1, "prf", ret ); @@ -1304,7 +1348,6 @@ static int ssl_set_handshake_prfs( mbedtls_ssl_handshake_params *handshake, #if defined(MBEDTLS_SSL_PROTO_SSL3) if( minor_ver == MBEDTLS_SSL_MINOR_VERSION_0 ) { - handshake->tls_prf = ssl3_prf; handshake->calc_verify = ssl_calc_verify_ssl; handshake->calc_finished = ssl_calc_finished_ssl; } @@ -1313,7 +1356,6 @@ static int ssl_set_handshake_prfs( mbedtls_ssl_handshake_params *handshake, #if defined(MBEDTLS_SSL_PROTO_TLS1) || defined(MBEDTLS_SSL_PROTO_TLS1_1) if( minor_ver < MBEDTLS_SSL_MINOR_VERSION_3 ) { - handshake->tls_prf = tls1_prf; handshake->calc_verify = ssl_calc_verify_tls; handshake->calc_finished = ssl_calc_finished_tls; } @@ -1324,7 +1366,6 @@ static int ssl_set_handshake_prfs( mbedtls_ssl_handshake_params *handshake, if( minor_ver == MBEDTLS_SSL_MINOR_VERSION_3 && hash == MBEDTLS_MD_SHA384 ) { - handshake->tls_prf = tls_prf_sha384; handshake->calc_verify = ssl_calc_verify_tls_sha384; handshake->calc_finished = ssl_calc_finished_tls_sha384; } @@ -1333,7 +1374,6 @@ static int ssl_set_handshake_prfs( mbedtls_ssl_handshake_params *handshake, #if defined(MBEDTLS_SHA256_C) if( minor_ver == MBEDTLS_SSL_MINOR_VERSION_3 ) { - handshake->tls_prf = tls_prf_sha256; handshake->calc_verify = ssl_calc_verify_tls_sha256; handshake->calc_finished = ssl_calc_finished_tls_sha256; } @@ -1363,10 +1403,13 @@ static int ssl_compute_master( mbedtls_ssl_handshake_params *handshake, { int ret; -#if !defined(MBEDTLS_DEBUG_C) && !defined(MBEDTLS_SSL_EXTENDED_MASTER_SECRET) - ssl = NULL; /* make sure we don't use it except for debug and EMS */ - (void) ssl; -#endif +/* #if !defined(MBEDTLS_DEBUG_C) && !defined(MBEDTLS_SSL_EXTENDED_MASTER_SECRET) */ +/* ssl = NULL; /\* make sure we don't use it except for debug and EMS *\/ */ +/* (void) ssl; */ +/* #endif */ + + mbedtls_ssl_ciphersuite_handle_t const ciphersuite = + mbedtls_ssl_handshake_get_ciphersuite( handshake ); #if !defined(MBEDTLS_SSL_NO_SESSION_RESUMPTION) if( handshake->resume != 0 ) @@ -1391,18 +1434,22 @@ static int ssl_compute_master( mbedtls_ssl_handshake_params *handshake, MBEDTLS_SSL_DEBUG_BUF( 3, "session hash for extended master secret", session_hash, hash_len ); - ret = handshake->tls_prf( handshake->premaster, handshake->pmslen, - "extended master secret", - session_hash, hash_len, - master, 48 ); + ret = ssl_prf( mbedtls_ssl_get_minor_ver( ssl ), + mbedtls_ssl_suite_get_mac( ciphersuite ), + handshake->premaster, handshake->pmslen, + "extended master secret", + session_hash, hash_len, + master, 48 ); } else #endif { - ret = handshake->tls_prf( handshake->premaster, handshake->pmslen, - "master secret", - handshake->randbytes, 64, - master, 48 ); + ret = ssl_prf( mbedtls_ssl_get_minor_ver( ssl ), + mbedtls_ssl_suite_get_mac( ciphersuite ), + handshake->premaster, handshake->pmslen, + "master secret", + handshake->randbytes, 64, + master, 48 ); } if( ret != 0 ) { @@ -1470,7 +1517,6 @@ int mbedtls_ssl_derive_keys( mbedtls_ssl_context *ssl ) #if defined(MBEDTLS_ZLIB_SUPPORT) ssl->session_negotiate->compression, #endif - ssl->handshake->tls_prf, ssl->handshake->randbytes, mbedtls_ssl_get_minor_ver( ssl ), mbedtls_ssl_conf_get_endpoint( ssl->conf ), @@ -7517,8 +7563,12 @@ static void ssl_calc_finished_tls( mbedtls_md5_finish_ret( &md5, padbuf ); mbedtls_sha1_finish_ret( &sha1, padbuf + 16 ); - ssl->handshake->tls_prf( session->master, 48, sender, - padbuf, 36, buf, len ); + ssl_prf( mbedtls_ssl_get_minor_ver( ssl ), + mbedtls_ssl_suite_get_mac( + mbedtls_ssl_ciphersuite_from_id( + mbedtls_ssl_session_get_ciphersuite( session ) ) ), + session->master, 48, sender, + padbuf, 36, buf, len ); MBEDTLS_SSL_DEBUG_BUF( 3, "calc finished result", buf, len ); @@ -7568,8 +7618,12 @@ static void ssl_calc_finished_tls_sha256( mbedtls_sha256_finish_ret( &sha256, padbuf ); - ssl->handshake->tls_prf( session->master, 48, sender, - padbuf, 32, buf, len ); + ssl_prf( mbedtls_ssl_get_minor_ver( ssl ), + mbedtls_ssl_suite_get_mac( + mbedtls_ssl_ciphersuite_from_id( + mbedtls_ssl_session_get_ciphersuite( session ) ) ), + session->master, 48, sender, + padbuf, 32, buf, len ); MBEDTLS_SSL_DEBUG_BUF( 3, "calc finished result", buf, len ); @@ -7617,8 +7671,12 @@ static void ssl_calc_finished_tls_sha384( mbedtls_sha512_finish_ret( &sha512, padbuf ); - ssl->handshake->tls_prf( session->master, 48, sender, - padbuf, 48, buf, len ); + ssl_prf( mbedtls_ssl_get_minor_ver( ssl ), + mbedtls_ssl_suite_get_mac( + mbedtls_ssl_ciphersuite_from_id( + mbedtls_ssl_session_get_ciphersuite( session ) ) ), + session->master, 48, sender, + padbuf, 48, buf, len ); MBEDTLS_SSL_DEBUG_BUF( 3, "calc finished result", buf, len ); @@ -11386,29 +11444,6 @@ int mbedtls_ssl_context_save( mbedtls_ssl_context *ssl, return( ssl_session_reset_int( ssl, 0 ) ); } -/* - * Helper to get TLS 1.2 PRF from ciphersuite - * (Duplicates bits of logic from ssl_set_handshake_prfs().) - */ -typedef int (*tls_prf_fn)( const unsigned char *secret, size_t slen, - const char *label, - const unsigned char *random, size_t rlen, - unsigned char *dstbuf, size_t dlen ); -static tls_prf_fn ssl_tls12prf_from_cs( int ciphersuite_id ) -{ - mbedtls_ssl_ciphersuite_handle_t const info = - mbedtls_ssl_ciphersuite_from_id( ciphersuite_id ); - const mbedtls_md_type_t hash = mbedtls_ssl_suite_get_mac( info ); - -#if defined(MBEDTLS_SHA512_C) - if( hash == MBEDTLS_MD_SHA384 ) - return( tls_prf_sha384 ); -#else - (void) hash; -#endif - return( tls_prf_sha256 ); -} - /* * Deserialize context, see mbedtls_ssl_context_save() for format. * @@ -11529,8 +11564,6 @@ static int ssl_context_load( mbedtls_ssl_context *ssl, #if defined(MBEDTLS_ZLIB_SUPPORT) ssl->session->compression, #endif - ssl_tls12prf_from_cs( - mbedtls_ssl_session_get_ciphersuite( ssl->session) ), p, /* currently pointing to randbytes */ MBEDTLS_SSL_MINOR_VERSION_3, /* (D)TLS 1.2 is forced */ mbedtls_ssl_conf_get_endpoint( ssl->conf ),