mirror of
https://github.com/neonbjb/tortoise-tts.git
synced 2026-04-05 06:25:39 +00:00
Add support for extracting and feeding conditioning latents directly into the model
- Adds a new script and API endpoints for doing this - Reworks autoregressive and diffusion models so that the conditioning is computed separately (which will actually provide a mild performance boost) - Updates README This is untested. Need to do the following manual tests (and someday write unit tests for this behemoth before it becomes a problem..) 1) Does get_conditioning_latents.py work? 2) Can I feed those latents back into the model by creating a new voice? 3) Can I still mix and match voices (both with conditioning latents and normal voices) with read.py?
This commit is contained in:
parent
a8264f5cef
commit
0ffc191408
8 changed files with 165 additions and 78 deletions
|
|
@ -121,23 +121,14 @@ def fix_autoregressive_output(codes, stop_token, complain=True):
|
|||
return codes
|
||||
|
||||
|
||||
def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_samples, temperature=1, verbose=True):
|
||||
def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_latents, temperature=1, verbose=True):
|
||||
"""
|
||||
Uses the specified diffusion model to convert discrete codes into a spectrogram.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
cond_mels = []
|
||||
for sample in conditioning_samples:
|
||||
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
|
||||
sample = torchaudio.functional.resample(sample, 22050, 24000)
|
||||
sample = pad_or_truncate(sample, 102400)
|
||||
cond_mel = wav_to_univnet_mel(sample.to(latents.device), do_normalization=False)
|
||||
cond_mels.append(cond_mel)
|
||||
cond_mels = torch.stack(cond_mels, dim=1)
|
||||
|
||||
output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
|
||||
output_shape = (latents.shape[0], 100, output_seq_len)
|
||||
precomputed_embeddings = diffusion_model.timestep_independent(latents, cond_mels, output_seq_len, False)
|
||||
precomputed_embeddings = diffusion_model.timestep_independent(latents, conditioning_latents, output_seq_len, False)
|
||||
|
||||
noise = torch.randn(output_shape, device=latents.device) * temperature
|
||||
mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
|
||||
|
|
@ -204,7 +195,7 @@ class TextToSpeech:
|
|||
self.vocoder.load_state_dict(torch.load(f'{models_dir}/vocoder.pth')['model_g'])
|
||||
self.vocoder.eval(inference=True)
|
||||
|
||||
def tts_with_preset(self, text, voice_samples, preset='fast', **kwargs):
|
||||
def tts_with_preset(self, text, preset='fast', **kwargs):
|
||||
"""
|
||||
Calls TTS with one of a set of preset generation parameters. Options:
|
||||
'ultra_fast': Produces speech at a speed which belies the name of this repo. (Not really, but it's definitely fastest).
|
||||
|
|
@ -225,9 +216,43 @@ class TextToSpeech:
|
|||
'high_quality': {'num_autoregressive_samples': 512, 'diffusion_iterations': 1024},
|
||||
}
|
||||
kwargs.update(presets[preset])
|
||||
return self.tts(text, voice_samples, **kwargs)
|
||||
return self.tts(text, **kwargs)
|
||||
|
||||
def tts(self, text, voice_samples, k=1, verbose=True,
|
||||
def get_conditioning_latents(self, voice_samples):
|
||||
"""
|
||||
Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
|
||||
These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic
|
||||
properties.
|
||||
:param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data.
|
||||
"""
|
||||
voice_samples = [v.to('cuda') for v in voice_samples]
|
||||
|
||||
auto_conds = []
|
||||
if not isinstance(voice_samples, list):
|
||||
voice_samples = [voice_samples]
|
||||
for vs in voice_samples:
|
||||
auto_conds.append(format_conditioning(vs))
|
||||
auto_conds = torch.stack(auto_conds, dim=1)
|
||||
self.autoregressive = self.autoregressive.cuda()
|
||||
auto_latent = self.autoregressive.get_conditioning(auto_conds)
|
||||
self.autoregressive = self.autoregressive.cpu()
|
||||
|
||||
diffusion_conds = []
|
||||
for sample in voice_samples:
|
||||
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
|
||||
sample = torchaudio.functional.resample(sample, 22050, 24000)
|
||||
sample = pad_or_truncate(sample, 102400)
|
||||
cond_mel = wav_to_univnet_mel(sample.to(voice_samples.device), do_normalization=False)
|
||||
diffusion_conds.append(cond_mel)
|
||||
diffusion_conds = torch.stack(diffusion_conds, dim=1)
|
||||
|
||||
self.diffusion = self.diffusion.cuda()
|
||||
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
|
||||
self.diffusion = self.diffusion.cpu()
|
||||
|
||||
return auto_latent, diffusion_latent
|
||||
|
||||
def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True,
|
||||
# autoregressive generation parameters follow
|
||||
num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500,
|
||||
typical_sampling=False, typical_mass=.9,
|
||||
|
|
@ -240,6 +265,9 @@ class TextToSpeech:
|
|||
Produces an audio clip of the given text being spoken with the given reference voice.
|
||||
:param text: Text to be spoken.
|
||||
:param voice_samples: List of 2 or more ~10 second reference clips which should be torch tensors containing 22.05kHz waveform data.
|
||||
:param conditioning_latents: A tuple of (autoregressive_conditioning_latent, diffusion_conditioning_latent), which
|
||||
can be provided in lieu of voice_samples. This is ignored unless voice_samples=None.
|
||||
Conditioning latents can be retrieved via get_conditioning_latents().
|
||||
:param k: The number of returned clips. The most likely (as determined by Tortoises' CLVP and CVVP models) clips are returned.
|
||||
:param verbose: Whether or not to print log messages indicating the progress of creating a clip. Default=true.
|
||||
~~AUTOREGRESSIVE KNOBS~~
|
||||
|
|
@ -283,12 +311,10 @@ class TextToSpeech:
|
|||
text = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
|
||||
text = F.pad(text, (0, 1)) # This may not be necessary.
|
||||
|
||||
conds = []
|
||||
if not isinstance(voice_samples, list):
|
||||
voice_samples = [voice_samples]
|
||||
for vs in voice_samples:
|
||||
conds.append(format_conditioning(vs))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
if voice_samples is not None:
|
||||
auto_conditioning, diffusion_conditioning = self.get_conditioning_latents(voice_samples)
|
||||
else:
|
||||
auto_conditioning, diffusion_conditioning = conditioning_latents
|
||||
|
||||
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
|
||||
|
||||
|
|
@ -301,7 +327,7 @@ class TextToSpeech:
|
|||
if verbose:
|
||||
print("Generating autoregressive samples..")
|
||||
for b in tqdm(range(num_batches), disable=not verbose):
|
||||
codes = self.autoregressive.inference_speech(conds, text,
|
||||
codes = self.autoregressive.inference_speech(auto_conditioning, text,
|
||||
do_sample=True,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
|
|
@ -340,16 +366,18 @@ class TextToSpeech:
|
|||
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
|
||||
# results, but will increase memory usage.
|
||||
self.autoregressive = self.autoregressive.cuda()
|
||||
best_latents = self.autoregressive(conds, text, torch.tensor([text.shape[-1]], device=conds.device), best_results,
|
||||
best_latents = self.autoregressive(auto_conditioning, text, torch.tensor([text.shape[-1]], device=conds.device), best_results,
|
||||
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=conds.device),
|
||||
return_latent=True, clip_inputs=False)
|
||||
self.autoregressive = self.autoregressive.cpu()
|
||||
del auto_conditioning
|
||||
|
||||
if verbose:
|
||||
print("Transforming autoregressive outputs into audio..")
|
||||
wav_candidates = []
|
||||
self.diffusion = self.diffusion.cuda()
|
||||
self.vocoder = self.vocoder.cuda()
|
||||
diffusion_conds =
|
||||
for b in range(best_results.shape[0]):
|
||||
codes = best_results[b].unsqueeze(0)
|
||||
latents = best_latents[b].unsqueeze(0)
|
||||
|
|
@ -365,7 +393,8 @@ class TextToSpeech:
|
|||
latents = latents[:, :k]
|
||||
break
|
||||
|
||||
mel = do_spectrogram_diffusion(self.diffusion, diffuser, latents, voice_samples, temperature=diffusion_temperature, verbose=verbose)
|
||||
mel = do_spectrogram_diffusion(self.diffusion, diffuser, latents, diffusion_conditioning,
|
||||
temperature=diffusion_temperature, verbose=verbose)
|
||||
wav = self.vocoder.inference(mel)
|
||||
wav_candidates.append(wav.cpu())
|
||||
self.diffusion = self.diffusion.cpu()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue