mirror of
https://github.com/neonbjb/tortoise-tts.git
synced 2026-01-29 20:04:20 +01:00
Merge 6c826929e1 into 572bdf3d24
This commit is contained in:
commit
e486ffb087
BIN
clipped_audio.wav
Normal file
BIN
clipped_audio.wav
Normal file
Binary file not shown.
BIN
test_sig.npy
Normal file
BIN
test_sig.npy
Normal file
Binary file not shown.
339
tortoise/api.py
339
tortoise/api.py
|
|
@ -6,6 +6,7 @@ from urllib import request
|
|||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import progressbar
|
||||
import torchaudio
|
||||
|
||||
|
|
@ -21,9 +22,11 @@ from tortoise.models.vocoder import UnivNetGenerator
|
|||
from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel, TacotronSTFT
|
||||
from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
|
||||
from tortoise.utils.tokenizer import VoiceBpeTokenizer
|
||||
from tortoise.utils.misc_helpers import Timer
|
||||
from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
|
||||
from contextlib import contextmanager
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
pbar = None
|
||||
|
||||
DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser('~'), '.cache', 'tortoise', 'models')
|
||||
|
|
@ -39,13 +42,19 @@ MODELS = {
|
|||
'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
|
||||
}
|
||||
|
||||
|
||||
def get_model_path(model_name, models_dir=MODELS_DIR):
|
||||
"""
|
||||
Get path to given model, download it if it doesn't exist.
|
||||
"""
|
||||
if model_name not in MODELS:
|
||||
raise ValueError(f'Model {model_name} not found in available models.')
|
||||
model_path = hf_hub_download(repo_id="Manmay/tortoise-tts", filename=model_name, cache_dir=models_dir)
|
||||
try:
|
||||
model_path = hf_hub_download(repo_id="Manmay/tortoise-tts", filename=model_name, cache_dir=models_dir)
|
||||
except:
|
||||
# CVVP not found in Manmay tortoise-tts.
|
||||
model_path = hf_hub_download(repo_id="jbetker/tortoise-tts-v2", subfolder=".models", filename=model_name, cache_dir=models_dir)
|
||||
|
||||
return model_path
|
||||
|
||||
|
||||
|
|
@ -56,17 +65,20 @@ def pad_or_truncate(t, length):
|
|||
if t.shape[-1] == length:
|
||||
return t
|
||||
elif t.shape[-1] < length:
|
||||
return F.pad(t, (0, length-t.shape[-1]))
|
||||
return F.pad(t, (0, length - t.shape[-1]))
|
||||
else:
|
||||
return t[..., :length]
|
||||
|
||||
|
||||
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1):
|
||||
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True,
|
||||
cond_free_k=1):
|
||||
"""
|
||||
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
|
||||
"""
|
||||
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
|
||||
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
|
||||
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]),
|
||||
model_mean_type='epsilon',
|
||||
model_var_type='learned_range', loss_type='mse',
|
||||
betas=get_named_beta_schedule('linear', trained_diffusion_steps),
|
||||
conditioning_free=cond_free, conditioning_free_k=cond_free_k)
|
||||
|
||||
|
||||
|
|
@ -119,15 +131,17 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_la
|
|||
Uses the specified diffusion model to convert discrete codes into a spectrogram.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
|
||||
output_seq_len = latents.shape[
|
||||
1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
|
||||
output_shape = (latents.shape[0], 100, output_seq_len)
|
||||
precomputed_embeddings = diffusion_model.timestep_independent(latents, conditioning_latents, output_seq_len, False)
|
||||
precomputed_embeddings = diffusion_model.timestep_independent(latents, conditioning_latents, output_seq_len,
|
||||
False)
|
||||
|
||||
noise = torch.randn(output_shape, device=latents.device) * temperature
|
||||
mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
|
||||
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings},
|
||||
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings},
|
||||
progress=verbose)
|
||||
return denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
|
||||
return denormalize_tacotron_mel(mel)[:, :, :output_seq_len]
|
||||
|
||||
|
||||
def classify_audio_clip(clip):
|
||||
|
|
@ -171,12 +185,13 @@ def pick_best_batch_size_for_gpu():
|
|||
return 4
|
||||
return 1
|
||||
|
||||
|
||||
class TextToSpeech:
|
||||
"""
|
||||
Main entry point into Tortoise.
|
||||
"""
|
||||
|
||||
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR,
|
||||
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR,
|
||||
enable_redaction=True, kv_cache=False, use_deepspeed=False, half=False, device=None,
|
||||
tokenizer_vocab_file=None, tokenizer_basic=False):
|
||||
|
||||
|
|
@ -194,7 +209,7 @@ class TextToSpeech:
|
|||
self.models_dir = models_dir
|
||||
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
|
||||
self.enable_redaction = enable_redaction
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu')
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
if torch.backends.mps.is_available():
|
||||
self.device = torch.device('mps')
|
||||
if self.enable_redaction:
|
||||
|
|
@ -210,15 +225,19 @@ class TextToSpeech:
|
|||
self.autoregressive = torch.jit.load(f'{models_dir}/autoregressive.ptt')
|
||||
self.diffusion = torch.jit.load(f'{models_dir}/diffusion_decoder.ptt')
|
||||
else:
|
||||
self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
|
||||
model_dim=1024,
|
||||
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
|
||||
train_solo_embeddings=False).cpu().eval()
|
||||
self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)), strict=False)
|
||||
self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2,
|
||||
layers=30,
|
||||
model_dim=1024,
|
||||
heads=16, number_text_tokens=255, start_text_token=255,
|
||||
checkpointing=False,
|
||||
train_solo_embeddings=False).cpu().eval()
|
||||
self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)),
|
||||
strict=False)
|
||||
self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=kv_cache, half=self.half)
|
||||
|
||||
|
||||
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
|
||||
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
|
||||
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False,
|
||||
num_heads=16,
|
||||
layer_drop=0, unconditioned_percentage=0).cpu().eval()
|
||||
self.diffusion.load_state_dict(torch.load(get_model_path('diffusion_decoder.pth', models_dir)))
|
||||
|
||||
|
|
@ -227,31 +246,37 @@ class TextToSpeech:
|
|||
num_speech_tokens=8192, speech_enc_depth=20, speech_heads=12, speech_seq_len=430,
|
||||
use_xformers=True).cpu().eval()
|
||||
self.clvp.load_state_dict(torch.load(get_model_path('clvp2.pth', models_dir)))
|
||||
self.cvvp = None # CVVP model is only loaded if used.
|
||||
self.cvvp = None # CVVP model is only loaded if used.
|
||||
|
||||
self.vocoder = UnivNetGenerator().cpu()
|
||||
self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir), map_location=torch.device('cpu'))['model_g'])
|
||||
self.vocoder.load_state_dict(
|
||||
torch.load(get_model_path('vocoder.pth', models_dir), map_location=torch.device('cpu'))['model_g'])
|
||||
self.vocoder.eval(inference=True)
|
||||
|
||||
self.stft = None # TacotronSTFT is only loaded if used.
|
||||
|
||||
self.stft = None # TacotronSTFT is only loaded if used.
|
||||
|
||||
# Random latent generators (RLGs) are loaded lazily.
|
||||
self.rlg_auto = None
|
||||
self.rlg_diffusion = None
|
||||
|
||||
@contextmanager
|
||||
def temporary_cuda(self, model):
|
||||
m = model.to(self.device)
|
||||
yield m
|
||||
m = model.cpu()
|
||||
|
||||
|
||||
def load_cvvp(self):
|
||||
"""Load CVVP model."""
|
||||
self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0,
|
||||
self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8,
|
||||
cond_mask_percentage=0,
|
||||
speech_enc_depth=8, speech_mask_percentage=0, latent_multiplier=1).cpu().eval()
|
||||
#self.cvvp.to(self.device).eval()
|
||||
|
||||
|
||||
self.cvvp.load_state_dict(torch.load(get_model_path('cvvp.pth', self.models_dir)))
|
||||
|
||||
def get_conditioning_latents(self, voice_samples, return_mels=False):
|
||||
|
||||
def get_conditioning_latents(self, voice_samples, return_mels=False, return_average=True):
|
||||
"""
|
||||
Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
|
||||
These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic
|
||||
|
|
@ -268,7 +293,7 @@ class TextToSpeech:
|
|||
auto_conds.append(format_conditioning(vs, device=self.device))
|
||||
auto_conds = torch.stack(auto_conds, dim=1)
|
||||
self.autoregressive = self.autoregressive.to(self.device)
|
||||
auto_latent = self.autoregressive.get_conditioning(auto_conds)
|
||||
auto_latent = self.autoregressive.get_conditioning(auto_conds, return_average=return_average)
|
||||
self.autoregressive = self.autoregressive.cpu()
|
||||
|
||||
if self.stft is None:
|
||||
|
|
@ -283,10 +308,11 @@ class TextToSpeech:
|
|||
cond_mel = wav_to_univnet_mel(sample.to(self.device), do_normalization=False,
|
||||
device=self.device, stft=self.stft)
|
||||
diffusion_conds.append(cond_mel)
|
||||
diffusion_conds = torch.stack(diffusion_conds, dim=1)
|
||||
|
||||
|
||||
diffusion_conds = torch.stack(diffusion_conds, dim=1)
|
||||
self.diffusion = self.diffusion.to(self.device)
|
||||
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
|
||||
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds, return_average=return_average)
|
||||
self.diffusion = self.diffusion.cpu()
|
||||
|
||||
if return_mels:
|
||||
|
|
@ -298,9 +324,11 @@ class TextToSpeech:
|
|||
# Lazy-load the RLG models.
|
||||
if self.rlg_auto is None:
|
||||
self.rlg_auto = RandomLatentConverter(1024).eval()
|
||||
self.rlg_auto.load_state_dict(torch.load(get_model_path('rlg_auto.pth', self.models_dir), map_location=torch.device('cpu')))
|
||||
self.rlg_auto.load_state_dict(
|
||||
torch.load(get_model_path('rlg_auto.pth', self.models_dir), map_location=torch.device('cpu')))
|
||||
self.rlg_diffusion = RandomLatentConverter(2048).eval()
|
||||
self.rlg_diffusion.load_state_dict(torch.load(get_model_path('rlg_diffuser.pth', self.models_dir), map_location=torch.device('cpu')))
|
||||
self.rlg_diffusion.load_state_dict(
|
||||
torch.load(get_model_path('rlg_diffuser.pth', self.models_dir), map_location=torch.device('cpu')))
|
||||
with torch.no_grad():
|
||||
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
|
||||
|
||||
|
|
@ -324,17 +352,20 @@ class TextToSpeech:
|
|||
'high_quality': {'num_autoregressive_samples': 256, 'diffusion_iterations': 400},
|
||||
}
|
||||
settings.update(presets[preset])
|
||||
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
||||
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
||||
return self.tts(text, **settings)
|
||||
|
||||
def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True, use_deterministic_seed=None,
|
||||
return_deterministic_state=False,
|
||||
# autoregressive generation parameters follow
|
||||
num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500,
|
||||
num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8,
|
||||
max_mel_tokens=500,
|
||||
# CVVP parameters follow
|
||||
cvvp_amount=.0,
|
||||
# diffusion generation parameters follow
|
||||
diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0,
|
||||
use_averaged_latents=True,
|
||||
auto_conds=None,
|
||||
**hf_generate_kwargs):
|
||||
"""
|
||||
Produces an audio clip of the given text being spoken with the given reference voice.
|
||||
|
|
@ -381,22 +412,38 @@ class TextToSpeech:
|
|||
:return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
|
||||
Sample rate is 24kHz.
|
||||
"""
|
||||
|
||||
if not use_averaged_latents:
|
||||
assert k==1, "Non-averaged latents currently only support single sample generation"
|
||||
|
||||
deterministic_seed = self.deterministic_state(seed=use_deterministic_seed)
|
||||
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
|
||||
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
|
||||
assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
|
||||
auto_conds = None
|
||||
assert text_tokens.shape[
|
||||
-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
|
||||
if voice_samples is not None:
|
||||
auto_conditioning, diffusion_conditioning, auto_conds, _ = self.get_conditioning_latents(voice_samples, return_mels=True)
|
||||
auto_conditioning, diffusion_conditioning, auto_conds, _ = self.get_conditioning_latents(voice_samples,
|
||||
return_mels=True,
|
||||
return_average=use_averaged_latents)
|
||||
elif conditioning_latents is not None:
|
||||
auto_conditioning, diffusion_conditioning = conditioning_latents
|
||||
if use_averaged_latents:
|
||||
# Average across second axis
|
||||
if auto_conditioning.dim() > 2:
|
||||
auto_conditioning = torch.mean(auto_conditioning,axis=1)
|
||||
if diffusion_conditioning.dim() > 2:
|
||||
diffusion_conditioning = torch.mean(diffusion_conditioning,axis=1)
|
||||
else:
|
||||
auto_conditioning, diffusion_conditioning = self.get_random_conditioning_latents()
|
||||
|
||||
|
||||
|
||||
auto_conditioning = auto_conditioning.to(self.device)
|
||||
diffusion_conditioning = diffusion_conditioning.to(self.device)
|
||||
|
||||
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
|
||||
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free,
|
||||
cond_free_k=cond_free_k)
|
||||
|
||||
with torch.no_grad():
|
||||
samples = []
|
||||
|
|
@ -407,41 +454,94 @@ class TextToSpeech:
|
|||
print("Generating autoregressive samples..")
|
||||
if not torch.backends.mps.is_available():
|
||||
with self.temporary_cuda(self.autoregressive
|
||||
) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16, enabled=self.half):
|
||||
) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16,
|
||||
enabled=self.half):
|
||||
# Store the latent indices for alignment with the diffusion conditions
|
||||
batched_latent_indices = []
|
||||
for b in tqdm(range(num_batches), disable=not verbose):
|
||||
codes = autoregressive.inference_speech(auto_conditioning, text_tokens,
|
||||
do_sample=True,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
num_return_sequences=self.autoregressive_batch_size,
|
||||
length_penalty=length_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_generate_length=max_mel_tokens,
|
||||
**hf_generate_kwargs)
|
||||
# Case where we're returning non-averaged latents
|
||||
if auto_conditioning.dim() > 2:
|
||||
# If the number of candidate speech conditioning latents are not equal to (greater)
|
||||
# num_return_sequences (batch), randomly select an equal number of candidate latents.
|
||||
if auto_conditioning.shape[1] >= self.autoregressive_batch_size:
|
||||
latent_indices = torch.randperm(auto_conditioning.shape[1])[
|
||||
:self.autoregressive_batch_size]
|
||||
batched_latent_indices.append(latent_indices)
|
||||
else:
|
||||
# If there are less candidate speech conditioning latents, replicate the
|
||||
# latents to meet the autoregressive batch size
|
||||
replications = np.ceil(self.autoregressive_batch_size /
|
||||
auto_conditioning.shape[1]).astype(int)
|
||||
latent_indices = (torch.arange(0, auto_conditioning.shape[1], dtype=torch.int32).
|
||||
repeat(replications))[:self.autoregressive_batch_size]
|
||||
batched_latent_indices.append(latent_indices)
|
||||
auto_conditioning_ = auto_conditioning[0, latent_indices].unsqueeze(0)
|
||||
|
||||
else:
|
||||
auto_conditioning_ = auto_conditioning
|
||||
|
||||
codes = autoregressive.inference_speech(auto_conditioning_, text_tokens,
|
||||
do_sample=True,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
num_return_sequences=self.autoregressive_batch_size,
|
||||
length_penalty=length_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_generate_length=max_mel_tokens,
|
||||
**hf_generate_kwargs)
|
||||
padding_needed = max_mel_tokens - codes.shape[1]
|
||||
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
|
||||
samples.append(codes)
|
||||
else:
|
||||
with self.temporary_cuda(self.autoregressive) as autoregressive:
|
||||
batched_latent_indices = []
|
||||
for b in tqdm(range(num_batches), disable=not verbose):
|
||||
codes = autoregressive.inference_speech(auto_conditioning, text_tokens,
|
||||
do_sample=True,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
num_return_sequences=self.autoregressive_batch_size,
|
||||
length_penalty=length_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_generate_length=max_mel_tokens,
|
||||
**hf_generate_kwargs)
|
||||
|
||||
# Case where we're returning non-averaged latents
|
||||
if auto_conditioning.dim() > 2:
|
||||
# If the number of candidate speech conditioning latents are not equal to (greater)
|
||||
# num_return_sequences (batch), randomly select an equal number of candidate latents.
|
||||
if auto_conditioning.shape[1] >= self.autoregressive_batch_size:
|
||||
latent_indices = torch.randperm(auto_conditioning.shape[1])[
|
||||
:self.autoregressive_batch_size]
|
||||
batched_latent_indices.append(latent_indices)
|
||||
else:
|
||||
# If there are less candidate speech conditioning latents, replicate the
|
||||
# latents to meet the autoregressive batch size
|
||||
replications = np.ceil(self.autoregressive_batch_size /
|
||||
auto_conditioning.shape[1]).astype(int)
|
||||
latent_indices = (torch.arange(0, auto_conditioning.shape[1], dtype=torch.int32).
|
||||
repeat(replications))[:self.autoregressive_batch_size]
|
||||
batched_latent_indices.append(latent_indices)
|
||||
auto_conditioning_ = auto_conditioning[0, latent_indices].unsqueeze(0)
|
||||
else:
|
||||
auto_conditioning_ = auto_conditioning
|
||||
|
||||
|
||||
codes = autoregressive.inference_speech(auto_conditioning_, text_tokens,
|
||||
do_sample=True,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
num_return_sequences=self.autoregressive_batch_size,
|
||||
length_penalty=length_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_generate_length=max_mel_tokens,
|
||||
**hf_generate_kwargs)
|
||||
padding_needed = max_mel_tokens - codes.shape[1]
|
||||
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
|
||||
samples.append(codes)
|
||||
|
||||
|
||||
# Flatten batched latents (if not using averaged latents)
|
||||
if len(batched_latent_indices):
|
||||
batched_latent_indices_flattened = torch.cat(batched_latent_indices, dim=0)
|
||||
clip_results = []
|
||||
|
||||
cvvp_results = []
|
||||
clvp_results = []
|
||||
if not torch.backends.mps.is_available():
|
||||
with self.temporary_cuda(self.clvp) as clvp, torch.autocast(
|
||||
device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half
|
||||
device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16,
|
||||
enabled=self.half
|
||||
):
|
||||
if cvvp_amount > 0:
|
||||
if self.cvvp is None:
|
||||
|
|
@ -451,7 +551,8 @@ class TextToSpeech:
|
|||
if self.cvvp is None:
|
||||
print("Computing best candidates using CLVP")
|
||||
else:
|
||||
print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%")
|
||||
print(
|
||||
f"Computing best candidates using CLVP {((1 - cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%")
|
||||
for batch in tqdm(samples, disable=not verbose):
|
||||
for i in range(batch.shape[0]):
|
||||
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
|
||||
|
|
@ -460,17 +561,44 @@ class TextToSpeech:
|
|||
if auto_conds is not None and cvvp_amount > 0:
|
||||
cvvp_accumulator = 0
|
||||
for cl in range(auto_conds.shape[1]):
|
||||
cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
|
||||
cvvp_accumulator = cvvp_accumulator + self.cvvp(
|
||||
auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
|
||||
cvvp = cvvp_accumulator / auto_conds.shape[1]
|
||||
if cvvp_amount == 1:
|
||||
# Voice only based selection (CVVP - how well do the VQ mels align with the voice prompt(s)?)
|
||||
clip_results.append(cvvp)
|
||||
else:
|
||||
clip_results.append(cvvp * cvvp_amount + clvp_out * (1-cvvp_amount))
|
||||
# Hybrid Voice-Text based selection
|
||||
# We append to clvp and cvvp lists separately such that norm equalization occurs
|
||||
# on total set.
|
||||
clvp_results.append(clvp_out)
|
||||
cvvp_results.append(cvvp)
|
||||
else:
|
||||
# Text based selection (CLVP - how well do the VQ mels align with the text prompt?)
|
||||
clip_results.append(clvp_out)
|
||||
clip_results = torch.cat(clip_results, dim=0)
|
||||
|
||||
if len(clvp_results):
|
||||
# cvvp and clvp_out have dramatically different scales. Equalize the norms such that
|
||||
# weighting value has more intuitive, linear meaning.
|
||||
clvp_results = torch.cat(clvp_results,dim=0)
|
||||
cvvp_results = torch.cat(cvvp_results,dim=0)
|
||||
norm_clvp = torch.linalg.norm(clvp_results)
|
||||
norm_cvvp = torch.linalg.norm(cvvp_results)
|
||||
norm_scale_cvvp = norm_clvp/norm_cvvp
|
||||
cvvp_results *= norm_scale_cvvp
|
||||
|
||||
# Calculate weighted clip results
|
||||
clip_results = cvvp * cvvp_amount + clvp_out * (1 - cvvp_amount)
|
||||
|
||||
|
||||
else:
|
||||
clip_results = torch.cat(clip_results, dim=0)
|
||||
samples = torch.cat(samples, dim=0)
|
||||
best_results = samples[torch.topk(clip_results, k=k).indices]
|
||||
top_k_ = torch.topk(clip_results, k=k).indices
|
||||
if len(batched_latent_indices):
|
||||
# map top_k_ back to samples to reference proper diffusion conditions
|
||||
mapped_top_k = batched_latent_indices_flattened[top_k_.cpu()]
|
||||
best_results = samples[top_k_]
|
||||
else:
|
||||
with self.temporary_cuda(self.clvp) as clvp:
|
||||
if cvvp_amount > 0:
|
||||
|
|
@ -481,7 +609,8 @@ class TextToSpeech:
|
|||
if self.cvvp is None:
|
||||
print("Computing best candidates using CLVP")
|
||||
else:
|
||||
print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%")
|
||||
print(
|
||||
f"Computing best candidates using CLVP {((1 - cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%")
|
||||
for batch in tqdm(samples, disable=not verbose):
|
||||
for i in range(batch.shape[0]):
|
||||
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
|
||||
|
|
@ -490,13 +619,17 @@ class TextToSpeech:
|
|||
if auto_conds is not None and cvvp_amount > 0:
|
||||
cvvp_accumulator = 0
|
||||
for cl in range(auto_conds.shape[1]):
|
||||
cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
|
||||
cvvp_accumulator = cvvp_accumulator + self.cvvp(
|
||||
auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
|
||||
cvvp = cvvp_accumulator / auto_conds.shape[1]
|
||||
if cvvp_amount == 1:
|
||||
# Voice only based selection (CVVP - how well do the VQ mels align with the voice prompt(s)?)
|
||||
clip_results.append(cvvp)
|
||||
else:
|
||||
clip_results.append(cvvp * cvvp_amount + clvp_out * (1-cvvp_amount))
|
||||
# Hybrid Voice-Text based selection
|
||||
clip_results.append(cvvp * cvvp_amount + clvp_out * (1 - cvvp_amount))
|
||||
else:
|
||||
# Text based selection (CLVP - how well do the VQ mels align with the text prompt?)
|
||||
clip_results.append(clvp_out)
|
||||
clip_results = torch.cat(clip_results, dim=0)
|
||||
samples = torch.cat(samples, dim=0)
|
||||
|
|
@ -510,23 +643,44 @@ class TextToSpeech:
|
|||
# results, but will increase memory usage.
|
||||
if not torch.backends.mps.is_available():
|
||||
with self.temporary_cuda(
|
||||
self.autoregressive
|
||||
self.autoregressive
|
||||
) as autoregressive, torch.autocast(
|
||||
device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half
|
||||
device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16,
|
||||
enabled=self.half
|
||||
):
|
||||
best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
|
||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
|
||||
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
|
||||
return_latent=True, clip_inputs=False)
|
||||
# Allow for non-averaged latents
|
||||
if auto_conditioning.dim() == 2:
|
||||
auto_conditioning = auto_conditioning.repeat(k, 1)
|
||||
else:
|
||||
# Select the best condition
|
||||
auto_conditioning = (auto_conditioning[0,mapped_top_k]).repeat(k,1)
|
||||
|
||||
best_latents = autoregressive(auto_conditioning, text_tokens.repeat(k, 1),
|
||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device),
|
||||
best_results,
|
||||
torch.tensor([best_results.shape[
|
||||
-1] * self.autoregressive.mel_length_compression],
|
||||
device=text_tokens.device),
|
||||
return_latent=True, clip_inputs=False)
|
||||
del auto_conditioning
|
||||
else:
|
||||
with self.temporary_cuda(
|
||||
self.autoregressive
|
||||
self.autoregressive
|
||||
) as autoregressive:
|
||||
best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
|
||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
|
||||
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
|
||||
return_latent=True, clip_inputs=False)
|
||||
# Allow for non-averaged latents
|
||||
if auto_conditioning.dim() == 2:
|
||||
auto_conditioning = auto_conditioning.repeat(k, 1)
|
||||
else:
|
||||
# Select the best condition
|
||||
auto_conditioning = (auto_conditioning[0,mapped_top_k]).repeat(k,1)
|
||||
|
||||
best_latents = autoregressive(auto_conditioning, text_tokens.repeat(k, 1),
|
||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device),
|
||||
best_results,
|
||||
torch.tensor([best_results.shape[
|
||||
-1] * self.autoregressive.mel_length_compression],
|
||||
device=text_tokens.device),
|
||||
return_latent=True, clip_inputs=False)
|
||||
del auto_conditioning
|
||||
|
||||
if verbose:
|
||||
|
|
@ -534,7 +688,7 @@ class TextToSpeech:
|
|||
wav_candidates = []
|
||||
if not torch.backends.mps.is_available():
|
||||
with self.temporary_cuda(self.diffusion) as diffusion, self.temporary_cuda(
|
||||
self.vocoder
|
||||
self.vocoder
|
||||
) as vocoder:
|
||||
for b in range(best_results.shape[0]):
|
||||
codes = best_results[b].unsqueeze(0)
|
||||
|
|
@ -550,8 +704,15 @@ class TextToSpeech:
|
|||
if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
|
||||
latents = latents[:, :k]
|
||||
break
|
||||
mel = do_spectrogram_diffusion(diffusion, diffuser, latents, diffusion_conditioning, temperature=diffusion_temperature,
|
||||
verbose=verbose)
|
||||
|
||||
# Get top selection for diffusion conditioning
|
||||
if diffusion_conditioning.dim() > 2:
|
||||
diffusion_conditioning_ = diffusion_conditioning[0,mapped_top_k]
|
||||
else:
|
||||
diffusion_conditioning_ = diffusion_conditioning
|
||||
mel = do_spectrogram_diffusion(diffusion, diffuser, latents, diffusion_conditioning_,
|
||||
temperature=diffusion_temperature,
|
||||
verbose=verbose)
|
||||
wav = vocoder.inference(mel)
|
||||
wav_candidates.append(wav.cpu())
|
||||
else:
|
||||
|
|
@ -571,15 +732,27 @@ class TextToSpeech:
|
|||
if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
|
||||
latents = latents[:, :k]
|
||||
break
|
||||
mel = do_spectrogram_diffusion(diffusion, diffuser, latents, diffusion_conditioning, temperature=diffusion_temperature,
|
||||
verbose=verbose)
|
||||
|
||||
# Get top selection for diffusion conditioning
|
||||
if diffusion_conditioning.dim() > 2:
|
||||
diffusion_conditioning_ = diffusion_conditioning[0, mapped_top_k]
|
||||
else:
|
||||
diffusion_conditioning_ = diffusion_conditioning
|
||||
|
||||
mel = do_spectrogram_diffusion(diffusion, diffuser, latents, diffusion_conditioning_,
|
||||
temperature=diffusion_temperature,
|
||||
verbose=verbose)
|
||||
wav = vocoder.inference(mel)
|
||||
wav_candidates.append(wav.cpu())
|
||||
|
||||
if verbose:
|
||||
print('Finished getting wav candidates')
|
||||
|
||||
def potentially_redact(clip, text):
|
||||
if self.enable_redaction:
|
||||
return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1)
|
||||
return clip
|
||||
|
||||
wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates]
|
||||
|
||||
if len(wav_candidates) > 1:
|
||||
|
|
@ -587,10 +760,14 @@ class TextToSpeech:
|
|||
else:
|
||||
res = wav_candidates[0]
|
||||
|
||||
if verbose:
|
||||
print("Returning result")
|
||||
|
||||
if return_deterministic_state:
|
||||
return res, (deterministic_seed, text, voice_samples, conditioning_latents)
|
||||
else:
|
||||
return res
|
||||
|
||||
def deterministic_state(self, seed=None):
|
||||
"""
|
||||
Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be
|
||||
|
|
|
|||
|
|
@ -129,8 +129,10 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
|
||||
# Create embedding
|
||||
mel_len = self.cached_mel_emb.shape[1]
|
||||
|
||||
if input_ids.shape[1] != 1:
|
||||
text_inputs = input_ids[:, mel_len:]
|
||||
text_emb = self.embeddings(text_inputs)
|
||||
|
|
@ -147,6 +149,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
emb = emb + self.text_pos_embedding.get_fixed_embedding(
|
||||
attention_mask.shape[1] - mel_len, attention_mask.device
|
||||
)
|
||||
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
inputs_embeds=emb,
|
||||
past_key_values=past_key_values,
|
||||
|
|
@ -159,8 +163,9 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=return_dict
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
# Set device for model parallelism
|
||||
|
|
@ -441,16 +446,18 @@ class UnifiedVoice(nn.Module):
|
|||
else:
|
||||
return first_logits
|
||||
|
||||
def get_conditioning(self, speech_conditioning_input):
|
||||
def get_conditioning(self, speech_conditioning_input, return_average=True):
|
||||
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(
|
||||
speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
conds = conds.mean(dim=1)
|
||||
if return_average:
|
||||
conds = conds.mean(dim=1)
|
||||
return conds
|
||||
|
||||
|
||||
def forward(self, speech_conditioning_latent, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False,
|
||||
return_latent=False, clip_inputs=True):
|
||||
"""
|
||||
|
|
@ -485,7 +492,11 @@ class UnifiedVoice(nn.Module):
|
|||
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
|
||||
mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token)
|
||||
|
||||
conds = speech_conditioning_latent.unsqueeze(1)
|
||||
if speech_conditioning_latent.dim() == 2:
|
||||
conds = speech_conditioning_latent.unsqueeze(1)
|
||||
else:
|
||||
conds = speech_conditioning_latent
|
||||
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
||||
|
|
@ -532,18 +543,38 @@ class UnifiedVoice(nn.Module):
|
|||
)
|
||||
gpt_inputs[:, -1] = self.start_mel_token
|
||||
return gpt_inputs
|
||||
def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
|
||||
def inference_speech(self, speech_conditioning_latents, text_inputs, input_tokens=None, num_return_sequences=1,
|
||||
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
|
||||
|
||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||
text_inputs, _ = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
|
||||
conds = speech_conditioning_latent.unsqueeze(1)
|
||||
emb = torch.cat([conds, text_emb], dim=1)
|
||||
# Optionally expand the speech conditioning latent for concatenation to text embedding.
|
||||
# Allow for different speech conditioning latents to be passed per sample.
|
||||
if speech_conditioning_latents.dim() == 2:
|
||||
conds = speech_conditioning_latents.unsqueeze(1)
|
||||
emb = torch.cat([conds, text_emb], dim=1)
|
||||
|
||||
else:
|
||||
assert speech_conditioning_latents.shape[1] == num_return_sequences, \
|
||||
("If the number of speech conditioning latents passed is > 1, they must be equal to the "
|
||||
"autoregressive batch size")
|
||||
conds = speech_conditioning_latents
|
||||
|
||||
# Here, we have num_return_sequences unique VQ mels (different conditional VQ mel for each)
|
||||
emb = torch.cat([torch.swapaxes(conds,0,1),
|
||||
text_emb.repeat(num_return_sequences,1,1)], dim=1)
|
||||
|
||||
|
||||
self.inference_model.store_mel_emb(emb)
|
||||
|
||||
fake_inputs = torch.full((emb.shape[0], conds.shape[1] + emb.shape[1],), fill_value=1, dtype=torch.long,
|
||||
|
||||
# TODO: Resolve below adjustment. When might emb.shape != 1?
|
||||
# fake_inputs = torch.full((emb.shape[0], conds.shape[1] + emb.shape[1],), fill_value=1, dtype=torch.long,
|
||||
# device=text_inputs.device)
|
||||
|
||||
fake_inputs = torch.full((1, 1 + emb.shape[1],), fill_value=1, dtype=torch.long,
|
||||
device=text_inputs.device)
|
||||
fake_inputs[:, -1] = self.start_mel_token
|
||||
trunc_index = fake_inputs.shape[1]
|
||||
|
|
@ -554,12 +585,19 @@ class UnifiedVoice(nn.Module):
|
|||
fake_inputs = fake_inputs.repeat(num_return_sequences, 1)
|
||||
input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
|
||||
inputs = torch.cat([fake_inputs, input_tokens], dim=1)
|
||||
|
||||
logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
|
||||
max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
|
||||
|
||||
# Pre-expansion of inputs - this simply expands inputs into the number of input sequences
|
||||
inputs = self.inference_model._expand_inputs_for_generation(expand_size=num_return_sequences,
|
||||
input_ids=inputs,
|
||||
**hf_generate_kwargs)[0]
|
||||
|
||||
#print(f'GENERATE KWARGS: {hf_generate_kwargs}')
|
||||
gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
||||
max_length=max_length, logits_processor=logits_processor,
|
||||
num_return_sequences=num_return_sequences, **hf_generate_kwargs)
|
||||
num_return_sequences=1, **hf_generate_kwargs)
|
||||
|
||||
return gen[:, trunc_index:]
|
||||
|
||||
def get_generator(self, fake_inputs, **hf_generate_kwargs):
|
||||
|
|
|
|||
|
|
@ -108,6 +108,9 @@ class CVVP(nn.Module):
|
|||
mel_input,
|
||||
return_loss=False
|
||||
):
|
||||
|
||||
#print(f'MEL COND SHAPE:{mel_cond.shape}')
|
||||
#print(f'MEL IN SHAPE: {mel_input.shape}')
|
||||
cond_emb = self.cond_emb(mel_cond).permute(0, 2, 1)
|
||||
enc_cond = self.conditioning_transformer(cond_emb)
|
||||
cond_latents = self.to_conditioning_latent(enc_cond)
|
||||
|
|
|
|||
|
|
@ -219,14 +219,21 @@ class DiffusionTts(nn.Module):
|
|||
}
|
||||
return groups
|
||||
|
||||
def get_conditioning(self, conditioning_input):
|
||||
def get_conditioning(self, conditioning_input, return_average=True):
|
||||
speech_conditioning_input = conditioning_input.unsqueeze(1) if len(
|
||||
conditioning_input.shape) == 3 else conditioning_input
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
|
||||
conds = torch.cat(conds, dim=-1)
|
||||
conds = conds.mean(dim=-1)
|
||||
diff_context = self.contextual_embedder(speech_conditioning_input[:, j])
|
||||
if not return_average:
|
||||
# We must still average across the last dim per sample (we don't average cross all samples)
|
||||
diff_context = diff_context.mean(dim=-1)
|
||||
conds.append(diff_context)
|
||||
if return_average:
|
||||
conds = torch.cat(conds, dim=-1)
|
||||
conds = conds.mean(dim=-1)
|
||||
else:
|
||||
conds = torch.stack(conds,dim=1)
|
||||
return conds
|
||||
|
||||
def timestep_independent(self, aligned_conditioning, conditioning_latent, expected_seq_len, return_code_pred):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import numpy as np
|
|||
from scipy.io.wavfile import read
|
||||
|
||||
from tortoise.utils.stft import STFT
|
||||
|
||||
from tortoise.utils.misc_helpers import Timer
|
||||
|
||||
BUILTIN_VOICES_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../voices')
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue