From a88534adb2440829235928e449bf43d2aaec5432 Mon Sep 17 00:00:00 2001 From: manmay-nakhashi Date: Sat, 15 Jul 2023 23:00:19 +0530 Subject: [PATCH] added kv_cache --- tortoise/api.py | 6 +- tortoise/models/autoregressive.py | 76 +++++------- tortoise_tts.ipynb | 185 +++++++++++++++++++++++++++--- 3 files changed, 201 insertions(+), 66 deletions(-) diff --git a/tortoise/api.py b/tortoise/api.py index 631f3a5..e5960a5 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -27,7 +27,7 @@ from tortoise.utils.wav2vec_alignment import Wav2VecAlignment pbar = None DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser('~'), '.cache', 'tortoise', 'models') -MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', DEFAULT_MODELS_DIR) +MODELS_DIR = MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', DEFAULT_MODELS_DIR) MODELS = { 'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth', 'classifier.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth', @@ -198,7 +198,7 @@ class TextToSpeech: Main entry point into Tortoise. """ - def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, use_deepspeed=False, device=None): + def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, kv_cache=False,use_deepspeed=False, device=None): """ Constructor :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing @@ -229,7 +229,7 @@ class TextToSpeech: heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, train_solo_embeddings=False).cpu().eval() self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir))) - self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed) + self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=kv_cache) self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, diff --git a/tortoise/models/autoregressive.py b/tortoise/models/autoregressive.py index 4d04908..26cd16c 100644 --- a/tortoise/models/autoregressive.py +++ b/tortoise/models/autoregressive.py @@ -33,50 +33,23 @@ class ResBlock(nn.Module): class GPT2InferenceModel(GPT2PreTrainedModel): - def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear): + def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=False): super().__init__(config) self.transformer = gpt self.text_pos_embedding = text_pos_emb self.embeddings = embeddings self.lm_head = nn.Sequential(norm, linear) - - # Model parallel - self.model_parallel = False - self.device_map = None - self.cached_mel_emb = None - - def parallelize(self, device_map=None): - self.device_map = ( - get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) - if device_map is None - else device_map - ) - assert_device_map(self.device_map, len(self.transformer.h)) - self.transformer.parallelize(self.device_map) - self.lm_head = self.lm_head.to(self.transformer.first_device) - self.model_parallel = True - - def deparallelize(self): - self.transformer.deparallelize() - self.transformer = self.transformer.to("cpu") - self.lm_head = self.lm_head.to("cpu") - self.model_parallel = False - torch.cuda.empty_cache() - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings + self.kv_cache = kv_cache def store_mel_emb(self, mel_emb): self.cached_mel_emb = mel_emb - def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): - - token_type_ids = kwargs.get("token_type_ids", None) + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) # usually None + if not self.kv_cache: + past_key_values = None # only last token for inputs_ids if past is defined in kwargs - if past: + if past_key_values: input_ids = input_ids[:, -1].unsqueeze(-1) if token_type_ids is not None: token_type_ids = token_type_ids[:, -1].unsqueeze(-1) @@ -88,13 +61,13 @@ class GPT2InferenceModel(GPT2PreTrainedModel): # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - if past: + if past_key_values: position_ids = position_ids[:, -1].unsqueeze(-1) else: position_ids = None return { "input_ids": input_ids, - "past_key_values": past, + "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "position_ids": position_ids, "attention_mask": attention_mask, @@ -121,7 +94,9 @@ class GPT2InferenceModel(GPT2PreTrainedModel): assert self.cached_mel_emb is not None assert inputs_embeds is None # Not supported by this inference model. assert labels is None # Training not supported by this inference model. - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) # Create embedding mel_len = self.cached_mel_emb.shape[1] @@ -130,14 +105,17 @@ class GPT2InferenceModel(GPT2PreTrainedModel): text_emb = self.embeddings(text_inputs) text_emb = text_emb + self.text_pos_embedding(text_emb) if self.cached_mel_emb.shape[0] != text_emb.shape[0]: - mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0) - else: + mel_emb = self.cached_mel_emb.repeat_interleave( + text_emb.shape[0] // self.cached_mel_emb.shape[0], 0 + ) + else: # this outcome only occurs once per loop in most cases mel_emb = self.cached_mel_emb emb = torch.cat([mel_emb, text_emb], dim=1) else: emb = self.embeddings(input_ids) - emb = emb + self.text_pos_embedding.get_fixed_embedding(attention_mask.shape[1]-mel_len, attention_mask.device) - + emb = emb + self.text_pos_embedding.get_fixed_embedding( + attention_mask.shape[1] - mel_len, attention_mask.device + ) transformer_outputs = self.transformer( inputs_embeds=emb, past_key_values=past_key_values, @@ -153,12 +131,6 @@ class GPT2InferenceModel(GPT2PreTrainedModel): return_dict=return_dict, ) hidden_states = transformer_outputs[0] - - # Set device for model parallelism - if self.model_parallel: - torch.cuda.set_device(self.transformer.first_device) - hidden_states = hidden_states.to(self.lm_head.weight.device) - lm_logits = self.lm_head(hidden_states) if not return_dict: @@ -181,7 +153,10 @@ class GPT2InferenceModel(GPT2PreTrainedModel): called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. """ return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ) for layer_past in past ) @@ -340,7 +315,7 @@ class UnifiedVoice(nn.Module): embeddings.append(self.mel_embedding) for module in embeddings: module.weight.data.normal_(mean=0.0, std=.02) - def post_init_gpt2_config(self, use_deepspeed=False): + def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False): seq_length = self.max_mel_tokens + self.max_text_tokens + 2 gpt_config = GPT2Config( vocab_size=self.max_mel_tokens, @@ -358,7 +333,8 @@ class UnifiedVoice(nn.Module): self.mel_pos_embedding, self.mel_embedding, self.final_norm, - self.mel_head + self.mel_head, + kv_cache=kv_cache, ) if use_deepspeed: import deepspeed diff --git a/tortoise_tts.ipynb b/tortoise_tts.ipynb index 2882534..fef4e88 100644 --- a/tortoise_tts.ipynb +++ b/tortoise_tts.ipynb @@ -40,6 +40,53 @@ "/home/manmay/anaconda3/envs/tortoise/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2023-07-15 10:55:28,559] [INFO] [logging.py:93:log_dist] [Rank -1] DeepSpeed info: version=0.8.3, git-hash=unknown, git-branch=unknown\n", + "[2023-07-15 10:55:28,603] [WARNING] [config_utils.py:75:_process_deprecated_field] Config parameter mp_size is deprecated use tensor_parallel.tp_size instead\n", + "[2023-07-15 10:55:28,605] [INFO] [logging.py:93:log_dist] [Rank -1] quantize_bits = 8 mlp_extra_grouping = False, quantize_groups = 1\n", + "WARNING! Setting BLOOMLayerPolicy._orig_layer_class to None due to Exception: module 'transformers.models' has no attribute 'bloom'\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using /home/manmay/.cache/torch_extensions/py39_cu117 as PyTorch extensions root...\n", + "Detected CUDA files, patching ldflags\n", + "Emitting ninja build file /home/manmay/.cache/torch_extensions/py39_cu117/transformer_inference/build.ninja...\n", + "Building extension module transformer_inference...\n", + "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", + "Loading extension module transformer_inference...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ninja: no work to do.\n", + "Time to load transformer_inference op: 0.9313881397247314 seconds\n", + "[2023-07-15 10:55:34,938] [INFO] [logging.py:93:log_dist] [Rank -1] DeepSpeed-Inference config: {'layer_id': 0, 'hidden_size': 1024, 'intermediate_size': 4096, 'heads': 16, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-05, 'mp_size': 1, 'q_int8': False, 'scale_attention': True, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'rotary_dim': -1, 'rotate_half': False, 'rotate_every_two': True, 'return_tuple': True, 'mlp_after_attn': True, 'mlp_act_func_type': , 'specialized_mode': False, 'training_mp_size': 1, 'bigscience_bloom': False, 'max_out_tokens': 1024, 'scale_attn_by_inverse_layer_idx': False, 'enable_qkv_quantization': False, 'use_mup': False, 'return_single_tuple': False}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using /home/manmay/.cache/torch_extensions/py39_cu117 as PyTorch extensions root...\n", + "No modifications detected for re-loaded extension module transformer_inference, skipping build step...\n", + "Loading extension module transformer_inference...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Time to load transformer_inference op: 0.15709829330444336 seconds\n" + ] } ], "source": [ @@ -55,9 +102,9 @@ "from tortoise.utils.audio import load_audio, load_voice, load_voices\n", "\n", "# This will download all the models used by Tortoise from the HF hub.\n", - "tts = TextToSpeech()\n", + "# tts = TextToSpeech()\n", "# If you want to use deepspeed the pass use_deepspeed=True nearly 2x faster than normal\n", - "# tts = TextToSpeech(use_deepspeed=True)" + "tts = TextToSpeech(use_deepspeed=True, kv_cache=True)" ] }, { @@ -151,7 +198,127 @@ "name": "stderr", "output_type": "stream", "text": [ - " 38%|███▊ | 6/16 [00:31<00:52, 5.20s/it]\n" + " 0%| | 0/16 [00:00 6\u001b[0m gen \u001b[39m=\u001b[39m tts\u001b[39m.\u001b[39;49mtts_with_preset(text, voice_samples\u001b[39m=\u001b[39;49mvoice_samples, conditioning_latents\u001b[39m=\u001b[39;49mconditioning_latents, \n\u001b[1;32m 7\u001b[0m preset\u001b[39m=\u001b[39;49mpreset)\n\u001b[1;32m 8\u001b[0m torchaudio\u001b[39m.\u001b[39msave(\u001b[39m'\u001b[39m\u001b[39mgenerated.wav\u001b[39m\u001b[39m'\u001b[39m, gen\u001b[39m.\u001b[39msqueeze(\u001b[39m0\u001b[39m)\u001b[39m.\u001b[39mcpu(), \u001b[39m24000\u001b[39m)\n\u001b[1;32m 9\u001b[0m IPython\u001b[39m.\u001b[39mdisplay\u001b[39m.\u001b[39mAudio(\u001b[39m'\u001b[39m\u001b[39mgenerated.wav\u001b[39m\u001b[39m'\u001b[39m)\n", "File \u001b[0;32m/data/speech_synth/tortoise-tts/tortoise/api.py:329\u001b[0m, in \u001b[0;36mTextToSpeech.tts_with_preset\u001b[0;34m(self, text, preset, **kwargs)\u001b[0m\n\u001b[1;32m 327\u001b[0m settings\u001b[39m.\u001b[39mupdate(presets[preset])\n\u001b[1;32m 328\u001b[0m settings\u001b[39m.\u001b[39mupdate(kwargs) \u001b[39m# allow overriding of preset settings with kwargs\u001b[39;00m\n\u001b[0;32m--> 329\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtts(text, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49msettings)\n", "File \u001b[0;32m/data/speech_synth/tortoise-tts/tortoise/api.py:412\u001b[0m, in \u001b[0;36mTextToSpeech.tts\u001b[0;34m(self, text, voice_samples, conditioning_latents, k, verbose, use_deterministic_seed, return_deterministic_state, num_autoregressive_samples, temperature, length_penalty, repetition_penalty, top_p, max_mel_tokens, cvvp_amount, diffusion_iterations, cond_free, cond_free_k, diffusion_temperature, **hf_generate_kwargs)\u001b[0m\n\u001b[1;32m 410\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m\"\u001b[39m\u001b[39mGenerating autoregressive samples..\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 411\u001b[0m \u001b[39mfor\u001b[39;00m b \u001b[39min\u001b[39;00m tqdm(\u001b[39mrange\u001b[39m(num_batches), disable\u001b[39m=\u001b[39m\u001b[39mnot\u001b[39;00m verbose):\n\u001b[0;32m--> 412\u001b[0m codes \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mautoregressive\u001b[39m.\u001b[39;49minference_speech(auto_conditioning, text_tokens,\n\u001b[1;32m 413\u001b[0m do_sample\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m 414\u001b[0m top_p\u001b[39m=\u001b[39;49mtop_p,\n\u001b[1;32m 415\u001b[0m temperature\u001b[39m=\u001b[39;49mtemperature,\n\u001b[1;32m 416\u001b[0m num_return_sequences\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mautoregressive_batch_size,\n\u001b[1;32m 417\u001b[0m length_penalty\u001b[39m=\u001b[39;49mlength_penalty,\n\u001b[1;32m 418\u001b[0m repetition_penalty\u001b[39m=\u001b[39;49mrepetition_penalty,\n\u001b[1;32m 419\u001b[0m max_generate_length\u001b[39m=\u001b[39;49mmax_mel_tokens,\n\u001b[1;32m 420\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mhf_generate_kwargs)\n\u001b[1;32m 421\u001b[0m padding_needed \u001b[39m=\u001b[39m max_mel_tokens \u001b[39m-\u001b[39m codes\u001b[39m.\u001b[39mshape[\u001b[39m1\u001b[39m]\n\u001b[1;32m 422\u001b[0m codes \u001b[39m=\u001b[39m F\u001b[39m.\u001b[39mpad(codes, (\u001b[39m0\u001b[39m, padding_needed), value\u001b[39m=\u001b[39mstop_mel_token)\n", - "File \u001b[0;32m/data/speech_synth/tortoise-tts/tortoise/models/autoregressive.py:513\u001b[0m, in \u001b[0;36mUnifiedVoice.inference_speech\u001b[0;34m(self, speech_conditioning_latent, text_inputs, input_tokens, num_return_sequences, max_generate_length, typical_sampling, typical_mass, **hf_generate_kwargs)\u001b[0m\n\u001b[1;32m 511\u001b[0m logits_processor \u001b[39m=\u001b[39m LogitsProcessorList([TypicalLogitsWarper(mass\u001b[39m=\u001b[39mtypical_mass)]) \u001b[39mif\u001b[39;00m typical_sampling \u001b[39melse\u001b[39;00m LogitsProcessorList()\n\u001b[1;32m 512\u001b[0m max_length \u001b[39m=\u001b[39m trunc_index \u001b[39m+\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmax_mel_tokens \u001b[39m-\u001b[39m \u001b[39m1\u001b[39m \u001b[39mif\u001b[39;00m max_generate_length \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m trunc_index \u001b[39m+\u001b[39m max_generate_length\n\u001b[0;32m--> 513\u001b[0m gen \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49minference_model\u001b[39m.\u001b[39;49mgenerate(inputs, bos_token_id\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstart_mel_token, pad_token_id\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstop_mel_token, eos_token_id\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstop_mel_token,\n\u001b[1;32m 514\u001b[0m max_length\u001b[39m=\u001b[39;49mmax_length, logits_processor\u001b[39m=\u001b[39;49mlogits_processor,\n\u001b[1;32m 515\u001b[0m num_return_sequences\u001b[39m=\u001b[39;49mnum_return_sequences, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mhf_generate_kwargs)\n\u001b[1;32m 516\u001b[0m \u001b[39mreturn\u001b[39;00m gen[:, trunc_index:]\n", + "File \u001b[0;32m/data/speech_synth/tortoise-tts/tortoise/models/autoregressive.py:490\u001b[0m, in \u001b[0;36mUnifiedVoice.inference_speech\u001b[0;34m(self, speech_conditioning_latent, text_inputs, input_tokens, num_return_sequences, max_generate_length, typical_sampling, typical_mass, **hf_generate_kwargs)\u001b[0m\n\u001b[1;32m 488\u001b[0m logits_processor \u001b[39m=\u001b[39m LogitsProcessorList([TypicalLogitsWarper(mass\u001b[39m=\u001b[39mtypical_mass)]) \u001b[39mif\u001b[39;00m typical_sampling \u001b[39melse\u001b[39;00m LogitsProcessorList()\n\u001b[1;32m 489\u001b[0m max_length \u001b[39m=\u001b[39m trunc_index \u001b[39m+\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmax_mel_tokens \u001b[39m-\u001b[39m \u001b[39m1\u001b[39m \u001b[39mif\u001b[39;00m max_generate_length \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m trunc_index \u001b[39m+\u001b[39m max_generate_length\n\u001b[0;32m--> 490\u001b[0m gen \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49minference_model\u001b[39m.\u001b[39;49mgenerate(inputs, bos_token_id\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstart_mel_token, pad_token_id\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstop_mel_token, eos_token_id\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstop_mel_token,\n\u001b[1;32m 491\u001b[0m max_length\u001b[39m=\u001b[39;49mmax_length, logits_processor\u001b[39m=\u001b[39;49mlogits_processor,\n\u001b[1;32m 492\u001b[0m num_return_sequences\u001b[39m=\u001b[39;49mnum_return_sequences, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mhf_generate_kwargs)\n\u001b[1;32m 493\u001b[0m \u001b[39mreturn\u001b[39;00m gen[:, trunc_index:]\n", "File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[39m@functools\u001b[39m\u001b[39m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdecorate_context\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[39mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[39mreturn\u001b[39;00m func(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", "File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/transformers/generation_utils.py:1310\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, typical_p, repetition_penalty, bad_words_ids, force_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, num_return_sequences, max_time, max_new_tokens, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, logits_processor, renormalize_logits, stopping_criteria, constraints, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, forced_bos_token_id, forced_eos_token_id, remove_invalid_values, synced_gpus, exponential_decay_length_penalty, **model_kwargs)\u001b[0m\n\u001b[1;32m 1302\u001b[0m input_ids, model_kwargs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_expand_inputs_for_generation(\n\u001b[1;32m 1303\u001b[0m input_ids,\n\u001b[1;32m 1304\u001b[0m expand_size\u001b[39m=\u001b[39mnum_return_sequences,\n\u001b[1;32m 1305\u001b[0m is_encoder_decoder\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mis_encoder_decoder,\n\u001b[1;32m 1306\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mmodel_kwargs,\n\u001b[1;32m 1307\u001b[0m )\n\u001b[1;32m 1309\u001b[0m \u001b[39m# 12. run sample\u001b[39;00m\n\u001b[0;32m-> 1310\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49msample(\n\u001b[1;32m 1311\u001b[0m input_ids,\n\u001b[1;32m 1312\u001b[0m logits_processor\u001b[39m=\u001b[39;49mlogits_processor,\n\u001b[1;32m 1313\u001b[0m logits_warper\u001b[39m=\u001b[39;49mlogits_warper,\n\u001b[1;32m 1314\u001b[0m stopping_criteria\u001b[39m=\u001b[39;49mstopping_criteria,\n\u001b[1;32m 1315\u001b[0m pad_token_id\u001b[39m=\u001b[39;49mpad_token_id,\n\u001b[1;32m 1316\u001b[0m eos_token_id\u001b[39m=\u001b[39;49meos_token_id,\n\u001b[1;32m 1317\u001b[0m output_scores\u001b[39m=\u001b[39;49moutput_scores,\n\u001b[1;32m 1318\u001b[0m return_dict_in_generate\u001b[39m=\u001b[39;49mreturn_dict_in_generate,\n\u001b[1;32m 1319\u001b[0m synced_gpus\u001b[39m=\u001b[39;49msynced_gpus,\n\u001b[1;32m 1320\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mmodel_kwargs,\n\u001b[1;32m 1321\u001b[0m )\n\u001b[1;32m 1323\u001b[0m \u001b[39melif\u001b[39;00m is_beam_gen_mode:\n\u001b[1;32m 1324\u001b[0m \u001b[39mif\u001b[39;00m num_return_sequences \u001b[39m>\u001b[39m num_beams:\n", - "File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/transformers/generation_utils.py:1926\u001b[0m, in \u001b[0;36mGenerationMixin.sample\u001b[0;34m(self, input_ids, logits_processor, stopping_criteria, logits_warper, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs)\u001b[0m\n\u001b[1;32m 1923\u001b[0m model_inputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mprepare_inputs_for_generation(input_ids, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mmodel_kwargs)\n\u001b[1;32m 1925\u001b[0m \u001b[39m# forward pass to get next token\u001b[39;00m\n\u001b[0;32m-> 1926\u001b[0m outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m(\n\u001b[1;32m 1927\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mmodel_inputs,\n\u001b[1;32m 1928\u001b[0m return_dict\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m 1929\u001b[0m output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m 1930\u001b[0m output_hidden_states\u001b[39m=\u001b[39;49moutput_hidden_states,\n\u001b[1;32m 1931\u001b[0m )\n\u001b[1;32m 1933\u001b[0m \u001b[39mif\u001b[39;00m synced_gpus \u001b[39mand\u001b[39;00m this_peer_finished:\n\u001b[1;32m 1934\u001b[0m cur_len \u001b[39m=\u001b[39m cur_len \u001b[39m+\u001b[39m \u001b[39m1\u001b[39m\n", - "File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", - "File \u001b[0;32m/data/speech_synth/tortoise-tts/tortoise/models/autoregressive.py:142\u001b[0m, in \u001b[0;36mGPT2InferenceModel.forward\u001b[0;34m(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 139\u001b[0m emb \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39membeddings(input_ids)\n\u001b[1;32m 140\u001b[0m emb \u001b[39m=\u001b[39m emb \u001b[39m+\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtext_pos_embedding\u001b[39m.\u001b[39mget_fixed_embedding(attention_mask\u001b[39m.\u001b[39mshape[\u001b[39m1\u001b[39m]\u001b[39m-\u001b[39mmel_len, attention_mask\u001b[39m.\u001b[39mdevice)\n\u001b[0;32m--> 142\u001b[0m transformer_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtransformer(\n\u001b[1;32m 143\u001b[0m inputs_embeds\u001b[39m=\u001b[39;49memb,\n\u001b[1;32m 144\u001b[0m past_key_values\u001b[39m=\u001b[39;49mpast_key_values,\n\u001b[1;32m 145\u001b[0m attention_mask\u001b[39m=\u001b[39;49mattention_mask,\n\u001b[1;32m 146\u001b[0m token_type_ids\u001b[39m=\u001b[39;49mtoken_type_ids,\n\u001b[1;32m 147\u001b[0m position_ids\u001b[39m=\u001b[39;49mposition_ids,\n\u001b[1;32m 148\u001b[0m head_mask\u001b[39m=\u001b[39;49mhead_mask,\n\u001b[1;32m 149\u001b[0m encoder_hidden_states\u001b[39m=\u001b[39;49mencoder_hidden_states,\n\u001b[1;32m 150\u001b[0m encoder_attention_mask\u001b[39m=\u001b[39;49mencoder_attention_mask,\n\u001b[1;32m 151\u001b[0m use_cache\u001b[39m=\u001b[39;49muse_cache,\n\u001b[1;32m 152\u001b[0m output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m 153\u001b[0m output_hidden_states\u001b[39m=\u001b[39;49moutput_hidden_states,\n\u001b[1;32m 154\u001b[0m return_dict\u001b[39m=\u001b[39;49mreturn_dict,\n\u001b[1;32m 155\u001b[0m )\n\u001b[1;32m 156\u001b[0m hidden_states \u001b[39m=\u001b[39m transformer_outputs[\u001b[39m0\u001b[39m]\n\u001b[1;32m 158\u001b[0m \u001b[39m# Set device for model parallelism\u001b[39;00m\n", - "File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", - "File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py:889\u001b[0m, in \u001b[0;36mGPT2Model.forward\u001b[0;34m(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 879\u001b[0m outputs \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mutils\u001b[39m.\u001b[39mcheckpoint\u001b[39m.\u001b[39mcheckpoint(\n\u001b[1;32m 880\u001b[0m create_custom_forward(block),\n\u001b[1;32m 881\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 886\u001b[0m encoder_attention_mask,\n\u001b[1;32m 887\u001b[0m )\n\u001b[1;32m 888\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 889\u001b[0m outputs \u001b[39m=\u001b[39m block(\n\u001b[1;32m 890\u001b[0m hidden_states,\n\u001b[1;32m 891\u001b[0m layer_past\u001b[39m=\u001b[39;49mlayer_past,\n\u001b[1;32m 892\u001b[0m attention_mask\u001b[39m=\u001b[39;49mattention_mask,\n\u001b[1;32m 893\u001b[0m head_mask\u001b[39m=\u001b[39;49mhead_mask[i],\n\u001b[1;32m 894\u001b[0m encoder_hidden_states\u001b[39m=\u001b[39;49mencoder_hidden_states,\n\u001b[1;32m 895\u001b[0m encoder_attention_mask\u001b[39m=\u001b[39;49mencoder_attention_mask,\n\u001b[1;32m 896\u001b[0m use_cache\u001b[39m=\u001b[39;49muse_cache,\n\u001b[1;32m 897\u001b[0m output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m 898\u001b[0m )\n\u001b[1;32m 900\u001b[0m hidden_states \u001b[39m=\u001b[39m outputs[\u001b[39m0\u001b[39m]\n\u001b[1;32m 901\u001b[0m \u001b[39mif\u001b[39;00m use_cache \u001b[39mis\u001b[39;00m \u001b[39mTrue\u001b[39;00m:\n", - "File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", - "File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py:390\u001b[0m, in \u001b[0;36mGPT2Block.forward\u001b[0;34m(self, hidden_states, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions)\u001b[0m\n\u001b[1;32m 388\u001b[0m residual \u001b[39m=\u001b[39m hidden_states\n\u001b[1;32m 389\u001b[0m hidden_states \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mln_1(hidden_states)\n\u001b[0;32m--> 390\u001b[0m attn_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mattn(\n\u001b[1;32m 391\u001b[0m hidden_states,\n\u001b[1;32m 392\u001b[0m layer_past\u001b[39m=\u001b[39;49mlayer_past,\n\u001b[1;32m 393\u001b[0m attention_mask\u001b[39m=\u001b[39;49mattention_mask,\n\u001b[1;32m 394\u001b[0m head_mask\u001b[39m=\u001b[39;49mhead_mask,\n\u001b[1;32m 395\u001b[0m use_cache\u001b[39m=\u001b[39;49muse_cache,\n\u001b[1;32m 396\u001b[0m output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m 397\u001b[0m )\n\u001b[1;32m 398\u001b[0m attn_output \u001b[39m=\u001b[39m attn_outputs[\u001b[39m0\u001b[39m] \u001b[39m# output_attn: a, present, (attentions)\u001b[39;00m\n\u001b[1;32m 399\u001b[0m outputs \u001b[39m=\u001b[39m attn_outputs[\u001b[39m1\u001b[39m:]\n", - "File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", - "File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py:290\u001b[0m, in \u001b[0;36mGPT2Attention.forward\u001b[0;34m(self, hidden_states, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions)\u001b[0m\n\u001b[1;32m 287\u001b[0m new_shape \u001b[39m=\u001b[39m tensor\u001b[39m.\u001b[39msize()[:\u001b[39m-\u001b[39m\u001b[39m2\u001b[39m] \u001b[39m+\u001b[39m (num_heads \u001b[39m*\u001b[39m attn_head_size,)\n\u001b[1;32m 288\u001b[0m \u001b[39mreturn\u001b[39;00m tensor\u001b[39m.\u001b[39mview(new_shape)\n\u001b[0;32m--> 290\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\n\u001b[1;32m 291\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 292\u001b[0m hidden_states: Optional[Tuple[torch\u001b[39m.\u001b[39mFloatTensor]],\n\u001b[1;32m 293\u001b[0m layer_past: Optional[Tuple[torch\u001b[39m.\u001b[39mTensor]] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 294\u001b[0m attention_mask: Optional[torch\u001b[39m.\u001b[39mFloatTensor] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 295\u001b[0m head_mask: Optional[torch\u001b[39m.\u001b[39mFloatTensor] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 296\u001b[0m encoder_hidden_states: Optional[torch\u001b[39m.\u001b[39mTensor] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 297\u001b[0m encoder_attention_mask: Optional[torch\u001b[39m.\u001b[39mFloatTensor] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 298\u001b[0m use_cache: Optional[\u001b[39mbool\u001b[39m] \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m,\n\u001b[1;32m 299\u001b[0m output_attentions: Optional[\u001b[39mbool\u001b[39m] \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m,\n\u001b[1;32m 300\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Tuple[Union[torch\u001b[39m.\u001b[39mTensor, Tuple[torch\u001b[39m.\u001b[39mTensor]], \u001b[39m.\u001b[39m\u001b[39m.\u001b[39m\u001b[39m.\u001b[39m]:\n\u001b[1;32m 301\u001b[0m \u001b[39mif\u001b[39;00m encoder_hidden_states \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 302\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mhasattr\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mq_attn\u001b[39m\u001b[39m\"\u001b[39m):\n", + "File \u001b[0;32m~/anaconda3/envs/tortoise/lib/python3.9/site-packages/transformers/generation_utils.py:1963\u001b[0m, in \u001b[0;36mGenerationMixin.sample\u001b[0;34m(self, input_ids, logits_processor, stopping_criteria, logits_warper, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs)\u001b[0m\n\u001b[1;32m 1961\u001b[0m \u001b[39m# sample\u001b[39;00m\n\u001b[1;32m 1962\u001b[0m probs \u001b[39m=\u001b[39m nn\u001b[39m.\u001b[39mfunctional\u001b[39m.\u001b[39msoftmax(next_token_scores, dim\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[0;32m-> 1963\u001b[0m next_tokens \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49mmultinomial(probs, num_samples\u001b[39m=\u001b[39;49m\u001b[39m1\u001b[39;49m)\u001b[39m.\u001b[39msqueeze(\u001b[39m1\u001b[39m)\n\u001b[1;32m 1965\u001b[0m \u001b[39m# finished sentences should have their next token be a padding token\u001b[39;00m\n\u001b[1;32m 1966\u001b[0m \u001b[39mif\u001b[39;00m eos_token_id \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] }