From 5a9707d93cf09331cd324f90a1584813944493ab Mon Sep 17 00:00:00 2001 From: manmay-nakhashi Date: Sun, 9 Jul 2023 18:40:10 +0530 Subject: [PATCH] added deepspeed inference --- requirements.txt | 6 +- tortoise/api.py | 5 +- tortoise/models/autoregressive.py | 47 ++-- tortoise_tts.ipynb | 453 ++++++++++++++++++------------ 4 files changed, 311 insertions(+), 200 deletions(-) diff --git a/requirements.txt b/requirements.txt index 79cc77e..fc29a98 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ tqdm rotary_embedding_torch -transformers==4.29.2 +transformers==4.19 tokenizers inflect progressbar @@ -15,3 +15,7 @@ torchaudio threadpoolctl llvmlite appdirs +nbconvert==5.3.1 +tornado==4.2 +pydantic==1.9.0 +deepspeed=9.0.0 diff --git a/tortoise/api.py b/tortoise/api.py index 296ef14..631f3a5 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -198,7 +198,7 @@ class TextToSpeech: Main entry point into Tortoise. """ - def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, device=None): + def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, use_deepspeed=False, device=None): """ Constructor :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing @@ -229,7 +229,8 @@ class TextToSpeech: heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, train_solo_embeddings=False).cpu().eval() self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir))) - + self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed) + self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, layer_drop=0, unconditioned_percentage=0).cpu().eval() diff --git a/tortoise/models/autoregressive.py b/tortoise/models/autoregressive.py index 757a7a8..03aa29f 100644 --- a/tortoise/models/autoregressive.py +++ b/tortoise/models/autoregressive.py @@ -3,6 +3,7 @@ import functools import torch import torch.nn as nn import torch.nn.functional as F +import deepspeed from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from transformers.utils.model_parallel_utils import get_device_map, assert_device_map @@ -340,7 +341,34 @@ class UnifiedVoice(nn.Module): embeddings.append(self.mel_embedding) for module in embeddings: module.weight.data.normal_(mean=0.0, std=.02) - + def post_init_gpt2_config(self, use_deepspeed=False): + seq_length = self.max_mel_tokens + self.max_text_tokens + 2 + gpt_config = GPT2Config( + vocab_size=self.max_mel_tokens, + n_positions=seq_length, + n_ctx=seq_length, + n_embd=self.model_dim, + n_layer=self.layers, + n_head=self.heads, + gradient_checkpointing=False, + use_cache=True, + ) + self.inference_model = GPT2InferenceModel( + gpt_config, + self.gpt, + self.mel_pos_embedding, + self.mel_embedding, + self.final_norm, + self.mel_head + ) + if use_deepspeed: + self.ds_engine = deepspeed.init_inference(model=self.inference_model, + mp_size=1, + replace_with_kernel_inject=True, + dtype=torch.float32) + self.inference_model = self.ds_engine.module.eval() + # self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head) + self.gpt.wte = self.mel_embedding def build_aligned_inputs_and_targets(self, input, start_token, stop_token): inp = F.pad(input, (1,0), value=start_token) tar = F.pad(input, (0,1), value=stop_token) @@ -458,23 +486,10 @@ class UnifiedVoice(nn.Module): return loss_text.mean(), loss_mel.mean(), mel_logits def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1, - max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs): - seq_length = self.max_mel_tokens + self.max_text_tokens + 2 - if not hasattr(self, 'inference_model'): - # TODO: Decouple gpt_config from this inference model. - gpt_config = GPT2Config(vocab_size=self.max_mel_tokens, - n_positions=seq_length, - n_ctx=seq_length, - n_embd=self.model_dim, - n_layer=self.layers, - n_head=self.heads, - gradient_checkpointing=False, - use_cache=True) - self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head) - self.gpt.wte = self.mel_embedding + max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs): text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) - text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) + text_inputs, _ = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) conds = speech_conditioning_latent.unsqueeze(1) diff --git a/tortoise_tts.ipynb b/tortoise_tts.ipynb index b0230e3..2882534 100644 --- a/tortoise_tts.ipynb +++ b/tortoise_tts.ipynb @@ -1,185 +1,276 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "tortoise-tts.ipynb", - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU" + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "_pIZ3ZXNp7cf" + }, + "source": [ + "Welcome to Tortoise! 🐒🐒🐒🐒\n", + "\n", + "Before you begin, I **strongly** recommend you turn on a GPU runtime.\n", + "\n", + "There's a reason this is called \"Tortoise\" - this model takes up to a minute to perform inference for a single sentence on a GPU. Expect waits on the order of hours on a CPU." + ] }, - "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JrK20I32grP6" + }, + "outputs": [], + "source": [ + "#first follow the instructions in the README.md file under Local Installation\n", + "!pip3 install -r requirements.txt\n", + "# !python3 setup.py install" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "Gen09NM4hONQ" + }, + "outputs": [ { - "cell_type": "markdown", - "source": [ - "Welcome to Tortoise! 🐒🐒🐒🐒\n", - "\n", - "Before you begin, I **strongly** recommend you turn on a GPU runtime.\n", - "\n", - "There's a reason this is called \"Tortoise\" - this model takes up to a minute to perform inference for a single sentence on a GPU. Expect waits on the order of hours on a CPU." - ], - "metadata": { - "id": "_pIZ3ZXNp7cf" - } - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JrK20I32grP6" - }, - "outputs": [], - "source": [ - "!git clone https://github.com/neonbjb/tortoise-tts.git\n", - "%cd tortoise-tts\n", - "!pip3 install -r requirements.txt\n", - "!python3 setup.py install" - ] - }, - { - "cell_type": "code", - "source": [ - "# Imports used through the rest of the notebook.\n", - "import torch\n", - "import torchaudio\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "\n", - "import IPython\n", - "\n", - "from tortoise.api import TextToSpeech\n", - "from tortoise.utils.audio import load_audio, load_voice, load_voices\n", - "\n", - "# This will download all the models used by Tortoise from the HF hub.\n", - "tts = TextToSpeech()" - ], - "metadata": { - "id": "Gen09NM4hONQ" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "# This is the text that will be spoken.\n", - "text = \"Joining two modalities results in a surprising increase in generalization! What would happen if we combined them all?\"\n", - "\n", - "# Here's something for the poetically inclined.. (set text=)\n", - "\"\"\"\n", - "Then took the other, as just as fair,\n", - "And having perhaps the better claim,\n", - "Because it was grassy and wanted wear;\n", - "Though as for that the passing there\n", - "Had worn them really about the same,\"\"\"\n", - "\n", - "# Pick a \"preset mode\" to determine quality. Options: {\"ultra_fast\", \"fast\" (default), \"standard\", \"high_quality\"}. See docs in api.py\n", - "preset = \"fast\"" - ], - "metadata": { - "id": "bt_aoxONjfL2" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "# Tortoise will attempt to mimic voices you provide. It comes pre-packaged\n", - "# with some voices you might recognize.\n", - "\n", - "# Let's list all the voices available. These are just some random clips I've gathered\n", - "# from the internet as well as a few voices from the training dataset.\n", - "# Feel free to add your own clips to the voices/ folder.\n", - "%ls tortoise/voices\n", - "\n", - "IPython.display.Audio('tortoise/voices/tom/1.wav')" - ], - "metadata": { - "id": "SSleVnRAiEE2" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "# Pick one of the voices from the output above\n", - "voice = 'tom'\n", - "\n", - "# Load it and send it through Tortoise.\n", - "voice_samples, conditioning_latents = load_voice(voice)\n", - "gen = tts.tts_with_preset(text, voice_samples=voice_samples, conditioning_latents=conditioning_latents, \n", - " preset=preset)\n", - "torchaudio.save('generated.wav', gen.squeeze(0).cpu(), 24000)\n", - "IPython.display.Audio('generated.wav')" - ], - "metadata": { - "id": "KEXOKjIvn6NW" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "# Tortoise can also generate speech using a random voice. The voice changes each time you execute this!\n", - "# (Note: random voices can be prone to strange utterances)\n", - "gen = tts.tts_with_preset(text, voice_samples=None, conditioning_latents=None, preset=preset)\n", - "torchaudio.save('generated.wav', gen.squeeze(0).cpu(), 24000)\n", - "IPython.display.Audio('generated.wav')" - ], - "metadata": { - "id": "16Xs2SSC3BXa" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "# You can also combine conditioning voices. Combining voices produces a new voice\n", - "# with traits from all the parents.\n", - "#\n", - "# Lets see what it would sound like if Picard and Kirk had a kid with a penchant for philosophy:\n", - "voice_samples, conditioning_latents = load_voices(['pat', 'william'])\n", - "\n", - "gen = tts.tts_with_preset(\"They used to say that if man was meant to fly, he’d have wings. But he did fly. He discovered he had to.\", \n", - " voice_samples=None, conditioning_latents=None, preset=preset)\n", - "torchaudio.save('captain_kirkard.wav', gen.squeeze(0).cpu(), 24000)\n", - "IPython.display.Audio('captain_kirkard.wav')" - ], - "metadata": { - "id": "fYTk8KUezUr5" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "del tts # Will break other cells, but necessary to conserve RAM if you want to run this cell.\n", - "\n", - "# Tortoise comes with some scripts that does a lot of the lifting for you. For example,\n", - "# read.py will read a text file for you.\n", - "!python3 tortoise/read.py --voice=train_atkins --textfile=tortoise/data/riding_hood.txt --preset=ultra_fast --output_path=.\n", - "\n", - "IPython.display.Audio('train_atkins/combined.wav')\n", - "# This will take awhile.." - ], - "metadata": { - "id": "t66yqWgu68KL" - }, - "execution_count": null, - "outputs": [] + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/manmay/anaconda3/envs/tortoise/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] } - ] -} \ No newline at end of file + ], + "source": [ + "# Imports used through the rest of the notebook.\n", + "import torch\n", + "import torchaudio\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import IPython\n", + "\n", + "from tortoise.api import TextToSpeech\n", + "from tortoise.utils.audio import load_audio, load_voice, load_voices\n", + "\n", + "# This will download all the models used by Tortoise from the HF hub.\n", + "tts = TextToSpeech()\n", + "# If you want to use deepspeed the pass use_deepspeed=True nearly 2x faster than normal\n", + "# tts = TextToSpeech(use_deepspeed=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "bt_aoxONjfL2" + }, + "outputs": [], + "source": [ + "# This is the text that will be spoken.\n", + "text = \"Joining two modalities results in a surprising increase in generalization! What would happen if we combined them all?\"\n", + "\n", + "# Here's something for the poetically inclined.. (set text=)\n", + "\"\"\"\n", + "Then took the other, as just as fair,\n", + "And having perhaps the better claim,\n", + "Because it was grassy and wanted wear;\n", + "Though as for that the passing there\n", + "Had worn them really about the same,\"\"\"\n", + "\n", + "# Pick a \"preset mode\" to determine quality. Options: {\"ultra_fast\", \"fast\" (default), \"standard\", \"high_quality\"}. See docs in api.py\n", + "preset = \"ultra_fast\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "SSleVnRAiEE2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[0m\u001b[01;34mangie\u001b[0m/ \u001b[01;34mfreeman\u001b[0m/ \u001b[01;34mmyself\u001b[0m/ \u001b[01;34mtom\u001b[0m/ \u001b[01;34mtrain_grace\u001b[0m/\n", + "\u001b[01;34mapplejack\u001b[0m/ \u001b[01;34mgeralt\u001b[0m/ \u001b[01;34mpat\u001b[0m/ \u001b[01;34mtrain_atkins\u001b[0m/ \u001b[01;34mtrain_kennard\u001b[0m/\n", + "\u001b[01;34mcond_latent_example\u001b[0m/ \u001b[01;34mhalle\u001b[0m/ \u001b[01;34mpat2\u001b[0m/ \u001b[01;34mtrain_daws\u001b[0m/ \u001b[01;34mtrain_lescault\u001b[0m/\n", + "\u001b[01;34mdaniel\u001b[0m/ \u001b[01;34mjlaw\u001b[0m/ \u001b[01;34mrainbow\u001b[0m/ \u001b[01;34mtrain_dotrice\u001b[0m/ \u001b[01;34mtrain_mouse\u001b[0m/\n", + "\u001b[01;34mdeniro\u001b[0m/ \u001b[01;34mlj\u001b[0m/ \u001b[01;34msnakes\u001b[0m/ \u001b[01;34mtrain_dreams\u001b[0m/ \u001b[01;34mweaver\u001b[0m/\n", + "\u001b[01;34memma\u001b[0m/ \u001b[01;34mmol\u001b[0m/ \u001b[01;34mtim_reynolds\u001b[0m/ \u001b[01;34mtrain_empire\u001b[0m/ \u001b[01;34mwilliam\u001b[0m/\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Tortoise will attempt to mimic voices you provide. It comes pre-packaged\n", + "# with some voices you might recognize.\n", + "\n", + "# Let's list all the voices available. These are just some random clips I've gathered\n", + "# from the internet as well as a few voices from the training dataset.\n", + "# Feel free to add your own clips to the voices/ folder.\n", + "%ls tortoise/voices\n", + "\n", + "IPython.display.Audio('tortoise/voices/tom/1.wav')" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "KEXOKjIvn6NW" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generating autoregressive samples..\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 38%|β–ˆβ–ˆβ–ˆβ–Š | 6/16 [00:31<00:52, 5.20s/it]\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[4], line 6\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[39m# Load it and send it through Tortoise.\u001b[39;00m\n\u001b[1;32m 5\u001b[0m voice_samples, conditioning_latents \u001b[39m=\u001b[39m load_voice(voice)\n\u001b[0;32m----> 6\u001b[0m gen \u001b[39m=\u001b[39m tts\u001b[39m.\u001b[39;49mtts_with_preset(text, voice_samples\u001b[39m=\u001b[39;49mvoice_samples, conditioning_latents\u001b[39m=\u001b[39;49mconditioning_latents, \n\u001b[1;32m 7\u001b[0m preset\u001b[39m=\u001b[39;49mpreset)\n\u001b[1;32m 8\u001b[0m torchaudio\u001b[39m.\u001b[39msave(\u001b[39m'\u001b[39m\u001b[39mgenerated.wav\u001b[39m\u001b[39m'\u001b[39m, gen\u001b[39m.\u001b[39msqueeze(\u001b[39m0\u001b[39m)\u001b[39m.\u001b[39mcpu(), \u001b[39m24000\u001b[39m)\n\u001b[1;32m 9\u001b[0m IPython\u001b[39m.\u001b[39mdisplay\u001b[39m.\u001b[39mAudio(\u001b[39m'\u001b[39m\u001b[39mgenerated.wav\u001b[39m\u001b[39m'\u001b[39m)\n", + "File \u001b[0;32m/data/speech_synth/tortoise-tts/tortoise/api.py:329\u001b[0m, in \u001b[0;36mTextToSpeech.tts_with_preset\u001b[0;34m(self, text, preset, **kwargs)\u001b[0m\n\u001b[1;32m 327\u001b[0m settings\u001b[39m.\u001b[39mupdate(presets[preset])\n\u001b[1;32m 328\u001b[0m settings\u001b[39m.\u001b[39mupdate(kwargs) \u001b[39m# allow overriding of preset settings with kwargs\u001b[39;00m\n\u001b[0;32m--> 329\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtts(text, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49msettings)\n", + "File \u001b[0;32m/data/speech_synth/tortoise-tts/tortoise/api.py:412\u001b[0m, in \u001b[0;36mTextToSpeech.tts\u001b[0;34m(self, text, voice_samples, conditioning_latents, k, verbose, use_deterministic_seed, return_deterministic_state, num_autoregressive_samples, temperature, length_penalty, repetition_penalty, top_p, max_mel_tokens, cvvp_amount, diffusion_iterations, cond_free, cond_free_k, diffusion_temperature, **hf_generate_kwargs)\u001b[0m\n\u001b[1;32m 410\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m\"\u001b[39m\u001b[39mGenerating autoregressive samples..\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 411\u001b[0m \u001b[39mfor\u001b[39;00m b \u001b[39min\u001b[39;00m tqdm(\u001b[39mrange\u001b[39m(num_batches), disable\u001b[39m=\u001b[39m\u001b[39mnot\u001b[39;00m verbose):\n\u001b[0;32m--> 412\u001b[0m codes \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mautoregressive\u001b[39m.\u001b[39;49minference_speech(auto_conditioning, text_tokens,\n\u001b[1;32m 413\u001b[0m do_sample\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m 414\u001b[0m top_p\u001b[39m=\u001b[39;49mtop_p,\n\u001b[1;32m 415\u001b[0m temperature\u001b[39m=\u001b[39;49mtemperature,\n\u001b[1;32m 416\u001b[0m num_return_sequences\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mautoregressive_batch_size,\n\u001b[1;32m 417\u001b[0m length_penalty\u001b[39m=\u001b[39;49mlength_penalty,\n\u001b[1;32m 418\u001b[0m repetition_penalty\u001b[39m=\u001b[39;49mrepetition_penalty,\n\u001b[1;32m 419\u001b[0m max_generate_length\u001b[39m=\u001b[39;49mmax_mel_tokens,\n\u001b[1;32m 420\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mhf_generate_kwargs)\n\u001b[1;32m 421\u001b[0m padding_needed \u001b[39m=\u001b[39m max_mel_tokens \u001b[39m-\u001b[39m codes\u001b[39m.\u001b[39mshape[\u001b[39m1\u001b[39m]\n\u001b[1;32m 422\u001b[0m codes \u001b[39m=\u001b[39m F\u001b[39m.\u001b[39mpad(codes, (\u001b[39m0\u001b[39m, padding_needed), value\u001b[39m=\u001b[39mstop_mel_token)\n", + "File \u001b[0;32m/data/speech_synth/tortoise-tts/tortoise/models/autoregressive.py:513\u001b[0m, in \u001b[0;36mUnifiedVoice.inference_speech\u001b[0;34m(self, speech_conditioning_latent, text_inputs, input_tokens, num_return_sequences, max_generate_length, typical_sampling, typical_mass, **hf_generate_kwargs)\u001b[0m\n\u001b[1;32m 511\u001b[0m logits_processor \u001b[39m=\u001b[39m LogitsProcessorList([TypicalLogitsWarper(mass\u001b[39m=\u001b[39mtypical_mass)]) \u001b[39mif\u001b[39;00m typical_sampling \u001b[39melse\u001b[39;00m LogitsProcessorList()\n\u001b[1;32m 512\u001b[0m max_length \u001b[39m=\u001b[39m trunc_index \u001b[39m+\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmax_mel_tokens \u001b[39m-\u001b[39m \u001b[39m1\u001b[39m \u001b[39mif\u001b[39;00m max_generate_length \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m trunc_index \u001b[39m+\u001b[39m max_generate_length\n\u001b[0;32m--> 513\u001b[0m gen \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49minference_model\u001b[39m.\u001b[39;49mgenerate(inputs, bos_token_id\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstart_mel_token, pad_token_id\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstop_mel_token, eos_token_id\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstop_mel_token,\n\u001b[1;32m 514\u001b[0m max_length\u001b[39m=\u001b[39;49mmax_length, logits_processor\u001b[39m=\u001b[39;49mlogits_processor,\n\u001b[1;32m 515\u001b[0m num_return_sequences\u001b[39m=\u001b[39;49mnum_return_sequences, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mhf_generate_kwargs)\n\u001b[1;32m 516\u001b[0m \u001b[39mreturn\u001b[39;00m gen[:, trunc_index:]\n", + "File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[39m@functools\u001b[39m\u001b[39m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdecorate_context\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[39mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[39mreturn\u001b[39;00m func(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/transformers/generation_utils.py:1310\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, typical_p, repetition_penalty, bad_words_ids, force_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, num_return_sequences, max_time, max_new_tokens, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, logits_processor, renormalize_logits, stopping_criteria, constraints, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, forced_bos_token_id, forced_eos_token_id, remove_invalid_values, synced_gpus, exponential_decay_length_penalty, **model_kwargs)\u001b[0m\n\u001b[1;32m 1302\u001b[0m input_ids, model_kwargs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_expand_inputs_for_generation(\n\u001b[1;32m 1303\u001b[0m input_ids,\n\u001b[1;32m 1304\u001b[0m expand_size\u001b[39m=\u001b[39mnum_return_sequences,\n\u001b[1;32m 1305\u001b[0m is_encoder_decoder\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mis_encoder_decoder,\n\u001b[1;32m 1306\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mmodel_kwargs,\n\u001b[1;32m 1307\u001b[0m )\n\u001b[1;32m 1309\u001b[0m \u001b[39m# 12. run sample\u001b[39;00m\n\u001b[0;32m-> 1310\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49msample(\n\u001b[1;32m 1311\u001b[0m input_ids,\n\u001b[1;32m 1312\u001b[0m logits_processor\u001b[39m=\u001b[39;49mlogits_processor,\n\u001b[1;32m 1313\u001b[0m logits_warper\u001b[39m=\u001b[39;49mlogits_warper,\n\u001b[1;32m 1314\u001b[0m stopping_criteria\u001b[39m=\u001b[39;49mstopping_criteria,\n\u001b[1;32m 1315\u001b[0m pad_token_id\u001b[39m=\u001b[39;49mpad_token_id,\n\u001b[1;32m 1316\u001b[0m eos_token_id\u001b[39m=\u001b[39;49meos_token_id,\n\u001b[1;32m 1317\u001b[0m output_scores\u001b[39m=\u001b[39;49moutput_scores,\n\u001b[1;32m 1318\u001b[0m return_dict_in_generate\u001b[39m=\u001b[39;49mreturn_dict_in_generate,\n\u001b[1;32m 1319\u001b[0m synced_gpus\u001b[39m=\u001b[39;49msynced_gpus,\n\u001b[1;32m 1320\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mmodel_kwargs,\n\u001b[1;32m 1321\u001b[0m )\n\u001b[1;32m 1323\u001b[0m \u001b[39melif\u001b[39;00m is_beam_gen_mode:\n\u001b[1;32m 1324\u001b[0m \u001b[39mif\u001b[39;00m num_return_sequences \u001b[39m>\u001b[39m num_beams:\n", + "File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/transformers/generation_utils.py:1926\u001b[0m, in \u001b[0;36mGenerationMixin.sample\u001b[0;34m(self, input_ids, logits_processor, stopping_criteria, logits_warper, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs)\u001b[0m\n\u001b[1;32m 1923\u001b[0m model_inputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mprepare_inputs_for_generation(input_ids, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mmodel_kwargs)\n\u001b[1;32m 1925\u001b[0m \u001b[39m# forward pass to get next token\u001b[39;00m\n\u001b[0;32m-> 1926\u001b[0m outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m(\n\u001b[1;32m 1927\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mmodel_inputs,\n\u001b[1;32m 1928\u001b[0m return_dict\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m 1929\u001b[0m output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m 1930\u001b[0m output_hidden_states\u001b[39m=\u001b[39;49moutput_hidden_states,\n\u001b[1;32m 1931\u001b[0m )\n\u001b[1;32m 1933\u001b[0m \u001b[39mif\u001b[39;00m synced_gpus \u001b[39mand\u001b[39;00m this_peer_finished:\n\u001b[1;32m 1934\u001b[0m cur_len \u001b[39m=\u001b[39m cur_len \u001b[39m+\u001b[39m \u001b[39m1\u001b[39m\n", + "File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", + "File \u001b[0;32m/data/speech_synth/tortoise-tts/tortoise/models/autoregressive.py:142\u001b[0m, in \u001b[0;36mGPT2InferenceModel.forward\u001b[0;34m(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 139\u001b[0m emb \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39membeddings(input_ids)\n\u001b[1;32m 140\u001b[0m emb \u001b[39m=\u001b[39m emb \u001b[39m+\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtext_pos_embedding\u001b[39m.\u001b[39mget_fixed_embedding(attention_mask\u001b[39m.\u001b[39mshape[\u001b[39m1\u001b[39m]\u001b[39m-\u001b[39mmel_len, attention_mask\u001b[39m.\u001b[39mdevice)\n\u001b[0;32m--> 142\u001b[0m transformer_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtransformer(\n\u001b[1;32m 143\u001b[0m inputs_embeds\u001b[39m=\u001b[39;49memb,\n\u001b[1;32m 144\u001b[0m past_key_values\u001b[39m=\u001b[39;49mpast_key_values,\n\u001b[1;32m 145\u001b[0m attention_mask\u001b[39m=\u001b[39;49mattention_mask,\n\u001b[1;32m 146\u001b[0m token_type_ids\u001b[39m=\u001b[39;49mtoken_type_ids,\n\u001b[1;32m 147\u001b[0m position_ids\u001b[39m=\u001b[39;49mposition_ids,\n\u001b[1;32m 148\u001b[0m head_mask\u001b[39m=\u001b[39;49mhead_mask,\n\u001b[1;32m 149\u001b[0m encoder_hidden_states\u001b[39m=\u001b[39;49mencoder_hidden_states,\n\u001b[1;32m 150\u001b[0m encoder_attention_mask\u001b[39m=\u001b[39;49mencoder_attention_mask,\n\u001b[1;32m 151\u001b[0m use_cache\u001b[39m=\u001b[39;49muse_cache,\n\u001b[1;32m 152\u001b[0m output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m 153\u001b[0m output_hidden_states\u001b[39m=\u001b[39;49moutput_hidden_states,\n\u001b[1;32m 154\u001b[0m return_dict\u001b[39m=\u001b[39;49mreturn_dict,\n\u001b[1;32m 155\u001b[0m )\n\u001b[1;32m 156\u001b[0m hidden_states \u001b[39m=\u001b[39m transformer_outputs[\u001b[39m0\u001b[39m]\n\u001b[1;32m 158\u001b[0m \u001b[39m# Set device for model parallelism\u001b[39;00m\n", + "File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", + "File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py:889\u001b[0m, in \u001b[0;36mGPT2Model.forward\u001b[0;34m(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 879\u001b[0m outputs \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mutils\u001b[39m.\u001b[39mcheckpoint\u001b[39m.\u001b[39mcheckpoint(\n\u001b[1;32m 880\u001b[0m create_custom_forward(block),\n\u001b[1;32m 881\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 886\u001b[0m encoder_attention_mask,\n\u001b[1;32m 887\u001b[0m )\n\u001b[1;32m 888\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 889\u001b[0m outputs \u001b[39m=\u001b[39m block(\n\u001b[1;32m 890\u001b[0m hidden_states,\n\u001b[1;32m 891\u001b[0m layer_past\u001b[39m=\u001b[39;49mlayer_past,\n\u001b[1;32m 892\u001b[0m attention_mask\u001b[39m=\u001b[39;49mattention_mask,\n\u001b[1;32m 893\u001b[0m head_mask\u001b[39m=\u001b[39;49mhead_mask[i],\n\u001b[1;32m 894\u001b[0m encoder_hidden_states\u001b[39m=\u001b[39;49mencoder_hidden_states,\n\u001b[1;32m 895\u001b[0m encoder_attention_mask\u001b[39m=\u001b[39;49mencoder_attention_mask,\n\u001b[1;32m 896\u001b[0m use_cache\u001b[39m=\u001b[39;49muse_cache,\n\u001b[1;32m 897\u001b[0m output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m 898\u001b[0m )\n\u001b[1;32m 900\u001b[0m hidden_states \u001b[39m=\u001b[39m outputs[\u001b[39m0\u001b[39m]\n\u001b[1;32m 901\u001b[0m \u001b[39mif\u001b[39;00m use_cache \u001b[39mis\u001b[39;00m \u001b[39mTrue\u001b[39;00m:\n", + "File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", + "File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py:390\u001b[0m, in \u001b[0;36mGPT2Block.forward\u001b[0;34m(self, hidden_states, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions)\u001b[0m\n\u001b[1;32m 388\u001b[0m residual \u001b[39m=\u001b[39m hidden_states\n\u001b[1;32m 389\u001b[0m hidden_states \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mln_1(hidden_states)\n\u001b[0;32m--> 390\u001b[0m attn_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mattn(\n\u001b[1;32m 391\u001b[0m hidden_states,\n\u001b[1;32m 392\u001b[0m layer_past\u001b[39m=\u001b[39;49mlayer_past,\n\u001b[1;32m 393\u001b[0m attention_mask\u001b[39m=\u001b[39;49mattention_mask,\n\u001b[1;32m 394\u001b[0m head_mask\u001b[39m=\u001b[39;49mhead_mask,\n\u001b[1;32m 395\u001b[0m use_cache\u001b[39m=\u001b[39;49muse_cache,\n\u001b[1;32m 396\u001b[0m output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m 397\u001b[0m )\n\u001b[1;32m 398\u001b[0m attn_output \u001b[39m=\u001b[39m attn_outputs[\u001b[39m0\u001b[39m] \u001b[39m# output_attn: a, present, (attentions)\u001b[39;00m\n\u001b[1;32m 399\u001b[0m outputs \u001b[39m=\u001b[39m attn_outputs[\u001b[39m1\u001b[39m:]\n", + "File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", + "File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py:290\u001b[0m, in \u001b[0;36mGPT2Attention.forward\u001b[0;34m(self, hidden_states, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions)\u001b[0m\n\u001b[1;32m 287\u001b[0m new_shape \u001b[39m=\u001b[39m tensor\u001b[39m.\u001b[39msize()[:\u001b[39m-\u001b[39m\u001b[39m2\u001b[39m] \u001b[39m+\u001b[39m (num_heads \u001b[39m*\u001b[39m attn_head_size,)\n\u001b[1;32m 288\u001b[0m \u001b[39mreturn\u001b[39;00m tensor\u001b[39m.\u001b[39mview(new_shape)\n\u001b[0;32m--> 290\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\n\u001b[1;32m 291\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 292\u001b[0m hidden_states: Optional[Tuple[torch\u001b[39m.\u001b[39mFloatTensor]],\n\u001b[1;32m 293\u001b[0m layer_past: Optional[Tuple[torch\u001b[39m.\u001b[39mTensor]] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 294\u001b[0m attention_mask: Optional[torch\u001b[39m.\u001b[39mFloatTensor] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 295\u001b[0m head_mask: Optional[torch\u001b[39m.\u001b[39mFloatTensor] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 296\u001b[0m encoder_hidden_states: Optional[torch\u001b[39m.\u001b[39mTensor] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 297\u001b[0m encoder_attention_mask: Optional[torch\u001b[39m.\u001b[39mFloatTensor] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 298\u001b[0m use_cache: Optional[\u001b[39mbool\u001b[39m] \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m,\n\u001b[1;32m 299\u001b[0m output_attentions: Optional[\u001b[39mbool\u001b[39m] \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m,\n\u001b[1;32m 300\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Tuple[Union[torch\u001b[39m.\u001b[39mTensor, Tuple[torch\u001b[39m.\u001b[39mTensor]], \u001b[39m.\u001b[39m\u001b[39m.\u001b[39m\u001b[39m.\u001b[39m]:\n\u001b[1;32m 301\u001b[0m \u001b[39mif\u001b[39;00m encoder_hidden_states \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 302\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mhasattr\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mq_attn\u001b[39m\u001b[39m\"\u001b[39m):\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "# Pick one of the voices from the output above\n", + "voice = 'tom'\n", + "\n", + "# Load it and send it through Tortoise.\n", + "voice_samples, conditioning_latents = load_voice(voice)\n", + "gen = tts.tts_with_preset(text, voice_samples=voice_samples, conditioning_latents=conditioning_latents, \n", + " preset=preset)\n", + "torchaudio.save('generated.wav', gen.squeeze(0).cpu(), 24000)\n", + "IPython.display.Audio('generated.wav')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "16Xs2SSC3BXa" + }, + "outputs": [], + "source": [ + "# Tortoise can also generate speech using a random voice. The voice changes each time you execute this!\n", + "# (Note: random voices can be prone to strange utterances)\n", + "gen = tts.tts_with_preset(text, voice_samples=None, conditioning_latents=None, preset=preset)\n", + "torchaudio.save('generated.wav', gen.squeeze(0).cpu(), 24000)\n", + "IPython.display.Audio('generated.wav')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "fYTk8KUezUr5" + }, + "outputs": [], + "source": [ + "# You can also combine conditioning voices. Combining voices produces a new voice\n", + "# with traits from all the parents.\n", + "#\n", + "# Lets see what it would sound like if Picard and Kirk had a kid with a penchant for philosophy:\n", + "voice_samples, conditioning_latents = load_voices(['pat', 'william'])\n", + "\n", + "gen = tts.tts_with_preset(\"They used to say that if man was meant to fly, he’d have wings. But he did fly. He discovered he had to.\", \n", + " voice_samples=None, conditioning_latents=None, preset=preset)\n", + "torchaudio.save('captain_kirkard.wav', gen.squeeze(0).cpu(), 24000)\n", + "IPython.display.Audio('captain_kirkard.wav')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "t66yqWgu68KL" + }, + "outputs": [], + "source": [ + "del tts # Will break other cells, but necessary to conserve RAM if you want to run this cell.\n", + "\n", + "# Tortoise comes with some scripts that does a lot of the lifting for you. For example,\n", + "# read.py will read a text file for you.\n", + "!python3 tortoise/read.py --voice=train_atkins --textfile=tortoise/data/riding_hood.txt --preset=ultra_fast --output_path=.\n", + "\n", + "IPython.display.Audio('train_atkins/combined.wav')\n", + "# This will take awhile.." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "tortoise-tts.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}