Added MPS support for do_tts

This commit is contained in:
Jerry-Master 2023-08-06 17:41:30 +02:00
parent 3c4d9c5131
commit b4988c24b3
7 changed files with 169 additions and 68 deletions

View file

@ -21,3 +21,4 @@ pydantic==1.9.1
deepspeed==0.8.3
py-cpuinfo
hjson
psutil

View file

@ -189,6 +189,16 @@ def pick_best_batch_size_for_gpu():
return 8
elif availableGb > 7:
return 4
if torch.backends.mps.is_available():
import psutil
available = psutil.virtual_memory().total
availableGb = available / (1024 ** 3)
if availableGb > 14:
return 16
elif availableGb > 10:
return 8
elif availableGb > 7:
return 4
return 1
class TextToSpeech:
@ -212,7 +222,9 @@ class TextToSpeech:
self.models_dir = models_dir
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
self.enable_redaction = enable_redaction
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu')
if torch.backends.mps.is_available():
self.device = torch.device('mps')
if self.enable_redaction:
self.aligner = Wav2VecAlignment()
@ -254,6 +266,7 @@ class TextToSpeech:
m = model.to(self.device)
yield m
m = model.cpu()
def load_cvvp(self):
"""Load CVVP model."""
@ -410,54 +423,102 @@ class TextToSpeech:
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
if verbose:
print("Generating autoregressive samples..")
with self.temporary_cuda(self.autoregressive
) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16, enabled=self.half):
for b in tqdm(range(num_batches), disable=not verbose):
codes = autoregressive.inference_speech(auto_conditioning, text_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
num_return_sequences=self.autoregressive_batch_size,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
max_generate_length=max_mel_tokens,
**hf_generate_kwargs)
padding_needed = max_mel_tokens - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
samples.append(codes)
if not torch.backends.mps.is_available():
with self.temporary_cuda(self.autoregressive
) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16, enabled=self.half):
for b in tqdm(range(num_batches), disable=not verbose):
codes = autoregressive.inference_speech(auto_conditioning, text_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
num_return_sequences=self.autoregressive_batch_size,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
max_generate_length=max_mel_tokens,
**hf_generate_kwargs)
padding_needed = max_mel_tokens - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
samples.append(codes)
else:
with self.temporary_cuda(self.autoregressive) as autoregressive:
for b in tqdm(range(num_batches), disable=not verbose):
codes = autoregressive.inference_speech(auto_conditioning, text_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
num_return_sequences=self.autoregressive_batch_size,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
max_generate_length=max_mel_tokens,
**hf_generate_kwargs)
padding_needed = max_mel_tokens - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
samples.append(codes)
clip_results = []
with self.temporary_cuda(self.clvp) as clvp, torch.autocast(
device_type="cuda", dtype=torch.float16, enabled=self.half
):
if cvvp_amount > 0:
if self.cvvp is None:
self.load_cvvp()
self.cvvp = self.cvvp.to(self.device)
if verbose:
if self.cvvp is None:
print("Computing best candidates using CLVP")
else:
print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%")
for batch in tqdm(samples, disable=not verbose):
for i in range(batch.shape[0]):
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
if cvvp_amount != 1:
clvp_out = clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False)
if auto_conds is not None and cvvp_amount > 0:
cvvp_accumulator = 0
for cl in range(auto_conds.shape[1]):
cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
cvvp = cvvp_accumulator / auto_conds.shape[1]
if cvvp_amount == 1:
clip_results.append(cvvp)
if not torch.backends.mps.is_available():
with self.temporary_cuda(self.clvp) as clvp, torch.autocast(
device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half
):
if cvvp_amount > 0:
if self.cvvp is None:
self.load_cvvp()
self.cvvp = self.cvvp.to(self.device)
if verbose:
if self.cvvp is None:
print("Computing best candidates using CLVP")
else:
clip_results.append(cvvp * cvvp_amount + clvp_out * (1-cvvp_amount))
else:
clip_results.append(clvp_out)
clip_results = torch.cat(clip_results, dim=0)
samples = torch.cat(samples, dim=0)
best_results = samples[torch.topk(clip_results, k=k).indices]
print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%")
for batch in tqdm(samples, disable=not verbose):
for i in range(batch.shape[0]):
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
if cvvp_amount != 1:
clvp_out = clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False)
if auto_conds is not None and cvvp_amount > 0:
cvvp_accumulator = 0
for cl in range(auto_conds.shape[1]):
cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
cvvp = cvvp_accumulator / auto_conds.shape[1]
if cvvp_amount == 1:
clip_results.append(cvvp)
else:
clip_results.append(cvvp * cvvp_amount + clvp_out * (1-cvvp_amount))
else:
clip_results.append(clvp_out)
clip_results = torch.cat(clip_results, dim=0)
samples = torch.cat(samples, dim=0)
best_results = samples[torch.topk(clip_results, k=k).indices]
else:
with self.temporary_cuda(self.clvp) as clvp:
if cvvp_amount > 0:
if self.cvvp is None:
self.load_cvvp()
self.cvvp = self.cvvp.to(self.device)
if verbose:
if self.cvvp is None:
print("Computing best candidates using CLVP")
else:
print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%")
for batch in tqdm(samples, disable=not verbose):
for i in range(batch.shape[0]):
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
if cvvp_amount != 1:
clvp_out = clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False)
if auto_conds is not None and cvvp_amount > 0:
cvvp_accumulator = 0
for cl in range(auto_conds.shape[1]):
cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
cvvp = cvvp_accumulator / auto_conds.shape[1]
if cvvp_amount == 1:
clip_results.append(cvvp)
else:
clip_results.append(cvvp * cvvp_amount + clvp_out * (1-cvvp_amount))
else:
clip_results.append(clvp_out)
clip_results = torch.cat(clip_results, dim=0)
samples = torch.cat(samples, dim=0)
best_results = samples[torch.topk(clip_results, k=k).indices]
if self.cvvp is not None:
self.cvvp = self.cvvp.cpu()
del samples
@ -465,26 +526,58 @@ class TextToSpeech:
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
# results, but will increase memory usage.
with self.temporary_cuda(
self.autoregressive
) as autoregressive, torch.autocast(
device_type="cuda", dtype=torch.float16, enabled=self.half
):
best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
return_latent=True, clip_inputs=False)
del auto_conditioning
if not torch.backends.mps.is_available():
with self.temporary_cuda(
self.autoregressive
) as autoregressive, torch.autocast(
device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half
):
best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
return_latent=True, clip_inputs=False)
del auto_conditioning
else:
with self.temporary_cuda(
self.autoregressive
) as autoregressive:
best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
return_latent=True, clip_inputs=False)
del auto_conditioning
if verbose:
print("Transforming autoregressive outputs into audio..")
wav_candidates = []
with self.temporary_cuda(self.diffusion) as diffusion, self.temporary_cuda(
self.vocoder
) as vocoder:
if not torch.backends.mps.is_available():
with self.temporary_cuda(self.diffusion) as diffusion, self.temporary_cuda(
self.vocoder
) as vocoder:
for b in range(best_results.shape[0]):
codes = best_results[b].unsqueeze(0)
latents = best_latents[b].unsqueeze(0)
# Find the first occurrence of the "calm" token and trim the codes to that.
ctokens = 0
for k in range(codes.shape[-1]):
if codes[0, k] == calm_token:
ctokens += 1
else:
ctokens = 0
if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
latents = latents[:, :k]
break
mel = do_spectrogram_diffusion(diffusion, diffuser, latents, diffusion_conditioning, temperature=diffusion_temperature,
verbose=verbose)
wav = vocoder.inference(mel)
wav_candidates.append(wav.cpu())
else:
diffusion, vocoder = self.diffusion, self.vocoder
diffusion_conditioning = diffusion_conditioning.cpu()
for b in range(best_results.shape[0]):
codes = best_results[b].unsqueeze(0)
latents = best_latents[b].unsqueeze(0)
codes = best_results[b].unsqueeze(0).cpu()
latents = best_latents[b].unsqueeze(0).cpu()
# Find the first occurrence of the "calm" token and trim the codes to that.
ctokens = 0
@ -496,7 +589,6 @@ class TextToSpeech:
if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
latents = latents[:, :k]
break
mel = do_spectrogram_diffusion(diffusion, diffuser, latents, diffusion_conditioning, temperature=diffusion_temperature,
verbose=verbose)
wav = vocoder.inference(mel)

View file

@ -47,7 +47,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
self.cached_mel_emb = None
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
get_device_map(len(self.transformer.h), range(max(1, torch.cuda.device_count())))
if device_map is None
else device_map
)
@ -62,6 +62,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
self.lm_head = self.lm_head.to("cpu")
self.model_parallel = False
torch.cuda.empty_cache()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
def get_output_embeddings(self):
return self.lm_head
@ -162,7 +164,10 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.transformer.first_device)
if torch.backends.mps.is_available():
self.to(self.transformer.first_device)
else:
torch.cuda.set_device(self.transformer.first_device)
hidden_states = hidden_states.to(self.lm_head.weight.device)
lm_logits = self.lm_head(hidden_states)

View file

@ -302,7 +302,10 @@ class DiffusionTts(nn.Module):
unused_params.extend(list(lyr.parameters()))
else:
# First and last blocks will have autocast disabled for improved precision.
with autocast(x.device.type, enabled=self.enable_fp16 and i != 0):
if not torch.backends.mps.is_available():
with autocast(x.device.type, enabled=self.enable_fp16 and i != 0):
x = lyr(x, time_emb)
else:
x = lyr(x, time_emb)
x = x.float()

View file

@ -180,7 +180,7 @@ class TacotronSTFT(torch.nn.Module):
return mel_output
def wav_to_univnet_mel(wav, do_normalization=False, device='cuda'):
def wav_to_univnet_mel(wav, do_normalization=False, device='cuda' if not torch.backends.mps.is_available() else 'mps'):
stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000)
stft = stft.to(device)
mel = stft.mel_spectrogram(wav)

View file

@ -1244,7 +1244,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
res = th.from_numpy(arr.astype(np.float32)).to(device=timesteps.device)[timesteps]
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)

View file

@ -49,7 +49,7 @@ class Wav2VecAlignment:
"""
Uses wav2vec2 to perform audio<->text alignment.
"""
def __init__(self, device='cuda'):
def __init__(self, device='cuda' if not torch.backends.mps.is_available() else 'mps'):
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols')