diff --git a/library/ssl_ticket.c b/library/ssl_ticket.c index 7f658497e..39f120995 100644 --- a/library/ssl_ticket.c +++ b/library/ssl_ticket.c @@ -216,19 +216,39 @@ int mbedtls_ssl_ticket_setup( mbedtls_ssl_ticket_context *ctx, uint32_t lifetime ) { int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; - const mbedtls_cipher_info_t *cipher_info; #if defined(MBEDTLS_USE_PSA_CRYPTO) psa_algorithm_t alg; psa_key_type_t key_type; size_t key_bits; -#endif +#else + const mbedtls_cipher_info_t *cipher_info; +#endif /* MBEDTLS_USE_PSA_CRYPTO */ ctx->f_rng = f_rng; ctx->p_rng = p_rng; ctx->ticket_lifetime = lifetime; +#if defined(MBEDTLS_USE_PSA_CRYPTO) + if( mbedtls_ssl_cipher_to_psa( cipher, TICKET_AUTH_TAG_BYTES, + &alg, &key_type, &key_bits ) != PSA_SUCCESS ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + + if( PSA_ALG_IS_AEAD( alg ) == 0 ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + + if( key_bits > PSA_BYTES_TO_BITS( MAX_KEY_BYTES ) ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + + ctx->keys[0].alg = alg; + ctx->keys[0].key_type = key_type; + ctx->keys[0].key_bits = key_bits; + + ctx->keys[1].alg = alg; + ctx->keys[1].key_type = key_type; + ctx->keys[1].key_bits = key_bits; +#else cipher_info = mbedtls_cipher_info_from_type( cipher ); if( mbedtls_cipher_info_get_mode( cipher_info ) != MBEDTLS_MODE_GCM && @@ -241,19 +261,6 @@ int mbedtls_ssl_ticket_setup( mbedtls_ssl_ticket_context *ctx, if( mbedtls_cipher_info_get_key_bitlen( cipher_info ) > 8 * MAX_KEY_BYTES ) return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); -#if defined(MBEDTLS_USE_PSA_CRYPTO) - if( mbedtls_ssl_cipher_to_psa( cipher_info->type, TICKET_AUTH_TAG_BYTES, - &alg, &key_type, &key_bits ) != PSA_SUCCESS ) - return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); - - ctx->keys[0].alg = alg; - ctx->keys[0].key_type = key_type; - ctx->keys[0].key_bits = key_bits; - - ctx->keys[1].alg = alg; - ctx->keys[1].key_type = key_type; - ctx->keys[1].key_bits = key_bits; -#else if( ( ret = mbedtls_cipher_setup( &ctx->keys[0].ctx, cipher_info ) ) != 0 ) return( ret );