diff --git a/setup.py b/setup.py index 22ae20a..6170cfe 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ setuptools.setup( 'librosa', 'transformers==4.31.0', 'tokenizers', - 'deepspeed==0.8.3', + # 'deepspeed==0.8.3', ], classifiers=[ "Programming Language :: Python :: 3", diff --git a/tortoise/api.py b/tortoise/api.py index 57d0159..69807b1 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -23,6 +23,7 @@ from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named from tortoise.utils.tokenizer import VoiceBpeTokenizer from tortoise.utils.wav2vec_alignment import Wav2VecAlignment from contextlib import contextmanager +from huggingface_hub import hf_hub_download pbar = None DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser('~'), '.cache', 'tortoise', 'models') @@ -38,44 +39,13 @@ MODELS = { 'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth', } -def download_models(specific_models=None): - """ - Call to download all the models that Tortoise uses. - """ - os.makedirs(MODELS_DIR, exist_ok=True) - - def show_progress(block_num, block_size, total_size): - global pbar - if pbar is None: - pbar = progressbar.ProgressBar(maxval=total_size) - pbar.start() - - downloaded = block_num * block_size - if downloaded < total_size: - pbar.update(downloaded) - else: - pbar.finish() - pbar = None - for model_name, url in MODELS.items(): - if specific_models is not None and model_name not in specific_models: - continue - model_path = os.path.join(MODELS_DIR, model_name) - if os.path.exists(model_path): - continue - print(f'Downloading {model_name} from {url}...') - request.urlretrieve(url, model_path, show_progress) - print('Done.') - - def get_model_path(model_name, models_dir=MODELS_DIR): """ Get path to given model, download it if it doesn't exist. """ if model_name not in MODELS: raise ValueError(f'Model {model_name} not found in available models.') - model_path = os.path.join(models_dir, model_name) - if not os.path.exists(model_path) and models_dir == MODELS_DIR: - download_models([model_name]) + model_path = hf_hub_download(repo_id="Manmay/tortoise-tts", filename=model_name, cache_dir=models_dir) return model_path diff --git a/tortoise/api_fast.py b/tortoise/api_fast.py new file mode 100644 index 0000000..3e96851 --- /dev/null +++ b/tortoise/api_fast.py @@ -0,0 +1,500 @@ +import os +import random +import uuid +from time import time +from urllib import request + +import torch +import torch.nn.functional as F +import progressbar +import torchaudio +import numpy as np +from tortoise.models.classifier import AudioMiniEncoderWithClassifierHead +from tortoise.models.diffusion_decoder import DiffusionTts +from tortoise.models.autoregressive import UnifiedVoice +from tqdm import tqdm +from tortoise.models.arch_util import TorchMelSpectrogram +from tortoise.models.clvp import CLVP +from tortoise.models.cvvp import CVVP +from tortoise.models.hifigan_decoder import HifiganGenerator +from tortoise.models.random_latent_generator import RandomLatentConverter +from tortoise.models.vocoder import UnivNetGenerator +from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel +from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule +from tortoise.utils.tokenizer import VoiceBpeTokenizer +from tortoise.utils.wav2vec_alignment import Wav2VecAlignment +from contextlib import contextmanager +from tortoise.models.stream_generator import init_stream_support +from huggingface_hub import hf_hub_download +pbar = None +init_stream_support() +DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser('~'), '.cache', 'tortoise', 'models') +MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', DEFAULT_MODELS_DIR) + +MODELS = { + 'autoregressive.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/autoregressive.pth', + 'classifier.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/classifier.pth', + 'rlg_auto.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/rlg_auto.pth', + 'hifidecoder.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/hifidecoder.pth', +} + +def get_model_path(model_name, models_dir=MODELS_DIR): + """ + Get path to given model, download it if it doesn't exist. + """ + if model_name not in MODELS: + raise ValueError(f'Model {model_name} not found in available models.') + model_path = hf_hub_download(repo_id="Manmay/tortoise-tts", filename=model_name, cache_dir=models_dir) + return model_path + + +def pad_or_truncate(t, length): + """ + Utility function for forcing to have the specified sequence length, whether by clipping it or padding it with 0s. + """ + if t.shape[-1] == length: + return t + elif t.shape[-1] < length: + return F.pad(t, (0, length-t.shape[-1])) + else: + return t[..., :length] + + +def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1): + """ + Helper function to load a GaussianDiffusion instance configured for use as a vocoder. + """ + return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon', + model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps), + conditioning_free=cond_free, conditioning_free_k=cond_free_k) + + +def format_conditioning(clip, cond_length=132300, device="cuda" if not torch.backends.mps.is_available() else 'mps'): + """ + Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models. + """ + gap = clip.shape[-1] - cond_length + if gap < 0: + clip = F.pad(clip, pad=(0, abs(gap))) + elif gap > 0: + rand_start = random.randint(0, gap) + clip = clip[:, rand_start:rand_start + cond_length] + mel_clip = TorchMelSpectrogram()(clip.unsqueeze(0)).squeeze(0) + return mel_clip.unsqueeze(0).to(device) + + +def fix_autoregressive_output(codes, stop_token, complain=True): + """ + This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was + trained on and what the autoregressive code generator creates (which has no padding or end). + This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with + a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE + and copying out the last few codes. + + Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar. + """ + # Strip off the autoregressive stop token and add padding. + stop_token_indices = (codes == stop_token).nonzero() + if len(stop_token_indices) == 0: + if complain: + print("No stop tokens found in one of the generated voice clips. This typically means the spoken audio is " + "too long. In some cases, the output will still be good, though. Listen to it and if it is missing words, " + "try breaking up your input text.") + return codes + else: + codes[stop_token_indices] = 83 + stm = stop_token_indices.min().item() + codes[stm:] = 83 + if stm - 3 < codes.shape[0]: + codes[-3] = 45 + codes[-2] = 45 + codes[-1] = 248 + + return codes + + +def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_latents, temperature=1, verbose=True): + """ + Uses the specified diffusion model to convert discrete codes into a spectrogram. + """ + with torch.no_grad(): + output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. + output_shape = (latents.shape[0], 100, output_seq_len) + precomputed_embeddings = diffusion_model.timestep_independent(latents, conditioning_latents, output_seq_len, False) + + noise = torch.randn(output_shape, device=latents.device) * temperature + mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise, + model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}, + progress=verbose) + return denormalize_tacotron_mel(mel)[:,:,:output_seq_len] + + +def classify_audio_clip(clip): + """ + Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise. + :param clip: torch tensor containing audio waveform data (get it from load_audio) + :return: True if the clip was classified as coming from Tortoise and false if it was classified as real. + """ + classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4, + resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32, + dropout=0, kernel_size=5, distribute_zero_label=False) + classifier.load_state_dict(torch.load(get_model_path('classifier.pth'), map_location=torch.device('cpu'))) + clip = clip.cpu().unsqueeze(0) + results = F.softmax(classifier(clip), dim=-1) + return results[0][0] + + +def pick_best_batch_size_for_gpu(): + """ + Tries to pick a batch size that will fit in your GPU. These sizes aren't guaranteed to work, but they should give + you a good shot. + """ + if torch.cuda.is_available(): + _, available = torch.cuda.mem_get_info() + availableGb = available / (1024 ** 3) + if availableGb > 14: + return 16 + elif availableGb > 10: + return 8 + elif availableGb > 7: + return 4 + if torch.backends.mps.is_available(): + import psutil + available = psutil.virtual_memory().total + availableGb = available / (1024 ** 3) + if availableGb > 14: + return 16 + elif availableGb > 10: + return 8 + elif availableGb > 7: + return 4 + return 1 + +class TextToSpeech: + """ + Main entry point into Tortoise. + """ + + def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, + enable_redaction=True, kv_cache=False, use_deepspeed=False, half=False, device=None, + tokenizer_vocab_file=None, tokenizer_basic=False): + + """ + Constructor + :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing + GPU OOM errors. Larger numbers generates slightly faster. + :param models_dir: Where model weights are stored. This should only be specified if you are providing your own + models, otherwise use the defaults. + :param enable_redaction: When true, text enclosed in brackets are automatically redacted from the spoken output + (but are still rendered by the model). This can be used for prompt engineering. + Default is true. + :param device: Device to use when running the model. If omitted, the device will be automatically chosen. + """ + self.models_dir = models_dir + self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size + self.enable_redaction = enable_redaction + self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu') + if torch.backends.mps.is_available(): + self.device = torch.device('mps') + if self.enable_redaction: + self.aligner = Wav2VecAlignment() + + self.tokenizer = VoiceBpeTokenizer( + vocab_file=tokenizer_vocab_file, + use_basic_cleaners=tokenizer_basic, + ) + self.half = half + if os.path.exists(f'{models_dir}/autoregressive.ptt'): + # Assume this is a traced directory. + self.autoregressive = torch.jit.load(f'{models_dir}/autoregressive.ptt') + else: + self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30, + model_dim=1024, + heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, + train_solo_embeddings=False).to(self.device).eval() + self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)), strict=False) + self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=kv_cache, half=self.half) + + self.hifi_decoder = HifiganGenerator(in_channels=1024, out_channels = 1, resblock_type = "1", + resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], resblock_kernel_sizes = [3, 7, 11], + upsample_kernel_sizes = [16, 16, 4, 4], upsample_initial_channel = 512, upsample_factors = [8, 8, 2, 2], + cond_channels=1024).to(self.device).eval() + hifi_model = torch.load(get_model_path('hifidecoder.pth')) + self.hifi_decoder.load_state_dict(hifi_model, strict=False) + # Random latent generators (RLGs) are loaded lazily. + self.rlg_auto = None + def get_conditioning_latents(self, voice_samples, return_mels=False): + """ + Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent). + These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic + properties. + :param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data. + """ + with torch.no_grad(): + voice_samples = [v.to(self.device) for v in voice_samples] + + auto_conds = [] + if not isinstance(voice_samples, list): + voice_samples = [voice_samples] + for vs in voice_samples: + auto_conds.append(format_conditioning(vs, device=self.device)) + auto_conds = torch.stack(auto_conds, dim=1) + auto_latent = self.autoregressive.get_conditioning(auto_conds) + + if return_mels: + return auto_latent + else: + return auto_latent + + def get_random_conditioning_latents(self): + # Lazy-load the RLG models. + if self.rlg_auto is None: + self.rlg_auto = RandomLatentConverter(1024).eval() + self.rlg_auto.load_state_dict(torch.load(get_model_path('rlg_auto.pth', self.models_dir), map_location=torch.device('cpu'))) + with torch.no_grad(): + return self.rlg_auto(torch.tensor([0.0])) + + def tts_with_preset(self, text, preset='fast', **kwargs): + """ + Calls TTS with one of a set of preset generation parameters. Options: + 'ultra_fast': Produces speech at a speed which belies the name of this repo. (Not really, but it's definitely fastest). + 'fast': Decent quality speech at a decent inference rate. A good choice for mass inference. + 'standard': Very good quality. This is generally about as good as you are going to get. + 'high_quality': Use if you want the absolute best. This is not really worth the compute, though. + """ + # Use generally found best tuning knobs for generation. + settings = {'temperature': .8, 'length_penalty': 1.0, 'repetition_penalty': 2.0, + 'top_p': .8, + 'cond_free_k': 2.0, 'diffusion_temperature': 1.0} + # Presets are defined here. + presets = { + 'ultra_fast': {'num_autoregressive_samples': 1, 'diffusion_iterations': 10}, + 'fast': {'num_autoregressive_samples': 32, 'diffusion_iterations': 50}, + 'standard': {'num_autoregressive_samples': 256, 'diffusion_iterations': 200}, + 'high_quality': {'num_autoregressive_samples': 256, 'diffusion_iterations': 400}, + } + settings.update(presets[preset]) + settings.update(kwargs) # allow overriding of preset settings with kwargs + for audio_frame in self.tts(text, **settings): + yield audio_frame + # taken from here https://github.com/coqui-ai/TTS/blob/d21f15cc850788f9cdf93dac0321395138665287/TTS/tts/models/xtts.py#L666 + def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len): + """Handle chunk formatting in streaming mode""" + wav_chunk = wav_gen[:-overlap_len] + if wav_gen_prev is not None: + wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len] + if wav_overlap is not None: + crossfade_wav = wav_chunk[:overlap_len] + crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device) + wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device) + wav_chunk[:overlap_len] += crossfade_wav + wav_overlap = wav_gen[-overlap_len:] + wav_gen_prev = wav_gen + return wav_chunk, wav_gen_prev, wav_overlap + + + def tts_stream(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True, use_deterministic_seed=None, + return_deterministic_state=False, overlap_wav_len=1024, stream_chunk_size=40, + # autoregressive generation parameters follow + num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500, + # CVVP parameters follow + cvvp_amount=.0, + # diffusion generation parameters follow + diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0, + **hf_generate_kwargs): + """ + Produces an audio clip of the given text being spoken with the given reference voice. + :param text: Text to be spoken. + :param voice_samples: List of 2 or more ~10 second reference clips which should be torch tensors containing 22.05kHz waveform data. + :param conditioning_latents: A tuple of (autoregressive_conditioning_latent, diffusion_conditioning_latent), which + can be provided in lieu of voice_samples. This is ignored unless voice_samples=None. + Conditioning latents can be retrieved via get_conditioning_latents(). + :param k: The number of returned clips. The most likely (as determined by Tortoises' CLVP model) clips are returned. + :param verbose: Whether or not to print log messages indicating the progress of creating a clip. Default=true. + ~~AUTOREGRESSIVE KNOBS~~ + :param num_autoregressive_samples: Number of samples taken from the autoregressive model, all of which are filtered using CLVP. + As Tortoise is a probabilistic model, more samples means a higher probability of creating something "great". + :param temperature: The softmax temperature of the autoregressive model. + :param length_penalty: A length penalty applied to the autoregressive decoder. Higher settings causes the model to produce more terse outputs. + :param repetition_penalty: A penalty that prevents the autoregressive decoder from repeating itself during decoding. Can be used to reduce the incidence + of long silences or "uhhhhhhs", etc. + :param top_p: P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely" (aka boring) outputs. + :param max_mel_tokens: Restricts the output length. (0,600] integer. Each unit is 1/20 of a second. + ~~DIFFUSION KNOBS~~ + :param diffusion_iterations: Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine + the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better, + however. + :param cond_free: Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion performs two forward passes for + each diffusion step: one with the outputs of the autoregressive model and one with no conditioning priors. The output + of the two is blended according to the cond_free_k value below. Conditioning-free diffusion is the real deal, and + dramatically improves realism. + :param cond_free_k: Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf]. + As cond_free_k increases, the output becomes dominated by the conditioning-free signal. + Formula is: output=cond_present_output*(cond_free_k+1)-cond_absenct_output*cond_free_k + :param diffusion_temperature: Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0 + are the "mean" prediction of the diffusion network and will sound bland and smeared. + ~~OTHER STUFF~~ + :param hf_generate_kwargs: The huggingface Transformers generate API is used for the autoregressive transformer. + Extra keyword args fed to this function get forwarded directly to that API. Documentation + here: https://huggingface.co/docs/transformers/internal/generation_utils + :return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. + Sample rate is 24kHz. + """ + deterministic_seed = self.deterministic_state(seed=use_deterministic_seed) + + text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device) + text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. + assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.' + if voice_samples is not None: + auto_conditioning = self.get_conditioning_latents(voice_samples, return_mels=False) + else: + auto_conditioning = self.get_random_conditioning_latents() + auto_conditioning = auto_conditioning.to(self.device) + + with torch.no_grad(): + calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" + if verbose: + print("Generating autoregressive samples..") + with torch.autocast( + device_type="cuda" , dtype=torch.float16, enabled=self.half + ): + fake_inputs = self.autoregressive.compute_embeddings( + auto_conditioning, + text_tokens, + ) + gpt_generator = self.autoregressive.get_generator( + fake_inputs=fake_inputs, + top_k=50, + top_p=top_p, + temperature=temperature, + do_sample=True, + num_beams=1, + num_return_sequences=1, + length_penalty=float(length_penalty), + repetition_penalty=float(repetition_penalty), + output_attentions=False, + output_hidden_states=True, + **hf_generate_kwargs, + ) + all_latents = [] + codes_ = [] + wav_gen_prev = None + wav_overlap = None + is_end = False + first_buffer = 60 + while not is_end: + try: + with torch.autocast( + device_type="cuda", dtype=torch.float16, enabled=self.half + ): + codes, latent = next(gpt_generator) + all_latents += [latent] + codes_ += [codes] + except StopIteration: + is_end = True + + if is_end or (stream_chunk_size > 0 and len(codes_) >= max(stream_chunk_size, first_buffer)): + first_buffer = 0 + gpt_latents = torch.cat(all_latents, dim=0)[None, :] + wav_gen = self.hifi_decoder.inference(gpt_latents.to(self.device), auto_conditioning) + wav_gen = wav_gen.squeeze() + wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( + wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len + ) + codes_ = [] + yield wav_chunk + def tts(self, text, voice_samples=None, k=1, verbose=True, use_deterministic_seed=None, + # autoregressive generation parameters follow + num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, + top_p=.8, max_mel_tokens=500, + # CVVP parameters follow + cvvp_amount=.0, + **hf_generate_kwargs): + """ + Produces an audio clip of the given text being spoken with the given reference voice. + :param text: Text to be spoken. + :param voice_samples: List of 2 or more ~10 second reference clips which should be torch tensors containing 22.05kHz waveform data. + :param conditioning_latents: A tuple of (autoregressive_conditioning_latent, diffusion_conditioning_latent), which + can be provided in lieu of voice_samples. This is ignored unless voice_samples=None. + Conditioning latents can be retrieved via get_conditioning_latents(). + :param k: The number of returned clips. The most likely (as determined by Tortoises' CLVP model) clips are returned. + :param verbose: Whether or not to print log messages indicating the progress of creating a clip. Default=true. + ~~AUTOREGRESSIVE KNOBS~~ + :param num_autoregressive_samples: Number of samples taken from the autoregressive model, all of which are filtered using CLVP. + As Tortoise is a probabilistic model, more samples means a higher probability of creating something "great". + :param temperature: The softmax temperature of the autoregressive model. + :param length_penalty: A length penalty applied to the autoregressive decoder. Higher settings causes the model to produce more terse outputs. + :param repetition_penalty: A penalty that prevents the autoregressive decoder from repeating itself during decoding. Can be used to reduce the incidence + of long silences or "uhhhhhhs", etc. + :param top_p: P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely" (aka boring) outputs. + :param max_mel_tokens: Restricts the output length. (0,600] integer. Each unit is 1/20 of a second. + ~~DIFFUSION KNOBS~~ + :param diffusion_iterations: Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine + the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better, + however. + :param cond_free: Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion performs two forward passes for + each diffusion step: one with the outputs of the autoregressive model and one with no conditioning priors. The output + of the two is blended according to the cond_free_k value below. Conditioning-free diffusion is the real deal, and + dramatically improves realism. + :param cond_free_k: Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf]. + As cond_free_k increases, the output becomes dominated by the conditioning-free signal. + Formula is: output=cond_present_output*(cond_free_k+1)-cond_absenct_output*cond_free_k + :param diffusion_temperature: Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0 + are the "mean" prediction of the diffusion network and will sound bland and smeared. + ~~OTHER STUFF~~ + :param hf_generate_kwargs: The huggingface Transformers generate API is used for the autoregressive transformer. + Extra keyword args fed to this function get forwarded directly to that API. Documentation + here: https://huggingface.co/docs/transformers/internal/generation_utils + :return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. + Sample rate is 24kHz. + """ + deterministic_seed = self.deterministic_state(seed=use_deterministic_seed) + + text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device) + text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. + assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.' + if voice_samples is not None: + auto_conditioning = self.get_conditioning_latents(voice_samples, return_mels=False) + else: + auto_conditioning = self.get_random_conditioning_latents() + auto_conditioning = auto_conditioning.to(self.device) + + with torch.no_grad(): + calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" + if verbose: + print("Generating autoregressive samples..") + with torch.autocast( + device_type="cuda" , dtype=torch.float16, enabled=self.half + ): + codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens, + top_k=50, + top_p=top_p, + temperature=temperature, + do_sample=True, + num_beams=1, + num_return_sequences=1, + length_penalty=float(length_penalty), + repetition_penalty=float(repetition_penalty), + output_attentions=False, + output_hidden_states=True, + **hf_generate_kwargs) + gpt_latents = self.autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1), + torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes, + torch.tensor([codes.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device), + return_latent=True, clip_inputs=False) + if verbose: + print("generating audio..") + wav_gen = self.hifi_decoder.inference(gpt_latents.to(self.device), auto_conditioning) + return wav_gen + def deterministic_state(self, seed=None): + """ + Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be + reproduced. + """ + seed = int(time()) if seed is None else seed + torch.manual_seed(seed) + random.seed(seed) + # Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary. + # torch.use_deterministic_algorithms(True) + + return seed diff --git a/tortoise/models/autoregressive.py b/tortoise/models/autoregressive.py index 7483d45..fcd1a94 100644 --- a/tortoise/models/autoregressive.py +++ b/tortoise/models/autoregressive.py @@ -38,6 +38,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel): self.transformer = gpt self.text_pos_embedding = text_pos_emb self.embeddings = embeddings + self.final_norm = norm self.lm_head = nn.Sequential(norm, linear) self.kv_cache = kv_cache @@ -509,7 +510,28 @@ class UnifiedVoice(nn.Module): loss_text = F.cross_entropy(text_logits, text_targets.long()) loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) return loss_text.mean(), loss_mel.mean(), mel_logits - + def compute_embeddings( + self, + cond_latents, + text_inputs, + ): + text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) + text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token) + emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + conds = cond_latents.unsqueeze(1) + emb = torch.cat([conds, emb], dim=1) + self.inference_model.store_mel_emb(emb) + gpt_inputs = torch.full( + ( + emb.shape[0], + emb.shape[1] + 1, # +1 for the start_mel_token + ), + fill_value=1, + dtype=torch.long, + device=text_inputs.device, + ) + gpt_inputs[:, -1] = self.start_mel_token + return gpt_inputs def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1, max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs): @@ -540,7 +562,16 @@ class UnifiedVoice(nn.Module): num_return_sequences=num_return_sequences, **hf_generate_kwargs) return gen[:, trunc_index:] - + def get_generator(self, fake_inputs, **hf_generate_kwargs): + return self.inference_model.generate_stream( + fake_inputs, + bos_token_id=self.start_mel_token, + pad_token_id=self.stop_mel_token, + eos_token_id=self.stop_mel_token, + max_length=500, + do_stream=True, + **hf_generate_kwargs, + ) if __name__ == '__main__': gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4) l = gpt(torch.randn(2, 3, 80, 800), diff --git a/tortoise/models/hifigan_decoder.py b/tortoise/models/hifigan_decoder.py new file mode 100644 index 0000000..1a428d1 --- /dev/null +++ b/tortoise/models/hifigan_decoder.py @@ -0,0 +1,299 @@ +# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py +import torch +from torch import nn +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn import functional as F +from torch.nn.utils import remove_weight_norm, weight_norm + +LRELU_SLOPE = 0.1 + + +def get_padding(k, d): + return int((k * d - d) / 2) + + +class ResBlock1(torch.nn.Module): + """Residual Block Type 1. It has 3 convolutional layers in each convolutional block. + + Network:: + + x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o + |--------------------------------------------------------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ + + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + ] + ) + + def forward(self, x): + """ + Args: + x (Tensor): input tensor. + Returns: + Tensor: output tensor. + Shapes: + x: [B, C, T] + """ + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + """Residual Block Type 2. It has 1 convolutional layers in each convolutional block. + + Network:: + + x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o + |---------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ + + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super().__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class HifiganGenerator(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + resblock_type, + resblock_dilation_sizes, + resblock_kernel_sizes, + upsample_kernel_sizes, + upsample_initial_channel, + upsample_factors, + inference_padding=5, + cond_channels=0, + conv_pre_weight_norm=True, + conv_post_weight_norm=True, + conv_post_bias=True, + ): + r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF) + + Network: + x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o + .. -> zI ---| + resblockN_kNx1 -> zN ---' + + Args: + in_channels (int): number of input tensor channels. + out_channels (int): number of output tensor channels. + resblock_type (str): type of the `ResBlock`. '1' or '2'. + resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`. + resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`. + upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution. + upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2 + for each consecutive upsampling layer. + upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer. + inference_padding (int): constant padding applied to the input at inference time. Defaults to 5. + """ + super().__init__() + self.inference_padding = inference_padding + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_factors) + # initial upsampling layers + self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if resblock_type == "1" else ResBlock2 + # upsampling layers + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + # MRF blocks + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + # post convolution layer + self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)) + if cond_channels > 0: + self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1) + + if not conv_pre_weight_norm: + remove_weight_norm(self.conv_pre) + + if not conv_post_weight_norm: + remove_weight_norm(self.conv_post) + + def forward(self, x, g=None): + """ + Args: + x (Tensor): feature input tensor. + g (Tensor): global conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + o = self.conv_pre(x) + if hasattr(self, "cond_layer"): + o = o + self.cond_layer(g) + for i in range(self.num_upsamples): + o = F.leaky_relu(o, LRELU_SLOPE) + o = self.ups[i](o) + z_sum = None + for j in range(self.num_kernels): + if z_sum is None: + z_sum = self.resblocks[i * self.num_kernels + j](o) + else: + z_sum += self.resblocks[i * self.num_kernels + j](o) + o = z_sum / self.num_kernels + o = F.leaky_relu(o) + o = self.conv_post(o) + o = torch.tanh(o) + return o + + @torch.no_grad() + def inference(self, c, g=None): + """ + Args: + x (Tensor): conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + # c = c.to(self.conv_pre.weight.device) + # c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") + up_1 = torch.nn.functional.interpolate( + c.transpose(1,2), + scale_factor=[1024 / 256], + mode="linear", + ) + up_2 = torch.nn.functional.interpolate( + up_1, + scale_factor=[24000 / 22050], + mode="linear", + ) + g = g.unsqueeze(0) + return self.forward(up_2.to("cuda"), g.transpose(1,2)) + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) diff --git a/tortoise/models/stream_generator.py b/tortoise/models/stream_generator.py new file mode 100644 index 0000000..a8dd07b --- /dev/null +++ b/tortoise/models/stream_generator.py @@ -0,0 +1,1057 @@ +# Adapted from: https://github.com/LowinLi/transformers-stream-generator + +from transformers import ( + GenerationConfig, + GenerationMixin, + LogitsProcessorList, + StoppingCriteriaList, + DisjunctiveConstraint, + BeamSearchScorer, + PhrasalConstraint, + ConstrainedBeamSearchScorer, + PreTrainedModel, +) +import numpy as np +import random +import warnings +import inspect +from transformers.generation.utils import GenerateOutput, SampleOutput, logger +import torch +from typing import Callable, List, Optional, Union +from torch import nn +import torch.distributed as dist +import copy + + +def setup_seed(seed): + if seed == -1: + return + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + +class StreamGenerationConfig(GenerationConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.do_stream = kwargs.pop("do_stream", False) + + +class NewGenerationMixin(GenerationMixin): + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[StreamGenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[ + Callable[[int, torch.Tensor], List[int]] + ] = None, + synced_gpus: Optional[bool] = False, + seed=0, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + r""" + + Generates sequences of token ids for models with a language modeling head. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the + method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of + `input_ids`, `input_values`, `input_features`, or `pixel_values`. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + kwargs: + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. + + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchDecoderOnlyOutput`], + - [`~generation.SampleDecoderOnlyOutput`], + - [`~generation.BeamSearchDecoderOnlyOutput`], + - [`~generation.BeamSampleDecoderOnlyOutput`] + + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchEncoderDecoderOutput`], + - [`~generation.SampleEncoderDecoderOutput`], + - [`~generation.BeamSearchEncoderDecoderOutput`], + - [`~generation.BeamSampleEncoderDecoderOutput`] + """ + setup_seed(seed) + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + self._validate_model_class() + + # priority: `generation_config` argument > `model.generation_config` (the default generation config) + if generation_config is None: + # legacy: users may modify the model configuration to control generation -- update the generation config + # model attribute accordingly, if it was created from the model config + if self.generation_config._from_model_config: + new_generation_config = StreamGenerationConfig.from_model_config( + self.config + ) + if new_generation_config != self.generation_config: + warnings.warn( + "You have modified the pretrained model configuration to control generation. This is a" + " deprecated strategy to control generation and will be removed soon, in a future version." + " Please use a generation configuration file (see" + " https://huggingface.co/docs/transformers/main_classes/text_generation)" + ) + self.generation_config = new_generation_config + generation_config = self.generation_config + + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update( + **kwargs + ) # All unused kwargs must be model kwargs + # self._validate_model_kwargs(model_kwargs.copy()) + + # 2. Set generation parameters if not already defined + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) + + if ( + generation_config.pad_token_id is None + and generation_config.eos_token_id is not None + ): + if model_kwargs.get("attention_mask", None) is None: + logger.warning( + "The attention mask and the pad token id were not set. As a consequence, you may observe " + "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." + ) + eos_token_id = generation_config.eos_token_id + if isinstance(eos_token_id, list): + eos_token_id = eos_token_id[0] + logger.warning( + f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation." + ) + generation_config.pad_token_id = eos_token_id + + # 3. Define model inputs + # inputs_tensor has to be defined + # model_input_name is defined if model-specific keyword input is passed + # otherwise model_input_name is None + # all model-specific keyword inputs are removed from `model_kwargs` + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + batch_size = inputs_tensor.shape[0] + + # 4. Define other model kwargs + model_kwargs["output_attentions"] = generation_config.output_attentions + model_kwargs["output_hidden_states"] = generation_config.output_hidden_states + model_kwargs["use_cache"] = generation_config.use_cache + + accepts_attention_mask = "attention_mask" in set( + inspect.signature(self.forward).parameters.keys() + ) + requires_attention_mask = "encoder_outputs" not in model_kwargs + + if ( + model_kwargs.get("attention_mask", None) is None + and requires_attention_mask + and accepts_attention_mask + ): + model_kwargs[ + "attention_mask" + ] = self._prepare_attention_mask_for_generation( + inputs_tensor, + generation_config.pad_token_id, + generation_config.eos_token_id, + ) + + # decoder-only models should use left-padding for generation + if not self.config.is_encoder_decoder: + if ( + generation_config.pad_token_id is not None + and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) + > 0 + ): + logger.warning( + "A decoder-only architecture is being used, but right-padding was detected! For correct " + "generation results, please set `padding_side='left'` when initializing the tokenizer." + ) + + if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: + # if model is encoder decoder encoder_outputs are created + # and added to `model_kwargs` + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, model_kwargs, model_input_name + ) + + # 5. Prepare `input_ids` which will be used for auto-regressive generation + if self.config.is_encoder_decoder: + input_ids = self._prepare_decoder_input_ids_for_generation( + batch_size, + decoder_start_token_id=generation_config.decoder_start_token_id, + bos_token_id=generation_config.bos_token_id, + model_kwargs=model_kwargs, + device=inputs_tensor.device, + ) + else: + # if decoder-only then inputs_tensor has to be `input_ids` + input_ids = inputs_tensor + + # 6. Prepare `max_length` depending on other stopping criteria. + input_ids_seq_length = input_ids.shape[-1] + has_default_max_length = ( + kwargs.get("max_length") is None + and generation_config.max_length is not None + ) + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to" + f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the" + " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif has_default_max_length and generation_config.max_new_tokens is not None: + generation_config.max_length = ( + generation_config.max_new_tokens + input_ids_seq_length + ) + elif ( + not has_default_max_length and generation_config.max_new_tokens is not None + ): + raise ValueError( + "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" + " limit to the generated output length. Remove one of those arguments. Please refer to the" + " documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + + if ( + generation_config.min_length is not None + and generation_config.min_length > generation_config.max_length + ): + raise ValueError( + f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" + f" the maximum length ({generation_config.max_length})" + ) + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = ( + "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + ) + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 7. determine generation mode + is_constraint_gen_mode = ( + generation_config.constraints is not None + or generation_config.force_words_ids is not None + ) + + is_contrastive_search_gen_mode = ( + generation_config.top_k is not None + and generation_config.top_k > 1 + and generation_config.do_sample is False + and generation_config.penalty_alpha is not None + and generation_config.penalty_alpha > 0 + ) + + is_greedy_gen_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_sample_gen_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is True + and generation_config.do_stream is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_sample_gen_stream_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_stream is True + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_beam_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_beam_sample_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is True + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_group_beam_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups > 1) + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + + if generation_config.num_beam_groups > generation_config.num_beams: + raise ValueError( + "`num_beam_groups` has to be smaller or equal to `num_beams`" + ) + if is_group_beam_gen_mode and generation_config.do_sample is True: + raise ValueError( + "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." + ) + + if self.device.type != input_ids.device.type: + warnings.warn( + "You are calling .generate() with the `input_ids` being on a device type different" + f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" + f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." + " Please make sure that you have put `input_ids` to the" + f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" + " running `.generate()`.", + UserWarning, + ) + # 8. prepare distribution pre_processing samplers + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + # 9. prepare stopping criteria + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + # 10. go into different generation modes + if is_greedy_gen_mode: + if generation_config.num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " greedy search." + ) + + # 11. run greedy search + return self.greedy_search( + input_ids, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_contrastive_search_gen_mode: + if generation_config.num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " contrastive search." + ) + + return self.contrastive_search( + input_ids, + top_k=generation_config.top_k, + penalty_alpha=generation_config.penalty_alpha, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_sample_gen_mode: + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) + + # 12. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 13. run sample + return self.sample( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + elif is_sample_gen_stream_mode: + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) + + # 12. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 13. run sample + return self.sample_stream( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + elif is_beam_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError( + "`num_return_sequences` has to be smaller or equal to `num_beams`." + ) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + + # 11. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.beam_search( + input_ids, + beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_beam_sample_gen_mode: + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + # 12. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size * generation_config.num_return_sequences, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + ) + + # 13. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams + * generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 14. run beam sample + return self.beam_sample( + input_ids, + beam_scorer, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_group_beam_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError( + "`num_return_sequences` has to be smaller or equal to `num_beams`." + ) + + if generation_config.num_beams % generation_config.num_beam_groups != 0: + raise ValueError( + "`num_beams` should be divisible by `num_beam_groups` for group beam search." + ) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + + has_default_typical_p = ( + kwargs.get("typical_p") is None and generation_config.typical_p == 1.0 + ) + if not has_default_typical_p: + raise ValueError( + "Decoder argument `typical_p` is not supported with beam groups." + ) + + # 11. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=generation_config.num_beams, + max_length=stopping_criteria.max_length, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + num_beam_groups=generation_config.num_beam_groups, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.group_beam_search( + input_ids, + beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_constraint_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError( + "`num_return_sequences` has to be smaller or equal to `num_beams`." + ) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + + if generation_config.num_beams <= 1: + raise ValueError( + "`num_beams` needs to be greater than 1 for constrained generation." + ) + + if generation_config.do_sample: + raise ValueError( + "`do_sample` needs to be false for constrained generation." + ) + + if ( + generation_config.num_beam_groups is not None + and generation_config.num_beam_groups > 1 + ): + raise ValueError( + "`num_beam_groups` not supported yet for constrained generation." + ) + + final_constraints = [] + if generation_config.constraints is not None: + final_constraints = generation_config.constraints + + if generation_config.force_words_ids is not None: + + def typeerror(): + raise ValueError( + "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" + f"of positive integers, but is {generation_config.force_words_ids}." + ) + + if ( + not isinstance(generation_config.force_words_ids, list) + or len(generation_config.force_words_ids) == 0 + ): + typeerror() + + for word_ids in generation_config.force_words_ids: + if isinstance(word_ids[0], list): + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any( + not isinstance(token_ids, list) for token_ids in word_ids + ): + typeerror() + if any( + any( + (not isinstance(token_id, int) or token_id < 0) + for token_id in token_ids + ) + for token_ids in word_ids + ): + typeerror() + + constraint = DisjunctiveConstraint(word_ids) + else: + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any( + (not isinstance(token_id, int) or token_id < 0) + for token_id in word_ids + ): + typeerror() + + constraint = PhrasalConstraint(word_ids) + final_constraints.append(constraint) + + # 11. prepare beam search scorer + constrained_beam_scorer = ConstrainedBeamSearchScorer( + constraints=final_constraints, + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.constrained_beam_search( + input_ids, + constrained_beam_scorer=constrained_beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + @torch.no_grad() + def sample_stream( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, + ) -> Union[SampleOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + + + In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead. + For an overview of generation strategies and code examples, check the [following + guide](./generation_strategies). + + + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`: + A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForCausalLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... TopKLogitsWarper, + ... TemperatureLogitsWarper, + ... StoppingCriteriaList, + ... MaxLengthCriteria, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + + >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token + >>> model.config.pad_token_id = model.config.eos_token_id + >>> model.generation_config.pad_token_id = model.config.eos_token_id + + >>> input_prompt = "Today is a beautiful day, and" + >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id), + ... ] + ... ) + >>> # instantiate logits processors + >>> logits_warper = LogitsProcessorList( + ... [ + ... TopKLogitsWarper(50), + ... TemperatureLogitsWarper(0.7), + ... ] + ... ) + + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) + + >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT + >>> outputs = model.sample( + ... input_ids, + ... logits_processor=logits_processor, + ... logits_warper=logits_warper, + ... stopping_criteria=stopping_criteria, + ... ) + + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] + ```""" + # init values + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria( + stopping_criteria, max_length + ) + logits_warper = ( + logits_warper if logits_warper is not None else LogitsProcessorList() + ) + pad_token_id = ( + pad_token_id + if pad_token_id is not None + else self.generation_config.pad_token_id + ) + eos_token_id = ( + eos_token_id + if eos_token_id is not None + else self.generation_config.eos_token_id + ) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + output_scores = ( + output_scores + if output_scores is not None + else self.generation_config.output_scores + ) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.generation_config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.generation_config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + cross_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + decoder_hidden_states = ( + () if (return_dict_in_generate and output_hidden_states) else None + ) + + # keep track of which sequences are already finished + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + + this_peer_finished = False # used by synced_gpus only + # auto-regressive generation + while True: + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor( + 0.0 if this_peer_finished else 1.0 + ).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) + if self.config.is_encoder_decoder + else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError( + "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." + ) + next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( + 1 - unfinished_sequences + ) + yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1]) + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + unfinished_sequences = unfinished_sequences.mul( + (sum(next_tokens != i for i in eos_token_id)).long() + ) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + +def init_stream_support(): + """Overload PreTrainedModel for streaming.""" + PreTrainedModel.generate_stream = NewGenerationMixin.generate + PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream + + +if __name__ == "__main__": + from transformers import PreTrainedModel + from transformers import AutoTokenizer, AutoModelForCausalLM + + PreTrainedModel.generate = NewGenerationMixin.generate + PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream + model = AutoModelForCausalLM.from_pretrained( + "bigscience/bloom-560m", torch_dtype=torch.float16 + ) + + tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") + model = model.to("cuda:0") + model = model.eval() + prompt_text = "hello? \n" + input_ids = tokenizer( + prompt_text, return_tensors="pt", add_special_tokens=False + ).input_ids + input_ids = input_ids.to("cuda:0") + + with torch.no_grad(): + result = model.generate( + input_ids, + max_new_tokens=200, + do_sample=True, + top_k=30, + top_p=0.85, + temperature=0.35, + repetition_penalty=1.2, + early_stopping=True, + seed=0, + ) + print(tokenizer.decode(result, skip_special_tokens=True)) + generator = model.generate( + input_ids, + max_new_tokens=200, + do_sample=True, + top_k=30, + top_p=0.85, + temperature=0.35, + repetition_penalty=1.2, + early_stopping=True, + seed=0, + do_stream=True, + ) + stream_result = "" + for x in generator: + chunk = tokenizer.decode(x, skip_special_tokens=True) + stream_result += chunk + print(stream_result) diff --git a/tortoise/read_fast.py b/tortoise/read_fast.py new file mode 100644 index 0000000..f2778d4 --- /dev/null +++ b/tortoise/read_fast.py @@ -0,0 +1,77 @@ +import argparse +import os +from time import time + +import torch +import torchaudio + +from api_fast import TextToSpeech, MODELS_DIR +from utils.audio import load_audio, load_voices +from utils.text import split_and_recombine_text + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--textfile', type=str, help='A file containing the text to read.', default="tortoise/data/riding_hood.txt") + parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) ' + 'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='lj') + parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/longform/') + parser.add_argument('--output_name', type=str, help='How to name the output file', default='combined.wav') + parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard') + parser.add_argument('--regenerate', type=str, help='Comma-separated list of clip numbers to re-generate, or nothing.', default=None) + parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this' + 'should only be specified if you have custom checkpoints.', default=MODELS_DIR) + parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None) + parser.add_argument('--use_deepspeed', type=bool, help='Use deepspeed for speed bump.', default=False) + parser.add_argument('--kv_cache', type=bool, help='If you disable this please wait for a long a time to get the output', default=True) + parser.add_argument('--half', type=bool, help="float16(half) precision inference if True it's faster and take less vram and ram", default=True) + + + args = parser.parse_args() + if torch.backends.mps.is_available(): + args.use_deepspeed = False + tts = TextToSpeech(models_dir=args.model_dir, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half) + + outpath = args.output_path + outname = args.output_name + selected_voices = args.voice.split(',') + regenerate = args.regenerate + if regenerate is not None: + regenerate = [int(e) for e in regenerate.split(',')] + + # Process text + with open(args.textfile, 'r', encoding='utf-8') as f: + text = ' '.join([l for l in f.readlines()]) + if '|' in text: + print("Found the '|' character in your text, which I will use as a cue for where to split it up. If this was not" + "your intent, please remove all '|' characters from the input.") + texts = text.split('|') + else: + texts = split_and_recombine_text(text) + + seed = int(time()) if args.seed is None else args.seed + for selected_voice in selected_voices: + voice_outpath = os.path.join(outpath, selected_voice) + os.makedirs(voice_outpath, exist_ok=True) + + if '&' in selected_voice: + voice_sel = selected_voice.split('&') + else: + voice_sel = [selected_voice] + + voice_samples, conditioning_latents = load_voices(voice_sel) + all_parts = [] + for j, text in enumerate(texts): + if regenerate is not None and j not in regenerate: + all_parts.append(load_audio(os.path.join(voice_outpath, f'{j}.wav'), 24000)) + continue + start_time = time() + gen = tts.tts(text, voice_samples=voice_samples, use_deterministic_seed=seed) + end_time = time() + audio_ = gen.squeeze(0).cpu() + print("Time taken to generate the audio: ", end_time - start_time, "seconds") + print("RTF: ", (end_time - start_time) / (audio_.shape[1] / 24000)) + torchaudio.save(os.path.join(voice_outpath, f'{j}.wav'), audio_, 24000) + all_parts.append(audio_) + full_audio = torch.cat(all_parts, dim=-1) + torchaudio.save(os.path.join(voice_outpath, f"{outname}.wav"), full_audio, 24000)