mirror of
https://github.com/neonbjb/tortoise-tts.git
synced 2026-03-23 13:44:38 +01:00
Port do_tts.py and relevant logics for DirectML
This commit is contained in:
parent
80f89987a5
commit
6f291d35ea
|
|
@ -23,3 +23,4 @@ py-cpuinfo
|
|||
hjson
|
||||
psutil
|
||||
sounddevice
|
||||
torch-directml
|
||||
|
|
|
|||
0
scripts/tortoise_tts.py
Executable file → Normal file
0
scripts/tortoise_tts.py
Executable file → Normal file
129
tortoise/api.py
129
tortoise/api.py
|
|
@ -4,7 +4,7 @@ import uuid
|
|||
from time import time
|
||||
from urllib import request
|
||||
|
||||
import torch
|
||||
import torch, torch_directml
|
||||
import torch.nn.functional as F
|
||||
import progressbar
|
||||
import torchaudio
|
||||
|
|
@ -70,10 +70,18 @@ def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusi
|
|||
conditioning_free=cond_free, conditioning_free_k=cond_free_k)
|
||||
|
||||
|
||||
def format_conditioning(clip, cond_length=132300, device="cuda" if not torch.backends.mps.is_available() else 'mps'):
|
||||
def format_conditioning(clip, cond_length=132300, device=''):
|
||||
"""
|
||||
Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models.
|
||||
"""
|
||||
|
||||
if device == '':
|
||||
device = 'cpu'
|
||||
if torch.cuda.is_available():
|
||||
device = 'cuda'
|
||||
elif torch.backends.mps.is_available():
|
||||
device = 'mps'
|
||||
|
||||
gap = clip.shape[-1] - cond_length
|
||||
if gap < 0:
|
||||
clip = F.pad(clip, pad=(0, abs(gap)))
|
||||
|
|
@ -169,6 +177,11 @@ def pick_best_batch_size_for_gpu():
|
|||
return 8
|
||||
elif availableGb > 7:
|
||||
return 4
|
||||
|
||||
# DirectML is available, but we don't know how much memory is available.
|
||||
if torch_directml.is_available():
|
||||
return 16
|
||||
|
||||
return 1
|
||||
|
||||
class TextToSpeech:
|
||||
|
|
@ -194,9 +207,16 @@ class TextToSpeech:
|
|||
self.models_dir = models_dir
|
||||
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
|
||||
self.enable_redaction = enable_redaction
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu')
|
||||
if torch.backends.mps.is_available():
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.device = torch.device('cuda')
|
||||
elif torch.backends.mps.is_available():
|
||||
self.device = torch.device('mps')
|
||||
elif torch_directml.is_available():
|
||||
self.device = torch_directml.device(0)
|
||||
else:
|
||||
self.device = torch.device('cpu')
|
||||
|
||||
if self.enable_redaction:
|
||||
self.aligner = Wav2VecAlignment()
|
||||
|
||||
|
|
@ -240,7 +260,7 @@ class TextToSpeech:
|
|||
def temporary_cuda(self, model):
|
||||
m = model.to(self.device)
|
||||
yield m
|
||||
m = model.cpu()
|
||||
# m = model.cpu()
|
||||
|
||||
|
||||
def load_cvvp(self):
|
||||
|
|
@ -379,6 +399,7 @@ class TextToSpeech:
|
|||
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
|
||||
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
|
||||
assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
|
||||
|
||||
auto_conds = None
|
||||
if voice_samples is not None:
|
||||
auto_conditioning, diffusion_conditioning, auto_conds, _ = self.get_conditioning_latents(voice_samples, return_mels=True)
|
||||
|
|
@ -398,7 +419,13 @@ class TextToSpeech:
|
|||
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
|
||||
if verbose:
|
||||
print("Generating autoregressive samples..")
|
||||
if not torch.backends.mps.is_available():
|
||||
print(f"CUDA Available: \t{torch.cuda.is_available()}")
|
||||
print(f"MPS Available: \t\t{torch.backends.mps.is_available()}")
|
||||
print(f"DirectML Available: \t{torch_directml.is_available()}")
|
||||
print(f"Autoregressive Batch Size: {self.autoregressive_batch_size}")
|
||||
|
||||
# CUDA
|
||||
if torch.cuda.is_available():
|
||||
with self.temporary_cuda(self.autoregressive
|
||||
) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16, enabled=self.half):
|
||||
for b in tqdm(range(num_batches), disable=not verbose):
|
||||
|
|
@ -414,6 +441,27 @@ class TextToSpeech:
|
|||
padding_needed = max_mel_tokens - codes.shape[1]
|
||||
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
|
||||
samples.append(codes)
|
||||
# MPS
|
||||
elif torch.backends.mps.is_available():
|
||||
with self.temporary_cuda(self.autoregressive
|
||||
) as autoregressive, torch.autocast(device_type="mps", dtype=torch.float16, enabled=self.half):
|
||||
for b in tqdm(range(num_batches), disable=not verbose):
|
||||
codes = autoregressive.inference_speech(auto_conditioning, text_tokens,
|
||||
do_sample=True,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
num_return_sequences=self.autoregressive_batch_size,
|
||||
length_penalty=length_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_generate_length=max_mel_tokens,
|
||||
**hf_generate_kwargs)
|
||||
padding_needed = max_mel_tokens - codes.shape[1]
|
||||
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
|
||||
samples.append(codes)
|
||||
|
||||
# CPU
|
||||
# DirectML doesn't support autocast for now.
|
||||
# https://github.com/microsoft/DirectML/issues/454#issuecomment-1703862192
|
||||
else:
|
||||
with self.temporary_cuda(self.autoregressive) as autoregressive:
|
||||
for b in tqdm(range(num_batches), disable=not verbose):
|
||||
|
|
@ -432,9 +480,10 @@ class TextToSpeech:
|
|||
|
||||
clip_results = []
|
||||
|
||||
if not torch.backends.mps.is_available():
|
||||
# CUDA
|
||||
if torch.cuda.is_available():
|
||||
with self.temporary_cuda(self.clvp) as clvp, torch.autocast(
|
||||
device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half
|
||||
device_type="cuda", dtype=torch.float16, enabled=self.half
|
||||
):
|
||||
if cvvp_amount > 0:
|
||||
if self.cvvp is None:
|
||||
|
|
@ -464,6 +513,43 @@ class TextToSpeech:
|
|||
clip_results = torch.cat(clip_results, dim=0)
|
||||
samples = torch.cat(samples, dim=0)
|
||||
best_results = samples[torch.topk(clip_results, k=k).indices]
|
||||
|
||||
# MPS
|
||||
if torch.backends.mps.is_available():
|
||||
with self.temporary_cuda(self.clvp) as clvp, torch.autocast(
|
||||
device_type="mps", dtype=torch.float16, enabled=self.half
|
||||
):
|
||||
if cvvp_amount > 0:
|
||||
if self.cvvp is None:
|
||||
self.load_cvvp()
|
||||
self.cvvp = self.cvvp.to(self.device)
|
||||
if verbose:
|
||||
if self.cvvp is None:
|
||||
print("Computing best candidates using CLVP")
|
||||
else:
|
||||
print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%")
|
||||
for batch in tqdm(samples, disable=not verbose):
|
||||
for i in range(batch.shape[0]):
|
||||
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
|
||||
if cvvp_amount != 1:
|
||||
clvp_out = clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False)
|
||||
if auto_conds is not None and cvvp_amount > 0:
|
||||
cvvp_accumulator = 0
|
||||
for cl in range(auto_conds.shape[1]):
|
||||
cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
|
||||
cvvp = cvvp_accumulator / auto_conds.shape[1]
|
||||
if cvvp_amount == 1:
|
||||
clip_results.append(cvvp)
|
||||
else:
|
||||
clip_results.append(cvvp * cvvp_amount + clvp_out * (1-cvvp_amount))
|
||||
else:
|
||||
clip_results.append(clvp_out)
|
||||
clip_results = torch.cat(clip_results, dim=0)
|
||||
samples = torch.cat(samples, dim=0)
|
||||
best_results = samples[torch.topk(clip_results, k=k).indices]
|
||||
|
||||
# CPU
|
||||
# DirectML does not support autocast for now
|
||||
else:
|
||||
with self.temporary_cuda(self.clvp) as clvp:
|
||||
if cvvp_amount > 0:
|
||||
|
|
@ -501,21 +587,24 @@ class TextToSpeech:
|
|||
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
|
||||
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
|
||||
# results, but will increase memory usage.
|
||||
if not torch.backends.mps.is_available():
|
||||
with self.temporary_cuda(
|
||||
self.autoregressive
|
||||
) as autoregressive, torch.autocast(
|
||||
device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
with self.temporary_cuda(self.autoregressive) as autoregressive, torch.autocast(
|
||||
device_type="cuda", dtype=torch.float16, enabled=self.half):
|
||||
best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
|
||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
|
||||
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
|
||||
return_latent=True, clip_inputs=False)
|
||||
del auto_conditioning
|
||||
elif torch.backends.mps.is_available():
|
||||
with self.temporary_cuda(self.autoregressive) as autoregressive, torch.autocast(
|
||||
device_type="mps", dtype=torch.float16, enabled=self.half):
|
||||
best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
|
||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
|
||||
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
|
||||
return_latent=True, clip_inputs=False)
|
||||
del auto_conditioning
|
||||
else:
|
||||
with self.temporary_cuda(
|
||||
self.autoregressive
|
||||
) as autoregressive:
|
||||
with self.temporary_cuda(self.autoregressive) as autoregressive:
|
||||
best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
|
||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
|
||||
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
|
||||
|
|
@ -525,10 +614,8 @@ class TextToSpeech:
|
|||
if verbose:
|
||||
print("Transforming autoregressive outputs into audio..")
|
||||
wav_candidates = []
|
||||
if not torch.backends.mps.is_available():
|
||||
with self.temporary_cuda(self.diffusion) as diffusion, self.temporary_cuda(
|
||||
self.vocoder
|
||||
) as vocoder:
|
||||
if torch.cuda.is_available() or torch.backends.mps.is_available() or torch_directml.is_available():
|
||||
with self.temporary_cuda(self.diffusion) as diffusion, self.temporary_cuda(self.vocoder) as vocoder:
|
||||
for b in range(best_results.shape[0]):
|
||||
codes = best_results[b].unsqueeze(0)
|
||||
latents = best_latents[b].unsqueeze(0)
|
||||
|
|
|
|||
|
|
@ -302,7 +302,7 @@ class DiffusionTts(nn.Module):
|
|||
unused_params.extend(list(lyr.parameters()))
|
||||
else:
|
||||
# First and last blocks will have autocast disabled for improved precision.
|
||||
if not torch.backends.mps.is_available():
|
||||
if torch.cuda.is_available():
|
||||
with autocast(x.device.type, enabled=self.enable_fp16 and i != 0):
|
||||
x = lyr(x, time_emb)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import torch
|
||||
import torch, torch_directml
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
|
@ -208,10 +208,16 @@ class LVCBlock(torch.nn.Module):
|
|||
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
||||
|
||||
o = torch.einsum('bildsk,biokl->bolsd', x, kernel)
|
||||
o = o.to(memory_format=torch.channels_last_3d)
|
||||
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
|
||||
o = o + bias
|
||||
o = o.contiguous().view(batch, out_channels, -1)
|
||||
|
||||
if torch_directml.is_available():
|
||||
o = o + bias.unsqueeze(-1).unsqueeze(-1)
|
||||
o = o.contiguous().view(batch, out_channels, -1)
|
||||
|
||||
else:
|
||||
o = o.to(memory_format=torch.channels_last_3d)
|
||||
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
|
||||
o = o + bias
|
||||
o = o.contiguous().view(batch, out_channels, -1)
|
||||
|
||||
return o
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue