mirror of
https://github.com/neonbjb/tortoise-tts.git
synced 2025-12-06 07:12:00 +01:00
Added MPS support for do_tts
This commit is contained in:
parent
3c4d9c5131
commit
b4988c24b3
|
|
@ -21,3 +21,4 @@ pydantic==1.9.1
|
|||
deepspeed==0.8.3
|
||||
py-cpuinfo
|
||||
hjson
|
||||
psutil
|
||||
|
|
|
|||
216
tortoise/api.py
216
tortoise/api.py
|
|
@ -189,6 +189,16 @@ def pick_best_batch_size_for_gpu():
|
|||
return 8
|
||||
elif availableGb > 7:
|
||||
return 4
|
||||
if torch.backends.mps.is_available():
|
||||
import psutil
|
||||
available = psutil.virtual_memory().total
|
||||
availableGb = available / (1024 ** 3)
|
||||
if availableGb > 14:
|
||||
return 16
|
||||
elif availableGb > 10:
|
||||
return 8
|
||||
elif availableGb > 7:
|
||||
return 4
|
||||
return 1
|
||||
|
||||
class TextToSpeech:
|
||||
|
|
@ -212,7 +222,9 @@ class TextToSpeech:
|
|||
self.models_dir = models_dir
|
||||
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
|
||||
self.enable_redaction = enable_redaction
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu')
|
||||
if torch.backends.mps.is_available():
|
||||
self.device = torch.device('mps')
|
||||
if self.enable_redaction:
|
||||
self.aligner = Wav2VecAlignment()
|
||||
|
||||
|
|
@ -254,6 +266,7 @@ class TextToSpeech:
|
|||
m = model.to(self.device)
|
||||
yield m
|
||||
m = model.cpu()
|
||||
|
||||
|
||||
def load_cvvp(self):
|
||||
"""Load CVVP model."""
|
||||
|
|
@ -410,54 +423,102 @@ class TextToSpeech:
|
|||
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
|
||||
if verbose:
|
||||
print("Generating autoregressive samples..")
|
||||
with self.temporary_cuda(self.autoregressive
|
||||
) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16, enabled=self.half):
|
||||
for b in tqdm(range(num_batches), disable=not verbose):
|
||||
codes = autoregressive.inference_speech(auto_conditioning, text_tokens,
|
||||
do_sample=True,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
num_return_sequences=self.autoregressive_batch_size,
|
||||
length_penalty=length_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_generate_length=max_mel_tokens,
|
||||
**hf_generate_kwargs)
|
||||
padding_needed = max_mel_tokens - codes.shape[1]
|
||||
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
|
||||
samples.append(codes)
|
||||
if not torch.backends.mps.is_available():
|
||||
with self.temporary_cuda(self.autoregressive
|
||||
) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16, enabled=self.half):
|
||||
for b in tqdm(range(num_batches), disable=not verbose):
|
||||
codes = autoregressive.inference_speech(auto_conditioning, text_tokens,
|
||||
do_sample=True,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
num_return_sequences=self.autoregressive_batch_size,
|
||||
length_penalty=length_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_generate_length=max_mel_tokens,
|
||||
**hf_generate_kwargs)
|
||||
padding_needed = max_mel_tokens - codes.shape[1]
|
||||
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
|
||||
samples.append(codes)
|
||||
else:
|
||||
with self.temporary_cuda(self.autoregressive) as autoregressive:
|
||||
for b in tqdm(range(num_batches), disable=not verbose):
|
||||
codes = autoregressive.inference_speech(auto_conditioning, text_tokens,
|
||||
do_sample=True,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
num_return_sequences=self.autoregressive_batch_size,
|
||||
length_penalty=length_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_generate_length=max_mel_tokens,
|
||||
**hf_generate_kwargs)
|
||||
padding_needed = max_mel_tokens - codes.shape[1]
|
||||
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
|
||||
samples.append(codes)
|
||||
|
||||
clip_results = []
|
||||
with self.temporary_cuda(self.clvp) as clvp, torch.autocast(
|
||||
device_type="cuda", dtype=torch.float16, enabled=self.half
|
||||
):
|
||||
if cvvp_amount > 0:
|
||||
if self.cvvp is None:
|
||||
self.load_cvvp()
|
||||
self.cvvp = self.cvvp.to(self.device)
|
||||
if verbose:
|
||||
if self.cvvp is None:
|
||||
print("Computing best candidates using CLVP")
|
||||
else:
|
||||
print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%")
|
||||
for batch in tqdm(samples, disable=not verbose):
|
||||
for i in range(batch.shape[0]):
|
||||
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
|
||||
if cvvp_amount != 1:
|
||||
clvp_out = clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False)
|
||||
if auto_conds is not None and cvvp_amount > 0:
|
||||
cvvp_accumulator = 0
|
||||
for cl in range(auto_conds.shape[1]):
|
||||
cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
|
||||
cvvp = cvvp_accumulator / auto_conds.shape[1]
|
||||
if cvvp_amount == 1:
|
||||
clip_results.append(cvvp)
|
||||
|
||||
if not torch.backends.mps.is_available():
|
||||
with self.temporary_cuda(self.clvp) as clvp, torch.autocast(
|
||||
device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half
|
||||
):
|
||||
if cvvp_amount > 0:
|
||||
if self.cvvp is None:
|
||||
self.load_cvvp()
|
||||
self.cvvp = self.cvvp.to(self.device)
|
||||
if verbose:
|
||||
if self.cvvp is None:
|
||||
print("Computing best candidates using CLVP")
|
||||
else:
|
||||
clip_results.append(cvvp * cvvp_amount + clvp_out * (1-cvvp_amount))
|
||||
else:
|
||||
clip_results.append(clvp_out)
|
||||
clip_results = torch.cat(clip_results, dim=0)
|
||||
samples = torch.cat(samples, dim=0)
|
||||
best_results = samples[torch.topk(clip_results, k=k).indices]
|
||||
print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%")
|
||||
for batch in tqdm(samples, disable=not verbose):
|
||||
for i in range(batch.shape[0]):
|
||||
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
|
||||
if cvvp_amount != 1:
|
||||
clvp_out = clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False)
|
||||
if auto_conds is not None and cvvp_amount > 0:
|
||||
cvvp_accumulator = 0
|
||||
for cl in range(auto_conds.shape[1]):
|
||||
cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
|
||||
cvvp = cvvp_accumulator / auto_conds.shape[1]
|
||||
if cvvp_amount == 1:
|
||||
clip_results.append(cvvp)
|
||||
else:
|
||||
clip_results.append(cvvp * cvvp_amount + clvp_out * (1-cvvp_amount))
|
||||
else:
|
||||
clip_results.append(clvp_out)
|
||||
clip_results = torch.cat(clip_results, dim=0)
|
||||
samples = torch.cat(samples, dim=0)
|
||||
best_results = samples[torch.topk(clip_results, k=k).indices]
|
||||
else:
|
||||
with self.temporary_cuda(self.clvp) as clvp:
|
||||
if cvvp_amount > 0:
|
||||
if self.cvvp is None:
|
||||
self.load_cvvp()
|
||||
self.cvvp = self.cvvp.to(self.device)
|
||||
if verbose:
|
||||
if self.cvvp is None:
|
||||
print("Computing best candidates using CLVP")
|
||||
else:
|
||||
print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%")
|
||||
for batch in tqdm(samples, disable=not verbose):
|
||||
for i in range(batch.shape[0]):
|
||||
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
|
||||
if cvvp_amount != 1:
|
||||
clvp_out = clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False)
|
||||
if auto_conds is not None and cvvp_amount > 0:
|
||||
cvvp_accumulator = 0
|
||||
for cl in range(auto_conds.shape[1]):
|
||||
cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
|
||||
cvvp = cvvp_accumulator / auto_conds.shape[1]
|
||||
if cvvp_amount == 1:
|
||||
clip_results.append(cvvp)
|
||||
else:
|
||||
clip_results.append(cvvp * cvvp_amount + clvp_out * (1-cvvp_amount))
|
||||
else:
|
||||
clip_results.append(clvp_out)
|
||||
clip_results = torch.cat(clip_results, dim=0)
|
||||
samples = torch.cat(samples, dim=0)
|
||||
best_results = samples[torch.topk(clip_results, k=k).indices]
|
||||
if self.cvvp is not None:
|
||||
self.cvvp = self.cvvp.cpu()
|
||||
del samples
|
||||
|
|
@ -465,26 +526,58 @@ class TextToSpeech:
|
|||
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
|
||||
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
|
||||
# results, but will increase memory usage.
|
||||
with self.temporary_cuda(
|
||||
self.autoregressive
|
||||
) as autoregressive, torch.autocast(
|
||||
device_type="cuda", dtype=torch.float16, enabled=self.half
|
||||
):
|
||||
best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
|
||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
|
||||
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
|
||||
return_latent=True, clip_inputs=False)
|
||||
del auto_conditioning
|
||||
if not torch.backends.mps.is_available():
|
||||
with self.temporary_cuda(
|
||||
self.autoregressive
|
||||
) as autoregressive, torch.autocast(
|
||||
device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half
|
||||
):
|
||||
best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
|
||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
|
||||
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
|
||||
return_latent=True, clip_inputs=False)
|
||||
del auto_conditioning
|
||||
else:
|
||||
with self.temporary_cuda(
|
||||
self.autoregressive
|
||||
) as autoregressive:
|
||||
best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
|
||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
|
||||
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
|
||||
return_latent=True, clip_inputs=False)
|
||||
del auto_conditioning
|
||||
|
||||
if verbose:
|
||||
print("Transforming autoregressive outputs into audio..")
|
||||
wav_candidates = []
|
||||
with self.temporary_cuda(self.diffusion) as diffusion, self.temporary_cuda(
|
||||
self.vocoder
|
||||
) as vocoder:
|
||||
if not torch.backends.mps.is_available():
|
||||
with self.temporary_cuda(self.diffusion) as diffusion, self.temporary_cuda(
|
||||
self.vocoder
|
||||
) as vocoder:
|
||||
for b in range(best_results.shape[0]):
|
||||
codes = best_results[b].unsqueeze(0)
|
||||
latents = best_latents[b].unsqueeze(0)
|
||||
|
||||
# Find the first occurrence of the "calm" token and trim the codes to that.
|
||||
ctokens = 0
|
||||
for k in range(codes.shape[-1]):
|
||||
if codes[0, k] == calm_token:
|
||||
ctokens += 1
|
||||
else:
|
||||
ctokens = 0
|
||||
if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
|
||||
latents = latents[:, :k]
|
||||
break
|
||||
mel = do_spectrogram_diffusion(diffusion, diffuser, latents, diffusion_conditioning, temperature=diffusion_temperature,
|
||||
verbose=verbose)
|
||||
wav = vocoder.inference(mel)
|
||||
wav_candidates.append(wav.cpu())
|
||||
else:
|
||||
diffusion, vocoder = self.diffusion, self.vocoder
|
||||
diffusion_conditioning = diffusion_conditioning.cpu()
|
||||
for b in range(best_results.shape[0]):
|
||||
codes = best_results[b].unsqueeze(0)
|
||||
latents = best_latents[b].unsqueeze(0)
|
||||
codes = best_results[b].unsqueeze(0).cpu()
|
||||
latents = best_latents[b].unsqueeze(0).cpu()
|
||||
|
||||
# Find the first occurrence of the "calm" token and trim the codes to that.
|
||||
ctokens = 0
|
||||
|
|
@ -496,7 +589,6 @@ class TextToSpeech:
|
|||
if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
|
||||
latents = latents[:, :k]
|
||||
break
|
||||
|
||||
mel = do_spectrogram_diffusion(diffusion, diffuser, latents, diffusion_conditioning, temperature=diffusion_temperature,
|
||||
verbose=verbose)
|
||||
wav = vocoder.inference(mel)
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
self.cached_mel_emb = None
|
||||
def parallelize(self, device_map=None):
|
||||
self.device_map = (
|
||||
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
||||
get_device_map(len(self.transformer.h), range(max(1, torch.cuda.device_count())))
|
||||
if device_map is None
|
||||
else device_map
|
||||
)
|
||||
|
|
@ -62,6 +62,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
self.lm_head = self.lm_head.to("cpu")
|
||||
self.model_parallel = False
|
||||
torch.cuda.empty_cache()
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
|
@ -162,7 +164,10 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
|
||||
# Set device for model parallelism
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(self.transformer.first_device)
|
||||
if torch.backends.mps.is_available():
|
||||
self.to(self.transformer.first_device)
|
||||
else:
|
||||
torch.cuda.set_device(self.transformer.first_device)
|
||||
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
|
|
|||
|
|
@ -302,7 +302,10 @@ class DiffusionTts(nn.Module):
|
|||
unused_params.extend(list(lyr.parameters()))
|
||||
else:
|
||||
# First and last blocks will have autocast disabled for improved precision.
|
||||
with autocast(x.device.type, enabled=self.enable_fp16 and i != 0):
|
||||
if not torch.backends.mps.is_available():
|
||||
with autocast(x.device.type, enabled=self.enable_fp16 and i != 0):
|
||||
x = lyr(x, time_emb)
|
||||
else:
|
||||
x = lyr(x, time_emb)
|
||||
|
||||
x = x.float()
|
||||
|
|
|
|||
|
|
@ -180,7 +180,7 @@ class TacotronSTFT(torch.nn.Module):
|
|||
return mel_output
|
||||
|
||||
|
||||
def wav_to_univnet_mel(wav, do_normalization=False, device='cuda'):
|
||||
def wav_to_univnet_mel(wav, do_normalization=False, device='cuda' if not torch.backends.mps.is_available() else 'mps'):
|
||||
stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000)
|
||||
stft = stft.to(device)
|
||||
mel = stft.mel_spectrogram(wav)
|
||||
|
|
|
|||
|
|
@ -1244,7 +1244,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
|||
dimension equal to the length of timesteps.
|
||||
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
||||
"""
|
||||
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
||||
res = th.from_numpy(arr.astype(np.float32)).to(device=timesteps.device)[timesteps]
|
||||
while len(res.shape) < len(broadcast_shape):
|
||||
res = res[..., None]
|
||||
return res.expand(broadcast_shape)
|
||||
|
|
@ -49,7 +49,7 @@ class Wav2VecAlignment:
|
|||
"""
|
||||
Uses wav2vec2 to perform audio<->text alignment.
|
||||
"""
|
||||
def __init__(self, device='cuda'):
|
||||
def __init__(self, device='cuda' if not torch.backends.mps.is_available() else 'mps'):
|
||||
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
|
||||
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
|
||||
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols')
|
||||
|
|
|
|||
Loading…
Reference in a new issue