diff --git a/tortoise/api.py b/tortoise/api.py index 69807b1..a5b95dd 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -18,7 +18,7 @@ from tortoise.models.clvp import CLVP from tortoise.models.cvvp import CVVP from tortoise.models.random_latent_generator import RandomLatentConverter from tortoise.models.vocoder import UnivNetGenerator -from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel +from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel, TacotronSTFT from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule from tortoise.utils.tokenizer import VoiceBpeTokenizer from tortoise.utils.wav2vec_alignment import Wav2VecAlignment @@ -232,6 +232,8 @@ class TextToSpeech: self.vocoder = UnivNetGenerator().cpu() self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir), map_location=torch.device('cpu'))['model_g']) self.vocoder.eval(inference=True) + + self.stft = None # TacotronSTFT is only loaded if used. # Random latent generators (RLGs) are loaded lazily. self.rlg_auto = None @@ -269,12 +271,17 @@ class TextToSpeech: auto_latent = self.autoregressive.get_conditioning(auto_conds) self.autoregressive = self.autoregressive.cpu() + if self.stft is None: + # Initialize STFT + self.stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000).to(self.device) + diffusion_conds = [] for sample in voice_samples: # The diffuser operates at a sample rate of 24000 (except for the latent inputs) sample = torchaudio.functional.resample(sample, 22050, 24000) sample = pad_or_truncate(sample, 102400) - cond_mel = wav_to_univnet_mel(sample.to(self.device), do_normalization=False, device=self.device) + cond_mel = wav_to_univnet_mel(sample.to(self.device), do_normalization=False, + device=self.device, stft=self.stft) diffusion_conds.append(cond_mel) diffusion_conds = torch.stack(diffusion_conds, dim=1) diff --git a/tortoise/utils/audio.py b/tortoise/utils/audio.py index e263435..98783ef 100644 --- a/tortoise/utils/audio.py +++ b/tortoise/utils/audio.py @@ -181,9 +181,13 @@ class TacotronSTFT(torch.nn.Module): return mel_output -def wav_to_univnet_mel(wav, do_normalization=False, device='cuda' if not torch.backends.mps.is_available() else 'mps'): - stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000) - stft = stft.to(device) +def wav_to_univnet_mel(wav, do_normalization=False, + device='cuda' if not torch.backends.mps.is_available() else 'mps', + stft=None): + # Don't require stft to be passed, but use it if it is. + if stft is None: + stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000) + stft = stft.to(device) mel = stft.mel_spectrogram(wav) if do_normalization: mel = normalize_tacotron_mel(mel)