This commit is contained in:
manmay-nakhashi 2023-07-26 01:35:42 +05:30
parent 8b317ebedf
commit 1aabb3cec1

View file

@ -13,7 +13,6 @@ from tortoise.models.classifier import AudioMiniEncoderWithClassifierHead
from tortoise.models.diffusion_decoder import DiffusionTts
from tortoise.models.autoregressive import UnifiedVoice
from tqdm import tqdm
from transformers import TextStreamer
from tortoise.models.arch_util import TorchMelSpectrogram
from tortoise.models.clvp import CLVP
from tortoise.models.cvvp import CVVP
@ -393,7 +392,6 @@ class TextToSpeech:
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
streamer = TextStreamer(self.tokenizer)
auto_conds = None
if voice_samples is not None:
auto_conditioning, diffusion_conditioning, auto_conds, _ = self.get_conditioning_latents(voice_samples, return_mels=True)
@ -416,7 +414,7 @@ class TextToSpeech:
with self.temporary_cuda(self.autoregressive
) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16, enabled=self.half):
for b in tqdm(range(num_batches), disable=not verbose):
codes = autoregressive.inference_speech(auto_conditioning, streamer, text_tokens,
codes = autoregressive.inference_speech(auto_conditioning, text_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
@ -521,23 +519,6 @@ class TextToSpeech:
return res, (deterministic_seed, text, voice_samples, conditioning_latents)
else:
return res
def tts_streamable(self, chunk_size, *args, **kwargs):
"""
A modified version of the tts function that yields the output in chunks.
:param chunk_size: The size of the chunks in which to split the output audio.
:param args: The original arguments of the tts function.
:param kwargs: The original keyword arguments of the tts function.
:yield: Chunks of the generated audio.
"""
# Call the original tts function and get the full audio
full_audio = self.tts(*args, **kwargs)
# Convert the audio tensor to a 1D numpy array
full_audio_np = full_audio.squeeze().cpu().numpy()
# Yield audio chunks
for i in range(0, len(full_audio_np), chunk_size):
yield full_audio_np[i:i+chunk_size]
def deterministic_state(self, seed=None):
"""
Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be