Modified how audio.wav_to_univnet_mel is called to prevent redundant init of stft

This commit is contained in:
Michael 2024-02-12 15:46:55 -05:00
parent 3eee92a4c8
commit f709ae7af9
2 changed files with 16 additions and 5 deletions

View file

@ -18,7 +18,7 @@ from tortoise.models.clvp import CLVP
from tortoise.models.cvvp import CVVP
from tortoise.models.random_latent_generator import RandomLatentConverter
from tortoise.models.vocoder import UnivNetGenerator
from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel
from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel, TacotronSTFT
from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
from tortoise.utils.tokenizer import VoiceBpeTokenizer
from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
@ -232,6 +232,8 @@ class TextToSpeech:
self.vocoder = UnivNetGenerator().cpu()
self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir), map_location=torch.device('cpu'))['model_g'])
self.vocoder.eval(inference=True)
self.stft = None # TacotronSTFT is only loaded if used.
# Random latent generators (RLGs) are loaded lazily.
self.rlg_auto = None
@ -269,12 +271,17 @@ class TextToSpeech:
auto_latent = self.autoregressive.get_conditioning(auto_conds)
self.autoregressive = self.autoregressive.cpu()
if self.stft is None:
# Initialize STFT
self.stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000).to(self.device)
diffusion_conds = []
for sample in voice_samples:
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
sample = torchaudio.functional.resample(sample, 22050, 24000)
sample = pad_or_truncate(sample, 102400)
cond_mel = wav_to_univnet_mel(sample.to(self.device), do_normalization=False, device=self.device)
cond_mel = wav_to_univnet_mel(sample.to(self.device), do_normalization=False,
device=self.device, stft=self.stft)
diffusion_conds.append(cond_mel)
diffusion_conds = torch.stack(diffusion_conds, dim=1)

View file

@ -181,9 +181,13 @@ class TacotronSTFT(torch.nn.Module):
return mel_output
def wav_to_univnet_mel(wav, do_normalization=False, device='cuda' if not torch.backends.mps.is_available() else 'mps'):
stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000)
stft = stft.to(device)
def wav_to_univnet_mel(wav, do_normalization=False,
device='cuda' if not torch.backends.mps.is_available() else 'mps',
stft=None):
# Don't require stft to be passed, but use it if it is.
if stft is None:
stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000)
stft = stft.to(device)
mel = stft.mel_spectrogram(wav)
if do_normalization:
mel = normalize_tacotron_mel(mel)