mirror of
https://github.com/neonbjb/tortoise-tts.git
synced 2025-12-06 07:12:00 +01:00
Modified how audio.wav_to_univnet_mel is called to prevent redundant init of stft
This commit is contained in:
parent
3eee92a4c8
commit
f709ae7af9
|
|
@ -18,7 +18,7 @@ from tortoise.models.clvp import CLVP
|
|||
from tortoise.models.cvvp import CVVP
|
||||
from tortoise.models.random_latent_generator import RandomLatentConverter
|
||||
from tortoise.models.vocoder import UnivNetGenerator
|
||||
from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel
|
||||
from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel, TacotronSTFT
|
||||
from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
|
||||
from tortoise.utils.tokenizer import VoiceBpeTokenizer
|
||||
from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
|
||||
|
|
@ -232,6 +232,8 @@ class TextToSpeech:
|
|||
self.vocoder = UnivNetGenerator().cpu()
|
||||
self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir), map_location=torch.device('cpu'))['model_g'])
|
||||
self.vocoder.eval(inference=True)
|
||||
|
||||
self.stft = None # TacotronSTFT is only loaded if used.
|
||||
|
||||
# Random latent generators (RLGs) are loaded lazily.
|
||||
self.rlg_auto = None
|
||||
|
|
@ -269,12 +271,17 @@ class TextToSpeech:
|
|||
auto_latent = self.autoregressive.get_conditioning(auto_conds)
|
||||
self.autoregressive = self.autoregressive.cpu()
|
||||
|
||||
if self.stft is None:
|
||||
# Initialize STFT
|
||||
self.stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000).to(self.device)
|
||||
|
||||
diffusion_conds = []
|
||||
for sample in voice_samples:
|
||||
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
|
||||
sample = torchaudio.functional.resample(sample, 22050, 24000)
|
||||
sample = pad_or_truncate(sample, 102400)
|
||||
cond_mel = wav_to_univnet_mel(sample.to(self.device), do_normalization=False, device=self.device)
|
||||
cond_mel = wav_to_univnet_mel(sample.to(self.device), do_normalization=False,
|
||||
device=self.device, stft=self.stft)
|
||||
diffusion_conds.append(cond_mel)
|
||||
diffusion_conds = torch.stack(diffusion_conds, dim=1)
|
||||
|
||||
|
|
|
|||
|
|
@ -181,9 +181,13 @@ class TacotronSTFT(torch.nn.Module):
|
|||
return mel_output
|
||||
|
||||
|
||||
def wav_to_univnet_mel(wav, do_normalization=False, device='cuda' if not torch.backends.mps.is_available() else 'mps'):
|
||||
stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000)
|
||||
stft = stft.to(device)
|
||||
def wav_to_univnet_mel(wav, do_normalization=False,
|
||||
device='cuda' if not torch.backends.mps.is_available() else 'mps',
|
||||
stft=None):
|
||||
# Don't require stft to be passed, but use it if it is.
|
||||
if stft is None:
|
||||
stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000)
|
||||
stft = stft.to(device)
|
||||
mel = stft.mel_spectrogram(wav)
|
||||
if do_normalization:
|
||||
mel = normalize_tacotron_mel(mel)
|
||||
|
|
|
|||
Loading…
Reference in a new issue