diff --git a/tortoise/api.py b/tortoise/api.py index e5960a5..caeb7ac 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -23,7 +23,7 @@ from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule from tortoise.utils.tokenizer import VoiceBpeTokenizer from tortoise.utils.wav2vec_alignment import Wav2VecAlignment - +from contextlib import contextmanager pbar = None DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser('~'), '.cache', 'tortoise', 'models') @@ -192,13 +192,13 @@ def pick_best_batch_size_for_gpu(): return 4 return 1 - class TextToSpeech: """ Main entry point into Tortoise. """ - def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, kv_cache=False,use_deepspeed=False, device=None): + def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, + enable_redaction=True, kv_cache=False, use_deepspeed=False, half=True, device=None): """ Constructor :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing @@ -250,7 +250,15 @@ class TextToSpeech: # Random latent generators (RLGs) are loaded lazily. self.rlg_auto = None self.rlg_diffusion = None - + @contextmanager + def temporary_cuda(self, model): + if self.high_vram: + yield model + else: + m = model.to(self.device) + yield m + m = model.cpu() + def load_cvvp(self): """Load CVVP model.""" self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0, @@ -336,6 +344,7 @@ class TextToSpeech: cvvp_amount=.0, # diffusion generation parameters follow diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0, + half=True, **hf_generate_kwargs): """ Produces an audio clip of the given text being spoken with the given reference voice. @@ -405,55 +414,56 @@ class TextToSpeech: num_batches = num_autoregressive_samples // self.autoregressive_batch_size stop_mel_token = self.autoregressive.stop_mel_token calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" - self.autoregressive = self.autoregressive.to(self.device) if verbose: print("Generating autoregressive samples..") - for b in tqdm(range(num_batches), disable=not verbose): - codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens, - do_sample=True, - top_p=top_p, - temperature=temperature, - num_return_sequences=self.autoregressive_batch_size, - length_penalty=length_penalty, - repetition_penalty=repetition_penalty, - max_generate_length=max_mel_tokens, - **hf_generate_kwargs) - padding_needed = max_mel_tokens - codes.shape[1] - codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) - samples.append(codes) - self.autoregressive = self.autoregressive.cpu() + with self.temporary_cuda(self.autoregressive + ) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16, enabled=half): + for b in tqdm(range(num_batches), disable=not verbose): + codes = autoregressive.inference_speech(auto_conditioning, text_tokens, + do_sample=True, + top_p=top_p, + temperature=temperature, + num_return_sequences=self.autoregressive_batch_size, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + max_generate_length=max_mel_tokens, + **hf_generate_kwargs) + padding_needed = max_mel_tokens - codes.shape[1] + codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) + samples.append(codes) clip_results = [] - self.clvp = self.clvp.to(self.device) - if cvvp_amount > 0: - if self.cvvp is None: - self.load_cvvp() - self.cvvp = self.cvvp.to(self.device) - if verbose: - if self.cvvp is None: - print("Computing best candidates using CLVP") - else: - print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%") - for batch in tqdm(samples, disable=not verbose): - for i in range(batch.shape[0]): - batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) - if cvvp_amount != 1: - clvp = self.clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False) - if auto_conds is not None and cvvp_amount > 0: - cvvp_accumulator = 0 - for cl in range(auto_conds.shape[1]): - cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False) - cvvp = cvvp_accumulator / auto_conds.shape[1] - if cvvp_amount == 1: - clip_results.append(cvvp) + with self.temporary_cuda(self.clvp) as clvp, torch.autocast( + device_type="cuda", dtype=torch.float16, enabled=half + ): + if cvvp_amount > 0: + if self.cvvp is None: + self.load_cvvp() + self.cvvp = self.cvvp.to(self.device) + if verbose: + if self.cvvp is None: + print("Computing best candidates using CLVP") else: - clip_results.append(cvvp * cvvp_amount + clvp * (1-cvvp_amount)) - else: - clip_results.append(clvp) - clip_results = torch.cat(clip_results, dim=0) - samples = torch.cat(samples, dim=0) - best_results = samples[torch.topk(clip_results, k=k).indices] - self.clvp = self.clvp.cpu() + print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%") + for batch in tqdm(samples, disable=not verbose): + for i in range(batch.shape[0]): + batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) + if cvvp_amount != 1: + clvp = clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False) + if auto_conds is not None and cvvp_amount > 0: + cvvp_accumulator = 0 + for cl in range(auto_conds.shape[1]): + cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False) + cvvp = cvvp_accumulator / auto_conds.shape[1] + if cvvp_amount == 1: + clip_results.append(cvvp) + else: + clip_results.append(cvvp * cvvp_amount + clvp * (1-cvvp_amount)) + else: + clip_results.append(clvp) + clip_results = torch.cat(clip_results, dim=0) + samples = torch.cat(samples, dim=0) + best_results = samples[torch.topk(clip_results, k=k).indices] if self.cvvp is not None: self.cvvp = self.cvvp.cpu() del samples @@ -461,40 +471,42 @@ class TextToSpeech: # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these # results, but will increase memory usage. - self.autoregressive = self.autoregressive.to(self.device) - best_latents = self.autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1), - torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results, - torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device), - return_latent=True, clip_inputs=False) - self.autoregressive = self.autoregressive.cpu() - del auto_conditioning + with self.temporary_cuda( + self.autoregressive + ) as autoregressive, torch.autocast( + device_type="cuda", dtype=torch.float16, enabled=half + ): + best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1), + torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results, + torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device), + return_latent=True, clip_inputs=False) + del auto_conditioning if verbose: print("Transforming autoregressive outputs into audio..") wav_candidates = [] - self.diffusion = self.diffusion.to(self.device) - self.vocoder = self.vocoder.to(self.device) - for b in range(best_results.shape[0]): - codes = best_results[b].unsqueeze(0) - latents = best_latents[b].unsqueeze(0) + with self.temporary_cuda(self.diffusion) as diffusion, self.temporary_cuda( + self.vocoder + ) as vocoder: + for b in range(best_results.shape[0]): + codes = best_results[b].unsqueeze(0) + latents = best_latents[b].unsqueeze(0) - # Find the first occurrence of the "calm" token and trim the codes to that. - ctokens = 0 - for k in range(codes.shape[-1]): - if codes[0, k] == calm_token: - ctokens += 1 - else: - ctokens = 0 - if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech. - latents = latents[:, :k] - break + # Find the first occurrence of the "calm" token and trim the codes to that. + ctokens = 0 + for k in range(codes.shape[-1]): + if codes[0, k] == calm_token: + ctokens += 1 + else: + ctokens = 0 + if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech. + latents = latents[:, :k] + break - mel = do_spectrogram_diffusion(self.diffusion, diffuser, latents, diffusion_conditioning, - temperature=diffusion_temperature, verbose=verbose) - wav = self.vocoder.inference(mel) - wav_candidates.append(wav.cpu()) - self.diffusion = self.diffusion.cpu() - self.vocoder = self.vocoder.cpu() + mel = do_spectrogram_diffusion(diffusion, diffuser, latents, diffusion_conditioning, + temperature=diffusion_temperature, verbose=verbose) + wav = vocoder.inference(mel) + wav_candidates.append(wav.cpu()) def potentially_redact(clip, text): if self.enable_redaction: diff --git a/tortoise/models/autoregressive.py b/tortoise/models/autoregressive.py index 26cd16c..9c81667 100644 --- a/tortoise/models/autoregressive.py +++ b/tortoise/models/autoregressive.py @@ -315,7 +315,7 @@ class UnifiedVoice(nn.Module): embeddings.append(self.mel_embedding) for module in embeddings: module.weight.data.normal_(mean=0.0, std=.02) - def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False): + def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False, half=True): seq_length = self.max_mel_tokens + self.max_text_tokens + 2 gpt_config = GPT2Config( vocab_size=self.max_mel_tokens, @@ -336,13 +336,23 @@ class UnifiedVoice(nn.Module): self.mel_head, kv_cache=kv_cache, ) - if use_deepspeed: + if use_deepspeed and half: + import deepspeed + self.ds_engine = deepspeed.init_inference(model=self.inference_model, + mp_size=1, + replace_with_kernel_inject=True, + dtype=torch.float16) + self.inference_model = self.ds_engine.module.eval() + elif use_deepspeed: import deepspeed self.ds_engine = deepspeed.init_inference(model=self.inference_model, mp_size=1, replace_with_kernel_inject=True, dtype=torch.float32) self.inference_model = self.ds_engine.module.eval() + else: + self.inference_model = self.inference_model.eval() + # self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head) self.gpt.wte = self.mel_embedding def build_aligned_inputs_and_targets(self, input, start_token, stop_token):