added kv_cache

This commit is contained in:
manmay-nakhashi 2023-07-15 23:00:19 +05:30
parent 82724cca54
commit a88534adb2
3 changed files with 201 additions and 66 deletions

View file

@ -27,7 +27,7 @@ from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
pbar = None
DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser('~'), '.cache', 'tortoise', 'models')
MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', DEFAULT_MODELS_DIR)
MODELS_DIR = MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', DEFAULT_MODELS_DIR)
MODELS = {
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth',
'classifier.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth',
@ -198,7 +198,7 @@ class TextToSpeech:
Main entry point into Tortoise.
"""
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, use_deepspeed=False, device=None):
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, kv_cache=False,use_deepspeed=False, device=None):
"""
Constructor
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
@ -229,7 +229,7 @@ class TextToSpeech:
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
train_solo_embeddings=False).cpu().eval()
self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)))
self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed)
self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=kv_cache)
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,

View file

@ -33,50 +33,23 @@ class ResBlock(nn.Module):
class GPT2InferenceModel(GPT2PreTrainedModel):
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear):
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=False):
super().__init__(config)
self.transformer = gpt
self.text_pos_embedding = text_pos_emb
self.embeddings = embeddings
self.lm_head = nn.Sequential(norm, linear)
# Model parallel
self.model_parallel = False
self.device_map = None
self.cached_mel_emb = None
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.transformer.h))
self.transformer.parallelize(self.device_map)
self.lm_head = self.lm_head.to(self.transformer.first_device)
self.model_parallel = True
def deparallelize(self):
self.transformer.deparallelize()
self.transformer = self.transformer.to("cpu")
self.lm_head = self.lm_head.to("cpu")
self.model_parallel = False
torch.cuda.empty_cache()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
self.kv_cache = kv_cache
def store_mel_emb(self, mel_emb):
self.cached_mel_emb = mel_emb
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) # usually None
if not self.kv_cache:
past_key_values = None
# only last token for inputs_ids if past is defined in kwargs
if past:
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
@ -88,13 +61,13 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
@ -121,7 +94,9 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
assert self.cached_mel_emb is not None
assert inputs_embeds is None # Not supported by this inference model.
assert labels is None # Training not supported by this inference model.
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# Create embedding
mel_len = self.cached_mel_emb.shape[1]
@ -130,14 +105,17 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
text_emb = self.embeddings(text_inputs)
text_emb = text_emb + self.text_pos_embedding(text_emb)
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0)
else:
mel_emb = self.cached_mel_emb.repeat_interleave(
text_emb.shape[0] // self.cached_mel_emb.shape[0], 0
)
else: # this outcome only occurs once per loop in most cases
mel_emb = self.cached_mel_emb
emb = torch.cat([mel_emb, text_emb], dim=1)
else:
emb = self.embeddings(input_ids)
emb = emb + self.text_pos_embedding.get_fixed_embedding(attention_mask.shape[1]-mel_len, attention_mask.device)
emb = emb + self.text_pos_embedding.get_fixed_embedding(
attention_mask.shape[1] - mel_len, attention_mask.device
)
transformer_outputs = self.transformer(
inputs_embeds=emb,
past_key_values=past_key_values,
@ -153,12 +131,6 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.transformer.first_device)
hidden_states = hidden_states.to(self.lm_head.weight.device)
lm_logits = self.lm_head(hidden_states)
if not return_dict:
@ -181,7 +153,10 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
tuple(
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past
)
for layer_past in past
)
@ -340,7 +315,7 @@ class UnifiedVoice(nn.Module):
embeddings.append(self.mel_embedding)
for module in embeddings:
module.weight.data.normal_(mean=0.0, std=.02)
def post_init_gpt2_config(self, use_deepspeed=False):
def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
gpt_config = GPT2Config(
vocab_size=self.max_mel_tokens,
@ -358,7 +333,8 @@ class UnifiedVoice(nn.Module):
self.mel_pos_embedding,
self.mel_embedding,
self.final_norm,
self.mel_head
self.mel_head,
kv_cache=kv_cache,
)
if use_deepspeed:
import deepspeed

View file

@ -40,6 +40,53 @@
"/home/manmay/anaconda3/envs/tortoise/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2023-07-15 10:55:28,559] [INFO] [logging.py:93:log_dist] [Rank -1] DeepSpeed info: version=0.8.3, git-hash=unknown, git-branch=unknown\n",
"[2023-07-15 10:55:28,603] [WARNING] [config_utils.py:75:_process_deprecated_field] Config parameter mp_size is deprecated use tensor_parallel.tp_size instead\n",
"[2023-07-15 10:55:28,605] [INFO] [logging.py:93:log_dist] [Rank -1] quantize_bits = 8 mlp_extra_grouping = False, quantize_groups = 1\n",
"WARNING! Setting BLOOMLayerPolicy._orig_layer_class to None due to Exception: module 'transformers.models' has no attribute 'bloom'\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using /home/manmay/.cache/torch_extensions/py39_cu117 as PyTorch extensions root...\n",
"Detected CUDA files, patching ldflags\n",
"Emitting ninja build file /home/manmay/.cache/torch_extensions/py39_cu117/transformer_inference/build.ninja...\n",
"Building extension module transformer_inference...\n",
"Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n",
"Loading extension module transformer_inference...\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"ninja: no work to do.\n",
"Time to load transformer_inference op: 0.9313881397247314 seconds\n",
"[2023-07-15 10:55:34,938] [INFO] [logging.py:93:log_dist] [Rank -1] DeepSpeed-Inference config: {'layer_id': 0, 'hidden_size': 1024, 'intermediate_size': 4096, 'heads': 16, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-05, 'mp_size': 1, 'q_int8': False, 'scale_attention': True, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'rotary_dim': -1, 'rotate_half': False, 'rotate_every_two': True, 'return_tuple': True, 'mlp_after_attn': True, 'mlp_act_func_type': <ActivationFuncType.GELU: 1>, 'specialized_mode': False, 'training_mp_size': 1, 'bigscience_bloom': False, 'max_out_tokens': 1024, 'scale_attn_by_inverse_layer_idx': False, 'enable_qkv_quantization': False, 'use_mup': False, 'return_single_tuple': False}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using /home/manmay/.cache/torch_extensions/py39_cu117 as PyTorch extensions root...\n",
"No modifications detected for re-loaded extension module transformer_inference, skipping build step...\n",
"Loading extension module transformer_inference...\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time to load transformer_inference op: 0.15709829330444336 seconds\n"
]
}
],
"source": [
@ -55,9 +102,9 @@
"from tortoise.utils.audio import load_audio, load_voice, load_voices\n",
"\n",
"# This will download all the models used by Tortoise from the HF hub.\n",
"tts = TextToSpeech()\n",
"# tts = TextToSpeech()\n",
"# If you want to use deepspeed the pass use_deepspeed=True nearly 2x faster than normal\n",
"# tts = TextToSpeech(use_deepspeed=True)"
"tts = TextToSpeech(use_deepspeed=True, kv_cache=True)"
]
},
{
@ -151,7 +198,127 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 38%|███▊ | 6/16 [00:31<00:52, 5.20s/it]\n"
" 0%| | 0/16 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 74])\n",
"------------------------------------------------------\n",
"Free memory : 2.864563 (GigaBytes) \n",
"Total memory: 5.805298 (GigaBytes) \n",
"Requested memory: 0.421875 (GigaBytes) \n",
"Setting maximum total tokens (input + output) to 1024 \n",
"------------------------------------------------------\n",
"torch.Size([1, 75])\n",
"torch.Size([1, 76])\n",
"torch.Size([1, 77])\n",
"torch.Size([1, 78])\n",
"torch.Size([1, 79])\n",
"torch.Size([1, 80])\n",
"torch.Size([1, 81])\n",
"torch.Size([1, 82])\n",
"torch.Size([1, 83])\n",
"torch.Size([1, 84])\n",
"torch.Size([1, 85])\n",
"torch.Size([1, 86])\n",
"torch.Size([1, 87])\n",
"torch.Size([1, 88])\n",
"torch.Size([1, 89])\n",
"torch.Size([1, 90])\n",
"torch.Size([1, 91])\n",
"torch.Size([1, 92])\n",
"torch.Size([1, 93])\n",
"torch.Size([1, 94])\n",
"torch.Size([1, 95])\n",
"torch.Size([1, 96])\n",
"torch.Size([1, 97])\n",
"torch.Size([1, 98])\n",
"torch.Size([1, 99])\n",
"torch.Size([1, 100])\n",
"torch.Size([1, 101])\n",
"torch.Size([1, 102])\n",
"torch.Size([1, 103])\n",
"torch.Size([1, 104])\n",
"torch.Size([1, 105])\n",
"torch.Size([1, 106])\n",
"torch.Size([1, 107])\n",
"torch.Size([1, 108])\n",
"torch.Size([1, 109])\n",
"torch.Size([1, 110])\n",
"torch.Size([1, 111])\n",
"torch.Size([1, 112])\n",
"torch.Size([1, 113])\n",
"torch.Size([1, 114])\n",
"torch.Size([1, 115])\n",
"torch.Size([1, 116])\n",
"torch.Size([1, 117])\n",
"torch.Size([1, 118])\n",
"torch.Size([1, 119])\n",
"torch.Size([1, 120])\n",
"torch.Size([1, 121])\n",
"torch.Size([1, 122])\n",
"torch.Size([1, 123])\n",
"torch.Size([1, 124])\n",
"torch.Size([1, 125])\n",
"torch.Size([1, 126])\n",
"torch.Size([1, 127])\n",
"torch.Size([1, 128])\n",
"torch.Size([1, 129])\n",
"torch.Size([1, 130])\n",
"torch.Size([1, 131])\n",
"torch.Size([1, 132])\n",
"torch.Size([1, 133])\n",
"torch.Size([1, 134])\n",
"torch.Size([1, 135])\n",
"torch.Size([1, 136])\n",
"torch.Size([1, 137])\n",
"torch.Size([1, 138])\n",
"torch.Size([1, 139])\n",
"torch.Size([1, 140])\n",
"torch.Size([1, 141])\n",
"torch.Size([1, 142])\n",
"torch.Size([1, 143])\n",
"torch.Size([1, 144])\n",
"torch.Size([1, 145])\n",
"torch.Size([1, 146])\n",
"torch.Size([1, 147])\n",
"torch.Size([1, 148])\n",
"torch.Size([1, 149])\n",
"torch.Size([1, 150])\n",
"torch.Size([1, 151])\n",
"torch.Size([1, 152])\n",
"torch.Size([1, 153])\n",
"torch.Size([1, 154])\n",
"torch.Size([1, 155])\n",
"torch.Size([1, 156])\n",
"torch.Size([1, 157])\n",
"torch.Size([1, 158])\n",
"torch.Size([1, 159])\n",
"torch.Size([1, 160])\n",
"torch.Size([1, 161])\n",
"torch.Size([1, 162])\n",
"torch.Size([1, 163])\n",
"torch.Size([1, 164])\n",
"torch.Size([1, 165])\n",
"torch.Size([1, 166])\n",
"torch.Size([1, 167])\n",
"torch.Size([1, 168])\n",
"torch.Size([1, 169])\n",
"torch.Size([1, 170])\n",
"torch.Size([1, 171])\n",
"torch.Size([1, 172])\n",
"torch.Size([1, 173])\n",
"torch.Size([1, 174])\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/16 [00:30<?, ?it/s]\n"
]
},
{
@ -164,18 +331,10 @@
"Cell \u001b[0;32mIn[4], line 6\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[39m# Load it and send it through Tortoise.\u001b[39;00m\n\u001b[1;32m 5\u001b[0m voice_samples, conditioning_latents \u001b[39m=\u001b[39m load_voice(voice)\n\u001b[0;32m----> 6\u001b[0m gen \u001b[39m=\u001b[39m tts\u001b[39m.\u001b[39;49mtts_with_preset(text, voice_samples\u001b[39m=\u001b[39;49mvoice_samples, conditioning_latents\u001b[39m=\u001b[39;49mconditioning_latents, \n\u001b[1;32m 7\u001b[0m preset\u001b[39m=\u001b[39;49mpreset)\n\u001b[1;32m 8\u001b[0m torchaudio\u001b[39m.\u001b[39msave(\u001b[39m'\u001b[39m\u001b[39mgenerated.wav\u001b[39m\u001b[39m'\u001b[39m, gen\u001b[39m.\u001b[39msqueeze(\u001b[39m0\u001b[39m)\u001b[39m.\u001b[39mcpu(), \u001b[39m24000\u001b[39m)\n\u001b[1;32m 9\u001b[0m IPython\u001b[39m.\u001b[39mdisplay\u001b[39m.\u001b[39mAudio(\u001b[39m'\u001b[39m\u001b[39mgenerated.wav\u001b[39m\u001b[39m'\u001b[39m)\n",
"File \u001b[0;32m/data/speech_synth/tortoise-tts/tortoise/api.py:329\u001b[0m, in \u001b[0;36mTextToSpeech.tts_with_preset\u001b[0;34m(self, text, preset, **kwargs)\u001b[0m\n\u001b[1;32m 327\u001b[0m settings\u001b[39m.\u001b[39mupdate(presets[preset])\n\u001b[1;32m 328\u001b[0m settings\u001b[39m.\u001b[39mupdate(kwargs) \u001b[39m# allow overriding of preset settings with kwargs\u001b[39;00m\n\u001b[0;32m--> 329\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtts(text, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49msettings)\n",
"File \u001b[0;32m/data/speech_synth/tortoise-tts/tortoise/api.py:412\u001b[0m, in \u001b[0;36mTextToSpeech.tts\u001b[0;34m(self, text, voice_samples, conditioning_latents, k, verbose, use_deterministic_seed, return_deterministic_state, num_autoregressive_samples, temperature, length_penalty, repetition_penalty, top_p, max_mel_tokens, cvvp_amount, diffusion_iterations, cond_free, cond_free_k, diffusion_temperature, **hf_generate_kwargs)\u001b[0m\n\u001b[1;32m 410\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m\"\u001b[39m\u001b[39mGenerating autoregressive samples..\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 411\u001b[0m \u001b[39mfor\u001b[39;00m b \u001b[39min\u001b[39;00m tqdm(\u001b[39mrange\u001b[39m(num_batches), disable\u001b[39m=\u001b[39m\u001b[39mnot\u001b[39;00m verbose):\n\u001b[0;32m--> 412\u001b[0m codes \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mautoregressive\u001b[39m.\u001b[39;49minference_speech(auto_conditioning, text_tokens,\n\u001b[1;32m 413\u001b[0m do_sample\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m 414\u001b[0m top_p\u001b[39m=\u001b[39;49mtop_p,\n\u001b[1;32m 415\u001b[0m temperature\u001b[39m=\u001b[39;49mtemperature,\n\u001b[1;32m 416\u001b[0m num_return_sequences\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mautoregressive_batch_size,\n\u001b[1;32m 417\u001b[0m length_penalty\u001b[39m=\u001b[39;49mlength_penalty,\n\u001b[1;32m 418\u001b[0m repetition_penalty\u001b[39m=\u001b[39;49mrepetition_penalty,\n\u001b[1;32m 419\u001b[0m max_generate_length\u001b[39m=\u001b[39;49mmax_mel_tokens,\n\u001b[1;32m 420\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mhf_generate_kwargs)\n\u001b[1;32m 421\u001b[0m padding_needed \u001b[39m=\u001b[39m max_mel_tokens \u001b[39m-\u001b[39m codes\u001b[39m.\u001b[39mshape[\u001b[39m1\u001b[39m]\n\u001b[1;32m 422\u001b[0m codes \u001b[39m=\u001b[39m F\u001b[39m.\u001b[39mpad(codes, (\u001b[39m0\u001b[39m, padding_needed), value\u001b[39m=\u001b[39mstop_mel_token)\n",
"File \u001b[0;32m/data/speech_synth/tortoise-tts/tortoise/models/autoregressive.py:513\u001b[0m, in \u001b[0;36mUnifiedVoice.inference_speech\u001b[0;34m(self, speech_conditioning_latent, text_inputs, input_tokens, num_return_sequences, max_generate_length, typical_sampling, typical_mass, **hf_generate_kwargs)\u001b[0m\n\u001b[1;32m 511\u001b[0m logits_processor \u001b[39m=\u001b[39m LogitsProcessorList([TypicalLogitsWarper(mass\u001b[39m=\u001b[39mtypical_mass)]) \u001b[39mif\u001b[39;00m typical_sampling \u001b[39melse\u001b[39;00m LogitsProcessorList()\n\u001b[1;32m 512\u001b[0m max_length \u001b[39m=\u001b[39m trunc_index \u001b[39m+\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmax_mel_tokens \u001b[39m-\u001b[39m \u001b[39m1\u001b[39m \u001b[39mif\u001b[39;00m max_generate_length \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m trunc_index \u001b[39m+\u001b[39m max_generate_length\n\u001b[0;32m--> 513\u001b[0m gen \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49minference_model\u001b[39m.\u001b[39;49mgenerate(inputs, bos_token_id\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstart_mel_token, pad_token_id\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstop_mel_token, eos_token_id\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstop_mel_token,\n\u001b[1;32m 514\u001b[0m max_length\u001b[39m=\u001b[39;49mmax_length, logits_processor\u001b[39m=\u001b[39;49mlogits_processor,\n\u001b[1;32m 515\u001b[0m num_return_sequences\u001b[39m=\u001b[39;49mnum_return_sequences, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mhf_generate_kwargs)\n\u001b[1;32m 516\u001b[0m \u001b[39mreturn\u001b[39;00m gen[:, trunc_index:]\n",
"File \u001b[0;32m/data/speech_synth/tortoise-tts/tortoise/models/autoregressive.py:490\u001b[0m, in \u001b[0;36mUnifiedVoice.inference_speech\u001b[0;34m(self, speech_conditioning_latent, text_inputs, input_tokens, num_return_sequences, max_generate_length, typical_sampling, typical_mass, **hf_generate_kwargs)\u001b[0m\n\u001b[1;32m 488\u001b[0m logits_processor \u001b[39m=\u001b[39m LogitsProcessorList([TypicalLogitsWarper(mass\u001b[39m=\u001b[39mtypical_mass)]) \u001b[39mif\u001b[39;00m typical_sampling \u001b[39melse\u001b[39;00m LogitsProcessorList()\n\u001b[1;32m 489\u001b[0m max_length \u001b[39m=\u001b[39m trunc_index \u001b[39m+\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmax_mel_tokens \u001b[39m-\u001b[39m \u001b[39m1\u001b[39m \u001b[39mif\u001b[39;00m max_generate_length \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m trunc_index \u001b[39m+\u001b[39m max_generate_length\n\u001b[0;32m--> 490\u001b[0m gen \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49minference_model\u001b[39m.\u001b[39;49mgenerate(inputs, bos_token_id\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstart_mel_token, pad_token_id\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstop_mel_token, eos_token_id\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstop_mel_token,\n\u001b[1;32m 491\u001b[0m max_length\u001b[39m=\u001b[39;49mmax_length, logits_processor\u001b[39m=\u001b[39;49mlogits_processor,\n\u001b[1;32m 492\u001b[0m num_return_sequences\u001b[39m=\u001b[39;49mnum_return_sequences, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mhf_generate_kwargs)\n\u001b[1;32m 493\u001b[0m \u001b[39mreturn\u001b[39;00m gen[:, trunc_index:]\n",
"File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[39m@functools\u001b[39m\u001b[39m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdecorate_context\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[39mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[39mreturn\u001b[39;00m func(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
"File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/transformers/generation_utils.py:1310\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, typical_p, repetition_penalty, bad_words_ids, force_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, num_return_sequences, max_time, max_new_tokens, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, logits_processor, renormalize_logits, stopping_criteria, constraints, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, forced_bos_token_id, forced_eos_token_id, remove_invalid_values, synced_gpus, exponential_decay_length_penalty, **model_kwargs)\u001b[0m\n\u001b[1;32m 1302\u001b[0m input_ids, model_kwargs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_expand_inputs_for_generation(\n\u001b[1;32m 1303\u001b[0m input_ids,\n\u001b[1;32m 1304\u001b[0m expand_size\u001b[39m=\u001b[39mnum_return_sequences,\n\u001b[1;32m 1305\u001b[0m is_encoder_decoder\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mis_encoder_decoder,\n\u001b[1;32m 1306\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mmodel_kwargs,\n\u001b[1;32m 1307\u001b[0m )\n\u001b[1;32m 1309\u001b[0m \u001b[39m# 12. run sample\u001b[39;00m\n\u001b[0;32m-> 1310\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49msample(\n\u001b[1;32m 1311\u001b[0m input_ids,\n\u001b[1;32m 1312\u001b[0m logits_processor\u001b[39m=\u001b[39;49mlogits_processor,\n\u001b[1;32m 1313\u001b[0m logits_warper\u001b[39m=\u001b[39;49mlogits_warper,\n\u001b[1;32m 1314\u001b[0m stopping_criteria\u001b[39m=\u001b[39;49mstopping_criteria,\n\u001b[1;32m 1315\u001b[0m pad_token_id\u001b[39m=\u001b[39;49mpad_token_id,\n\u001b[1;32m 1316\u001b[0m eos_token_id\u001b[39m=\u001b[39;49meos_token_id,\n\u001b[1;32m 1317\u001b[0m output_scores\u001b[39m=\u001b[39;49moutput_scores,\n\u001b[1;32m 1318\u001b[0m return_dict_in_generate\u001b[39m=\u001b[39;49mreturn_dict_in_generate,\n\u001b[1;32m 1319\u001b[0m synced_gpus\u001b[39m=\u001b[39;49msynced_gpus,\n\u001b[1;32m 1320\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mmodel_kwargs,\n\u001b[1;32m 1321\u001b[0m )\n\u001b[1;32m 1323\u001b[0m \u001b[39melif\u001b[39;00m is_beam_gen_mode:\n\u001b[1;32m 1324\u001b[0m \u001b[39mif\u001b[39;00m num_return_sequences \u001b[39m>\u001b[39m num_beams:\n",
"File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/transformers/generation_utils.py:1926\u001b[0m, in \u001b[0;36mGenerationMixin.sample\u001b[0;34m(self, input_ids, logits_processor, stopping_criteria, logits_warper, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs)\u001b[0m\n\u001b[1;32m 1923\u001b[0m model_inputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mprepare_inputs_for_generation(input_ids, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mmodel_kwargs)\n\u001b[1;32m 1925\u001b[0m \u001b[39m# forward pass to get next token\u001b[39;00m\n\u001b[0;32m-> 1926\u001b[0m outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m(\n\u001b[1;32m 1927\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mmodel_inputs,\n\u001b[1;32m 1928\u001b[0m return_dict\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m 1929\u001b[0m output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m 1930\u001b[0m output_hidden_states\u001b[39m=\u001b[39;49moutput_hidden_states,\n\u001b[1;32m 1931\u001b[0m )\n\u001b[1;32m 1933\u001b[0m \u001b[39mif\u001b[39;00m synced_gpus \u001b[39mand\u001b[39;00m this_peer_finished:\n\u001b[1;32m 1934\u001b[0m cur_len \u001b[39m=\u001b[39m cur_len \u001b[39m+\u001b[39m \u001b[39m1\u001b[39m\n",
"File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
"File \u001b[0;32m/data/speech_synth/tortoise-tts/tortoise/models/autoregressive.py:142\u001b[0m, in \u001b[0;36mGPT2InferenceModel.forward\u001b[0;34m(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 139\u001b[0m emb \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39membeddings(input_ids)\n\u001b[1;32m 140\u001b[0m emb \u001b[39m=\u001b[39m emb \u001b[39m+\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtext_pos_embedding\u001b[39m.\u001b[39mget_fixed_embedding(attention_mask\u001b[39m.\u001b[39mshape[\u001b[39m1\u001b[39m]\u001b[39m-\u001b[39mmel_len, attention_mask\u001b[39m.\u001b[39mdevice)\n\u001b[0;32m--> 142\u001b[0m transformer_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtransformer(\n\u001b[1;32m 143\u001b[0m inputs_embeds\u001b[39m=\u001b[39;49memb,\n\u001b[1;32m 144\u001b[0m past_key_values\u001b[39m=\u001b[39;49mpast_key_values,\n\u001b[1;32m 145\u001b[0m attention_mask\u001b[39m=\u001b[39;49mattention_mask,\n\u001b[1;32m 146\u001b[0m token_type_ids\u001b[39m=\u001b[39;49mtoken_type_ids,\n\u001b[1;32m 147\u001b[0m position_ids\u001b[39m=\u001b[39;49mposition_ids,\n\u001b[1;32m 148\u001b[0m head_mask\u001b[39m=\u001b[39;49mhead_mask,\n\u001b[1;32m 149\u001b[0m encoder_hidden_states\u001b[39m=\u001b[39;49mencoder_hidden_states,\n\u001b[1;32m 150\u001b[0m encoder_attention_mask\u001b[39m=\u001b[39;49mencoder_attention_mask,\n\u001b[1;32m 151\u001b[0m use_cache\u001b[39m=\u001b[39;49muse_cache,\n\u001b[1;32m 152\u001b[0m output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m 153\u001b[0m output_hidden_states\u001b[39m=\u001b[39;49moutput_hidden_states,\n\u001b[1;32m 154\u001b[0m return_dict\u001b[39m=\u001b[39;49mreturn_dict,\n\u001b[1;32m 155\u001b[0m )\n\u001b[1;32m 156\u001b[0m hidden_states \u001b[39m=\u001b[39m transformer_outputs[\u001b[39m0\u001b[39m]\n\u001b[1;32m 158\u001b[0m \u001b[39m# Set device for model parallelism\u001b[39;00m\n",
"File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
"File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py:889\u001b[0m, in \u001b[0;36mGPT2Model.forward\u001b[0;34m(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 879\u001b[0m outputs \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mutils\u001b[39m.\u001b[39mcheckpoint\u001b[39m.\u001b[39mcheckpoint(\n\u001b[1;32m 880\u001b[0m create_custom_forward(block),\n\u001b[1;32m 881\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 886\u001b[0m encoder_attention_mask,\n\u001b[1;32m 887\u001b[0m )\n\u001b[1;32m 888\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 889\u001b[0m outputs \u001b[39m=\u001b[39m block(\n\u001b[1;32m 890\u001b[0m hidden_states,\n\u001b[1;32m 891\u001b[0m layer_past\u001b[39m=\u001b[39;49mlayer_past,\n\u001b[1;32m 892\u001b[0m attention_mask\u001b[39m=\u001b[39;49mattention_mask,\n\u001b[1;32m 893\u001b[0m head_mask\u001b[39m=\u001b[39;49mhead_mask[i],\n\u001b[1;32m 894\u001b[0m encoder_hidden_states\u001b[39m=\u001b[39;49mencoder_hidden_states,\n\u001b[1;32m 895\u001b[0m encoder_attention_mask\u001b[39m=\u001b[39;49mencoder_attention_mask,\n\u001b[1;32m 896\u001b[0m use_cache\u001b[39m=\u001b[39;49muse_cache,\n\u001b[1;32m 897\u001b[0m output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m 898\u001b[0m )\n\u001b[1;32m 900\u001b[0m hidden_states \u001b[39m=\u001b[39m outputs[\u001b[39m0\u001b[39m]\n\u001b[1;32m 901\u001b[0m \u001b[39mif\u001b[39;00m use_cache \u001b[39mis\u001b[39;00m \u001b[39mTrue\u001b[39;00m:\n",
"File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
"File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py:390\u001b[0m, in \u001b[0;36mGPT2Block.forward\u001b[0;34m(self, hidden_states, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions)\u001b[0m\n\u001b[1;32m 388\u001b[0m residual \u001b[39m=\u001b[39m hidden_states\n\u001b[1;32m 389\u001b[0m hidden_states \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mln_1(hidden_states)\n\u001b[0;32m--> 390\u001b[0m attn_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mattn(\n\u001b[1;32m 391\u001b[0m hidden_states,\n\u001b[1;32m 392\u001b[0m layer_past\u001b[39m=\u001b[39;49mlayer_past,\n\u001b[1;32m 393\u001b[0m attention_mask\u001b[39m=\u001b[39;49mattention_mask,\n\u001b[1;32m 394\u001b[0m head_mask\u001b[39m=\u001b[39;49mhead_mask,\n\u001b[1;32m 395\u001b[0m use_cache\u001b[39m=\u001b[39;49muse_cache,\n\u001b[1;32m 396\u001b[0m output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m 397\u001b[0m )\n\u001b[1;32m 398\u001b[0m attn_output \u001b[39m=\u001b[39m attn_outputs[\u001b[39m0\u001b[39m] \u001b[39m# output_attn: a, present, (attentions)\u001b[39;00m\n\u001b[1;32m 399\u001b[0m outputs \u001b[39m=\u001b[39m attn_outputs[\u001b[39m1\u001b[39m:]\n",
"File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
"File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py:290\u001b[0m, in \u001b[0;36mGPT2Attention.forward\u001b[0;34m(self, hidden_states, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions)\u001b[0m\n\u001b[1;32m 287\u001b[0m new_shape \u001b[39m=\u001b[39m tensor\u001b[39m.\u001b[39msize()[:\u001b[39m-\u001b[39m\u001b[39m2\u001b[39m] \u001b[39m+\u001b[39m (num_heads \u001b[39m*\u001b[39m attn_head_size,)\n\u001b[1;32m 288\u001b[0m \u001b[39mreturn\u001b[39;00m tensor\u001b[39m.\u001b[39mview(new_shape)\n\u001b[0;32m--> 290\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\n\u001b[1;32m 291\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 292\u001b[0m hidden_states: Optional[Tuple[torch\u001b[39m.\u001b[39mFloatTensor]],\n\u001b[1;32m 293\u001b[0m layer_past: Optional[Tuple[torch\u001b[39m.\u001b[39mTensor]] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 294\u001b[0m attention_mask: Optional[torch\u001b[39m.\u001b[39mFloatTensor] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 295\u001b[0m head_mask: Optional[torch\u001b[39m.\u001b[39mFloatTensor] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 296\u001b[0m encoder_hidden_states: Optional[torch\u001b[39m.\u001b[39mTensor] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 297\u001b[0m encoder_attention_mask: Optional[torch\u001b[39m.\u001b[39mFloatTensor] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 298\u001b[0m use_cache: Optional[\u001b[39mbool\u001b[39m] \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m,\n\u001b[1;32m 299\u001b[0m output_attentions: Optional[\u001b[39mbool\u001b[39m] \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m,\n\u001b[1;32m 300\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Tuple[Union[torch\u001b[39m.\u001b[39mTensor, Tuple[torch\u001b[39m.\u001b[39mTensor]], \u001b[39m.\u001b[39m\u001b[39m.\u001b[39m\u001b[39m.\u001b[39m]:\n\u001b[1;32m 301\u001b[0m \u001b[39mif\u001b[39;00m encoder_hidden_states \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 302\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mhasattr\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mq_attn\u001b[39m\u001b[39m\"\u001b[39m):\n",
"File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/transformers/generation_utils.py:1963\u001b[0m, in \u001b[0;36mGenerationMixin.sample\u001b[0;34m(self, input_ids, logits_processor, stopping_criteria, logits_warper, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs)\u001b[0m\n\u001b[1;32m 1961\u001b[0m \u001b[39m# sample\u001b[39;00m\n\u001b[1;32m 1962\u001b[0m probs \u001b[39m=\u001b[39m nn\u001b[39m.\u001b[39mfunctional\u001b[39m.\u001b[39msoftmax(next_token_scores, dim\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[0;32m-> 1963\u001b[0m next_tokens \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49mmultinomial(probs, num_samples\u001b[39m=\u001b[39;49m\u001b[39m1\u001b[39;49m)\u001b[39m.\u001b[39msqueeze(\u001b[39m1\u001b[39m)\n\u001b[1;32m 1965\u001b[0m \u001b[39m# finished sentences should have their next token be a padding token\u001b[39;00m\n\u001b[1;32m 1966\u001b[0m \u001b[39mif\u001b[39;00m eos_token_id \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}