add half because kv_cache increases memory footprint

This commit is contained in:
manmay-nakhashi 2023-07-16 00:49:17 +05:30
parent a88534adb2
commit 19f5250454
2 changed files with 100 additions and 78 deletions

View file

@ -23,7 +23,7 @@ from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel
from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
from tortoise.utils.tokenizer import VoiceBpeTokenizer
from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
from contextlib import contextmanager
pbar = None
DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser('~'), '.cache', 'tortoise', 'models')
@ -192,13 +192,13 @@ def pick_best_batch_size_for_gpu():
return 4
return 1
class TextToSpeech:
"""
Main entry point into Tortoise.
"""
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, kv_cache=False,use_deepspeed=False, device=None):
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR,
enable_redaction=True, kv_cache=False, use_deepspeed=False, half=True, device=None):
"""
Constructor
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
@ -250,7 +250,15 @@ class TextToSpeech:
# Random latent generators (RLGs) are loaded lazily.
self.rlg_auto = None
self.rlg_diffusion = None
@contextmanager
def temporary_cuda(self, model):
if self.high_vram:
yield model
else:
m = model.to(self.device)
yield m
m = model.cpu()
def load_cvvp(self):
"""Load CVVP model."""
self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0,
@ -336,6 +344,7 @@ class TextToSpeech:
cvvp_amount=.0,
# diffusion generation parameters follow
diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0,
half=True,
**hf_generate_kwargs):
"""
Produces an audio clip of the given text being spoken with the given reference voice.
@ -405,55 +414,56 @@ class TextToSpeech:
num_batches = num_autoregressive_samples // self.autoregressive_batch_size
stop_mel_token = self.autoregressive.stop_mel_token
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
self.autoregressive = self.autoregressive.to(self.device)
if verbose:
print("Generating autoregressive samples..")
for b in tqdm(range(num_batches), disable=not verbose):
codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
num_return_sequences=self.autoregressive_batch_size,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
max_generate_length=max_mel_tokens,
**hf_generate_kwargs)
padding_needed = max_mel_tokens - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
samples.append(codes)
self.autoregressive = self.autoregressive.cpu()
with self.temporary_cuda(self.autoregressive
) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16, enabled=half):
for b in tqdm(range(num_batches), disable=not verbose):
codes = autoregressive.inference_speech(auto_conditioning, text_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
num_return_sequences=self.autoregressive_batch_size,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
max_generate_length=max_mel_tokens,
**hf_generate_kwargs)
padding_needed = max_mel_tokens - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
samples.append(codes)
clip_results = []
self.clvp = self.clvp.to(self.device)
if cvvp_amount > 0:
if self.cvvp is None:
self.load_cvvp()
self.cvvp = self.cvvp.to(self.device)
if verbose:
if self.cvvp is None:
print("Computing best candidates using CLVP")
else:
print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%")
for batch in tqdm(samples, disable=not verbose):
for i in range(batch.shape[0]):
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
if cvvp_amount != 1:
clvp = self.clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False)
if auto_conds is not None and cvvp_amount > 0:
cvvp_accumulator = 0
for cl in range(auto_conds.shape[1]):
cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
cvvp = cvvp_accumulator / auto_conds.shape[1]
if cvvp_amount == 1:
clip_results.append(cvvp)
with self.temporary_cuda(self.clvp) as clvp, torch.autocast(
device_type="cuda", dtype=torch.float16, enabled=half
):
if cvvp_amount > 0:
if self.cvvp is None:
self.load_cvvp()
self.cvvp = self.cvvp.to(self.device)
if verbose:
if self.cvvp is None:
print("Computing best candidates using CLVP")
else:
clip_results.append(cvvp * cvvp_amount + clvp * (1-cvvp_amount))
else:
clip_results.append(clvp)
clip_results = torch.cat(clip_results, dim=0)
samples = torch.cat(samples, dim=0)
best_results = samples[torch.topk(clip_results, k=k).indices]
self.clvp = self.clvp.cpu()
print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%")
for batch in tqdm(samples, disable=not verbose):
for i in range(batch.shape[0]):
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
if cvvp_amount != 1:
clvp = clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False)
if auto_conds is not None and cvvp_amount > 0:
cvvp_accumulator = 0
for cl in range(auto_conds.shape[1]):
cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
cvvp = cvvp_accumulator / auto_conds.shape[1]
if cvvp_amount == 1:
clip_results.append(cvvp)
else:
clip_results.append(cvvp * cvvp_amount + clvp * (1-cvvp_amount))
else:
clip_results.append(clvp)
clip_results = torch.cat(clip_results, dim=0)
samples = torch.cat(samples, dim=0)
best_results = samples[torch.topk(clip_results, k=k).indices]
if self.cvvp is not None:
self.cvvp = self.cvvp.cpu()
del samples
@ -461,40 +471,42 @@ class TextToSpeech:
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
# results, but will increase memory usage.
self.autoregressive = self.autoregressive.to(self.device)
best_latents = self.autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
return_latent=True, clip_inputs=False)
self.autoregressive = self.autoregressive.cpu()
del auto_conditioning
with self.temporary_cuda(
self.autoregressive
) as autoregressive, torch.autocast(
device_type="cuda", dtype=torch.float16, enabled=half
):
best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
return_latent=True, clip_inputs=False)
del auto_conditioning
if verbose:
print("Transforming autoregressive outputs into audio..")
wav_candidates = []
self.diffusion = self.diffusion.to(self.device)
self.vocoder = self.vocoder.to(self.device)
for b in range(best_results.shape[0]):
codes = best_results[b].unsqueeze(0)
latents = best_latents[b].unsqueeze(0)
with self.temporary_cuda(self.diffusion) as diffusion, self.temporary_cuda(
self.vocoder
) as vocoder:
for b in range(best_results.shape[0]):
codes = best_results[b].unsqueeze(0)
latents = best_latents[b].unsqueeze(0)
# Find the first occurrence of the "calm" token and trim the codes to that.
ctokens = 0
for k in range(codes.shape[-1]):
if codes[0, k] == calm_token:
ctokens += 1
else:
ctokens = 0
if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
latents = latents[:, :k]
break
# Find the first occurrence of the "calm" token and trim the codes to that.
ctokens = 0
for k in range(codes.shape[-1]):
if codes[0, k] == calm_token:
ctokens += 1
else:
ctokens = 0
if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
latents = latents[:, :k]
break
mel = do_spectrogram_diffusion(self.diffusion, diffuser, latents, diffusion_conditioning,
temperature=diffusion_temperature, verbose=verbose)
wav = self.vocoder.inference(mel)
wav_candidates.append(wav.cpu())
self.diffusion = self.diffusion.cpu()
self.vocoder = self.vocoder.cpu()
mel = do_spectrogram_diffusion(diffusion, diffuser, latents, diffusion_conditioning,
temperature=diffusion_temperature, verbose=verbose)
wav = vocoder.inference(mel)
wav_candidates.append(wav.cpu())
def potentially_redact(clip, text):
if self.enable_redaction:

View file

@ -315,7 +315,7 @@ class UnifiedVoice(nn.Module):
embeddings.append(self.mel_embedding)
for module in embeddings:
module.weight.data.normal_(mean=0.0, std=.02)
def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False):
def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False, half=True):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
gpt_config = GPT2Config(
vocab_size=self.max_mel_tokens,
@ -336,13 +336,23 @@ class UnifiedVoice(nn.Module):
self.mel_head,
kv_cache=kv_cache,
)
if use_deepspeed:
if use_deepspeed and half:
import deepspeed
self.ds_engine = deepspeed.init_inference(model=self.inference_model,
mp_size=1,
replace_with_kernel_inject=True,
dtype=torch.float16)
self.inference_model = self.ds_engine.module.eval()
elif use_deepspeed:
import deepspeed
self.ds_engine = deepspeed.init_inference(model=self.inference_model,
mp_size=1,
replace_with_kernel_inject=True,
dtype=torch.float32)
self.inference_model = self.ds_engine.module.eval()
else:
self.inference_model = self.inference_model.eval()
# self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
self.gpt.wte = self.mel_embedding
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):