Added MPS support for do_tts

This commit is contained in:
Jerry-Master 2023-08-06 17:41:30 +02:00
parent 3c4d9c5131
commit b4988c24b3
7 changed files with 169 additions and 68 deletions

View file

@ -21,3 +21,4 @@ pydantic==1.9.1
deepspeed==0.8.3 deepspeed==0.8.3
py-cpuinfo py-cpuinfo
hjson hjson
psutil

View file

@ -189,6 +189,16 @@ def pick_best_batch_size_for_gpu():
return 8 return 8
elif availableGb > 7: elif availableGb > 7:
return 4 return 4
if torch.backends.mps.is_available():
import psutil
available = psutil.virtual_memory().total
availableGb = available / (1024 ** 3)
if availableGb > 14:
return 16
elif availableGb > 10:
return 8
elif availableGb > 7:
return 4
return 1 return 1
class TextToSpeech: class TextToSpeech:
@ -212,7 +222,9 @@ class TextToSpeech:
self.models_dir = models_dir self.models_dir = models_dir
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
self.enable_redaction = enable_redaction self.enable_redaction = enable_redaction
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu')
if torch.backends.mps.is_available():
self.device = torch.device('mps')
if self.enable_redaction: if self.enable_redaction:
self.aligner = Wav2VecAlignment() self.aligner = Wav2VecAlignment()
@ -255,6 +267,7 @@ class TextToSpeech:
yield m yield m
m = model.cpu() m = model.cpu()
def load_cvvp(self): def load_cvvp(self):
"""Load CVVP model.""" """Load CVVP model."""
self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0, self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0,
@ -410,6 +423,7 @@ class TextToSpeech:
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
if verbose: if verbose:
print("Generating autoregressive samples..") print("Generating autoregressive samples..")
if not torch.backends.mps.is_available():
with self.temporary_cuda(self.autoregressive with self.temporary_cuda(self.autoregressive
) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16, enabled=self.half): ) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16, enabled=self.half):
for b in tqdm(range(num_batches), disable=not verbose): for b in tqdm(range(num_batches), disable=not verbose):
@ -425,10 +439,27 @@ class TextToSpeech:
padding_needed = max_mel_tokens - codes.shape[1] padding_needed = max_mel_tokens - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
samples.append(codes) samples.append(codes)
else:
with self.temporary_cuda(self.autoregressive) as autoregressive:
for b in tqdm(range(num_batches), disable=not verbose):
codes = autoregressive.inference_speech(auto_conditioning, text_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
num_return_sequences=self.autoregressive_batch_size,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
max_generate_length=max_mel_tokens,
**hf_generate_kwargs)
padding_needed = max_mel_tokens - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
samples.append(codes)
clip_results = [] clip_results = []
if not torch.backends.mps.is_available():
with self.temporary_cuda(self.clvp) as clvp, torch.autocast( with self.temporary_cuda(self.clvp) as clvp, torch.autocast(
device_type="cuda", dtype=torch.float16, enabled=self.half device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half
): ):
if cvvp_amount > 0: if cvvp_amount > 0:
if self.cvvp is None: if self.cvvp is None:
@ -458,6 +489,36 @@ class TextToSpeech:
clip_results = torch.cat(clip_results, dim=0) clip_results = torch.cat(clip_results, dim=0)
samples = torch.cat(samples, dim=0) samples = torch.cat(samples, dim=0)
best_results = samples[torch.topk(clip_results, k=k).indices] best_results = samples[torch.topk(clip_results, k=k).indices]
else:
with self.temporary_cuda(self.clvp) as clvp:
if cvvp_amount > 0:
if self.cvvp is None:
self.load_cvvp()
self.cvvp = self.cvvp.to(self.device)
if verbose:
if self.cvvp is None:
print("Computing best candidates using CLVP")
else:
print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%")
for batch in tqdm(samples, disable=not verbose):
for i in range(batch.shape[0]):
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
if cvvp_amount != 1:
clvp_out = clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False)
if auto_conds is not None and cvvp_amount > 0:
cvvp_accumulator = 0
for cl in range(auto_conds.shape[1]):
cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
cvvp = cvvp_accumulator / auto_conds.shape[1]
if cvvp_amount == 1:
clip_results.append(cvvp)
else:
clip_results.append(cvvp * cvvp_amount + clvp_out * (1-cvvp_amount))
else:
clip_results.append(clvp_out)
clip_results = torch.cat(clip_results, dim=0)
samples = torch.cat(samples, dim=0)
best_results = samples[torch.topk(clip_results, k=k).indices]
if self.cvvp is not None: if self.cvvp is not None:
self.cvvp = self.cvvp.cpu() self.cvvp = self.cvvp.cpu()
del samples del samples
@ -465,20 +526,31 @@ class TextToSpeech:
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
# results, but will increase memory usage. # results, but will increase memory usage.
if not torch.backends.mps.is_available():
with self.temporary_cuda( with self.temporary_cuda(
self.autoregressive self.autoregressive
) as autoregressive, torch.autocast( ) as autoregressive, torch.autocast(
device_type="cuda", dtype=torch.float16, enabled=self.half device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half
): ):
best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1), best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results, torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device), torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
return_latent=True, clip_inputs=False) return_latent=True, clip_inputs=False)
del auto_conditioning del auto_conditioning
else:
with self.temporary_cuda(
self.autoregressive
) as autoregressive:
best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
return_latent=True, clip_inputs=False)
del auto_conditioning
if verbose: if verbose:
print("Transforming autoregressive outputs into audio..") print("Transforming autoregressive outputs into audio..")
wav_candidates = [] wav_candidates = []
if not torch.backends.mps.is_available():
with self.temporary_cuda(self.diffusion) as diffusion, self.temporary_cuda( with self.temporary_cuda(self.diffusion) as diffusion, self.temporary_cuda(
self.vocoder self.vocoder
) as vocoder: ) as vocoder:
@ -496,7 +568,27 @@ class TextToSpeech:
if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech. if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
latents = latents[:, :k] latents = latents[:, :k]
break break
mel = do_spectrogram_diffusion(diffusion, diffuser, latents, diffusion_conditioning, temperature=diffusion_temperature,
verbose=verbose)
wav = vocoder.inference(mel)
wav_candidates.append(wav.cpu())
else:
diffusion, vocoder = self.diffusion, self.vocoder
diffusion_conditioning = diffusion_conditioning.cpu()
for b in range(best_results.shape[0]):
codes = best_results[b].unsqueeze(0).cpu()
latents = best_latents[b].unsqueeze(0).cpu()
# Find the first occurrence of the "calm" token and trim the codes to that.
ctokens = 0
for k in range(codes.shape[-1]):
if codes[0, k] == calm_token:
ctokens += 1
else:
ctokens = 0
if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
latents = latents[:, :k]
break
mel = do_spectrogram_diffusion(diffusion, diffuser, latents, diffusion_conditioning, temperature=diffusion_temperature, mel = do_spectrogram_diffusion(diffusion, diffuser, latents, diffusion_conditioning, temperature=diffusion_temperature,
verbose=verbose) verbose=verbose)
wav = vocoder.inference(mel) wav = vocoder.inference(mel)

View file

@ -47,7 +47,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
self.cached_mel_emb = None self.cached_mel_emb = None
def parallelize(self, device_map=None): def parallelize(self, device_map=None):
self.device_map = ( self.device_map = (
get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) get_device_map(len(self.transformer.h), range(max(1, torch.cuda.device_count())))
if device_map is None if device_map is None
else device_map else device_map
) )
@ -62,6 +62,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
self.lm_head = self.lm_head.to("cpu") self.lm_head = self.lm_head.to("cpu")
self.model_parallel = False self.model_parallel = False
torch.cuda.empty_cache() torch.cuda.empty_cache()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
@ -162,6 +164,9 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
# Set device for model parallelism # Set device for model parallelism
if self.model_parallel: if self.model_parallel:
if torch.backends.mps.is_available():
self.to(self.transformer.first_device)
else:
torch.cuda.set_device(self.transformer.first_device) torch.cuda.set_device(self.transformer.first_device)
hidden_states = hidden_states.to(self.lm_head.weight.device) hidden_states = hidden_states.to(self.lm_head.weight.device)

View file

@ -302,8 +302,11 @@ class DiffusionTts(nn.Module):
unused_params.extend(list(lyr.parameters())) unused_params.extend(list(lyr.parameters()))
else: else:
# First and last blocks will have autocast disabled for improved precision. # First and last blocks will have autocast disabled for improved precision.
if not torch.backends.mps.is_available():
with autocast(x.device.type, enabled=self.enable_fp16 and i != 0): with autocast(x.device.type, enabled=self.enable_fp16 and i != 0):
x = lyr(x, time_emb) x = lyr(x, time_emb)
else:
x = lyr(x, time_emb)
x = x.float() x = x.float()
out = self.out(x) out = self.out(x)

View file

@ -180,7 +180,7 @@ class TacotronSTFT(torch.nn.Module):
return mel_output return mel_output
def wav_to_univnet_mel(wav, do_normalization=False, device='cuda'): def wav_to_univnet_mel(wav, do_normalization=False, device='cuda' if not torch.backends.mps.is_available() else 'mps'):
stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000) stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000)
stft = stft.to(device) stft = stft.to(device)
mel = stft.mel_spectrogram(wav) mel = stft.mel_spectrogram(wav)

View file

@ -1244,7 +1244,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
dimension equal to the length of timesteps. dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
""" """
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() res = th.from_numpy(arr.astype(np.float32)).to(device=timesteps.device)[timesteps]
while len(res.shape) < len(broadcast_shape): while len(res.shape) < len(broadcast_shape):
res = res[..., None] res = res[..., None]
return res.expand(broadcast_shape) return res.expand(broadcast_shape)

View file

@ -49,7 +49,7 @@ class Wav2VecAlignment:
""" """
Uses wav2vec2 to perform audio<->text alignment. Uses wav2vec2 to perform audio<->text alignment.
""" """
def __init__(self, device='cuda'): def __init__(self, device='cuda' if not torch.backends.mps.is_available() else 'mps'):
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu() self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h") self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols') self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols')